├── .gitignore ├── LICENSE ├── README.md ├── algo ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── behavior_clone.cpython-35.pyc │ ├── ppo.cpython-35.pyc │ └── ppo.cpython-36.pyc ├── behavior_clone.py └── ppo.py ├── images ├── graph.png └── legend.png ├── log ├── test │ ├── bc │ │ └── events.out.tfevents.1516152758.nakata-System-Product-Name │ ├── gail │ │ └── events.out.tfevents.1516152538.nakata-System-Product-Name │ └── ppo │ │ └── events.out.tfevents.1516152554.nakata-System-Product-Name └── train │ ├── bc │ └── events.out.tfevents.1516152197.nakata-System-Product-Name │ ├── gail │ └── events.out.tfevents.1516152230.nakata-System-Product-Name │ └── ppo │ └── events.out.tfevents.1542158042.YusukenoMacBook-pro.local ├── network_models ├── __init__.py ├── discriminator.py └── policy_net.py ├── run_behavior_clone.py ├── run_gail.py ├── run_ppo.py ├── sample_trajectory.py ├── test_policy.py ├── trained_models ├── bc │ ├── checkpoint │ ├── model.ckpt-100.data-00000-of-00001 │ ├── model.ckpt-100.index │ ├── model.ckpt-100.meta │ ├── model.ckpt-1000.data-00000-of-00001 │ ├── model.ckpt-1000.index │ ├── model.ckpt-1000.meta │ ├── model.ckpt-200.data-00000-of-00001 │ ├── model.ckpt-200.index │ ├── model.ckpt-200.meta │ ├── model.ckpt-300.data-00000-of-00001 │ ├── model.ckpt-300.index │ ├── model.ckpt-300.meta │ ├── model.ckpt-400.data-00000-of-00001 │ ├── model.ckpt-400.index │ ├── model.ckpt-400.meta │ ├── model.ckpt-500.data-00000-of-00001 │ ├── model.ckpt-500.index │ ├── model.ckpt-500.meta │ ├── model.ckpt-600.data-00000-of-00001 │ ├── model.ckpt-600.index │ ├── model.ckpt-600.meta │ ├── model.ckpt-700.data-00000-of-00001 │ ├── model.ckpt-700.index │ ├── model.ckpt-700.meta │ ├── model.ckpt-800.data-00000-of-00001 │ ├── model.ckpt-800.index │ ├── model.ckpt-800.meta │ ├── model.ckpt-900.data-00000-of-00001 │ ├── model.ckpt-900.index │ └── model.ckpt-900.meta ├── gail │ ├── checkpoint │ ├── model.ckpt.data-00000-of-00001 │ ├── model.ckpt.index │ └── model.ckpt.meta └── ppo │ ├── checkpoint │ ├── model.ckpt.data-00000-of-00001 │ ├── model.ckpt.index │ └── model.ckpt.meta └── trajectory ├── actions.csv └── observations.csv /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yusuke Nakata 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Imitation Learning 2 | Implementation of Generative Adversarial Imitation Learning(GAIL) using tensorflow 3 | 4 | ## Dependencies 5 | python>=3.5 6 | tensorflow>=1.4 7 | gym>=0.9.3 8 | 9 | ## Gym environment 10 | 11 | Env==CartPole-v0 12 | State==Continuous 13 | Action==Discrete 14 | 15 | ## Usage 16 | 17 | **Train experts** 18 | ``` 19 | python3 run_ppo.py 20 | ``` 21 | **Sample trajectory using expert** 22 | ``` 23 | python3 sample_trajectory.py 24 | ``` 25 | **Run GAIL** 26 | ``` 27 | python3 run_gail.py 28 | ``` 29 | **Run supervised learning** 30 | ``` 31 | python3 run_behavior_clone.py 32 | ``` 33 | **Test trained policy** 34 | ``` 35 | python3 test_policy.py 36 | ``` 37 | Default policy is trained with gail 38 | --alg=bc or ppo allows you to change test policy 39 | 40 | If you want to test bc policy, specify the _number_ of model.ckpt-_number_ in the directory trained_models/bc 41 | Example 42 | ``` 43 | python3 test_policy.py --alg=bc --model=1000 44 | ``` 45 | **Tensorboard** 46 | ``` 47 | tensorboard --logdir=log 48 | ``` 49 | 50 | ## Results 51 | 52 | | ![](./images/graph.png) | ![](./images/legend.png) | 53 | | :---: | :---: | 54 | | Fig.1 Training results | legend | 55 | 56 | ## LICENSE 57 | MIT LICENSE 58 | -------------------------------------------------------------------------------- /algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/algo/__init__.py -------------------------------------------------------------------------------- /algo/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/algo/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /algo/__pycache__/behavior_clone.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/algo/__pycache__/behavior_clone.cpython-35.pyc -------------------------------------------------------------------------------- /algo/__pycache__/ppo.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/algo/__pycache__/ppo.cpython-35.pyc -------------------------------------------------------------------------------- /algo/__pycache__/ppo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/algo/__pycache__/ppo.cpython-36.pyc -------------------------------------------------------------------------------- /algo/behavior_clone.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class BehavioralCloning: 5 | def __init__(self, Policy): 6 | self.Policy = Policy 7 | 8 | self.actions_expert = tf.placeholder(tf.int32, shape=[None], name='actions_expert') 9 | 10 | actions_vec = tf.one_hot(self.actions_expert, depth=self.Policy.act_probs.shape[1], dtype=tf.float32) 11 | 12 | loss = tf.reduce_sum(actions_vec * tf.log(tf.clip_by_value(self.Policy.act_probs, 1e-10, 1.0)), 1) 13 | loss = - tf.reduce_mean(loss) 14 | tf.summary.scalar('loss/cross_entropy', loss) 15 | 16 | optimizer = tf.train.AdamOptimizer() 17 | self.train_op = optimizer.minimize(loss) 18 | 19 | self.merged = tf.summary.merge_all() 20 | 21 | def train(self, obs, actions): 22 | return tf.get_default_session().run(self.train_op, feed_dict={self.Policy.obs: obs, 23 | self.actions_expert: actions}) 24 | 25 | def get_summary(self, obs, actions): 26 | return tf.get_default_session().run(self.merged, feed_dict={self.Policy.obs: obs, 27 | self.actions_expert: actions}) 28 | 29 | -------------------------------------------------------------------------------- /algo/ppo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import copy 3 | 4 | 5 | class PPOTrain: 6 | def __init__(self, Policy, Old_Policy, gamma=0.95, clip_value=0.2, c_1=1, c_2=0.01): 7 | """ 8 | :param Policy: 9 | :param Old_Policy: 10 | :param gamma: 11 | :param clip_value: 12 | :param c_1: parameter for value difference 13 | :param c_2: parameter for entropy bonus 14 | """ 15 | 16 | self.Policy = Policy 17 | self.Old_Policy = Old_Policy 18 | self.gamma = gamma 19 | 20 | pi_trainable = self.Policy.get_trainable_variables() 21 | old_pi_trainable = self.Old_Policy.get_trainable_variables() 22 | 23 | # assign_operations for policy parameter values to old policy parameters 24 | with tf.variable_scope('assign_op'): 25 | self.assign_ops = [] 26 | for v_old, v in zip(old_pi_trainable, pi_trainable): 27 | self.assign_ops.append(tf.assign(v_old, v)) 28 | 29 | # inputs for train_op 30 | with tf.variable_scope('train_inp'): 31 | self.actions = tf.placeholder(dtype=tf.int32, shape=[None], name='actions') 32 | self.rewards = tf.placeholder(dtype=tf.float32, shape=[None], name='rewards') 33 | self.v_preds_next = tf.placeholder(dtype=tf.float32, shape=[None], name='v_preds_next') 34 | self.gaes = tf.placeholder(dtype=tf.float32, shape=[None], name='gaes') 35 | 36 | act_probs = self.Policy.act_probs 37 | act_probs_old = self.Old_Policy.act_probs 38 | 39 | # probabilities of actions which agent took with policy 40 | act_probs = act_probs * tf.one_hot(indices=self.actions, depth=act_probs.shape[1]) 41 | act_probs = tf.reduce_sum(act_probs, axis=1) 42 | 43 | # probabilities of actions which agent took with old policy 44 | act_probs_old = act_probs_old * tf.one_hot(indices=self.actions, depth=act_probs_old.shape[1]) 45 | act_probs_old = tf.reduce_sum(act_probs_old, axis=1) 46 | 47 | with tf.variable_scope('loss'): 48 | # construct computation graph for loss_clip 49 | # ratios = tf.divide(act_probs, act_probs_old) 50 | ratios = tf.exp(tf.log(tf.clip_by_value(act_probs, 1e-10, 1.0)) 51 | - tf.log(tf.clip_by_value(act_probs_old, 1e-10, 1.0))) 52 | clipped_ratios = tf.clip_by_value(ratios, clip_value_min=1 - clip_value, clip_value_max=1 + clip_value) 53 | loss_clip = tf.minimum(tf.multiply(self.gaes, ratios), tf.multiply(self.gaes, clipped_ratios)) 54 | loss_clip = tf.reduce_mean(loss_clip) 55 | tf.summary.scalar('loss_clip', loss_clip) 56 | 57 | # construct computation graph for loss of entropy bonus 58 | entropy = -tf.reduce_sum(self.Policy.act_probs * 59 | tf.log(tf.clip_by_value(self.Policy.act_probs, 1e-10, 1.0)), axis=1) 60 | entropy = tf.reduce_mean(entropy, axis=0) # mean of entropy of pi(obs) 61 | tf.summary.scalar('entropy', entropy) 62 | 63 | # construct computation graph for loss of value function 64 | v_preds = self.Policy.v_preds 65 | loss_vf = tf.squared_difference(self.rewards + self.gamma * self.v_preds_next, v_preds) 66 | loss_vf = tf.reduce_mean(loss_vf) 67 | tf.summary.scalar('value_difference', loss_vf) 68 | 69 | # construct computation graph for loss 70 | loss = loss_clip - c_1 * loss_vf + c_2 * entropy 71 | 72 | # minimize -loss == maximize loss 73 | loss = -loss 74 | tf.summary.scalar('total', loss) 75 | 76 | self.merged = tf.summary.merge_all() 77 | optimizer = tf.train.AdamOptimizer(learning_rate=5e-5, epsilon=1e-5) 78 | self.gradients = optimizer.compute_gradients(loss, var_list=pi_trainable) 79 | self.train_op = optimizer.minimize(loss, var_list=pi_trainable) 80 | 81 | def train(self, obs, actions, gaes, rewards, v_preds_next): 82 | tf.get_default_session().run(self.train_op, feed_dict={self.Policy.obs: obs, 83 | self.Old_Policy.obs: obs, 84 | self.actions: actions, 85 | self.rewards: rewards, 86 | self.v_preds_next: v_preds_next, 87 | self.gaes: gaes}) 88 | 89 | def get_summary(self, obs, actions, gaes, rewards, v_preds_next): 90 | return tf.get_default_session().run(self.merged, feed_dict={self.Policy.obs: obs, 91 | self.Old_Policy.obs: obs, 92 | self.actions: actions, 93 | self.rewards: rewards, 94 | self.v_preds_next: v_preds_next, 95 | self.gaes: gaes}) 96 | 97 | def assign_policy_parameters(self): 98 | # assign policy parameter values to old policy parameters 99 | return tf.get_default_session().run(self.assign_ops) 100 | 101 | def get_gaes(self, rewards, v_preds, v_preds_next): 102 | deltas = [r_t + self.gamma * v_next - v for r_t, v_next, v in zip(rewards, v_preds_next, v_preds)] 103 | # calculate generative advantage estimator(lambda = 1), see ppo paper eq(11) 104 | gaes = copy.deepcopy(deltas) 105 | for t in reversed(range(len(gaes) - 1)): # is T-1, where T is time step which run policy 106 | gaes[t] = gaes[t] + self.gamma * gaes[t + 1] 107 | return gaes 108 | 109 | def get_grad(self, obs, actions, gaes, rewards, v_preds_next): 110 | return tf.get_default_session().run(self.gradients, feed_dict={self.Policy.obs: obs, 111 | self.Old_Policy.obs: obs, 112 | self.actions: actions, 113 | self.rewards: rewards, 114 | self.v_preds_next: v_preds_next, 115 | self.gaes: gaes}) 116 | -------------------------------------------------------------------------------- /images/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/images/graph.png -------------------------------------------------------------------------------- /images/legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/images/legend.png -------------------------------------------------------------------------------- /log/test/bc/events.out.tfevents.1516152758.nakata-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/test/bc/events.out.tfevents.1516152758.nakata-System-Product-Name -------------------------------------------------------------------------------- /log/test/gail/events.out.tfevents.1516152538.nakata-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/test/gail/events.out.tfevents.1516152538.nakata-System-Product-Name -------------------------------------------------------------------------------- /log/test/ppo/events.out.tfevents.1516152554.nakata-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/test/ppo/events.out.tfevents.1516152554.nakata-System-Product-Name -------------------------------------------------------------------------------- /log/train/bc/events.out.tfevents.1516152197.nakata-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/train/bc/events.out.tfevents.1516152197.nakata-System-Product-Name -------------------------------------------------------------------------------- /log/train/gail/events.out.tfevents.1516152230.nakata-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/train/gail/events.out.tfevents.1516152230.nakata-System-Product-Name -------------------------------------------------------------------------------- /log/train/ppo/events.out.tfevents.1542158042.YusukenoMacBook-pro.local: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/log/train/ppo/events.out.tfevents.1542158042.YusukenoMacBook-pro.local -------------------------------------------------------------------------------- /network_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/network_models/__init__.py -------------------------------------------------------------------------------- /network_models/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Discriminator: 5 | def __init__(self, env): 6 | """ 7 | :param env: 8 | Output of this Discriminator is reward for learning agent. Not the cost. 9 | Because discriminator predicts P(expert|s,a) = 1 - P(agent|s,a). 10 | """ 11 | 12 | with tf.variable_scope('discriminator'): 13 | self.scope = tf.get_variable_scope().name 14 | self.expert_s = tf.placeholder(dtype=tf.float32, shape=[None] + list(env.observation_space.shape)) 15 | self.expert_a = tf.placeholder(dtype=tf.int32, shape=[None]) 16 | expert_a_one_hot = tf.one_hot(self.expert_a, depth=env.action_space.n) 17 | # add noise for stabilise training 18 | expert_a_one_hot += tf.random_normal(tf.shape(expert_a_one_hot), mean=0.2, stddev=0.1, dtype=tf.float32)/1.2 19 | expert_s_a = tf.concat([self.expert_s, expert_a_one_hot], axis=1) 20 | 21 | self.agent_s = tf.placeholder(dtype=tf.float32, shape=[None] + list(env.observation_space.shape)) 22 | self.agent_a = tf.placeholder(dtype=tf.int32, shape=[None]) 23 | agent_a_one_hot = tf.one_hot(self.agent_a, depth=env.action_space.n) 24 | # add noise for stabilise training 25 | agent_a_one_hot += tf.random_normal(tf.shape(agent_a_one_hot), mean=0.2, stddev=0.1, dtype=tf.float32)/1.2 26 | agent_s_a = tf.concat([self.agent_s, agent_a_one_hot], axis=1) 27 | 28 | with tf.variable_scope('network') as network_scope: 29 | prob_1 = self.construct_network(input=expert_s_a) 30 | network_scope.reuse_variables() # share parameter 31 | prob_2 = self.construct_network(input=agent_s_a) 32 | 33 | with tf.variable_scope('loss'): 34 | loss_expert = tf.reduce_mean(tf.log(tf.clip_by_value(prob_1, 0.01, 1))) 35 | loss_agent = tf.reduce_mean(tf.log(tf.clip_by_value(1 - prob_2, 0.01, 1))) 36 | loss = loss_expert + loss_agent 37 | loss = -loss 38 | tf.summary.scalar('discriminator', loss) 39 | 40 | optimizer = tf.train.AdamOptimizer() 41 | self.train_op = optimizer.minimize(loss) 42 | 43 | self.rewards = tf.log(tf.clip_by_value(prob_2, 1e-10, 1)) # log(P(expert|s,a)) larger is better for agent 44 | 45 | def construct_network(self, input): 46 | layer_1 = tf.layers.dense(inputs=input, units=20, activation=tf.nn.leaky_relu, name='layer1') 47 | layer_2 = tf.layers.dense(inputs=layer_1, units=20, activation=tf.nn.leaky_relu, name='layer2') 48 | layer_3 = tf.layers.dense(inputs=layer_2, units=20, activation=tf.nn.leaky_relu, name='layer3') 49 | prob = tf.layers.dense(inputs=layer_3, units=1, activation=tf.sigmoid, name='prob') 50 | return prob 51 | 52 | def train(self, expert_s, expert_a, agent_s, agent_a): 53 | return tf.get_default_session().run(self.train_op, feed_dict={self.expert_s: expert_s, 54 | self.expert_a: expert_a, 55 | self.agent_s: agent_s, 56 | self.agent_a: agent_a}) 57 | 58 | def get_rewards(self, agent_s, agent_a): 59 | return tf.get_default_session().run(self.rewards, feed_dict={self.agent_s: agent_s, 60 | self.agent_a: agent_a}) 61 | 62 | def get_trainable_variables(self): 63 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 64 | 65 | -------------------------------------------------------------------------------- /network_models/policy_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Policy_net: 5 | def __init__(self, name: str, env): 6 | """ 7 | :param name: string 8 | :param env: gym env 9 | """ 10 | 11 | ob_space = env.observation_space 12 | act_space = env.action_space 13 | 14 | with tf.variable_scope(name): 15 | self.obs = tf.placeholder(dtype=tf.float32, shape=[None] + list(ob_space.shape), name='obs') 16 | 17 | with tf.variable_scope('policy_net'): 18 | layer_1 = tf.layers.dense(inputs=self.obs, units=20, activation=tf.tanh) 19 | layer_2 = tf.layers.dense(inputs=layer_1, units=20, activation=tf.tanh) 20 | layer_3 = tf.layers.dense(inputs=layer_2, units=act_space.n, activation=tf.tanh) 21 | self.act_probs = tf.layers.dense(inputs=layer_3, units=act_space.n, activation=tf.nn.softmax) 22 | 23 | with tf.variable_scope('value_net'): 24 | layer_1 = tf.layers.dense(inputs=self.obs, units=20, activation=tf.tanh) 25 | layer_2 = tf.layers.dense(inputs=layer_1, units=20, activation=tf.tanh) 26 | self.v_preds = tf.layers.dense(inputs=layer_2, units=1, activation=None) 27 | 28 | self.act_stochastic = tf.multinomial(tf.log(self.act_probs), num_samples=1) 29 | self.act_stochastic = tf.reshape(self.act_stochastic, shape=[-1]) 30 | 31 | self.act_deterministic = tf.argmax(self.act_probs, axis=1) 32 | 33 | self.scope = tf.get_variable_scope().name 34 | 35 | def act(self, obs, stochastic=True): 36 | if stochastic: 37 | return tf.get_default_session().run([self.act_stochastic, self.v_preds], feed_dict={self.obs: obs}) 38 | else: 39 | return tf.get_default_session().run([self.act_deterministic, self.v_preds], feed_dict={self.obs: obs}) 40 | 41 | def get_action_prob(self, obs): 42 | return tf.get_default_session().run(self.act_probs, feed_dict={self.obs: obs}) 43 | 44 | def get_variables(self): 45 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope) 46 | 47 | def get_trainable_variables(self): 48 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 49 | 50 | -------------------------------------------------------------------------------- /run_behavior_clone.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | import tensorflow as tf 5 | from network_models.policy_net import Policy_net 6 | from algo.behavior_clone import BehavioralCloning 7 | 8 | 9 | def argparser(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--savedir', help='name of directory to save model', default='trained_models/bc') 12 | parser.add_argument('--max_to_keep', help='number of models to save', default=10, type=int) 13 | parser.add_argument('--logdir', help='log directory', default='log/train/bc') 14 | parser.add_argument('--iteration', default=int(1e3), type=int) 15 | parser.add_argument('--interval', help='save interval', default=int(1e2), type=int) 16 | parser.add_argument('--minibatch_size', default=128, type=int) 17 | parser.add_argument('--epoch_num', default=10, type=int) 18 | return parser.parse_args() 19 | 20 | 21 | def main(args): 22 | env = gym.make('CartPole-v0') 23 | Policy = Policy_net('policy', env) 24 | BC = BehavioralCloning(Policy) 25 | saver = tf.train.Saver(max_to_keep=args.max_to_keep) 26 | 27 | observations = np.genfromtxt('trajectory/observations.csv') 28 | actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32) 29 | 30 | with tf.Session() as sess: 31 | writer = tf.summary.FileWriter(args.logdir, sess.graph) 32 | sess.run(tf.global_variables_initializer()) 33 | 34 | inp = [observations, actions] 35 | 36 | for iteration in range(args.iteration): # episode 37 | 38 | # train 39 | for epoch in range(args.epoch_num): 40 | # select sample indices in [low, high) 41 | sample_indices = np.random.randint(low=0, high=observations.shape[0], size=args.minibatch_size) 42 | 43 | sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data 44 | BC.train(obs=sampled_inp[0], actions=sampled_inp[1]) 45 | 46 | summary = BC.get_summary(obs=inp[0], actions=inp[1]) 47 | 48 | if (iteration+1) % args.interval == 0: 49 | saver.save(sess, args.savedir + '/model.ckpt', global_step=iteration+1) 50 | 51 | writer.add_summary(summary, iteration) 52 | writer.close() 53 | 54 | 55 | if __name__ == '__main__': 56 | args = argparser() 57 | main(args) 58 | -------------------------------------------------------------------------------- /run_gail.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import argparse 3 | import gym 4 | import numpy as np 5 | import tensorflow as tf 6 | from network_models.policy_net import Policy_net 7 | from network_models.discriminator import Discriminator 8 | from algo.ppo import PPOTrain 9 | 10 | 11 | def argparser(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--logdir', help='log directory', default='log/train/gail') 14 | parser.add_argument('--savedir', help='save directory', default='trained_models/gail') 15 | parser.add_argument('--gamma', default=0.95) 16 | parser.add_argument('--iteration', default=int(1e4)) 17 | return parser.parse_args() 18 | 19 | 20 | def main(args): 21 | env = gym.make('CartPole-v0') 22 | env.seed(0) 23 | ob_space = env.observation_space 24 | Policy = Policy_net('policy', env) 25 | Old_Policy = Policy_net('old_policy', env) 26 | PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma) 27 | D = Discriminator(env) 28 | 29 | expert_observations = np.genfromtxt('trajectory/observations.csv') 30 | expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32) 31 | 32 | saver = tf.train.Saver() 33 | 34 | with tf.Session() as sess: 35 | writer = tf.summary.FileWriter(args.logdir, sess.graph) 36 | sess.run(tf.global_variables_initializer()) 37 | 38 | obs = env.reset() 39 | success_num = 0 40 | 41 | for iteration in range(args.iteration): 42 | observations = [] 43 | actions = [] 44 | # do NOT use rewards to update policy 45 | rewards = [] 46 | v_preds = [] 47 | run_policy_steps = 0 48 | while True: 49 | run_policy_steps += 1 50 | obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs 51 | act, v_pred = Policy.act(obs=obs, stochastic=True) 52 | 53 | act = np.asscalar(act) 54 | v_pred = np.asscalar(v_pred) 55 | next_obs, reward, done, info = env.step(act) 56 | 57 | observations.append(obs) 58 | actions.append(act) 59 | rewards.append(reward) 60 | v_preds.append(v_pred) 61 | 62 | if done: 63 | next_obs = np.stack([next_obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs 64 | _, v_pred = Policy.act(obs=next_obs, stochastic=True) 65 | v_preds_next = v_preds[1:] + [np.asscalar(v_pred)] 66 | obs = env.reset() 67 | break 68 | else: 69 | obs = next_obs 70 | 71 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)]) 72 | , iteration) 73 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))]) 74 | , iteration) 75 | 76 | if sum(rewards) >= 195: 77 | success_num += 1 78 | if success_num >= 100: 79 | saver.save(sess, args.savedir + '/model.ckpt') 80 | print('Clear!! Model saved.') 81 | break 82 | else: 83 | success_num = 0 84 | 85 | # convert list to numpy array for feeding tf.placeholder 86 | observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape)) 87 | actions = np.array(actions).astype(dtype=np.int32) 88 | 89 | # train discriminator 90 | for i in range(2): 91 | D.train(expert_s=expert_observations, 92 | expert_a=expert_actions, 93 | agent_s=observations, 94 | agent_a=actions) 95 | 96 | # output of this discriminator is reward 97 | d_rewards = D.get_rewards(agent_s=observations, agent_a=actions) 98 | d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32) 99 | 100 | gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next) 101 | gaes = np.array(gaes).astype(dtype=np.float32) 102 | # gaes = (gaes - gaes.mean()) / gaes.std() 103 | v_preds_next = np.array(v_preds_next).astype(dtype=np.float32) 104 | 105 | # train policy 106 | inp = [observations, actions, gaes, d_rewards, v_preds_next] 107 | PPO.assign_policy_parameters() 108 | for epoch in range(6): 109 | sample_indices = np.random.randint(low=0, high=observations.shape[0], 110 | size=32) # indices are in [low, high) 111 | sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data 112 | PPO.train(obs=sampled_inp[0], 113 | actions=sampled_inp[1], 114 | gaes=sampled_inp[2], 115 | rewards=sampled_inp[3], 116 | v_preds_next=sampled_inp[4]) 117 | 118 | summary = PPO.get_summary(obs=inp[0], 119 | actions=inp[1], 120 | gaes=inp[2], 121 | rewards=inp[3], 122 | v_preds_next=inp[4]) 123 | 124 | writer.add_summary(summary, iteration) 125 | writer.close() 126 | 127 | 128 | if __name__ == '__main__': 129 | args = argparser() 130 | main(args) 131 | -------------------------------------------------------------------------------- /run_ppo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import argparse 3 | import gym 4 | import numpy as np 5 | import tensorflow as tf 6 | from network_models.policy_net import Policy_net 7 | from algo.ppo import PPOTrain 8 | 9 | 10 | def argparser(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--logdir', help='log directory', default='log/train/ppo') 13 | parser.add_argument('--savedir', help='save directory', default='trained_models/ppo') 14 | parser.add_argument('--gamma', default=0.95, type=float) 15 | parser.add_argument('--iteration', default=int(1e4), type=int) 16 | return parser.parse_args() 17 | 18 | 19 | def main(args): 20 | env = gym.make('CartPole-v0') 21 | env.seed(0) 22 | ob_space = env.observation_space 23 | Policy = Policy_net('policy', env) 24 | Old_Policy = Policy_net('old_policy', env) 25 | PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma) 26 | saver = tf.train.Saver() 27 | 28 | with tf.Session() as sess: 29 | writer = tf.summary.FileWriter(args.logdir, sess.graph) 30 | sess.run(tf.global_variables_initializer()) 31 | obs = env.reset() 32 | success_num = 0 33 | 34 | for iteration in range(args.iteration): 35 | observations = [] 36 | actions = [] 37 | rewards = [] 38 | v_preds = [] 39 | episode_length = 0 40 | while True: # run policy RUN_POLICY_STEPS which is much less than episode length 41 | episode_length += 1 42 | obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs 43 | act, v_pred = Policy.act(obs=obs, stochastic=True) 44 | 45 | act = np.asscalar(act) 46 | v_pred = np.asscalar(v_pred) 47 | 48 | next_obs, reward, done, info = env.step(act) 49 | 50 | observations.append(obs) 51 | actions.append(act) 52 | rewards.append(reward) 53 | v_preds.append(v_pred) 54 | 55 | if done: 56 | next_obs = np.stack([next_obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs 57 | _, v_pred = Policy.act(obs=next_obs, stochastic=True) 58 | v_preds_next = v_preds[1:] + [np.asscalar(v_pred)] 59 | obs = env.reset() 60 | break 61 | else: 62 | obs = next_obs 63 | 64 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=episode_length)]) 65 | , iteration) 66 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))]) 67 | , iteration) 68 | 69 | if sum(rewards) >= 195: 70 | success_num += 1 71 | if success_num >= 100: 72 | saver.save(sess, args.savedir+'/model.ckpt') 73 | print('Clear!! Model saved.') 74 | break 75 | else: 76 | success_num = 0 77 | 78 | gaes = PPO.get_gaes(rewards=rewards, v_preds=v_preds, v_preds_next=v_preds_next) 79 | 80 | # convert list to numpy array for feeding tf.placeholder 81 | observations = np.reshape(observations, newshape=(-1,) + ob_space.shape) 82 | actions = np.array(actions).astype(dtype=np.int32) 83 | gaes = np.array(gaes).astype(dtype=np.float32) 84 | gaes = (gaes - gaes.mean()) / gaes.std() 85 | rewards = np.array(rewards).astype(dtype=np.float32) 86 | v_preds_next = np.array(v_preds_next).astype(dtype=np.float32) 87 | 88 | PPO.assign_policy_parameters() 89 | 90 | inp = [observations, actions, gaes, rewards, v_preds_next] 91 | 92 | # train 93 | for epoch in range(6): 94 | # sample indices from [low, high) 95 | sample_indices = np.random.randint(low=0, high=observations.shape[0], size=32) 96 | sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data 97 | PPO.train(obs=sampled_inp[0], 98 | actions=sampled_inp[1], 99 | gaes=sampled_inp[2], 100 | rewards=sampled_inp[3], 101 | v_preds_next=sampled_inp[4]) 102 | 103 | summary = PPO.get_summary(obs=inp[0], 104 | actions=inp[1], 105 | gaes=inp[2], 106 | rewards=inp[3], 107 | v_preds_next=inp[4]) 108 | 109 | writer.add_summary(summary, iteration) 110 | writer.close() 111 | 112 | 113 | if __name__ == '__main__': 114 | args = argparser() 115 | main(args) 116 | -------------------------------------------------------------------------------- /sample_trajectory.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | from network_models.policy_net import Policy_net 5 | import tensorflow as tf 6 | 7 | 8 | # noinspection PyTypeChecker 9 | def open_file_and_save(file_path, data): 10 | """ 11 | :param file_path: type==string 12 | :param data: 13 | """ 14 | try: 15 | with open(file_path, 'ab') as f_handle: 16 | np.savetxt(f_handle, data, fmt='%s') 17 | except FileNotFoundError: 18 | with open(file_path, 'wb') as f_handle: 19 | np.savetxt(f_handle, data, fmt='%s') 20 | 21 | 22 | def argparser(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model', help='filename of model to test', default='trained_models/ppo/model.ckpt') 25 | parser.add_argument('--iteration', default=10, type=int) 26 | 27 | return parser.parse_args() 28 | 29 | 30 | def main(args): 31 | env = gym.make('CartPole-v0') 32 | env.seed(0) 33 | ob_space = env.observation_space 34 | Policy = Policy_net('policy', env) 35 | saver = tf.train.Saver() 36 | 37 | with tf.Session() as sess: 38 | sess.run(tf.global_variables_initializer()) 39 | saver.restore(sess, args.model) 40 | obs = env.reset() 41 | 42 | for iteration in range(args.iteration): # episode 43 | observations = [] 44 | actions = [] 45 | run_steps = 0 46 | while True: 47 | run_steps += 1 48 | # prepare to feed placeholder Policy.obs 49 | obs = np.stack([obs]).astype(dtype=np.float32) 50 | 51 | act, _ = Policy.act(obs=obs, stochastic=True) 52 | act = np.asscalar(act) 53 | 54 | observations.append(obs) 55 | actions.append(act) 56 | 57 | next_obs, reward, done, info = env.step(act) 58 | 59 | if done: 60 | print(run_steps) 61 | obs = env.reset() 62 | break 63 | else: 64 | obs = next_obs 65 | 66 | observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape)) 67 | actions = np.array(actions).astype(dtype=np.int32) 68 | 69 | open_file_and_save('trajectory/observations.csv', observations) 70 | open_file_and_save('trajectory/actions.csv', actions) 71 | 72 | 73 | if __name__ == '__main__': 74 | args = argparser() 75 | main(args) 76 | -------------------------------------------------------------------------------- /test_policy.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import tensorflow as tf 4 | import argparse 5 | from network_models.policy_net import Policy_net 6 | 7 | 8 | def argparser(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--modeldir', help='directory of model', default='trained_models') 11 | parser.add_argument('--alg', help='chose algorithm one of gail, ppo, bc', default='gail') 12 | parser.add_argument('--model', help='number of model to test. model.ckpt-number', default='') 13 | parser.add_argument('--logdir', help='log directory', default='log/test') 14 | parser.add_argument('--iteration', default=int(1e3)) 15 | parser.add_argument('--stochastic', action='store_false') 16 | return parser.parse_args() 17 | 18 | 19 | def main(args): 20 | env = gym.make('CartPole-v0') 21 | env.seed(0) 22 | Policy = Policy_net('policy', env) 23 | saver = tf.train.Saver() 24 | 25 | with tf.Session() as sess: 26 | writer = tf.summary.FileWriter(args.logdir+'/'+args.alg, sess.graph) 27 | sess.run(tf.global_variables_initializer()) 28 | if args.model == '': 29 | saver.restore(sess, args.modeldir+'/'+args.alg+'/'+'model.ckpt') 30 | else: 31 | saver.restore(sess, args.modeldir+'/'+args.alg+'/'+'model.ckpt-'+args.model) 32 | obs = env.reset() 33 | reward = 0 34 | success_num = 0 35 | 36 | for iteration in range(args.iteration): 37 | rewards = [] 38 | run_policy_steps = 0 39 | while True: # run policy RUN_POLICY_STEPS which is much less than episode length 40 | run_policy_steps += 1 41 | obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs 42 | act, _ = Policy.act(obs=obs, stochastic=args.stochastic) 43 | 44 | act = np.asscalar(act) 45 | 46 | rewards.append(reward) 47 | 48 | next_obs, reward, done, info = env.step(act) 49 | 50 | if done: 51 | obs = env.reset() 52 | reward = -1 53 | break 54 | else: 55 | obs = next_obs 56 | 57 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)]) 58 | , iteration) 59 | writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))]) 60 | , iteration) 61 | 62 | # end condition of test 63 | if sum(rewards) >= 195: 64 | success_num += 1 65 | if success_num >= 100: 66 | print('Iteration: ', iteration) 67 | print('Clear!!') 68 | break 69 | else: 70 | success_num = 0 71 | 72 | writer.close() 73 | 74 | 75 | if __name__ == '__main__': 76 | args = argparser() 77 | main(args) 78 | -------------------------------------------------------------------------------- /trained_models/bc/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-1000" 2 | all_model_checkpoint_paths: "model.ckpt-100" 3 | all_model_checkpoint_paths: "model.ckpt-200" 4 | all_model_checkpoint_paths: "model.ckpt-300" 5 | all_model_checkpoint_paths: "model.ckpt-400" 6 | all_model_checkpoint_paths: "model.ckpt-500" 7 | all_model_checkpoint_paths: "model.ckpt-600" 8 | all_model_checkpoint_paths: "model.ckpt-700" 9 | all_model_checkpoint_paths: "model.ckpt-800" 10 | all_model_checkpoint_paths: "model.ckpt-900" 11 | all_model_checkpoint_paths: "model.ckpt-1000" 12 | -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-100.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-100.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-100.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-100.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-100.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-100.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-1000.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-1000.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-1000.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-1000.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-1000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-1000.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-200.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-200.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-200.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-200.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-200.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-200.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-300.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-300.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-300.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-300.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-300.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-300.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-400.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-400.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-400.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-400.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-400.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-400.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-500.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-500.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-500.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-500.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-500.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-500.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-600.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-600.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-600.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-600.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-600.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-600.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-700.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-700.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-700.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-700.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-700.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-700.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-800.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-800.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-800.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-800.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-800.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-800.meta -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-900.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-900.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-900.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-900.index -------------------------------------------------------------------------------- /trained_models/bc/model.ckpt-900.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/bc/model.ckpt-900.meta -------------------------------------------------------------------------------- /trained_models/gail/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /trained_models/gail/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/gail/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/gail/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/gail/model.ckpt.index -------------------------------------------------------------------------------- /trained_models/gail/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/gail/model.ckpt.meta -------------------------------------------------------------------------------- /trained_models/ppo/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /trained_models/ppo/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/ppo/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_models/ppo/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/ppo/model.ckpt.index -------------------------------------------------------------------------------- /trained_models/ppo/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uidilr/gail_ppo_tf/0bdee5ac8e15455e0e48c8a99f72ec1e10510855/trained_models/ppo/model.ckpt.meta -------------------------------------------------------------------------------- /trajectory/actions.csv: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | 1 4 | 0 5 | 1 6 | 0 7 | 1 8 | 1 9 | 0 10 | 0 11 | 1 12 | 1 13 | 0 14 | 0 15 | 1 16 | 0 17 | 1 18 | 0 19 | 0 20 | 0 21 | 1 22 | 1 23 | 0 24 | 0 25 | 1 26 | 0 27 | 0 28 | 1 29 | 1 30 | 0 31 | 0 32 | 0 33 | 1 34 | 1 35 | 1 36 | 1 37 | 0 38 | 0 39 | 0 40 | 1 41 | 1 42 | 0 43 | 1 44 | 0 45 | 0 46 | 1 47 | 0 48 | 1 49 | 1 50 | 0 51 | 1 52 | 0 53 | 1 54 | 0 55 | 1 56 | 1 57 | 0 58 | 0 59 | 0 60 | 1 61 | 1 62 | 0 63 | 0 64 | 1 65 | 1 66 | 0 67 | 0 68 | 1 69 | 0 70 | 1 71 | 0 72 | 0 73 | 1 74 | 1 75 | 0 76 | 1 77 | 0 78 | 0 79 | 1 80 | 0 81 | 0 82 | 1 83 | 1 84 | 1 85 | 0 86 | 1 87 | 1 88 | 0 89 | 0 90 | 1 91 | 0 92 | 1 93 | 0 94 | 1 95 | 1 96 | 0 97 | 1 98 | 0 99 | 1 100 | 0 101 | 0 102 | 1 103 | 1 104 | 0 105 | 0 106 | 1 107 | 0 108 | 1 109 | 1 110 | 0 111 | 0 112 | 1 113 | 1 114 | 0 115 | 1 116 | 0 117 | 0 118 | 1 119 | 0 120 | 1 121 | 1 122 | 0 123 | 0 124 | 1 125 | 1 126 | 0 127 | 0 128 | 1 129 | 1 130 | 1 131 | 0 132 | 0 133 | 0 134 | 1 135 | 1 136 | 0 137 | 0 138 | 1 139 | 1 140 | 0 141 | 1 142 | 0 143 | 0 144 | 1 145 | 0 146 | 0 147 | 1 148 | 1 149 | 1 150 | 0 151 | 1 152 | 0 153 | 0 154 | 1 155 | 0 156 | 0 157 | 1 158 | 1 159 | 0 160 | 1 161 | 0 162 | 1 163 | 0 164 | 1 165 | 1 166 | 0 167 | 0 168 | 1 169 | 0 170 | 0 171 | 1 172 | 1 173 | 0 174 | 0 175 | 1 176 | 0 177 | 0 178 | 1 179 | 1 180 | 0 181 | 0 182 | 0 183 | 1 184 | 1 185 | 0 186 | 0 187 | 1 188 | 0 189 | 1 190 | 0 191 | 1 192 | 0 193 | 0 194 | 1 195 | 0 196 | 1 197 | 1 198 | 0 199 | 0 200 | 1 201 | 1 202 | 1 203 | 0 204 | 0 205 | 1 206 | 0 207 | 1 208 | 1 209 | 0 210 | 0 211 | 0 212 | 1 213 | 1 214 | 1 215 | 0 216 | 0 217 | 0 218 | 0 219 | 1 220 | 0 221 | 0 222 | 1 223 | 1 224 | 1 225 | 0 226 | 1 227 | 0 228 | 0 229 | 1 230 | 0 231 | 0 232 | 1 233 | 0 234 | 1 235 | 0 236 | 1 237 | 0 238 | 0 239 | 0 240 | 0 241 | 1 242 | 0 243 | 1 244 | 1 245 | 0 246 | 1 247 | 0 248 | 0 249 | 0 250 | 1 251 | 0 252 | 1 253 | 0 254 | 1 255 | 0 256 | 1 257 | 0 258 | 1 259 | 1 260 | 1 261 | 0 262 | 0 263 | 1 264 | 0 265 | 1 266 | 1 267 | 1 268 | 0 269 | 1 270 | 1 271 | 0 272 | 0 273 | 1 274 | 1 275 | 1 276 | 0 277 | 0 278 | 1 279 | 0 280 | 1 281 | 1 282 | 0 283 | 1 284 | 1 285 | 0 286 | 1 287 | 1 288 | 0 289 | 0 290 | 1 291 | 1 292 | 0 293 | 1 294 | 1 295 | 1 296 | 0 297 | 1 298 | 0 299 | 0 300 | 0 301 | 1 302 | 0 303 | 1 304 | 0 305 | 1 306 | 0 307 | 0 308 | 0 309 | 1 310 | 1 311 | 0 312 | 0 313 | 0 314 | 1 315 | 1 316 | 0 317 | 1 318 | 0 319 | 0 320 | 1 321 | 1 322 | 0 323 | 0 324 | 1 325 | 0 326 | 0 327 | 1 328 | 1 329 | 0 330 | 1 331 | 0 332 | 1 333 | 0 334 | 0 335 | 1 336 | 0 337 | 0 338 | 1 339 | 0 340 | 1 341 | 0 342 | 0 343 | 1 344 | 1 345 | 1 346 | 0 347 | 0 348 | 1 349 | 0 350 | 1 351 | 1 352 | 0 353 | 0 354 | 1 355 | 1 356 | 0 357 | 0 358 | 0 359 | 1 360 | 0 361 | 1 362 | 0 363 | 0 364 | 1 365 | 1 366 | 0 367 | 1 368 | 1 369 | 1 370 | 0 371 | 1 372 | 0 373 | 1 374 | 0 375 | 1 376 | 0 377 | 1 378 | 1 379 | 0 380 | 1 381 | 0 382 | 1 383 | 0 384 | 1 385 | 0 386 | 0 387 | 1 388 | 1 389 | 1 390 | 0 391 | 0 392 | 0 393 | 0 394 | 1 395 | 1 396 | 0 397 | 1 398 | 1 399 | 0 400 | 0 401 | 0 402 | 1 403 | 1 404 | 1 405 | 0 406 | 0 407 | 1 408 | 0 409 | 0 410 | 1 411 | 0 412 | 1 413 | 0 414 | 1 415 | 0 416 | 0 417 | 1 418 | 0 419 | 1 420 | 0 421 | 1 422 | 1 423 | 0 424 | 0 425 | 1 426 | 0 427 | 0 428 | 1 429 | 1 430 | 0 431 | 1 432 | 0 433 | 1 434 | 0 435 | 0 436 | 1 437 | 0 438 | 0 439 | 1 440 | 0 441 | 1 442 | 1 443 | 0 444 | 0 445 | 1 446 | 1 447 | 0 448 | 0 449 | 1 450 | 0 451 | 1 452 | 1 453 | 0 454 | 0 455 | 0 456 | 1 457 | 1 458 | 1 459 | 0 460 | 1 461 | 0 462 | 1 463 | 1 464 | 0 465 | 0 466 | 1 467 | 1 468 | 1 469 | 0 470 | 1 471 | 0 472 | 1 473 | 0 474 | 0 475 | 1 476 | 0 477 | 1 478 | 0 479 | 1 480 | 1 481 | 0 482 | 0 483 | 1 484 | 1 485 | 0 486 | 0 487 | 0 488 | 1 489 | 1 490 | 1 491 | 0 492 | 0 493 | 0 494 | 0 495 | 1 496 | 1 497 | 1 498 | 0 499 | 1 500 | 0 501 | 0 502 | 1 503 | 1 504 | 0 505 | 1 506 | 0 507 | 0 508 | 0 509 | 0 510 | 1 511 | 1 512 | 0 513 | 0 514 | 1 515 | 1 516 | 1 517 | 0 518 | 1 519 | 0 520 | 0 521 | 0 522 | 1 523 | 0 524 | 1 525 | 1 526 | 0 527 | 1 528 | 0 529 | 1 530 | 1 531 | 1 532 | 0 533 | 1 534 | 0 535 | 0 536 | 1 537 | 0 538 | 1 539 | 0 540 | 1 541 | 0 542 | 1 543 | 1 544 | 0 545 | 1 546 | 0 547 | 0 548 | 0 549 | 1 550 | 1 551 | 0 552 | 1 553 | 1 554 | 0 555 | 1 556 | 1 557 | 0 558 | 0 559 | 1 560 | 1 561 | 1 562 | 0 563 | 0 564 | 0 565 | 0 566 | 1 567 | 0 568 | 1 569 | 1 570 | 0 571 | 1 572 | 0 573 | 0 574 | 1 575 | 1 576 | 0 577 | 0 578 | 1 579 | 0 580 | 1 581 | 1 582 | 0 583 | 0 584 | 1 585 | 1 586 | 1 587 | 0 588 | 0 589 | 1 590 | 0 591 | 1 592 | 0 593 | 1 594 | 0 595 | 1 596 | 0 597 | 0 598 | 1 599 | 0 600 | 0 601 | 1 602 | 1 603 | 0 604 | 0 605 | 1 606 | 0 607 | 1 608 | 1 609 | 0 610 | 0 611 | 1 612 | 1 613 | 0 614 | 0 615 | 1 616 | 1 617 | 0 618 | 1 619 | 0 620 | 1 621 | 0 622 | 0 623 | 1 624 | 0 625 | 1 626 | 0 627 | 1 628 | 0 629 | 1 630 | 0 631 | 0 632 | 1 633 | 0 634 | 0 635 | 1 636 | 1 637 | 0 638 | 0 639 | 1 640 | 0 641 | 1 642 | 1 643 | 0 644 | 0 645 | 1 646 | 0 647 | 0 648 | 1 649 | 1 650 | 0 651 | 0 652 | 1 653 | 0 654 | 1 655 | 1 656 | 0 657 | 1 658 | 0 659 | 1 660 | 0 661 | 0 662 | 0 663 | 0 664 | 1 665 | 1 666 | 0 667 | 1 668 | 1 669 | 0 670 | 1 671 | 0 672 | 0 673 | 1 674 | 1 675 | 0 676 | 1 677 | 0 678 | 1 679 | 0 680 | 1 681 | 0 682 | 1 683 | 1 684 | 0 685 | 0 686 | 0 687 | 1 688 | 1 689 | 1 690 | 0 691 | 0 692 | 1 693 | 0 694 | 1 695 | 0 696 | 1 697 | 0 698 | 0 699 | 1 700 | 1 701 | 1 702 | 0 703 | 1 704 | 0 705 | 0 706 | 1 707 | 1 708 | 0 709 | 1 710 | 0 711 | 1 712 | 1 713 | 0 714 | 0 715 | 1 716 | 0 717 | 1 718 | 1 719 | 0 720 | 0 721 | 1 722 | 0 723 | 0 724 | 0 725 | 1 726 | 1 727 | 0 728 | 1 729 | 1 730 | 1 731 | 0 732 | 0 733 | 1 734 | 0 735 | 1 736 | 1 737 | 0 738 | 0 739 | 1 740 | 1 741 | 0 742 | 1 743 | 0 744 | 1 745 | 1 746 | 1 747 | 0 748 | 0 749 | 0 750 | 0 751 | 1 752 | 0 753 | 1 754 | 1 755 | 0 756 | 1 757 | 1 758 | 0 759 | 1 760 | 0 761 | 1 762 | 0 763 | 1 764 | 0 765 | 0 766 | 1 767 | 1 768 | 0 769 | 0 770 | 0 771 | 1 772 | 0 773 | 0 774 | 1 775 | 1 776 | 0 777 | 1 778 | 0 779 | 0 780 | 1 781 | 0 782 | 0 783 | 0 784 | 1 785 | 0 786 | 0 787 | 1 788 | 1 789 | 0 790 | 1 791 | 1 792 | 0 793 | 0 794 | 1 795 | 0 796 | 1 797 | 1 798 | 0 799 | 0 800 | 1 801 | 0 802 | 1 803 | 0 804 | 1 805 | 1 806 | 0 807 | 0 808 | 1 809 | 0 810 | 1 811 | 1 812 | 0 813 | 0 814 | 1 815 | 1 816 | 0 817 | 0 818 | 1 819 | 0 820 | 1 821 | 0 822 | 0 823 | 1 824 | 0 825 | 1 826 | 0 827 | 1 828 | 1 829 | 0 830 | 1 831 | 0 832 | 0 833 | 1 834 | 0 835 | 1 836 | 1 837 | 0 838 | 0 839 | 1 840 | 0 841 | 1 842 | 0 843 | 1 844 | 0 845 | 0 846 | 1 847 | 0 848 | 0 849 | 1 850 | 1 851 | 0 852 | 0 853 | 1 854 | 0 855 | 0 856 | 1 857 | 1 858 | 0 859 | 1 860 | 1 861 | 0 862 | 0 863 | 1 864 | 0 865 | 1 866 | 0 867 | 1 868 | 0 869 | 0 870 | 1 871 | 0 872 | 0 873 | 1 874 | 1 875 | 1 876 | 0 877 | 1 878 | 0 879 | 0 880 | 1 881 | 0 882 | 1 883 | 1 884 | 0 885 | 1 886 | 1 887 | 0 888 | 1 889 | 1 890 | 0 891 | 1 892 | 0 893 | 1 894 | 0 895 | 1 896 | 0 897 | 1 898 | 0 899 | 1 900 | 1 901 | 1 902 | 0 903 | 0 904 | 0 905 | 1 906 | 1 907 | 1 908 | 0 909 | 0 910 | 1 911 | 0 912 | 1 913 | 0 914 | 1 915 | 1 916 | 0 917 | 1 918 | 0 919 | 1 920 | 1 921 | 0 922 | 0 923 | 0 924 | 1 925 | 0 926 | 1 927 | 0 928 | 1 929 | 0 930 | 0 931 | 1 932 | 0 933 | 0 934 | 1 935 | 0 936 | 1 937 | 1 938 | 0 939 | 0 940 | 1 941 | 0 942 | 1 943 | 0 944 | 1 945 | 1 946 | 0 947 | 0 948 | 1 949 | 1 950 | 0 951 | 0 952 | 0 953 | 0 954 | 1 955 | 1 956 | 1 957 | 0 958 | 0 959 | 1 960 | 1 961 | 0 962 | 1 963 | 1 964 | 0 965 | 0 966 | 1 967 | 0 968 | 1 969 | 1 970 | 0 971 | 0 972 | 1 973 | 1 974 | 1 975 | 0 976 | 1 977 | 0 978 | 0 979 | 0 980 | 1 981 | 1 982 | 1 983 | 0 984 | 1 985 | 0 986 | 0 987 | 1 988 | 0 989 | 1 990 | 0 991 | 1 992 | 0 993 | 1 994 | 0 995 | 0 996 | 1 997 | 1 998 | 0 999 | 0 1000 | 0 1001 | 1 1002 | 1 1003 | 0 1004 | 0 1005 | 1 1006 | 0 1007 | 1 1008 | 1 1009 | 0 1010 | 0 1011 | 1 1012 | 0 1013 | 0 1014 | 1 1015 | 0 1016 | 1 1017 | 0 1018 | 1 1019 | 0 1020 | 0 1021 | 1 1022 | 1 1023 | 0 1024 | 0 1025 | 0 1026 | 1 1027 | 0 1028 | 1 1029 | 1 1030 | 0 1031 | 0 1032 | 1 1033 | 0 1034 | 1 1035 | 0 1036 | 0 1037 | 1 1038 | 0 1039 | 1 1040 | 0 1041 | 1 1042 | 0 1043 | 1 1044 | 0 1045 | 1 1046 | 1 1047 | 0 1048 | 1 1049 | 1 1050 | 0 1051 | 0 1052 | 1 1053 | 1 1054 | 0 1055 | 0 1056 | 1 1057 | 1 1058 | 1 1059 | 1 1060 | 0 1061 | 1 1062 | 0 1063 | 0 1064 | 1 1065 | 0 1066 | 1 1067 | 0 1068 | 1 1069 | 0 1070 | 1 1071 | 0 1072 | 1 1073 | 0 1074 | 1 1075 | 0 1076 | 1 1077 | 1 1078 | 0 1079 | 0 1080 | 1 1081 | 0 1082 | 1 1083 | 0 1084 | 1 1085 | 1 1086 | 0 1087 | 0 1088 | 1 1089 | 0 1090 | 0 1091 | 1 1092 | 1 1093 | 0 1094 | 1 1095 | 0 1096 | 0 1097 | 0 1098 | 1 1099 | 1 1100 | 0 1101 | 1 1102 | 0 1103 | 1 1104 | 0 1105 | 0 1106 | 1 1107 | 0 1108 | 1 1109 | 0 1110 | 1 1111 | 0 1112 | 1 1113 | 1 1114 | 0 1115 | 1 1116 | 0 1117 | 0 1118 | 1 1119 | 0 1120 | 1 1121 | 0 1122 | 1 1123 | 0 1124 | 1 1125 | 0 1126 | 1 1127 | 1 1128 | 0 1129 | 0 1130 | 0 1131 | 1 1132 | 0 1133 | 1 1134 | 1 1135 | 1 1136 | 0 1137 | 1 1138 | 0 1139 | 0 1140 | 0 1141 | 0 1142 | 1 1143 | 1 1144 | 1 1145 | 0 1146 | 1 1147 | 1 1148 | 0 1149 | 1 1150 | 1 1151 | 0 1152 | 1 1153 | 1 1154 | 0 1155 | 0 1156 | 1 1157 | 1 1158 | 1 1159 | 0 1160 | 0 1161 | 1 1162 | 0 1163 | 1 1164 | 0 1165 | 1 1166 | 0 1167 | 1 1168 | 1 1169 | 0 1170 | 0 1171 | 0 1172 | 1 1173 | 1 1174 | 0 1175 | 1 1176 | 1 1177 | 0 1178 | 1 1179 | 0 1180 | 1 1181 | 0 1182 | 0 1183 | 0 1184 | 1 1185 | 1 1186 | 0 1187 | 1 1188 | 0 1189 | 0 1190 | 1 1191 | 0 1192 | 0 1193 | 1 1194 | 0 1195 | 0 1196 | 1 1197 | 0 1198 | 1 1199 | 0 1200 | 0 1201 | 1 1202 | 0 1203 | 1 1204 | 0 1205 | 1 1206 | 1 1207 | 0 1208 | 0 1209 | 1 1210 | 0 1211 | 1 1212 | 0 1213 | 1 1214 | 0 1215 | 1 1216 | 0 1217 | 1 1218 | 0 1219 | 1 1220 | 0 1221 | 0 1222 | 1 1223 | 1 1224 | 0 1225 | 0 1226 | 1 1227 | 1 1228 | 0 1229 | 0 1230 | 0 1231 | 1 1232 | 0 1233 | 1 1234 | 0 1235 | 0 1236 | 1 1237 | 0 1238 | 1 1239 | 0 1240 | 0 1241 | 1 1242 | 1 1243 | 0 1244 | 0 1245 | 1 1246 | 1 1247 | 1 1248 | 0 1249 | 0 1250 | 0 1251 | 1 1252 | 1 1253 | 0 1254 | 0 1255 | 1 1256 | 1 1257 | 0 1258 | 0 1259 | 0 1260 | 1 1261 | 0 1262 | 1 1263 | 0 1264 | 1 1265 | 1 1266 | 0 1267 | 1 1268 | 0 1269 | 0 1270 | 0 1271 | 1 1272 | 1 1273 | 1 1274 | 0 1275 | 1 1276 | 0 1277 | 1 1278 | 1 1279 | 0 1280 | 0 1281 | 1 1282 | 0 1283 | 1 1284 | 1 1285 | 1 1286 | 1 1287 | 0 1288 | 1 1289 | 0 1290 | 1 1291 | 1 1292 | 0 1293 | 1 1294 | 1 1295 | 1 1296 | 0 1297 | 1 1298 | 0 1299 | 0 1300 | 1 1301 | 1 1302 | 0 1303 | 0 1304 | 1 1305 | 0 1306 | 0 1307 | 1 1308 | 1 1309 | 0 1310 | 0 1311 | 1 1312 | 1 1313 | 0 1314 | 1 1315 | 0 1316 | 1 1317 | 0 1318 | 0 1319 | 1 1320 | 0 1321 | 1 1322 | 0 1323 | 0 1324 | 1 1325 | 1 1326 | 0 1327 | 0 1328 | 0 1329 | 1 1330 | 1 1331 | 1 1332 | 0 1333 | 0 1334 | 0 1335 | 1 1336 | 0 1337 | 0 1338 | 1 1339 | 1 1340 | 0 1341 | 0 1342 | 1 1343 | 0 1344 | 0 1345 | 1 1346 | 1 1347 | 0 1348 | 1 1349 | 0 1350 | 0 1351 | 1 1352 | 1 1353 | 0 1354 | 0 1355 | 0 1356 | 1 1357 | 0 1358 | 0 1359 | 1 1360 | 0 1361 | 1 1362 | 1 1363 | 0 1364 | 0 1365 | 1 1366 | 1 1367 | 0 1368 | 0 1369 | 1 1370 | 1 1371 | 1 1372 | 0 1373 | 0 1374 | 1 1375 | 1 1376 | 0 1377 | 1 1378 | 1 1379 | 0 1380 | 0 1381 | 1 1382 | 1 1383 | 1 1384 | 0 1385 | 1 1386 | 1 1387 | 0 1388 | 0 1389 | 1 1390 | 0 1391 | 0 1392 | 1 1393 | 1 1394 | 1 1395 | 0 1396 | 1 1397 | 0 1398 | 1 1399 | 0 1400 | 1 1401 | 1 1402 | 1 1403 | 0 1404 | 0 1405 | 1 1406 | 0 1407 | 1 1408 | 0 1409 | 1 1410 | 0 1411 | 0 1412 | 1 1413 | 0 1414 | 1 1415 | 0 1416 | 1 1417 | 0 1418 | 0 1419 | 0 1420 | 1 1421 | 1 1422 | 0 1423 | 0 1424 | 1 1425 | 0 1426 | 0 1427 | 1 1428 | 0 1429 | 0 1430 | 1 1431 | 1 1432 | 1 1433 | 0 1434 | 0 1435 | 1 1436 | 1 1437 | 0 1438 | 1 1439 | 1 1440 | 1 1441 | 0 1442 | 1 1443 | 0 1444 | 1 1445 | 0 1446 | 1 1447 | 0 1448 | 0 1449 | 0 1450 | 1 1451 | 1 1452 | 0 1453 | 1 1454 | 0 1455 | 1 1456 | 0 1457 | 1 1458 | 1 1459 | 0 1460 | 1 1461 | 0 1462 | 1 1463 | 0 1464 | 1 1465 | 0 1466 | 1 1467 | 0 1468 | 1 1469 | 0 1470 | 1 1471 | 1 1472 | 0 1473 | 0 1474 | 1 1475 | 0 1476 | 1 1477 | 1 1478 | 1 1479 | 0 1480 | 0 1481 | 1 1482 | 0 1483 | 0 1484 | 1 1485 | 1 1486 | 0 1487 | 1 1488 | 1 1489 | 0 1490 | 0 1491 | 0 1492 | 1 1493 | 0 1494 | 1 1495 | 0 1496 | 0 1497 | 1 1498 | 1 1499 | 0 1500 | 0 1501 | 0 1502 | 0 1503 | 1 1504 | 1 1505 | 1 1506 | 0 1507 | 0 1508 | 1 1509 | 0 1510 | 0 1511 | 0 1512 | 1 1513 | 1 1514 | 1 1515 | 0 1516 | 1 1517 | 0 1518 | 1 1519 | 0 1520 | 1 1521 | 0 1522 | 0 1523 | 1 1524 | 1 1525 | 0 1526 | 0 1527 | 0 1528 | 1 1529 | 1 1530 | 0 1531 | 1 1532 | 0 1533 | 1 1534 | 1 1535 | 0 1536 | 0 1537 | 1 1538 | 1 1539 | 0 1540 | 0 1541 | 0 1542 | 1 1543 | 1 1544 | 1 1545 | 0 1546 | 1 1547 | 0 1548 | 0 1549 | 1 1550 | 0 1551 | 1 1552 | 0 1553 | 1 1554 | 0 1555 | 1 1556 | 0 1557 | 1 1558 | 0 1559 | 0 1560 | 1 1561 | 1 1562 | 0 1563 | 1 1564 | 1 1565 | 0 1566 | 1 1567 | 1 1568 | 0 1569 | 0 1570 | 0 1571 | 1 1572 | 0 1573 | 1 1574 | 1 1575 | 1 1576 | 0 1577 | 1 1578 | 0 1579 | 0 1580 | 1 1581 | 1 1582 | 1 1583 | 0 1584 | 0 1585 | 0 1586 | 1 1587 | 1 1588 | 1 1589 | 0 1590 | 0 1591 | 1 1592 | 1 1593 | 0 1594 | 0 1595 | 1 1596 | 0 1597 | 1 1598 | 0 1599 | 1 1600 | 0 1601 | 1 1602 | 1 1603 | 0 1604 | 1 1605 | 0 1606 | 0 1607 | 0 1608 | 0 1609 | 1 1610 | 1 1611 | 1 1612 | 1 1613 | 0 1614 | 0 1615 | 1 1616 | 0 1617 | 1 1618 | 0 1619 | 0 1620 | 1 1621 | 0 1622 | 0 1623 | 0 1624 | 1 1625 | 1 1626 | 0 1627 | 0 1628 | 0 1629 | 1 1630 | 1 1631 | 0 1632 | 0 1633 | 1 1634 | 0 1635 | 0 1636 | 1 1637 | 1 1638 | 1 1639 | 0 1640 | 1 1641 | 0 1642 | 1 1643 | 0 1644 | 1 1645 | 0 1646 | 0 1647 | 0 1648 | 0 1649 | 1 1650 | 1 1651 | 1 1652 | 1 1653 | 0 1654 | 0 1655 | 1 1656 | 0 1657 | 0 1658 | 1 1659 | 0 1660 | 1 1661 | 1 1662 | 1 1663 | 0 1664 | 0 1665 | 0 1666 | 0 1667 | 1 1668 | 1 1669 | 1 1670 | 0 1671 | 0 1672 | 1 1673 | 1 1674 | 1 1675 | 0 1676 | 0 1677 | 0 1678 | 1 1679 | 1 1680 | 0 1681 | 1 1682 | 1 1683 | 0 1684 | 1 1685 | 0 1686 | 1 1687 | 0 1688 | 1 1689 | 0 1690 | 1 1691 | 0 1692 | 1 1693 | 1 1694 | 0 1695 | 0 1696 | 1 1697 | 1 1698 | 0 1699 | 1 1700 | 0 1701 | 1 1702 | 0 1703 | 0 1704 | 1 1705 | 1 1706 | 0 1707 | 1 1708 | 0 1709 | 1 1710 | 0 1711 | 1 1712 | 0 1713 | 1 1714 | 1 1715 | 1 1716 | 0 1717 | 1 1718 | 1 1719 | 0 1720 | 0 1721 | 0 1722 | 1 1723 | 0 1724 | 1 1725 | 0 1726 | 0 1727 | 1 1728 | 1 1729 | 0 1730 | 0 1731 | 1 1732 | 0 1733 | 0 1734 | 1 1735 | 0 1736 | 0 1737 | 1 1738 | 1 1739 | 1 1740 | 0 1741 | 1 1742 | 0 1743 | 1 1744 | 1 1745 | 0 1746 | 1 1747 | 1 1748 | 0 1749 | 0 1750 | 0 1751 | 1 1752 | 1 1753 | 1 1754 | 0 1755 | 1 1756 | 0 1757 | 0 1758 | 0 1759 | 1 1760 | 0 1761 | 1 1762 | 0 1763 | 0 1764 | 1 1765 | 0 1766 | 1 1767 | 1 1768 | 0 1769 | 0 1770 | 0 1771 | 1 1772 | 0 1773 | 0 1774 | 1 1775 | 1 1776 | 0 1777 | 0 1778 | 1 1779 | 0 1780 | 1 1781 | 0 1782 | 0 1783 | 1 1784 | 0 1785 | 1 1786 | 1 1787 | 0 1788 | 1 1789 | 0 1790 | 1 1791 | 0 1792 | 1 1793 | 1 1794 | 0 1795 | 0 1796 | 1 1797 | 1 1798 | 0 1799 | 0 1800 | 1 1801 | 1 1802 | 1 1803 | 0 1804 | 0 1805 | 1 1806 | 0 1807 | 0 1808 | 1 1809 | 1 1810 | 1 1811 | 0 1812 | 0 1813 | 0 1814 | 1 1815 | 0 1816 | 1 1817 | 0 1818 | 0 1819 | 1 1820 | 1 1821 | 0 1822 | 1 1823 | 0 1824 | 0 1825 | 0 1826 | 1 1827 | 1 1828 | 0 1829 | 0 1830 | 1 1831 | 1 1832 | 0 1833 | 1 1834 | 1 1835 | 0 1836 | 0 1837 | 0 1838 | 1 1839 | 0 1840 | 0 1841 | 0 1842 | 1 1843 | 1 1844 | 1 1845 | 0 1846 | 1 1847 | 0 1848 | 1 1849 | 0 1850 | 1 1851 | 0 1852 | 1 1853 | 0 1854 | 1 1855 | 0 1856 | 1 1857 | 1 1858 | 0 1859 | 1 1860 | 1 1861 | 0 1862 | 0 1863 | 0 1864 | 1 1865 | 0 1866 | 1 1867 | 0 1868 | 1 1869 | 0 1870 | 1 1871 | 1 1872 | 0 1873 | 1 1874 | 1 1875 | 0 1876 | 1 1877 | 0 1878 | 0 1879 | 1 1880 | 1 1881 | 0 1882 | 1 1883 | 1 1884 | 0 1885 | 0 1886 | 0 1887 | 1 1888 | 1 1889 | 0 1890 | 0 1891 | 0 1892 | 1 1893 | 1 1894 | 1 1895 | 0 1896 | 0 1897 | 1 1898 | 1 1899 | 0 1900 | 1 1901 | 1 1902 | 0 1903 | 0 1904 | 0 1905 | 1 1906 | 0 1907 | 1 1908 | 1 1909 | 0 1910 | 0 1911 | 1 1912 | 0 1913 | 1 1914 | 1 1915 | 0 1916 | 0 1917 | 1 1918 | 0 1919 | 0 1920 | 1 1921 | 0 1922 | 1 1923 | 0 1924 | 0 1925 | 1 1926 | 0 1927 | 1 1928 | 1 1929 | 1 1930 | 0 1931 | 1 1932 | 0 1933 | 0 1934 | 1 1935 | 0 1936 | 0 1937 | 0 1938 | 1 1939 | 1 1940 | 1 1941 | 0 1942 | 0 1943 | 0 1944 | 1 1945 | 0 1946 | 1 1947 | 1 1948 | 1 1949 | 0 1950 | 1 1951 | 0 1952 | 1 1953 | 0 1954 | 1 1955 | 0 1956 | 1 1957 | 1 1958 | 0 1959 | 1 1960 | 0 1961 | 1 1962 | 0 1963 | 0 1964 | 1 1965 | 1 1966 | 0 1967 | 1 1968 | 1 1969 | 0 1970 | 1 1971 | 1 1972 | 0 1973 | 0 1974 | 0 1975 | 1 1976 | 1 1977 | 0 1978 | 1 1979 | 1 1980 | 1 1981 | 0 1982 | 0 1983 | 1 1984 | 0 1985 | 0 1986 | 0 1987 | 1 1988 | 1 1989 | 0 1990 | 0 1991 | 1 1992 | 0 1993 | 1 1994 | 1 1995 | 1 1996 | 1 1997 | 0 1998 | 0 1999 | 0 2000 | 0 2001 | 1 2002 | 0 2003 | 1 2004 | 0 2005 | 1 2006 | 0 2007 | 0 2008 | 1 2009 | 0 2010 | 1 2011 | 1 2012 | 1 2013 | 0 2014 | 0 2015 | 1 2016 | 1 2017 | 1 2018 | 0 2019 | 0 2020 | 1 2021 | 0 2022 | 0 2023 | 0 2024 | 0 2025 | 1 2026 | 0 2027 | 0 2028 | 1 2029 | 1 2030 | 0 2031 | 1 2032 | 0 2033 | 1 2034 | 0 2035 | 0 2036 | 1 2037 | 1 2038 | 0 2039 | 1 2040 | 0 2041 | 1 2042 | 0 2043 | 1 2044 | 0 2045 | 0 2046 | 1 2047 | 0 2048 | 1 2049 | 0 2050 | 1 2051 | 0 2052 | 0 2053 | 1 2054 | 1 2055 | 0 2056 | 0 2057 | 1 2058 | 0 2059 | 0 2060 | 1 2061 | 0 2062 | 1 2063 | 0 2064 | 1 2065 | 1 2066 | 0 2067 | 1 2068 | 0 2069 | 1 2070 | 0 2071 | 0 2072 | 0 2073 | 1 2074 | 1 2075 | 1 2076 | 0 2077 | 1 2078 | 1 2079 | 1 2080 | 0 2081 | 0 2082 | 0 2083 | 1 2084 | 0 2085 | 1 2086 | 1 2087 | 1 2088 | 0 2089 | 0 2090 | 1 2091 | 0 2092 | 1 2093 | 0 2094 | 1 2095 | 1 2096 | 0 2097 | 1 2098 | 1 2099 | 0 2100 | 0 2101 | 0 2102 | 1 2103 | 0 2104 | 1 2105 | 1 2106 | 0 2107 | 0 2108 | 1 2109 | 1 2110 | 0 2111 | 0 2112 | 1 2113 | 1 2114 | 1 2115 | 0 2116 | 1 2117 | 1 2118 | 0 2119 | 1 2120 | 1 2121 | 0 2122 | 1 2123 | 0 2124 | 1 2125 | 1 2126 | 0 2127 | 0 2128 | 1 2129 | 0 2130 | 1 2131 | 0 2132 | 1 2133 | 0 2134 | 1 2135 | 0 2136 | 1 2137 | 1 2138 | 0 2139 | 0 2140 | 0 2141 | 1 2142 | 1 2143 | 0 2144 | 0 2145 | 1 2146 | 1 2147 | 0 2148 | 1 2149 | 0 2150 | 0 2151 | 1 2152 | 1 2153 | 0 2154 | 1 2155 | 0 2156 | 0 2157 | 1 2158 | 0 2159 | 1 2160 | 1 2161 | 0 2162 | 0 2163 | 1 2164 | 0 2165 | 1 2166 | 0 2167 | 1 2168 | 0 2169 | 1 2170 | 0 2171 | 0 2172 | 1 2173 | 1 2174 | 0 2175 | 1 2176 | 1 2177 | 0 2178 | 0 2179 | 0 2180 | 1 2181 | 0 2182 | 1 2183 | 0 2184 | 1 2185 | 1 2186 | 1 2187 | 0 2188 | 1 2189 | 1 2190 | 0 2191 | 1 2192 | 0 2193 | 0 2194 | 1 2195 | 0 2196 | 0 2197 | 1 2198 | 1 2199 | 0 2200 | 1 2201 | 1 2202 | 1 2203 | 0 2204 | 0 2205 | 0 2206 | 1 2207 | 0 2208 | 1 2209 | 0 2210 | 1 2211 | 0 2212 | 0 2213 | 1 2214 | 1 2215 | 1 2216 | 0 2217 | 1 2218 | 0 2219 | 0 2220 | 1 2221 | 1 2222 | 0 2223 | 1 2224 | 0 2225 | 1 2226 | 0 2227 | 0 2228 | 1 2229 | 1 2230 | 0 2231 | 0 2232 | 0 2233 | 1 2234 | 0 2235 | 1 2236 | 0 2237 | 1 2238 | 0 2239 | 1 2240 | 0 2241 | 0 2242 | 1 2243 | 1 2244 | 0 2245 | 1 2246 | 1 2247 | 0 2248 | 0 2249 | 0 2250 | 1 2251 | 1 2252 | 0 2253 | 1 2254 | 0 2255 | 0 2256 | 0 2257 | 1 2258 | 1 2259 | 0 2260 | 1 2261 | 1 2262 | 0 2263 | 0 2264 | 0 2265 | 1 2266 | 0 2267 | 1 2268 | 1 2269 | 1 2270 | 0 2271 | 0 2272 | 1 2273 | 1 2274 | 0 2275 | 0 2276 | 1 2277 | 0 2278 | 1 2279 | 1 2280 | 0 2281 | 1 2282 | 0 2283 | 1 2284 | 1 2285 | 0 2286 | 1 2287 | 0 2288 | 1 2289 | 1 2290 | 0 2291 | 0 2292 | 0 2293 | 1 2294 | 1 2295 | 0 2296 | 1 2297 | 0 2298 | 1 2299 | 0 2300 | 1 2301 | 1 2302 | 0 2303 | 0 2304 | 1 2305 | 1 2306 | 1 2307 | 0 2308 | 0 2309 | 1 2310 | 1 2311 | 0 2312 | 0 2313 | 0 2314 | 0 2315 | 1 2316 | 1 2317 | 0 2318 | 0 2319 | 1 2320 | 0 2321 | 1 2322 | 0 2323 | 1 2324 | 0 2325 | 1 2326 | 0 2327 | 1 2328 | 0 2329 | 1 2330 | 1 2331 | 0 2332 | 1 2333 | 0 2334 | 1 2335 | 0 2336 | 0 2337 | 1 2338 | 0 2339 | 1 2340 | 0 2341 | 0 2342 | 1 2343 | 0 2344 | 0 2345 | 1 2346 | 1 2347 | 0 2348 | 1 2349 | 0 2350 | 0 2351 | 1 2352 | 1 2353 | 1 2354 | 0 2355 | 0 2356 | 1 2357 | 1 2358 | 1 2359 | 1 2360 | 0 2361 | 1 2362 | 0 2363 | 0 2364 | 0 2365 | 1 2366 | 1 2367 | 0 2368 | 1 2369 | 1 2370 | 0 2371 | 1 2372 | 0 2373 | 0 2374 | 0 2375 | 0 2376 | 1 2377 | 1 2378 | 1 2379 | 0 2380 | 1 2381 | 0 2382 | 1 2383 | 0 2384 | 0 2385 | 1 2386 | 1 2387 | 1 2388 | 0 2389 | 1 2390 | 0 2391 | 0 2392 | 1 2393 | 1 2394 | 0 2395 | 1 2396 | 0 2397 | 1 2398 | 1 2399 | 1 2400 | 1 2401 | 0 2402 | 1 2403 | 1 2404 | 0 2405 | 1 2406 | 0 2407 | 0 2408 | 1 2409 | 0 2410 | 1 2411 | 1 2412 | 0 2413 | 1 2414 | 0 2415 | 0 2416 | 0 2417 | 1 2418 | 0 2419 | 1 2420 | 0 2421 | 1 2422 | 1 2423 | 1 2424 | 0 2425 | 1 2426 | 0 2427 | 0 2428 | 0 2429 | 1 2430 | 0 2431 | 0 2432 | 1 2433 | 1 2434 | 0 2435 | 0 2436 | 1 2437 | 0 2438 | 1 2439 | 1 2440 | 0 2441 | 0 2442 | 1 2443 | 1 2444 | 0 2445 | 1 2446 | 0 2447 | 0 2448 | 0 2449 | 0 2450 | 1 2451 | 0 2452 | 0 2453 | 1 2454 | 0 2455 | 0 2456 | 1 2457 | 1 2458 | 0 2459 | 1 2460 | 1 2461 | 0 2462 | 1 2463 | 1 2464 | 0 2465 | 0 2466 | 1 2467 | 0 2468 | 0 2469 | 1 2470 | 0 2471 | 0 2472 | 1 2473 | 1 2474 | 0 2475 | 0 2476 | 1 2477 | 1 2478 | 0 2479 | 1 2480 | 0 2481 | 1 2482 | 0 2483 | 1 2484 | 1 2485 | 0 2486 | 1 2487 | 1 2488 | 1 2489 | 0 2490 | 1 2491 | 1 2492 | 0 2493 | 0 2494 | 1 2495 | 0 2496 | 1 2497 | 1 2498 | 0 2499 | 1 2500 | 0 2501 | 1 2502 | 0 2503 | 1 2504 | 1 2505 | 1 2506 | 0 2507 | 0 2508 | 1 2509 | 0 2510 | 1 2511 | 1 2512 | 1 2513 | 0 2514 | 1 2515 | 0 2516 | 0 2517 | 1 2518 | 1 2519 | 1 2520 | 0 2521 | 1 2522 | 1 2523 | 1 2524 | 0 2525 | 0 2526 | 0 2527 | 0 2528 | 1 2529 | 1 2530 | 1 2531 | 0 2532 | 1 2533 | 0 2534 | 1 2535 | 0 2536 | 0 2537 | 0 2538 | 1 2539 | 1 2540 | 1 2541 | 1 2542 | 0 2543 | 0 2544 | 0 2545 | 1 2546 | 1 2547 | 1 2548 | 1 2549 | 0 2550 | 0 2551 | 1 2552 | 0 2553 | 1 2554 | 0 2555 | 1 2556 | 0 2557 | 1 2558 | 0 2559 | 1 2560 | 1 2561 | 0 2562 | 0 2563 | 0 2564 | 1 2565 | 0 2566 | 1 2567 | 1 2568 | 0 2569 | 1 2570 | 1 2571 | 0 2572 | 0 2573 | 1 2574 | 1 2575 | 0 2576 | 0 2577 | 0 2578 | 1 2579 | 1 2580 | 0 2581 | 1 2582 | 1 2583 | 0 2584 | 0 2585 | 0 2586 | 1 2587 | 0 2588 | 1 2589 | 0 2590 | 1 2591 | 0 2592 | 1 2593 | 1 2594 | 0 2595 | 0 2596 | 0 2597 | 1 2598 | 1 2599 | 0 2600 | 1 2601 | 1 2602 | 1 2603 | 0 2604 | 1 2605 | 0 2606 | 0 2607 | 1 2608 | 1 2609 | 1 2610 | 0 2611 | 1 2612 | 0 2613 | 0 2614 | 0 2615 | 1 2616 | 0 2617 | 1 2618 | 0 2619 | 0 2620 | 1 2621 | 1 2622 | 0 2623 | 0 2624 | 1 2625 | 0 2626 | 1 2627 | 0 2628 | 1 2629 | 0 2630 | 1 2631 | 0 2632 | 0 2633 | 0 2634 | 1 2635 | 1 2636 | 1 2637 | 0 2638 | 0 2639 | 1 2640 | 0 2641 | 0 2642 | 1 2643 | 0 2644 | 1 2645 | 0 2646 | 1 2647 | 0 2648 | 1 2649 | 0 2650 | 0 2651 | 1 2652 | 0 2653 | 0 2654 | 1 2655 | 0 2656 | 1 2657 | 1 2658 | 1 2659 | 0 2660 | 0 2661 | 1 2662 | 0 2663 | 1 2664 | 0 2665 | 0 2666 | 1 2667 | 0 2668 | 1 2669 | 1 2670 | 0 2671 | 0 2672 | 1 2673 | 0 2674 | 0 2675 | 1 2676 | 0 2677 | 1 2678 | 1 2679 | 0 2680 | 0 2681 | 1 2682 | 0 2683 | 1 2684 | 0 2685 | 1 2686 | 0 2687 | 1 2688 | 0 2689 | 1 2690 | 0 2691 | 1 2692 | 1 2693 | 1 2694 | 0 2695 | 0 2696 | 0 2697 | 1 2698 | 0 2699 | 0 2700 | 1 2701 | 1 2702 | 0 2703 | 1 2704 | 1 2705 | 0 2706 | 0 2707 | 1 2708 | 0 2709 | 1 2710 | 0 2711 | 1 2712 | 1 2713 | 0 2714 | 1 2715 | 0 2716 | 1 2717 | 1 2718 | 1 2719 | 1 2720 | 0 2721 | 0 2722 | 0 2723 | 0 2724 | 1 2725 | 0 2726 | 1 2727 | 1 2728 | 0 2729 | 1 2730 | 0 2731 | 1 2732 | 1 2733 | 0 2734 | 1 2735 | 0 2736 | 1 2737 | 1 2738 | 0 2739 | 1 2740 | 1 2741 | 1 2742 | 0 2743 | 0 2744 | 1 2745 | 0 2746 | 1 2747 | 0 2748 | 1 2749 | 1 2750 | 0 2751 | 0 2752 | 1 2753 | 1 2754 | 0 2755 | 1 2756 | 0 2757 | 0 2758 | 1 2759 | 1 2760 | 1 2761 | 0 2762 | 1 2763 | 0 2764 | 1 2765 | 0 2766 | 1 2767 | 1 2768 | 1 2769 | 0 2770 | 1 2771 | 1 2772 | 0 2773 | 1 2774 | 1 2775 | 0 2776 | 0 2777 | 1 2778 | 1 2779 | 0 2780 | 1 2781 | 1 2782 | 0 2783 | 1 2784 | 1 2785 | 0 2786 | 0 2787 | 0 2788 | 1 2789 | 0 2790 | 0 2791 | 1 2792 | 1 2793 | 1 2794 | 0 2795 | 1 2796 | 0 2797 | 0 2798 | 1 2799 | 1 2800 | 1 2801 | 0 2802 | 1 2803 | 1 2804 | 0 2805 | 0 2806 | 1 2807 | 1 2808 | 1 2809 | 0 2810 | 0 2811 | 0 2812 | 0 2813 | 0 2814 | 1 2815 | 1 2816 | 1 2817 | 0 2818 | 1 2819 | 1 2820 | 0 2821 | 0 2822 | 1 2823 | 0 2824 | 0 2825 | 0 2826 | 1 2827 | 0 2828 | 1 2829 | 1 2830 | 1 2831 | 1 2832 | 0 2833 | 0 2834 | 0 2835 | 1 2836 | 0 2837 | 0 2838 | 0 2839 | 1 2840 | 1 2841 | 0 2842 | 1 2843 | 0 2844 | 1 2845 | 0 2846 | 0 2847 | 0 2848 | 1 2849 | 1 2850 | 1 2851 | 1 2852 | 0 2853 | 0 2854 | 0 2855 | 1 2856 | 1 2857 | 0 2858 | 1 2859 | 1 2860 | 0 2861 | 1 2862 | 1 2863 | 0 2864 | 0 2865 | 1 2866 | 0 2867 | 0 2868 | 1 2869 | 1 2870 | 0 2871 | 1 2872 | 0 2873 | 0 2874 | 1 2875 | 1 2876 | 1 2877 | 1 2878 | 0 2879 | 1 2880 | 0 2881 | 1 2882 | 1 2883 | 0 2884 | 0 2885 | 0 2886 | 1 2887 | 0 2888 | 1 2889 | 1 2890 | 0 2891 | 1 2892 | 1 2893 | 0 2894 | 1 2895 | 0 2896 | 0 2897 | 1 2898 | 0 2899 | 1 2900 | 0 2901 | 0 2902 | 1 2903 | 1 2904 | 1 2905 | 0 2906 | 0 2907 | 1 2908 | 0 2909 | 1 2910 | 0 2911 | 0 2912 | 1 2913 | 1 2914 | 1 2915 | 0 2916 | 1 2917 | 0 2918 | 0 2919 | 1 2920 | 0 2921 | 0 2922 | 1 2923 | 1 2924 | 1 2925 | 0 2926 | 1 2927 | 0 2928 | 0 2929 | 1 2930 | 0 2931 | 1 2932 | 0 2933 | 1 2934 | 1 2935 | 0 2936 | 0 2937 | 0 2938 | 1 2939 | 0 2940 | 1 2941 | 0 2942 | 1 2943 | 0 2944 | 0 2945 | 0 2946 | 1 2947 | 1 2948 | 1 2949 | 1 2950 | 1 2951 | 0 2952 | 0 2953 | 0 2954 | 0 2955 | 0 2956 | 1 2957 | 0 2958 | 0 2959 | 0 2960 | 1 2961 | 0 2962 | 0 2963 | 1 2964 | 1 2965 | 0 2966 | 0 2967 | 1 2968 | 0 2969 | 1 2970 | 0 2971 | 1 2972 | 1 2973 | 0 2974 | 1 2975 | 0 2976 | 1 2977 | 0 2978 | 1 2979 | 0 2980 | 1 2981 | 0 2982 | 1 2983 | 0 2984 | 1 2985 | 0 2986 | 1 2987 | 0 2988 | 1 2989 | 0 2990 | 1 2991 | 1 2992 | 1 2993 | 0 2994 | 0 2995 | 1 2996 | 0 2997 | 0 2998 | 1 2999 | 1 3000 | 0 3001 | 1 3002 | 0 3003 | 1 3004 | 0 3005 | 1 3006 | 0 3007 | 1 3008 | 0 3009 | 1 3010 | 0 3011 | 1 3012 | 0 3013 | 0 3014 | 1 3015 | 0 3016 | 1 3017 | 0 3018 | 0 3019 | 1 3020 | 0 3021 | 1 3022 | 0 3023 | 1 3024 | 0 3025 | 1 3026 | 1 3027 | 0 3028 | 0 3029 | 1 3030 | 1 3031 | 0 3032 | 1 3033 | 1 3034 | 0 3035 | 1 3036 | 1 3037 | 0 3038 | 0 3039 | 0 3040 | 1 3041 | 0 3042 | 1 3043 | 1 3044 | 1 3045 | 0 3046 | 0 3047 | 0 3048 | 0 3049 | 0 3050 | 1 3051 | 1 3052 | 0 3053 | 1 3054 | 0 3055 | 0 3056 | 0 3057 | 1 3058 | 0 3059 | 1 3060 | 0 3061 | 1 3062 | 0 3063 | 1 3064 | 0 3065 | 0 3066 | 1 3067 | 1 3068 | 1 3069 | 0 3070 | 1 3071 | 1 3072 | 0 3073 | 1 3074 | 1 3075 | 0 3076 | 0 3077 | 0 3078 | 1 3079 | 1 3080 | 0 3081 | 1 3082 | 1 3083 | 1 3084 | 0 3085 | 1 3086 | 0 3087 | 1 3088 | 0 3089 | 0 3090 | 0 3091 | 1 3092 | 0 3093 | 1 3094 | 1 3095 | 0 3096 | 1 3097 | 0 3098 | 1 3099 | 1 3100 | 0 3101 | 0 3102 | 1 3103 | 1 3104 | 1 3105 | 0 3106 | 1 3107 | 1 3108 | 0 3109 | 1 3110 | 0 3111 | 0 3112 | 1 3113 | 0 3114 | 1 3115 | 0 3116 | 1 3117 | 0 3118 | 0 3119 | 0 3120 | 0 3121 | 1 3122 | 1 3123 | 0 3124 | 0 3125 | 0 3126 | 1 3127 | 0 3128 | 1 3129 | 1 3130 | 1 3131 | 1 3132 | 1 3133 | 1 3134 | 0 3135 | 1 3136 | 0 3137 | 0 3138 | 1 3139 | 0 3140 | 0 3141 | 0 3142 | 0 3143 | 1 3144 | 1 3145 | 1 3146 | 1 3147 | 0 3148 | 0 3149 | 1 3150 | 1 3151 | 1 3152 | 0 3153 | 0 3154 | 0 3155 | 0 3156 | 1 3157 | 0 3158 | 1 3159 | 1 3160 | 1 3161 | 0 3162 | 1 3163 | 0 3164 | 0 3165 | 1 3166 | 0 3167 | 0 3168 | 1 3169 | 1 3170 | 0 3171 | 1 3172 | 0 3173 | 0 3174 | 1 3175 | 0 3176 | 1 3177 | 0 3178 | 0 3179 | 1 3180 | 1 3181 | 1 3182 | 0 3183 | 1 3184 | 0 3185 | 1 3186 | 1 3187 | 1 3188 | 0 3189 | 0 3190 | 0 3191 | 1 3192 | 0 3193 | 1 3194 | 0 3195 | 1 3196 | 0 3197 | 1 3198 | 0 3199 | 0 3200 | 0 3201 | 1 3202 | 0 3203 | 0 3204 | 1 3205 | 1 3206 | 1 3207 | 0 3208 | 1 3209 | 0 3210 | 0 3211 | 1 3212 | 0 3213 | 0 3214 | 1 3215 | 1 3216 | 0 3217 | 1 3218 | 0 3219 | 0 3220 | 1 3221 | 0 3222 | 1 3223 | 0 3224 | 1 3225 | 1 3226 | 0 3227 | 0 3228 | 1 3229 | 0 3230 | 1 3231 | 1 3232 | 1 3233 | 0 3234 | 0 3235 | 0 3236 | 0 3237 | 1 3238 | 0 3239 | 1 3240 | 0 3241 | 1 3242 | 1 3243 | 0 3244 | 0 3245 | 1 3246 | 0 3247 | 1 3248 | 0 3249 | 0 3250 | 1 3251 | 0 3252 | 0 3253 | 1 3254 | 0 3255 | 1 3256 | 1 3257 | 0 3258 | 1 3259 | 0 3260 | 1 3261 | 0 3262 | 1 3263 | 0 3264 | 0 3265 | 1 3266 | 1 3267 | 0 3268 | 1 3269 | 1 3270 | 0 3271 | 0 3272 | 1 3273 | 0 3274 | 0 3275 | 1 3276 | 0 3277 | 1 3278 | 1 3279 | 1 3280 | 0 3281 | 1 3282 | 0 3283 | 0 3284 | 1 3285 | 0 3286 | 0 3287 | 1 3288 | 1 3289 | 0 3290 | 0 3291 | 1 3292 | 0 3293 | 1 3294 | 0 3295 | 0 3296 | 0 3297 | 1 3298 | 1 3299 | 1 3300 | 1 3301 | 0 3302 | 1 3303 | 0 3304 | 0 3305 | 1 3306 | 1 3307 | 0 3308 | 1 3309 | 0 3310 | 1 3311 | 0 3312 | 1 3313 | 0 3314 | 0 3315 | 1 3316 | 1 3317 | 1 3318 | 1 3319 | 1 3320 | 0 3321 | 1 3322 | 1 3323 | 1 3324 | 1 3325 | 1 3326 | 0 3327 | 0 3328 | 1 3329 | 0 3330 | 0 3331 | 0 3332 | 1 3333 | 0 3334 | 1 3335 | 1 3336 | 0 3337 | 0 3338 | 1 3339 | 1 3340 | 1 3341 | 1 3342 | 0 3343 | 1 3344 | 0 3345 | 1 3346 | 0 3347 | 0 3348 | 0 3349 | 1 3350 | 1 3351 | 0 3352 | 0 3353 | 1 3354 | 0 3355 | 0 3356 | 0 3357 | 1 3358 | 0 3359 | 0 3360 | 1 3361 | 1 3362 | 1 3363 | 0 3364 | 0 3365 | 0 3366 | 1 3367 | 0 3368 | 1 3369 | 0 3370 | 0 3371 | 1 3372 | 1 3373 | 1 3374 | 1 3375 | 0 3376 | 1 3377 | 0 3378 | 0 3379 | 1 3380 | 0 3381 | 1 3382 | 1 3383 | 0 3384 | 0 3385 | 1 3386 | 1 3387 | 1 3388 | 0 3389 | 0 3390 | 0 3391 | 1 3392 | 0 3393 | 0 3394 | 1 3395 | 0 3396 | 0 3397 | 1 3398 | 1 3399 | 0 3400 | 1 3401 | 1 3402 | 0 3403 | 1 3404 | 0 3405 | 0 3406 | 1 3407 | 1 3408 | 0 3409 | 0 3410 | 1 3411 | 0 3412 | 0 3413 | 1 3414 | 0 3415 | 1 3416 | 0 3417 | 0 3418 | 1 3419 | 1 3420 | 1 3421 | 1 3422 | 0 3423 | 0 3424 | 0 3425 | 0 3426 | 1 3427 | 1 3428 | 0 3429 | 1 3430 | 1 3431 | 1 3432 | 0 3433 | 1 3434 | 0 3435 | 1 3436 | 0 3437 | 0 3438 | 1 3439 | 1 3440 | 0 3441 | 1 3442 | 1 3443 | 0 3444 | 0 3445 | 1 3446 | 1 3447 | 0 3448 | 1 3449 | 0 3450 | 1 3451 | 1 3452 | 0 3453 | 0 3454 | 1 3455 | 0 3456 | 0 3457 | 1 3458 | 0 3459 | 1 3460 | 1 3461 | 0 3462 | 0 3463 | 0 3464 | 1 3465 | 1 3466 | 0 3467 | 1 3468 | 1 3469 | 0 3470 | 0 3471 | 1 3472 | 1 3473 | 1 3474 | 1 3475 | 0 3476 | 0 3477 | 0 3478 | 0 3479 | 1 3480 | 1 3481 | 0 3482 | 0 3483 | 1 3484 | 0 3485 | 0 3486 | 1 3487 | 0 3488 | 1 3489 | 1 3490 | 0 3491 | 1 3492 | 0 3493 | 0 3494 | 0 3495 | 1 3496 | 0 3497 | 1 3498 | 1 3499 | 1 3500 | 0 3501 | 0 3502 | 0 3503 | 1 3504 | 1 3505 | 0 3506 | 0 3507 | 1 3508 | 0 3509 | 0 3510 | 1 3511 | 1 3512 | 0 3513 | 1 3514 | 1 3515 | 0 3516 | 0 3517 | 0 3518 | 1 3519 | 1 3520 | 1 3521 | 0 3522 | 0 3523 | 1 3524 | 0 3525 | 0 3526 | 1 3527 | 0 3528 | 0 3529 | 1 3530 | 1 3531 | 0 3532 | 1 3533 | 1 3534 | 0 3535 | 1 3536 | 0 3537 | 1 3538 | 0 3539 | 0 3540 | 1 3541 | 1 3542 | 0 3543 | 1 3544 | 0 3545 | 0 3546 | 0 3547 | 1 3548 | 1 3549 | 1 3550 | 1 3551 | 0 3552 | 0 3553 | 0 3554 | 0 3555 | 0 3556 | 1 3557 | 1 3558 | 0 3559 | 1 3560 | 0 3561 | 1 3562 | 0 3563 | 1 3564 | 0 3565 | 1 3566 | 1 3567 | 1 3568 | 0 3569 | 1 3570 | 0 3571 | 1 3572 | 0 3573 | 0 3574 | 1 3575 | 1 3576 | 0 3577 | 1 3578 | 1 3579 | 0 3580 | 1 3581 | 0 3582 | 0 3583 | 1 3584 | 1 3585 | 0 3586 | 1 3587 | 0 3588 | 1 3589 | 1 3590 | 0 3591 | 1 3592 | 1 3593 | 0 3594 | 0 3595 | 0 3596 | 1 3597 | 0 3598 | 0 3599 | 1 3600 | 1 3601 | 1 3602 | 0 3603 | 1 3604 | 1 3605 | 0 3606 | 0 3607 | 1 3608 | 0 3609 | 1 3610 | 1 3611 | 0 3612 | 1 3613 | 0 3614 | 0 3615 | 0 3616 | 0 3617 | 1 3618 | 1 3619 | 0 3620 | 0 3621 | 1 3622 | 0 3623 | 0 3624 | 1 3625 | 1 3626 | 0 3627 | 1 3628 | 0 3629 | 0 3630 | 0 3631 | 1 3632 | 1 3633 | 0 3634 | 0 3635 | 1 3636 | 0 3637 | 1 3638 | 1 3639 | 0 3640 | 0 3641 | 1 3642 | 0 3643 | 0 3644 | 0 3645 | 0 3646 | 1 3647 | 1 3648 | 1 3649 | 0 3650 | 0 3651 | 1 3652 | 1 3653 | 0 3654 | 1 3655 | 0 3656 | 0 3657 | 1 3658 | 1 3659 | 0 3660 | 0 3661 | 1 3662 | 1 3663 | 0 3664 | 1 3665 | 0 3666 | 0 3667 | 1 3668 | 0 3669 | 1 3670 | 1 3671 | 0 3672 | 1 3673 | 0 3674 | 0 3675 | 0 3676 | 1 3677 | 1 3678 | 1 3679 | 1 3680 | 0 3681 | 1 3682 | 0 3683 | 1 3684 | 0 3685 | 1 3686 | 1 3687 | 0 3688 | 1 3689 | 0 3690 | 1 3691 | 0 3692 | 1 3693 | 1 3694 | 1 3695 | 0 3696 | 0 3697 | 0 3698 | 0 3699 | 1 3700 | 1 3701 | 1 3702 | 1 3703 | 0 3704 | 0 3705 | 1 3706 | 0 3707 | 0 3708 | 1 3709 | 1 3710 | 1 3711 | 1 3712 | 0 3713 | 1 3714 | 1 3715 | 1 3716 | 0 3717 | 0 3718 | 1 3719 | 1 3720 | 0 3721 | 0 3722 | 1 3723 | 0 3724 | 1 3725 | 1 3726 | 0 3727 | 1 3728 | 0 3729 | 1 3730 | 1 3731 | 0 3732 | 0 3733 | 1 3734 | 1 3735 | 0 3736 | 1 3737 | 0 3738 | 1 3739 | 0 3740 | 1 3741 | 1 3742 | 1 3743 | 0 3744 | 0 3745 | 1 3746 | 1 3747 | 1 3748 | 1 3749 | 1 3750 | 0 3751 | 0 3752 | 0 3753 | 1 3754 | 1 3755 | 0 3756 | 1 3757 | 0 3758 | 1 3759 | 0 3760 | 0 3761 | 1 3762 | 1 3763 | 0 3764 | 1 3765 | 0 3766 | 0 3767 | 0 3768 | 1 3769 | 1 3770 | 0 3771 | 1 3772 | 0 3773 | 0 3774 | 1 3775 | 1 3776 | 1 3777 | 0 3778 | 0 3779 | 1 3780 | 0 3781 | 0 3782 | 1 3783 | 0 3784 | 1 3785 | 0 3786 | 1 3787 | 1 3788 | 0 3789 | 0 3790 | 1 3791 | 1 3792 | 0 3793 | 0 3794 | 0 3795 | 0 3796 | 1 3797 | 1 3798 | 1 3799 | 0 3800 | 1 3801 | 0 3802 | 0 3803 | 1 3804 | 1 3805 | 1 3806 | 0 3807 | 0 3808 | 1 3809 | 0 3810 | 1 3811 | 0 3812 | 1 3813 | 0 3814 | 1 3815 | 0 3816 | 1 3817 | 0 3818 | 0 3819 | 1 3820 | 0 3821 | 0 3822 | 1 3823 | 0 3824 | 0 3825 | 0 3826 | 0 3827 | 1 3828 | 1 3829 | 0 3830 | 0 3831 | 1 3832 | 1 3833 | 1 3834 | 1 3835 | 0 3836 | 1 3837 | 0 3838 | 1 3839 | 0 3840 | 0 3841 | 1 3842 | 1 3843 | 1 3844 | 0 3845 | 0 3846 | 1 3847 | 0 3848 | 1 3849 | 0 3850 | 0 3851 | 0 3852 | 1 3853 | 1 3854 | 0 3855 | 1 3856 | 0 3857 | 0 3858 | 1 3859 | 1 3860 | 0 3861 | 1 3862 | 0 3863 | 0 3864 | 1 3865 | 0 3866 | 1 3867 | 1 3868 | 0 3869 | 0 3870 | 0 3871 | 1 3872 | 1 3873 | 0 3874 | 1 3875 | 0 3876 | 1 3877 | 1 3878 | 0 3879 | 1 3880 | 1 3881 | 1 3882 | 0 3883 | 0 3884 | 1 3885 | 0 3886 | 1 3887 | 1 3888 | 0 3889 | 1 3890 | 0 3891 | 0 3892 | 0 3893 | 1 3894 | 0 3895 | 1 3896 | 1 3897 | 0 3898 | 1 3899 | 0 3900 | 0 3901 | 1 3902 | 1 3903 | 0 3904 | 1 3905 | 0 3906 | 1 3907 | 0 3908 | 1 3909 | 1 3910 | 0 3911 | 1 3912 | 0 3913 | 1 3914 | 0 3915 | 1 3916 | 0 3917 | 1 3918 | 0 3919 | 1 3920 | 0 3921 | 0 3922 | 0 3923 | 1 3924 | 1 3925 | 0 3926 | 1 3927 | 0 3928 | 1 3929 | 0 3930 | 1 3931 | 1 3932 | 0 3933 | 1 3934 | 1 3935 | 0 3936 | 0 3937 | 0 3938 | 1 3939 | 1 3940 | 0 3941 | 1 3942 | 0 3943 | 0 3944 | 0 3945 | 1 3946 | 0 3947 | 1 3948 | 0 3949 | 1 3950 | 1 3951 | 0 3952 | 0 3953 | 1 3954 | 1 3955 | 1 3956 | 0 3957 | 1 3958 | 0 3959 | 1 3960 | 1 3961 | 1 3962 | 0 3963 | 0 3964 | 0 3965 | 1 3966 | 0 3967 | 1 3968 | 1 3969 | 0 3970 | 0 3971 | 1 3972 | 0 3973 | 1 3974 | 1 3975 | 0 3976 | 1 3977 | 0 3978 | 1 3979 | 0 3980 | 0 3981 | 1 3982 | 0 3983 | 1 3984 | 0 3985 | 1 3986 | 1 3987 | 0 3988 | 1 3989 | 1 3990 | 0 3991 | 0 3992 | 1 3993 | 1 3994 | 0 3995 | 1 3996 | 0 3997 | 0 3998 | --------------------------------------------------------------------------------