├── .gitignore ├── LICENSE ├── README.md ├── input_queue.py ├── model.py ├── omniglot_embed.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # data 104 | data 105 | /runs 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Donghwa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TCML-tensorflow 2 | A tensorflow implementation of [Meta-Learning with Temporal Convolutions](https://arxiv.org/abs/1707.03141) 3 | 4 | Embedding for Omniglot dataset is only available now. 5 | 6 | # Prerequisites 7 | * Python 3.6+ 8 | * Tensorflow 1.3 9 | * Image dataset for training/validation (Python dictionary-like object) 10 | 11 | # Usage 12 | ``` 13 | python train.py --dataset omniglot --n 5 --k 1 --dilation 1 2 1 2 4 8 16 --lr 5e-4 --batch_size 64 14 | ``` 15 | 16 | It will print valiation loss and accuracy. 17 | Checkpoints and summaries for other metrics are saved in `./runs/tcml_{input_dim}_{num_dense_filter}_{attention_value_dim}_{lr}` 18 | 19 | # Results 20 | Omniglot | 5-way, 1-shot | 5-way, 5-shot 21 | ------------ | ------------- | ------------- 22 | Accuracy | 95.12% | 95.01% 23 | 24 | Since the paper didn't share exact hyperparameters, it is hard to make same results with the paper. 25 | This code performs ~95% accuracy for 5-way environment with way smaller numbers of dilation stacks. 26 | 27 | When I tried to use same dilation stack with paper(1 2 1 2 1 2 1 2 1 2 4 8 16), it is too deep, thus the loss converges at 1.60 with poor accuracy. 28 | 29 | Still trying to find better hyperparameters to get higher performance. 30 | 31 | # License 32 | MIT 33 | 34 | # Author 35 | Donghwa Kim ([@storykim](https://github.com/storykim)) 36 | -------------------------------------------------------------------------------- /input_queue.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | class FewShotInputQueue: 6 | def __init__(self, classes, inputs, N, K): 7 | """ 8 | Initialize 9 | 10 | :param capacity: int. capacity of queue 11 | :param inputs: dict that key is class, value is data(d0, ..., dn) 12 | """ 13 | self.classes = classes 14 | self.inputs = inputs 15 | self.N = N 16 | self.K = K 17 | 18 | def make_one_data(self): 19 | """ 20 | Extract K datas from N classes, concat and shuffle them. 21 | Then, add 1 data from random class(in N classes) at tail of data 22 | :return: (NK+1) x data 23 | """ 24 | target_classes = random.sample(self.classes, self.N) 25 | dataset = [] 26 | label_set = [] 27 | 28 | last_data = None 29 | 30 | last_class = random.sample(target_classes, 1)[0] 31 | last_class_idx = None 32 | for i, class_name in enumerate(target_classes): 33 | data_len = self.inputs[class_name].shape[0] 34 | if class_name == last_class: 35 | last_class_idx = i 36 | target_indices = np.random.choice(data_len, self.K + 1, False) 37 | target_datas = self.inputs[class_name][target_indices] 38 | target_datas, last_data, _ = np.split(target_datas, [self.K, self.K + 1]) 39 | else: 40 | target_indices = np.random.choice(data_len, self.K, False) 41 | target_datas = self.inputs[class_name][target_indices] 42 | dataset.append(target_datas) 43 | label_set += [i] * self.K 44 | 45 | dataset_np = np.concatenate(dataset) 46 | perm = np.random.permutation(self.N * self.K) 47 | 48 | dataset_np = dataset_np[perm] 49 | label_set = np.asarray(label_set, np.int32)[perm] 50 | 51 | return np.expand_dims(np.append(dataset_np, last_data, axis=0), -1), \ 52 | np.append(label_set, [last_class_idx], axis=0).astype(np.int32) 53 | 54 | 55 | def _make_dummy_inputs(): 56 | # 3 classes, 4 datas, 2x2 dim. 57 | input_dict = {} 58 | for i in range(3): 59 | input_dict[i] = np.multiply(np.ones([4, 2, 2], dtype=np.float32), i) 60 | 61 | return input_dict 62 | 63 | 64 | def _FewShotInputQueue_test(): 65 | inputs = _make_dummy_inputs() 66 | q = FewShotInputQueue(inputs.keys(), inputs, 3, 3) 67 | 68 | result, result_label = zip(*[q.make_one_data() for _ in range(10)]) 69 | 70 | print(result, result_label) 71 | 72 | 73 | if __name__ == "__main__": 74 | _FewShotInputQueue_test() 75 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class TCML: 6 | def __init__(self, hparams, input_tensor, label_tensor, is_train): 7 | assert hparams.dilation is not None 8 | self.num_classes = hparams.n 9 | self.batch_size = hparams.batch_size 10 | self.seq_len = hparams.seq_len 11 | self.input_dim = hparams.input_dim 12 | self.num_dense_filter = hparams.num_dense_filter 13 | self.dilation = hparams.dilation 14 | self.attention_value_dim = hparams.attention_value_dim 15 | self.lr = hparams.lr 16 | self.reg_coeff = hparams.reg_coeff 17 | 18 | self.l2_loss = 0 19 | 20 | self.filter_width = 2 21 | 22 | self.input_placeholder = tf.cast(input_tensor, tf.float32) 23 | self.label_placeholder = label_tensor 24 | self.is_train = is_train 25 | if self.is_train: 26 | self.global_step = tf.get_variable("global_step", initializer=0, trainable=False) 27 | else: 28 | self.global_step = None 29 | 30 | self.dense_blocks = [] 31 | 32 | feed_label, target_label = tf.split(self.label_placeholder, [self.seq_len - 1, 1], 33 | axis=1) 34 | self.target_label = target_label 35 | feed_label_one_hot = tf.one_hot(feed_label, 36 | depth=self.num_classes, 37 | dtype=tf.float32) 38 | feed_label_one_hot = tf.concat([feed_label_one_hot, tf.zeros((self.batch_size, 1, self.num_classes))], axis=1) 39 | concated_input = tf.concat([self.input_placeholder, feed_label_one_hot], axis=2) 40 | 41 | last_output = concated_input 42 | d = self.input_dim + self.num_classes 43 | for i, dilation in enumerate(self.dilation): 44 | name = f"dilation{i}_{dilation}" 45 | with tf.variable_scope(name): 46 | last_output = output = self.generate_dense_block(last_output, d, dilation) 47 | self.dense_blocks.append((name, output)) 48 | d += self.num_dense_filter 49 | 50 | # last_output : [B, T, D + 128 * i] 51 | with tf.variable_scope("attention"): 52 | kernel_size = [1, d, self.attention_value_dim] # width, in_channel, out_channel 53 | conv_kernel = tf.get_variable("1x1_conv", kernel_size, 54 | dtype=tf.float32, 55 | initializer=tf.contrib.layers.xavier_initializer_conv2d()) 56 | self.l2_loss += tf.nn.l2_loss(conv_kernel) 57 | 58 | key, query = tf.split(last_output, [self.seq_len - 1, 1], axis=1) 59 | attention_value = tf.nn.conv1d(key, conv_kernel, 1, "SAME") 60 | attention_outputs = self.attention_layer(key, attention_value, query, float(d)) 61 | 62 | # attention_output : [B, 1, d'] 63 | # channel-wise softmax 64 | with tf.variable_scope("softmax"): 65 | kernel_size = [1, self.attention_value_dim, self.num_classes] 66 | conv_kernel = tf.get_variable("1x1_conv", kernel_size, 67 | dtype=tf.float32, 68 | initializer=tf.contrib.layers.xavier_initializer_conv2d()) 69 | self.l2_loss += tf.nn.l2_loss(conv_kernel) 70 | self.last_vector = softmax_vector = tf.nn.conv1d(attention_outputs, conv_kernel, 1, "SAME") 71 | 72 | ce_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_label, 73 | logits=softmax_vector)) 74 | self.loss = ce_loss + self.reg_coeff * self.l2_loss 75 | self.train_step = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.global_step) 76 | 77 | self.accuracy = self._calc_accuracy() 78 | 79 | def _causal_conv(self, x, dilation, in_channel, out_channel): 80 | with tf.variable_scope("causal_conv"): 81 | # input shape : [B, T, D] 82 | # filter_shape : spatial_filter_shape + [in_channels, out_channels] 83 | filter_shape = [self.filter_width, in_channel, out_channel] 84 | initializer = tf.contrib.layers.xavier_initializer_conv2d() 85 | 86 | tanh_filter = tf.get_variable("tanh_filter", shape=filter_shape, dtype=tf.float32, 87 | initializer=initializer) 88 | sigmoid_filter = tf.get_variable("sigmoid_filter", shape=filter_shape, dtype=tf.float32, 89 | initializer=initializer) 90 | 91 | x_reverse = tf.pad(x, [[0, 0], [dilation, 0], [0, 0]]) 92 | 93 | tanh_output = tf.tanh(tf.nn.convolution(x_reverse, tanh_filter, 94 | padding="VALID", 95 | dilation_rate=(dilation,))) 96 | sigmoid_output = tf.sigmoid(tf.nn.convolution(x_reverse, sigmoid_filter, 97 | padding="VALID", 98 | dilation_rate=(dilation,))) 99 | 100 | return tf.multiply(tanh_output, sigmoid_output) 101 | 102 | def _residual_block(self, x, dilation, num_filter): 103 | # input shape : [B, T, D] 104 | # [filter_height, filter_width, in_channels, out_channels] 105 | conv_output = self._causal_conv(x, dilation, num_filter, num_filter) 106 | return x + conv_output 107 | 108 | def generate_dense_block(self, x, input_dim, dilation): 109 | # input shape : [B, T, D] 110 | conv = self._causal_conv(x, dilation, input_dim, self.num_dense_filter) 111 | with tf.variable_scope("residual_block_1"): 112 | residual1 = self._residual_block(conv, dilation, self.num_dense_filter) 113 | with tf.variable_scope("residual_block_2"): 114 | residual2 = self._residual_block(residual1, dilation, self.num_dense_filter) 115 | return tf.concat([x, residual2], axis=2) 116 | 117 | def attention_layer(self, key, value, query, d): 118 | # key : B x T-1 x d 119 | # value : B x T-1 x d' 120 | # query : B x 1 x d 121 | attention = tf.nn.softmax(tf.divide(tf.matmul(query, key, transpose_b=True), tf.sqrt(d))) # 1 x (t-1) 122 | return tf.matmul(attention, value) # B x d' 123 | 124 | def _calc_accuracy(self): 125 | with tf.name_scope("accuracy"): 126 | predictions = tf.argmax(self.last_vector, 2, name="predictions", output_type=tf.int32) 127 | labels = self.target_label 128 | correct_predictions = tf.equal(predictions, labels) 129 | accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 130 | 131 | # self.confusion_matrix = tf.confusion_matrix(labels, predictions, num_classes=self.num_classes) 132 | return accuracy 133 | 134 | 135 | def _make_dummy_data(): 136 | # 4 x 20 x 10 input data (float32) 137 | # 4 x 20 label data (int, [0, 4]) 138 | input_data = np.random.randn(4, 20, 10) 139 | label_data = np.random.randint(5, size=(4, 20)) 140 | return input_data, label_data 141 | 142 | 143 | def _TCML_test(): 144 | class Dummy: pass 145 | 146 | hparams = Dummy() 147 | hparams.n = 5 148 | hparams.input_dim = 10 149 | hparams.num_dense_filter = 16 150 | hparams.batch_size = 4 151 | hparams.seq_len = 20 152 | hparams.attention_value_dim = 16 153 | hparams.dilation = [1, 2, 1, 2] 154 | hparams.lr = 1e-3 155 | 156 | with tf.Graph().as_default(): 157 | dummy_input, dummy_label = _make_dummy_data() 158 | model = TCML(hparams, tf.stack(dummy_input), tf.cast(tf.stack(dummy_label), tf.int32), True) 159 | 160 | config = tf.ConfigProto() 161 | config.gpu_options.allow_growth = True 162 | sess = tf.Session(config=config) 163 | 164 | with sess.as_default(): 165 | init = tf.initialize_all_variables() 166 | sess.run(init) 167 | 168 | _, loss, acc = sess.run([model.train_step, model.loss, model.accuracy]) 169 | print(loss, acc) 170 | 171 | 172 | if __name__ == "__main__": 173 | _TCML_test() 174 | -------------------------------------------------------------------------------- /omniglot_embed.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class OmniglotEmbedNetwork: 6 | def __init__(self, inputs, batch_size): 7 | ''' 8 | 4 blocks of 9 | {3 × 3 conv (64 filters), 10 | batch normalization, 11 | leaky ReLU activation (leak 0.1), 12 | and 2 × 2 max-pooling} 13 | ''' 14 | 15 | self.epsilon = 1e-10 16 | 17 | # input : B x T x H x W x C 18 | # output : B x T x D 19 | self.input_placeholder, self.label_placeholder = inputs 20 | 21 | with tf.variable_scope("omni_embed_0"): 22 | last_output = self.add_block(self.input_placeholder, 1, 64) 23 | 24 | for i in [1, 2, 3]: 25 | with tf.variable_scope(f"omni_embed_{i}"): 26 | last_output = self.add_block(last_output, 64, 64) 27 | 28 | self.output = tf.squeeze(last_output) 29 | 30 | def add_block(self, x, in_channel, out_channel): 31 | kernel_size = [1, 3, 3, in_channel, out_channel] 32 | kernel = tf.get_variable("kernel", kernel_size, dtype=tf.float32, 33 | initializer=tf.contrib.layers.xavier_initializer_conv2d()) 34 | conv_output = tf.nn.conv3d(x, kernel, [1, 1, 1, 1, 1], "SAME") 35 | 36 | beta = tf.get_variable('beta', [out_channel], initializer=tf.constant_initializer(0.0)) 37 | gamma = tf.get_variable('gamma', [out_channel], initializer=tf.constant_initializer(1.0)) 38 | 39 | batch_mean, batch_var = tf.nn.moments(conv_output, [0, 1, 2, 3]) 40 | batch_normalized = tf.nn.batch_normalization(conv_output, batch_mean, batch_var, beta, gamma, self.epsilon) 41 | 42 | relu_output = tf.nn.relu(batch_normalized) - 0.1 * tf.nn.relu(-batch_normalized) 43 | 44 | return tf.nn.max_pool3d(relu_output, [1, 1, 2, 2, 1], [1, 1, 2, 2, 1], "VALID") 45 | 46 | 47 | def _OmniglotEmbed_test(): 48 | with tf.Graph().as_default(): 49 | config = tf.ConfigProto() 50 | config.gpu_options.allow_growth = True 51 | sess = tf.Session(config=config) 52 | 53 | dummy_input = np.random.rand(10, 28, 28, 1) 54 | dummy_label = np.random.randint(5, size=(10, )) 55 | queue = tf.RandomShuffleQueue(20, 56 | min_after_dequeue=2, 57 | shapes=[dummy_input.shape, dummy_label.shape], dtypes=[tf.float32, tf.int32]) 58 | enqueue = queue.enqueue([dummy_input, dummy_label]) 59 | qr = tf.train.QueueRunner(queue, [enqueue] * 2) 60 | tf.train.add_queue_runner(qr) 61 | 62 | coord = tf.train.Coordinator() 63 | enqueue_threads = qr.create_threads(sess, coord=coord, start=True) 64 | 65 | model = OmniglotEmbedNetwork(queue, 5) 66 | 67 | with sess.as_default(): 68 | init = tf.initialize_all_variables() 69 | sess.run(init) 70 | 71 | output, = sess.run([model.output]) 72 | print(output.shape) 73 | 74 | coord.request_stop() 75 | coord.join(enqueue_threads) 76 | 77 | if __name__ == "__main__": 78 | _OmniglotEmbed_test() 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | import time 6 | from model import TCML 7 | from omniglot_embed import OmniglotEmbedNetwork 8 | from input_queue import FewShotInputQueue 9 | 10 | 11 | def define_flags(): 12 | flags = argparse.ArgumentParser() 13 | 14 | flags.add_argument("--n", type=int, default=None, help="N [Required]") 15 | flags.add_argument("--k", type=int, default=None, help="K [Required]") 16 | flags.add_argument("--dataset", type=str, default="omniglot", help="Dataset (omniglot / miniimage) [omniglot]") 17 | 18 | flags.add_argument("--dilation", type=int, nargs='+', help="List of dilation size[Required]") 19 | 20 | flags.add_argument("--batch_size", type=int, default=128, help="Batch size B[128]") 21 | flags.add_argument("--input_dim", type=int, default=64, help="Dimension of input D[64]") 22 | flags.add_argument("--num_dense_filter", type=int, default=128, help="# of filter in Dense block[128]") 23 | flags.add_argument("--attention_value_dim", type=int, default=16, help="Dimension of attension value d'[16]") 24 | flags.add_argument("--lr", type=float, default=1e-3, help="Learning rate[1e-3]") 25 | flags.add_argument("--reg_coeff", type=float, default=1e-3, help="Coefficient for regularization loss[1e-3]") 26 | return flags.parse_args() 27 | 28 | 29 | def train(): 30 | hparams = define_flags() 31 | hparams.seq_len = episode_len = hparams.n * hparams.k + 1 32 | 33 | if hparams.dataset == "omniglot": 34 | input_path = "data/omniglot/train.npz" 35 | valid_path = "data/omniglot/test.npz" 36 | else: 37 | raise NotImplementedError 38 | 39 | if hparams.dataset == "omniglot": 40 | input_size = (episode_len, 28, 28, 1) 41 | else: 42 | raise NotImplementedError 43 | 44 | with open(input_path, "rb") as f: 45 | input_npz = np.load(f) 46 | inputs = {} 47 | for filename in input_npz.files: 48 | inputs[filename] = input_npz[filename] 49 | 50 | with open(valid_path, "rb") as f: 51 | valid_npz = np.load(f) 52 | valid_inputs = {} 53 | for filename in valid_npz.files: 54 | valid_inputs[filename] = valid_npz[filename] 55 | 56 | with tf.Graph().as_default(): 57 | q = FewShotInputQueue(inputs.keys(), inputs, hparams.n, hparams.k) 58 | valid_q = FewShotInputQueue(valid_inputs.keys(), valid_inputs, hparams.n, hparams.k) 59 | 60 | generated_input, generated_label = tf.py_func(q.make_one_data, [], [tf.float32, tf.int32]) 61 | batch_tensors = tf.train.batch([generated_input, generated_label], batch_size=hparams.batch_size, num_threads=4, 62 | shapes=[input_size, (episode_len,)], capacity=hparams.batch_size*5) 63 | valid_input, valid_label = tf.py_func(valid_q.make_one_data, [], [tf.float32, tf.int32]) 64 | valid_batch_tensors = tf.train.batch([valid_input, valid_label], batch_size=hparams.batch_size, num_threads=4, 65 | shapes=[input_size, (episode_len,)], capacity=hparams.batch_size*5) 66 | 67 | with tf.variable_scope("networks"): 68 | embed_network = OmniglotEmbedNetwork(batch_tensors, hparams.batch_size) 69 | tcml = TCML(hparams, embed_network.output, embed_network.label_placeholder, True) 70 | 71 | with tf.variable_scope("networks", reuse=True): 72 | valid_embed_network = OmniglotEmbedNetwork(valid_batch_tensors, hparams.batch_size) 73 | valid_tcml = TCML(hparams, valid_embed_network.output, valid_embed_network.label_placeholder, False) 74 | 75 | params_to_str = f"tcml_{hparams.input_dim}_{hparams.num_dense_filter}_{hparams.attention_value_dim}_{hparams.lr}" 76 | log_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", params_to_str)) 77 | 78 | # Summaries 79 | tf.summary.scalar("train_loss", tcml.loss) 80 | tf.summary.scalar("train_acc", tcml.accuracy) 81 | 82 | tf.summary.scalar("valid_loss", valid_tcml.loss) 83 | tf.summary.scalar("valid_acc", valid_tcml.accuracy) 84 | 85 | tf.summary.image("inputs", valid_embed_network.input_placeholder[0], max_outputs=episode_len) 86 | 87 | # Supervisor 88 | supervisor = tf.train.Supervisor( 89 | logdir=log_dir, 90 | save_summaries_secs=120, 91 | save_model_secs=600, 92 | global_step=tcml.global_step, 93 | ) 94 | 95 | config = tf.ConfigProto() 96 | config.gpu_options.allow_growth = True 97 | 98 | print("Training start") 99 | 100 | with supervisor.managed_session(config=config) as sess: 101 | min_dev_loss = 10000 102 | min_step = -1 103 | 104 | STEP_NUM = 10000000 105 | EARLY_STOP = 30000000 106 | print_every = 500 107 | 108 | HUGE_VALIDATION_CYCLE = print_every * 20 109 | 110 | last_dev = time.time() 111 | 112 | for step in range(STEP_NUM): 113 | if supervisor.should_stop(): 114 | break 115 | 116 | if step - min_step > EARLY_STOP: 117 | print("Early stopping...") 118 | break 119 | 120 | if step % print_every != 0: 121 | _, loss, acc, global_step = sess.run( 122 | [tcml.train_step, tcml.loss, tcml.accuracy, tcml.global_step]) 123 | else: 124 | _, loss, acc, global_step = sess.run( 125 | [tcml.train_step, tcml.loss, tcml.accuracy, tcml.global_step]) 126 | 127 | loss, acc = sess.run([valid_tcml.loss, valid_tcml.accuracy]) 128 | 129 | current_time = time.time() 130 | print( 131 | f'Evaluate(Step {step}/{global_step} : valid loss({loss}), acc({acc}) in {current_time - last_dev} s') 132 | 133 | # HUGE VALIDATION 134 | if step != 0 and step % HUGE_VALIDATION_CYCLE == 0: 135 | total_loss = total_acc = 0. 136 | BATCH_NUM = 30 137 | for _ in range(BATCH_NUM): 138 | loss, acc = sess.run([valid_tcml.loss, valid_tcml.accuracy]) 139 | total_loss += loss * hparams.batch_size 140 | total_acc += acc * hparams.batch_size 141 | 142 | total_loss /= BATCH_NUM * hparams.batch_size 143 | total_acc /= BATCH_NUM * hparams.batch_size 144 | 145 | huge_data_acc_summary = tf.Summary() 146 | huge_data_acc_summary.value.add(tag="huge_data_accuracy", simple_value=total_acc) 147 | supervisor.summary_computed(sess, huge_data_acc_summary, global_step=global_step) 148 | 149 | huge_data_loss_summary = tf.Summary() 150 | huge_data_loss_summary.value.add(tag="huge_data_loss", simple_value=total_loss) 151 | supervisor.summary_computed(sess, huge_data_loss_summary, global_step=global_step) 152 | 153 | last_dev = current_time 154 | 155 | if loss < min_dev_loss: 156 | min_dev_loss = loss 157 | min_step = step 158 | 159 | 160 | if __name__ == "__main__": 161 | train() 162 | --------------------------------------------------------------------------------