├── assets └── model.png ├── README.md ├── download.py ├── LICENSE ├── main.py ├── utils.py ├── .gitignore ├── config.py ├── trainer.py ├── layers.py ├── model.py └── data_loader.py /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devsisters/neural-combinatorial-rl-tensorflow/HEAD/assets/model.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Combinatorial Optimization in Tensorflow 2 | 3 | TensorFlow implementation of [Neural Combinatorial Optimization with Reinforcement Learning](http://arxiv.org/abs/1611.09940). 4 | 5 | ![model](./assets/model.png) 6 | 7 | (in progress) 8 | 9 | 10 | ## Requirements 11 | 12 | - Python 2.7 13 | - [tqdm](httsp://github.com/tqdm/tqdm) 14 | - [TensorFlow 0.12.1](httsp://github.com/tensorflow/tensorflow/tree/r0.12) 15 | 16 | 17 | ## Usage 18 | 19 | To train a model: 20 | 21 | $ python main.py --task=tsp20 --lr_start=0.001 --min_data_length=5 --max_data_length=20 22 | $ python main.py --task=tsp50 --lr_start=0.001 --min_data_length=5 --max_data_length=50 23 | $ python main.py --task=tsp100 --lr_start=0.0001 --min_data_length=5 --max_data_length=100 24 | 25 | 26 | To train a model: 27 | 28 | $ python main.py 29 | $ tensorboard --logdir=logs --host=0.0.0.0 30 | 31 | To test a model: 32 | 33 | $ python main.py --is_train=False 34 | 35 | ## Results 36 | 37 | (in progress) 38 | 39 | 40 | ## Author 41 | 42 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io) 43 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # Code based on 2 | # http://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039 3 | import requests 4 | from tqdm import tqdm 5 | 6 | def download_file_from_google_drive(id, destination): 7 | URL = "https://docs.google.com/uc?export=download" 8 | 9 | session = requests.Session() 10 | 11 | response = session.get(URL, params = { 'id' : id }, stream = True) 12 | token = get_confirm_token(response) 13 | 14 | if token: 15 | params = { 'id' : id, 'confirm' : token } 16 | response = session.get(URL, params = params, stream = True) 17 | 18 | save_response_content(response, destination) 19 | return True 20 | 21 | def get_confirm_token(response): 22 | for key, value in response.cookies.items(): 23 | if key.startswith('download_warning'): 24 | return value 25 | 26 | return None 27 | 28 | def save_response_content(response, destination): 29 | CHUNK_SIZE = 32768 30 | 31 | with open(destination, "wb") as f: 32 | for chunk in tqdm(response.iter_content(CHUNK_SIZE)): 33 | if chunk: # filter out keep-alive new chunks 34 | f.write(chunk) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Devsisters corp. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from trainer import Trainer 6 | from config import get_config 7 | from utils import prepare_dirs_and_logger, save_config 8 | 9 | config = None 10 | 11 | def main(_): 12 | prepare_dirs_and_logger(config) 13 | 14 | if not config.task.lower().startswith('tsp'): 15 | raise Exception("[!] Task should starts with TSP") 16 | 17 | if config.max_enc_length is None: 18 | config.max_enc_length = config.max_data_length 19 | if config.max_dec_length is None: 20 | config.max_dec_length = config.max_data_length 21 | 22 | rng = np.random.RandomState(config.random_seed) 23 | tf.set_random_seed(config.random_seed) 24 | 25 | trainer = Trainer(config, rng) 26 | save_config(config.model_dir, config) 27 | 28 | if config.is_train: 29 | trainer.train() 30 | else: 31 | if not config.load_path: 32 | raise Exception("[!] You should specify `load_path` to load a pretrained model") 33 | trainer.test() 34 | 35 | tf.logging.info("Run finished.") 36 | 37 | if __name__ == "__main__": 38 | config, unparsed = get_config() 39 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 40 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import numpy as np 5 | from datetime import datetime 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | 10 | def prepare_dirs_and_logger(config): 11 | formatter = logging.Formatter( 12 | "%(asctime)s:%(levelname)s::%(message)s") 13 | logger = logging.getLogger('tensorflow') 14 | 15 | for hdlr in logger.handlers: 16 | logger.removeHandler(hdlr) 17 | 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | 21 | logger.addHandler(handler) 22 | logger.setLevel(tf.logging.INFO) 23 | 24 | if config.load_path: 25 | if config.load_path.startswith(config.task): 26 | config.model_name = config.load_path 27 | else: 28 | config.model_name = "{}_{}".format(config.task, config.load_path) 29 | else: 30 | config.model_name = "{}_{}".format(config.task, get_time()) 31 | 32 | config.model_dir = os.path.join(config.log_dir, config.model_name) 33 | 34 | for path in [config.log_dir, config.data_dir, config.model_dir]: 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | def get_time(): 39 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 40 | 41 | def show_all_variables(): 42 | model_vars = tf.trainable_variables() 43 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 44 | 45 | def save_config(model_dir, config): 46 | param_path = os.path.join(model_dir, "params.json") 47 | 48 | tf.logging.info("MODEL dir: %s" % model_dir) 49 | tf.logging.info("PARAM path: %s" % param_path) 50 | 51 | with open(param_path, 'w') as fp: 52 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/hand 3 | data/gaze 4 | data/* 5 | samples 6 | outputs 7 | 8 | # Log 9 | logs 10 | 11 | # ETC 12 | paper.pdf 13 | 14 | # Created by https://www.gitignore.io/api/python,vim 15 | 16 | ### Python ### 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *,cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | .venv/ 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | 110 | ### Vim ### 111 | # swap 112 | [._]*.s[a-v][a-z] 113 | [._]*.sw[a-p] 114 | [._]s[a-v][a-z] 115 | [._]sw[a-p] 116 | # session 117 | Session.vim 118 | # temporary 119 | .netrwhist 120 | *~ 121 | # auto-generated tag files 122 | tags 123 | 124 | # End of https://www.gitignore.io/api/python,vim 125 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | return v.lower() in ('true', '1') 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # Network 16 | net_arg = add_argument_group('Network') 17 | net_arg.add_argument('--hidden_dim', type=int, default=256, help='') 18 | net_arg.add_argument('--num_layers', type=int, default=1, help='') 19 | net_arg.add_argument('--input_dim', type=int, default=2, help='') 20 | net_arg.add_argument('--max_enc_length', type=int, default=None, help='') 21 | net_arg.add_argument('--max_dec_length', type=int, default=None, help='') 22 | net_arg.add_argument('--init_min_val', type=float, default=-0.08, help='for uniform random initializer') 23 | net_arg.add_argument('--init_max_val', type=float, default=+0.08, help='for uniform random initializer') 24 | net_arg.add_argument('--num_glimpse', type=int, default=1, help='') 25 | net_arg.add_argument('--use_terminal_symbol', type=str2bool, default=True, help='Not implemented yet') 26 | 27 | # Data 28 | data_arg = add_argument_group('Data') 29 | data_arg.add_argument('--task', type=str, default='tsp') 30 | data_arg.add_argument('--batch_size', type=int, default=128) 31 | data_arg.add_argument('--min_data_length', type=int, default=5) 32 | data_arg.add_argument('--max_data_length', type=int, default=10) 33 | data_arg.add_argument('--train_num', type=int, default=1000000) 34 | data_arg.add_argument('--valid_num', type=int, default=1000) 35 | data_arg.add_argument('--test_num', type=int, default=1000) 36 | 37 | # Training / test parameters 38 | train_arg = add_argument_group('Training') 39 | train_arg.add_argument('--is_train', type=str2bool, default=True, help='') 40 | train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='') 41 | train_arg.add_argument('--max_step', type=int, default=1000000, help='') 42 | train_arg.add_argument('--lr_start', type=float, default=0.001, help='') 43 | train_arg.add_argument('--lr_decay_step', type=int, default=5000, help='') 44 | train_arg.add_argument('--lr_decay_rate', type=float, default=0.96, help='') 45 | train_arg.add_argument('--max_grad_norm', type=float, default=2.0, help='') 46 | train_arg.add_argument('--checkpoint_secs', type=int, default=300, help='') 47 | 48 | # Misc 49 | misc_arg = add_argument_group('Misc') 50 | misc_arg.add_argument('--log_step', type=int, default=50, help='') 51 | misc_arg.add_argument('--num_log_samples', type=int, default=3, help='') 52 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'], help='') 53 | misc_arg.add_argument('--log_dir', type=str, default='logs') 54 | misc_arg.add_argument('--data_dir', type=str, default='data') 55 | misc_arg.add_argument('--output_dir', type=str, default='outputs') 56 | misc_arg.add_argument('--load_path', type=str, default='') 57 | misc_arg.add_argument('--debug', type=str2bool, default=False) 58 | misc_arg.add_argument('--gpu_memory_fraction', type=float, default=1.0) 59 | misc_arg.add_argument('--random_seed', type=int, default=123, help='') 60 | 61 | def get_config(): 62 | config, unparsed = parser.parse_known_args() 63 | return config, unparsed 64 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import trange 4 | import tensorflow as tf 5 | from tensorflow.contrib.framework.python.ops import arg_scope 6 | 7 | from model import Model 8 | from utils import show_all_variables 9 | from data_loader import TSPDataLoader 10 | 11 | class Trainer(object): 12 | def __init__(self, config, rng): 13 | self.config = config 14 | self.rng = rng 15 | 16 | self.task = config.task 17 | self.model_dir = config.model_dir 18 | self.gpu_memory_fraction = config.gpu_memory_fraction 19 | 20 | self.log_step = config.log_step 21 | self.max_step = config.max_step 22 | self.num_log_samples = config.num_log_samples 23 | self.checkpoint_secs = config.checkpoint_secs 24 | 25 | self.summary_ops = {} 26 | 27 | if config.task.lower().startswith('tsp'): 28 | self.data_loader = TSPDataLoader(config, rng=self.rng) 29 | else: 30 | raise Exception("[!] Unknown task: {}".format(config.task)) 31 | 32 | self.models = {} 33 | 34 | self.model = Model( 35 | config, 36 | inputs=self.data_loader.x, 37 | labels=self.data_loader.y, 38 | enc_seq_length=self.data_loader.seq_length, 39 | dec_seq_length=self.data_loader.seq_length, 40 | mask=self.data_loader.mask) 41 | 42 | self.build_session() 43 | show_all_variables() 44 | 45 | def build_session(self): 46 | self.saver = tf.train.Saver() 47 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 48 | 49 | sv = tf.train.Supervisor(logdir=self.model_dir, 50 | is_chief=True, 51 | saver=self.saver, 52 | summary_op=None, 53 | summary_writer=self.summary_writer, 54 | save_summaries_secs=300, 55 | save_model_secs=self.checkpoint_secs, 56 | global_step=self.model.global_step) 57 | 58 | gpu_options = tf.GPUOptions( 59 | per_process_gpu_memory_fraction=self.gpu_memory_fraction, 60 | allow_growth=True) # seems to be not working 61 | sess_config = tf.ConfigProto(allow_soft_placement=True, 62 | gpu_options=gpu_options) 63 | 64 | self.sess = sv.prepare_or_wait_for_session(config=sess_config) 65 | 66 | def train(self): 67 | tf.logging.info("Training starts...") 68 | self.data_loader.run_input_queue(self.sess) 69 | 70 | summary_writer = None 71 | for k in trange(self.max_step, desc="train"): 72 | fetch = { 73 | 'optim': self.model.optim, 74 | } 75 | result = self.model.train(self.sess, fetch, summary_writer) 76 | 77 | if result['step'] % self.log_step == 0: 78 | self._test(self.summary_writer) 79 | 80 | summary_writer = self._get_summary_writer(result) 81 | 82 | def test(self): 83 | tf.logging.info("Testing starts...") 84 | 85 | for idx in range(10): 86 | self._test(None) 87 | 88 | def _test(self, summary_writer): 89 | fetch = { 90 | 'loss': self.model.total_inference_loss, 91 | 'pred': self.model.dec_inference, 92 | 'true': self.model.dec_targets, 93 | } 94 | result = self.model.test(self.sess, fetch, summary_writer) 95 | 96 | tf.logging.info("") 97 | tf.logging.info("test loss: {}".format(result['loss'])) 98 | for idx in range(self.num_log_samples): 99 | pred, true = result['pred'][idx], result['true'][idx] 100 | tf.logging.info("test pred: {}".format(pred)) 101 | tf.logging.info("test true: {} ({})".format(true, np.array_equal(pred, true))) 102 | 103 | if summary_writer: 104 | summary_writer.add_summary(result['summary'], result['step']) 105 | 106 | def _inject_summary(self, tag, feed_dict, step): 107 | summaries = self.sess.run(self.summary_ops[tag], feed_dict) 108 | self.summary_writer.add_summary(summaries['summary'], step) 109 | 110 | path = os.path.join( 111 | self.config.sample_model_dir, "{}.png".format(step)) 112 | imwrite(path, img_tile(summaries['output'], 113 | tile_shape=self.config.sample_image_grid)[:,:,0]) 114 | 115 | def _get_summary_writer(self, result): 116 | if result['step'] % self.log_step == 0: 117 | return self.summary_writer 118 | else: 119 | return None 120 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | from tensorflow.contrib import layers 4 | from tensorflow.contrib import seq2seq 5 | from tensorflow.python.util import nest 6 | 7 | LSTMCell = rnn.LSTMCell 8 | MultiRNNCell = rnn.MultiRNNCell 9 | dynamic_rnn_decoder = seq2seq.dynamic_rnn_decoder 10 | simple_decoder_fn_train = seq2seq.simple_decoder_fn_train 11 | 12 | def decoder_rnn(cell, inputs, 13 | enc_outputs, enc_final_states, 14 | seq_length, hidden_dim, 15 | num_glimpse, batch_size, is_train, 16 | end_of_sequence_id=0, initializer=None, 17 | max_length=None): 18 | with tf.variable_scope("decoder_rnn") as scope: 19 | def attention(ref, query, with_softmax, scope="attention"): 20 | with tf.variable_scope(scope): 21 | W_ref = tf.get_variable( 22 | "W_ref", [1, hidden_dim, hidden_dim], initializer=initializer) 23 | W_q = tf.get_variable( 24 | "W_q", [hidden_dim, hidden_dim], initializer=initializer) 25 | v = tf.get_variable( 26 | "v", [hidden_dim], initializer=initializer) 27 | 28 | encoded_ref = tf.nn.conv1d(ref, W_ref, 1, "VALID", name="encoded_ref") 29 | encoded_query = tf.expand_dims(tf.matmul(query, W_q, name="encoded_query"), 1) 30 | tiled_encoded_Query = tf.tile( 31 | encoded_query, [1, tf.shape(encoded_ref)[1], 1], name="tiled_encoded_query") 32 | scores = tf.reduce_sum(v * tf.tanh(encoded_ref + encoded_query), [-1]) 33 | 34 | if with_softmax: 35 | return tf.nn.softmax(scores) 36 | else: 37 | return scores 38 | 39 | def glimpse(ref, query, scope="glimpse"): 40 | p = attention(ref, query, with_softmax=True, scope=scope) 41 | alignments = tf.expand_dims(p, 2) 42 | return tf.reduce_sum(alignments * ref, [1]) 43 | 44 | def output_fn(ref, query, num_glimpse): 45 | if query is None: 46 | return tf.zeros([max_length], tf.float32) # only used for shape inference 47 | else: 48 | for idx in range(num_glimpse): 49 | query = glimpse(ref, query, "glimpse_{}".format(idx)) 50 | return attention(ref, query, with_softmax=False, scope="attention") 51 | 52 | def input_fn(sampled_idx): 53 | return tf.stop_gradient( 54 | tf.gather_nd(enc_outputs, index_matrix_to_pairs(sampled_idx))) 55 | 56 | if is_train: 57 | decoder_fn = simple_decoder_fn_train(enc_final_states) 58 | else: 59 | maximum_length = tf.convert_to_tensor(max_length, tf.int32) 60 | 61 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 62 | cell_output = output_fn(enc_outputs, cell_output, num_glimpse) 63 | if cell_state is None: 64 | cell_state = enc_final_states 65 | next_input = cell_input 66 | done = tf.zeros([batch_size,], dtype=tf.bool) 67 | else: 68 | sampled_idx = tf.cast(tf.argmax(cell_output, 1), tf.int32) 69 | next_input = input_fn(sampled_idx) 70 | done = tf.equal(sampled_idx, end_of_sequence_id) 71 | 72 | done = tf.cond(tf.greater(time, maximum_length), 73 | lambda: tf.ones([batch_size,], dtype=tf.bool), 74 | lambda: done) 75 | return (done, cell_state, next_input, cell_output, context_state) 76 | 77 | outputs, final_state, final_context_state = \ 78 | dynamic_rnn_decoder(cell, decoder_fn, inputs=inputs, 79 | sequence_length=seq_length, scope=scope) 80 | 81 | if is_train: 82 | transposed_outputs = tf.transpose(outputs, [1, 0, 2]) 83 | fn = lambda x: output_fn(enc_outputs, x, num_glimpse) 84 | outputs = tf.transpose(tf.map_fn(fn, transposed_outputs), [1, 0, 2]) 85 | 86 | return outputs, final_state, final_context_state 87 | 88 | def trainable_initial_state(batch_size, state_size, 89 | initializer=None, name="initial_state"): 90 | flat_state_size = nest.flatten(state_size) 91 | 92 | if not initializer: 93 | flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size) 94 | else: 95 | flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size) 96 | 97 | names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))] 98 | tiled_states = [] 99 | 100 | for name, size, init in zip(names, flat_state_size, flat_initializer): 101 | shape_with_batch_dim = [1, size] 102 | initial_state_variable = tf.get_variable( 103 | name, shape=shape_with_batch_dim, initializer=init()) 104 | 105 | tiled_state = tf.tile(initial_state_variable, 106 | [batch_size, 1], name=(name + "_tiled")) 107 | tiled_states.append(tiled_state) 108 | 109 | return nest.pack_sequence_as(structure=state_size, 110 | flat_sequence=tiled_states) 111 | 112 | def index_matrix_to_pairs(index_matrix): 113 | # [[3,1,2], [2,3,1]] -> [[[0, 3], [1, 1], [2, 2]], 114 | # [[0, 2], [1, 3], [2, 1]]] 115 | replicated_first_indices = tf.range(tf.shape(index_matrix)[0]) 116 | rank = len(index_matrix.get_shape()) 117 | if rank == 2: 118 | replicated_first_indices = tf.tile( 119 | tf.expand_dims(replicated_first_indices, dim=1), 120 | [1, tf.shape(index_matrix)[1]]) 121 | return tf.stack([replicated_first_indices, index_matrix], axis=rank) 122 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework import arg_scope 3 | 4 | from layers import * 5 | 6 | class Model(object): 7 | def __init__(self, config, 8 | inputs, labels, enc_seq_length, dec_seq_length, mask, 9 | reuse=False, is_critic=False): 10 | self.task = config.task 11 | self.debug = config.debug 12 | self.config = config 13 | 14 | self.input_dim = config.input_dim 15 | self.hidden_dim = config.hidden_dim 16 | self.num_layers = config.num_layers 17 | 18 | self.max_enc_length = config.max_enc_length 19 | self.max_dec_length = config.max_dec_length 20 | self.num_glimpse = config.num_glimpse 21 | 22 | self.init_min_val = config.init_min_val 23 | self.init_max_val = config.init_max_val 24 | self.initializer = \ 25 | tf.random_uniform_initializer(self.init_min_val, self.init_max_val) 26 | 27 | self.use_terminal_symbol = config.use_terminal_symbol 28 | 29 | self.lr_start = config.lr_start 30 | self.lr_decay_step = config.lr_decay_step 31 | self.lr_decay_rate = config.lr_decay_rate 32 | self.max_grad_norm = config.max_grad_norm 33 | 34 | self.layer_dict = {} 35 | 36 | ############## 37 | # inputs 38 | ############## 39 | 40 | self.is_training = tf.placeholder_with_default( 41 | tf.constant(False, dtype=tf.bool), 42 | shape=(), name='is_training' 43 | ) 44 | 45 | self.enc_inputs, self.dec_targets, self.enc_seq_length, self.dec_seq_length, self.mask = \ 46 | tf.contrib.layers.utils.smart_cond( 47 | self.is_training, 48 | lambda: (inputs['train'], labels['train'], enc_seq_length['train'], 49 | dec_seq_length['train'], mask['train']), 50 | lambda: (inputs['test'], labels['test'], enc_seq_length['test'], 51 | dec_seq_length['test'], mask['test']) 52 | ) 53 | 54 | if self.use_terminal_symbol: 55 | self.dec_seq_length += 1 # terminal symbol 56 | 57 | self._build_model() 58 | self._build_steps() 59 | 60 | if not reuse: 61 | self._build_optim() 62 | 63 | self.train_summary = tf.summary.merge([ 64 | tf.summary.scalar("train/total_loss", self.total_loss), 65 | tf.summary.scalar("train/lr", self.lr), 66 | ]) 67 | 68 | self.test_summary = tf.summary.merge([ 69 | tf.summary.scalar("test/total_loss", self.total_loss), 70 | ]) 71 | 72 | def _build_steps(self): 73 | def run(sess, fetch, feed_dict, summary_writer, summary): 74 | fetch['step'] = self.global_step 75 | if summary is not None: 76 | fetch['summary'] = summary 77 | 78 | result = sess.run(fetch) 79 | if summary_writer is not None: 80 | summary_writer.add_summary(result['summary'], result['step']) 81 | summary_writer.flush() 82 | return result 83 | 84 | def train(sess, fetch, summary_writer): 85 | return run(sess, fetch, feed_dict={}, 86 | summary_writer=summary_writer, summary=self.train_summary) 87 | 88 | def test(sess, fetch, summary_writer=None): 89 | return run(sess, fetch, feed_dict={self.is_training: False}, 90 | summary_writer=summary_writer, summary=self.test_summary) 91 | 92 | self.train = train 93 | self.test = test 94 | 95 | def _build_model(self): 96 | tf.logging.info("Create a model..") 97 | self.global_step = tf.Variable(0, trainable=False) 98 | 99 | input_embed = tf.get_variable( 100 | "input_embed", [1, self.input_dim, self.hidden_dim], 101 | initializer=self.initializer) 102 | 103 | with tf.variable_scope("encoder"): 104 | self.embeded_enc_inputs = tf.nn.conv1d( 105 | self.enc_inputs, input_embed, 1, "VALID") 106 | 107 | batch_size = tf.shape(self.enc_inputs)[0] 108 | with tf.variable_scope("encoder"): 109 | self.enc_cell = LSTMCell( 110 | self.hidden_dim, 111 | initializer=self.initializer) 112 | 113 | if self.num_layers > 1: 114 | cells = [self.enc_cell] * self.num_layers 115 | self.enc_cell = MultiRNNCell(cells) 116 | self.enc_init_state = trainable_initial_state( 117 | batch_size, self.enc_cell.state_size) 118 | 119 | # self.encoder_outputs : [None, max_time, output_size] 120 | self.enc_outputs, self.enc_final_states = tf.nn.dynamic_rnn( 121 | self.enc_cell, self.embeded_enc_inputs, 122 | self.enc_seq_length, self.enc_init_state) 123 | 124 | if self.use_terminal_symbol: 125 | # 0 index indicates terminal 126 | self.first_decoder_input = tf.expand_dims(trainable_initial_state( 127 | batch_size, self.hidden_dim, name="first_decoder_input"), 1) 128 | self.enc_outputs = tf.concat_v2( 129 | [self.first_decoder_input, self.enc_outputs], axis=1) 130 | 131 | with tf.variable_scope("dencoder"): 132 | self.idx_pairs = index_matrix_to_pairs(self.dec_targets) 133 | self.embeded_dec_inputs = tf.stop_gradient( 134 | tf.gather_nd(self.enc_outputs, self.idx_pairs)) 135 | 136 | if self.use_terminal_symbol: 137 | tiled_zero_idxs = tf.tile(tf.zeros( 138 | [1, 1], dtype=tf.int32), [batch_size, 1], name="tiled_zero_idxs") 139 | self.dec_targets = tf.concat_v2([self.dec_targets, tiled_zero_idxs], axis=1) 140 | 141 | self.embeded_dec_inputs = tf.concat_v2( 142 | [self.first_decoder_input, self.embeded_dec_inputs], axis=1) 143 | 144 | self.dec_cell = LSTMCell( 145 | self.hidden_dim, 146 | initializer=self.initializer) 147 | 148 | if self.num_layers > 1: 149 | cells = [self.dec_cell] * self.num_layers 150 | self.dec_cell = MultiRNNCell(cells) 151 | 152 | self.dec_pred_logits, _, _ = decoder_rnn( 153 | self.dec_cell, self.embeded_dec_inputs, 154 | self.enc_outputs, self.enc_final_states, 155 | self.dec_seq_length, self.hidden_dim, 156 | self.num_glimpse, batch_size, is_train=True, 157 | initializer=self.initializer) 158 | self.dec_pred_prob = tf.nn.softmax( 159 | self.dec_pred_logits, 2, name="dec_pred_prob") 160 | self.dec_pred = tf.argmax( 161 | self.dec_pred_logits, 2, name="dec_pred") 162 | 163 | with tf.variable_scope("dencoder", reuse=True): 164 | self.dec_inference_logits, _, _ = decoder_rnn( 165 | self.dec_cell, self.first_decoder_input, 166 | self.enc_outputs, self.enc_final_states, 167 | self.dec_seq_length, self.hidden_dim, 168 | self.num_glimpse, batch_size, is_train=False, 169 | initializer=self.initializer, 170 | max_length=self.max_dec_length + int(self.use_terminal_symbol)) 171 | self.dec_inference_prob = tf.nn.softmax( 172 | self.dec_inference_logits, 2, name="dec_inference_logits") 173 | self.dec_inference = tf.argmax( 174 | self.dec_inference_logits, 2, name="dec_inference") 175 | 176 | def _build_optim(self): 177 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 178 | labels=self.dec_targets, logits=self.dec_pred_logits) 179 | inference_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 180 | labels=self.dec_targets, logits=self.dec_inference_logits) 181 | 182 | def apply_mask(op): 183 | length = tf.cast(op[:1], tf.int32) 184 | loss = op[1:] 185 | return tf.multiply(loss, tf.ones(length, dtype=tf.float32)) 186 | 187 | batch_loss = tf.div( 188 | tf.reduce_sum(tf.multiply(losses, self.mask)), 189 | tf.reduce_sum(self.mask), name="batch_loss") 190 | 191 | batch_inference_loss = tf.div( 192 | tf.reduce_sum(tf.multiply(losses, self.mask)), 193 | tf.reduce_sum(self.mask), name="batch_inference_loss") 194 | 195 | tf.losses.add_loss(batch_loss) 196 | total_loss = tf.losses.get_total_loss() 197 | 198 | self.total_loss = total_loss 199 | self.target_cross_entropy_losses = losses 200 | self.total_inference_loss = batch_inference_loss 201 | 202 | self.lr = tf.train.exponential_decay( 203 | self.lr_start, self.global_step, self.lr_decay_step, 204 | self.lr_decay_rate, staircase=True, name="learning_rate") 205 | 206 | optimizer = tf.train.AdamOptimizer(self.lr) 207 | 208 | if self.max_grad_norm != None: 209 | grads_and_vars = optimizer.compute_gradients(self.total_loss) 210 | for idx, (grad, var) in enumerate(grads_and_vars): 211 | if grad is not None: 212 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 213 | self.optim = optimizer.apply_gradients(grads_and_vars, global_step=self.global_step) 214 | else: 215 | self.optim = optimizer.minimize(self.total_loss, global_step=self.global_step) 216 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # Most of the codes are from 2 | # https://github.com/vshallc/PtrNets/blob/master/pointer/misc/tsp.py 3 | import os 4 | import re 5 | import zipfile 6 | import itertools 7 | import threading 8 | import numpy as np 9 | from tqdm import trange, tqdm 10 | from collections import namedtuple 11 | 12 | import tensorflow as tf 13 | from download import download_file_from_google_drive 14 | 15 | GOOGLE_DRIVE_IDS = { 16 | 'tsp5_train.zip': '0B2fg8yPGn2TCSW1pNTJMXzFPYTg', 17 | 'tsp10_train.zip': '0B2fg8yPGn2TCbHowM0hfOTJCNkU', 18 | 'tsp5-20_train.zip': '0B2fg8yPGn2TCTWNxX21jTDBGeXc', 19 | 'tsp50_train.zip': '0B2fg8yPGn2TCaVQxSl9ab29QajA', 20 | 'tsp20_test.txt': '0B2fg8yPGn2TCdF9TUU5DZVNCNjQ', 21 | 'tsp40_test.txt': '0B2fg8yPGn2TCcjFrYk85SGFVNlU', 22 | 'tsp50_test.txt.zip': '0B2fg8yPGn2TCUVlCQmQtelpZTTQ', 23 | } 24 | 25 | TSP = namedtuple('TSP', ['x', 'y', 'name']) 26 | 27 | def length(x, y): 28 | return np.linalg.norm(np.asarray(x) - np.asarray(y)) 29 | 30 | # https://gist.github.com/mlalevic/6222750 31 | def solve_tsp_dynamic(points): 32 | #calc all lengths 33 | all_distances = [[length(x,y) for y in points] for x in points] 34 | #initial value - just distance from 0 to every other point + keep the track of edges 35 | A = {(frozenset([0, idx+1]), idx+1): (dist, [0,idx+1]) for idx,dist in enumerate(all_distances[0][1:])} 36 | cnt = len(points) 37 | for m in range(2, cnt): 38 | B = {} 39 | for S in [frozenset(C) | {0} for C in itertools.combinations(range(1, cnt), m)]: 40 | for j in S - {0}: 41 | B[(S, j)] = min( [(A[(S-{j},k)][0] + all_distances[k][j], A[(S-{j},k)][1] + [j]) for k in S if k != 0 and k!=j]) #this will use 0th index of tuple for ordering, the same as if key=itemgetter(0) used 42 | A = B 43 | res = min([(A[d][0] + all_distances[0][d[1]], A[d][1]) for d in iter(A)]) 44 | return np.asarray(res[1]) + 1 # 0 for padding 45 | 46 | def generate_one_example(n_nodes, rng): 47 | nodes = rng.rand(n_nodes, 2).astype(np.float32) 48 | solutions = solve_tsp_dynamic(nodes) 49 | return nodes, solutions 50 | 51 | def read_paper_dataset(paths, max_length): 52 | x, y = [], [] 53 | for path in paths: 54 | tf.logging.info("Read dataset {} which is used in the paper..".format(path)) 55 | length = max(re.findall('\d+', path)) 56 | with open(path) as f: 57 | for l in tqdm(f): 58 | inputs, outputs = l.split(' output ') 59 | x.append(np.array(inputs.split(), dtype=np.float32).reshape([-1, 2])) 60 | y.append(np.array(outputs.split(), dtype=np.int32)[:-1]) # skip the last one 61 | return x, y 62 | 63 | class TSPDataLoader(object): 64 | def __init__(self, config, rng=None): 65 | self.config = config 66 | self.rng = rng 67 | 68 | self.task = config.task.lower() 69 | self.batch_size = config.batch_size 70 | self.min_length = config.min_data_length 71 | self.max_length = config.max_data_length 72 | 73 | self.is_train = config.is_train 74 | self.use_terminal_symbol = config.use_terminal_symbol 75 | self.random_seed = config.random_seed 76 | 77 | self.data_num = {} 78 | self.data_num['train'] = config.train_num 79 | self.data_num['test'] = config.test_num 80 | 81 | self.data_dir = config.data_dir 82 | self.task_name = "{}_({},{})".format( 83 | self.task, self.min_length, self.max_length) 84 | 85 | self.data = None 86 | self.coord = None 87 | self.input_ops, self.target_ops = None, None 88 | self.queue_ops, self.enqueue_ops = None, None 89 | self.x, self.y, self.seq_length, self.mask = None, None, None, None 90 | 91 | paths = self.download_google_drive_file() 92 | if len(paths) != 0: 93 | self._maybe_generate_and_save(except_list=paths.keys()) 94 | for name, path in paths.items(): 95 | self.read_zip_and_update_data(path, name) 96 | else: 97 | self._maybe_generate_and_save() 98 | self._create_input_queue() 99 | 100 | def _create_input_queue(self, queue_capacity_factor=16): 101 | self.input_ops, self.target_ops = {}, {} 102 | self.queue_ops, self.enqueue_ops = {}, {} 103 | self.x, self.y, self.seq_length, self.mask = {}, {}, {}, {} 104 | 105 | for name in self.data_num.keys(): 106 | self.input_ops[name] = tf.placeholder(tf.float32, shape=[None, None]) 107 | self.target_ops[name] = tf.placeholder(tf.int32, shape=[None]) 108 | 109 | min_after_dequeue = 1000 110 | capacity = min_after_dequeue + 3 * self.batch_size 111 | 112 | self.queue_ops[name] = tf.RandomShuffleQueue( 113 | capacity=capacity, 114 | min_after_dequeue=min_after_dequeue, 115 | dtypes=[tf.float32, tf.int32], 116 | shapes=[[self.max_length, 2,], [self.max_length]], 117 | seed=self.random_seed, 118 | name="random_queue_{}".format(name)) 119 | self.enqueue_ops[name] = \ 120 | self.queue_ops[name].enqueue([self.input_ops[name], self.target_ops[name]]) 121 | 122 | inputs, labels = self.queue_ops[name].dequeue() 123 | 124 | seq_length = tf.shape(inputs)[0] 125 | if self.use_terminal_symbol: 126 | mask = tf.ones([seq_length + 1], dtype=tf.float32) # terminal symbol 127 | else: 128 | mask = tf.ones([seq_length], dtype=tf.float32) 129 | 130 | self.x[name], self.y[name], self.seq_length[name], self.mask[name] = \ 131 | tf.train.batch( 132 | [inputs, labels, seq_length, mask], 133 | batch_size=self.batch_size, 134 | capacity=capacity, 135 | dynamic_pad=True, 136 | name="batch_and_pad") 137 | 138 | def run_input_queue(self, sess): 139 | threads = [] 140 | self.coord = tf.train.Coordinator() 141 | 142 | for name in self.data_num.keys(): 143 | def load_and_enqueue(sess, name, input_ops, target_ops, enqueue_ops, coord): 144 | idx = 0 145 | while not coord.should_stop(): 146 | feed_dict = { 147 | input_ops[name]: self.data[name].x[idx], 148 | target_ops[name]: self.data[name].y[idx], 149 | } 150 | sess.run(self.enqueue_ops[name], feed_dict=feed_dict) 151 | idx = idx+1 if idx+1 <= len(self.data[name].x) - 1 else 0 152 | 153 | args = (sess, name, self.input_ops, self.target_ops, self.enqueue_ops, self.coord) 154 | t = threading.Thread(target=load_and_enqueue, args=args) 155 | t.start() 156 | threads.append(t) 157 | tf.logging.info("Thread start for [{}]".format(name)) 158 | 159 | def stop_input_queue(self): 160 | self.coord.request_stop() 161 | self.coord.join(threads) 162 | 163 | def _maybe_generate_and_save(self, except_list=[]): 164 | self.data = {} 165 | 166 | for name, num in self.data_num.items(): 167 | if name in except_list: 168 | tf.logging.info("Skip creating {} because of given except_list {}".format(name, except_list)) 169 | continue 170 | path = self.get_path(name) 171 | 172 | if not os.path.exists(path): 173 | tf.logging.info("Creating {} for [{}]".format(path, self.task)) 174 | 175 | x = np.zeros([num, self.max_length, 2], dtype=np.float32) 176 | y = np.zeros([num, self.max_length], dtype=np.int32) 177 | 178 | for idx in trange(num, desc="Create {} data".format(name)): 179 | n_nodes = self.rng.randint(self.min_length, self.max_length+ 1) 180 | nodes, res = generate_one_example(n_nodes, self.rng) 181 | x[idx,:len(nodes)] = nodes 182 | y[idx,:len(res)] = res 183 | 184 | np.savez(path, x=x, y=y) 185 | self.data[name] = TSP(x=x, y=y, name=name) 186 | else: 187 | tf.logging.info("Skip creating {} for [{}]".format(path, self.task)) 188 | tmp = np.load(path) 189 | self.data[name] = TSP(x=tmp['x'], y=tmp['y'], name=name) 190 | 191 | def get_path(self, name): 192 | return os.path.join( 193 | self.data_dir, "{}_{}={}.npz".format( 194 | self.task_name, name, self.data_num[name])) 195 | 196 | def download_google_drive_file(self): 197 | paths = {} 198 | for mode in ['train', 'test']: 199 | candidates = [] 200 | candidates.append( 201 | '{}{}_{}'.format(self.task, self.max_length, mode)) 202 | candidates.append( 203 | '{}{}-{}_{}'.format(self.task, self.min_length, self.max_length, mode)) 204 | 205 | for key in candidates: 206 | for search_key in GOOGLE_DRIVE_IDS.keys(): 207 | if search_key.startswith(key): 208 | path = os.path.join(self.data_dir, search_key) 209 | tf.logging.info("Download dataset of the paper to {}".format(path)) 210 | 211 | if not os.path.exists(path): 212 | download_file_from_google_drive(GOOGLE_DRIVE_IDS[search_key], path) 213 | if path.endswith('zip'): 214 | with zipfile.ZipFile(path, 'r') as z: 215 | z.extractall(self.data_dir) 216 | paths[mode] = path 217 | 218 | tf.logging.info("Can't found dataset from the paper!") 219 | return paths 220 | 221 | def read_zip_and_update_data(self, path, name): 222 | if path.endswith('zip'): 223 | filenames = zipfile.ZipFile(path).namelist() 224 | paths = [os.path.join(self.data_dir, filename) for filename in filenames] 225 | else: 226 | paths = [path] 227 | 228 | x_list, y_list = read_paper_dataset(paths, self.max_length) 229 | 230 | x = np.zeros([len(x_list), self.max_length, 2], dtype=np.float32) 231 | y = np.zeros([len(y_list), self.max_length], dtype=np.int32) 232 | 233 | for idx, (nodes, res) in enumerate(tqdm(zip(x_list, y_list))): 234 | x[idx,:len(nodes)] = nodes 235 | y[idx,:len(res)] = res 236 | 237 | if self.data is None: 238 | self.data = {} 239 | 240 | tf.logging.info("Update [{}] data with {} used in the paper".format(name, path)) 241 | self.data[name] = TSP(x=x, y=y, name=name) 242 | --------------------------------------------------------------------------------