├── ER.py ├── README.md ├── __init__.py ├── common.py ├── discriminator.py ├── driver.py ├── environment.py ├── expert_trajectories └── hopper_er.bin ├── forward_model.py ├── main.py ├── mgail.py ├── ops.py └── policy.py /ER.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | class ER(object): 6 | 7 | def __init__(self, memory_size, state_dim, action_dim, reward_dim, qpos_dim, qvel_dim, batch_size, history_length=1): 8 | self.memory_size = memory_size 9 | self.actions = np.random.normal(scale=0.35, size=(self.memory_size, action_dim)) 10 | self.rewards = np.random.normal(scale=0.35, size=(self.memory_size, )) 11 | self.states = np.random.normal(scale=0.35, size=(self.memory_size, state_dim)) 12 | self.qpos = np.random.normal(scale=0.35, size=(self.memory_size, qpos_dim)) 13 | self.qvel = np.random.normal(scale=0.35, size=(self.memory_size, qvel_dim)) 14 | self.terminals = np.zeros(self.memory_size, dtype=np.float32) 15 | self.batch_size = batch_size 16 | self.history_length = history_length 17 | self.count = 0 18 | self.current = 0 19 | self.state_dim = state_dim 20 | self.action_dim = action_dim 21 | 22 | # pre-allocate prestates and poststates for minibatch 23 | self.prestates = np.empty((self.batch_size, self.history_length, state_dim), dtype=np.float32) 24 | self.poststates = np.empty((self.batch_size, self.history_length, state_dim), dtype=np.float32) 25 | self.traj_length = 2 26 | self.traj_states = np.empty((self.batch_size, self.traj_length, state_dim), dtype=np.float32) 27 | self.traj_actions = np.empty((self.batch_size, self.traj_length-1, action_dim), dtype=np.float32) 28 | 29 | def add(self, actions, rewards, next_states, terminals, qposs=[], qvels = []): 30 | # state is post-state, after action and reward 31 | for idx in range(len(actions)): 32 | self.actions[self.current, ...] = actions[idx] 33 | self.rewards[self.current] = rewards[idx] 34 | self.states[self.current, ...] = next_states[idx] 35 | self.terminals[self.current] = terminals[idx] 36 | if len(qposs) == len(actions): 37 | self.qpos[self.current, ...] = qposs[idx] 38 | self.qvel[self.current, ...] = qvels[idx] 39 | self.count = max(self.count, self.current + 1) 40 | self.current = (self.current + 1) % self.memory_size 41 | 42 | def get_state(self, index): 43 | assert self.count > 0, "replay memory is empy" 44 | # normalize index to expected range, allows negative indexes 45 | index = index % self.count 46 | # if is not in the beginning of matrix 47 | if index >= self.history_length - 1: 48 | # use faster slicing 49 | return self.states[(index - (self.history_length - 1)):(index + 1), ...] 50 | else: 51 | # otherwise normalize indexes and use slower list based access 52 | indexes = [(index - i) % self.count for i in reversed(range(self.history_length))] 53 | return self.states[indexes, ...] 54 | 55 | def sample(self): 56 | # memory must include poststate, prestate and history 57 | assert self.count > self.history_length 58 | # sample random indexes 59 | indexes = [] 60 | while len(indexes) < self.batch_size: 61 | # find random index 62 | while True: 63 | # sample one index (ignore states wraping over 64 | index = random.randint(self.history_length, self.count - 1) 65 | # if wraps over current pointer, then get new one 66 | if index >= self.current > index - self.history_length: 67 | continue 68 | # if wraps over episode end, then get new one 69 | # poststate (last screen) can be terminal state! 70 | if self.terminals[(index - self.history_length):index].any(): 71 | continue 72 | # otherwise use this index 73 | break 74 | 75 | # having index first is fastest in C-order matrices 76 | self.prestates[len(indexes), ...] = self.get_state(index - 1) 77 | self.poststates[len(indexes), ...] = self.get_state(index) 78 | indexes.append(index) 79 | 80 | actions = self.actions[indexes, ...] 81 | rewards = self.rewards[indexes, ...] 82 | if hasattr(self, 'qpos'): 83 | qpos = self.qpos[indexes, ...] 84 | qvels = self.qvel[indexes, ...] 85 | else: 86 | qpos = [] 87 | qvels = [] 88 | terminals = self.terminals[indexes] 89 | 90 | return np.squeeze(self.prestates, axis=1), actions, rewards, \ 91 | np.squeeze(self.poststates, axis=1), terminals, qpos, qvels -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Model-Based Generative Adversarial Imitation Learning 2 | 3 | Code for ICML 2017 paper "End-to-End Differentiable Adversarial Imitation Learning", by Nir Baram, Oron Anschel, Itai Caspi, Shie Mannor. 4 | 5 | ## Dependencies 6 | * Gym >= 0.8.1 7 | * Mujoco-py >= 0.5.7 8 | * Tensorflow >= 1.0.1 9 | 10 | ## Running 11 | Run the following command to train the Mujoco Hopper environment by imitating an expert trained with TRPO 12 | 13 | ```python 14 | python main.py 15 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itaicaspi/mgail/b3b91aa5e0bd47923f726a27522f45146721940d/__init__.py -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | 6 | def save_params(fname, saver, session): 7 | saver.save(session, fname) 8 | 9 | 10 | def load_er(fname, batch_size, history_length, traj_length): 11 | f = file(fname, 'rb') 12 | er = cPickle.load(f) 13 | er.batch_size = batch_size 14 | er = set_er_stats(er, history_length, traj_length) 15 | return er 16 | 17 | 18 | def set_er_stats(er, history_length, traj_length): 19 | state_dim = er.states.shape[-1] 20 | action_dim = er.actions.shape[-1] 21 | er.prestates = np.empty((er.batch_size, history_length, state_dim), dtype=np.float32) 22 | er.poststates = np.empty((er.batch_size, history_length, state_dim), dtype=np.float32) 23 | er.traj_states = np.empty((er.batch_size, traj_length, state_dim), dtype=np.float32) 24 | er.traj_actions = np.empty((er.batch_size, traj_length-1, action_dim), dtype=np.float32) 25 | er.states_min = np.min(er.states[:er.count], axis=0) 26 | er.states_max = np.max(er.states[:er.count], axis=0) 27 | er.actions_min = np.min(er.actions[:er.count], axis=0) 28 | er.actions_max = np.max(er.actions[:er.count], axis=0) 29 | er.states_mean = np.mean(er.states[:er.count], axis=0) 30 | er.actions_mean = np.mean(er.actions[:er.count], axis=0) 31 | er.states_std = np.std(er.states[:er.count], axis=0) 32 | er.states_std[er.states_std == 0] = 1 33 | er.actions_std = np.std(er.actions[:er.count], axis=0) 34 | return er 35 | 36 | 37 | def re_parametrization(state_e, state_a): 38 | nu = state_e - state_a 39 | nu = tf.stop_gradient(nu) 40 | return state_a + nu, nu 41 | 42 | 43 | def normalize(x, mean, std): 44 | return (x - mean)/std 45 | 46 | 47 | def denormalize(x, mean, std): 48 | return x * std + mean 49 | 50 | 51 | def sample_gumbel(shape, eps=1e-20): 52 | """Sample from Gumbel(0, 1)""" 53 | U = tf.random_uniform(shape,minval=0,maxval=1) 54 | return -tf.log(-tf.log(U + eps) + eps) 55 | 56 | 57 | def gumbel_softmax_sample(logits, temperature): 58 | """ Draw a sample from the Gumbel-Softmax distribution""" 59 | y = logits + sample_gumbel(tf.shape(logits)) 60 | return tf.nn.softmax(y / temperature) 61 | 62 | 63 | def gumbel_softmax(logits, temperature, hard=True): 64 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 65 | Args: 66 | logits: [batch_size, n_class] unnormalized log-probs 67 | temperature: non-negative scalar 68 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 69 | Returns: 70 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 71 | If hard=True, then the returned sample will be one-hot, otherwise it will 72 | be a probabilitiy distribution that sums to 1 across classes 73 | """ 74 | y = gumbel_softmax_sample(logits, temperature) 75 | if hard: 76 | k = tf.shape(logits)[-1] 77 | #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) 78 | y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype) 79 | y = tf.stop_gradient(y_hard - y) + y 80 | return y 81 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops 3 | 4 | 5 | class Discriminator(object): 6 | def __init__(self, in_dim, out_dim, size, lr, do_keep_prob, weight_decay): 7 | self.arch_params = { 8 | 'in_dim': in_dim, 9 | 'out_dim': out_dim, 10 | 'n_hidden_0': size[0], 11 | 'n_hidden_1': size[1], 12 | 'do_keep_prob': do_keep_prob 13 | } 14 | 15 | self.solver_params = { 16 | 'lr': lr, 17 | 'weight_decay': weight_decay 18 | } 19 | 20 | def forward(self, state, action, reuse=False): 21 | 22 | with tf.variable_scope('discriminator'): 23 | concat = tf.concat(axis=1, values=[state, action]) 24 | h0 = ops.dense(concat, self.arch_params['in_dim'], self.arch_params['n_hidden_0'], tf.nn.relu, 'dense0', reuse) 25 | h1 = ops.dense(h0, self.arch_params['n_hidden_0'], self.arch_params['n_hidden_1'], tf.nn.relu, 'dense1', reuse) 26 | relu1_do = tf.nn.dropout(h1, self.arch_params['do_keep_prob']) 27 | d = ops.dense(relu1_do, self.arch_params['n_hidden_1'], self.arch_params['out_dim'], None, 'dense2', reuse) 28 | 29 | return d 30 | 31 | def backward(self, loss): 32 | self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 33 | 34 | # create an optimizer 35 | opt = tf.train.AdamOptimizer(learning_rate=self.solver_params['lr']) 36 | 37 | # weight decay 38 | loss += self.solver_params['weight_decay'] * tf.add_n([tf.nn.l2_loss(w) for w in self.weights if 'weights' in w.name]) 39 | 40 | # compute the gradients for a list of variables 41 | grads_and_vars = opt.compute_gradients(loss=loss, var_list=self.weights) 42 | 43 | # apply the gradient 44 | apply_grads = opt.apply_gradients(grads_and_vars) 45 | 46 | return apply_grads 47 | 48 | def train(self, objective): 49 | self.loss = objective 50 | self.minimize = self.backward(self.loss) 51 | -------------------------------------------------------------------------------- /driver.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | import common 6 | from mgail import MGAIL 7 | 8 | 9 | class Driver(object): 10 | def __init__(self, environment): 11 | 12 | self.env = environment 13 | self.algorithm = MGAIL(environment=self.env) 14 | self.init_graph = tf.global_variables_initializer() 15 | self.saver = tf.train.Saver() 16 | self.sess = tf.Session() 17 | if self.env.trained_model: 18 | self.saver.restore(self.sess, self.env.trained_model) 19 | else: 20 | self.sess.run(self.init_graph) 21 | self.run_dir = self.env.run_dir 22 | self.loss = 999. * np.ones(3) 23 | self.reward_mean = 0 24 | self.reward_std = 0 25 | self.run_avg = 0.001 26 | self.discriminator_policy_switch = 0 27 | self.policy_loop_time = 0 28 | self.disc_acc = 0 29 | self.er_count = 0 30 | self.itr = 0 31 | self.best_reward = 0 32 | self.mode = 'Prep' 33 | np.set_printoptions(precision=2) 34 | np.set_printoptions(linewidth=220) 35 | 36 | def update_stats(self, module, attr, value): 37 | v = {'forward_model': 0, 'discriminator': 1, 'policy': 2} 38 | module_ind = v[module] 39 | if attr == 'loss': 40 | self.loss[module_ind] = self.run_avg * self.loss[module_ind] + (1 - self.run_avg) * np.asarray(value) 41 | elif attr == 'accuracy': 42 | self.disc_acc = self.run_avg * self.disc_acc + (1 - self.run_avg) * np.asarray(value) 43 | 44 | def train_forward_model(self): 45 | alg = self.algorithm 46 | states_, actions, _, states = self.algorithm.er_agent.sample()[:4] 47 | fetches = [alg.forward_model.minimize, alg.forward_model.loss] 48 | feed_dict = {alg.states_: states_, alg.states: states, alg.actions: actions, 49 | alg.do_keep_prob: self.env.do_keep_prob} 50 | run_vals = self.sess.run(fetches, feed_dict) 51 | self.update_stats('forward_model', 'loss', run_vals[1]) 52 | 53 | def train_discriminator(self): 54 | alg = self.algorithm 55 | # get states and actions 56 | state_a_, action_a = self.algorithm.er_agent.sample()[:2] 57 | state_e_, action_e = self.algorithm.er_expert.sample()[:2] 58 | states = np.concatenate([state_a_, state_e_]) 59 | actions = np.concatenate([action_a, action_e]) 60 | # labels (policy/expert) : 0/1, and in 1-hot form: policy-[1,0], expert-[0,1] 61 | labels_a = np.zeros(shape=(state_a_.shape[0],)) 62 | labels_e = np.ones(shape=(state_e_.shape[0],)) 63 | labels = np.expand_dims(np.concatenate([labels_a, labels_e]), axis=1) 64 | fetches = [alg.discriminator.minimize, alg.discriminator.loss, alg.discriminator.acc] 65 | feed_dict = {alg.states: states, alg.actions: actions, 66 | alg.label: labels, alg.do_keep_prob: self.env.do_keep_prob} 67 | run_vals = self.sess.run(fetches, feed_dict) 68 | self.update_stats('discriminator', 'loss', run_vals[1]) 69 | self.update_stats('discriminator', 'accuracy', run_vals[2]) 70 | 71 | def train_policy(self): 72 | alg = self.algorithm 73 | 74 | # reset the policy gradient 75 | self.sess.run([alg.policy.reset_grad_op], {}) 76 | 77 | # Adversarial Learning 78 | if self.env.get_status(): 79 | state = self.env.reset() 80 | else: 81 | state = self.env.get_state() 82 | 83 | # Accumulate the (noisy) adversarial gradient 84 | for i in range(self.env.policy_accum_steps): 85 | # accumulate AL gradient 86 | fetches = [alg.policy.accum_grads_al, alg.policy.loss_al] 87 | feed_dict = {alg.states: np.array([state]), alg.gamma: self.env.gamma, 88 | alg.do_keep_prob: self.env.do_keep_prob, alg.noise: 1., alg.temp: self.env.temp} 89 | run_vals = self.sess.run(fetches, feed_dict) 90 | self.update_stats('policy', 'loss', run_vals[1]) 91 | 92 | # apply AL gradient 93 | self.sess.run([alg.policy.apply_grads_al], {}) 94 | 95 | def collect_experience(self, record=1, vis=0, n_steps=None, noise_flag=True, start_at_zero=True): 96 | alg = self.algorithm 97 | 98 | # environment initialization point 99 | if start_at_zero: 100 | observation = self.env.reset() 101 | else: 102 | qposs, qvels = alg.er_expert.sample()[5:] 103 | observation = self.env.reset(qpos=qposs[0], qvel=qvels[0]) 104 | 105 | do_keep_prob = self.env.do_keep_prob 106 | t = 0 107 | R = 0 108 | done = 0 109 | if n_steps is None: 110 | n_steps = self.env.n_steps_test 111 | 112 | while not done: 113 | if vis: 114 | self.env.render() 115 | 116 | if not noise_flag: 117 | do_keep_prob = 1. 118 | 119 | a = self.sess.run(fetches=[alg.action_test], feed_dict={alg.states: np.reshape(observation, [1, -1]), 120 | alg.do_keep_prob: do_keep_prob, 121 | alg.noise: noise_flag, 122 | alg.temp: self.env.temp}) 123 | 124 | observation, reward, done, info, qpos, qvel = self.env.step(a, mode='python') 125 | 126 | done = done or t > n_steps 127 | t += 1 128 | R += reward 129 | 130 | if record: 131 | if self.env.continuous_actions: 132 | action = a 133 | else: 134 | action = np.zeros((1, self.env.action_size)) 135 | action[0, a[0]] = 1 136 | alg.er_agent.add(actions=action, rewards=[reward], next_states=[observation], terminals=[done], 137 | qposs=[qpos], qvels=[qvel]) 138 | 139 | return R 140 | 141 | def train_step(self): 142 | # phase_1 - Adversarial training 143 | # forward_model: learning from agent data 144 | # discriminator: learning in an interleaved mode with policy 145 | # policy: learning in adversarial mode 146 | 147 | # Fill Experience Buffer 148 | if self.itr == 0: 149 | while self.algorithm.er_agent.current == self.algorithm.er_agent.count: 150 | self.collect_experience() 151 | buf = 'Collecting examples...%d/%d' % \ 152 | (self.algorithm.er_agent.current, self.algorithm.er_agent.states.shape[0]) 153 | sys.stdout.write('\r' + buf) 154 | 155 | # Adversarial Learning 156 | else: 157 | self.train_forward_model() 158 | 159 | self.mode = 'Prep' 160 | if self.itr < self.env.prep_time: 161 | self.train_discriminator() 162 | else: 163 | self.mode = 'AL' 164 | 165 | if self.discriminator_policy_switch: 166 | self.train_discriminator() 167 | else: 168 | self.train_policy() 169 | 170 | if self.itr % self.env.collect_experience_interval == 0: 171 | self.collect_experience(start_at_zero=False, n_steps=self.env.n_steps_train) 172 | 173 | # switch discriminator-policy 174 | if self.itr % self.env.discr_policy_itrvl == 0: 175 | self.discriminator_policy_switch = not self.discriminator_policy_switch 176 | 177 | # print progress 178 | if self.itr % 100 == 0: 179 | self.print_info_line('slim') 180 | 181 | def print_info_line(self, mode): 182 | if mode == 'full': 183 | buf = '%s Training(%s): iter %d, loss: %s R: %.1f, R_std: %.2f\n' % \ 184 | (time.strftime("%H:%M:%S"), self.mode, self.itr, self.loss, self.reward_mean, self.reward_std) 185 | else: 186 | buf = "processing iter: %d, loss(forward_model,discriminator,policy): %s" % (self.itr, self.loss) 187 | sys.stdout.write('\r' + buf) 188 | 189 | def save_model(self, dir_name=None): 190 | import os 191 | if dir_name is None: 192 | dir_name = self.run_dir + '/snapshots/' 193 | if not os.path.isdir(dir_name): 194 | os.mkdir(dir_name) 195 | fname = dir_name + time.strftime("%Y-%m-%d-%H-%M-") + ('%0.6d.sn' % self.itr) 196 | common.save_params(fname=fname, saver=self.saver, session=self.sess) 197 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import gym 4 | 5 | 6 | class Environment(object): 7 | def __init__(self, run_dir, env_name): 8 | self.name = env_name 9 | self.gym = gym.make(self.name) 10 | self.random_initialization = True 11 | self._connect() 12 | self._train_params() 13 | self.run_dir = run_dir 14 | 15 | def _step(self, action): 16 | action = np.squeeze(action) 17 | self.t += 1 18 | result = self.gym.step(action) 19 | self.state, self.reward, self.done, self.info = result[:4] 20 | if self.random_initialization: 21 | self.qpos, self.qvel = self.gym.env.model.data.qpos.flatten(), self.gym.env.model.data.qvel.flatten() 22 | return np.float32(self.state), np.float32(self.reward), self.done, np.float32(self.qpos), np.float32(self.qvel) 23 | else: 24 | return np.float32(self.state), np.float32(self.reward), self.done 25 | 26 | def step(self, action, mode): 27 | qvel, qpos = [], [] 28 | if mode == 'tensorflow': 29 | if self.random_initialization: 30 | state, reward, done, qval, qpos = tf.py_func(self._step, inp=[action], Tout=[tf.float32, tf.float32, tf.bool, tf.float32, tf.float32], name='env_step_func') 31 | else: 32 | state, reward, done = tf.py_func(self._step, inp=[action], 33 | Tout=[tf.float32, tf.float32, tf.bool], 34 | name='env_step_func') 35 | 36 | state = tf.reshape(state, shape=(self.state_size,)) 37 | done.set_shape(()) 38 | else: 39 | if self.random_initialization: 40 | state, reward, done, qvel, qpos = self._step(action) 41 | else: 42 | state, reward, done = self._step(action) 43 | 44 | return state, reward, done, 0., qvel, qpos 45 | 46 | def reset(self, qpos=None, qvel=None): 47 | self.t = 0 48 | self.state = self.gym.reset() 49 | if self.random_initialization and qpos is not None and qvel is not None: 50 | self.gym.env.set_state(qpos, qvel) 51 | return self.state 52 | 53 | def get_status(self): 54 | return self.done 55 | 56 | def get_state(self): 57 | return self.state 58 | 59 | def render(self): 60 | self.gym.render() 61 | 62 | def _connect(self): 63 | self.state_size = self.gym.observation_space.shape[0] 64 | self.action_size = self.gym.action_space.shape[0] 65 | self.action_space = np.asarray([None]*self.action_size) 66 | self.qpos_size = self.gym.env.data.qpos.shape[0] 67 | self.qvel_size = self.gym.env.data.qvel.shape[0] 68 | 69 | def _train_params(self): 70 | self.trained_model = None 71 | self.train_mode = True 72 | self.expert_data = 'expert_trajectories/hopper_er.bin' 73 | self.n_train_iters = 1000000 74 | self.n_episodes_test = 1 75 | self.test_interval = 1000 76 | self.n_steps_test = 1000 77 | self.vis_flag = True 78 | self.save_models = True 79 | self.config_dir = None 80 | self.continuous_actions = True 81 | 82 | # Main parameters to play with: 83 | self.er_agent_size = 50000 84 | self.prep_time = 1000 85 | self.collect_experience_interval = 15 86 | self.n_steps_train = 10 87 | self.discr_policy_itrvl = 100 88 | self.gamma = 0.99 89 | self.batch_size = 70 90 | self.weight_decay = 1e-7 91 | self.policy_al_w = 1e-2 92 | self.policy_tr_w = 1e-4 93 | self.policy_accum_steps = 7 94 | self.total_trans_err_allowed = 1000 95 | self.temp = 1. 96 | self.cost_sensitive_weight = 0.8 97 | self.noise_intensity = 6. 98 | self.do_keep_prob = 0.75 99 | 100 | # Hidden layers size 101 | self.fm_size = 100 102 | self.d_size = [200, 100] 103 | self.p_size = [100, 50] 104 | 105 | # Learning rates 106 | self.fm_lr = 1e-4 107 | self.d_lr = 1e-3 108 | self.p_lr = 1e-4 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /forward_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import ops 3 | import tensorflow as tf 4 | import common 5 | 6 | 7 | class ForwardModel(object): 8 | def __init__(self, state_size, action_size, encoding_size, lr): 9 | self.state_size = state_size 10 | self.action_size = action_size 11 | self.encoding_size = encoding_size 12 | 13 | self.lr = lr 14 | 15 | def forward(self, input, reuse=False): 16 | with tf.variable_scope('forward_model'): 17 | state = tf.cast(input[0], tf.float32) 18 | action = tf.cast(input[1], tf.float32) 19 | gru_state = tf.cast(input[2], tf.float32) 20 | 21 | # State embedding 22 | state_embedder1 = ops.dense(state, self.state_size, self.encoding_size, tf.nn.relu, "encoder1_state", reuse) 23 | gru_state = ops.gru(state_embedder1, gru_state, self.encoding_size, self.encoding_size, 'gru1', reuse) 24 | state_embedder2 = ops.dense(gru_state, self.encoding_size, self.encoding_size, tf.sigmoid, "encoder2_state", reuse) 25 | 26 | # Action embedding 27 | action_embedder1 = ops.dense(action, self.action_size, self.encoding_size, tf.nn.relu, "encoder1_action", reuse) 28 | action_embedder2 = ops.dense(action_embedder1, self.encoding_size, self.encoding_size, tf.sigmoid, "encoder2_action", reuse) 29 | 30 | # Joint embedding 31 | joint_embedding = tf.multiply(state_embedder2, action_embedder2) 32 | 33 | # Next state prediction 34 | hidden1 = ops.dense(joint_embedding, self.encoding_size, self.encoding_size, tf.nn.relu, "encoder3", reuse) 35 | hidden2 = ops.dense(hidden1, self.encoding_size, self.encoding_size, tf.nn.relu, "encoder4", reuse) 36 | hidden3 = ops.dense(hidden2, self.encoding_size, self.encoding_size, tf.nn.relu, "decoder1", reuse) 37 | next_state = ops.dense(hidden3, self.encoding_size, self.state_size, None, "decoder2", reuse) 38 | 39 | gru_state = tf.cast(gru_state, tf.float64) 40 | 41 | return next_state, gru_state 42 | 43 | def backward(self, loss): 44 | self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='forward_model') 45 | 46 | # create an optimizer 47 | opt = tf.train.AdamOptimizer(learning_rate=self.lr) 48 | 49 | # compute the gradients for a list of variables 50 | grads_and_vars = opt.compute_gradients(loss=loss, var_list=self.weights) 51 | 52 | # apply the gradient 53 | apply_grads = opt.apply_gradients(grads_and_vars) 54 | 55 | return apply_grads 56 | 57 | def train(self, objective): 58 | self.loss = objective 59 | self.minimize = self.backward(self.loss) 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from environment import Environment 4 | from driver import Driver 5 | 6 | 7 | def dispatcher(env): 8 | 9 | driver = Driver(env) 10 | 11 | while driver.itr < env.n_train_iters: 12 | 13 | # Train 14 | if env.train_mode: 15 | driver.train_step() 16 | 17 | # Test 18 | if driver.itr % env.test_interval == 0: 19 | 20 | # measure performance 21 | R = [] 22 | for n in range(env.n_episodes_test): 23 | R.append(driver.collect_experience(record=True, vis=env.vis_flag, noise_flag=False, n_steps=1000)) 24 | 25 | # update stats 26 | driver.reward_mean = sum(R) / len(R) 27 | driver.reward_std = np.std(R) 28 | 29 | # print info line 30 | driver.print_info_line('full') 31 | 32 | # save snapshot 33 | if env.train_mode and env.save_models: 34 | driver.save_model(dir_name=env.config_dir) 35 | 36 | driver.itr += 1 37 | 38 | 39 | if __name__ == '__main__': 40 | # load environment 41 | env = Environment(os.path.curdir, 'Hopper-v1') 42 | 43 | # start training 44 | dispatcher(env=env) 45 | -------------------------------------------------------------------------------- /mgail.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import os 5 | import common 6 | from ER import ER 7 | from forward_model import ForwardModel 8 | from discriminator import Discriminator 9 | from policy import Policy 10 | 11 | 12 | class MGAIL(object): 13 | def __init__(self, environment): 14 | 15 | self.env = environment 16 | 17 | # Create placeholders for all the inputs 18 | self.states_ = tf.placeholder("float", shape=(None, self.env.state_size), name='states_') # Batch x State 19 | self.states = tf.placeholder("float", shape=(None, self.env.state_size), name='states') # Batch x State 20 | self.actions = tf.placeholder("float", shape=(None, self.env.action_size), name='action') # Batch x Action 21 | self.label = tf.placeholder("float", shape=(None, 1), name='label') 22 | self.gamma = tf.placeholder("float", shape=(), name='gamma') 23 | self.temp = tf.placeholder("float", shape=(), name='temperature') 24 | self.noise = tf.placeholder("float", shape=(), name='noise_flag') 25 | self.do_keep_prob = tf.placeholder("float", shape=(), name='do_keep_prob') 26 | 27 | # Create MGAIL blocks 28 | self.forward_model = ForwardModel(state_size=self.env.state_size, 29 | action_size=self.env.action_size, 30 | encoding_size=self.env.fm_size, 31 | lr=self.env.fm_lr) 32 | 33 | self.discriminator = Discriminator(in_dim=self.env.state_size + self.env.action_size, 34 | out_dim=2, 35 | size=self.env.d_size, 36 | lr=self.env.d_lr, 37 | do_keep_prob=self.do_keep_prob, 38 | weight_decay=self.env.weight_decay) 39 | 40 | self.policy = Policy(in_dim=self.env.state_size, 41 | out_dim=self.env.action_size, 42 | size=self.env.p_size, 43 | lr=self.env.p_lr, 44 | do_keep_prob=self.do_keep_prob, 45 | n_accum_steps=self.env.policy_accum_steps, 46 | weight_decay=self.env.weight_decay) 47 | 48 | # Create experience buffers 49 | self.er_agent = ER(memory_size=self.env.er_agent_size, 50 | state_dim=self.env.state_size, 51 | action_dim=self.env.action_size, 52 | reward_dim=1, # stub connection 53 | qpos_dim=self.env.qpos_size, 54 | qvel_dim=self.env.qvel_size, 55 | batch_size=self.env.batch_size, 56 | history_length=1) 57 | 58 | self.er_expert = common.load_er(fname=os.path.join(self.env.run_dir, self.env.expert_data), 59 | batch_size=self.env.batch_size, 60 | history_length=1, 61 | traj_length=2) 62 | 63 | self.env.sigma = self.er_expert.actions_std / self.env.noise_intensity 64 | 65 | # Normalize the inputs 66 | states_ = common.normalize(self.states_, self.er_expert.states_mean, self.er_expert.states_std) 67 | states = common.normalize(self.states, self.er_expert.states_mean, self.er_expert.states_std) 68 | if self.env.continuous_actions: 69 | actions = common.normalize(self.actions, self.er_expert.actions_mean, self.er_expert.actions_std) 70 | else: 71 | actions = self.actions 72 | 73 | # 1. Forward Model 74 | initial_gru_state = np.ones((1, self.forward_model.encoding_size)) 75 | forward_model_prediction, _ = self.forward_model.forward([states_, actions, initial_gru_state]) 76 | forward_model_loss = tf.reduce_mean(tf.square(states-forward_model_prediction)) 77 | self.forward_model.train(objective=forward_model_loss) 78 | 79 | # 2. Discriminator 80 | labels = tf.concat([1 - self.label, self.label], 1) 81 | d = self.discriminator.forward(states, actions) 82 | 83 | # 2.1 0-1 accuracy 84 | correct_predictions = tf.equal(tf.argmax(d, 1), tf.argmax(labels, 1)) 85 | self.discriminator.acc = tf.reduce_mean(tf.cast(correct_predictions, "float")) 86 | # 2.2 prediction 87 | d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=d, labels=labels) 88 | # cost sensitive weighting (weight true=expert, predict=agent mistakes) 89 | d_loss_weighted = self.env.cost_sensitive_weight * tf.multiply(tf.to_float(tf.equal(tf.squeeze(self.label), 1.)), d_cross_entropy) +\ 90 | tf.multiply(tf.to_float(tf.equal(tf.squeeze(self.label), 0.)), d_cross_entropy) 91 | discriminator_loss = tf.reduce_mean(d_loss_weighted) 92 | self.discriminator.train(objective=discriminator_loss) 93 | 94 | # 3. Collect experience 95 | mu = self.policy.forward(states) 96 | if self.env.continuous_actions: 97 | a = common.denormalize(mu, self.er_expert.actions_mean, self.er_expert.actions_std) 98 | eta = tf.random_normal(shape=tf.shape(a), stddev=self.env.sigma) 99 | self.action_test = tf.squeeze(a + self.noise * eta) 100 | else: 101 | a = common.gumbel_softmax(logits=mu, temperature=self.temp) 102 | self.action_test = tf.argmax(a, dimension=1) 103 | 104 | # 4.3 AL 105 | def policy_loop(state_, t, total_cost, total_trans_err, _): 106 | mu = self.policy.forward(state_, reuse=True) 107 | 108 | if self.env.continuous_actions: 109 | eta = self.env.sigma * tf.random_normal(shape=tf.shape(mu)) 110 | action = mu + eta 111 | else: 112 | action = common.gumbel_softmax_sample(logits=mu, temperature=self.temp) 113 | 114 | # minimize the gap between agent logit (d[:,0]) and expert logit (d[:,1]) 115 | d = self.discriminator.forward(state_, action, reuse=True) 116 | cost = self.al_loss(d) 117 | 118 | # add step cost 119 | total_cost += tf.multiply(tf.pow(self.gamma, t), cost) 120 | 121 | # get action 122 | if self.env.continuous_actions: 123 | a_sim = common.denormalize(action, self.er_expert.actions_mean, self.er_expert.actions_std) 124 | else: 125 | a_sim = tf.argmax(action, dimension=1) 126 | 127 | # get next state 128 | state_env, _, env_term_sig, = self.env.step(a_sim, mode='tensorflow')[:3] 129 | state_e = common.normalize(state_env, self.er_expert.states_mean, self.er_expert.states_std) 130 | state_e = tf.stop_gradient(state_e) 131 | 132 | state_a, _ = self.forward_model.forward([state_, action, initial_gru_state], reuse=True) 133 | 134 | state, nu = common.re_parametrization(state_e=state_e, state_a=state_a) 135 | total_trans_err += tf.reduce_mean(abs(nu)) 136 | t += 1 137 | 138 | return state, t, total_cost, total_trans_err, env_term_sig 139 | 140 | def policy_stop_condition(state_, t, cost, trans_err, env_term_sig): 141 | cond = tf.logical_not(env_term_sig) 142 | cond = tf.logical_and(cond, t < self.env.n_steps_train) 143 | cond = tf.logical_and(cond, trans_err < self.env.total_trans_err_allowed) 144 | return cond 145 | 146 | state_0 = tf.slice(states, [0, 0], [1, -1]) 147 | loop_outputs = tf.while_loop(policy_stop_condition, policy_loop, [state_0, 0., 0., 0., False]) 148 | self.policy.train(objective=loop_outputs[2]) 149 | 150 | def al_loss(self, d): 151 | logit_agent, logit_expert = tf.split(axis=1, num_or_size_splits=2, value=d) 152 | 153 | # Cross entropy loss 154 | labels = tf.concat([tf.zeros_like(logit_agent), tf.ones_like(logit_expert)], 1) 155 | d_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=d, labels=labels) 156 | loss = tf.reduce_mean(d_cross_entropy) 157 | 158 | return loss*self.env.policy_al_w 159 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def dense(input, input_size, output_size, activation, name, reuse=False): 5 | with tf.variable_scope(name, reuse=reuse, initializer=tf.random_normal_initializer(stddev=0.15)): 6 | weights = tf.get_variable('weights', [input_size, output_size]) 7 | biases = tf.get_variable('biases', [output_size]) 8 | output = tf.matmul(input, weights) + biases 9 | if activation: 10 | output = activation(output) 11 | return output 12 | 13 | 14 | def gru(input, hidden, input_size, hidden_size, name, reuse=False): 15 | with tf.variable_scope(name, reuse=reuse, initializer=tf.random_normal_initializer(stddev=0.15)): 16 | Wxr = tf.get_variable('weights_xr', [input_size, hidden_size]) 17 | Wxz = tf.get_variable('weights_xz', [input_size, hidden_size]) 18 | Wxh = tf.get_variable('weights_xh', [input_size, hidden_size]) 19 | Whr = tf.get_variable('weights_hr', [hidden_size, hidden_size]) 20 | Whz = tf.get_variable('weights_hz', [hidden_size, hidden_size]) 21 | Whh = tf.get_variable('weights_hh', [hidden_size, hidden_size]) 22 | br = tf.get_variable('biases_r', [1, hidden_size]) 23 | bz = tf.get_variable('biases_z', [1, hidden_size]) 24 | bh = tf.get_variable('biases_h', [1, hidden_size]) 25 | 26 | x, h_ = input, hidden 27 | r = tf.sigmoid(tf.matmul(x, Wxr) + tf.matmul(h_, Whr) + br) 28 | z = tf.sigmoid(tf.matmul(x, Wxz) + tf.matmul(h_, Whz) + bz) 29 | 30 | h_hat = tf.tanh(tf.matmul(x, Wxh) + tf.matmul(tf.multiply(r, h_), Whh) + bh) 31 | 32 | output = tf.multiply((1 - z), h_hat) + tf.multiply(z, h_) 33 | 34 | return output 35 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops 3 | 4 | 5 | class Policy(object): 6 | def __init__(self, in_dim, out_dim, size, lr, do_keep_prob, n_accum_steps, weight_decay): 7 | 8 | self.arch_params = { 9 | 'in_dim': in_dim, 10 | 'out_dim': out_dim, 11 | 'n_hidden_0': size[0], 12 | 'n_hidden_1': size[1], 13 | 'do_keep_prob': do_keep_prob 14 | } 15 | 16 | self.solver_params = { 17 | 'lr': lr, 18 | 'weight_decay': weight_decay, 19 | 'n_accum_steps': n_accum_steps, 20 | } 21 | 22 | def forward(self, state, reuse=False): 23 | with tf.variable_scope('policy'): 24 | h0 = ops.dense(state, self.arch_params['in_dim'], self.arch_params['n_hidden_0'], tf.nn.relu, 'dense0', reuse) 25 | h1 = ops.dense(h0, self.arch_params['n_hidden_0'], self.arch_params['n_hidden_1'], tf.nn.relu, 'dense1', reuse) 26 | relu1_do = tf.nn.dropout(h1, self.arch_params['do_keep_prob']) 27 | a = ops.dense(relu1_do, self.arch_params['n_hidden_1'], self.arch_params['out_dim'], None, 'dense2', reuse) 28 | 29 | return a 30 | 31 | def backward(self, loss): 32 | 33 | self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='policy') 34 | 35 | self.accum_grads = [tf.Variable(tf.zeros(w.get_shape())) for w in self.weights] 36 | 37 | # reset gradients op 38 | self.reset_grad_op = [] 39 | for acc_grad in self.accum_grads: 40 | self.reset_grad_op.append(acc_grad.assign(0. * acc_grad)) 41 | 42 | # create an optimizer 43 | opt = tf.train.AdamOptimizer(learning_rate=self.solver_params['lr']) 44 | 45 | # weight decay 46 | loss += self.solver_params['weight_decay'] * tf.add_n([tf.nn.l2_loss(w) for w in self.weights if 'weights' in w.name]) 47 | 48 | # compute the gradients for a list of variables 49 | grads_and_vars = opt.compute_gradients(loss=loss, var_list=self.weights) 50 | 51 | # get clipped gradients 52 | grads = [tf.clip_by_value(g, -2, 2) for g, v in grads_and_vars] 53 | 54 | variables = [v for g, v in grads_and_vars] 55 | 56 | # accumulate the grads 57 | accum_grads_op = [] 58 | for i, accum_grad in enumerate(self.accum_grads): 59 | accum_grads_op.append(accum_grad.assign_add(grads[i])) 60 | 61 | # pack accumulated gradient and vars back in grads_and_vars (while normalizing by policy_accum_steps) 62 | grads_and_vars = [] 63 | for g, v in zip(self.accum_grads, variables): 64 | grads_and_vars.append([tf.div(g, self.solver_params['n_accum_steps']), v]) 65 | 66 | # apply the gradient 67 | apply_grads = opt.apply_gradients(grads_and_vars) 68 | 69 | return apply_grads, accum_grads_op 70 | 71 | def train(self, objective): 72 | self.loss_al = objective 73 | self.apply_grads_al, self.accum_grads_al = self.backward(self.loss_al) 74 | --------------------------------------------------------------------------------