├── .gitignore ├── README.md ├── caps_attn_flatten ├── Capsule_masked.py ├── Config.py ├── TfUtils.py ├── __init__.py ├── data_iterator.py ├── dataprocess │ ├── __init__.py │ ├── dataprocess.py │ ├── dataprocess_sentence.py │ └── vocab.py ├── model.py ├── nest.py ├── train_test.py └── utils.py ├── caps_attn_hierarchical ├── Capsule_masked.py ├── Config.py ├── TfUtils.py ├── __init__.py ├── data_iterator.py ├── dataprocess │ ├── __init__.py │ ├── dataprocess.py │ ├── dataprocess_sentence.py │ └── vocab.py ├── model.py ├── nest.py ├── train_test.py └── utils.py ├── data └── downloadDataset.md └── savings ├── imdb └── config ├── sst01 └── config ├── sst02 └── config ├── yelp2013 └── config └── yelp2014 └── config /.gitignore: -------------------------------------------------------------------------------- 1 | data/smallset 2 | data/dataset.tar.gz 3 | *.pyc 4 | */_gen/ 5 | */gen_cmd.py 6 | .DS_Store 7 | ._.DS_Store 8 | git_bak 9 | /savings/imdb/* 10 | /savings/yelp2013/* 11 | /savings/yelp2014/* 12 | /savings/sst01/* 13 | /savings/sst02/* 14 | 15 | !/savings/imdb/config 16 | !/savings/yelp2013/config 17 | !/savings/yelp2014/config 18 | !/savings/sst01/config 19 | !/savings/sst02/config 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Capsule4TextClassification 2 | Implementation of our paper 3 | ["Information Aggregation via Dynamic Routing for Sequence Encoding"](https://arxiv.org/pdf/1806.01501.pdf) 4 | 5 | # Sytem Requirements 6 | OS: Linux (Ubuntu 14.04+) 7 | 8 | Python: v3.6 9 | 10 | Tensorflow: v1.4.0 11 | 12 | Numpy: v1.14 13 | 14 | CUDA : v8.0 15 | 16 | CUDNN: v6.0 17 | 18 | 19 | # Data Dowload 20 | Refer to [downloadDataset](./data/downloadDataset.md) for data download instructions 21 | 22 | # Quick start 23 | Please first refer to [Data Dowload](./data/downloadDataset.md) and download all the data needed, 24 | Go to root of this project `Capsule4TextClassification`, then type the following 25 | command to start training process on correspoding dataset. 26 | 27 | ```bash 28 | #for sentence level datasets, more specifically for SST-1 and SST-2 datasets 29 | python ./caps_attn_flatten/train_test.py --load-config --weight-path ./savings/sst01 30 | python ./caps_attn_flatten/train_test.py --load-config --weight-path ./savings/sst02 31 | 32 | #for document level datasets, more specifically for imdb, yelp-2013 and yelp-2014 33 | python ./caps_attn_hierarchical/train_test.py --load-config --weight-path ./savings/imdb 34 | python ./caps_attn_hierarchical/train_test.py --load-config --weight-path ./savings/yelp2013 35 | python ./caps_attn_hierarchical/train_test.py --load-config --weight-path ./savings/yelp2014 36 | 37 | ``` 38 | #### further explanation of the comamnd: 39 | Note that we provide a sentence level model (caps_attn_flatten) and a document level model (caps_attn_hierarchical) 40 | 41 | Take first command for example `--load-config` indicates that before construction of the computational graph we will 42 | load a config file from a directory which ever `--weight-path` specifies, in this case `./savings/sst01`. 43 | There is a `./savings/sst01/config` file that controls the configuration of the model, if you ever want to run 44 | another configuration, you should simply copy the `./savings/sst01` directory and modify the config file, 45 | and then run a command similar to those specified as above. 46 | 47 | `--weight-path` specifies in which directory we want to store our config file, and most importantly the model checkpoint. 48 | Also a `status` file, which is originally used to avoid conflict. 49 | **note that if you ever encountered "process running or finished" error, you should remove `status` file**. 50 | 51 | 52 | -------------------------------------------------------------------------------- /caps_attn_flatten/Capsule_masked.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.layers import base as base_layer 2 | import tensorflow as tf 3 | import numpy as np 4 | from TfUtils import mkMask 5 | 6 | _EPSILON = 1e-9 7 | _MIN_NUM = -np.Inf 8 | 9 | class Capusule(base_layer.Layer): 10 | def __init__(self, out_caps_num, out_caps_dim, iter_num=3, reuse=None): 11 | super(Capusule, self).__init__(_reuse=reuse) 12 | self.out_caps_num = out_caps_num 13 | self.out_caps_dim = out_caps_dim 14 | self.iter_num = iter_num 15 | 16 | 17 | def call(self, in_caps, seqLen, reverse_routing=False): 18 | caps_uhat = shared_routing_uhat(in_caps, self.out_caps_num, self.out_caps_dim, scope='rnn_caps_uhat') 19 | if not reverse_routing: 20 | V, S = masked_routing_iter(caps_uhat, seqLen, self.iter_num) 21 | else: 22 | V, S = masked_reverse_routing_iter(caps_uhat, seqLen, self.iter_num) 23 | return V 24 | 25 | 26 | def shared_routing_uhat(caps, out_caps_num, out_caps_dim, scope=None): 27 | ''' 28 | 29 | Args: 30 | caps: # shape(b_sz, caps_num, caps_dim) 31 | out_caps_num: #number of output capsule 32 | out_caps_dim: #dimension of output capsule 33 | Returns: 34 | caps_uhat: shape(b_sz, caps_num, out_caps_num, out_caps_dim) 35 | ''' 36 | b_sz = tf.shape(caps)[0] 37 | tstp = tf.shape(caps)[1] 38 | 39 | with tf.variable_scope(scope or 'shared_routing_uhat'): 40 | '''shape(b_sz, caps_num, out_caps_num*out_caps_dim)''' 41 | caps_uhat = tf.layers.dense(caps, out_caps_num * out_caps_dim, activation=tf.tanh) 42 | caps_uhat = tf.reshape(caps_uhat, shape=[b_sz, tstp, out_caps_num, out_caps_dim]) 43 | 44 | return caps_uhat 45 | 46 | 47 | def masked_routing_iter(caps_uhat, seqLen, iter_num): 48 | ''' 49 | 50 | Args: 51 | caps_uhat: shape(b_sz, tstp, out_caps_num, out_caps_dim) 52 | seqLen: shape(b_sz) 53 | iter_num: number of iteration 54 | 55 | Returns: 56 | V_ret: #shape(b_sz, out_caps_num, out_caps_dim) 57 | ''' 58 | assert iter_num > 0 59 | b_sz = tf.shape(caps_uhat)[0] 60 | tstp = tf.shape(caps_uhat)[1] 61 | out_caps_num = int(caps_uhat.get_shape()[2]) 62 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 63 | mask = mkMask(seqLen, tstp) # shape(b_sz, tstp) 64 | floatmask = tf.cast(tf.expand_dims(mask, axis=-1), dtype=tf.float32) # shape(b_sz, tstp, 1) 65 | 66 | # shape(b_sz, tstp, out_caps_num) 67 | B = tf.zeros([b_sz, tstp, out_caps_num], dtype=tf.float32) 68 | for i in range(iter_num): 69 | C = tf.nn.softmax(B, dim=2) # shape(b_sz, tstp, out_caps_num) 70 | C = tf.expand_dims(C*floatmask, axis=-1) # shape(b_sz, tstp, out_caps_num, 1) 71 | weighted_uhat = C * caps_uhat # shape(b_sz, tstp, out_caps_num, out_caps_dim) 72 | 73 | S = tf.reduce_sum(weighted_uhat, axis=1) # shape(b_sz, out_caps_num, out_caps_dim) 74 | 75 | V = _squash(S, axes=[2]) # shape(b_sz, out_caps_num, out_caps_dim) 76 | V = tf.expand_dims(V, axis=1) # shape(b_sz, 1, out_caps_num, out_caps_dim) 77 | B = tf.reduce_sum(caps_uhat * V, axis=-1) + B # shape(b_sz, tstp, out_caps_num) 78 | 79 | V_ret = tf.squeeze(V, axis=[1]) # shape(b_sz, out_caps_num, out_caps_dim) 80 | S_ret = S 81 | return V_ret, S_ret 82 | 83 | 84 | def masked_reverse_routing_iter(caps_uhat, seqLen, iter_num): 85 | ''' 86 | 87 | Args: 88 | caps_uhat: shape(b_sz, tstp, out_caps_num, out_caps_dim) 89 | seqLen: shape(b_sz) 90 | iter_num: number of iteration 91 | 92 | Returns: 93 | V_ret: #shape(b_sz, out_caps_num, out_caps_dim) 94 | ''' 95 | assert iter_num > 0 96 | b_sz = tf.shape(caps_uhat)[0] 97 | tstp = tf.shape(caps_uhat)[1] 98 | out_caps_num = int(caps_uhat.get_shape()[2]) 99 | 100 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 101 | mask = mkMask(seqLen, tstp) # shape(b_sz, tstp) 102 | mask = tf.tile(tf.expand_dims(mask, axis=-1), # shape(b_sz, tstp, out_caps_num) 103 | multiples=[1, 1, out_caps_num]) 104 | # shape(b_sz, tstp, out_caps_num) 105 | B = tf.zeros([b_sz, tstp, out_caps_num], dtype=tf.float32) 106 | B = tf.where(mask, B, tf.ones_like(B) * _MIN_NUM) 107 | for i in range(iter_num): 108 | C = tf.nn.softmax(B, dim=1) # shape(b_sz, tstp, out_caps_num) 109 | C = tf.expand_dims(C, axis=-1) # shape(b_sz, tstp, out_caps_num, 1) 110 | weighted_uhat = C * caps_uhat # shape(b_sz, tstp, out_caps_num, out_caps_dim) 111 | 112 | S = tf.reduce_sum(weighted_uhat, axis=1) # shape(b_sz, out_caps_num, out_caps_dim) 113 | 114 | V = _squash(S, axes=[2]) # shape(b_sz, out_caps_num, out_caps_dim) 115 | V = tf.expand_dims(V, axis=1) # shape(b_sz, 1, out_caps_num, out_caps_dim) 116 | B = tf.reduce_sum(caps_uhat * V, axis=-1) + B # shape(b_sz, tstp, out_caps_num) 117 | 118 | V_ret = tf.squeeze(V, axis=[1]) # shape(b_sz, out_caps_num, out_caps_dim) 119 | S_ret = S 120 | return V_ret, S_ret 121 | 122 | 123 | def margin_loss(y_true, y_pred): 124 | """ 125 | :param y_true: [None, n_classes] 126 | :param y_pred: [None, n_classes] 127 | :return: a scalar loss value. 128 | """ 129 | L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \ 130 | 0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1)) 131 | 132 | assert_inf_L = tf.Assert(tf.logical_not(tf.reduce_any(tf.is_inf(L))), 133 | ['assert_inf_L', L], summarize=100) 134 | assert_nan_L = tf.Assert(tf.logical_not(tf.reduce_any(tf.is_nan(L))), 135 | ['assert_nan_L', L], summarize=100) 136 | with tf.control_dependencies([assert_inf_L, assert_nan_L]): 137 | ret = tf.reduce_mean(tf.reduce_sum(L, axis=1)) 138 | 139 | return ret 140 | 141 | 142 | def _squash(in_caps, axes): 143 | ''' 144 | Squashing function corresponding to Eq. 1 145 | Args: 146 | in_caps: a tensor 147 | axes: dimensions along which to apply squash 148 | 149 | Returns: 150 | vec_squashed: squashed tensor 151 | 152 | ''' 153 | vec_squared_norm = tf.reduce_sum(tf.square(in_caps), axis=axes, keep_dims=True) 154 | scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + _EPSILON) 155 | vec_squashed = scalar_factor * in_caps # element-wise 156 | return vec_squashed 157 | 158 | 159 | -------------------------------------------------------------------------------- /caps_attn_flatten/Config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import traceback 3 | import json 4 | 5 | 6 | class Config(object): 7 | """Holds model hyperparams and data information. 8 | 9 | The config class is used to store various hyperparameters and dataset 10 | information parameters. Model objects are passed a Config() object at 11 | instantiation. 12 | 13 | """ 14 | 15 | """General""" 16 | revision = 'None' 17 | datapath = './data/smallset/' 18 | embed_path = './data/embedding.txt' 19 | 20 | optimizer = 'adam' 21 | attn_mode = 'attn' 22 | seq_encoder = 'bigru' 23 | 24 | out_caps_num = 5 25 | rout_iter = 3 26 | 27 | max_snt_num = 30 28 | max_wd_num = 30 29 | max_epochs = 50 30 | pre_trained = True 31 | batch_sz = 64 32 | batch_sz_min = 32 33 | bucket_sz = 5000 34 | partial_update_until_epoch = 1 35 | 36 | embed_size = 300 37 | hidden_size = 200 38 | dense_hidden = [300, 5] 39 | 40 | lr = 0.0001 41 | decay_steps = 1000 42 | decay_rate = 0.9 43 | 44 | dropout = 0.2 45 | early_stopping = 7 46 | reg = 0. 47 | 48 | def __init__(self): 49 | self.attr_list = [i for i in list(Config.__dict__.keys()) if 50 | not callable(getattr(self, i)) and not i.startswith("__")] 51 | 52 | def printall(self): 53 | for attr in self.attr_list: 54 | print(attr, getattr(self, attr), type(getattr(self, attr))) 55 | 56 | def saveConfig(self, filePath): 57 | 58 | cfg = configparser.ConfigParser() 59 | cfg['General'] = {} 60 | gen_sec = cfg['General'] 61 | for attr in self.attr_list: 62 | try: 63 | gen_sec[attr] = json.dumps(getattr(self, attr)) 64 | except Exception as e: 65 | traceback.print_exc() 66 | raise ValueError('something wrong in “%s” entry' % attr) 67 | 68 | with open(filePath, 'w') as fd: 69 | cfg.write(fd) 70 | 71 | def loadConfig(self, filePath): 72 | 73 | cfg = configparser.ConfigParser() 74 | cfg.read(filePath) 75 | gen_sec = cfg['General'] 76 | for attr in self.attr_list: 77 | try: 78 | val = json.loads(gen_sec[attr]) 79 | assert type(val) == type(getattr(self, attr)), \ 80 | 'type not match, expect %s got %s' % \ 81 | (type(getattr(self, attr)), type(val)) 82 | 83 | setattr(self, attr, val) 84 | except Exception as e: 85 | traceback.print_exc() 86 | raise ValueError('something wrong in “%s” entry' % attr) 87 | 88 | with open(filePath, 'w') as fd: 89 | cfg.write(fd) -------------------------------------------------------------------------------- /caps_attn_flatten/TfUtils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import nest 3 | import numpy as np 4 | 5 | def mkMask(input_tensor, maxLen): 6 | shape_of_input = tf.shape(input_tensor) 7 | shape_of_output = tf.concat(axis=0, values=[shape_of_input, [maxLen]]) 8 | 9 | oneDtensor = tf.reshape(input_tensor, shape=(-1,)) 10 | flat_mask = tf.sequence_mask(oneDtensor, maxlen=maxLen) 11 | return tf.reshape(flat_mask, shape_of_output) 12 | 13 | 14 | def reduce_avg(reduce_target, lengths, dim): 15 | """ 16 | Args: 17 | reduce_target : shape(d_0, d_1,..,d_dim, .., d_k) 18 | lengths : shape(d0, .., d_(dim-1)) 19 | dim : which dimension to average, should be a python number 20 | """ 21 | shape_of_lengths = lengths.get_shape() 22 | shape_of_target = reduce_target.get_shape() 23 | if len(shape_of_lengths) != dim: 24 | raise ValueError(('Second input tensor should be rank %d, ' + 25 | 'while it got rank %d') % (dim, len(shape_of_lengths))) 26 | if len(shape_of_target) < dim+1 : 27 | raise ValueError(('First input tensor should be at least rank %d, ' + 28 | 'while it got rank %d') % (dim+1, len(shape_of_target))) 29 | 30 | rank_diff = len(shape_of_target) - len(shape_of_lengths) - 1 31 | mxlen = tf.shape(reduce_target)[dim] 32 | mask = mkMask(lengths, mxlen) 33 | if rank_diff!=0: 34 | len_shape = tf.concat(axis=0, values=[tf.shape(lengths), [1]*rank_diff]) 35 | mask_shape = tf.concat(axis=0, values=[tf.shape(mask), [1]*rank_diff]) 36 | else: 37 | len_shape = tf.shape(lengths) 38 | mask_shape = tf.shape(mask) 39 | lengths_reshape = tf.reshape(lengths, shape=len_shape) 40 | mask = tf.reshape(mask, shape=mask_shape) 41 | 42 | mask_target = reduce_target * tf.cast(mask, dtype=reduce_target.dtype) 43 | 44 | red_sum = tf.reduce_sum(mask_target, axis=[dim], keep_dims=False) 45 | red_avg = red_sum / (tf.to_float(lengths_reshape) + 1e-30) 46 | return red_avg 47 | 48 | 49 | def reduce_sum(reduce_target, lengths, dim): 50 | """ 51 | Args: 52 | reduce_target : shape(d_0, d_1,..,d_dim, .., d_k) 53 | lengths : shape(d0, .., d_(dim-1)) 54 | dim : which dimension to average, should be a python number 55 | """ 56 | shape_of_lengths = lengths.get_shape() 57 | shape_of_target = reduce_target.get_shape() 58 | if len(shape_of_lengths) != dim: 59 | raise ValueError(('Second input tensor should be rank %d, ' + 60 | 'while it got rank %d') % (dim, len(shape_of_lengths))) 61 | if len(shape_of_target) < dim+1 : 62 | raise ValueError(('First input tensor should be at least rank %d, ' + 63 | 'while it got rank %d') % (dim+1, len(shape_of_target))) 64 | 65 | rank_diff = len(shape_of_target) - len(shape_of_lengths) - 1 66 | mxlen = tf.shape(reduce_target)[dim] 67 | mask = mkMask(lengths, mxlen) 68 | if rank_diff!=0: 69 | len_shape = tf.concat(axis=0, values=[tf.shape(lengths), [1]*rank_diff]) 70 | mask_shape = tf.concat(axis=0, values=[tf.shape(mask), [1]*rank_diff]) 71 | else: 72 | len_shape = tf.shape(lengths) 73 | mask_shape = tf.shape(mask) 74 | lengths_reshape = tf.reshape(lengths, shape=len_shape) 75 | mask = tf.reshape(mask, shape=mask_shape) 76 | 77 | mask_target = reduce_target * tf.cast(mask, dtype=reduce_target.dtype) 78 | 79 | red_sum = tf.reduce_sum(mask_target, axis=[dim], keep_dims=False) 80 | 81 | return red_sum 82 | 83 | 84 | def embed_lookup_last_dim(embedding, ids): 85 | ''' 86 | embedding: shape(b_sz, tstp, emb_sz) 87 | ids : shape(b_sz, tstp) 88 | ''' 89 | input_shape = tf.shape(embedding) 90 | time_steps = input_shape[0] 91 | def _create_ta(name, dtype): 92 | return tf.TensorArray(dtype=dtype, 93 | size=time_steps, 94 | tensor_array_name=name) 95 | input_ta = _create_ta('input_ta', embedding.dtype) 96 | fetch_ta = _create_ta('fetch_ta', ids.dtype) 97 | output_ta = _create_ta('output_ta', embedding.dtype) 98 | input_ta = input_ta.unpack(embedding) 99 | fetch_ta = fetch_ta.unpack(ids) 100 | 101 | def loop_body(time, output_ta): 102 | embed = input_ta.read(time) #shape(tstp, emb_sz) type of float32 103 | fetch_id = fetch_ta.read(time) #shape(tstp) type of int32 104 | out_emb = tf.nn.embedding_lookup(embed, fetch_id) 105 | output_ta = output_ta.write(time, out_emb) 106 | 107 | next_time = time+1 108 | return next_time, output_ta 109 | time = tf.constant(0) 110 | _, output_ta = tf.while_loop(cond=lambda time, *_: time < time_steps, 111 | body=loop_body, loop_vars=(time, output_ta), 112 | swap_memory=True) 113 | ret_t = output_ta.pack() #shape(b_sz, tstp, embd_sz) 114 | return ret_t 115 | 116 | 117 | def entry_stop_gradients(target, mask): 118 | ''' 119 | Args: 120 | target: a tensor 121 | mask: a boolean tensor that broadcast to the rank of that to target tensor 122 | Returns: 123 | ret: a tensor have the same value of target, 124 | but some entry will have no gradient during backprop 125 | ''' 126 | mask_h = tf.logical_not(mask) 127 | 128 | mask = tf.cast(mask, dtype=target.dtype) 129 | mask_h = tf.cast(mask_h, dtype=target.dtype) 130 | ret = tf.stop_gradient(mask_h * target) + mask * target 131 | 132 | return ret 133 | 134 | 135 | def last_dim_linear(inputs, output_size, bias, scope): 136 | ''' 137 | Args: 138 | input: shape(b_sz, ..., rep_sz) 139 | output_size: a scalar, python number 140 | ''' 141 | bias_start=0.0 142 | input_shape = tf.shape(inputs) 143 | out_shape = tf.concat(axis=0, values=[input_shape[:-1], [output_size]]) 144 | input_size = int(inputs.get_shape()[-1]) 145 | unbatch_input = tf.reshape(inputs, shape=[-1, input_size]) 146 | 147 | unbatch_output = linear(unbatch_input, output_size, bias=bias, 148 | bias_start=bias_start, scope=scope) 149 | batch_output = tf.reshape(unbatch_output, shape=out_shape) 150 | 151 | return batch_output # shape(b_sz, ..., output_size) 152 | 153 | 154 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 155 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 156 | 157 | Args: 158 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 159 | output_size: int, second dimension of W[i]. 160 | bias: boolean, whether to add a bias term or not. 161 | bias_start: starting value to initialize the bias; 0 by default. 162 | scope: (optional) Variable scope to create parameters in. 163 | 164 | Returns: 165 | A 2D Tensor with shape [batch x output_size] equal to 166 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 167 | 168 | Raises: 169 | ValueError: if some of the arguments has unspecified or wrong shape. 170 | """ 171 | if args is None or (nest.is_sequence(args) and not args): 172 | raise ValueError("`args` must be specified") 173 | if not nest.is_sequence(args): 174 | args = [args] 175 | 176 | # Calculate the total size of arguments on dimension 1. 177 | total_arg_size = 0 178 | shapes = [a.get_shape() for a in args] 179 | for shape in shapes: 180 | if shape.ndims != 2: 181 | raise ValueError("linear is expecting 2D arguments: %s" % shapes) 182 | if shape[1].value is None: 183 | raise ValueError("linear expects shape[1] to be provided for shape %s, " 184 | "but saw %s" % (shape, shape[1])) 185 | else: 186 | total_arg_size += shape[1].value 187 | 188 | dtype = [a.dtype for a in args][0] 189 | 190 | # Now the computation. 191 | with tf.variable_scope(scope or 'Linear') as outer_scope: 192 | weights = tf.get_variable( 193 | "weights", [total_arg_size, output_size], dtype=dtype) 194 | if len(args) == 1: 195 | res = tf.matmul(args[0], weights) 196 | else: 197 | res = tf.matmul(tf.concat(args, 1), weights) 198 | if not bias: 199 | return res 200 | with tf.variable_scope(outer_scope) as inner_scope: 201 | inner_scope.set_partitioner(None) 202 | biases = tf.get_variable( 203 | "biases", [output_size], 204 | dtype=dtype, 205 | initializer=tf.constant_initializer(bias_start, dtype=dtype)) 206 | return tf.nn.bias_add(res, biases) 207 | 208 | 209 | def masked_softmax(inp, seqLen): 210 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 211 | if len(inp.get_shape()) != len(seqLen.get_shape())+1: 212 | raise ValueError('rank of seqLen should be %d, but have the rank %d.\n' 213 | % (len(inp.get_shape())-1, len(seqLen.get_shape()))) 214 | mask = mkMask(seqLen, tf.shape(inp)[-1]) 215 | masked_inp = tf.where(mask, inp, tf.ones_like(inp) * (-np.Inf)) 216 | ret = tf.nn.softmax(masked_inp) 217 | return ret 218 | 219 | from tensorflow.python.client import device_lib 220 | def get_available_gpus(): 221 | local_device_protos = device_lib.list_local_devices() 222 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 223 | -------------------------------------------------------------------------------- /caps_attn_flatten/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingjing-gong/Capsule4TextClassification/35a5e0d65f45c7810c082c1046535de9207eb12b/caps_attn_flatten/__init__.py -------------------------------------------------------------------------------- /caps_attn_flatten/data_iterator.py: -------------------------------------------------------------------------------- 1 | import _pickle as pkl 2 | import numpy as np 3 | from multiprocessing import Queue 4 | 5 | class TextIterator: 6 | """Simple Bitext iterator.""" 7 | def __init__(self, datapath, batch_size=128, bucket_sz=1000, shuffle=False, sample_balance=False, id2weight=None): 8 | 9 | with open(datapath, 'rb') as fd: 10 | data = pkl.load(fd) 11 | '''data==> [(labe, doc),]''' 12 | example_num = len(data) 13 | '''shape(example_num)''' 14 | doc_sz = np.array([len(doc) for _, doc in data], dtype=np.int32) 15 | 16 | if shuffle: 17 | self.tidx = np.argsort(doc_sz) 18 | else: 19 | self.tidx = np.arange(example_num) 20 | 21 | self.num_example = example_num 22 | self.shuffle = shuffle 23 | self.bucket_sz = bucket_sz 24 | self.batch_sz = batch_size 25 | self.data = data 26 | 27 | self.sample_balance = sample_balance 28 | self.id2weight = id2weight 29 | 30 | def __iter__(self): 31 | if self.bucket_sz < self.batch_sz: 32 | self.bucket_sz = self.batch_sz 33 | if self.bucket_sz > self.num_example: 34 | self.bucket_sz = self.num_example 35 | self.startpoint = 0 36 | return self 37 | 38 | def __next__(self): 39 | if self.startpoint >= self.num_example: 40 | raise StopIteration 41 | 42 | if self.shuffle: 43 | bucket_start = np.random.randint(0, self.num_example) 44 | bucket_end = (bucket_start + self.bucket_sz) % self.num_example 45 | if bucket_end - bucket_start < self.bucket_sz: 46 | candidate = np.concatenate([self.tidx[bucket_start:], self.tidx[:bucket_end]]) 47 | else: 48 | candidate = self.tidx[bucket_start: bucket_end] 49 | candidate_p = None 50 | if self.sample_balance and self.id2weight: 51 | candidate_label = [self.data[c][0] for c in candidate] 52 | candidate_p = np.array([self.id2weight[l] for l in candidate_label]) 53 | candidate_p = candidate_p/np.sum(candidate_p) 54 | target_idx = np.random.choice(candidate, size=self.batch_sz, p=candidate_p) 55 | else: 56 | target_idx = self.tidx[self.startpoint:self.startpoint+self.batch_sz] 57 | self.startpoint += self.batch_sz 58 | 59 | labels = [] 60 | data_x = [] 61 | for idx in target_idx: 62 | l, d = self.data[idx] 63 | labels.append(l) 64 | data_x.append(d) 65 | return labels, data_x 66 | 67 | 68 | def preparedata(dataset: list, q: Queue, max_wd_num: int, class_freq: dict): 69 | for labels, data_x in dataset: 70 | example_weight = np.array([class_freq[i] for i in labels]) #(b_sz) 71 | data_batch, wNum = paddata(data_x, max_wd_num=max_wd_num) 72 | labels = np.array(labels) 73 | q.put((data_batch, labels, wNum, example_weight)) 74 | q.put(None) 75 | 76 | 77 | def paddata(data_x: list, max_wd_num: int): 78 | ''' 79 | 80 | :param data_x: (b_sz, wd_num) 81 | :param max_wd_num: 82 | :return: b shape(b_sz, wd_sz), wNum shape(b_sz,) 83 | ''' 84 | 85 | b_sz = len(data_x) 86 | 87 | wd_num = np.array([len(doc) if len(doc) < max_wd_num else max_wd_num for doc in data_x], dtype=np.int32) 88 | wd_sz = np.max(wd_num) 89 | 90 | b = np.zeros(shape=[b_sz, wd_sz], dtype=np.int32) # == PAD 91 | wNum = wd_num 92 | for i, snt in enumerate(data_x): 93 | for j, wd in enumerate(snt): 94 | if j >= wd_sz: 95 | continue 96 | b[i, j] = wd 97 | 98 | 99 | return b, wNum -------------------------------------------------------------------------------- /caps_attn_flatten/dataprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingjing-gong/Capsule4TextClassification/35a5e0d65f45c7810c082c1046535de9207eb12b/caps_attn_flatten/dataprocess/__init__.py -------------------------------------------------------------------------------- /caps_attn_flatten/dataprocess/dataprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pkl 3 | import os, operator 4 | from collections import defaultdict 5 | from tensorflow.python.util import nest 6 | from vocab import Vocab 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description="datasets") 10 | 11 | parser.add_argument('--train-set', action='store', dest='train_set', default=None) 12 | parser.add_argument('--dev-set', action='store', dest='dev_set', default=None) 13 | parser.add_argument('--test-set', action='store', dest='test_set', default=None) 14 | parser.add_argument('--ref-embedding', action='store', dest='ref_emb', default='/home/jjgong/data/glove300d/glove.840B.300d.txt') 15 | parser.add_argument('--dest-dir', action='store', dest='dest_dir', default='./') 16 | parser.add_argument('--label2id', action='store', dest='label2id', default=None) 17 | parser.add_argument('--base-wd-freq', action='store', dest='base_wd_freq', default=3) 18 | 19 | args = parser.parse_args() 20 | 21 | def extract(fn): 22 | label_collect = [] 23 | doc_tok_collect = [] 24 | with open(fn, 'r') as fd: 25 | for line in fd: 26 | item = line.strip().split('\t\t') 27 | try: 28 | label = item[2] 29 | doc = item[3] 30 | except: 31 | print(line) 32 | print(item) 33 | raise ValueError 34 | 35 | doc2snt = doc.strip().split('') 36 | doc2snt2wd = [snt.strip().split(' ') for snt in doc2snt] 37 | label_collect.append(label) 38 | doc_tok_collect.append(doc2snt2wd) 39 | return label_collect, doc_tok_collect 40 | 41 | def constructLabel_dict(labels, savepath): 42 | label_freq = defaultdict(int) 43 | for i in labels: 44 | label_freq[i] += 1 45 | class_num = len(label_freq.values()) 46 | id2revfreq = {} 47 | dinominator = float(sum(label_freq.values())) 48 | if args.label2id is None: 49 | label2id = dict(list(zip(label_freq.keys(), [int(o)-1 for o in label_freq.keys()]))) 50 | else: 51 | with open(args.label2id, 'rb') as fd: 52 | label2id = pkl.load(fd) 53 | id2label = {idx: label for label, idx in label2id.items()} 54 | for item in id2label: 55 | label = id2label[item] 56 | freq = label_freq[label] 57 | id2revfreq[item] = float(dinominator)/float(freq) 58 | dino = float(sum(id2revfreq.values())) 59 | id2weight = {idx: class_num * revfreq/dino for idx, revfreq in id2revfreq.items()} 60 | 61 | with open(savepath, 'wb') as fd: 62 | pkl.dump(label2id, fd) 63 | pkl.dump(id2label, fd) 64 | pkl.dump(id2revfreq, fd) 65 | pkl.dump(id2weight, fd) 66 | 67 | def loadLabel_dict(savepath): 68 | with open(savepath, 'rb') as fd: 69 | label2id = pkl.load(fd) 70 | id2label = pkl.load(fd) 71 | id2revfreq = pkl.load(fd) 72 | id2weight = pkl.load(fd) 73 | return label2id, id2label, id2revfreq, id2weight 74 | 75 | 76 | def readEmbedding(fileName): 77 | """ 78 | Read Embedding Function 79 | 80 | Args: 81 | fileName : file which stores the embedding 82 | Returns: 83 | embeddings_index : a dictionary contains the mapping from word to vector 84 | """ 85 | embeddings_index = {} 86 | with open(fileName, 'r') as f: 87 | for line in f: 88 | line_uni = line.strip() 89 | values = line_uni.split(' ') 90 | if len(values) != 301: 91 | continue 92 | word = values[0] 93 | w2v_line = ' '.join(values) 94 | embeddings_index[word] = w2v_line 95 | return embeddings_index 96 | 97 | def buildEmbedding(src_embed_file, tgt_embed_file, word_dict): 98 | emb_dict = readEmbedding(src_embed_file) 99 | with open(tgt_embed_file, 'w') as fd: 100 | for word in word_dict: 101 | if word in emb_dict: 102 | fd.writelines(emb_dict[word]+'\n') 103 | return None 104 | 105 | if __name__ == '__main__': 106 | vocab = Vocab() 107 | tok_collect = [] 108 | labels_collect = [] 109 | if args.train_set: 110 | train_label, train_toks = extract(args.train_set) 111 | tok_collect.append(train_toks) 112 | labels_collect.append(train_label) 113 | if args.dev_set: 114 | dev_label, dev_toks = extract(args.dev_set) 115 | tok_collect.append(dev_toks) 116 | labels_collect.append(dev_label) 117 | if args.test_set: 118 | test_label, test_toks = extract(args.test_set) 119 | 120 | vocab.construct(nest.flatten(tok_collect)) 121 | vocab.limit_vocab_length(base_freq=args.base_wd_freq) 122 | vocab.save_vocab(os.path.join(args.dest_dir, 'vocab.pkl')) 123 | 124 | constructLabel_dict(nest.flatten(labels_collect), os.path.join(args.dest_dir, 'label2id.pkl')) 125 | 126 | vocab = Vocab() 127 | vocab.load_vocab_from_file(os.path.join(args.dest_dir, 'vocab.pkl')) 128 | 129 | buildEmbedding(args.ref_emb, os.path.join(args.dest_dir, 'embedding.txt'), vocab.word_to_index) 130 | 131 | label2id, id2label, id2revfreq, id2weight = loadLabel_dict(os.path.join(args.dest_dir, 'label2id.pkl')) 132 | 133 | if args.train_set: 134 | train_label = [label2id[o] for o in train_label] 135 | train_toks = nest.map_structure(lambda x: vocab.encode(x), train_toks) 136 | train_set = [o for o in zip(train_label, train_toks)] 137 | with open(os.path.join(args.dest_dir, 'trainset.pkl'), 'wb') as fd: 138 | pkl.dump(train_set, fd) 139 | if args.dev_set: 140 | dev_label = [label2id[o] for o in dev_label] 141 | dev_toks = nest.map_structure(lambda x: vocab.encode(x), dev_toks) 142 | dev_set = [o for o in zip(dev_label, dev_toks)] 143 | with open(os.path.join(args.dest_dir, 'devset.pkl'), 'wb') as fd: 144 | pkl.dump(dev_set, fd) 145 | if args.test_set: 146 | test_label = [label2id[o] for o in test_label] 147 | test_toks = nest.map_structure(lambda x: vocab.encode(x), test_toks) 148 | test_set = [o for o in zip(test_label, test_toks)] 149 | with open(os.path.join(args.dest_dir, 'testset.pkl'), 'wb') as fd: 150 | pkl.dump(test_set, fd) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /caps_attn_flatten/dataprocess/dataprocess_sentence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pkl 3 | import os, operator 4 | from collections import defaultdict 5 | from tensorflow.python.util import nest 6 | from vocab import Vocab 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description="datasets") 10 | 11 | parser.add_argument('--train-set', action='store', dest='train_set', default=None) 12 | parser.add_argument('--dev-set', action='store', dest='dev_set', default=None) 13 | parser.add_argument('--test-set', action='store', dest='test_set', default=None) 14 | parser.add_argument('--ref-embedding', action='store', dest='ref_emb', default='/home/jjgong/data/glove300d/glove.840B.300d.txt') 15 | parser.add_argument('--dest-dir', action='store', dest='dest_dir', default='./') 16 | parser.add_argument('--label2id', action='store', dest='label2id', default=None) 17 | parser.add_argument('--base-wd-freq', type=int, action='store', dest='base_wd_freq', default=3) 18 | parser.add_argument('--labelshift', type=int, action='store', dest='shift', default=1) 19 | 20 | args = parser.parse_args() 21 | 22 | def extract(fn): 23 | label_collect = [] 24 | snt_tok_collect = [] 25 | with open(fn, 'r') as fd: 26 | for line in fd: 27 | item = line.strip().split('\t\t') 28 | try: 29 | label = item[0] 30 | snt = item[1] 31 | except: 32 | print(line) 33 | print(item) 34 | raise ValueError 35 | 36 | snt2wd = snt.strip().split(' ') 37 | label_collect.append(label) 38 | snt_tok_collect.append(snt2wd) 39 | return label_collect, snt_tok_collect 40 | 41 | def constructLabel_dict(labels, savepath): 42 | label_freq = defaultdict(int) 43 | for i in labels: 44 | label_freq[i] += 1 45 | class_num = len(label_freq.values()) 46 | id2revfreq = {} 47 | dinominator = float(sum(label_freq.values())) 48 | if args.label2id is None: 49 | label2id = dict(list(zip(label_freq.keys(), [int(o)-args.shift for o in label_freq.keys()]))) 50 | else: 51 | with open(args.label2id, 'rb') as fd: 52 | label2id = pkl.load(fd) 53 | id2label = {idx: label for label, idx in label2id.items()} 54 | for item in id2label: 55 | label = id2label[item] 56 | freq = label_freq[label] 57 | id2revfreq[item] = float(dinominator)/float(freq) 58 | dino = float(sum(id2revfreq.values())) 59 | id2weight = {idx: class_num * revfreq/dino for idx, revfreq in id2revfreq.items()} 60 | 61 | with open(savepath, 'wb') as fd: 62 | pkl.dump(label2id, fd) 63 | pkl.dump(id2label, fd) 64 | pkl.dump(id2revfreq, fd) 65 | pkl.dump(id2weight, fd) 66 | 67 | def loadLabel_dict(savepath): 68 | with open(savepath, 'rb') as fd: 69 | label2id = pkl.load(fd) 70 | id2label = pkl.load(fd) 71 | id2revfreq = pkl.load(fd) 72 | id2weight = pkl.load(fd) 73 | return label2id, id2label, id2revfreq, id2weight 74 | 75 | 76 | def readEmbedding(fileName): 77 | """ 78 | Read Embedding Function 79 | 80 | Args: 81 | fileName : file which stores the embedding 82 | Returns: 83 | embeddings_index : a dictionary contains the mapping from word to vector 84 | """ 85 | embeddings_index = {} 86 | with open(fileName, 'r') as f: 87 | for line in f: 88 | line_uni = line.strip() 89 | values = line_uni.split(' ') 90 | if len(values) != 301: 91 | continue 92 | word = values[0] 93 | w2v_line = ' '.join(values) 94 | embeddings_index[word] = w2v_line 95 | return embeddings_index 96 | 97 | def buildEmbedding(src_embed_file, tgt_embed_file, word_dict): 98 | emb_dict = readEmbedding(src_embed_file) 99 | with open(tgt_embed_file, 'w') as fd: 100 | for word in word_dict: 101 | if word in emb_dict: 102 | fd.writelines(emb_dict[word]+'\n') 103 | return None 104 | 105 | if __name__ == '__main__': 106 | vocab = Vocab() 107 | tok_collect = [] 108 | labels_collect = [] 109 | if args.train_set: 110 | train_label, train_toks = extract(args.train_set) 111 | tok_collect.append(train_toks) 112 | labels_collect.append(train_label) 113 | if args.dev_set: 114 | dev_label, dev_toks = extract(args.dev_set) 115 | tok_collect.append(dev_toks) 116 | labels_collect.append(dev_label) 117 | if args.test_set: 118 | test_label, test_toks = extract(args.test_set) 119 | 120 | vocab.construct(nest.flatten(tok_collect)) 121 | vocab.limit_vocab_length(base_freq=args.base_wd_freq) 122 | vocab.save_vocab(os.path.join(args.dest_dir, 'vocab.pkl')) 123 | 124 | constructLabel_dict(nest.flatten(labels_collect), os.path.join(args.dest_dir, 'label2id.pkl')) 125 | 126 | vocab = Vocab() 127 | vocab.load_vocab_from_file(os.path.join(args.dest_dir, 'vocab.pkl')) 128 | 129 | buildEmbedding(args.ref_emb, os.path.join(args.dest_dir, 'embedding.txt'), vocab.word_to_index) 130 | 131 | label2id, id2label, id2revfreq, id2weight = loadLabel_dict(os.path.join(args.dest_dir, 'label2id.pkl')) 132 | 133 | if args.train_set: 134 | train_label = [label2id[o] for o in train_label] 135 | train_toks = nest.map_structure(lambda x: vocab.encode(x), train_toks) 136 | train_set = [o for o in zip(train_label, train_toks)] 137 | with open(os.path.join(args.dest_dir, 'trainset.pkl'), 'wb') as fd: 138 | pkl.dump(train_set, fd) 139 | if args.dev_set: 140 | dev_label = [label2id[o] for o in dev_label] 141 | dev_toks = nest.map_structure(lambda x: vocab.encode(x), dev_toks) 142 | dev_set = [o for o in zip(dev_label, dev_toks)] 143 | with open(os.path.join(args.dest_dir, 'devset.pkl'), 'wb') as fd: 144 | pkl.dump(dev_set, fd) 145 | if args.test_set: 146 | test_label = [label2id[o] for o in test_label] 147 | test_toks = nest.map_structure(lambda x: vocab.encode(x), test_toks) 148 | test_set = [o for o in zip(test_label, test_toks)] 149 | with open(os.path.join(args.dest_dir, 'testset.pkl'), 'wb') as fd: 150 | pkl.dump(test_set, fd) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /caps_attn_flatten/dataprocess/vocab.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import _pickle as pkl 3 | from collections import defaultdict 4 | 5 | class Vocab(object): 6 | def __init__(self, id_start=5): 7 | self.word_to_index = {} 8 | self.index_to_word = {} 9 | self.word_freq = defaultdict(int) 10 | self.id_start = id_start 11 | 12 | def add_word(self, word, count=1): 13 | word = word.strip() 14 | if len(word) == 0: 15 | return 16 | elif word.isspace(): 17 | return 18 | if word not in self.word_to_index: 19 | index = len(self.word_to_index) 20 | self.word_to_index[word] = index 21 | self.index_to_word[index] = word 22 | self.word_freq[word] += count 23 | 24 | def construct(self, words): 25 | for word in words: 26 | self.add_word(word) 27 | total_words = float(sum(self.word_freq.values())) 28 | 29 | '''sort by word frequency''' 30 | new_word_to_index = {} 31 | new_index_to_word = {} 32 | sorted_tup = sorted(self.word_freq.items(), key=operator.itemgetter(1)) 33 | sorted_tup.reverse() 34 | self.word_freq = dict(sorted_tup) 35 | for idx, (word, freq) in enumerate(sorted_tup): 36 | index = self.id_start + idx 37 | new_word_to_index[word] = index 38 | new_index_to_word[index] = word 39 | 40 | self.word_to_index = new_word_to_index 41 | self.index_to_word = new_index_to_word 42 | 43 | print('{} total words with {} uniques'.format(total_words, len(self.word_freq))) 44 | 45 | def limit_vocab_length(self, base_freq): 46 | """ 47 | Truncate vocabulary to keep most frequent words 48 | 49 | Args: 50 | None 51 | 52 | Returns: 53 | None 54 | """ 55 | 56 | new_word_to_index = {} 57 | new_index_to_word = {} 58 | sorted_tup = sorted(self.word_freq.items(), key=operator.itemgetter(1)) 59 | sorted_tup.reverse() 60 | vocab_tup = [item for item in sorted_tup if item[1] > base_freq] 61 | self.word_freq = dict(vocab_tup) 62 | for idx, (word, freq) in enumerate(vocab_tup): 63 | index = self.id_start + idx 64 | new_word_to_index[word] = index 65 | new_index_to_word[index] = word 66 | self.word_to_index = new_word_to_index 67 | self.index_to_word = new_index_to_word 68 | 69 | def save_vocab(self, filePath): 70 | """ 71 | Save vocabulary a offline file 72 | 73 | Args: 74 | filePath: where you want to save your vocabulary, every line in the 75 | file represents a word with a tab seperating word and it's frequency 76 | 77 | Returns: 78 | None 79 | """ 80 | with open(filePath, 'wb') as fd: 81 | pkl.dump(self.word_to_index, fd) 82 | pkl.dump(self.index_to_word, fd) 83 | pkl.dump(self.word_freq, fd) 84 | 85 | def load_vocab_from_file(self, filePath): 86 | """ 87 | Truncate vocabulary to keep most frequent words 88 | 89 | Args: 90 | filePath: vocabulary file path, every line in the file represents 91 | a word with a tab seperating word and it's frequency 92 | 93 | Returns: 94 | None 95 | """ 96 | with open(filePath, 'rb') as fd: 97 | self.word_to_index = pkl.load(fd) 98 | self.index_to_word = pkl.load(fd) 99 | self.word_freq = pkl.load(fd) 100 | 101 | print('load from <' + filePath + '>, there are {} words in dictionary'.format(len(self.word_freq))) 102 | 103 | def encode(self, word): 104 | if word not in self.word_to_index: 105 | return 1 #unk 106 | else: 107 | return self.word_to_index[word] 108 | 109 | def decode(self, index): 110 | if index not in self.index_to_word: 111 | return 'pad/unk' 112 | return self.index_to_word[index] 113 | 114 | def __len__(self): 115 | return len(self.word_to_index) -------------------------------------------------------------------------------- /caps_attn_flatten/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Sep 21, 2016 3 | 4 | @author: jerrik 5 | ''' 6 | 7 | import os 8 | import sys 9 | import time 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | import utils, nest 14 | from TfUtils import entry_stop_gradients, mkMask, reduce_avg, masked_softmax 15 | from Capsule_masked import Capusule 16 | 17 | class model(object): 18 | """Abstracts a Tensorflow graph for a learning task. 19 | 20 | We use various Model classes as usual abstractions to encapsulate tensorflow 21 | computational graphs. Each algorithm you will construct in this homework will 22 | inherit from a Model object. 23 | """ 24 | def __init__(self, config): 25 | """options in this function""" 26 | self.config = config 27 | self.EX_REG_SCOPE = [] 28 | 29 | self.on_epoch = tf.Variable(0, name='epoch_count', trainable=False) 30 | self.on_epoch_accu = tf.assign_add(self.on_epoch, 1) 31 | 32 | self.build() 33 | 34 | def add_placeholders(self): 35 | # shape(b_sz, sNum, wNum) 36 | self.ph_input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ph_input') 37 | 38 | # shape(bsz) 39 | self.ph_labels = tf.placeholder(shape=(None,), dtype=tf.int32, name='ph_labels') 40 | 41 | # [b_sz] 42 | self.ph_wNum = tf.placeholder(shape=(None,), dtype=tf.int32, name='ph_wNum') 43 | 44 | self.ph_sample_weights = tf.placeholder(shape=(None,), dtype=tf.float32, name='ph_sample_weights') 45 | self.ph_train = tf.placeholder(dtype=tf.bool, name='ph_train') 46 | 47 | def create_feed_dict(self, data_batch, train): 48 | '''data_batch: label_ids, snt1_matrix, snt2_matrix, snt1_len, snt2_len''' 49 | 50 | phs = (self.ph_input, self.ph_labels, self.ph_wNum, self.ph_sample_weights, self.ph_train) 51 | feed_dict = dict(zip(phs, data_batch+(train,))) 52 | return feed_dict 53 | 54 | def add_embedding(self): 55 | """Add embedding layer. that maps from vocabulary to vectors. 56 | inputs: a list of tensors each of which have a size of [batch_size, embed_size] 57 | """ 58 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 59 | vocab_sz = max(self.config.vocab_dict.values()) 60 | with tf.variable_scope('embedding') as scp: 61 | self.exclude_reg_scope(scp) 62 | if self.config.pre_trained: 63 | embed = utils.readEmbedding(self.config.embed_path) 64 | embed_matrix, valid_mask = utils.mkEmbedMatrix(embed, dict(self.config.vocab_dict)) 65 | embedding = tf.Variable(embed_matrix, 'Embedding') 66 | partial_update_embedding = entry_stop_gradients(embedding, tf.expand_dims(valid_mask, 1)) 67 | embedding = tf.cond(self.on_epoch < self.config.partial_update_until_epoch, 68 | lambda: partial_update_embedding, lambda: embedding) 69 | else: 70 | embedding = tf.get_variable( 71 | 'Embedding', 72 | [vocab_sz, self.config.embed_size], trainable=True) 73 | return embedding 74 | 75 | def embed_lookup(self, embedding, batch_x, dropout=None, is_train=False): 76 | ''' 77 | 78 | :param embedding: shape(v_sz, emb_sz) 79 | :param batch_x: shape(b_sz, wNum) 80 | :return: shape(b_sz, wNum, emb_sz) 81 | ''' 82 | inputs = tf.nn.embedding_lookup(embedding, batch_x) 83 | if dropout is not None: 84 | inputs = tf.layers.dropout(inputs, rate=dropout, training=is_train) 85 | return inputs 86 | 87 | def flatten_attention(self, in_x, wNum, scope=None): 88 | ''' 89 | 90 | :param in_x: shape(b_sz, wtstp, emb_sz) 91 | :param sNum: shape(b_sz, ) 92 | :param wNum: shape(b_sz,) 93 | :param scope: 94 | :return: 95 | ''' 96 | b_sz, wtstp, _ = tf.unstack(tf.shape(in_x)) 97 | emb_sz = int(in_x.get_shape()[-1]) 98 | with tf.variable_scope(scope or 'encoding_attention'): 99 | 100 | with tf.variable_scope('snt_enc'): 101 | if self.config.seq_encoder == 'bigru': 102 | birnn_wd = self.biGRU(in_x, wNum, self.config.hidden_size, scope='biGRU') 103 | elif self.config.seq_encoder == 'bilstm': 104 | birnn_wd = self.biLSTM(in_x, wNum, self.config.hidden_size, scope='biLSTM') 105 | else: 106 | raise ValueError('no such encoder %s'%self.config.seq_encoder) 107 | 108 | '''shape(b_sz, dim)''' 109 | if self.config.attn_mode == 'avg': 110 | snt_rep = reduce_avg(birnn_wd, wNum, dim=1) 111 | elif self.config.attn_mode == 'attn': 112 | snt_rep = self.task_specific_attention(birnn_wd, wNum, 113 | int(birnn_wd.get_shape()[-1]), 114 | dropout=self.config.dropout, 115 | is_train=self.ph_train, scope='attention') 116 | elif self.config.attn_mode == 'rout': 117 | snt_rep = self.routing_masked(birnn_wd, wNum, 118 | int(birnn_wd.get_shape()[-1]), 119 | self.config.out_caps_num, 120 | iter=self.config.rout_iter, 121 | dropout=self.config.dropout, 122 | is_train=self.ph_train, scope='attention') 123 | elif self.config.attn_mode == 'Rrout': 124 | snt_rep = self.reverse_routing_masked(birnn_wd, wNum, 125 | int(birnn_wd.get_shape()[-1]), 126 | self.config.out_caps_num, 127 | iter=self.config.rout_iter, 128 | dropout=self.config.dropout, 129 | is_train=self.ph_train, scope='attention') 130 | else: 131 | raise ValueError('no such attn mode %s' % self.config.attn_mode) 132 | return snt_rep 133 | 134 | def build(self): 135 | self.add_placeholders() 136 | self.embedding = self.add_embedding() 137 | '''shape(b_sz, ststp, wtstp, emb_sz)''' 138 | in_x = self.embed_lookup(self.embedding, self.ph_input, 139 | dropout=self.config.dropout, is_train=self.ph_train) 140 | snt_reps = self.flatten_attention(in_x, self.ph_wNum, scope='enc_attn') 141 | 142 | with tf.variable_scope('classifier'): 143 | logits = self.Dense(snt_reps, dropout=self.config.dropout, 144 | is_train=self.ph_train, activation=tf.nn.tanh) 145 | opt_loss = self.add_loss_op(logits, self.ph_labels) 146 | train_op = self.add_train_op(opt_loss) 147 | self.train_op = train_op 148 | self.opt_loss = opt_loss 149 | tf.summary.scalar('accuracy', self.accuracy) 150 | tf.summary.scalar('ce_loss', self.ce_loss) 151 | tf.summary.scalar('opt_loss', self.opt_loss) 152 | tf.summary.scalar('w_loss', self.w_loss) 153 | 154 | def Dense(self, inputs, dropout=None, is_train=False, activation=None): 155 | loop_input = inputs 156 | if self.config.dense_hidden[-1] != self.config.class_num: 157 | raise ValueError('last hidden layer should be %d, but get %d' % 158 | (self.config.class_num, 159 | self.config.dense_hidden[-1])) 160 | for i, hid_num in enumerate(self.config.dense_hidden): 161 | with tf.variable_scope('dense-layer-%d' % i): 162 | loop_input = tf.layers.dense(loop_input, units=hid_num) 163 | 164 | if i < len(self.config.dense_hidden) - 1: 165 | if dropout is not None: 166 | loop_input = tf.layers.dropout(loop_input, rate=dropout, training=is_train) 167 | loop_input = activation(loop_input) 168 | 169 | logits = loop_input 170 | return logits 171 | 172 | def add_loss_op(self, logits, labels): 173 | ''' 174 | 175 | :param logits: shape(b_sz, c_num) type(float) 176 | :param labels: shape(b_sz,) type(int) 177 | :return: 178 | ''' 179 | 180 | self.prediction = tf.argmax(logits, axis=-1, output_type=labels.dtype) 181 | 182 | self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.prediction, labels), tf.float32)) 183 | 184 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 185 | ce_loss = tf.reduce_mean(loss) 186 | 187 | exclude_vars = nest.flatten([[v for v in tf.trainable_variables(o.name)] for o in self.EX_REG_SCOPE]) 188 | exclude_vars_2 = [v for v in tf.trainable_variables() if '/bias:' in v.name] 189 | exclude_vars = exclude_vars + exclude_vars_2 190 | 191 | reg_var_list = [v for v in tf.trainable_variables() if v not in exclude_vars] 192 | reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in reg_var_list]) 193 | self.param_cnt = np.sum([np.prod(v.get_shape().as_list()) for v in reg_var_list]) 194 | 195 | print('===' * 20) 196 | print('total reg parameter count: %.3f M' % (self.param_cnt / 1000000.)) 197 | print('excluded variables from regularization') 198 | print([v.name for v in exclude_vars]) 199 | print('===' * 20) 200 | 201 | print('regularized variables') 202 | print(['%s:%.3fM' % (v.name, np.prod(v.get_shape().as_list()) / 1000000.) for v in reg_var_list]) 203 | print('===' * 20) 204 | '''shape(b_sz,)''' 205 | self.ce_loss = ce_loss 206 | self.w_loss = tf.reduce_mean(tf.multiply(loss, self.ph_sample_weights)) 207 | reg = self.config.reg 208 | 209 | return self.ce_loss + reg * reg_loss 210 | 211 | def add_train_op(self, loss): 212 | 213 | lr = tf.train.exponential_decay(self.config.lr, self.global_step, 214 | self.config.decay_steps, 215 | self.config.decay_rate, staircase=True) 216 | self.learning_rate = tf.maximum(lr, 1e-5) 217 | if self.config.optimizer == 'adam': 218 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 219 | elif self.config.optimizer == 'grad': 220 | optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 221 | elif self.config.optimizer == 'adgrad': 222 | optimizer = tf.train.AdagradOptimizer(self.learning_rate) 223 | elif self.config.optimizer == 'adadelta': 224 | optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 225 | else: 226 | raise ValueError('No such Optimizer: %s' % self.config.optimizer) 227 | 228 | gvs = optimizer.compute_gradients(loss=loss) 229 | 230 | capped_gvs = [(tf.clip_by_value(grad, -2., 2.), var) for grad, var in gvs] 231 | train_op = optimizer.apply_gradients(capped_gvs, global_step=self.global_step) 232 | return train_op 233 | 234 | def exclude_reg_scope(self, scope): 235 | if scope not in self.EX_REG_SCOPE: 236 | self.EX_REG_SCOPE.append(scope) 237 | 238 | @staticmethod 239 | def biLSTM(in_x, xLen, h_sz, dropout=None, is_train=False, scope=None): 240 | 241 | with tf.variable_scope(scope or 'biLSTM'): 242 | cell_fwd = tf.nn.rnn_cell.BasicLSTMCell(h_sz) 243 | cell_bwd = tf.nn.rnn_cell.BasicLSTMCell(h_sz) 244 | x_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_fwd, cell_bwd, in_x, xLen, 245 | dtype=tf.float32, swap_memory=True, 246 | scope='birnn') 247 | 248 | x_out = tf.concat(x_out, axis=2) 249 | if dropout is not None: 250 | x_out = tf.layers.dropout(x_out, rate=dropout, training=is_train) 251 | return x_out 252 | 253 | @staticmethod 254 | def biGRU(in_x, xLen, h_sz, dropout=None, is_train=False, scope=None): 255 | 256 | with tf.variable_scope(scope or 'biGRU'): 257 | cell_fwd = tf.nn.rnn_cell.GRUCell(h_sz) 258 | cell_bwd = tf.nn.rnn_cell.GRUCell(h_sz) 259 | x_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_fwd, cell_bwd, in_x, xLen, 260 | dtype=tf.float32, swap_memory=True, 261 | scope='birnn') 262 | 263 | x_out = tf.concat(x_out, axis=2) 264 | if dropout is not None: 265 | x_out = tf.layers.dropout(x_out, rate=dropout, training=is_train) 266 | return x_out 267 | 268 | @staticmethod 269 | def task_specific_attention(in_x, xLen, out_sz, activation_fn=tf.tanh, 270 | dropout=None, is_train=False, scope=None): 271 | ''' 272 | 273 | :param in_x: shape(b_sz, tstp, dim) 274 | :param xLen: shape(b_sz,) 275 | :param out_sz: scalar 276 | :param activation_fn: activation 277 | :param dropout: 278 | :param is_train: 279 | :param scope: 280 | :return: 281 | ''' 282 | 283 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 284 | 285 | with tf.variable_scope(scope or 'attention') as scope: 286 | context_vector = tf.get_variable(name='context_vector', shape=[out_sz], 287 | dtype=tf.float32) 288 | in_x_mlp = tf.layers.dense(in_x, out_sz, activation=activation_fn, name='mlp') 289 | 290 | attn = tf.tensordot(in_x_mlp, context_vector, axes=[[2], [0]]) # shape(b_sz, tstp) 291 | attn_normed = masked_softmax(attn, xLen) 292 | 293 | attn_normed = tf.expand_dims(attn_normed, axis=-1) 294 | attn_ctx = tf.matmul(in_x_mlp, attn_normed, transpose_a=True) # shape(b_sz, dim, 1) 295 | attn_ctx = tf.squeeze(attn_ctx, axis=[2]) # shape(b_sz, dim) 296 | if dropout is not None: 297 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 298 | return attn_ctx 299 | 300 | @staticmethod 301 | def routing_masked(in_x, xLen, out_sz, out_caps_num, iter=3, 302 | dropout=None, is_train=False, scope=None): 303 | ''' 304 | 305 | :param in_x: shape(b_sz, tstp, dim) 306 | :param xLen: shape(b_sz,) 307 | :param out_sz: scalar 308 | :param dropout: 309 | :param is_train: 310 | :param scope: 311 | :return: 312 | ''' 313 | 314 | 315 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 316 | b_sz = tf.shape(in_x)[0] 317 | with tf.variable_scope(scope or 'routing'): 318 | attn_ctx = Capusule(out_caps_num, out_sz, iter)(in_x, xLen) # shape(b_sz, out_caps_num, out_sz) 319 | attn_ctx = tf.reshape(attn_ctx, shape=[b_sz, out_caps_num*out_sz]) 320 | if dropout is not None: 321 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 322 | return attn_ctx 323 | 324 | @staticmethod 325 | def reverse_routing_masked(in_x, xLen, out_sz, out_caps_num, iter=3, 326 | dropout=None, is_train=False, scope=None): 327 | ''' 328 | 329 | :param in_x: shape(b_sz, tstp, dim) 330 | :param xLen: shape(b_sz,) 331 | :param out_sz: scalar 332 | :param dropout: 333 | :param is_train: 334 | :param scope: 335 | :return: 336 | ''' 337 | 338 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 339 | b_sz = tf.shape(in_x)[0] 340 | with tf.variable_scope(scope or 'routing'): 341 | '''shape(b_sz, out_caps_num, out_sz)''' 342 | attn_ctx = Capusule(out_caps_num, out_sz, iter)(in_x, xLen, reverse_routing=True) 343 | attn_ctx = tf.reshape(attn_ctx, shape=[b_sz, out_caps_num * out_sz]) 344 | if dropout is not None: 345 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 346 | return attn_ctx -------------------------------------------------------------------------------- /caps_attn_flatten/nest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """## Functions for working with arbitrarily nested sequences of elements. 17 | 18 | This module can perform operations on nested structures. A nested structure is a 19 | Python sequence, tuple (including `namedtuple`), or dict that can contain 20 | further sequences, tuples, and dicts. 21 | 22 | The utilities here assume (and do not check) that the nested structures form a 23 | 'tree', i.e., no references in the structure of the input of these functions 24 | should be recursive. 25 | 26 | Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), 27 | (np.array([3, 4]), tf.constant([3, 4])))` 28 | """ 29 | 30 | 31 | import collections as _collections 32 | 33 | import six as _six 34 | 35 | 36 | def _sorted(dict_): 37 | """Returns a sorted list of the dict keys, with error if keys not sortable.""" 38 | try: 39 | return sorted(_six.iterkeys(dict_)) 40 | except TypeError: 41 | raise TypeError("nest only supports dicts with sortable keys.") 42 | 43 | 44 | def _sequence_like(instance, args): 45 | """Converts the sequence `args` to the same type as `instance`. 46 | 47 | Args: 48 | instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or 49 | `collections.OrderedDict`. 50 | args: elements to be converted to the `instance` type. 51 | 52 | Returns: 53 | `args` with the type of `instance`. 54 | """ 55 | if isinstance(instance, dict): 56 | # Pack dictionaries in a deterministic order by sorting the keys. 57 | # Notice this means that we ignore the original order of `OrderedDict` 58 | # instances. This is intentional, to avoid potential bugs caused by mixing 59 | # ordered and plain dicts (e.g., flattening a dict but using a 60 | # corresponding `OrderedDict` to pack it back). 61 | result = dict(zip(_sorted(instance), args)) 62 | return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) 63 | elif (isinstance(instance, tuple) and 64 | hasattr(instance, "_fields") and 65 | isinstance(instance._fields, _collections.Sequence) and 66 | all(isinstance(f, _six.string_types) for f in instance._fields)): 67 | # This is a namedtuple 68 | return type(instance)(*args) 69 | else: 70 | # Not a namedtuple 71 | return type(instance)(args) 72 | 73 | 74 | def _yield_value(iterable): 75 | if isinstance(iterable, dict): 76 | # Iterate through dictionaries in a deterministic order by sorting the 77 | # keys. Notice this means that we ignore the original order of `OrderedDict` 78 | # instances. This is intentional, to avoid potential bugs caused by mixing 79 | # ordered and plain dicts (e.g., flattening a dict but using a 80 | # corresponding `OrderedDict` to pack it back). 81 | for key in _sorted(iterable): 82 | yield iterable[key] 83 | else: 84 | for value in iterable: 85 | yield value 86 | 87 | 88 | def _yield_flat_nest(nest): 89 | for n in _yield_value(nest): 90 | if is_sequence(n): 91 | for ni in _yield_flat_nest(n): 92 | yield ni 93 | else: 94 | yield n 95 | 96 | 97 | # Used by `_warn_once` to remember which warning messages have been given. 98 | _ALREADY_WARNED = {} 99 | 100 | 101 | def _warn_once(message): 102 | """Logs a warning message, once per unique string.""" 103 | if message not in _ALREADY_WARNED: 104 | _ALREADY_WARNED[message] = True 105 | 106 | 107 | def is_sequence(seq): 108 | """Returns a true if its input is a collections.Sequence (except strings). 109 | 110 | Args: 111 | seq: an input sequence. 112 | 113 | Returns: 114 | True if the sequence is a not a string and is a collections.Sequence or a 115 | dict. 116 | """ 117 | if isinstance(seq, dict): 118 | return True 119 | if isinstance(seq, set): 120 | _warn_once("Sets are not currently considered sequences, but this may " 121 | "change in the future, so consider avoiding using them.") 122 | return (isinstance(seq, _collections.Sequence) 123 | and not isinstance(seq, _six.string_types)) 124 | 125 | 126 | def flatten(nest): 127 | """Returns a flat list from a given nested structure. 128 | 129 | If `nest` is not a sequence, tuple, or dict, then returns a single-element 130 | list: `[nest]`. 131 | 132 | In the case of dict instances, the sequence consists of the values, sorted by 133 | key to ensure deterministic behavior. This is true also for `OrderedDict` 134 | instances: their sequence order is ignored, the sorting order of keys is 135 | used instead. The same convention is followed in `pack_sequence_as`. This 136 | correctly repacks dicts and `OrderedDict`s after they have been flattened, 137 | and also allows flattening an `OrderedDict` and then repacking it back using 138 | a correponding plain dict, or vice-versa. 139 | Dictionaries with non-sortable keys cannot be flattened. 140 | 141 | Args: 142 | nest: an arbitrarily nested structure or a scalar object. Note, numpy 143 | arrays are considered scalars. 144 | 145 | Returns: 146 | A Python list, the flattened version of the input. 147 | 148 | Raises: 149 | TypeError: The nest is or contains a dict with non-sortable keys. 150 | """ 151 | if is_sequence(nest): 152 | return list(_yield_flat_nest(nest)) 153 | else: 154 | return [nest] 155 | 156 | 157 | def _recursive_assert_same_structure(nest1, nest2, check_types): 158 | """Helper function for `assert_same_structure`.""" 159 | is_sequence_nest1 = is_sequence(nest1) 160 | if is_sequence_nest1 != is_sequence(nest2): 161 | raise ValueError( 162 | "The two structures don't have the same nested structure.\n\n" 163 | "First structure: %s\n\nSecond structure: %s." % (nest1, nest2)) 164 | 165 | if not is_sequence_nest1: 166 | return # finished checking 167 | 168 | if check_types: 169 | type_nest1 = type(nest1) 170 | type_nest2 = type(nest2) 171 | if type_nest1 != type_nest2: 172 | raise TypeError( 173 | "The two structures don't have the same sequence type. First " 174 | "structure has type %s, while second structure has type %s." 175 | % (type_nest1, type_nest2)) 176 | 177 | if isinstance(nest1, dict): 178 | keys1 = set(_six.iterkeys(nest1)) 179 | keys2 = set(_six.iterkeys(nest2)) 180 | if keys1 != keys2: 181 | raise ValueError( 182 | "The two dictionaries don't have the same set of keys. First " 183 | "structure has keys {}, while second structure has keys {}." 184 | .format(keys1, keys2)) 185 | 186 | nest1_as_sequence = [n for n in _yield_value(nest1)] 187 | nest2_as_sequence = [n for n in _yield_value(nest2)] 188 | for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence): 189 | _recursive_assert_same_structure(n1, n2, check_types) 190 | 191 | 192 | def assert_same_structure(nest1, nest2, check_types=True): 193 | """Asserts that two structures are nested in the same way. 194 | 195 | Args: 196 | nest1: an arbitrarily nested structure. 197 | nest2: an arbitrarily nested structure. 198 | check_types: if `True` (default) types of sequences are checked as 199 | well, including the keys of dictionaries. If set to `False`, for example 200 | a list and a tuple of objects will look the same if they have the same 201 | size. 202 | 203 | Raises: 204 | ValueError: If the two structures do not have the same number of elements or 205 | if the two structures are not nested in the same way. 206 | TypeError: If the two structures differ in the type of sequence in any of 207 | their substructures. Only possible if `check_types` is `True`. 208 | """ 209 | len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 210 | len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 211 | if len_nest1 != len_nest2: 212 | raise ValueError("The two structures don't have the same number of " 213 | "elements.\n\nFirst structure (%i elements): %s\n\n" 214 | "Second structure (%i elements): %s" 215 | % (len_nest1, nest1, len_nest2, nest2)) 216 | _recursive_assert_same_structure(nest1, nest2, check_types) 217 | 218 | 219 | def flatten_dict_items(dictionary): 220 | """Returns a dictionary with flattened keys and values. 221 | 222 | This function flattens the keys and values of a dictionary, which can be 223 | arbitrarily nested structures, and returns the flattened version of such 224 | structures: 225 | 226 | ```python 227 | example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 228 | result = {4: "a", 5: "b", 6: "c", 8: "d"} 229 | flatten_dict_items(example_dictionary) == result 230 | ``` 231 | 232 | The input dictionary must satisfy two properties: 233 | 234 | 1. Its keys and values should have the same exact nested structure. 235 | 2. The set of all flattened keys of the dictionary must not contain repeated 236 | keys. 237 | 238 | Args: 239 | dictionary: the dictionary to zip 240 | 241 | Returns: 242 | The zipped dictionary. 243 | 244 | Raises: 245 | TypeError: If the input is not a dictionary. 246 | ValueError: If any key and value have not the same structure, or if keys are 247 | not unique. 248 | """ 249 | if not isinstance(dictionary, dict): 250 | raise TypeError("input must be a dictionary") 251 | flat_dictionary = {} 252 | for i, v in _six.iteritems(dictionary): 253 | if not is_sequence(i): 254 | if i in flat_dictionary: 255 | raise ValueError( 256 | "Could not flatten dictionary: key %s is not unique." % i) 257 | flat_dictionary[i] = v 258 | else: 259 | flat_i = flatten(i) 260 | flat_v = flatten(v) 261 | if len(flat_i) != len(flat_v): 262 | raise ValueError( 263 | "Could not flatten dictionary. Key had %d elements, but value had " 264 | "%d elements. Key: %s, value: %s." 265 | % (len(flat_i), len(flat_v), flat_i, flat_v)) 266 | for new_i, new_v in zip(flat_i, flat_v): 267 | if new_i in flat_dictionary: 268 | raise ValueError( 269 | "Could not flatten dictionary: key %s is not unique." 270 | % (new_i)) 271 | flat_dictionary[new_i] = new_v 272 | return flat_dictionary 273 | 274 | 275 | def _packed_nest_with_indices(structure, flat, index): 276 | """Helper function for pack_sequence_as. 277 | 278 | Args: 279 | structure: Substructure (list / tuple / dict) to mimic. 280 | flat: Flattened values to output substructure for. 281 | index: Index at which to start reading from flat. 282 | 283 | Returns: 284 | The tuple (new_index, child), where: 285 | * new_index - the updated index into `flat` having processed `structure`. 286 | * packed - the subset of `flat` corresponding to `structure`, 287 | having started at `index`, and packed into the same nested 288 | format. 289 | 290 | Raises: 291 | ValueError: if `structure` contains more elements than `flat` 292 | (assuming indexing starts from `index`). 293 | """ 294 | packed = [] 295 | for s in _yield_value(structure): 296 | if is_sequence(s): 297 | new_index, child = _packed_nest_with_indices(s, flat, index) 298 | packed.append(_sequence_like(s, child)) 299 | index = new_index 300 | else: 301 | packed.append(flat[index]) 302 | index += 1 303 | return index, packed 304 | 305 | 306 | def pack_sequence_as(structure, flat_sequence): 307 | """Returns a given flattened sequence packed into a given structure. 308 | 309 | If `structure` is a scalar, `flat_sequence` must be a single-element list; 310 | in this case the return value is `flat_sequence[0]`. 311 | 312 | If `structure` is or contains a dict instance, the keys will be sorted to 313 | pack the flat sequence in deterministic order. This is true also for 314 | `OrderedDict` instances: their sequence order is ignored, the sorting order of 315 | keys is used instead. The same convention is followed in `pack_sequence_as`. 316 | This correctly repacks dicts and `OrderedDict`s after they have been 317 | flattened, and also allows flattening an `OrderedDict` and then repacking it 318 | back using a correponding plain dict, or vice-versa. 319 | Dictionaries with non-sortable keys cannot be flattened. 320 | 321 | Args: 322 | structure: Nested structure, whose structure is given by nested lists, 323 | tuples, and dicts. Note: numpy arrays and strings are considered 324 | scalars. 325 | flat_sequence: flat sequence to pack. 326 | 327 | Returns: 328 | packed: `flat_sequence` converted to have the same recursive structure as 329 | `structure`. 330 | 331 | Raises: 332 | ValueError: If `flat_sequence` and `structure` have different 333 | element counts. 334 | TypeError: `structure` is or contains a dict with non-sortable keys. 335 | """ 336 | if not is_sequence(flat_sequence): 337 | raise TypeError("flat_sequence must be a sequence") 338 | 339 | if not is_sequence(structure): 340 | if len(flat_sequence) != 1: 341 | raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1" 342 | % len(flat_sequence)) 343 | return flat_sequence[0] 344 | 345 | flat_structure = flatten(structure) 346 | if len(flat_structure) != len(flat_sequence): 347 | raise ValueError( 348 | "Could not pack sequence. Structure had %d elements, but flat_sequence " 349 | "had %d elements. Structure: %s, flat_sequence: %s." 350 | % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 351 | 352 | _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) 353 | return _sequence_like(structure, packed) 354 | 355 | 356 | def map_structure(func, *structure, **check_types_dict): 357 | """Applies `func` to each entry in `structure` and returns a new structure. 358 | 359 | Applies `func(x[0], x[1], ...)` where x[i] is an entry in 360 | `structure[i]`. All structures in `structure` must have the same arity, 361 | and the return value will contain the results in the same structure. 362 | 363 | Args: 364 | func: A callable that accepts as many arguments as there are structures. 365 | *structure: scalar, or tuple or list of constructed scalars and/or other 366 | tuples/lists, or scalars. Note: numpy arrays are considered as scalars. 367 | **check_types_dict: only valid keyword argument is `check_types`. If set to 368 | `True` (default) the types of iterables within the structures have to be 369 | same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` 370 | exception). To allow this set this argument to `False`. 371 | 372 | Returns: 373 | A new structure with the same arity as `structure`, whose values correspond 374 | to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding 375 | location in `structure[i]`. If there are different sequence types and 376 | `check_types` is `False` the sequence types of the first structure will be 377 | used. 378 | 379 | Raises: 380 | TypeError: If `func` is not callable or if the structures do not match 381 | each other by depth tree. 382 | ValueError: If no structure is provided or if the structures do not match 383 | each other by type. 384 | ValueError: If wrong keyword arguments are provided. 385 | """ 386 | if not callable(func): 387 | raise TypeError("func must be callable, got: %s" % func) 388 | 389 | if not structure: 390 | raise ValueError("Must provide at least one structure") 391 | 392 | if check_types_dict: 393 | if "check_types" not in check_types_dict or len(check_types_dict) > 1: 394 | raise ValueError("Only valid keyword argument is check_types") 395 | check_types = check_types_dict["check_types"] 396 | else: 397 | check_types = True 398 | 399 | for other in structure[1:]: 400 | assert_same_structure(structure[0], other, check_types=check_types) 401 | 402 | flat_structure = [flatten(s) for s in structure] 403 | entries = zip(*flat_structure) 404 | 405 | return pack_sequence_as( 406 | structure[0], [func(*x) for x in entries]) 407 | 408 | 409 | def _yield_flat_up_to(shallow_tree, input_tree): 410 | """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" 411 | if is_sequence(shallow_tree): 412 | for shallow_branch, input_branch in zip(_yield_value(shallow_tree), 413 | _yield_value(input_tree)): 414 | for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): 415 | yield input_leaf 416 | else: 417 | yield input_tree 418 | 419 | 420 | def assert_shallow_structure(shallow_tree, input_tree, check_types=True): 421 | """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 422 | 423 | That is, this function tests if the `input_tree` structure can be created from 424 | the `shallow_tree` structure by replacing its leaf nodes with deeper 425 | tree structures. 426 | 427 | Examples: 428 | 429 | The following code will raise an exception: 430 | ```python 431 | shallow_tree = ["a", "b"] 432 | input_tree = ["c", ["d", "e"], "f"] 433 | assert_shallow_structure(shallow_tree, input_tree) 434 | ``` 435 | 436 | The following code will not raise an exception: 437 | ```python 438 | shallow_tree = ["a", "b"] 439 | input_tree = ["c", ["d", "e"]] 440 | assert_shallow_structure(shallow_tree, input_tree) 441 | ``` 442 | 443 | Args: 444 | shallow_tree: an arbitrarily nested structure. 445 | input_tree: an arbitrarily nested structure. 446 | check_types: if `True` (default) the sequence types of `shallow_tree` and 447 | `input_tree` have to be the same. 448 | 449 | Raises: 450 | TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 451 | TypeError: If the sequence types of `shallow_tree` are different from 452 | `input_tree`. Only raised if `check_types` is `True`. 453 | ValueError: If the sequence lengths of `shallow_tree` are different from 454 | `input_tree`. 455 | """ 456 | if is_sequence(shallow_tree): 457 | if not is_sequence(input_tree): 458 | raise TypeError( 459 | "If shallow structure is a sequence, input must also be a sequence. " 460 | "Input has type: %s." % type(input_tree)) 461 | 462 | if check_types and not isinstance(input_tree, type(shallow_tree)): 463 | raise TypeError( 464 | "The two structures don't have the same sequence type. Input " 465 | "structure has type %s, while shallow structure has type %s." 466 | % (type(input_tree), type(shallow_tree))) 467 | 468 | if len(input_tree) != len(shallow_tree): 469 | raise ValueError( 470 | "The two structures don't have the same sequence length. Input " 471 | "structure has length %s, while shallow structure has length %s." 472 | % (len(input_tree), len(shallow_tree))) 473 | 474 | for shallow_branch, input_branch in zip(shallow_tree, input_tree): 475 | assert_shallow_structure(shallow_branch, input_branch, 476 | check_types=check_types) 477 | 478 | 479 | def flatten_up_to(shallow_tree, input_tree): 480 | """Flattens `input_tree` up to `shallow_tree`. 481 | 482 | Any further depth in structure in `input_tree` is retained as elements in the 483 | partially flatten output. 484 | 485 | If `shallow_tree` and `input_tree` are not sequences, this returns a 486 | single-element list: `[input_tree]`. 487 | 488 | Use Case: 489 | 490 | Sometimes we may wish to partially flatten a nested sequence, retaining some 491 | of the nested structure. We achieve this by specifying a shallow structure, 492 | `shallow_tree`, we wish to flatten up to. 493 | 494 | The input, `input_tree`, can be thought of as having the same structure as 495 | `shallow_tree`, but with leaf nodes that are themselves tree structures. 496 | 497 | Examples: 498 | 499 | ```python 500 | input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 501 | shallow_tree = [[True, True], [False, True]] 502 | 503 | flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 504 | flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 505 | 506 | # Output is: 507 | # [[2, 2], [3, 3], [4, 9], [5, 5]] 508 | # [True, True, False, True] 509 | ``` 510 | 511 | ```python 512 | input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 513 | shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 514 | 515 | input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 516 | input_tree_flattened = flatten(input_tree) 517 | 518 | # Output is: 519 | # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 520 | # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 521 | ``` 522 | 523 | Non-Sequence Edge Cases: 524 | 525 | ```python 526 | flatten_up_to(0, 0) # Output: [0] 527 | flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 528 | flatten_up_to([0, 1, 2], 0) # Output: TypeError 529 | flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 530 | ``` 531 | 532 | Args: 533 | shallow_tree: a possibly pruned structure of input_tree. 534 | input_tree: an arbitrarily nested structure or a scalar object. 535 | Note, numpy arrays are considered scalars. 536 | 537 | Returns: 538 | A Python list, the partially flattened version of `input_tree` according to 539 | the structure of `shallow_tree`. 540 | 541 | Raises: 542 | TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 543 | TypeError: If the sequence types of `shallow_tree` are different from 544 | `input_tree`. 545 | ValueError: If the sequence lengths of `shallow_tree` are different from 546 | `input_tree`. 547 | """ 548 | assert_shallow_structure(shallow_tree, input_tree) 549 | return list(_yield_flat_up_to(shallow_tree, input_tree)) 550 | 551 | 552 | def map_structure_up_to(shallow_tree, func, *inputs): 553 | """Applies a function or op to a number of partially flattened inputs. 554 | 555 | The `inputs` are flattened up to `shallow_tree` before being mapped. 556 | 557 | Use Case: 558 | 559 | Sometimes we wish to apply a function to a partially flattened 560 | sequence (for example when the function itself takes sequence inputs). We 561 | achieve this by specifying a shallow structure, `shallow_tree` we wish to 562 | flatten up to. 563 | 564 | The `inputs`, can be thought of as having the same structure as 565 | `shallow_tree`, but with leaf nodes that are themselves tree structures. 566 | 567 | This function therefore will return something with the same base structure as 568 | `shallow_tree`. 569 | 570 | Examples: 571 | 572 | ```python 573 | ab_tuple = collections.namedtuple("ab_tuple", "a, b") 574 | op_tuple = collections.namedtuple("op_tuple", "add, mul") 575 | inp_val = ab_tuple(a=2, b=3) 576 | inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 577 | out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 578 | inp_val, inp_ops) 579 | 580 | # Output is: ab_tuple(a=6, b=15) 581 | ``` 582 | 583 | ```python 584 | data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 585 | name_list = ['evens', ['odds', 'primes']] 586 | out = map_structure_up_to( 587 | name_list, 588 | lambda name, sec: "first_{}_{}".format(len(sec), name), 589 | name_list, data_list) 590 | 591 | # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 592 | ``` 593 | 594 | Args: 595 | shallow_tree: a shallow tree, common to all the inputs. 596 | func: callable which will be applied to each input individually. 597 | *inputs: arbitrarily nested combination of objects that are compatible with 598 | shallow_tree. The function `func` is applied to corresponding 599 | partially flattened elements of each input, so the function must support 600 | arity of `len(inputs)`. 601 | 602 | Raises: 603 | TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 604 | TypeError: If the sequence types of `shallow_tree` are different from 605 | `input_tree`. 606 | ValueError: If the sequence lengths of `shallow_tree` are different from 607 | `input_tree`. 608 | 609 | Returns: 610 | result of repeatedly applying `func`, with same structure as 611 | `shallow_tree`. 612 | """ 613 | if not inputs: 614 | raise ValueError("Cannot map over no sequences") 615 | for input_tree in inputs: 616 | assert_shallow_structure(shallow_tree, input_tree) 617 | 618 | # Flatten each input separately, apply the function to corresponding elements, 619 | # then repack based on the structure of the first input. 620 | all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) 621 | for input_tree in inputs] 622 | results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] 623 | return pack_sequence_as(structure=shallow_tree, flat_sequence=results) 624 | 625 | 626 | def get_traverse_shallow_structure(traverse_fn, structure): 627 | """Generates a shallow structure from a `traverse_fn` and `structure`. 628 | 629 | `traverse_fn` must accept any possible subtree of `structure` and return 630 | a depth=1 structure containing `True` or `False` values, describing which 631 | of the top-level subtrees may be traversed. It may also 632 | return scalar `True` or `False` "traversal is OK / not OK for all subtrees." 633 | 634 | Examples are available in the unit tests (nest_test.py). 635 | 636 | Args: 637 | traverse_fn: Function taking a substructure and returning either a scalar 638 | `bool` (whether to traverse that substructure or not) or a depth=1 639 | shallow structure of the same type, describing which parts of the 640 | substructure to traverse. 641 | structure: The structure to traverse. 642 | 643 | Returns: 644 | A shallow structure containing python bools, which can be passed to 645 | `map_structure_up_to` and `flatten_up_to`. 646 | 647 | Raises: 648 | TypeError: if `traverse_fn` returns a sequence for a non-sequence input, 649 | or a structure with depth higher than 1 for a sequence input, 650 | or if any leaf values in the returned structure or scalar are not type 651 | `bool`. 652 | """ 653 | to_traverse = traverse_fn(structure) 654 | if not is_sequence(structure): 655 | if not isinstance(to_traverse, bool): 656 | raise TypeError("traverse_fn returned structure: %s for non-structure: %s" 657 | % (to_traverse, structure)) 658 | return to_traverse 659 | level_traverse = [] 660 | if isinstance(to_traverse, bool): 661 | if not to_traverse: 662 | # Do not traverse this substructure at all. Exit early. 663 | return False 664 | else: 665 | # Traverse the entire substructure. 666 | for branch in _yield_value(structure): 667 | level_traverse.append( 668 | get_traverse_shallow_structure(traverse_fn, branch)) 669 | elif not is_sequence(to_traverse): 670 | raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" 671 | % (to_traverse, structure)) 672 | else: 673 | # Traverse some subset of this substructure. 674 | assert_shallow_structure(to_traverse, structure) 675 | for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): 676 | if not isinstance(t, bool): 677 | raise TypeError( 678 | "traverse_fn didn't return a depth=1 structure of bools. saw: %s " 679 | " for structure: %s" % (to_traverse, structure)) 680 | if t: 681 | level_traverse.append( 682 | get_traverse_shallow_structure(traverse_fn, branch)) 683 | else: 684 | level_traverse.append(False) 685 | return _sequence_like(structure, level_traverse) 686 | 687 | -------------------------------------------------------------------------------- /caps_attn_flatten/train_test.py: -------------------------------------------------------------------------------- 1 | import argparse, sys, os, time, logging, threading, traceback 2 | import numpy as np 3 | import tensorflow as tf 4 | import _pickle as pkl 5 | import sys 6 | from multiprocessing import Queue, Process 7 | 8 | from Config import Config 9 | from model import model 10 | from data_iterator import TextIterator, preparedata 11 | from dataprocess.vocab import Vocab 12 | import utils 13 | 14 | _REVISION = 'flatten' 15 | 16 | parser = argparse.ArgumentParser(description="training options") 17 | 18 | parser.add_argument('--load-config', action='store_true', dest='load_config', default=False) 19 | parser.add_argument('--gpu-num', action='store', dest='gpu_num', default=0, type=int) 20 | parser.add_argument('--train-test', action='store', dest='train_test', default='train', choices=['train', 'test']) 21 | parser.add_argument('--weight-path', action='store', dest='weight_path', required=True) 22 | parser.add_argument('--restore-ckpt', action='store_true', dest='restore_ckpt', default=False) 23 | parser.add_argument('--retain-gpu', action='store_true', dest='retain_gpu', default=False) 24 | 25 | parser.add_argument('--debug-enable', action='store_true', dest='debug_enable', default=False) 26 | 27 | args = parser.parse_args() 28 | 29 | DEBUG = args.debug_enable 30 | if not DEBUG: 31 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 32 | 33 | def debug(s): 34 | if DEBUG: 35 | print(s) 36 | pass 37 | 38 | class Train: 39 | 40 | def __init__(self, args): 41 | if utils.valid_entry(args.weight_path) and not args.restore_ckpt\ 42 | and args.train_test != 'test': 43 | raise ValueError('process running or finished') 44 | 45 | gpu_lock = threading.Lock() 46 | gpu_lock.acquire() 47 | def retain_gpu(): 48 | if args.retain_gpu: 49 | with tf.Session(): 50 | gpu_lock.acquire() 51 | else: 52 | pass 53 | 54 | lockThread = threading.Thread(target=retain_gpu) 55 | lockThread.start() 56 | try: 57 | self.args = args 58 | config = Config() 59 | 60 | self.args = args 61 | self.weight_path = args.weight_path 62 | 63 | if args.load_config == False: 64 | config.saveConfig(self.weight_path + '/config') 65 | print('default configuration generated, please specify --load-config and run again.') 66 | gpu_lock.release() 67 | lockThread.join() 68 | sys.exit() 69 | else: 70 | if os.path.exists(self.weight_path + '/config'): 71 | config.loadConfig(self.weight_path + '/config') 72 | else: 73 | raise ValueError('No config file in %s' % self.weight_path) 74 | 75 | if config.revision != _REVISION: 76 | raise ValueError('revision dont match: %s over %s' % (config.revision, _REVISION)) 77 | 78 | vocab = Vocab() 79 | vocab.load_vocab_from_file(os.path.join(config.datapath, 'vocab.pkl')) 80 | config.vocab_dict = vocab.word_to_index 81 | with open(os.path.join(config.datapath, 'label2id.pkl'), 'rb') as fd: 82 | _ = pkl.load(fd) 83 | config.id2label = pkl.load(fd) 84 | _ = pkl.load(fd) 85 | config.id2weight = pkl.load(fd) 86 | 87 | config.class_num = len(config.id2label) 88 | self.config = config 89 | 90 | self.train_data = TextIterator(os.path.join(config.datapath, 'trainset.pkl'), self.config.batch_sz, 91 | bucket_sz=self.config.bucket_sz, shuffle=True) 92 | config.n_samples = self.train_data.num_example 93 | self.dev_data = TextIterator(os.path.join(config.datapath, 'devset.pkl'), self.config.batch_sz, 94 | bucket_sz=self.config.bucket_sz, shuffle=False) 95 | 96 | self.test_data = TextIterator(os.path.join(config.datapath, 'testset.pkl'), self.config.batch_sz, 97 | bucket_sz=self.config.bucket_sz, shuffle=False) 98 | 99 | self.data_q = Queue(10) 100 | 101 | self.model = model(config) 102 | 103 | except Exception as e: 104 | traceback.print_exc() 105 | gpu_lock.release() 106 | lockThread.join() 107 | exit() 108 | 109 | gpu_lock.release() 110 | lockThread.join() 111 | if utils.valid_entry(args.weight_path) and not args.restore_ckpt\ 112 | and args.train_test != 'test': 113 | raise ValueError('process running or finished') 114 | 115 | def get_epoch(self, sess): 116 | epoch = sess.run(self.model.on_epoch) 117 | return epoch 118 | 119 | def run_epoch(self, sess, input_data: TextIterator, verbose=10): 120 | """Runs an epoch of training. 121 | 122 | Trains the model for one-epoch. 123 | 124 | Args: 125 | sess: tf.Session() object 126 | Returns: 127 | average_loss: scalar. Average minibatch loss of model on epoch. 128 | """ 129 | total_steps = input_data.num_example // input_data.batch_sz 130 | total_loss = [] 131 | total_w_loss = [] 132 | total_ce_loss = [] 133 | collect_time = [] 134 | collect_data_time = [] 135 | accuracy_collect = [] 136 | step = -1 137 | dataset = [o for o in input_data] 138 | producer = Process(target=preparedata, 139 | args=(dataset, self.data_q, self.config.max_wd_num, self.config.id2weight)) 140 | producer.start() 141 | try: 142 | while True: 143 | step += 1 144 | start_stamp = time.time() 145 | data_batch = self.data_q.get() 146 | if data_batch is None: 147 | break 148 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=True) 149 | 150 | data_stamp = time.time() 151 | (accuracy, global_step, summary, opt_loss, w_loss, ce_loss, lr, _ 152 | ) = sess.run([self.model.accuracy, self.model.global_step, self.merged, 153 | self.model.opt_loss, self.model.w_loss, self.model.ce_loss, 154 | self.model.learning_rate, self.model.train_op], 155 | feed_dict=feed_dict) 156 | self.train_writer.add_summary(summary, global_step) 157 | self.train_writer.flush() 158 | end_stamp = time.time() 159 | 160 | collect_time.append(end_stamp-start_stamp) 161 | collect_data_time.append(data_stamp-start_stamp) 162 | accuracy_collect.append(accuracy) 163 | total_loss.append(opt_loss) 164 | total_w_loss.append(w_loss) 165 | total_ce_loss.append(ce_loss) 166 | 167 | if verbose and step % verbose == 0: 168 | sys.stdout.write('\r%d / %d : opt_loss = %.4f, w_loss = %.4f, ce_loss = %.4f, %.3fs/iter, %.3fs/batch' 169 | 'lr = %f, accu = %.4f, b_sz = %d' % ( 170 | step, total_steps, np.mean(total_loss[-verbose:]),np.mean(total_w_loss[-verbose:]), 171 | np.mean(total_ce_loss[-verbose:]), np.mean(collect_time), np.mean(collect_data_time), lr, 172 | np.mean(accuracy_collect[-verbose:]), input_data.batch_sz)) 173 | collect_time = [] 174 | sys.stdout.flush() 175 | utils.write_status(self.weight_path) 176 | except: 177 | traceback.print_exc() 178 | producer.terminate() 179 | exit() 180 | 181 | producer.join() 182 | 183 | sess.run(self.model.on_epoch_accu) 184 | 185 | return np.mean(total_ce_loss), np.mean(total_loss), np.mean(accuracy_collect) 186 | 187 | def fit(self, sess, input_data :TextIterator, verbose=10): 188 | """ 189 | Fit the model. 190 | 191 | Args: 192 | sess: tf.Session() object 193 | Returns: 194 | average_loss: scalar. Average minibatch loss of model on epoch. 195 | """ 196 | 197 | total_steps = input_data.num_example // input_data.batch_sz 198 | total_loss = [] 199 | total_ce_loss = [] 200 | collect_time = [] 201 | step = -1 202 | dataset = [o for o in input_data] 203 | producer = Process(target=preparedata, 204 | args=(dataset, self.data_q, self.config.max_wd_num, self.config.id2weight)) 205 | producer.start() 206 | try: 207 | while True: 208 | step += 1 209 | data_batch = self.data_q.get() 210 | if data_batch is None: 211 | break 212 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=False) 213 | 214 | start_stamp = time.time() 215 | (global_step, summary, ce_loss, opt_loss, 216 | ) = sess.run([self.model.global_step, self.merged, self.model.ce_loss, 217 | self.model.opt_loss], feed_dict=feed_dict) 218 | 219 | self.test_writer.add_summary(summary, step+global_step) 220 | self.test_writer.flush() 221 | 222 | end_stamp = time.time() 223 | collect_time.append(end_stamp - start_stamp) 224 | total_ce_loss.append(ce_loss) 225 | total_loss.append(opt_loss) 226 | 227 | if verbose and step % verbose == 0: 228 | sys.stdout.write('\r%d / %d: ce_loss = %f, opt_loss = %f, %.3fs/iter' % ( 229 | step, total_steps, np.mean(total_ce_loss[-verbose:]), 230 | np.mean(total_loss[-verbose:]), np.mean(collect_time))) 231 | collect_time = [] 232 | sys.stdout.flush() 233 | print('\n') 234 | except: 235 | traceback.print_exc() 236 | producer.terminate() 237 | exit() 238 | producer.join() 239 | return np.mean(total_ce_loss), np.mean(total_loss) 240 | 241 | def predict(self, sess, input_data: TextIterator, verbose=10): 242 | """ 243 | Args: 244 | sess: tf.Session() object 245 | Returns: 246 | average_loss: scalar. Average minibatch loss of model on epoch. 247 | """ 248 | total_steps = input_data.num_example // input_data.batch_sz 249 | collect_time = [] 250 | collect_pred = [] 251 | label_id = [] 252 | step = -1 253 | dataset = [o for o in input_data] 254 | producer = Process(target=preparedata, 255 | args=(dataset, self.data_q, self.config.max_wd_num, self.config.id2weight)) 256 | producer.start() 257 | try: 258 | while True: 259 | step += 1 260 | data_batch = self.data_q.get() 261 | if data_batch is None: 262 | break 263 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=False) 264 | 265 | start_stamp = time.time() 266 | pred = sess.run(self.model.prediction, feed_dict=feed_dict) 267 | end_stamp = time.time() 268 | collect_time.append(end_stamp - start_stamp) 269 | 270 | collect_pred.append(pred) 271 | label_id += data_batch[1].tolist() 272 | if verbose and step % verbose == 0: 273 | sys.stdout.write('\r%d / %d: , %.3fs/iter' % ( 274 | step, total_steps, np.mean(collect_time))) 275 | collect_time = [] 276 | sys.stdout.flush() 277 | print('\n') 278 | except: 279 | traceback.print_exc() 280 | producer.terminate() 281 | exit() 282 | producer.join() 283 | res_pred = np.concatenate(collect_pred, axis=0) 284 | return res_pred, label_id 285 | 286 | def test_case(self, sess, data, onset='VALIDATION'): 287 | print('#' * 20, 'ON ' + onset + ' SET START ', '#' * 20) 288 | print("=" * 10 + ' '.join(sys.argv) + "=" * 10) 289 | epoch = self.get_epoch(sess) 290 | ce_loss, opt_loss = self.fit(sess, data) 291 | pred, label = self.predict(sess, data) 292 | 293 | (prec, recall, overall_prec, overall_recall, _ 294 | ) = utils.calculate_confusion_single(pred, label, len(self.config.id2label)) 295 | 296 | utils.print_confusion_single(prec, recall, overall_prec, overall_recall, self.config.id2label) 297 | accuracy = utils.calculate_accuracy_single(pred, label) 298 | 299 | print('%d th Epoch -- Overall %s accuracy is: %f' % (epoch, onset, accuracy)) 300 | logging.info('%d th Epoch -- Overall %s accuracy is: %f' % (epoch, onset, accuracy)) 301 | 302 | print('%d th Epoch -- Overall %s ce_loss is: %f, opt_loss is: %f' % (epoch, onset, ce_loss, opt_loss)) 303 | logging.info('%d th Epoch -- Overall %s ce_loss is: %f, opt_loss is: %f' % (epoch, onset, ce_loss, opt_loss)) 304 | print('#' * 20, 'ON ' + onset + ' SET END ', '#' * 20) 305 | return accuracy, ce_loss 306 | 307 | def train_run(self): 308 | logging.info('Training start') 309 | logging.info("Parameter count is: %d" % self.model.param_cnt) 310 | if not args.restore_ckpt: 311 | self.remove_file(self.args.weight_path + '/summary.log') 312 | saver = tf.train.Saver(max_to_keep=30) 313 | 314 | config = tf.ConfigProto() 315 | config.gpu_options.allow_growth = True 316 | config.allow_soft_placement = True 317 | with tf.Session(config=config) as sess: 318 | 319 | self.merged = tf.summary.merge_all() 320 | self.train_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_train', 321 | sess.graph) 322 | self.test_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_test') 323 | 324 | sess.run(tf.global_variables_initializer()) 325 | if args.restore_ckpt: 326 | saver.restore(sess, self.args.weight_path + '/classifier.weights') 327 | best_loss = np.Inf 328 | best_accuracy = 0 329 | best_val_epoch = self.get_epoch(sess) 330 | 331 | for _ in range(self.config.max_epochs): 332 | 333 | epoch = self.get_epoch(sess) 334 | print("=" * 20 + "Epoch ", epoch, "=" * 20) 335 | ce_loss, opt_loss, accuracy = self.run_epoch(sess, self.train_data, verbose=10) 336 | print('') 337 | print("Mean ce_loss in %dth epoch is: %f, Mean ce_loss is: %f,"%(epoch, ce_loss, opt_loss)) 338 | print('Mean training accuracy is : %.4f' % accuracy) 339 | logging.info('Mean training accuracy is : %.4f' % accuracy) 340 | logging.info("Mean ce_loss in %dth epoch is: %f, Mean ce_loss is: %f,"%(epoch, ce_loss, opt_loss)) 341 | print('=' * 50) 342 | val_accuracy, val_loss = self.test_case(sess, self.dev_data, onset='VALIDATION') 343 | test_accuracy, test_loss = self.test_case(sess, self.test_data, onset='TEST') 344 | self.save_loss_accu(self.args.weight_path + '/summary.log', train_loss=ce_loss, 345 | valid_loss=val_loss, test_loss=test_loss, 346 | valid_accu=val_accuracy, test_accu=test_accuracy, epoch=epoch) 347 | if best_accuracy < val_accuracy: 348 | best_accuracy = val_accuracy 349 | best_val_epoch = epoch 350 | if not os.path.exists(self.args.weight_path): 351 | os.makedirs(self.args.weight_path) 352 | logging.info('best epoch is %dth epoch' % best_val_epoch) 353 | saver.save(sess, self.args.weight_path + '/classifier.weights') 354 | else: 355 | b_sz = self.train_data.batch_sz//2 356 | max_b_sz = max([b_sz, self.config.batch_sz_min]) 357 | buck_sz = self.train_data.bucket_sz * 2 358 | buck_sz = min([self.train_data.num_example, buck_sz]) 359 | self.train_data.batch_sz = max_b_sz 360 | self.train_data.bucket_sz = buck_sz 361 | 362 | if epoch - best_val_epoch > self.config.early_stopping: 363 | logging.info("Normal Early stop") 364 | break 365 | utils.write_status(self.weight_path, finished=True) 366 | logging.info("Training complete") 367 | 368 | def test_run(self): 369 | 370 | saver = tf.train.Saver(max_to_keep=30) 371 | 372 | config = tf.ConfigProto() 373 | config.gpu_options.allow_growth = True 374 | config.allow_soft_placement = True 375 | with tf.Session(config=config) as sess: 376 | self.merged = tf.summary.merge_all() 377 | self.test_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_test') 378 | 379 | sess.run(tf.global_variables_initializer()) 380 | saver.restore(sess, self.args.weight_path + '/classifier.weights') 381 | 382 | self.test_case(sess, self.test_data, onset='TEST') 383 | 384 | def main_run(self): 385 | 386 | if not os.path.exists(self.args.weight_path): 387 | os.makedirs(self.args.weight_path) 388 | logFile = self.args.weight_path + '/run.log' 389 | 390 | if self.args.train_test == "train": 391 | 392 | try: 393 | os.remove(logFile) 394 | except OSError: 395 | pass 396 | logging.basicConfig(filename=logFile, format='%(levelname)s %(asctime)s %(message)s', level=logging.INFO) 397 | debug('_main_run_') 398 | self.train_run() 399 | self.test_run() 400 | else: 401 | logging.basicConfig(filename=logFile, format='%(levelname)s %(asctime)s %(message)s', level=logging.INFO) 402 | self.test_run() 403 | 404 | @staticmethod 405 | def save_loss_accu(fileName, train_loss, valid_loss, 406 | test_loss, valid_accu, test_accu, epoch): 407 | with open(fileName, 'a') as fd: 408 | fd.write('%3d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\n' % 409 | (epoch, train_loss, valid_loss, 410 | test_loss, valid_accu, test_accu)) 411 | 412 | @staticmethod 413 | def remove_file(fileName): 414 | if os.path.exists(fileName): 415 | os.remove(fileName) 416 | 417 | if __name__ == '__main__': 418 | trainer = Train(args) 419 | trainer.main_run() 420 | 421 | -------------------------------------------------------------------------------- /caps_attn_flatten/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import _pickle as pkl 3 | import pdb 4 | import numpy as np 5 | import copy 6 | 7 | import os 8 | import warnings 9 | import sys 10 | from time import time 11 | import pprint 12 | import logging 13 | from collections import OrderedDict 14 | 15 | '''check alive''' 16 | def write_status(path, finished=False): 17 | full_path = path+'/status' 18 | if not finished: 19 | fd = open(full_path, 'w') 20 | fd.write(str(time())) 21 | fd.flush() 22 | fd.close() 23 | else: 24 | fd = open(full_path, 'w') 25 | fd.write('0.1') 26 | fd.flush() 27 | fd.close() 28 | 29 | def read_status(status_path): 30 | if not os.path.exists(status_path): 31 | return 'error' 32 | fd = open(status_path, 'r') 33 | time_stamp = float(fd.read().strip()) 34 | fd.close() 35 | if time_stamp < 10.: 36 | return 'finished' 37 | cur_time = time() 38 | if cur_time - time_stamp < 1000.: 39 | return 'running' 40 | else: 41 | return 'error' 42 | 43 | def valid_entry(save_path): 44 | 45 | if not os.path.exists(save_path): 46 | return False 47 | if read_status(save_path + '/status') == 'running': 48 | return True 49 | if read_status(save_path + '/status') == 'finished': 50 | return True 51 | if read_status(save_path + '/status') == 'error': 52 | return False 53 | 54 | raise ValueError('unknown error') 55 | 56 | def pad(x, len_x): 57 | if len(x) > len_x: 58 | return x[:len_x] 59 | return x+[0]* (len_x-len(x)) 60 | # batch preparation 61 | def prepare_data(seqs_x, seqs_char_x, seqs_pos_x, seqs_em_x, 62 | seqs_y, seqs_char_y, seqs_pos_y, seqs_em_y, 63 | labels, max_char_len): 64 | 65 | lengths_x = [len(s) for s in seqs_x] 66 | lengths_char_x = [len(s) for s in seqs_char_x] 67 | lengths_pos_x = [len(s) for s in seqs_pos_x] 68 | lengths_em_x = [len(s) for s in seqs_em_x] 69 | 70 | lengths_y = [len(s) for s in seqs_y] 71 | lengths_char_y = [len(s) for s in seqs_char_y] 72 | lengths_pos_y = [len(s) for s in seqs_pos_y] 73 | lengths_em_y = [len(s) for s in seqs_em_y] 74 | 75 | assert np.all(np.equal(lengths_x, lengths_pos_x)) 76 | assert np.all(np.equal(lengths_x, lengths_char_x)) 77 | assert np.all(np.equal(lengths_x, lengths_em_x)) 78 | 79 | assert np.all(np.equal(lengths_y, lengths_pos_y)) 80 | assert np.all(np.equal(lengths_y, lengths_char_y)) 81 | assert np.all(np.equal(lengths_y, lengths_em_y)) 82 | 83 | n_samples = len(seqs_x) 84 | maxlen_x = np.max(lengths_x) 85 | maxlen_y = np.max(lengths_y) 86 | 87 | seqs_char_x = [[pad(w_lst, max_char_len) for w_lst in snt] for snt in seqs_char_x] 88 | seqs_char_y = [[pad(w_lst, max_char_len) for w_lst in snt] for snt in seqs_char_y] 89 | 90 | x = np.zeros((n_samples, maxlen_x)).astype('int32') 91 | x_pos = np.zeros((n_samples, maxlen_x)).astype('int32') 92 | x_em = np.zeros((n_samples, maxlen_x)).astype('int32') 93 | x_char = np.zeros((n_samples, maxlen_x, max_char_len)).astype('int32') 94 | 95 | y = np.zeros((n_samples, maxlen_y)).astype('int32') 96 | y_pos = np.zeros((n_samples, maxlen_y)).astype('int32') 97 | y_em = np.zeros((n_samples, maxlen_y)).astype('int32') 98 | y_char = np.zeros((n_samples, maxlen_y, max_char_len)).astype('int32') 99 | 100 | l = np.zeros((n_samples,)).astype('int32') 101 | for idx, [s_x, s_char_x, s_pos_x, s_em_x, s_y, s_char_y, s_pos_y, s_em_y, ll] in enumerate(zip( 102 | seqs_x, seqs_char_x, seqs_pos_x, seqs_em_x, 103 | seqs_y, seqs_char_y, seqs_pos_y, seqs_em_y, labels)): 104 | 105 | x[idx, :lengths_x[idx]] = s_x 106 | x_char[idx, :lengths_x[idx]] = s_char_x 107 | x_pos[idx, :lengths_x[idx]] = s_pos_x 108 | x_em[idx, :lengths_x[idx]] = s_em_x 109 | 110 | y[idx, :lengths_y[idx]] = s_y 111 | y_char[idx, :lengths_y[idx]] = s_char_y 112 | y_pos[idx, :lengths_y[idx]] = s_pos_y 113 | y_em[idx, :lengths_y[idx]] = s_em_y 114 | 115 | l[idx] = ll 116 | 117 | return x, x_char, x_pos, x_em, lengths_x, y, y_char, y_pos, y_em, lengths_y, l 118 | 119 | '''===============================================================''' 120 | 121 | '''Read and make embedding''' 122 | 123 | def readEmbedding(fileName): 124 | """ 125 | Read Embedding Function 126 | 127 | Args: 128 | fileName : file which stores the embedding 129 | Returns: 130 | embeddings_index : a dictionary contains the mapping from word to vector 131 | """ 132 | embeddings_index = {} 133 | with open(fileName, 'rb') as f: 134 | for line in f: 135 | line_uni = line.strip() 136 | line_uni = line_uni.decode('utf-8') 137 | values = line_uni.split(' ') 138 | word = values[0] 139 | try: 140 | coefs = np.asarray(values[1:], dtype='float32') 141 | except: 142 | print(values, len(values)) 143 | embeddings_index[word] = coefs 144 | return embeddings_index 145 | 146 | def mkEmbedMatrix(embed_dic, vocab_dic): 147 | """ 148 | Construct embedding matrix 149 | 150 | Args: 151 | embed_dic : word-embedding dictionary 152 | vocab_dic : word-index dictionary 153 | Returns: 154 | embedding_matrix: return embedding matrix 155 | """ 156 | if type(embed_dic) is not dict or type(vocab_dic) is not dict: 157 | raise TypeError('Inputs are not dictionary') 158 | if len(embed_dic) < 1 or len(vocab_dic) < 1: 159 | raise ValueError('Input dimension less than 1') 160 | vocab_sz = max(vocab_dic.values()) + 1 161 | EMBEDDING_DIM = len(list(embed_dic.values())[0]) 162 | # embedding_matrix = np.zeros((len(vocab_dic), EMBEDDING_DIM), dtype=np.float32) 163 | embedding_matrix = np.random.rand(vocab_sz, EMBEDDING_DIM).astype(np.float32) * 0.05 164 | valid_mask = np.ones(vocab_sz, dtype=np.bool) 165 | for word, i in vocab_dic.items(): 166 | embedding_vector = embed_dic.get(word) 167 | if embedding_vector is not None: 168 | # words not found in embedding index will be all-zeros. 169 | embedding_matrix[i] = embedding_vector 170 | else: 171 | valid_mask[i] = False 172 | return embedding_matrix, valid_mask 173 | 174 | '''evaluation''' 175 | 176 | def pred_from_prob_single(prob_matrix): 177 | """ 178 | 179 | Args: 180 | prob_matrix: probability matrix have the shape of (data_num, class_num), 181 | type of float. Generated from softmax activation 182 | 183 | Returns: 184 | ret: return class ids, shape of(data_num,) 185 | """ 186 | ret = np.argmax(prob_matrix, axis=1) 187 | return ret 188 | 189 | 190 | def calculate_accuracy_single(pred_ids, label_ids): 191 | """ 192 | Args: 193 | pred_ids: prediction id list shape of (data_num, ), type of int 194 | label_ids: true label id list, same shape and type as pred_ids 195 | 196 | Returns: 197 | accuracy: accuracy of the prediction, type float 198 | """ 199 | if np.ndim(pred_ids) != 1 or np.ndim(label_ids) != 1: 200 | raise TypeError('require rank 1, 1. get {}, {}'.format(np.rank(pred_ids), np.rank(label_ids))) 201 | if len(pred_ids) != len(label_ids): 202 | raise TypeError('first argument and second argument have different length') 203 | 204 | accuracy = np.mean(np.equal(pred_ids, label_ids)) 205 | return accuracy 206 | 207 | 208 | def calculate_confusion_single(pred_list, label_list, label_size): 209 | """Helper method that calculates confusion matrix.""" 210 | confusion = np.zeros((label_size, label_size), dtype=np.int32) 211 | for i in range(len(label_list)): 212 | confusion[label_list[i], pred_list[i]] += 1 213 | 214 | tp_fp = np.sum(confusion, axis=0) 215 | tp_fn = np.sum(confusion, axis=1) 216 | tp = np.array([confusion[i, i] for i in range(len(confusion))]) 217 | 218 | precision = tp.astype(np.float32)/(tp_fp+1e-40) 219 | recall = tp.astype(np.float32)/(tp_fn+1e-40) 220 | overall_prec = np.float(np.sum(tp))/(np.sum(tp_fp)+1e-40) 221 | overall_recall = np.float(np.sum(tp))/(np.sum(tp_fn)+1e-40) 222 | 223 | return precision, recall, overall_prec, overall_recall, confusion 224 | 225 | 226 | def print_confusion_single(prec, recall, overall_prec, overall_recall, num_to_tag): 227 | """Helper method that prints confusion matrix.""" 228 | logstr="\n" 229 | logstr += '{:15}\t{:7}\t{:7}\n'.format('TAG', 'Prec', 'Recall') 230 | for i, tag in sorted(num_to_tag.items()): 231 | logstr += '{:15}\t{:2.4f}\t{:2.4f}\n'.format(tag, prec[i], recall[i]) 232 | logstr += '{:15}\t{:2.4f}\t{:2.4f}\n'.format('OVERALL', overall_prec, overall_recall) 233 | logging.info(logstr) 234 | print(logstr) 235 | 236 | 237 | def save_objs(obj, path): 238 | with open(path, 'wb') as fd: 239 | pkl.dump(obj, fd) 240 | 241 | 242 | def load_objs(path): 243 | with open(path, 'rb') as fd: 244 | ret = pkl.load(fd) 245 | return ret -------------------------------------------------------------------------------- /caps_attn_hierarchical/Capsule_masked.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.layers import base as base_layer 2 | import tensorflow as tf 3 | import numpy as np 4 | from TfUtils import mkMask 5 | 6 | _EPSILON = 1e-9 7 | _MIN_NUM = -np.Inf 8 | 9 | class Capusule(base_layer.Layer): 10 | def __init__(self, out_caps_num, out_caps_dim, iter_num=3, reuse=None): 11 | super(Capusule, self).__init__(_reuse=reuse) 12 | self.out_caps_num = out_caps_num 13 | self.out_caps_dim = out_caps_dim 14 | self.iter_num = iter_num 15 | 16 | 17 | def call(self, in_caps, seqLen, reverse_routing=False): 18 | caps_uhat = shared_routing_uhat(in_caps, self.out_caps_num, self.out_caps_dim, scope='rnn_caps_uhat') 19 | if not reverse_routing: 20 | V, S = masked_routing_iter(caps_uhat, seqLen, self.iter_num) 21 | else: 22 | V, S = masked_reverse_routing_iter(caps_uhat, seqLen, self.iter_num) 23 | return V 24 | 25 | 26 | def shared_routing_uhat(caps, out_caps_num, out_caps_dim, scope=None): 27 | ''' 28 | 29 | Args: 30 | caps: # shape(b_sz, caps_num, caps_dim) 31 | out_caps_num: #number of output capsule 32 | out_caps_dim: #dimension of output capsule 33 | Returns: 34 | caps_uhat: shape(b_sz, caps_num, out_caps_num, out_caps_dim) 35 | ''' 36 | b_sz = tf.shape(caps)[0] 37 | tstp = tf.shape(caps)[1] 38 | 39 | with tf.variable_scope(scope or 'shared_routing_uhat'): 40 | '''shape(b_sz, caps_num, out_caps_num*out_caps_dim)''' 41 | caps_uhat = tf.layers.dense(caps, out_caps_num * out_caps_dim, activation=tf.tanh) 42 | caps_uhat = tf.reshape(caps_uhat, shape=[b_sz, tstp, out_caps_num, out_caps_dim]) 43 | 44 | return caps_uhat 45 | 46 | 47 | def masked_routing_iter(caps_uhat, seqLen, iter_num): 48 | ''' 49 | 50 | Args: 51 | caps_uhat: shape(b_sz, tstp, out_caps_num, out_caps_dim) 52 | seqLen: shape(b_sz) 53 | iter_num: number of iteration 54 | 55 | Returns: 56 | V_ret: #shape(b_sz, out_caps_num, out_caps_dim) 57 | ''' 58 | assert iter_num > 0 59 | b_sz = tf.shape(caps_uhat)[0] 60 | tstp = tf.shape(caps_uhat)[1] 61 | out_caps_num = int(caps_uhat.get_shape()[2]) 62 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 63 | mask = mkMask(seqLen, tstp) # shape(b_sz, tstp) 64 | floatmask = tf.cast(tf.expand_dims(mask, axis=-1), dtype=tf.float32) # shape(b_sz, tstp, 1) 65 | 66 | # shape(b_sz, tstp, out_caps_num) 67 | B = tf.zeros([b_sz, tstp, out_caps_num], dtype=tf.float32) 68 | for i in range(iter_num): 69 | C = tf.nn.softmax(B, dim=2) # shape(b_sz, tstp, out_caps_num) 70 | C = tf.expand_dims(C*floatmask, axis=-1) # shape(b_sz, tstp, out_caps_num, 1) 71 | weighted_uhat = C * caps_uhat # shape(b_sz, tstp, out_caps_num, out_caps_dim) 72 | 73 | S = tf.reduce_sum(weighted_uhat, axis=1) # shape(b_sz, out_caps_num, out_caps_dim) 74 | 75 | V = _squash(S, axes=[2]) # shape(b_sz, out_caps_num, out_caps_dim) 76 | V = tf.expand_dims(V, axis=1) # shape(b_sz, 1, out_caps_num, out_caps_dim) 77 | B = tf.reduce_sum(caps_uhat * V, axis=-1) + B # shape(b_sz, tstp, out_caps_num) 78 | 79 | V_ret = tf.squeeze(V, axis=[1]) # shape(b_sz, out_caps_num, out_caps_dim) 80 | S_ret = S 81 | return V_ret, S_ret 82 | 83 | 84 | def masked_reverse_routing_iter(caps_uhat, seqLen, iter_num): 85 | ''' 86 | 87 | Args: 88 | caps_uhat: shape(b_sz, tstp, out_caps_num, out_caps_dim) 89 | seqLen: shape(b_sz) 90 | iter_num: number of iteration 91 | 92 | Returns: 93 | V_ret: #shape(b_sz, out_caps_num, out_caps_dim) 94 | ''' 95 | assert iter_num > 0 96 | b_sz = tf.shape(caps_uhat)[0] 97 | tstp = tf.shape(caps_uhat)[1] 98 | out_caps_num = int(caps_uhat.get_shape()[2]) 99 | 100 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 101 | mask = mkMask(seqLen, tstp) # shape(b_sz, tstp) 102 | mask = tf.tile(tf.expand_dims(mask, axis=-1), # shape(b_sz, tstp, out_caps_num) 103 | multiples=[1, 1, out_caps_num]) 104 | # shape(b_sz, tstp, out_caps_num) 105 | B = tf.zeros([b_sz, tstp, out_caps_num], dtype=tf.float32) 106 | B = tf.where(mask, B, tf.ones_like(B) * _MIN_NUM) 107 | for i in range(iter_num): 108 | C = tf.nn.softmax(B, dim=1) # shape(b_sz, tstp, out_caps_num) 109 | C = tf.expand_dims(C, axis=-1) # shape(b_sz, tstp, out_caps_num, 1) 110 | weighted_uhat = C * caps_uhat # shape(b_sz, tstp, out_caps_num, out_caps_dim) 111 | 112 | S = tf.reduce_sum(weighted_uhat, axis=1) # shape(b_sz, out_caps_num, out_caps_dim) 113 | 114 | V = _squash(S, axes=[2]) # shape(b_sz, out_caps_num, out_caps_dim) 115 | V = tf.expand_dims(V, axis=1) # shape(b_sz, 1, out_caps_num, out_caps_dim) 116 | B = tf.reduce_sum(caps_uhat * V, axis=-1) + B # shape(b_sz, tstp, out_caps_num) 117 | 118 | V_ret = tf.squeeze(V, axis=[1]) # shape(b_sz, out_caps_num, out_caps_dim) 119 | S_ret = S 120 | return V_ret, S_ret 121 | 122 | 123 | def margin_loss(y_true, y_pred): 124 | """ 125 | :param y_true: [None, n_classes] 126 | :param y_pred: [None, n_classes] 127 | :return: a scalar loss value. 128 | """ 129 | L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \ 130 | 0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1)) 131 | 132 | assert_inf_L = tf.Assert(tf.logical_not(tf.reduce_any(tf.is_inf(L))), 133 | ['assert_inf_L', L], summarize=100) 134 | assert_nan_L = tf.Assert(tf.logical_not(tf.reduce_any(tf.is_nan(L))), 135 | ['assert_nan_L', L], summarize=100) 136 | with tf.control_dependencies([assert_inf_L, assert_nan_L]): 137 | ret = tf.reduce_mean(tf.reduce_sum(L, axis=1)) 138 | 139 | return ret 140 | 141 | 142 | def _squash(in_caps, axes): 143 | ''' 144 | Squashing function corresponding to Eq. 1 145 | Args: 146 | in_caps: a tensor 147 | axes: dimensions along which to apply squash 148 | 149 | Returns: 150 | vec_squashed: squashed tensor 151 | 152 | ''' 153 | vec_squared_norm = tf.reduce_sum(tf.square(in_caps), axis=axes, keep_dims=True) 154 | scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + _EPSILON) 155 | vec_squashed = scalar_factor * in_caps # element-wise 156 | return vec_squashed 157 | 158 | 159 | -------------------------------------------------------------------------------- /caps_attn_hierarchical/Config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import traceback 3 | import json 4 | 5 | 6 | class Config(object): 7 | """Holds model hyperparams and data information. 8 | 9 | The config class is used to store various hyperparameters and dataset 10 | information parameters. Model objects are passed a Config() object at 11 | instantiation. 12 | 13 | """ 14 | 15 | """General""" 16 | revision = 'None' 17 | datapath = './data/smallset/' 18 | embed_path = './data/embedding.txt' 19 | 20 | optimizer = 'adam' 21 | attn_mode = 'attn' 22 | seq_encoder = 'bigru' 23 | 24 | out_caps_num = 5 25 | rout_iter = 3 26 | 27 | max_snt_num = 30 28 | max_wd_num = 30 29 | max_epochs = 50 30 | pre_trained = True 31 | batch_sz = 64 32 | batch_sz_min = 32 33 | bucket_sz = 5000 34 | partial_update_until_epoch = 1 35 | 36 | embed_size = 300 37 | hidden_size = 200 38 | dense_hidden = [300, 5] 39 | 40 | lr = 0.0001 41 | decay_steps = 1000 42 | decay_rate = 0.9 43 | 44 | dropout = 0.2 45 | early_stopping = 7 46 | reg = 0. 47 | 48 | def __init__(self): 49 | self.attr_list = [i for i in list(Config.__dict__.keys()) if 50 | not callable(getattr(self, i)) and not i.startswith("__")] 51 | 52 | def printall(self): 53 | for attr in self.attr_list: 54 | print(attr, getattr(self, attr), type(getattr(self, attr))) 55 | 56 | def saveConfig(self, filePath): 57 | 58 | cfg = configparser.ConfigParser() 59 | cfg['General'] = {} 60 | gen_sec = cfg['General'] 61 | for attr in self.attr_list: 62 | try: 63 | gen_sec[attr] = json.dumps(getattr(self, attr)) 64 | except Exception as e: 65 | traceback.print_exc() 66 | raise ValueError('something wrong in “%s” entry' % attr) 67 | 68 | with open(filePath, 'w') as fd: 69 | cfg.write(fd) 70 | 71 | def loadConfig(self, filePath): 72 | 73 | cfg = configparser.ConfigParser() 74 | cfg.read(filePath) 75 | gen_sec = cfg['General'] 76 | for attr in self.attr_list: 77 | try: 78 | val = json.loads(gen_sec[attr]) 79 | assert type(val) == type(getattr(self, attr)), \ 80 | 'type not match, expect %s got %s' % \ 81 | (type(getattr(self, attr)), type(val)) 82 | 83 | setattr(self, attr, val) 84 | except Exception as e: 85 | traceback.print_exc() 86 | raise ValueError('something wrong in “%s” entry' % attr) 87 | 88 | with open(filePath, 'w') as fd: 89 | cfg.write(fd) -------------------------------------------------------------------------------- /caps_attn_hierarchical/TfUtils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import nest 3 | import numpy as np 4 | 5 | def mkMask(input_tensor, maxLen): 6 | shape_of_input = tf.shape(input_tensor) 7 | shape_of_output = tf.concat(axis=0, values=[shape_of_input, [maxLen]]) 8 | 9 | oneDtensor = tf.reshape(input_tensor, shape=(-1,)) 10 | flat_mask = tf.sequence_mask(oneDtensor, maxlen=maxLen) 11 | return tf.reshape(flat_mask, shape_of_output) 12 | 13 | 14 | def reduce_avg(reduce_target, lengths, dim): 15 | """ 16 | Args: 17 | reduce_target : shape(d_0, d_1,..,d_dim, .., d_k) 18 | lengths : shape(d0, .., d_(dim-1)) 19 | dim : which dimension to average, should be a python number 20 | """ 21 | shape_of_lengths = lengths.get_shape() 22 | shape_of_target = reduce_target.get_shape() 23 | if len(shape_of_lengths) != dim: 24 | raise ValueError(('Second input tensor should be rank %d, ' + 25 | 'while it got rank %d') % (dim, len(shape_of_lengths))) 26 | if len(shape_of_target) < dim+1 : 27 | raise ValueError(('First input tensor should be at least rank %d, ' + 28 | 'while it got rank %d') % (dim+1, len(shape_of_target))) 29 | 30 | rank_diff = len(shape_of_target) - len(shape_of_lengths) - 1 31 | mxlen = tf.shape(reduce_target)[dim] 32 | mask = mkMask(lengths, mxlen) 33 | if rank_diff!=0: 34 | len_shape = tf.concat(axis=0, values=[tf.shape(lengths), [1]*rank_diff]) 35 | mask_shape = tf.concat(axis=0, values=[tf.shape(mask), [1]*rank_diff]) 36 | else: 37 | len_shape = tf.shape(lengths) 38 | mask_shape = tf.shape(mask) 39 | lengths_reshape = tf.reshape(lengths, shape=len_shape) 40 | mask = tf.reshape(mask, shape=mask_shape) 41 | 42 | mask_target = reduce_target * tf.cast(mask, dtype=reduce_target.dtype) 43 | 44 | red_sum = tf.reduce_sum(mask_target, axis=[dim], keep_dims=False) 45 | red_avg = red_sum / (tf.to_float(lengths_reshape) + 1e-30) 46 | return red_avg 47 | 48 | 49 | def reduce_sum(reduce_target, lengths, dim): 50 | """ 51 | Args: 52 | reduce_target : shape(d_0, d_1,..,d_dim, .., d_k) 53 | lengths : shape(d0, .., d_(dim-1)) 54 | dim : which dimension to average, should be a python number 55 | """ 56 | shape_of_lengths = lengths.get_shape() 57 | shape_of_target = reduce_target.get_shape() 58 | if len(shape_of_lengths) != dim: 59 | raise ValueError(('Second input tensor should be rank %d, ' + 60 | 'while it got rank %d') % (dim, len(shape_of_lengths))) 61 | if len(shape_of_target) < dim+1 : 62 | raise ValueError(('First input tensor should be at least rank %d, ' + 63 | 'while it got rank %d') % (dim+1, len(shape_of_target))) 64 | 65 | rank_diff = len(shape_of_target) - len(shape_of_lengths) - 1 66 | mxlen = tf.shape(reduce_target)[dim] 67 | mask = mkMask(lengths, mxlen) 68 | if rank_diff!=0: 69 | len_shape = tf.concat(axis=0, values=[tf.shape(lengths), [1]*rank_diff]) 70 | mask_shape = tf.concat(axis=0, values=[tf.shape(mask), [1]*rank_diff]) 71 | else: 72 | len_shape = tf.shape(lengths) 73 | mask_shape = tf.shape(mask) 74 | lengths_reshape = tf.reshape(lengths, shape=len_shape) 75 | mask = tf.reshape(mask, shape=mask_shape) 76 | 77 | mask_target = reduce_target * tf.cast(mask, dtype=reduce_target.dtype) 78 | 79 | red_sum = tf.reduce_sum(mask_target, axis=[dim], keep_dims=False) 80 | 81 | return red_sum 82 | 83 | 84 | def embed_lookup_last_dim(embedding, ids): 85 | ''' 86 | embedding: shape(b_sz, tstp, emb_sz) 87 | ids : shape(b_sz, tstp) 88 | ''' 89 | input_shape = tf.shape(embedding) 90 | time_steps = input_shape[0] 91 | def _create_ta(name, dtype): 92 | return tf.TensorArray(dtype=dtype, 93 | size=time_steps, 94 | tensor_array_name=name) 95 | input_ta = _create_ta('input_ta', embedding.dtype) 96 | fetch_ta = _create_ta('fetch_ta', ids.dtype) 97 | output_ta = _create_ta('output_ta', embedding.dtype) 98 | input_ta = input_ta.unpack(embedding) 99 | fetch_ta = fetch_ta.unpack(ids) 100 | 101 | def loop_body(time, output_ta): 102 | embed = input_ta.read(time) #shape(tstp, emb_sz) type of float32 103 | fetch_id = fetch_ta.read(time) #shape(tstp) type of int32 104 | out_emb = tf.nn.embedding_lookup(embed, fetch_id) 105 | output_ta = output_ta.write(time, out_emb) 106 | 107 | next_time = time+1 108 | return next_time, output_ta 109 | time = tf.constant(0) 110 | _, output_ta = tf.while_loop(cond=lambda time, *_: time < time_steps, 111 | body=loop_body, loop_vars=(time, output_ta), 112 | swap_memory=True) 113 | ret_t = output_ta.pack() #shape(b_sz, tstp, embd_sz) 114 | return ret_t 115 | 116 | 117 | def entry_stop_gradients(target, mask): 118 | ''' 119 | Args: 120 | target: a tensor 121 | mask: a boolean tensor that broadcast to the rank of that to target tensor 122 | Returns: 123 | ret: a tensor have the same value of target, 124 | but some entry will have no gradient during backprop 125 | ''' 126 | mask_h = tf.logical_not(mask) 127 | 128 | mask = tf.cast(mask, dtype=target.dtype) 129 | mask_h = tf.cast(mask_h, dtype=target.dtype) 130 | ret = tf.stop_gradient(mask_h * target) + mask * target 131 | 132 | return ret 133 | 134 | 135 | def last_dim_linear(inputs, output_size, bias, scope): 136 | ''' 137 | Args: 138 | input: shape(b_sz, ..., rep_sz) 139 | output_size: a scalar, python number 140 | ''' 141 | bias_start=0.0 142 | input_shape = tf.shape(inputs) 143 | out_shape = tf.concat(axis=0, values=[input_shape[:-1], [output_size]]) 144 | input_size = int(inputs.get_shape()[-1]) 145 | unbatch_input = tf.reshape(inputs, shape=[-1, input_size]) 146 | 147 | unbatch_output = linear(unbatch_input, output_size, bias=bias, 148 | bias_start=bias_start, scope=scope) 149 | batch_output = tf.reshape(unbatch_output, shape=out_shape) 150 | 151 | return batch_output # shape(b_sz, ..., output_size) 152 | 153 | 154 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 155 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 156 | 157 | Args: 158 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 159 | output_size: int, second dimension of W[i]. 160 | bias: boolean, whether to add a bias term or not. 161 | bias_start: starting value to initialize the bias; 0 by default. 162 | scope: (optional) Variable scope to create parameters in. 163 | 164 | Returns: 165 | A 2D Tensor with shape [batch x output_size] equal to 166 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 167 | 168 | Raises: 169 | ValueError: if some of the arguments has unspecified or wrong shape. 170 | """ 171 | if args is None or (nest.is_sequence(args) and not args): 172 | raise ValueError("`args` must be specified") 173 | if not nest.is_sequence(args): 174 | args = [args] 175 | 176 | # Calculate the total size of arguments on dimension 1. 177 | total_arg_size = 0 178 | shapes = [a.get_shape() for a in args] 179 | for shape in shapes: 180 | if shape.ndims != 2: 181 | raise ValueError("linear is expecting 2D arguments: %s" % shapes) 182 | if shape[1].value is None: 183 | raise ValueError("linear expects shape[1] to be provided for shape %s, " 184 | "but saw %s" % (shape, shape[1])) 185 | else: 186 | total_arg_size += shape[1].value 187 | 188 | dtype = [a.dtype for a in args][0] 189 | 190 | # Now the computation. 191 | with tf.variable_scope(scope or 'Linear') as outer_scope: 192 | weights = tf.get_variable( 193 | "weights", [total_arg_size, output_size], dtype=dtype) 194 | if len(args) == 1: 195 | res = tf.matmul(args[0], weights) 196 | else: 197 | res = tf.matmul(tf.concat(args, 1), weights) 198 | if not bias: 199 | return res 200 | with tf.variable_scope(outer_scope) as inner_scope: 201 | inner_scope.set_partitioner(None) 202 | biases = tf.get_variable( 203 | "biases", [output_size], 204 | dtype=dtype, 205 | initializer=tf.constant_initializer(bias_start, dtype=dtype)) 206 | return tf.nn.bias_add(res, biases) 207 | 208 | 209 | def masked_softmax(inp, seqLen): 210 | seqLen = tf.where(tf.equal(seqLen, 0), tf.ones_like(seqLen), seqLen) 211 | if len(inp.get_shape()) != len(seqLen.get_shape())+1: 212 | raise ValueError('rank of seqLen should be %d, but have the rank %d.\n' 213 | % (len(inp.get_shape())-1, len(seqLen.get_shape()))) 214 | mask = mkMask(seqLen, tf.shape(inp)[-1]) 215 | masked_inp = tf.where(mask, inp, tf.ones_like(inp) * (-np.Inf)) 216 | ret = tf.nn.softmax(masked_inp) 217 | return ret 218 | 219 | from tensorflow.python.client import device_lib 220 | def get_available_gpus(): 221 | local_device_protos = device_lib.list_local_devices() 222 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 223 | -------------------------------------------------------------------------------- /caps_attn_hierarchical/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingjing-gong/Capsule4TextClassification/35a5e0d65f45c7810c082c1046535de9207eb12b/caps_attn_hierarchical/__init__.py -------------------------------------------------------------------------------- /caps_attn_hierarchical/data_iterator.py: -------------------------------------------------------------------------------- 1 | import _pickle as pkl 2 | import numpy as np 3 | from multiprocessing import Queue 4 | 5 | class TextIterator: 6 | """Simple Bitext iterator.""" 7 | def __init__(self, datapath, batch_size=128, bucket_sz=1000, shuffle=False, sample_balance=False, id2weight=None): 8 | 9 | with open(datapath, 'rb') as fd: 10 | data = pkl.load(fd) 11 | '''data==> [(labe, doc),]''' 12 | example_num = len(data) 13 | '''shape(example_num)''' 14 | doc_sz = np.array([len(doc) for _, doc in data], dtype=np.int32) 15 | 16 | if shuffle: 17 | self.tidx = np.argsort(doc_sz) 18 | else: 19 | self.tidx = np.arange(example_num) 20 | 21 | self.num_example = example_num 22 | self.shuffle = shuffle 23 | self.bucket_sz = bucket_sz 24 | self.batch_sz = batch_size 25 | self.data = data 26 | 27 | self.sample_balance = sample_balance 28 | self.id2weight = id2weight 29 | 30 | def __iter__(self): 31 | if self.bucket_sz < self.batch_sz: 32 | self.bucket_sz = self.batch_sz 33 | if self.bucket_sz > self.num_example: 34 | self.bucket_sz = self.num_example 35 | self.startpoint = 0 36 | return self 37 | 38 | def __next__(self): 39 | if self.startpoint >= self.num_example: 40 | raise StopIteration 41 | 42 | if self.shuffle: 43 | bucket_start = np.random.randint(0, self.num_example) 44 | bucket_end = (bucket_start + self.bucket_sz) % self.num_example 45 | if bucket_end - bucket_start < self.bucket_sz: 46 | candidate = np.concatenate([self.tidx[bucket_start:], self.tidx[:bucket_end]]) 47 | else: 48 | candidate = self.tidx[bucket_start: bucket_end] 49 | candidate_p = None 50 | if self.sample_balance and self.id2weight: 51 | candidate_label = [self.data[c][0] for c in candidate] 52 | candidate_p = np.array([self.id2weight[l] for l in candidate_label]) 53 | candidate_p = candidate_p/np.sum(candidate_p) 54 | target_idx = np.random.choice(candidate, size=self.batch_sz, p=candidate_p) 55 | else: 56 | target_idx = self.tidx[self.startpoint:self.startpoint+self.batch_sz] 57 | self.startpoint += self.batch_sz 58 | 59 | labels = [] 60 | data_x = [] 61 | for idx in target_idx: 62 | l, d = self.data[idx] 63 | labels.append(l) 64 | data_x.append(d) 65 | return labels, data_x 66 | 67 | 68 | def preparedata(dataset: list, q: Queue, max_snt_num: int, max_wd_num: int, class_freq: dict): 69 | for labels, data_x in dataset: 70 | example_weight = np.array([class_freq[i] for i in labels]) #(b_sz) 71 | data_batch, sNum, wNum = paddata(data_x, max_snt_num=max_snt_num, max_wd_num=max_wd_num) 72 | labels = np.array(labels) 73 | q.put((data_batch, labels, sNum, wNum, example_weight)) 74 | q.put(None) 75 | 76 | 77 | def paddata(data_x: list, max_snt_num: int, max_wd_num: int): 78 | ''' 79 | 80 | :param data_x: (b_sz, snt_num, wd_num) 81 | :param max_snt_num: 82 | :param max_wd_num: 83 | :return: 84 | ''' 85 | 86 | b_sz = len(data_x) 87 | 88 | snt_num = np.array([len(doc) if len(doc) < max_snt_num else max_snt_num for doc in data_x], dtype=np.int32) 89 | snt_sz = np.max(snt_num) 90 | 91 | wd_num = [[len(sent) if len(sent) < max_wd_num else max_wd_num for sent in doc] for doc in data_x] 92 | wd_sz = min(max(map(max, wd_num)), max_wd_num) 93 | 94 | b = np.zeros(shape=[b_sz, snt_sz, wd_sz], dtype=np.int32) # == PAD 95 | 96 | sNum = snt_num 97 | wNum = np.zeros(shape=[b_sz, snt_sz], dtype=np.int32) 98 | 99 | for i, document in enumerate(data_x): 100 | for j, sentence in enumerate(document): 101 | if j >= snt_sz: 102 | continue 103 | wNum[i, j] = wd_num[i][j] 104 | for k, word in enumerate(sentence): 105 | if k >= wd_sz: 106 | continue 107 | b[i, j, k] = word 108 | 109 | return b, sNum, wNum -------------------------------------------------------------------------------- /caps_attn_hierarchical/dataprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingjing-gong/Capsule4TextClassification/35a5e0d65f45c7810c082c1046535de9207eb12b/caps_attn_hierarchical/dataprocess/__init__.py -------------------------------------------------------------------------------- /caps_attn_hierarchical/dataprocess/dataprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pkl 3 | import os, operator 4 | from collections import defaultdict 5 | from tensorflow.python.util import nest 6 | from vocab import Vocab 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description="datasets") 10 | 11 | parser.add_argument('--train-set', action='store', dest='train_set', default=None) 12 | parser.add_argument('--dev-set', action='store', dest='dev_set', default=None) 13 | parser.add_argument('--test-set', action='store', dest='test_set', default=None) 14 | parser.add_argument('--ref-embedding', action='store', dest='ref_emb', default='/home/jjgong/data/glove300d/glove.840B.300d.txt') 15 | parser.add_argument('--dest-dir', action='store', dest='dest_dir', default='./') 16 | parser.add_argument('--label2id', action='store', dest='label2id', default=None) 17 | 18 | args = parser.parse_args() 19 | 20 | def extract(fn): 21 | label_collect = [] 22 | doc_tok_collect = [] 23 | with open(fn, 'r') as fd: 24 | for line in fd: 25 | item = line.strip().split('\t\t') 26 | try: 27 | label = item[2] 28 | doc = item[3] 29 | except: 30 | print(line) 31 | print(item) 32 | raise ValueError 33 | 34 | doc2snt = doc.strip().split('') 35 | doc2snt2wd = [snt.strip().split(' ') for snt in doc2snt] 36 | label_collect.append(label) 37 | doc_tok_collect.append(doc2snt2wd) 38 | return label_collect, doc_tok_collect 39 | 40 | def constructLabel_dict(labels, savepath): 41 | label_freq = defaultdict(int) 42 | for i in labels: 43 | label_freq[i] += 1 44 | class_num = len(label_freq.values()) 45 | id2revfreq = {} 46 | dinominator = float(sum(label_freq.values())) 47 | if args.label2id is None: 48 | label2id = dict(list(zip(label_freq.keys(), [int(o)-1 for o in label_freq.keys()]))) 49 | else: 50 | with open(args.label2id, 'rb') as fd: 51 | label2id = pkl.load(fd) 52 | id2label = {idx: label for label, idx in label2id.items()} 53 | for item in id2label: 54 | label = id2label[item] 55 | freq = label_freq[label] 56 | id2revfreq[item] = float(dinominator)/float(freq) 57 | dino = float(sum(id2revfreq.values())) 58 | id2weight = {idx: class_num * revfreq/dino for idx, revfreq in id2revfreq.items()} 59 | 60 | with open(savepath, 'wb') as fd: 61 | pkl.dump(label2id, fd) 62 | pkl.dump(id2label, fd) 63 | pkl.dump(id2revfreq, fd) 64 | pkl.dump(id2weight, fd) 65 | 66 | def loadLabel_dict(savepath): 67 | with open(savepath, 'rb') as fd: 68 | label2id = pkl.load(fd) 69 | id2label = pkl.load(fd) 70 | id2revfreq = pkl.load(fd) 71 | id2weight = pkl.load(fd) 72 | return label2id, id2label, id2revfreq, id2weight 73 | 74 | 75 | def readEmbedding(fileName): 76 | """ 77 | Read Embedding Function 78 | 79 | Args: 80 | fileName : file which stores the embedding 81 | Returns: 82 | embeddings_index : a dictionary contains the mapping from word to vector 83 | """ 84 | embeddings_index = {} 85 | with open(fileName, 'r') as f: 86 | for line in f: 87 | line_uni = line.strip() 88 | values = line_uni.split(' ') 89 | if len(values) != 301: 90 | continue 91 | word = values[0] 92 | w2v_line = ' '.join(values) 93 | embeddings_index[word] = w2v_line 94 | return embeddings_index 95 | 96 | def buildEmbedding(src_embed_file, tgt_embed_file, word_dict): 97 | emb_dict = readEmbedding(src_embed_file) 98 | with open(tgt_embed_file, 'w') as fd: 99 | for word in word_dict: 100 | if word in emb_dict: 101 | fd.writelines(emb_dict[word]+'\n') 102 | return None 103 | 104 | if __name__ == '__main__': 105 | vocab = Vocab() 106 | tok_collect = [] 107 | labels_collect = [] 108 | if args.train_set: 109 | train_label, train_toks = extract(args.train_set) 110 | tok_collect.append(train_toks) 111 | labels_collect.append(train_label) 112 | if args.dev_set: 113 | dev_label, dev_toks = extract(args.dev_set) 114 | tok_collect.append(dev_toks) 115 | labels_collect.append(dev_label) 116 | if args.test_set: 117 | test_label, test_toks = extract(args.test_set) 118 | 119 | vocab.construct(nest.flatten(tok_collect)) 120 | vocab.limit_vocab_length(base_freq=3) 121 | vocab.save_vocab(os.path.join(args.dest_dir, 'vocab.pkl')) 122 | 123 | constructLabel_dict(nest.flatten(labels_collect), os.path.join(args.dest_dir, 'label2id.pkl')) 124 | 125 | vocab = Vocab() 126 | vocab.load_vocab_from_file(os.path.join(args.dest_dir, 'vocab.pkl')) 127 | 128 | buildEmbedding(args.ref_emb, os.path.join(args.dest_dir, 'embedding.txt'), vocab.word_to_index) 129 | 130 | label2id, id2label, id2revfreq, id2weight = loadLabel_dict(os.path.join(args.dest_dir, 'label2id.pkl')) 131 | 132 | if args.train_set: 133 | train_label = [label2id[o] for o in train_label] 134 | train_toks = nest.map_structure(lambda x: vocab.encode(x), train_toks) 135 | train_set = [o for o in zip(train_label, train_toks)] 136 | with open(os.path.join(args.dest_dir, 'trainset.pkl'), 'wb') as fd: 137 | pkl.dump(train_set, fd) 138 | if args.dev_set: 139 | dev_label = [label2id[o] for o in dev_label] 140 | dev_toks = nest.map_structure(lambda x: vocab.encode(x), dev_toks) 141 | dev_set = [o for o in zip(dev_label, dev_toks)] 142 | with open(os.path.join(args.dest_dir, 'devset.pkl'), 'wb') as fd: 143 | pkl.dump(dev_set, fd) 144 | if args.test_set: 145 | test_label = [label2id[o] for o in test_label] 146 | test_toks = nest.map_structure(lambda x: vocab.encode(x), test_toks) 147 | test_set = [o for o in zip(test_label, test_toks)] 148 | with open(os.path.join(args.dest_dir, 'testset.pkl'), 'wb') as fd: 149 | pkl.dump(test_set, fd) 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /caps_attn_hierarchical/dataprocess/dataprocess_sentence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pkl 3 | import os, operator 4 | from collections import defaultdict 5 | from tensorflow.python.util import nest 6 | from vocab import Vocab 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description="datasets") 10 | 11 | parser.add_argument('--train-set', action='store', dest='train_set', default=None) 12 | parser.add_argument('--dev-set', action='store', dest='dev_set', default=None) 13 | parser.add_argument('--test-set', action='store', dest='test_set', default=None) 14 | parser.add_argument('--ref-embedding', action='store', dest='ref_emb', default='/home/jjgong/data/glove300d/glove.840B.300d.txt') 15 | parser.add_argument('--dest-dir', action='store', dest='dest_dir', default='./') 16 | parser.add_argument('--label2id', action='store', dest='label2id', default=None) 17 | 18 | args = parser.parse_args() 19 | 20 | def extract(fn): 21 | label_collect = [] 22 | snt_tok_collect = [] 23 | with open(fn, 'r') as fd: 24 | for line in fd: 25 | item = line.strip().split('\t\t') 26 | try: 27 | label = item[0] 28 | snt = item[1] 29 | except: 30 | print(line) 31 | print(item) 32 | raise ValueError 33 | 34 | snt2wd = snt.strip().split(' ') 35 | label_collect.append(label) 36 | snt_tok_collect.append(snt2wd) 37 | return label_collect, snt_tok_collect 38 | 39 | def constructLabel_dict(labels, savepath): 40 | label_freq = defaultdict(int) 41 | for i in labels: 42 | label_freq[i] += 1 43 | class_num = len(label_freq.values()) 44 | id2revfreq = {} 45 | dinominator = float(sum(label_freq.values())) 46 | if args.label2id is None: 47 | label2id = dict(list(zip(label_freq.keys(), [int(o)-1 for o in label_freq.keys()]))) 48 | else: 49 | with open(args.label2id, 'rb') as fd: 50 | label2id = pkl.load(fd) 51 | id2label = {idx: label for label, idx in label2id.items()} 52 | for item in id2label: 53 | label = id2label[item] 54 | freq = label_freq[label] 55 | id2revfreq[item] = float(dinominator)/float(freq) 56 | dino = float(sum(id2revfreq.values())) 57 | id2weight = {idx: class_num * revfreq/dino for idx, revfreq in id2revfreq.items()} 58 | 59 | with open(savepath, 'wb') as fd: 60 | pkl.dump(label2id, fd) 61 | pkl.dump(id2label, fd) 62 | pkl.dump(id2revfreq, fd) 63 | pkl.dump(id2weight, fd) 64 | 65 | def loadLabel_dict(savepath): 66 | with open(savepath, 'rb') as fd: 67 | label2id = pkl.load(fd) 68 | id2label = pkl.load(fd) 69 | id2revfreq = pkl.load(fd) 70 | id2weight = pkl.load(fd) 71 | return label2id, id2label, id2revfreq, id2weight 72 | 73 | 74 | def readEmbedding(fileName): 75 | """ 76 | Read Embedding Function 77 | 78 | Args: 79 | fileName : file which stores the embedding 80 | Returns: 81 | embeddings_index : a dictionary contains the mapping from word to vector 82 | """ 83 | embeddings_index = {} 84 | with open(fileName, 'r') as f: 85 | for line in f: 86 | line_uni = line.strip() 87 | values = line_uni.split(' ') 88 | if len(values) != 301: 89 | continue 90 | word = values[0] 91 | w2v_line = ' '.join(values) 92 | embeddings_index[word] = w2v_line 93 | return embeddings_index 94 | 95 | def buildEmbedding(src_embed_file, tgt_embed_file, word_dict): 96 | emb_dict = readEmbedding(src_embed_file) 97 | with open(tgt_embed_file, 'w') as fd: 98 | for word in word_dict: 99 | if word in emb_dict: 100 | fd.writelines(emb_dict[word]+'\n') 101 | return None 102 | 103 | if __name__ == '__main__': 104 | vocab = Vocab() 105 | tok_collect = [] 106 | labels_collect = [] 107 | if args.train_set: 108 | train_label, train_toks = extract(args.train_set) 109 | tok_collect.append(train_toks) 110 | labels_collect.append(train_label) 111 | if args.dev_set: 112 | dev_label, dev_toks = extract(args.dev_set) 113 | tok_collect.append(dev_toks) 114 | labels_collect.append(dev_label) 115 | if args.test_set: 116 | test_label, test_toks = extract(args.test_set) 117 | 118 | vocab.construct(nest.flatten(tok_collect)) 119 | vocab.limit_vocab_length(base_freq=3) 120 | vocab.save_vocab(os.path.join(args.dest_dir, 'vocab.pkl')) 121 | 122 | constructLabel_dict(nest.flatten(labels_collect), os.path.join(args.dest_dir, 'label2id.pkl')) 123 | 124 | vocab = Vocab() 125 | vocab.load_vocab_from_file(os.path.join(args.dest_dir, 'vocab.pkl')) 126 | 127 | buildEmbedding(args.ref_emb, os.path.join(args.dest_dir, 'embedding.txt'), vocab.word_to_index) 128 | 129 | label2id, id2label, id2revfreq, id2weight = loadLabel_dict(os.path.join(args.dest_dir, 'label2id.pkl')) 130 | 131 | if args.train_set: 132 | train_label = [label2id[o] for o in train_label] 133 | train_toks = nest.map_structure(lambda x: vocab.encode(x), train_toks) 134 | train_set = [o for o in zip(train_label, train_toks)] 135 | with open(os.path.join(args.dest_dir, 'trainset.pkl'), 'wb') as fd: 136 | pkl.dump(train_set, fd) 137 | if args.dev_set: 138 | dev_label = [label2id[o] for o in dev_label] 139 | dev_toks = nest.map_structure(lambda x: vocab.encode(x), dev_toks) 140 | dev_set = [o for o in zip(dev_label, dev_toks)] 141 | with open(os.path.join(args.dest_dir, 'devset.pkl'), 'wb') as fd: 142 | pkl.dump(dev_set, fd) 143 | if args.test_set: 144 | test_label = [label2id[o] for o in test_label] 145 | test_toks = nest.map_structure(lambda x: vocab.encode(x), test_toks) 146 | test_set = [o for o in zip(test_label, test_toks)] 147 | with open(os.path.join(args.dest_dir, 'testset.pkl'), 'wb') as fd: 148 | pkl.dump(test_set, fd) 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /caps_attn_hierarchical/dataprocess/vocab.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import _pickle as pkl 3 | from collections import defaultdict 4 | 5 | class Vocab(object): 6 | def __init__(self, id_start=5): 7 | self.word_to_index = {} 8 | self.index_to_word = {} 9 | self.word_freq = defaultdict(int) 10 | self.id_start = id_start 11 | 12 | def add_word(self, word, count=1): 13 | word = word.strip() 14 | if len(word) == 0: 15 | return 16 | elif word.isspace(): 17 | return 18 | if word not in self.word_to_index: 19 | index = len(self.word_to_index) 20 | self.word_to_index[word] = index 21 | self.index_to_word[index] = word 22 | self.word_freq[word] += count 23 | 24 | def construct(self, words): 25 | for word in words: 26 | self.add_word(word) 27 | total_words = float(sum(self.word_freq.values())) 28 | 29 | '''sort by word frequency''' 30 | new_word_to_index = {} 31 | new_index_to_word = {} 32 | sorted_tup = sorted(self.word_freq.items(), key=operator.itemgetter(1)) 33 | sorted_tup.reverse() 34 | self.word_freq = dict(sorted_tup) 35 | for idx, (word, freq) in enumerate(sorted_tup): 36 | index = self.id_start + idx 37 | new_word_to_index[word] = index 38 | new_index_to_word[index] = word 39 | 40 | self.word_to_index = new_word_to_index 41 | self.index_to_word = new_index_to_word 42 | 43 | print('{} total words with {} uniques'.format(total_words, len(self.word_freq))) 44 | 45 | def limit_vocab_length(self, base_freq): 46 | """ 47 | Truncate vocabulary to keep most frequent words 48 | 49 | Args: 50 | None 51 | 52 | Returns: 53 | None 54 | """ 55 | 56 | new_word_to_index = {} 57 | new_index_to_word = {} 58 | sorted_tup = sorted(self.word_freq.items(), key=operator.itemgetter(1)) 59 | sorted_tup.reverse() 60 | vocab_tup = [item for item in sorted_tup if item[1] > base_freq] 61 | self.word_freq = dict(vocab_tup) 62 | for idx, (word, freq) in enumerate(vocab_tup): 63 | index = self.id_start + idx 64 | new_word_to_index[word] = index 65 | new_index_to_word[index] = word 66 | self.word_to_index = new_word_to_index 67 | self.index_to_word = new_index_to_word 68 | 69 | def save_vocab(self, filePath): 70 | """ 71 | Save vocabulary a offline file 72 | 73 | Args: 74 | filePath: where you want to save your vocabulary, every line in the 75 | file represents a word with a tab seperating word and it's frequency 76 | 77 | Returns: 78 | None 79 | """ 80 | with open(filePath, 'wb') as fd: 81 | pkl.dump(self.word_to_index, fd) 82 | pkl.dump(self.index_to_word, fd) 83 | pkl.dump(self.word_freq, fd) 84 | 85 | def load_vocab_from_file(self, filePath): 86 | """ 87 | Truncate vocabulary to keep most frequent words 88 | 89 | Args: 90 | filePath: vocabulary file path, every line in the file represents 91 | a word with a tab seperating word and it's frequency 92 | 93 | Returns: 94 | None 95 | """ 96 | with open(filePath, 'rb') as fd: 97 | self.word_to_index = pkl.load(fd) 98 | self.index_to_word = pkl.load(fd) 99 | self.word_freq = pkl.load(fd) 100 | 101 | print('load from <' + filePath + '>, there are {} words in dictionary'.format(len(self.word_freq))) 102 | 103 | def encode(self, word): 104 | if word not in self.word_to_index: 105 | return 1 #unk 106 | else: 107 | return self.word_to_index[word] 108 | 109 | def decode(self, index): 110 | if index not in self.index_to_word: 111 | return 'pad/unk' 112 | return self.index_to_word[index] 113 | 114 | def __len__(self): 115 | return len(self.word_to_index) -------------------------------------------------------------------------------- /caps_attn_hierarchical/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Sep 21, 2016 3 | 4 | @author: jerrik 5 | ''' 6 | 7 | import os 8 | import sys 9 | import time 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | import utils, nest 14 | from TfUtils import entry_stop_gradients, mkMask, reduce_avg, masked_softmax 15 | from Capsule_masked import Capusule 16 | 17 | class model(object): 18 | """Abstracts a Tensorflow graph for a learning task. 19 | 20 | We use various Model classes as usual abstractions to encapsulate tensorflow 21 | computational graphs. Each algorithm you will construct in this homework will 22 | inherit from a Model object. 23 | """ 24 | def __init__(self, config): 25 | """options in this function""" 26 | self.config = config 27 | self.EX_REG_SCOPE = [] 28 | 29 | self.on_epoch = tf.Variable(0, name='epoch_count', trainable=False) 30 | self.on_epoch_accu = tf.assign_add(self.on_epoch, 1) 31 | 32 | self.build() 33 | 34 | def add_placeholders(self): 35 | # shape(b_sz, sNum, wNum) 36 | self.ph_input = tf.placeholder(shape=(None, None, None), dtype=tf.int32, name='ph_input') 37 | 38 | # shape(bsz) 39 | self.ph_labels = tf.placeholder(shape=(None,), dtype=tf.int32, name='ph_labels') 40 | 41 | # [b_sz] 42 | self.ph_sNum = tf.placeholder(shape=(None,), dtype=tf.int32, name='ph_sNum') 43 | 44 | # shape(b_sz, sNum) 45 | self.ph_wNum = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ph_wNum') 46 | 47 | self.ph_sample_weights = tf.placeholder(shape=(None,), dtype=tf.float32, name='ph_sample_weights') 48 | self.ph_train = tf.placeholder(dtype=tf.bool, name='ph_train') 49 | 50 | def create_feed_dict(self, data_batch, train): 51 | '''data_batch: label_ids, snt1_matrix, snt2_matrix, snt1_len, snt2_len''' 52 | 53 | phs = (self.ph_input, self.ph_labels, self.ph_sNum, self.ph_wNum, self.ph_sample_weights, self.ph_train) 54 | feed_dict = dict(zip(phs, data_batch+(train,))) 55 | return feed_dict 56 | 57 | def add_embedding(self): 58 | """Add embedding layer. that maps from vocabulary to vectors. 59 | inputs: a list of tensors each of which have a size of [batch_size, embed_size] 60 | """ 61 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 62 | vocab_sz = max(self.config.vocab_dict.values()) 63 | with tf.variable_scope('embedding') as scp: 64 | self.exclude_reg_scope(scp) 65 | if self.config.pre_trained: 66 | embed = utils.readEmbedding(self.config.embed_path) 67 | embed_matrix, valid_mask = utils.mkEmbedMatrix(embed, dict(self.config.vocab_dict)) 68 | embedding = tf.Variable(embed_matrix, 'Embedding') 69 | partial_update_embedding = entry_stop_gradients(embedding, tf.expand_dims(valid_mask, 1)) 70 | embedding = tf.cond(self.on_epoch < self.config.partial_update_until_epoch, 71 | lambda: partial_update_embedding, lambda: embedding) 72 | else: 73 | embedding = tf.get_variable( 74 | 'Embedding', 75 | [vocab_sz, self.config.embed_size], trainable=True) 76 | return embedding 77 | 78 | def embed_lookup(self, embedding, batch_x, dropout=None, is_train=False): 79 | ''' 80 | 81 | :param embedding: shape(v_sz, emb_sz) 82 | :param batch_x: shape(b_sz, sNum, wNum) 83 | :return: shape(b_sz, sNum, wNum, emb_sz) 84 | ''' 85 | inputs = tf.nn.embedding_lookup(embedding, batch_x) 86 | if dropout is not None: 87 | inputs = tf.layers.dropout(inputs, rate=dropout, training=is_train) 88 | return inputs 89 | 90 | def hierachical_attention(self, in_x, sNum, wNum, scope=None): 91 | ''' 92 | 93 | :param in_x: shape(b_sz, ststp, wtstp, emb_sz) 94 | :param sNum: shape(b_sz, ) 95 | :param wNum: shape(b_sz, ststp) 96 | :param scope: 97 | :return: 98 | ''' 99 | b_sz, ststp, wtstp, _ = tf.unstack(tf.shape(in_x)) 100 | emb_sz = int(in_x.get_shape()[-1]) 101 | with tf.variable_scope(scope or 'hierachical_attention'): 102 | flatten_in_x = tf.reshape(in_x, [b_sz*ststp, wtstp, emb_sz]) 103 | flatten_wNum = tf.reshape(wNum, [b_sz * ststp]) 104 | 105 | with tf.variable_scope('sentence_enc'): 106 | if self.config.seq_encoder == 'bigru': 107 | flatten_birnn_x = self.biGRU(flatten_in_x, flatten_wNum, 108 | self.config.hidden_size, scope='biGRU') 109 | elif self.config.seq_encoder == 'bilstm': 110 | flatten_birnn_x = self.biLSTM(flatten_in_x, flatten_wNum, 111 | self.config.hidden_size, scope='biLSTM') 112 | else: 113 | raise ValueError('no such encoder %s'%self.config.seq_encoder) 114 | 115 | '''shape(b_sz*sNum, dim)''' 116 | if self.config.attn_mode == 'avg': 117 | flatten_attn_ctx = reduce_avg(flatten_birnn_x, flatten_wNum, dim=1) 118 | elif self.config.attn_mode == 'attn': 119 | flatten_attn_ctx = self.task_specific_attention(flatten_birnn_x, flatten_wNum, 120 | int(flatten_birnn_x.get_shape()[-1]), 121 | dropout=self.config.dropout, 122 | is_train=self.ph_train, scope='attention') 123 | elif self.config.attn_mode == 'rout': 124 | flatten_attn_ctx = self.routing_masked(flatten_birnn_x, flatten_wNum, 125 | int(flatten_birnn_x.get_shape()[-1]), 126 | self.config.out_caps_num, iter=self.config.rout_iter, 127 | dropout=self.config.dropout, 128 | is_train=self.ph_train, scope='rout') 129 | elif self.config.attn_mode == 'Rrout': 130 | flatten_attn_ctx = self.reverse_routing_masked(flatten_birnn_x, flatten_wNum, 131 | int(flatten_birnn_x.get_shape()[-1]), 132 | self.config.out_caps_num, 133 | iter=self.config.rout_iter, 134 | dropout=self.config.dropout, 135 | is_train=self.ph_train, scope='Rrout') 136 | else: 137 | raise ValueError('no such attn mode %s' % self.config.attn_mode) 138 | snt_dim = int(flatten_attn_ctx.get_shape()[-1]) 139 | snt_reps = tf.reshape(flatten_attn_ctx, shape=[b_sz, ststp, snt_dim]) 140 | 141 | with tf.variable_scope('doc_enc'): 142 | if self.config.seq_encoder == 'bigru': 143 | birnn_snt = self.biGRU(snt_reps, sNum, self.config.hidden_size, scope='biGRU') 144 | elif self.config.seq_encoder == 'bilstm': 145 | birnn_snt = self.biLSTM(snt_reps, sNum, self.config.hidden_size, scope='biLSTM') 146 | else: 147 | raise ValueError('no such encoder %s'%self.config.seq_encoder) 148 | 149 | '''shape(b_sz, dim)''' 150 | if self.config.attn_mode == 'avg': 151 | doc_rep = reduce_avg(birnn_snt, sNum, dim=1) 152 | elif self.config.attn_mode == 'max': 153 | doc_rep = tf.reduce_max(birnn_snt, axis=1) 154 | elif self.config.attn_mode == 'attn': 155 | doc_rep = self.task_specific_attention(birnn_snt, sNum, 156 | int(birnn_snt.get_shape()[-1]), 157 | dropout=self.config.dropout, 158 | is_train=self.ph_train, scope='attention') 159 | elif self.config.attn_mode == 'rout': 160 | doc_rep = self.routing_masked(birnn_snt, sNum, 161 | int(birnn_snt.get_shape()[-1]), 162 | self.config.out_caps_num, 163 | iter=self.config.rout_iter, 164 | dropout=self.config.dropout, 165 | is_train=self.ph_train, scope='attention') 166 | elif self.config.attn_mode == 'Rrout': 167 | doc_rep = self.reverse_routing_masked(birnn_snt, sNum, 168 | int(birnn_snt.get_shape()[-1]), 169 | self.config.out_caps_num, 170 | iter=self.config.rout_iter, 171 | dropout=self.config.dropout, 172 | is_train=self.ph_train, scope='attention') 173 | else: 174 | raise ValueError('no such attn mode %s' % self.config.attn_mode) 175 | return doc_rep 176 | 177 | def build(self): 178 | self.add_placeholders() 179 | self.embedding = self.add_embedding() 180 | '''shape(b_sz, ststp, wtstp, emb_sz)''' 181 | in_x = self.embed_lookup(self.embedding, self.ph_input, 182 | dropout=self.config.dropout, is_train=self.ph_train) 183 | doc_reps = self.hierachical_attention(in_x, self.ph_sNum, self.ph_wNum, scope='hierachical_attn') 184 | 185 | with tf.variable_scope('classifier'): 186 | logits = self.Dense(doc_reps, dropout=self.config.dropout, 187 | is_train=self.ph_train, activation=tf.nn.tanh) 188 | opt_loss = self.add_loss_op(logits, self.ph_labels) 189 | train_op = self.add_train_op(opt_loss) 190 | self.train_op = train_op 191 | self.opt_loss = opt_loss 192 | tf.summary.scalar('accuracy', self.accuracy) 193 | tf.summary.scalar('ce_loss', self.ce_loss) 194 | tf.summary.scalar('opt_loss', self.opt_loss) 195 | tf.summary.scalar('w_loss', self.w_loss) 196 | 197 | def Dense(self, inputs, dropout=None, is_train=False, activation=None): 198 | loop_input = inputs 199 | if self.config.dense_hidden[-1] != self.config.class_num: 200 | raise ValueError('last hidden layer should be %d, but get %d' % 201 | (self.config.class_num, 202 | self.config.dense_hidden[-1])) 203 | for i, hid_num in enumerate(self.config.dense_hidden): 204 | with tf.variable_scope('dense-layer-%d' % i): 205 | loop_input = tf.layers.dense(loop_input, units=hid_num) 206 | 207 | if i < len(self.config.dense_hidden) - 1: 208 | if dropout is not None: 209 | loop_input = tf.layers.dropout(loop_input, rate=dropout, training=is_train) 210 | loop_input = activation(loop_input) 211 | 212 | logits = loop_input 213 | return logits 214 | 215 | def add_loss_op(self, logits, labels): 216 | ''' 217 | 218 | :param logits: shape(b_sz, c_num) type(float) 219 | :param labels: shape(b_sz,) type(int) 220 | :return: 221 | ''' 222 | 223 | self.prediction = tf.argmax(logits, axis=-1, output_type=labels.dtype) 224 | 225 | self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.prediction, labels), tf.float32)) 226 | 227 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 228 | ce_loss = tf.reduce_mean(loss) 229 | 230 | exclude_vars = nest.flatten([[v for v in tf.trainable_variables(o.name)] for o in self.EX_REG_SCOPE]) 231 | exclude_vars_2 = [v for v in tf.trainable_variables() if '/bias:' in v.name] 232 | exclude_vars = exclude_vars + exclude_vars_2 233 | 234 | reg_var_list = [v for v in tf.trainable_variables() if v not in exclude_vars] 235 | reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in reg_var_list]) 236 | self.param_cnt = np.sum([np.prod(v.get_shape().as_list()) for v in reg_var_list]) 237 | 238 | print('===' * 20) 239 | print('total reg parameter count: %.3f M' % (self.param_cnt / 1000000.)) 240 | print('excluded variables from regularization') 241 | print([v.name for v in exclude_vars]) 242 | print('===' * 20) 243 | 244 | print('regularized variables') 245 | print(['%s:%.3fM' % (v.name, np.prod(v.get_shape().as_list()) / 1000000.) for v in reg_var_list]) 246 | print('===' * 20) 247 | '''shape(b_sz,)''' 248 | self.ce_loss = ce_loss 249 | self.w_loss = tf.reduce_mean(tf.multiply(loss, self.ph_sample_weights)) 250 | reg = self.config.reg 251 | 252 | return self.ce_loss + reg * reg_loss 253 | 254 | def add_train_op(self, loss): 255 | 256 | lr = tf.train.exponential_decay(self.config.lr, self.global_step, 257 | self.config.decay_steps, 258 | self.config.decay_rate, staircase=True) 259 | self.learning_rate = tf.maximum(lr, 1e-5) 260 | if self.config.optimizer == 'adam': 261 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 262 | elif self.config.optimizer == 'grad': 263 | optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 264 | elif self.config.optimizer == 'adgrad': 265 | optimizer = tf.train.AdagradOptimizer(self.learning_rate) 266 | elif self.config.optimizer == 'adadelta': 267 | optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 268 | else: 269 | raise ValueError('No such Optimizer: %s' % self.config.optimizer) 270 | 271 | gvs = optimizer.compute_gradients(loss=loss) 272 | 273 | capped_gvs = [(tf.clip_by_value(grad, -2., 2.), var) for grad, var in gvs] 274 | train_op = optimizer.apply_gradients(capped_gvs, global_step=self.global_step) 275 | return train_op 276 | 277 | def exclude_reg_scope(self, scope): 278 | if scope not in self.EX_REG_SCOPE: 279 | self.EX_REG_SCOPE.append(scope) 280 | 281 | @staticmethod 282 | def biLSTM(in_x, xLen, h_sz, dropout=None, is_train=False, scope=None): 283 | 284 | with tf.variable_scope(scope or 'biLSTM'): 285 | cell_fwd = tf.nn.rnn_cell.BasicLSTMCell(h_sz) 286 | cell_bwd = tf.nn.rnn_cell.BasicLSTMCell(h_sz) 287 | x_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_fwd, cell_bwd, in_x, xLen, 288 | dtype=tf.float32, swap_memory=True, 289 | scope='birnn') 290 | 291 | x_out = tf.concat(x_out, axis=2) 292 | if dropout is not None: 293 | x_out = tf.layers.dropout(x_out, rate=dropout, training=is_train) 294 | return x_out 295 | 296 | @staticmethod 297 | def biGRU(in_x, xLen, h_sz, dropout=None, is_train=False, scope=None): 298 | 299 | with tf.variable_scope(scope or 'biGRU'): 300 | cell_fwd = tf.nn.rnn_cell.GRUCell(h_sz) 301 | cell_bwd = tf.nn.rnn_cell.GRUCell(h_sz) 302 | x_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_fwd, cell_bwd, in_x, xLen, 303 | dtype=tf.float32, swap_memory=True, 304 | scope='birnn') 305 | 306 | x_out = tf.concat(x_out, axis=2) 307 | if dropout is not None: 308 | x_out = tf.layers.dropout(x_out, rate=dropout, training=is_train) 309 | return x_out 310 | 311 | @staticmethod 312 | def task_specific_attention(in_x, xLen, out_sz, activation_fn=tf.tanh, 313 | dropout=None, is_train=False, scope=None): 314 | ''' 315 | 316 | :param in_x: shape(b_sz, tstp, dim) 317 | :param xLen: shape(b_sz,) 318 | :param out_sz: scalar 319 | :param activation_fn: activation 320 | :param dropout: 321 | :param is_train: 322 | :param scope: 323 | :return: 324 | ''' 325 | 326 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 327 | 328 | with tf.variable_scope(scope or 'attention') as scope: 329 | context_vector = tf.get_variable(name='context_vector', shape=[out_sz], 330 | dtype=tf.float32) 331 | in_x_mlp = tf.layers.dense(in_x, out_sz, activation=activation_fn, name='mlp') 332 | 333 | attn = tf.tensordot(in_x_mlp, context_vector, axes=[[2], [0]]) # shape(b_sz, tstp) 334 | attn_normed = masked_softmax(attn, xLen) 335 | 336 | attn_normed = tf.expand_dims(attn_normed, axis=-1) 337 | attn_ctx = tf.matmul(in_x_mlp, attn_normed, transpose_a=True) # shape(b_sz, dim, 1) 338 | attn_ctx = tf.squeeze(attn_ctx, axis=[2]) # shape(b_sz, dim) 339 | if dropout is not None: 340 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 341 | return attn_ctx 342 | 343 | @staticmethod 344 | def routing_masked(in_x, xLen, out_sz, out_caps_num, iter=3, 345 | dropout=None, is_train=False, scope=None): 346 | ''' 347 | 348 | :param in_x: shape(b_sz, tstp, dim) 349 | :param xLen: shape(b_sz,) 350 | :param out_sz: scalar 351 | :param dropout: 352 | :param is_train: 353 | :param scope: 354 | :return: 355 | ''' 356 | 357 | 358 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 359 | b_sz = tf.shape(in_x)[0] 360 | with tf.variable_scope(scope or 'routing'): 361 | attn_ctx = Capusule(out_caps_num, out_sz, iter)(in_x, xLen) # shape(b_sz, out_caps_num, out_sz) 362 | attn_ctx = tf.reshape(attn_ctx, shape=[b_sz, out_caps_num*out_sz]) 363 | if dropout is not None: 364 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 365 | return attn_ctx 366 | 367 | @staticmethod 368 | def reverse_routing_masked(in_x, xLen, out_sz, out_caps_num, iter=3, 369 | dropout=None, is_train=False, scope=None): 370 | ''' 371 | 372 | :param in_x: shape(b_sz, tstp, dim) 373 | :param xLen: shape(b_sz,) 374 | :param out_sz: scalar 375 | :param dropout: 376 | :param is_train: 377 | :param scope: 378 | :return: 379 | ''' 380 | 381 | assert len(in_x.get_shape()) == 3 and in_x.get_shape()[-1].value is not None 382 | b_sz = tf.shape(in_x)[0] 383 | with tf.variable_scope(scope or 'routing'): 384 | '''shape(b_sz, out_caps_num, out_sz)''' 385 | attn_ctx = Capusule(out_caps_num, out_sz, iter)(in_x, xLen, reverse_routing=True) 386 | attn_ctx = tf.reshape(attn_ctx, shape=[b_sz, out_caps_num * out_sz]) 387 | if dropout is not None: 388 | attn_ctx = tf.layers.dropout(attn_ctx, rate=dropout, training=is_train) 389 | return attn_ctx -------------------------------------------------------------------------------- /caps_attn_hierarchical/train_test.py: -------------------------------------------------------------------------------- 1 | import argparse, sys, os, time, logging, threading, traceback 2 | import numpy as np 3 | import tensorflow as tf 4 | import _pickle as pkl 5 | import sys 6 | from multiprocessing import Queue, Process 7 | 8 | from Config import Config 9 | from model import model 10 | from data_iterator import TextIterator, preparedata 11 | from dataprocess.vocab import Vocab 12 | import utils 13 | 14 | _REVISION = 'first' 15 | 16 | parser = argparse.ArgumentParser(description="training options") 17 | 18 | parser.add_argument('--load-config', action='store_true', dest='load_config', default=False) 19 | parser.add_argument('--gpu-num', action='store', dest='gpu_num', default=0, type=int) 20 | parser.add_argument('--train-test', action='store', dest='train_test', default='train', choices=['train', 'test']) 21 | parser.add_argument('--weight-path', action='store', dest='weight_path', required=True) 22 | parser.add_argument('--restore-ckpt', action='store_true', dest='restore_ckpt', default=False) 23 | parser.add_argument('--retain-gpu', action='store_true', dest='retain_gpu', default=False) 24 | 25 | parser.add_argument('--debug-enable', action='store_true', dest='debug_enable', default=False) 26 | 27 | args = parser.parse_args() 28 | 29 | DEBUG = args.debug_enable 30 | if not DEBUG: 31 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 32 | 33 | def debug(s): 34 | if DEBUG: 35 | print(s) 36 | pass 37 | 38 | class Train: 39 | 40 | def __init__(self, args): 41 | if utils.valid_entry(args.weight_path) and not args.restore_ckpt\ 42 | and args.train_test != 'test': 43 | raise ValueError('process running or finished') 44 | 45 | gpu_lock = threading.Lock() 46 | gpu_lock.acquire() 47 | def retain_gpu(): 48 | if args.retain_gpu: 49 | with tf.Session(): 50 | gpu_lock.acquire() 51 | else: 52 | pass 53 | 54 | lockThread = threading.Thread(target=retain_gpu) 55 | lockThread.start() 56 | try: 57 | self.args = args 58 | config = Config() 59 | 60 | self.args = args 61 | self.weight_path = args.weight_path 62 | 63 | if args.load_config == False: 64 | config.saveConfig(self.weight_path + '/config') 65 | print('default configuration generated, please specify --load-config and run again.') 66 | gpu_lock.release() 67 | lockThread.join() 68 | sys.exit() 69 | else: 70 | if os.path.exists(self.weight_path + '/config'): 71 | config.loadConfig(self.weight_path + '/config') 72 | else: 73 | raise ValueError('No config file in %s' % self.weight_path) 74 | 75 | if config.revision != _REVISION: 76 | raise ValueError('revision dont match: %s over %s' % (config.revision, _REVISION)) 77 | 78 | vocab = Vocab() 79 | vocab.load_vocab_from_file(os.path.join(config.datapath, 'vocab.pkl')) 80 | config.vocab_dict = vocab.word_to_index 81 | with open(os.path.join(config.datapath, 'label2id.pkl'), 'rb') as fd: 82 | _ = pkl.load(fd) 83 | config.id2label = pkl.load(fd) 84 | _ = pkl.load(fd) 85 | config.id2weight = pkl.load(fd) 86 | 87 | config.class_num = len(config.id2label) 88 | self.config = config 89 | 90 | self.train_data = TextIterator(os.path.join(config.datapath, 'trainset.pkl'), self.config.batch_sz, 91 | bucket_sz=self.config.bucket_sz, shuffle=True) 92 | config.n_samples = self.train_data.num_example 93 | self.dev_data = TextIterator(os.path.join(config.datapath, 'devset.pkl'), self.config.batch_sz, 94 | bucket_sz=self.config.bucket_sz, shuffle=False) 95 | 96 | self.test_data = TextIterator(os.path.join(config.datapath, 'testset.pkl'), self.config.batch_sz, 97 | bucket_sz=self.config.bucket_sz, shuffle=False) 98 | 99 | self.data_q = Queue(10) 100 | 101 | self.model = model(config) 102 | 103 | except Exception as e: 104 | traceback.print_exc() 105 | gpu_lock.release() 106 | lockThread.join() 107 | exit() 108 | 109 | gpu_lock.release() 110 | lockThread.join() 111 | if utils.valid_entry(args.weight_path) and not args.restore_ckpt\ 112 | and args.train_test != 'test': 113 | raise ValueError('process running or finished') 114 | 115 | def get_epoch(self, sess): 116 | epoch = sess.run(self.model.on_epoch) 117 | return epoch 118 | 119 | def run_epoch(self, sess, input_data: TextIterator, verbose=10): 120 | """Runs an epoch of training. 121 | 122 | Trains the model for one-epoch. 123 | 124 | Args: 125 | sess: tf.Session() object 126 | Returns: 127 | average_loss: scalar. Average minibatch loss of model on epoch. 128 | """ 129 | total_steps = input_data.num_example // input_data.batch_sz 130 | total_loss = [] 131 | total_w_loss = [] 132 | total_ce_loss = [] 133 | collect_time = [] 134 | collect_data_time = [] 135 | accuracy_collect = [] 136 | step = -1 137 | dataset = [o for o in input_data] 138 | producer = Process(target=preparedata, 139 | args=(dataset, self.data_q, self.config.max_snt_num, 140 | self.config.max_wd_num, self.config.id2weight)) 141 | producer.start() 142 | while True: 143 | step += 1 144 | start_stamp = time.time() 145 | data_batch = self.data_q.get() 146 | if data_batch is None: 147 | break 148 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=True) 149 | 150 | data_stamp = time.time() 151 | (accuracy, global_step, summary, opt_loss, w_loss, ce_loss, lr, _ 152 | ) = sess.run([self.model.accuracy, self.model.global_step, self.merged, 153 | self.model.opt_loss, self.model.w_loss, self.model.ce_loss, 154 | self.model.learning_rate, self.model.train_op], 155 | feed_dict=feed_dict) 156 | self.train_writer.add_summary(summary, global_step) 157 | self.train_writer.flush() 158 | 159 | end_stamp = time.time() 160 | 161 | collect_time.append(end_stamp-start_stamp) 162 | collect_time.append(data_stamp - start_stamp) 163 | accuracy_collect.append(accuracy) 164 | total_loss.append(opt_loss) 165 | total_w_loss.append(w_loss) 166 | total_ce_loss.append(ce_loss) 167 | 168 | if verbose and step % verbose == 0: 169 | sys.stdout.write('\r%d / %d : opt_loss = %.4f, w_loss = %.4f, ce_loss = %.4f, %.3fs/iter, %.3fs/batch, ' 170 | 'lr = %f, accu = %.4f, b_sz = %d' % ( 171 | step, total_steps, np.mean(total_loss[-verbose:]),np.mean(total_w_loss[-verbose:]), 172 | np.mean(total_ce_loss[-verbose:]), np.mean(collect_time), np.mean(collect_data_time), lr, 173 | np.mean(accuracy_collect[-verbose:]), input_data.batch_sz)) 174 | collect_time = [] 175 | sys.stdout.flush() 176 | utils.write_status(self.weight_path) 177 | producer.join() 178 | 179 | sess.run(self.model.on_epoch_accu) 180 | 181 | return np.mean(total_ce_loss), np.mean(total_loss), np.mean(accuracy_collect) 182 | 183 | def fit(self, sess, input_data :TextIterator, verbose=10): 184 | """ 185 | Fit the model. 186 | 187 | Args: 188 | sess: tf.Session() object 189 | Returns: 190 | average_loss: scalar. Average minibatch loss of model on epoch. 191 | """ 192 | 193 | total_steps = input_data.num_example // input_data.batch_sz 194 | total_loss = [] 195 | total_ce_loss = [] 196 | collect_time = [] 197 | step = -1 198 | dataset = [o for o in input_data] 199 | producer = Process(target=preparedata, 200 | args=(dataset, self.data_q, self.config.max_snt_num, 201 | self.config.max_wd_num, self.config.id2weight)) 202 | producer.start() 203 | while True: 204 | step += 1 205 | data_batch = self.data_q.get() 206 | if data_batch is None: 207 | break 208 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=False) 209 | 210 | start_stamp = time.time() 211 | (global_step, summary, ce_loss, opt_loss, 212 | ) = sess.run([self.model.global_step, self.merged, self.model.ce_loss, 213 | self.model.opt_loss], feed_dict=feed_dict) 214 | 215 | self.test_writer.add_summary(summary, step+global_step) 216 | self.test_writer.flush() 217 | 218 | end_stamp = time.time() 219 | collect_time.append(end_stamp - start_stamp) 220 | total_ce_loss.append(ce_loss) 221 | total_loss.append(opt_loss) 222 | 223 | if verbose and step % verbose == 0: 224 | sys.stdout.write('\r%d / %d: ce_loss = %f, opt_loss = %f, %.3fs/iter' % ( 225 | step, total_steps, np.mean(total_ce_loss[-verbose:]), 226 | np.mean(total_loss[-verbose:]), np.mean(collect_time))) 227 | collect_time = [] 228 | sys.stdout.flush() 229 | print('\n') 230 | producer.join() 231 | return np.mean(total_ce_loss), np.mean(total_loss) 232 | 233 | def predict(self, sess, input_data: TextIterator, verbose=10): 234 | """ 235 | Args: 236 | sess: tf.Session() object 237 | Returns: 238 | average_loss: scalar. Average minibatch loss of model on epoch. 239 | """ 240 | total_steps = input_data.num_example // input_data.batch_sz 241 | collect_time = [] 242 | collect_pred = [] 243 | label_id = [] 244 | step = -1 245 | dataset = [o for o in input_data] 246 | producer = Process(target=preparedata, 247 | args=(dataset, self.data_q, self.config.max_snt_num, 248 | self.config.max_wd_num, self.config.id2weight)) 249 | producer.start() 250 | while True: 251 | step += 1 252 | data_batch = self.data_q.get() 253 | if data_batch is None: 254 | break 255 | feed_dict = self.model.create_feed_dict(data_batch=data_batch, train=False) 256 | 257 | start_stamp = time.time() 258 | pred = sess.run(self.model.prediction, feed_dict=feed_dict) 259 | end_stamp = time.time() 260 | collect_time.append(end_stamp - start_stamp) 261 | 262 | collect_pred.append(pred) 263 | label_id += data_batch[1].tolist() 264 | if verbose and step % verbose == 0: 265 | sys.stdout.write('\r%d / %d: , %.3fs/iter' % ( 266 | step, total_steps, np.mean(collect_time))) 267 | collect_time = [] 268 | sys.stdout.flush() 269 | print('\n') 270 | producer.join() 271 | res_pred = np.concatenate(collect_pred, axis=0) 272 | return res_pred, label_id 273 | 274 | def test_case(self, sess, data, onset='VALIDATION'): 275 | print('#' * 20, 'ON ' + onset + ' SET START ', '#' * 20) 276 | print("=" * 10 + ' '.join(sys.argv) + "=" * 10) 277 | epoch = self.get_epoch(sess) 278 | ce_loss, opt_loss = self.fit(sess, data) 279 | pred, label = self.predict(sess, data) 280 | 281 | (prec, recall, overall_prec, overall_recall, _ 282 | ) = utils.calculate_confusion_single(pred, label, len(self.config.id2label)) 283 | 284 | utils.print_confusion_single(prec, recall, overall_prec, overall_recall, self.config.id2label) 285 | accuracy = utils.calculate_accuracy_single(pred, label) 286 | 287 | print('%d th Epoch -- Overall %s accuracy is: %f' % (epoch, onset, accuracy)) 288 | logging.info('%d th Epoch -- Overall %s accuracy is: %f' % (epoch, onset, accuracy)) 289 | 290 | print('%d th Epoch -- Overall %s ce_loss is: %f, opt_loss is: %f' % (epoch, onset, ce_loss, opt_loss)) 291 | logging.info('%d th Epoch -- Overall %s ce_loss is: %f, opt_loss is: %f' % (epoch, onset, ce_loss, opt_loss)) 292 | print('#' * 20, 'ON ' + onset + ' SET END ', '#' * 20) 293 | return accuracy, ce_loss 294 | 295 | def train_run(self): 296 | logging.info('Training start') 297 | logging.info("Parameter count is: %d" % self.model.param_cnt) 298 | if not args.restore_ckpt: 299 | self.remove_file(self.args.weight_path + '/summary.log') 300 | saver = tf.train.Saver(max_to_keep=30) 301 | 302 | config = tf.ConfigProto() 303 | config.gpu_options.allow_growth = True 304 | config.allow_soft_placement = True 305 | with tf.Session(config=config) as sess: 306 | 307 | self.merged = tf.summary.merge_all() 308 | self.train_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_train', 309 | sess.graph) 310 | self.test_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_test') 311 | 312 | sess.run(tf.global_variables_initializer()) 313 | if args.restore_ckpt: 314 | saver.restore(sess, self.args.weight_path + '/classifier.weights') 315 | best_loss = np.Inf 316 | best_accuracy = 0 317 | best_val_epoch = self.get_epoch(sess) 318 | 319 | for _ in range(self.config.max_epochs): 320 | 321 | epoch = self.get_epoch(sess) 322 | print("=" * 20 + "Epoch ", epoch, "=" * 20) 323 | ce_loss, opt_loss, accuracy = self.run_epoch(sess, self.train_data, verbose=10) 324 | print('') 325 | print("Mean ce_loss in %dth epoch is: %f, Mean ce_loss is: %f,"%(epoch, ce_loss, opt_loss)) 326 | print('Mean training accuracy is : %.4f' % accuracy) 327 | logging.info('Mean training accuracy is : %.4f' % accuracy) 328 | logging.info("Mean ce_loss in %dth epoch is: %f, Mean ce_loss is: %f,"%(epoch, ce_loss, opt_loss)) 329 | print('=' * 50) 330 | val_accuracy, val_loss = self.test_case(sess, self.dev_data, onset='VALIDATION') 331 | test_accuracy, test_loss = self.test_case(sess, self.test_data, onset='TEST') 332 | self.save_loss_accu(self.args.weight_path + '/summary.log', train_loss=ce_loss, 333 | valid_loss=val_loss, test_loss=test_loss, 334 | valid_accu=val_accuracy, test_accu=test_accuracy, epoch=epoch) 335 | if best_accuracy < val_accuracy: 336 | best_accuracy = val_accuracy 337 | best_val_epoch = epoch 338 | if not os.path.exists(self.args.weight_path): 339 | os.makedirs(self.args.weight_path) 340 | logging.info('best epoch is %dth epoch' % best_val_epoch) 341 | saver.save(sess, self.args.weight_path + '/classifier.weights') 342 | else: 343 | b_sz = self.train_data.batch_sz//2 344 | max_b_sz = max([b_sz, self.config.batch_sz_min]) 345 | buck_sz = self.train_data.bucket_sz * 2 346 | buck_sz = min([self.train_data.num_example, buck_sz]) 347 | self.train_data.batch_sz = max_b_sz 348 | self.train_data.bucket_sz = buck_sz 349 | 350 | if epoch - best_val_epoch > self.config.early_stopping: 351 | logging.info("Normal Early stop") 352 | break 353 | utils.write_status(self.weight_path, finished=True) 354 | logging.info("Training complete") 355 | 356 | def test_run(self): 357 | 358 | saver = tf.train.Saver(max_to_keep=30) 359 | 360 | config = tf.ConfigProto() 361 | config.gpu_options.allow_growth = True 362 | config.allow_soft_placement = True 363 | with tf.Session(config=config) as sess: 364 | self.merged = tf.summary.merge_all() 365 | self.test_writer = tf.summary.FileWriter(self.args.weight_path + '/summary_test') 366 | 367 | sess.run(tf.global_variables_initializer()) 368 | saver.restore(sess, self.args.weight_path + '/classifier.weights') 369 | 370 | self.test_case(sess, self.test_data, onset='TEST') 371 | 372 | def main_run(self): 373 | 374 | if not os.path.exists(self.args.weight_path): 375 | os.makedirs(self.args.weight_path) 376 | logFile = self.args.weight_path + '/run.log' 377 | 378 | if self.args.train_test == "train": 379 | 380 | try: 381 | os.remove(logFile) 382 | except OSError: 383 | pass 384 | logging.basicConfig(filename=logFile, format='%(levelname)s %(asctime)s %(message)s', level=logging.INFO) 385 | debug('_main_run_') 386 | self.train_run() 387 | self.test_run() 388 | else: 389 | logging.basicConfig(filename=logFile, format='%(levelname)s %(asctime)s %(message)s', level=logging.INFO) 390 | self.test_run() 391 | 392 | @staticmethod 393 | def save_loss_accu(fileName, train_loss, valid_loss, 394 | test_loss, valid_accu, test_accu, epoch): 395 | with open(fileName, 'a') as fd: 396 | fd.write('%3d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\n' % 397 | (epoch, train_loss, valid_loss, 398 | test_loss, valid_accu, test_accu)) 399 | 400 | @staticmethod 401 | def remove_file(fileName): 402 | if os.path.exists(fileName): 403 | os.remove(fileName) 404 | 405 | if __name__ == '__main__': 406 | trainer = Train(args) 407 | trainer.main_run() 408 | 409 | -------------------------------------------------------------------------------- /caps_attn_hierarchical/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import _pickle as pkl 3 | import pdb 4 | import numpy as np 5 | import copy 6 | 7 | import os 8 | import warnings 9 | import sys 10 | from time import time 11 | import pprint 12 | import logging 13 | from collections import OrderedDict 14 | 15 | '''check alive''' 16 | def write_status(path, finished=False): 17 | full_path = path+'/status' 18 | if not finished: 19 | fd = open(full_path, 'w') 20 | fd.write(str(time())) 21 | fd.flush() 22 | fd.close() 23 | else: 24 | fd = open(full_path, 'w') 25 | fd.write('0.1') 26 | fd.flush() 27 | fd.close() 28 | 29 | def read_status(status_path): 30 | if not os.path.exists(status_path): 31 | return 'error' 32 | fd = open(status_path, 'r') 33 | time_stamp = float(fd.read().strip()) 34 | fd.close() 35 | if time_stamp < 10.: 36 | return 'finished' 37 | cur_time = time() 38 | if cur_time - time_stamp < 1000.: 39 | return 'running' 40 | else: 41 | return 'error' 42 | 43 | def valid_entry(save_path): 44 | 45 | if not os.path.exists(save_path): 46 | return False 47 | if read_status(save_path + '/status') == 'running': 48 | return True 49 | if read_status(save_path + '/status') == 'finished': 50 | return True 51 | if read_status(save_path + '/status') == 'error': 52 | return False 53 | 54 | raise ValueError('unknown error') 55 | 56 | def pad(x, len_x): 57 | if len(x) > len_x: 58 | return x[:len_x] 59 | return x+[0]* (len_x-len(x)) 60 | # batch preparation 61 | def prepare_data(seqs_x, seqs_char_x, seqs_pos_x, seqs_em_x, 62 | seqs_y, seqs_char_y, seqs_pos_y, seqs_em_y, 63 | labels, max_char_len): 64 | 65 | lengths_x = [len(s) for s in seqs_x] 66 | lengths_char_x = [len(s) for s in seqs_char_x] 67 | lengths_pos_x = [len(s) for s in seqs_pos_x] 68 | lengths_em_x = [len(s) for s in seqs_em_x] 69 | 70 | lengths_y = [len(s) for s in seqs_y] 71 | lengths_char_y = [len(s) for s in seqs_char_y] 72 | lengths_pos_y = [len(s) for s in seqs_pos_y] 73 | lengths_em_y = [len(s) for s in seqs_em_y] 74 | 75 | assert np.all(np.equal(lengths_x, lengths_pos_x)) 76 | assert np.all(np.equal(lengths_x, lengths_char_x)) 77 | assert np.all(np.equal(lengths_x, lengths_em_x)) 78 | 79 | assert np.all(np.equal(lengths_y, lengths_pos_y)) 80 | assert np.all(np.equal(lengths_y, lengths_char_y)) 81 | assert np.all(np.equal(lengths_y, lengths_em_y)) 82 | 83 | n_samples = len(seqs_x) 84 | maxlen_x = np.max(lengths_x) 85 | maxlen_y = np.max(lengths_y) 86 | 87 | seqs_char_x = [[pad(w_lst, max_char_len) for w_lst in snt] for snt in seqs_char_x] 88 | seqs_char_y = [[pad(w_lst, max_char_len) for w_lst in snt] for snt in seqs_char_y] 89 | 90 | x = np.zeros((n_samples, maxlen_x)).astype('int32') 91 | x_pos = np.zeros((n_samples, maxlen_x)).astype('int32') 92 | x_em = np.zeros((n_samples, maxlen_x)).astype('int32') 93 | x_char = np.zeros((n_samples, maxlen_x, max_char_len)).astype('int32') 94 | 95 | y = np.zeros((n_samples, maxlen_y)).astype('int32') 96 | y_pos = np.zeros((n_samples, maxlen_y)).astype('int32') 97 | y_em = np.zeros((n_samples, maxlen_y)).astype('int32') 98 | y_char = np.zeros((n_samples, maxlen_y, max_char_len)).astype('int32') 99 | 100 | l = np.zeros((n_samples,)).astype('int32') 101 | for idx, [s_x, s_char_x, s_pos_x, s_em_x, s_y, s_char_y, s_pos_y, s_em_y, ll] in enumerate(zip( 102 | seqs_x, seqs_char_x, seqs_pos_x, seqs_em_x, 103 | seqs_y, seqs_char_y, seqs_pos_y, seqs_em_y, labels)): 104 | 105 | x[idx, :lengths_x[idx]] = s_x 106 | x_char[idx, :lengths_x[idx]] = s_char_x 107 | x_pos[idx, :lengths_x[idx]] = s_pos_x 108 | x_em[idx, :lengths_x[idx]] = s_em_x 109 | 110 | y[idx, :lengths_y[idx]] = s_y 111 | y_char[idx, :lengths_y[idx]] = s_char_y 112 | y_pos[idx, :lengths_y[idx]] = s_pos_y 113 | y_em[idx, :lengths_y[idx]] = s_em_y 114 | 115 | l[idx] = ll 116 | 117 | return x, x_char, x_pos, x_em, lengths_x, y, y_char, y_pos, y_em, lengths_y, l 118 | 119 | '''===============================================================''' 120 | 121 | '''Read and make embedding''' 122 | 123 | def readEmbedding(fileName): 124 | """ 125 | Read Embedding Function 126 | 127 | Args: 128 | fileName : file which stores the embedding 129 | Returns: 130 | embeddings_index : a dictionary contains the mapping from word to vector 131 | """ 132 | embeddings_index = {} 133 | with open(fileName, 'rb') as f: 134 | for line in f: 135 | line_uni = line.strip() 136 | line_uni = line_uni.decode('utf-8') 137 | values = line_uni.split(' ') 138 | word = values[0] 139 | try: 140 | coefs = np.asarray(values[1:], dtype='float32') 141 | except: 142 | print(values, len(values)) 143 | embeddings_index[word] = coefs 144 | return embeddings_index 145 | 146 | def mkEmbedMatrix(embed_dic, vocab_dic): 147 | """ 148 | Construct embedding matrix 149 | 150 | Args: 151 | embed_dic : word-embedding dictionary 152 | vocab_dic : word-index dictionary 153 | Returns: 154 | embedding_matrix: return embedding matrix 155 | """ 156 | if type(embed_dic) is not dict or type(vocab_dic) is not dict: 157 | raise TypeError('Inputs are not dictionary') 158 | if len(embed_dic) < 1 or len(vocab_dic) < 1: 159 | raise ValueError('Input dimension less than 1') 160 | vocab_sz = max(vocab_dic.values()) + 1 161 | EMBEDDING_DIM = len(list(embed_dic.values())[0]) 162 | # embedding_matrix = np.zeros((len(vocab_dic), EMBEDDING_DIM), dtype=np.float32) 163 | embedding_matrix = np.random.rand(vocab_sz, EMBEDDING_DIM).astype(np.float32) * 0.05 164 | valid_mask = np.ones(vocab_sz, dtype=np.bool) 165 | for word, i in vocab_dic.items(): 166 | embedding_vector = embed_dic.get(word) 167 | if embedding_vector is not None: 168 | # words not found in embedding index will be all-zeros. 169 | embedding_matrix[i] = embedding_vector 170 | else: 171 | valid_mask[i] = False 172 | return embedding_matrix, valid_mask 173 | 174 | '''evaluation''' 175 | 176 | def pred_from_prob_single(prob_matrix): 177 | """ 178 | 179 | Args: 180 | prob_matrix: probability matrix have the shape of (data_num, class_num), 181 | type of float. Generated from softmax activation 182 | 183 | Returns: 184 | ret: return class ids, shape of(data_num,) 185 | """ 186 | ret = np.argmax(prob_matrix, axis=1) 187 | return ret 188 | 189 | 190 | def calculate_accuracy_single(pred_ids, label_ids): 191 | """ 192 | Args: 193 | pred_ids: prediction id list shape of (data_num, ), type of int 194 | label_ids: true label id list, same shape and type as pred_ids 195 | 196 | Returns: 197 | accuracy: accuracy of the prediction, type float 198 | """ 199 | if np.ndim(pred_ids) != 1 or np.ndim(label_ids) != 1: 200 | raise TypeError('require rank 1, 1. get {}, {}'.format(np.rank(pred_ids), np.rank(label_ids))) 201 | if len(pred_ids) != len(label_ids): 202 | raise TypeError('first argument and second argument have different length') 203 | 204 | accuracy = np.mean(np.equal(pred_ids, label_ids)) 205 | return accuracy 206 | 207 | 208 | def calculate_confusion_single(pred_list, label_list, label_size): 209 | """Helper method that calculates confusion matrix.""" 210 | confusion = np.zeros((label_size, label_size), dtype=np.int32) 211 | for i in range(len(label_list)): 212 | confusion[label_list[i], pred_list[i]] += 1 213 | 214 | tp_fp = np.sum(confusion, axis=0) 215 | tp_fn = np.sum(confusion, axis=1) 216 | tp = np.array([confusion[i, i] for i in range(len(confusion))]) 217 | 218 | precision = tp.astype(np.float32)/(tp_fp+1e-40) 219 | recall = tp.astype(np.float32)/(tp_fn+1e-40) 220 | overall_prec = np.float(np.sum(tp))/(np.sum(tp_fp)+1e-40) 221 | overall_recall = np.float(np.sum(tp))/(np.sum(tp_fn)+1e-40) 222 | 223 | return precision, recall, overall_prec, overall_recall, confusion 224 | 225 | 226 | def print_confusion_single(prec, recall, overall_prec, overall_recall, num_to_tag): 227 | """Helper method that prints confusion matrix.""" 228 | logstr="\n" 229 | logstr += '{:15}\t{:7}\t{:7}\n'.format('TAG', 'Prec', 'Recall') 230 | for i, tag in sorted(num_to_tag.items()): 231 | logstr += '{:15}\t{:2.4f}\t{:2.4f}\n'.format(tag, prec[i], recall[i]) 232 | logstr += '{:15}\t{:2.4f}\t{:2.4f}\n'.format('OVERALL', overall_prec, overall_recall) 233 | logging.info(logstr) 234 | print(logstr) 235 | 236 | 237 | def save_objs(obj, path): 238 | with open(path, 'wb') as fd: 239 | pkl.dump(obj, fd) 240 | 241 | 242 | def load_objs(path): 243 | with open(path, 'rb') as fd: 244 | ret = pkl.load(fd) 245 | return ret -------------------------------------------------------------------------------- /data/downloadDataset.md: -------------------------------------------------------------------------------- 1 | Please retrieve data from [Baidu Cloud Storage](https://pan.baidu.com/s/1x3lU9yrPNOYEpxqY_TtUbQ) or 2 | [Google Drive](https://drive.google.com/open?id=1JmY1CuASBIA8SW-Z5wMyc9ilseU2RHDZ), 3 | pelase download `dataset.tar.gz` and decompress it by `tar zxvf dataset.tar.gz` and place directory 4 | `smallset` in `data` directory. 5 | 6 | The structure of `data` directory should be as follows: 7 | ```bash 8 | Capsule4TextClassification 9 | ├── data 10 | │   └── smallset 11 | │   ├── imdb 12 | │   │   ├── devset.pkl 13 | │   │   ├── embedding.txt 14 | │   │   ├── label2id.pkl 15 | │   │   ├── testset.pkl 16 | │   │   ├── trainset.pkl 17 | │   │   └── vocab.pkl 18 | │   ├── SST-1 19 | │   │   ├── devset.pkl 20 | │   │   ├── embedding.txt 21 | │   │   ├── label2id.pkl 22 | │   │   ├── testset.pkl 23 | │   │   ├── trainset.pkl 24 | │   │   └── vocab.pkl 25 | │   ├── SST-2 26 | │   │   ├── devset.pkl 27 | │   │   ├── embedding.txt 28 | │   │   ├── label2id.pkl 29 | │   │   ├── testset.pkl 30 | │   │   ├── trainset.pkl 31 | │   │   └── vocab.pkl 32 | │   ├── yelp-2013 33 | │   │   ├── devset.pkl 34 | │   │   ├── embedding.txt 35 | │   │   ├── label2id.pkl 36 | │   │   ├── testset.pkl 37 | │   │   ├── trainset.pkl 38 | │   │   └── vocab.pkl 39 | │   └── yelp-2014 40 | │   ├── devset.pkl 41 | │   ├── embedding.txt 42 | │   ├── label2id.pkl 43 | │   ├── testset.pkl 44 | │   ├── trainset.pkl 45 | │   └── vocab.pkl 46 | 47 | ``` -------------------------------------------------------------------------------- /savings/imdb/config: -------------------------------------------------------------------------------- 1 | [General] 2 | revision = "first" 3 | datapath = "./data/smallset/imdb/" 4 | embed_path = "./data/smallset/imdb/embedding.txt" 5 | optimizer = "adam" 6 | attn_mode = "rout" 7 | seq_encoder = "bilstm" 8 | out_caps_num = 5 9 | rout_iter = 3 10 | max_snt_num = 40 11 | max_wd_num = 40 12 | max_epochs = 50 13 | pre_trained = true 14 | batch_sz = 32 15 | batch_sz_min = 32 16 | bucket_sz = 5000 17 | partial_update_until_epoch = 2 18 | embed_size = 300 19 | hidden_size = 200 20 | dense_hidden = [300, 10] 21 | lr = 0.0002 22 | decay_steps = 1000 23 | decay_rate = 0.9 24 | dropout = 0.2 25 | early_stopping = 7 26 | reg = 1e-06 27 | 28 | -------------------------------------------------------------------------------- /savings/sst01/config: -------------------------------------------------------------------------------- 1 | [General] 2 | revision = "flatten" 3 | datapath = "./data/smallset/SST-1/" 4 | embed_path = "./data/smallset/SST-1/embedding.txt" 5 | optimizer = "adam" 6 | attn_mode = "rout" 7 | seq_encoder = "bilstm" 8 | out_caps_num = 5 9 | rout_iter = 3 10 | max_snt_num = 50 11 | max_wd_num = 50 12 | max_epochs = 50 13 | pre_trained = true 14 | batch_sz = 64 15 | batch_sz_min = 16 16 | bucket_sz = 5000 17 | partial_update_until_epoch = 2 18 | embed_size = 300 19 | hidden_size = 200 20 | dense_hidden = [300, 5] 21 | lr = 0.0001 22 | decay_steps = 500 23 | decay_rate = 0.95 24 | dropout = 0.2 25 | early_stopping = 7 26 | reg = 1e-06 27 | 28 | -------------------------------------------------------------------------------- /savings/sst02/config: -------------------------------------------------------------------------------- 1 | [General] 2 | revision = "flatten" 3 | datapath = "./data/smallset/SST-2/" 4 | embed_path = "./data/smallset/SST-2/embedding.txt" 5 | optimizer = "adam" 6 | attn_mode = "rout" 7 | seq_encoder = "bilstm" 8 | out_caps_num = 5 9 | rout_iter = 3 10 | max_snt_num = 50 11 | max_wd_num = 50 12 | max_epochs = 50 13 | pre_trained = true 14 | batch_sz = 64 15 | batch_sz_min = 16 16 | bucket_sz = 5000 17 | partial_update_until_epoch = 2 18 | embed_size = 300 19 | hidden_size = 200 20 | dense_hidden = [300, 2] 21 | lr = 0.0003 22 | decay_steps = 1000 23 | decay_rate = 0.9 24 | dropout = 0.5 25 | early_stopping = 7 26 | reg = 1e-06 27 | 28 | -------------------------------------------------------------------------------- /savings/yelp2013/config: -------------------------------------------------------------------------------- 1 | [General] 2 | revision = "first" 3 | datapath = "./data/smallset/yelp-2013/" 4 | embed_path = "./data/smallset/yelp-2013/embedding.txt" 5 | optimizer = "adam" 6 | attn_mode = "rout" 7 | seq_encoder = "bilstm" 8 | out_caps_num = 5 9 | rout_iter = 3 10 | max_snt_num = 40 11 | max_wd_num = 40 12 | max_epochs = 50 13 | pre_trained = true 14 | batch_sz = 32 15 | batch_sz_min = 32 16 | bucket_sz = 5000 17 | partial_update_until_epoch = 2 18 | embed_size = 300 19 | hidden_size = 200 20 | dense_hidden = [300, 5] 21 | lr = 0.0001 22 | decay_steps = 1000 23 | decay_rate = 0.9 24 | dropout = 0.2 25 | early_stopping = 7 26 | reg = 1e-05 27 | 28 | -------------------------------------------------------------------------------- /savings/yelp2014/config: -------------------------------------------------------------------------------- 1 | [General] 2 | revision = "first" 3 | datapath = "./data/smallset/yelp-2014/" 4 | embed_path = "./data/smallset/yelp-2014/embedding.txt" 5 | optimizer = "adam" 6 | attn_mode = "rout" 7 | seq_encoder = "bilstm" 8 | out_caps_num = 5 9 | rout_iter = 3 10 | max_snt_num = 40 11 | max_wd_num = 40 12 | max_epochs = 50 13 | pre_trained = true 14 | batch_sz = 32 15 | batch_sz_min = 32 16 | bucket_sz = 5000 17 | partial_update_until_epoch = 2 18 | embed_size = 300 19 | hidden_size = 200 20 | dense_hidden = [300, 5] 21 | lr = 0.0002 22 | decay_steps = 1000 23 | decay_rate = 0.9 24 | dropout = 0.2 25 | early_stopping = 7 26 | reg = 1e-05 27 | 28 | --------------------------------------------------------------------------------