├── .gitignore ├── README.md ├── algos ├── __init__.py ├── dqn │ ├── __init__.py │ ├── core.py │ ├── dqn.py │ └── dqn_cartpole.py ├── ppo │ ├── core_cnn_torch.py │ ├── core_resnet_simple_torch.py │ ├── core_torch.py │ ├── generate_expert_dmc.py │ ├── plot_seed.py │ ├── run_ppo_atari_torch.py │ ├── run_ppo_dmc_torch.py │ ├── run_ppo_multi_torch.py │ ├── run_ppo_torch.py │ ├── test_ppo_torch.py │ └── utils.py └── vae │ ├── core_vae.py │ ├── decoder.py │ ├── encoder.py │ ├── run_ppo_vae_dmc.py │ └── run_vae.py ├── data ├── dqn │ └── dqn_s0 │ │ ├── config.json │ │ └── progress.txt ├── dqn_cartpole │ └── dqn_cartpole_s0 │ │ ├── config.json │ │ ├── logs │ │ ├── events.out.tfevents.1566310371.zhr-AERO-15WV8 │ │ ├── events.out.tfevents.1566310768.zhr-AERO-15WV8 │ │ └── events.out.tfevents.1566313578.zhr-AERO-15WV8 │ │ └── progress.txt └── dqn_seaquest │ └── dqn_seaquest_s0 │ ├── config.json │ ├── logs │ └── events.out.tfevents.1565228049.zhr-System-Product-Name │ └── progress.txt ├── env ├── __init__.py ├── atari_lib.py ├── atari_process.py ├── dmc_env.py ├── gym_example.py ├── vecenv.py └── wrappers.py ├── user_config.py └── utils ├── __init__.py ├── logx.py ├── mpi_tools.py ├── normalization.py ├── plot.py ├── run_utils.py ├── serialization_utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | data 3 | .vscode 4 | imgs 5 | videos 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL-Implementation 2 | simple code to reinforcement learning 3 | 4 | - pendulum 5 | - python -um algos.ppo.run_ppo_torch --exp_name xxx --max_grad_norm 0.5 --steps 2048 --anneal_lr --is_clip_v --seed 20 --env Pendulum-v0 --is_gae --target_kl 0.03 --lr 0.0003 --norm_state --iteration 1000 --batch 64 --last_v --norm_rewards returns --a_update 10 6 | 7 | - mujoco 8 | - python -um algos.ppo.run_ppo_torch --exp_name xxx --max_grad_norm 0.5 --steps 2048 --anneal_lr --is_clip_v --seed 30 --env HalfCheetah-v2 --is_gae --target_kl 0.07 --lr 0.0003 --norm_state --iteration 2000 --batch 64 --last_v --norm_rewards returns --a_update 10 9 | 10 | - dmc 11 | - python -um algos.ppo.run_ppo_dmc_torch --exp_name xxx --max_grad_norm 0.5 --steps 2048 --anneal_lr --is_clip_v --seed 10 --domain_name cheetah --task_name run --network mlp --target_kl 0.05 --lr 0.0003 --iteration 2000 --batch 64 --encoder_type state --last_v --is_gae --a_update 10 --norm_state --norm_rewards returns --c_en 0 12 | 13 | - if you want to test dmc and see video 14 | - python -um algos.ppo.test_ppo_torch --exp_name xxx --encoder_type state 15 | 16 | - if you want to plot results (mujoco) 17 | - please modify taskname in plot_seed.py 18 | - then python -um algos.ppo.plot_seed 19 | 20 | - output results in data/* when you train experiments 21 | - you can find logs(tensorboard),checkpoints(pytorch),args.json and process.txt 22 | 23 | - if you know chinese, you can read 24 | - 强化学习中的调参经验与编程技巧(on policy 篇) 25 | https://blog.csdn.net/qq_27008079/article/details/108313137 26 | -------------------------------------------------------------------------------- /algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/algos/__init__.py -------------------------------------------------------------------------------- /algos/dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/algos/dqn/__init__.py -------------------------------------------------------------------------------- /algos/dqn/core.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras import layers 3 | 4 | 5 | def nature_dqn(num_actions): 6 | 7 | model = tf.keras.Sequential() 8 | model.add(layers.Conv2D(32, [8, 8], strides=4, input_shape=(84, 84, 4), activation="relu")) 9 | model.add(layers.Conv2D(64, [4, 4], strides=2, activation="relu")) 10 | model.add(layers.Conv2D(64, [3, 3], strides=1, activation="relu")) 11 | model.add(layers.Flatten()) 12 | model.add(layers.Dense(512, activation="relu")) 13 | model.add(layers.Dense(num_actions)) 14 | 15 | model.compile(optimizer=tf.train.RMSPropOptimizer(0.001), 16 | loss="mse", 17 | metrics=["mae"]) 18 | 19 | model_target = tf.keras.Sequential() 20 | model_target.add(layers.Conv2D(32, [8, 8], strides=4, input_shape=(84, 84, 4), activation="relu")) 21 | model_target.add(layers.Conv2D(64, [4, 4], strides=2, activation="relu")) 22 | model_target.add(layers.Conv2D(64, [3, 3], strides=1, activation="relu")) 23 | model_target.add(layers.Flatten()) 24 | model_target.add(layers.Dense(512, activation="relu")) 25 | model_target.add(layers.Dense(num_actions)) 26 | 27 | return model, model_target 28 | 29 | 30 | def mlp_dqn(num_actions, input_shape, hidden_sizes=(400,300), activation="relu"): 31 | 32 | model = tf.keras.Sequential() 33 | model.add(layers.Dense(hidden_sizes[0], input_shape=input_shape, activation=activation)) 34 | for i in hidden_sizes[1:-1]: 35 | model.add(layers.Dense(hidden_sizes[i], activation=activation)) 36 | 37 | model.add(layers.Dense(num_actions)) 38 | 39 | model.compile(optimizer=tf.train.RMSPropOptimizer(0.01), 40 | loss="mse", 41 | metrics=["mae"]) 42 | 43 | model_target = tf.keras.Sequential() 44 | model_target.add(layers.Dense(hidden_sizes[0], input_shape=input_shape, activation=activation)) 45 | for i in hidden_sizes[1:-1]: 46 | model_target.add(layers.Dense(hidden_sizes[i], activation=activation)) 47 | 48 | model_target.add(layers.Dense(num_actions)) 49 | 50 | return model, model_target 51 | 52 | -------------------------------------------------------------------------------- /algos/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | from env.atari_lib import make_atari, wrap_deepmind 2 | import tensorflow as tf 3 | from tensorflow.python.keras import layers 4 | from algos.dqn.core import nature_dqn 5 | import numpy as np 6 | import random 7 | import copy 8 | from utils.logx import EpochLogger 9 | import os 10 | import gym 11 | from env.atari_process import AtariPreprocessing 12 | debug_mode = True 13 | 14 | class ReplayBuffer: 15 | def __init__(self, size): 16 | self.storage = [] 17 | self.maxsize = size 18 | self.next_idx = 0 19 | self.size = 0 20 | 21 | def add(self, s, a, s_, r, done): 22 | data = (s, a, s_, r, done) 23 | if self.size < self.maxsize: 24 | self.storage.append(data) 25 | 26 | else: 27 | self.storage[self.next_idx] = data 28 | 29 | self.next_idx = (self.next_idx + 1) % self.maxsize 30 | self.size = min(self.size+1, self.maxsize) 31 | 32 | def sample(self, batch_size=32): 33 | ids = np.random.randint(0, self.size, size=batch_size) 34 | # storage = np.array(self.storage) 35 | state = [] 36 | action = [] 37 | state_ = [] 38 | reward = [] 39 | done = [] 40 | for i in ids: 41 | (s, a, s_, r, d) = self.storage[i] 42 | state.append(s) 43 | action.append(a) 44 | state_.append(s_) 45 | reward.append(r) 46 | done.append(d) 47 | 48 | return np.array(state), np.array(action), np.array(state_), np.array(reward), np.array(done) 49 | 50 | 51 | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): 52 | step_left = decay_period + warmup_steps - step 53 | ep = step_left/decay_period*(1-epsilon) 54 | ep = np.clip(ep, 0, 1-epsilon) 55 | 56 | return epsilon + ep 57 | 58 | 59 | class Dqn: 60 | def __init__(self, env_name, train_step=250000/4, evaluation_step=125000/4, max_ep_len=27000/4, epsilon_train=0.1, 61 | epsilon_eval=0.01, batch_size=32, replay_size=1e6, 62 | epsilon_decay_period=250000/4, warmup_steps=20000/4, iteration=200, gamma=0.99, 63 | target_update_period=8000/4, update_period=4, logger_kwargs=dict()): 64 | 65 | self.logger = EpochLogger(**logger_kwargs) 66 | self.logger.save_config(locals()) 67 | 68 | # self.env = make_atari(env_name) 69 | # self.env = wrap_deepmind(self.env, frame_stack=True) 70 | self.env = gym.make(env_name) 71 | env = self.env.env 72 | self.env = AtariPreprocessing(env) 73 | 74 | self.train_step = train_step 75 | self.evaluation_step = evaluation_step 76 | self.max_ep_len = max_ep_len 77 | self.epsilon_train = epsilon_train 78 | self.epsilon_eval = epsilon_eval 79 | self.batch_size = batch_size 80 | self.replay_size = replay_size 81 | self.epsilon_decay_period = epsilon_decay_period 82 | self.warmup_steps = warmup_steps 83 | self.iteration = iteration 84 | self.replay_buffer = ReplayBuffer(replay_size) 85 | self.gamma = gamma 86 | self.target_update_period = target_update_period 87 | self.update_period = update_period 88 | 89 | self.build_model() 90 | self.cur_train_step = 0 91 | 92 | self.observation_shape = (84, 84) 93 | self.state_shape = (1,) + self.observation_shape + (4,) 94 | self.s = np.zeros(self.state_shape) 95 | self.last_s = np.zeros(self.state_shape) 96 | 97 | if debug_mode: 98 | self.summary = tf.summary.FileWriter(os.path.join(self.logger.output_dir, "logs")) 99 | 100 | self.sess = tf.Session() 101 | self.loss = tf.placeholder(tf.float32, shape=[]) 102 | self.q = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 103 | self.q_target = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 104 | self.target_q = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 105 | tf.summary.scalar("loss", self.loss) 106 | # tf.summary.histogram("q", self.q) 107 | # tf.summary.histogram("q_target", self.q_target) 108 | # tf.summary.histogram("target_q", self.target_q) 109 | self.merge = tf.summary.merge_all() 110 | 111 | def build_model(self): 112 | self.model, self.model_target = nature_dqn(self.env.action_space.n) 113 | self.model_target.set_weights(self.model.get_weights()) 114 | 115 | def choose_action(self, s, eval_mode=False): 116 | epsilon = self.epsilon_eval if eval_mode \ 117 | else linearly_decaying_epsilon(self.epsilon_decay_period, self.cur_train_step, self.warmup_steps, self.epsilon_train) 118 | 119 | if random.random() <= 1-epsilon: 120 | q = self.model.predict(s[np.newaxis, :]) 121 | a = np.argmax(q, axis=1)[0] 122 | # print() 123 | else: 124 | a = self.env.action_space.sample() 125 | 126 | return a 127 | 128 | def record_obs(self, observation): 129 | self.last_s = copy.copy(self.s) 130 | self.s = np.roll(self.s, -1, axis=-1) 131 | self.s[0, ..., -1] = np.squeeze(observation) 132 | 133 | def store(self, s, a, s_, r, done): 134 | pass 135 | 136 | def run_one_phrase(self, min_step, eval_mode=False): 137 | step = 0 138 | episode = 0 139 | reward = 0. 140 | 141 | while step < min_step: 142 | done = False 143 | obs = self.env.reset() 144 | # o = np.array(obs) 145 | 146 | step_episode = 0 147 | reward_episode = 0 148 | while not done: 149 | a = self.choose_action(np.array(obs), eval_mode) 150 | obs_, r, done, _ = self.env.step(a) 151 | 152 | step += 1 153 | step_episode += 1 154 | reward += r 155 | reward_episode += r 156 | 157 | if not eval_mode: 158 | self.cur_train_step += 1 159 | self.replay_buffer.add(np.array(obs), a, np.array(obs_), r, done) 160 | 161 | if self.cur_train_step > 20000/4: 162 | if self.cur_train_step % self.update_period == 0: 163 | # data = self.replay_buffer.sample() 164 | (s, a, s_, r, d) = self.replay_buffer.sample() 165 | q_ = np.max(self.model_target.predict(s_), axis=1) 166 | q_target = r + (1-d)*self.gamma * q_ 167 | q = self.model.predict(s) 168 | batch_index = np.arange(self.batch_size) 169 | q[batch_index, a] = q_target 170 | result = self.model.train_on_batch(np.array(s), q) 171 | # print("result:", result) 172 | 173 | merge = self.sess.run(self.merge, feed_dict={self.loss: result[0]}) 174 | self.summary.add_summary(merge, (self.cur_train_step-20000)/self.update_period) 175 | 176 | if self.cur_train_step % self.target_update_period == 0: 177 | self.model_target.set_weights(self.model.get_weights()) 178 | 179 | if step_episode >= self.max_ep_len: 180 | break 181 | obs = obs_ 182 | 183 | episode += 1 184 | 185 | # print("ep:", episode, "step:", step, "r:", reward) 186 | self.logger.store(step=step_episode, reward=reward_episode) 187 | 188 | return reward, episode 189 | 190 | def train_test(self): 191 | for i in range(self.iteration): 192 | print("iter:", i+1) 193 | self.logger.store(iter=i+1) 194 | reward, episode = self.run_one_phrase(self.train_step) 195 | print("reward:", reward/episode, "episode:", episode) 196 | 197 | self.logger.log_tabular("reward", with_min_and_max=True) 198 | self.logger.log_tabular("step", with_min_and_max=True) 199 | self.logger.dump_tabular() 200 | 201 | reward, episode = self.run_one_phrase(self.evaluation_step, True) 202 | print("reward:", reward / episode, "episode:", episode) 203 | 204 | 205 | if __name__ == '__main__': 206 | import argparse 207 | 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--env', type=str, default='SeaquestNoFrameskip-v4') 210 | parser.add_argument('--seed', '-s', type=int, default=0) 211 | parser.add_argument('--exp_name', type=str, default='dqn_seaquest') 212 | args = parser.parse_args() 213 | 214 | from utils.run_utils import setup_logger_kwargs 215 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 216 | 217 | dqn = Dqn(args.env, logger_kwargs=logger_kwargs) 218 | dqn.train_test() 219 | 220 | 221 | -------------------------------------------------------------------------------- /algos/dqn/dqn_cartpole.py: -------------------------------------------------------------------------------- 1 | from env.atari_lib import make_atari, wrap_deepmind 2 | import tensorflow as tf 3 | from tensorflow.python.keras import layers 4 | from algos.dqn.core import mlp_dqn 5 | import numpy as np 6 | import random 7 | import copy 8 | from utils.logx import EpochLogger 9 | import gym 10 | import os 11 | 12 | debug_mode = True 13 | 14 | class ReplayBuffer: 15 | def __init__(self, size): 16 | self.storage = [] 17 | self.maxsize = size 18 | self.next_idx = 0 19 | self.size = 0 20 | 21 | def add(self, s, a, s_, r, done): 22 | data = (s, a, s_, r, done) 23 | if self.size < self.maxsize: 24 | self.storage.append(data) 25 | 26 | else: 27 | self.storage[self.next_idx] = data 28 | 29 | self.next_idx = (self.next_idx + 1) % self.maxsize 30 | self.size = min(self.size+1, self.maxsize) 31 | 32 | def sample(self, batch_size=32): 33 | ids = np.random.randint(0, self.size, size=batch_size) 34 | # storage = np.array(self.storage) 35 | state = [] 36 | action = [] 37 | state_ = [] 38 | reward = [] 39 | done = [] 40 | for i in ids: 41 | (s, a, s_, r, d) = self.storage[i] 42 | state.append(s) 43 | action.append(a) 44 | state_.append(s_) 45 | reward.append(r) 46 | done.append(d) 47 | 48 | return np.array(state), np.array(action), np.array(state_), np.array(reward), np.array(done) 49 | 50 | 51 | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): 52 | step_left = decay_period + warmup_steps - step 53 | ep = step_left/decay_period*(1-epsilon) 54 | ep = np.clip(ep, 0, 1-epsilon) 55 | 56 | return epsilon + ep 57 | 58 | 59 | class Dqn: 60 | def __init__(self, env_name, train_step=200, evaluation_step=1000, max_ep_len=200, epsilon_train=0.1, 61 | epsilon_eval=0.01, batch_size=32, replay_size=1e6, 62 | epsilon_decay_period=100, warmup_steps=0, iteration=200, gamma=0.99, 63 | target_update_period=50, update_period=10, logger_kwargs=dict()): 64 | 65 | self.logger = EpochLogger(**logger_kwargs) 66 | self.logger.save_config(locals()) 67 | 68 | self.env = gym.make(env_name) 69 | 70 | self.train_step = train_step 71 | self.evaluation_step = evaluation_step 72 | self.max_ep_len = max_ep_len 73 | self.epsilon_train = epsilon_train 74 | self.epsilon_eval = epsilon_eval 75 | self.batch_size = batch_size 76 | self.replay_size = replay_size 77 | self.epsilon_decay_period = epsilon_decay_period 78 | self.warmup_steps = warmup_steps 79 | self.iteration = iteration 80 | self.replay_buffer = ReplayBuffer(replay_size) 81 | self.gamma = gamma 82 | self.target_update_period = target_update_period 83 | self.update_period = update_period 84 | 85 | self.build_model() 86 | self.cur_train_step = 0 87 | 88 | if debug_mode: 89 | self.summary = tf.summary.FileWriter(os.path.join(self.logger.output_dir, "logs")) 90 | 91 | self.sess = tf.Session() 92 | self.loss = tf.placeholder(tf.float32, shape=[]) 93 | self.q = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 94 | self.q_target = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 95 | self.target_q = tf.placeholder(tf.float32, shape=[None, self.env.action_space.n]) 96 | tf.summary.scalar("loss", self.loss) 97 | tf.summary.histogram("q", self.q) 98 | tf.summary.histogram("q_target", self.q_target) 99 | tf.summary.histogram("target_q", self.target_q) 100 | self.merge = tf.summary.merge_all() 101 | 102 | 103 | 104 | def build_model(self): 105 | self.input_shape = self.env.observation_space.shape 106 | self.model, self.model_target = mlp_dqn(self.env.action_space.n, self.input_shape) 107 | self.model_target.set_weights(self.model.get_weights()) 108 | 109 | def choose_action(self, s, eval_mode=False): 110 | epsilon = self.epsilon_eval if eval_mode \ 111 | else linearly_decaying_epsilon(self.epsilon_decay_period, self.cur_train_step, self.warmup_steps, self.epsilon_train) 112 | # print("epsilon:", epsilon) 113 | if random.random() <= 1-epsilon: 114 | q = self.model.predict(s[np.newaxis, :]) 115 | a = np.argmax(q, axis=1)[0] 116 | # print() 117 | else: 118 | a = self.env.action_space.sample() 119 | 120 | return a 121 | 122 | def run_one_phrase(self, min_step, eval_mode=False): 123 | step = 0 124 | episode = 0 125 | reward = 0. 126 | 127 | while step < min_step: 128 | reward_episode = 0. 129 | step_episode = 0 130 | done = False 131 | obs = self.env.reset() 132 | # o = np.array(obs) 133 | 134 | while not done: 135 | a = self.choose_action(np.array(obs), eval_mode) 136 | obs_, r, done, _ = self.env.step(a) 137 | 138 | step += 1 139 | step_episode += 1 140 | reward += r 141 | reward_episode += r 142 | 143 | if not eval_mode: 144 | self.cur_train_step += 1 145 | self.replay_buffer.add(np.array(obs), a, np.array(obs_), r, done) 146 | 147 | if self.cur_train_step > 100: 148 | if self.cur_train_step % self.update_period == 0: 149 | # data = self.replay_buffer.sample() 150 | (s, a, s_, r, d) = self.replay_buffer.sample() 151 | target_q = self.model_target.predict(s_) 152 | q_ = np.max(target_q, axis=1) 153 | 154 | q_target = r + (1-d) *self.gamma * q_ 155 | 156 | q = self.model.predict(s) 157 | ori_q = np.copy(q) 158 | batch_index = np.arange(self.batch_size) 159 | q[batch_index, a] = q_target 160 | 161 | result = self.model.train_on_batch(np.array(s), q) 162 | 163 | if debug_mode: 164 | merge = self.sess.run(self.merge, 165 | feed_dict={self.loss: result[0], self.q: ori_q, self.q_target: q, 166 | self.target_q: target_q}) 167 | self.summary.add_summary(merge, (self.cur_train_step-100)/self.update_period) 168 | # print("result:", result) 169 | 170 | if self.cur_train_step % self.target_update_period == 0: 171 | self.model_target.set_weights(self.model.get_weights()) 172 | 173 | if step_episode >= self.max_ep_len: 174 | break 175 | obs = obs_ 176 | 177 | episode += 1 178 | 179 | # print("ep:", episode, "step:", step, "r:", reward) 180 | if not eval_mode: 181 | self.logger.store(step=step_episode, reward=reward_episode) 182 | 183 | return reward, episode 184 | 185 | def train_test(self): 186 | for i in range(self.iteration): 187 | print("iter:", i+1) 188 | self.logger.store(iter=i+1) 189 | reward, episode = self.run_one_phrase(self.train_step) 190 | print("reward:", reward/episode, "episode:", episode) 191 | 192 | self.logger.log_tabular("iter", i+1) 193 | self.logger.log_tabular("reward", with_min_and_max=True) 194 | self.logger.log_tabular("step", with_min_and_max=True) 195 | self.logger.dump_tabular() 196 | 197 | reward, episode = self.run_one_phrase(self.evaluation_step, True) 198 | print("reward:", reward / episode, "episode:", episode) 199 | 200 | 201 | if __name__ == '__main__': 202 | import argparse 203 | 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('--env', type=str, default='CartPole-v1') 206 | parser.add_argument('--seed', '-s', type=int, default=0) 207 | parser.add_argument('--exp_name', type=str, default='dqn_cartpole') 208 | args = parser.parse_args() 209 | 210 | from utils.run_utils import setup_logger_kwargs 211 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 212 | 213 | dqn = Dqn(args.env, logger_kwargs=logger_kwargs) 214 | dqn.train_test() 215 | 216 | 217 | -------------------------------------------------------------------------------- /algos/ppo/core_cnn_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal, MultivariateNormal 3 | import numpy as np 4 | 5 | def initialize_weights(mod, initialization_type, scale=1): 6 | ''' 7 | Weight initializer for the models. 8 | Inputs: A model, Returns: none, initializes the parameters 9 | ''' 10 | for p in mod.parameters(): 11 | if initialization_type == "normal": 12 | p.data.normal_(0.01) 13 | elif initialization_type == "xavier": 14 | if len(p.data.shape) >= 2: 15 | torch.nn.init.xavier_uniform_(p.data) 16 | else: 17 | p.data.zero_() 18 | elif initialization_type == "orthogonal": 19 | if len(p.data.shape) >= 2: 20 | torch.nn.init.orthogonal_(p.data, gain=scale) 21 | else: 22 | p.data.zero_() 23 | else: 24 | raise ValueError("Need a valid initialization key") 25 | 26 | 27 | class ActorCnnEmb(torch.nn.Module): 28 | def __init__(self, s_dim, emb_dim=512): 29 | super().__init__() 30 | self.cnn = torch.nn.Sequential( 31 | torch.nn.Conv2d(s_dim[0], 32, kernel_size=8, stride=4), 32 | torch.nn.ReLU(), 33 | torch.nn.Conv2d(32, 64, kernel_size=4, stride=2), 34 | torch.nn.ReLU(), 35 | torch.nn.Conv2d(64, 64, kernel_size=3, stride=1), 36 | torch.nn.ReLU(), 37 | ) 38 | 39 | h = self.conv2d_size_out(self.conv2d_size_out( 40 | self.conv2d_size_out(s_dim[1], kernel_size=8, stride=4), 4, 2), 3, 1) 41 | w = self.conv2d_size_out(self.conv2d_size_out( 42 | self.conv2d_size_out(s_dim[2], kernel_size=8, stride=4), 4, 2), 3, 1) 43 | 44 | self.fc = torch.nn.Sequential( 45 | torch.nn.Linear(h*w*64, emb_dim), 46 | torch.nn.ReLU(),) 47 | 48 | def conv2d_size_out(self, size, kernel_size=5, stride=2): 49 | return (size - (kernel_size - 1) - 1) // stride + 1 50 | 51 | def forward(self, s): 52 | cnn = self.cnn(s) 53 | x = cnn.view(s.shape[0], -1) 54 | x = self.fc(x) 55 | return x 56 | 57 | class Critic(torch.nn.Module): 58 | def __init__(self, s_dim, emb_dim=512): 59 | super().__init__() 60 | 61 | self.fc = torch.nn.Sequential( 62 | torch.nn.Linear(emb_dim, 1), 63 | ) 64 | 65 | initialize_weights(self.fc, "orthogonal", scale=1) 66 | 67 | def forward(self, s_emb): 68 | # s_emb = self.emb(s_emb) 69 | v = self.fc(s_emb) 70 | return v 71 | 72 | 73 | class Actor(torch.nn.Module): 74 | def __init__(self, s_dim, a_dim, a_max=1, emb_dim=512): 75 | super().__init__() 76 | self.act_dim = a_dim 77 | self.a_max = a_max 78 | self.emb = ActorCnnEmb(s_dim) 79 | 80 | self.mu = torch.nn.Sequential( 81 | torch.nn.Linear(emb_dim, 100), 82 | torch.nn.ReLU(), 83 | torch.nn.Linear(100, a_dim), 84 | torch.nn.Tanh() 85 | ) 86 | self.var = torch.nn.Parameter(torch.zeros(1, a_dim)) 87 | 88 | initialize_weights(self.emb, "orthogonal", scale=np.sqrt(2)) 89 | initialize_weights(self.mu, "orthogonal", scale=0.01) 90 | 91 | def forward(self, s): 92 | emb = self.emb(s) 93 | mu = self.mu(emb)*self.a_max 94 | 95 | return mu, torch.exp(self.var) 96 | 97 | def select_action(self, s): 98 | with torch.no_grad(): 99 | mean, std = self(s) 100 | 101 | normal = Normal(mean, std) 102 | action = normal.sample() # [1, a_dim] 103 | action = torch.squeeze(action, dim=0) # [a_dim] 104 | 105 | return action 106 | 107 | def log_pi(self, s, a): 108 | mean, std = self(s) 109 | 110 | normal = Normal(mean, std) 111 | logpi = normal.log_prob(a) 112 | logpi = torch.sum(logpi, dim=1) 113 | entropy = normal.entropy().sum(dim=1) 114 | 115 | return logpi, entropy # [None,] 116 | 117 | class ActorDisc(torch.nn.Module): 118 | def __init__(self, s_dim, a_num, emb_dim=512): 119 | super().__init__() 120 | self.act_num = a_num 121 | self.emb = ActorCnnEmb(s_dim) 122 | 123 | self.final = torch.nn.Sequential( 124 | torch.nn.Linear(emb_dim, a_num), 125 | # torch.nn.Softmax() 126 | ) 127 | 128 | initialize_weights(self.emb, "orthogonal", scale=np.sqrt(2)) 129 | initialize_weights(self.final, "orthogonal", scale=0.01) 130 | 131 | def forward(self, s): 132 | emb = self.emb(s) 133 | x = self.final(emb) 134 | 135 | return x 136 | 137 | def select_action(self, s): 138 | with torch.no_grad(): 139 | x = self(s) 140 | normal = torch.distributions.Categorical(logits=x) 141 | action = normal.sample() 142 | action = torch.squeeze(action, dim=0) 143 | 144 | return action 145 | 146 | def log_pi(self, s, a): 147 | x = self(s) 148 | import torch.nn.functional as F 149 | softx = F.softmax(x, dim=1) 150 | 151 | normal = torch.distributions.Categorical(logits=x) 152 | logpi = normal.log_prob(a) 153 | 154 | return logpi # [None,] 155 | 156 | class PPO(torch.nn.Module): 157 | def __init__(self, state_dim, act_dim, act_max, epsilon, device, lr_a=0.001, 158 | c_en=0.01, c_vf=0.5, max_grad_norm=-1, anneal_lr=False, train_steps=1000,): 159 | super().__init__() 160 | if type(act_dim) == np.int64 or type(act_dim) == np.int: 161 | self.actor = ActorDisc(state_dim, act_dim).to(device) 162 | self.old_actor = ActorDisc(state_dim, act_dim).to(device) 163 | else: 164 | self.actor = Actor(state_dim, act_dim[0], act_max).to(device) 165 | self.old_actor = Actor(state_dim, act_dim[0], act_max).to(device) 166 | self.critic = Critic(state_dim).to(device) 167 | self.epsilon = epsilon 168 | self.c_en = c_en 169 | self.c_vf = c_vf 170 | 171 | self.max_grad_norm = max_grad_norm 172 | self.anneal_lr = anneal_lr 173 | 174 | self.opti = torch.optim.Adam(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr_a, eps=1e-5) 175 | 176 | if anneal_lr: 177 | lam = lambda f: 1 - f / train_steps 178 | self.opti_scheduler = torch.optim.lr_scheduler.LambdaLR(self.opti, lr_lambda=lam) 179 | 180 | def train_ac(self, s, a, adv, vs, oldv, is_clip_v=True): 181 | self.opti.zero_grad() 182 | logpi, entropy = self.actor.log_pi(s, a) 183 | old_logpi, _ = self.old_actor.log_pi(s, a) 184 | 185 | ratio = torch.exp(logpi - old_logpi) 186 | surr = ratio * adv 187 | clip_adv = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * adv 188 | aloss = -torch.mean(torch.min(surr, clip_adv)) 189 | # loss_entropy = torch.mean(torch.exp(logpi) * logpi) 190 | loss_entropy = entropy.mean() 191 | kl = torch.mean(old_logpi - logpi) 192 | 193 | emb = self.actor.emb(s) 194 | v = self.critic(emb) 195 | v = torch.squeeze(v, 1) 196 | 197 | if not is_clip_v: 198 | v_loss = ((v - vs) ** 2).mean() 199 | else: 200 | clip_v = oldv + torch.clamp(v - oldv, -self.epsilon, self.epsilon) 201 | v_max = torch.max(((v - vs) ** 2), ((clip_v - vs) ** 2)) 202 | v_loss = v_max.mean() 203 | 204 | loss = aloss - loss_entropy*self.c_en + v_loss*self.c_vf 205 | loss.backward() 206 | if self.max_grad_norm != -1: 207 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 208 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 209 | self.opti.step() 210 | 211 | info = dict(vloss=v_loss.item(), aloss=aloss.item(), entropy=loss_entropy.item(), kl=kl.item(), loss=loss.item()) 212 | return info 213 | 214 | def lr_scheduler(self): 215 | if self.anneal_lr: 216 | self.opti_scheduler.step() 217 | 218 | def getV(self, s): 219 | with torch.no_grad(): 220 | emb = self.actor.emb(s) 221 | v = self.critic(emb) # [1,1] 222 | 223 | return torch.squeeze(v) # 224 | 225 | def update_a(self): 226 | self.old_actor.load_state_dict(self.actor.state_dict()) -------------------------------------------------------------------------------- /algos/ppo/core_resnet_simple_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal, MultivariateNormal 3 | import numpy as np 4 | import torchvision.models as models 5 | 6 | class Discriminator(torch.nn.Module): 7 | def __init__(self, s_dim, a_dim): 8 | self.s_emb = ActorEmb(s_dim) 9 | self.a_emb = torch.nn.Sequential( 10 | torch.nn.Linear(a_dim, 100) 11 | ) 12 | self.d = torch.nn.Sequential( 13 | torch.nn.Linear(100+100, 100), 14 | torch.nn.LeakyReLU(0.2), 15 | torch.nn.Linear(100, 1), 16 | torch.nn.Sigmoid(), 17 | ) 18 | 19 | def forward(self, s, a): 20 | s = self.s_emb(s) 21 | a = self.a_emb(a) 22 | x = torch.cat([s, a], dim=1) 23 | x = self.d(x) 24 | return x 25 | 26 | 27 | 28 | def initialize_weights(mod, initialization_type, scale=np.sqrt(2)): 29 | ''' 30 | Weight initializer for the models. 31 | Inputs: A model, Returns: none, initializes the parameters 32 | ''' 33 | for p in mod.parameters(): 34 | if initialization_type == "normal": 35 | p.data.normal_(0.01) 36 | elif initialization_type == "xavier": 37 | if len(p.data.shape) >= 2: 38 | torch.nn.init.xavier_uniform_(p.data) 39 | else: 40 | p.data.zero_() 41 | elif initialization_type == "orthogonal": 42 | if len(p.data.shape) >= 2: 43 | torch.nn.init.orthogonal_(p.data, gain=scale) 44 | else: 45 | p.data.zero_() 46 | else: 47 | raise ValueError("Need a valid initialization key") 48 | 49 | 50 | class ActorCnnEmb(torch.nn.Module): 51 | def __init__(self, s_dim, emb_dim=100): 52 | super().__init__() 53 | self.conv_block = models.resnet18(pretrained=True) 54 | for param in self.conv_block.parameters(): 55 | param.requires_grad = False 56 | self.conv_block.fc = torch.nn.Linear(self.conv_block.fc.in_features, 512) 57 | 58 | 59 | self.fc = torch.nn.Sequential( 60 | torch.nn.Linear(512, emb_dim), 61 | torch.nn.ReLU(), 62 | torch.nn.Linear(emb_dim, emb_dim), 63 | torch.nn.ReLU()) 64 | 65 | 66 | def forward(self, s): 67 | cnn = self.conv_block(s) 68 | x = self.fc(cnn) 69 | return x 70 | 71 | 72 | class Critic(torch.nn.Module): 73 | def __init__(self, s_dim, emb_dim=100): 74 | super().__init__() 75 | 76 | self.fc = torch.nn.Sequential( 77 | torch.nn.Linear(emb_dim, 100), 78 | torch.nn.ReLU(), 79 | torch.nn.Linear(100, 100), 80 | torch.nn.ReLU(), 81 | torch.nn.Linear(100, 1), 82 | ) 83 | 84 | initialize_weights(self.fc, "orthogonal", scale=1) 85 | 86 | def forward(self, s_emb): 87 | v = self.fc(s_emb) 88 | return v 89 | 90 | 91 | class Actor(torch.nn.Module): 92 | def __init__(self, s_dim, a_dim, a_max=1, emb_dim=100): 93 | super().__init__() 94 | self.act_dim = a_dim 95 | self.a_max = a_max 96 | 97 | self.emb = ActorCnnEmb(s_dim) 98 | 99 | self.mu = torch.nn.Sequential( 100 | torch.nn.Linear(emb_dim, 100), 101 | torch.nn.ReLU(), 102 | torch.nn.Linear(100, a_dim), 103 | torch.nn.Tanh() 104 | ) 105 | self.var = torch.nn.Parameter(torch.zeros(1, a_dim)) 106 | 107 | initialize_weights(self.mu, "orthogonal", scale=1) 108 | 109 | def forward(self, s): 110 | emb = self.emb(s) 111 | mu = self.mu(emb)*self.a_max 112 | 113 | return mu, torch.exp(self.var) 114 | 115 | def select_action(self, s): 116 | with torch.no_grad(): 117 | mean, std = self(s) 118 | 119 | normal = Normal(mean, std) 120 | action = normal.sample() # [1, a_dim] 121 | action = torch.squeeze(action, dim=0) # [a_dim] 122 | 123 | return action 124 | 125 | def log_pi(self, s, a): 126 | mean, std = self(s) 127 | 128 | normal = Normal(mean, std) 129 | logpi = normal.log_prob(a) 130 | logpi = torch.sum(logpi, dim=1) 131 | entropy = normal.entropy().sum(dim=1) 132 | 133 | return logpi, entropy # [None,] 134 | 135 | 136 | class PPO(torch.nn.Module): 137 | def __init__(self, state_dim, act_dim, act_max, epsilon, device, lr_a=0.001, 138 | c_en=0.01, c_vf=0.5, max_grad_norm=-1, anneal_lr=False, train_steps=1000,: 139 | super().__init__() 140 | self.actor = Actor(state_dim, act_dim, act_max).to(device) 141 | self.old_actor = Actor(state_dim, act_dim, act_max).to(device) 142 | self.critic = Critic(state_dim).to(device) 143 | self.epsilon = epsilon 144 | self.c_en = c_en 145 | self.c_vf = c_vf 146 | 147 | self.max_grad_norm = max_grad_norm 148 | self.anneal_lr = anneal_lr 149 | 150 | self.opti = torch.optim.Adam(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr_a) 151 | 152 | if anneal_lr: 153 | lam = lambda f: 1 - f / train_steps 154 | self.opti_scheduler = torch.optim.lr_scheduler.LambdaLR(self.opti, lr_lambda=lam) 155 | 156 | def train(self, s, a, adv, vs, oldv, is_clip_v=True): 157 | self.opti.zero_grad() 158 | logpi, entropy = self.actor.log_pi(s, a) 159 | old_logpi, _ = self.old_actor.log_pi(s, a) 160 | 161 | ratio = torch.exp(logpi - old_logpi) 162 | surr = ratio * adv 163 | clip_adv = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * adv 164 | aloss = -torch.mean(torch.min(surr, clip_adv)) 165 | # loss_entropy = torch.mean(torch.exp(logpi) * logpi) 166 | loss_entropy = entropy.mean() 167 | kl = torch.mean(old_logpi - logpi) 168 | 169 | emb = self.actor.emb(s) 170 | v = self.critic(emb) 171 | v = torch.squeeze(v, 1) 172 | 173 | if not is_clip_v: 174 | v_loss = ((v - vs) ** 2).mean() 175 | else: 176 | clip_v = oldv + torch.clamp(v - oldv, -self.epsilon, self.epsilon) 177 | v_loss = torch.max(((v - vs) ** 2).mean(), ((clip_v - vs) ** 2).mean()) 178 | 179 | loss = aloss - loss_entropy*self.c_en + v_loss*self.c_vf 180 | loss.backward() 181 | if self.max_grad_norm != -1: 182 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 183 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 184 | self.opti.step() 185 | 186 | info = dict(vloss=v_loss.item(), aloss=aloss.item(), entropy=loss_entropy.item(), kl=kl.item(), loss=loss.item()) 187 | return info 188 | 189 | def lr_scheduler(self): 190 | if self.anneal_lr: 191 | self.opti_scheduler.step() 192 | 193 | def getV(self, s): 194 | with torch.no_grad(): 195 | emb = self.actor.emb(s) 196 | v = self.critic(emb) # [1,1] 197 | 198 | return torch.squeeze(v) # 199 | 200 | def update_a(self): 201 | self.old_actor.load_state_dict(self.actor.state_dict()) -------------------------------------------------------------------------------- /algos/ppo/core_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal, MultivariateNormal 3 | import numpy as np 4 | from utils.normalization import RunningMeanStd 5 | 6 | 7 | class Discriminator(torch.nn.Module): 8 | def __init__(self, s_dim, a_dim): 9 | super().__init__() 10 | self.s_emb = torch.nn.Sequential( 11 | torch.nn.Linear(np.prod(s_dim), 100) 12 | ) 13 | self.a_emb = torch.nn.Sequential( 14 | torch.nn.Linear(np.prod(a_dim), 100) 15 | ) 16 | self.d = torch.nn.Sequential( 17 | torch.nn.Linear(100+100, 100), 18 | torch.nn.LeakyReLU(0.2), 19 | torch.nn.Linear(100, 1), 20 | torch.nn.Sigmoid(), 21 | ) 22 | 23 | self.returns = None 24 | self.ret_rms = RunningMeanStd(shape=()) 25 | 26 | def forward(self, s, a): 27 | s = self.s_emb(s) 28 | a = self.a_emb(a) 29 | x = torch.cat([s, a], dim=1) 30 | x = self.d(x) 31 | return x 32 | 33 | def predict_v(self, s, a, masks, update_rms=True, gamma=0.99): 34 | d = self(s, a) 35 | reward = torch.squeeze(-torch.log(d)) 36 | if self.returns is None: 37 | self.returns = reward.clone() 38 | 39 | if update_rms: 40 | self.returns = self.returns * masks * gamma + reward 41 | self.ret_rms.update(self.returns.detach().cpu().numpy()) 42 | 43 | return reward / np.sqrt(self.ret_rms.var + 1e-8) 44 | 45 | 46 | def initialize_weights(mod, initialization_type, scale=np.sqrt(2)): 47 | ''' 48 | Weight initializer for the models. 49 | Inputs: A model, Returns: none, initializes the parameters 50 | ''' 51 | for p in mod.parameters(): 52 | if initialization_type == "normal": 53 | p.data.normal_(0.01) 54 | elif initialization_type == "xavier": 55 | if len(p.data.shape) >= 2: 56 | torch.nn.init.xavier_uniform_(p.data) 57 | else: 58 | p.data.zero_() 59 | elif initialization_type == "orthogonal": 60 | if len(p.data.shape) >= 2: 61 | torch.nn.init.orthogonal_(p.data, gain=scale) 62 | else: 63 | p.data.zero_() 64 | else: 65 | raise ValueError("Need a valid initialization key") 66 | 67 | class Critic(torch.nn.Module): 68 | def __init__(self, s_dim, emb_dim=64): 69 | super().__init__() 70 | 71 | self.fc = torch.nn.Sequential( 72 | torch.nn.Linear(np.prod(s_dim), 64), 73 | torch.nn.ReLU(), 74 | torch.nn.Linear(64, 64), 75 | torch.nn.ReLU(), 76 | ) 77 | self.v = torch.nn.Sequential( 78 | torch.nn.Linear(64, 1), 79 | ) 80 | 81 | initialize_weights(self.fc, "orthogonal") 82 | initialize_weights(self.v, "orthogonal", scale=1) 83 | 84 | def forward(self, s_emb): 85 | x = self.fc(s_emb) 86 | v = self.v(x) 87 | return v 88 | 89 | 90 | class Actor(torch.nn.Module): 91 | def __init__(self, s_dim, a_dim, a_max=1, emb_dim=64): 92 | super().__init__() 93 | self.act_dim = a_dim 94 | self.a_max = a_max 95 | self.fc = torch.nn.Sequential( 96 | torch.nn.Linear(np.prod(s_dim), 64), 97 | torch.nn.ReLU(), 98 | torch.nn.Linear(64, 64), 99 | torch.nn.ReLU(), 100 | ) 101 | 102 | self.mu = torch.nn.Sequential( 103 | torch.nn.Linear(64, a_dim), 104 | torch.nn.Tanh() 105 | ) 106 | self.var = torch.nn.Parameter(torch.zeros(1, a_dim)) 107 | 108 | initialize_weights(self.fc, "orthogonal") 109 | initialize_weights(self.mu, "orthogonal", scale=0.01) 110 | 111 | def forward(self, s): 112 | emb = self.fc(s) 113 | mu = self.mu(emb)*self.a_max 114 | 115 | return mu, torch.exp(self.var) 116 | 117 | def select_action(self, s): 118 | with torch.no_grad(): 119 | mean, std = self(s) 120 | 121 | normal = Normal(mean, std) 122 | action = normal.sample() # [1, a_dim] 123 | action = torch.squeeze(action, dim=0) # [a_dim] 124 | 125 | return action 126 | 127 | def log_pi(self, s, a): 128 | mean, std = self(s) 129 | 130 | normal = Normal(mean, std) 131 | logpi = normal.log_prob(a) 132 | logpi = torch.sum(logpi, dim=1) 133 | entropy = normal.entropy().sum(dim=1) 134 | 135 | return logpi, entropy # [None,] 136 | 137 | class ActorDisc(torch.nn.Module): 138 | def __init__(self, s_dim, a_num, emb_dim=64): 139 | super().__init__() 140 | self.act_num = a_num 141 | self.fc = torch.nn.Sequential( 142 | torch.nn.Linear(np.prod(s_dim), 64), 143 | torch.nn.ReLU(), 144 | torch.nn.Linear(64, 64), 145 | torch.nn.ReLU(), 146 | ) 147 | 148 | self.final = torch.nn.Sequential( 149 | torch.nn.Linear(emb_dim, a_num), 150 | ) 151 | 152 | initialize_weights(self.fc, "orthogonal") 153 | initialize_weights(self.final, "orthogonal", 0.01) 154 | 155 | def forward(self, s): 156 | emb = self.fc(s) 157 | x = self.final(emb) 158 | 159 | return x 160 | 161 | def select_action(self, s): 162 | with torch.no_grad(): 163 | x = self(s) 164 | normal = torch.distributions.Categorical(logits=x) 165 | action = normal.sample() 166 | action = torch.squeeze(action, dim=0) 167 | 168 | return action 169 | 170 | def log_pi(self, s, a): 171 | x = self(s) 172 | 173 | normal = torch.distributions.Categorical(logits=x) 174 | logpi = normal.log_prob(a) 175 | 176 | return logpi # [None,] 177 | class PPO(torch.nn.Module): 178 | def __init__(self, state_dim, act_dim, act_max, epsilon, device, lr_a=0.001, 179 | c_en=0.01, c_vf=0.5, max_grad_norm=-1, anneal_lr=False, train_steps=1000): 180 | super().__init__() 181 | if type(act_dim) == np.int64 or type(act_dim) == np.int: 182 | self.actor = ActorDisc(state_dim, act_dim).to(device) 183 | self.old_actor = ActorDisc(state_dim, act_dim).to(device) 184 | else: 185 | self.actor = Actor(state_dim, act_dim[0], act_max).to(device) 186 | self.old_actor = Actor(state_dim, act_dim[0], act_max).to(device) 187 | self.critic = Critic(state_dim).to(device) 188 | self.epsilon = epsilon 189 | self.c_en = c_en 190 | self.c_vf = c_vf 191 | 192 | self.max_grad_norm = max_grad_norm 193 | self.anneal_lr = anneal_lr 194 | 195 | self.opti = torch.optim.Adam(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr_a, eps=1e-5) 196 | 197 | if anneal_lr: 198 | lam = lambda f: 1 - f / train_steps 199 | self.opti_scheduler = torch.optim.lr_scheduler.LambdaLR(self.opti, lr_lambda=lam) 200 | 201 | 202 | def train_ac(self, s, a, adv, vs, oldv, is_clip_v=True): 203 | self.opti.zero_grad() 204 | logpi, entropy = self.actor.log_pi(s, a) 205 | old_logpi, _ = self.old_actor.log_pi(s, a) 206 | 207 | ratio = torch.exp(logpi - old_logpi) 208 | surr = ratio * adv 209 | clip_adv = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * adv 210 | aloss = -torch.mean(torch.min(surr, clip_adv)) 211 | # loss_entropy = torch.mean(torch.exp(logpi) * logpi) 212 | loss_entropy = entropy.mean() 213 | kl = torch.mean(old_logpi - logpi) 214 | 215 | v = self.critic(s) 216 | v = torch.squeeze(v, 1) 217 | 218 | if not is_clip_v: 219 | v_loss = ((v - vs) ** 2).mean() 220 | else: 221 | clip_v = oldv + torch.clamp(v - oldv, -self.epsilon, self.epsilon) 222 | v_max = torch.max(((v - vs) ** 2), ((clip_v - vs) ** 2)) 223 | v_loss = v_max.mean() 224 | 225 | loss = aloss - loss_entropy*self.c_en + v_loss*self.c_vf 226 | loss.backward() 227 | if self.max_grad_norm != -1: 228 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 229 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 230 | self.opti.step() 231 | 232 | info = dict(vloss=v_loss.item(), aloss=aloss.item(), entropy=loss_entropy.item(), kl=kl.item(), loss=loss.item()) 233 | return info 234 | 235 | def lr_scheduler(self): 236 | if self.anneal_lr: 237 | self.opti_scheduler.step() 238 | 239 | def getV(self, s): 240 | with torch.no_grad(): 241 | v = self.critic(s) # [1,1] 242 | 243 | return torch.squeeze(v) # 244 | 245 | def update_a(self): 246 | self.old_actor.load_state_dict(self.actor.state_dict()) -------------------------------------------------------------------------------- /algos/ppo/generate_expert_dmc.py: -------------------------------------------------------------------------------- 1 | import algos.ppo.core_cnn_torch as core 2 | 3 | import numpy as np 4 | import gym 5 | import argparse 6 | import scipy 7 | from scipy import signal 8 | 9 | import os 10 | from utils.logx import EpochLogger 11 | import torch 12 | import dmc2gym 13 | from collections import deque 14 | import pickle 15 | from env.dmc_env import DMCFrameStack 16 | from utils.normalization import * 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--domain_name', default='cheetah') 22 | parser.add_argument('--task_name', default='run') 23 | parser.add_argument('--image_size', default=84, type=int) 24 | parser.add_argument('--action_repeat', default=1, type=int) 25 | parser.add_argument('--frame_stack', default=3, type=int) 26 | parser.add_argument('--encoder_type', default='pixel', type=str) 27 | 28 | parser.add_argument('--exp_name', default="ppo_cheetah_run_clipv_maxgrad_anneallr2.5e-3_stack3_normal_state01_maxkl0.03_gae") 29 | parser.add_argument('--seed', default=10, type=int) 30 | parser.add_argument('--norm_state', default=False) 31 | parser.add_argument('--norm_rewards', default=False) 32 | parser.add_argument('--expert_num', default=10, type=int) 33 | parser.add_argument('--check_num', default=900, type=int) 34 | args = parser.parse_args() 35 | 36 | # env = gym.make("Hopper-v2") 37 | env = dmc2gym.make( 38 | domain_name=args.domain_name, 39 | task_name=args.task_name, 40 | seed=args.seed, 41 | visualize_reward=False, 42 | from_pixels=(args.encoder_type == 'pixel'), 43 | height=args.image_size, 44 | width=args.image_size, 45 | frame_skip=args.action_repeat 46 | ) 47 | if args.encoder_type == 'pixel': 48 | env = DMCFrameStack(env, k=args.frame_stack) 49 | state_dim = env.observation_space.shape 50 | act_dim = env.action_space.shape[0] 51 | device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu") 52 | 53 | from utils.run_utils import setup_logger_kwargs 54 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 55 | 56 | actor = core.Actor(state_dim, act_dim).to(device) 57 | checkpoint = torch.load(os.path.join(logger_kwargs["output_dir"], "checkpoints", str(args.check_num) + '.pth')) 58 | actor.load_state_dict(checkpoint["actor"]) 59 | 60 | state_norm = Identity() 61 | state_norm = ImageProcess(state_norm) 62 | reward_norm = Identity() 63 | file = os.path.join(logger_kwargs["output_dir"], "checkpoints", str(args.check_num) + '.pkl') 64 | with open(file, "rb") as f: 65 | if args.norm_state: 66 | state_norm = pickle.load(f)["state"] 67 | if args.norm_rewards: 68 | reward_norm = pickle.load(f)["reward"] 69 | 70 | expert_data_file = os.path.join(logger_kwargs["output_dir"], "experts") 71 | if not os.path.exists(expert_data_file): 72 | os.mkdir(expert_data_file) 73 | expert_data = {"obs":[], "action":[]} 74 | 75 | obs = env.reset() 76 | state = state_norm(obs, update=False) 77 | rew = 0 78 | rew_list = [] 79 | epi = 0 80 | while epi <= args.expert_num: 81 | # env.render() 82 | expert_data["obs"].append(obs) 83 | state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) 84 | # a, var = actor(state_tensor) 85 | a = actor.select_action(state_tensor) 86 | # pi = actor.log_pi(state_tensor, a) 87 | a = torch.squeeze(a, 0).detach().cpu().numpy() 88 | a = np.clip(a, -1, 1) 89 | expert_data["action"].append(a) 90 | obs, r, d, _ = env.step(a) 91 | 92 | rew += r 93 | if d: 94 | rew_list.append(rew) 95 | epi += 1 96 | print("reward", rew) 97 | 98 | # if epi % 10 == 0: 99 | # print("teset_", np.mean(rew_list)) 100 | # rew_list = [] 101 | obs = env.reset() 102 | rew = 0 103 | 104 | state = state_norm(obs, update=False) 105 | 106 | expert_data["obs"] = np.array(expert_data["obs"]) 107 | expert_data["action"] = np.array(expert_data["action"]) 108 | 109 | with open(os.path.join(expert_data_file, 110 | args.domain_name + "_" + args.task_name + "_epoch" + str(args.expert_num) + ".pkl"), "wb") as f: 111 | pickle.dump(expert_data, f) 112 | 113 | 114 | -------------------------------------------------------------------------------- /algos/ppo/plot_seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | from user_config import DEFAULT_IMG_DIR, DEFAULT_DATA_DIR 7 | import argparse 8 | import glob 9 | 10 | def smooth(data, sm=1, value="Averagetest_reward"): 11 | if sm > 1: 12 | smooth_data = [] 13 | for d in data: 14 | x = np.asarray(d[value]) 15 | y = np.ones(sm)*1.0/sm 16 | d[value] = np.convolve(y, x, "same") 17 | 18 | smooth_data.append(d) 19 | 20 | return smooth_data 21 | return data 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--plot_name', default=None) 26 | parser.add_argument('--seed', default='10', type=int) 27 | parser.add_argument('--smooth', default='2', type=int) 28 | parser.add_argument('--output_name', default=None, type=str) 29 | parser.add_argument('--x', default="Epoch", type=str) 30 | parser.add_argument('--y', default="Averagetest_reward", type=str) 31 | args = parser.parse_args() 32 | 33 | tasks = ["hopper", "halfcheetah", "walker2d"] 34 | for task in tasks: 35 | plt.cla() 36 | data = [] 37 | for kl in [0.03, 0.05, 0.07]: 38 | for seed in [20, 30, 40]: 39 | taskname = "ppo_kl_" + task + "_clipv_maxgrad_anneallr3e-4_normal_maxkl" + str(kl) \ 40 | + "_gae_norm-state-return_steps2048_batch64_notdone_lastv_4_entropy_update10" 41 | file_dir = os.path.join(DEFAULT_DATA_DIR, taskname) 42 | file_seed = os.path.join(file_dir, taskname+"_s" + str(seed), "progress.txt") 43 | pd_data = pd.read_table(file_seed) 44 | pd_data["KL"] = "max_kl" + str(kl) 45 | 46 | data.append(pd_data) 47 | 48 | smooth(data, sm=args.smooth) 49 | data = pd.concat(data, ignore_index=True) 50 | 51 | sns.set(style="darkgrid", font_scale=1.5) 52 | sns.lineplot(data=data, x=args.x, y=args.y, hue="KL") 53 | 54 | output_name = "ppo_" + task + "_smooth" 55 | out_file = os.path.join(DEFAULT_IMG_DIR, output_name + ".png") 56 | plt.legend(loc='best').set_draggable(True) 57 | plt.tight_layout(pad=0.5) 58 | plt.savefig(out_file) 59 | 60 | 61 | -------------------------------------------------------------------------------- /algos/ppo/run_ppo_atari_torch.py: -------------------------------------------------------------------------------- 1 | try: 2 | import algos.ppo.core_cnn_torch as core 3 | except Exception: 4 | import core_cnn_torch as core 5 | 6 | import numpy as np 7 | import gym 8 | import argparse 9 | import scipy 10 | from scipy import signal 11 | import pickle 12 | 13 | import os 14 | from utils.logx import EpochLogger 15 | import torch 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from env.atari_lib import make_atari, wrap_deepmind 19 | from utils.normalization import * 20 | import json 21 | from algos.ppo.utils import discount_path, get_path_indices 22 | 23 | class ReplayBuffer: 24 | def __init__(self, size, state_dim, act_dim, gamma=0.99, lam=0.95, is_gae=True): 25 | self.size = size 26 | self.state_dim = state_dim 27 | self.act_dim = act_dim 28 | self.gamma = gamma 29 | self.lam = lam 30 | self.is_gae = is_gae 31 | self.reset() 32 | 33 | def reset(self): 34 | self.state = np.zeros((self.size,) + self.state_dim, np.float32) 35 | if type(self.act_dim) == np.int64 or type(self.act_dim) == np.int: 36 | self.action = np.zeros((self.size, ), np.int32) 37 | else: 38 | self.action = np.zeros((self.size,) + self.act_dim, np.float32) 39 | self.v = np.zeros((self.size, ), np.float32) 40 | self.reward = np.zeros((self.size, ), np.float32) 41 | self.adv = np.zeros((self.size, ), np.float32) 42 | self.mask = np.zeros((self.size, ), np.float32) 43 | self.ptr, self.path_start = 0, 0 44 | 45 | def add(self, s, a, r, mask): 46 | if self.ptr < self.size: 47 | self.state[self.ptr] = s 48 | self.action[self.ptr] = a 49 | self.reward[self.ptr] = r 50 | self.mask[self.ptr] = mask 51 | self.ptr += 1 52 | 53 | def update_v(self, v, pos): 54 | self.v[pos] = v 55 | 56 | def finish_path(self, last_v): 57 | """ 58 | Calculate GAE advantage, discounted returns, and 59 | true reward (average reward per trajectory) 60 | 61 | GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t) 62 | using formula from John Schulman's code: 63 | V(s_t+1) = {0 if s_t is terminal 64 | {v_s_{t+1} if s_t not terminal and t != T (last step) 65 | {v_s if s_t not terminal and t == T 66 | """ 67 | v_ = np.concatenate([self.v[1:], [last_v]], axis=0) * self.mask 68 | adv = self.reward + self.gamma * v_ - self.v 69 | 70 | indices = get_path_indices(self.mask) 71 | 72 | for (start, end) in indices: 73 | self.adv[start:end] = discount_path(adv[start:end], self.gamma * self.lam) 74 | if not self.is_gae: 75 | self.reward[start:end] = discount_path(self.reward[start:end], self.gamma) 76 | if self.is_gae: 77 | self.reward = self.adv + self.v 78 | 79 | self.adv = (self.adv - np.mean(self.adv))/(np.std(self.adv) + 1e-8) 80 | 81 | def get_batch(self, batch=100, shuffle=True): 82 | if shuffle: 83 | indices = np.random.permutation(self.size) 84 | else: 85 | indices = np.arange(self.size) 86 | 87 | for idx in np.arange(0, self.size, batch): 88 | pos = indices[idx:(idx + batch)] 89 | yield (self.state[pos], self.action[pos], self.reward[pos], self.adv[pos], self.v[pos]) 90 | 91 | class ImageToPyTorch(gym.ObservationWrapper): 92 | """ 93 | Image shape to channels x weight x height 94 | """ 95 | 96 | def __init__(self, env): 97 | super(ImageToPyTorch, self).__init__(env) 98 | old_shape = self.observation_space.shape 99 | self.observation_space = gym.spaces.Box( 100 | low=0, 101 | high=255, 102 | shape=(old_shape[-1], old_shape[0], old_shape[1]), 103 | dtype=np.float32, 104 | ) 105 | 106 | def observation(self, observation): 107 | obs = np.array(observation).astype(np.float32) / 255.0 108 | return np.transpose(obs, axes=(2, 0, 1)) 109 | 110 | 111 | if __name__ == '__main__': 112 | 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--iteration', default=int(1e3), type=int) 115 | parser.add_argument('--gamma', default=0.99, type=float) 116 | parser.add_argument('--lam', default=0.95, type=float) 117 | parser.add_argument('--a_update', default=10, type=int) 118 | parser.add_argument('--lr', default=2.5e-4, type=float) 119 | parser.add_argument('--log', type=str, default="logs") 120 | parser.add_argument('--steps', default=3000, type=int) 121 | parser.add_argument('--gpu', default=0, type=int) 122 | parser.add_argument('--env', default="BreakoutNoFrameskip-v4") 123 | parser.add_argument('--env_num', default=4, type=int) 124 | parser.add_argument('--exp_name', default="ppo_Pong") 125 | parser.add_argument('--seed', default=0, type=int) 126 | parser.add_argument('--batch', default=50, type=int) 127 | parser.add_argument('--norm_state', action="store_true") 128 | parser.add_argument('--norm_rewards', default=False) 129 | parser.add_argument('--clip_coef', type=float, default=0.2) 130 | parser.add_argument('--is_clip_v', action="store_true") 131 | parser.add_argument('--is_gae', action="store_true") 132 | parser.add_argument('--last_v', action="store_true") 133 | parser.add_argument('--max_grad_norm', default=-1, type=float) 134 | parser.add_argument('--anneal_lr', action="store_true") 135 | parser.add_argument('--debug', action="store_false") 136 | parser.add_argument('--log_every', default=10, type=int) 137 | parser.add_argument('--target_kl', default=0.03, type=float) 138 | parser.add_argument('--test_epoch', default=10, type=int) 139 | args = parser.parse_args() 140 | 141 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 142 | 143 | from utils.run_utils import setup_logger_kwargs 144 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 145 | logger = EpochLogger(**logger_kwargs) 146 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 147 | with open(os.path.join(logger.output_dir, 'args.json'), 'w') as f: 148 | json.dump(vars(args), f, sort_keys=True, indent=4) 149 | 150 | env = make_atari(args.env) 151 | env = gym.wrappers.RecordEpisodeStatistics(env) 152 | env = wrap_deepmind(env, frame_stack=True) 153 | env = ImageToPyTorch(env) 154 | # test_env = make_atari(args.env) 155 | # test_env = gym.wrappers.RecordEpisodeStatistics(test_env) 156 | # test_env = wrap_deepmind(test_env, frame_stack=True) 157 | # test_env = ImageToPyTorch(test_env) 158 | torch.manual_seed(args.seed) 159 | np.random.seed(args.seed) 160 | env.seed(args.seed) 161 | 162 | state_dim = env.observation_space.shape 163 | act_dim = env.action_space.n 164 | ppo = core.PPO(state_dim, act_dim, 1, args.clip_coef, device, lr_a=args.lr, 165 | max_grad_norm=args.max_grad_norm, 166 | anneal_lr=args.anneal_lr, train_steps=args.iteration) 167 | replay = ReplayBuffer(args.steps, state_dim, act_dim, is_gae=args.is_gae) 168 | 169 | state_norm = Identity() 170 | reward_norm = Identity() 171 | if args.norm_state: 172 | state_norm = AutoNormalization(state_norm, state_dim, clip=10.0) 173 | if args.norm_rewards == "rewards": 174 | reward_norm = AutoNormalization(reward_norm, (), clip=10.0) 175 | elif args.norm_rewards == "returns": 176 | reward_norm = RewardFilter(reward_norm, (), clip=10.0) 177 | 178 | obs = env.reset() 179 | obs = state_norm(obs) 180 | rew = 0 181 | for iter in range(args.iteration): 182 | ppo.train() 183 | replay.reset() 184 | flag = 0 185 | 186 | for step in range(args.steps): 187 | state_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) 188 | a_tensor = ppo.actor.select_action(state_tensor) 189 | a = a_tensor.detach().cpu().numpy() 190 | obs_, r, done, info = env.step(a) 191 | rew += r 192 | r = reward_norm(r) 193 | mask = 1-done 194 | 195 | replay.add(obs, a, r, mask) 196 | 197 | obs = obs_ 198 | if done: 199 | if 'episode' in info.keys(): 200 | logger.store(reward=info['episode']['r']) 201 | flag = 1 202 | rew = 0 203 | obs = env.reset() 204 | obs = state_norm(obs) 205 | if flag == 0: 206 | logger.store(reward=0) 207 | print("null reward") 208 | 209 | state = replay.state 210 | for idx in np.arange(0, state.shape[0], args.batch): 211 | if idx + args.batch <= state.shape[0]: 212 | pos = np.arange(idx, idx + args.batch) 213 | else: 214 | pos = np.arange(idx, state.shape[0]) 215 | s = torch.tensor(state[pos], dtype=torch.float32).to(device) 216 | v = ppo.getV(s).detach().cpu().numpy() 217 | replay.update_v(v, pos) 218 | s_tensor = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0) 219 | last_v = ppo.getV(s_tensor).detach().cpu().numpy() 220 | replay.finish_path(last_v) 221 | 222 | ppo.update_a() 223 | 224 | for i in range(args.a_update): 225 | for (s, a, r, adv, v) in replay.get_batch(batch=args.batch): 226 | s_tensor = torch.tensor(s, dtype=torch.float32, device=device) 227 | a_tensor = torch.tensor(a, dtype=torch.float32, device=device) 228 | adv_tensor = torch.tensor(adv, dtype=torch.float32, device=device) 229 | r_tensor = torch.tensor(r, dtype=torch.float32, device=device) 230 | v_tensor = torch.tensor(v, dtype=torch.float32, device=device) 231 | 232 | info = ppo.train_ac(s_tensor, a_tensor, adv_tensor, r_tensor, v_tensor, is_clip_v=args.is_clip_v) 233 | 234 | if args.debug: 235 | logger.store(aloss=info["aloss"]) 236 | logger.store(vloss=info["vloss"]) 237 | logger.store(entropy=info["entropy"]) 238 | logger.store(kl=info["kl"]) 239 | 240 | if logger.get_stats("kl")[0] > args.target_kl: 241 | print("stop at:", str(i)) 242 | break 243 | 244 | if args.anneal_lr: 245 | ppo.lr_scheduler() 246 | 247 | 248 | # writer.add_scalar("test_reward", logger.get_stats("test_reward")[0], global_step=iter) 249 | writer.add_scalar("reward", logger.get_stats("reward")[0], global_step=iter) 250 | writer.add_histogram("action", np.array(replay.action), global_step=iter) 251 | if args.debug: 252 | writer.add_scalar("aloss", logger.get_stats("aloss")[0], global_step=iter) 253 | writer.add_scalar("vloss", logger.get_stats("vloss")[0], global_step=iter) 254 | writer.add_scalar("entropy", logger.get_stats("entropy")[0], global_step=iter) 255 | writer.add_scalar("kl", logger.get_stats("kl")[0], global_step=iter) 256 | 257 | logger.log_tabular('Epoch', iter) 258 | logger.log_tabular("reward", with_min_and_max=True) 259 | # logger.log_tabular("test_reward", with_min_and_max=True) 260 | if args.debug: 261 | logger.log_tabular("aloss", with_min_and_max=True) 262 | logger.log_tabular("vloss", with_min_and_max=True) 263 | logger.log_tabular("entropy", with_min_and_max=True) 264 | logger.log_tabular("kl", with_min_and_max=True) 265 | logger.dump_tabular() 266 | 267 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 268 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 269 | if iter % args.log_every == 0: 270 | state = { 271 | "actor": ppo.actor.state_dict(), 272 | "critic": ppo.critic.state_dict(), 273 | 274 | } 275 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 276 | norm = {"state": state_norm, "reward": reward_norm} 277 | with open(os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pkl'), "wb") as f: 278 | pickle.dump(norm, f) 279 | 280 | 281 | -------------------------------------------------------------------------------- /algos/ppo/run_ppo_dmc_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import argparse 4 | import scipy 5 | from scipy import signal 6 | import pickle 7 | from collections import deque 8 | import json 9 | 10 | import os 11 | from utils.logx import EpochLogger 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | import dmc2gym 15 | import env.atari_lib as atari 16 | from env.dmc_env import DMCFrameStack 17 | from utils.normalization import * 18 | from algos.ppo.utils import discount_path, get_path_indices 19 | 20 | class ReplayBuffer: 21 | def __init__(self, size, state_dim, act_dim, gamma=0.99, lam=0.95, is_gae=True): 22 | self.size = size 23 | self.state_dim = state_dim 24 | self.act_dim = act_dim 25 | self.gamma = gamma 26 | self.lam = lam 27 | self.is_gae = is_gae 28 | self.reset() 29 | 30 | def reset(self): 31 | self.state = np.zeros((self.size,) + self.state_dim, np.float32) 32 | if type(self.act_dim) == np.int64 or type(self.act_dim) == np.int: 33 | self.action = np.zeros((self.size, ), np.int32) 34 | else: 35 | self.action = np.zeros((self.size,) + self.act_dim, np.float32) 36 | self.v = np.zeros((self.size, ), np.float32) 37 | self.reward = np.zeros((self.size, ), np.float32) 38 | self.adv = np.zeros((self.size, ), np.float32) 39 | self.mask = np.zeros((self.size, ), np.float32) 40 | self.ptr, self.path_start = 0, 0 41 | 42 | def add(self, s, a, r, mask): 43 | if self.ptr < self.size: 44 | self.state[self.ptr] = s 45 | self.action[self.ptr] = a 46 | self.reward[self.ptr] = r 47 | self.mask[self.ptr] = mask 48 | self.ptr += 1 49 | 50 | def update_v(self, v, pos): 51 | self.v[pos] = v 52 | 53 | def finish_path(self, last_v=None): 54 | """ 55 | Calculate GAE advantage, discounted returns, and 56 | true reward (average reward per trajectory) 57 | 58 | GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t) 59 | using formula from John Schulman's code: 60 | V(s_t+1) = {0 if s_t is terminal 61 | {v_s_{t+1} if s_t not terminal and t != T (last step) 62 | {v_s if s_t not terminal and t == T 63 | """ 64 | if last_v is None: 65 | v_ = np.concatenate([self.v[1:], self.v[-1:]], axis=0) * self.mask 66 | else: 67 | v_ = np.concatenate([self.v[1:], [last_v]], axis=0) * self.mask 68 | adv = self.reward + self.gamma * v_ - self.v 69 | 70 | indices = get_path_indices(self.mask) 71 | 72 | for (start, end) in indices: 73 | self.adv[start:end] = discount_path(adv[start:end], self.gamma * self.lam) 74 | if not self.is_gae: 75 | self.reward[start:end] = discount_path(self.reward[start:end], self.gamma) 76 | if self.is_gae: 77 | self.reward = self.adv + self.v 78 | 79 | self.adv = (self.adv - np.mean(self.adv))/(np.std(self.adv) + 1e-8) 80 | 81 | def get_batch(self, batch=100, shuffle=True): 82 | if shuffle: 83 | indices = np.random.permutation(self.size) 84 | else: 85 | indices = np.arange(self.size) 86 | 87 | for idx in np.arange(0, self.size, batch): 88 | pos = indices[idx:(idx + batch)] 89 | yield (self.state[pos], self.action[pos], self.reward[pos], self.adv[pos], self.v[pos]) 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--domain_name', default='cheetah') 96 | parser.add_argument('--task_name', default='run') 97 | parser.add_argument('--image_size', default=84, type=int) 98 | parser.add_argument('--action_repeat', default=1, type=int) 99 | parser.add_argument('--frame_stack', default=4, type=int) 100 | parser.add_argument('--encoder_type', default='pixel', type=str) 101 | 102 | parser.add_argument('--iteration', default=int(1e3), type=int) 103 | parser.add_argument('--gamma', default=0.99, type=float) 104 | parser.add_argument('--lam', default=0.95, type=float) 105 | parser.add_argument('--c_en', default=0.01, type=float) 106 | parser.add_argument('--c_vf', default=0.5, type=float) 107 | parser.add_argument('--a_update', default=10, type=int) 108 | parser.add_argument('--lr', default=3e-4, type=float) 109 | parser.add_argument('--log', type=str, default="logs") 110 | parser.add_argument('--steps', default=3000, type=int) 111 | parser.add_argument('--gpu', default=0, type=int) 112 | parser.add_argument('--env_num', default=4, type=int) 113 | parser.add_argument('--exp_name', default="ppo_cheetah_run") 114 | parser.add_argument('--seed', default=0, type=int) 115 | parser.add_argument('--batch', default=50, type=int) 116 | parser.add_argument('--norm_state', action="store_true") 117 | parser.add_argument('--norm_rewards', default=False, type=str) 118 | parser.add_argument('--is_clip_v', action="store_true") 119 | parser.add_argument('--last_v', action="store_true") 120 | parser.add_argument('--is_gae', action="store_true") 121 | parser.add_argument('--max_grad_norm', default=-1, type=float) 122 | parser.add_argument('--anneal_lr', action="store_true") 123 | parser.add_argument('--debug', action="store_false") 124 | parser.add_argument('--log_every', default=10, type=int) 125 | parser.add_argument('--network', default="cnn") 126 | parser.add_argument('--target_kl', default=0.03, type=float) 127 | parser.add_argument('--test_epoch', default=10, type=int) 128 | args = parser.parse_args() 129 | 130 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 131 | 132 | from utils.run_utils import setup_logger_kwargs 133 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 134 | logger = EpochLogger(**logger_kwargs) 135 | with open(os.path.join(logger.output_dir, 'args.json'), 'w') as f: 136 | json.dump(vars(args), f, sort_keys=True, indent=4) 137 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 138 | 139 | env = dmc2gym.make( 140 | domain_name=args.domain_name, 141 | task_name=args.task_name, 142 | seed=args.seed, 143 | visualize_reward=False, 144 | from_pixels=(args.encoder_type == 'pixel'), 145 | height=args.image_size, 146 | width=args.image_size, 147 | frame_skip=args.action_repeat 148 | ) 149 | test_env = dmc2gym.make( 150 | domain_name=args.domain_name, 151 | task_name=args.task_name, 152 | seed=args.seed, 153 | visualize_reward=False, 154 | from_pixels=(args.encoder_type == 'pixel'), 155 | height=args.image_size, 156 | width=args.image_size, 157 | frame_skip=args.action_repeat 158 | ) 159 | if args.encoder_type == 'pixel': 160 | env = DMCFrameStack(env, k=args.frame_stack) 161 | test_env = DMCFrameStack(test_env, k=args.frame_stack) 162 | torch.manual_seed(args.seed) 163 | np.random.seed(args.seed) 164 | env.seed(args.seed) 165 | test_env.seed(args.seed) 166 | 167 | state_dim = env.observation_space.shape 168 | act_dim = env.action_space.shape 169 | action_max = env.action_space.high[0] 170 | if args.network == "cnn": 171 | import algos.ppo.core_cnn_torch as core 172 | elif args.network == "resnet": 173 | import algos.ppo.core_resnet_simple_torch as core 174 | elif args.network == "mlp" or args.encoder_type == 'state': 175 | import algos.ppo.core_torch as core 176 | ppo = core.PPO(state_dim, act_dim, action_max, 0.2, device, lr_a=args.lr, 177 | max_grad_norm=args.max_grad_norm, 178 | anneal_lr=args.anneal_lr, train_steps=args.iteration, c_en=args.c_en, c_vf=args.c_vf) 179 | replay = ReplayBuffer(args.steps, state_dim, act_dim, is_gae=args.is_gae) 180 | 181 | state_norm = Identity() 182 | if args.encoder_type == 'pixel': 183 | state_norm = ImageProcess(state_norm) 184 | reward_norm = Identity() 185 | if args.norm_state: 186 | state_norm = AutoNormalization(state_norm, state_dim, clip=10.0) 187 | if args.norm_rewards == "rewards": 188 | reward_norm = AutoNormalization(reward_norm, (), clip=10.0) 189 | elif args.norm_rewards == "returns": 190 | reward_norm = RewardFilter(reward_norm, (), clip=10.0) 191 | 192 | state_norm.reset() 193 | reward_norm.reset() 194 | obs = env.reset() 195 | obs = state_norm(obs) 196 | rew = 0 197 | for iter in range(args.iteration): 198 | ppo.train() 199 | replay.reset() 200 | 201 | for step in range(args.steps): 202 | state_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) 203 | a_tensor = ppo.actor.select_action(state_tensor) 204 | a = a_tensor.detach().cpu().numpy() 205 | obs_, r, done, _ = env.step(np.clip(a, -1, 1)) 206 | rew += r 207 | r = reward_norm(r) 208 | mask = 1-done 209 | 210 | replay.add(obs, a, r, mask) 211 | 212 | obs = obs_ 213 | if done: 214 | logger.store(reward=rew) 215 | rew = 0 216 | obs = env.reset() 217 | state_norm.reset() 218 | reward_norm.reset() 219 | obs = state_norm(obs) 220 | 221 | state = replay.state 222 | for idx in np.arange(0, state.shape[0], args.batch): 223 | if idx + args.batch <= state.shape[0]: 224 | pos = np.arange(idx, idx + args.batch) 225 | else: 226 | pos = np.arange(idx, state.shape[0]) 227 | s = torch.tensor(state[pos], dtype=torch.float32).to(device) 228 | v = ppo.getV(s).detach().cpu().numpy() 229 | replay.update_v(v, pos) 230 | if args.last_v: 231 | s_tensor = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0) 232 | last_v = ppo.getV(s_tensor).detach().cpu().numpy() 233 | replay.finish_path(last_v=last_v) 234 | else: 235 | replay.finish_path() 236 | 237 | ppo.update_a() 238 | for i in range(args.a_update): 239 | for (s, a, r, adv, v) in replay.get_batch(batch=args.batch): 240 | s_tensor = torch.tensor(s, dtype=torch.float32, device=device) 241 | a_tensor = torch.tensor(a, dtype=torch.float32, device=device) 242 | adv_tensor = torch.tensor(adv, dtype=torch.float32, device=device) 243 | r_tensor = torch.tensor(r, dtype=torch.float32, device=device) 244 | v_tensor = torch.tensor(v, dtype=torch.float32, device=device) 245 | 246 | info = ppo.train_ac(s_tensor, a_tensor, adv_tensor, r_tensor, v_tensor, is_clip_v=args.is_clip_v) 247 | 248 | if args.debug: 249 | logger.store(aloss=info["aloss"]) 250 | logger.store(vloss=info["vloss"]) 251 | logger.store(entropy=info["entropy"]) 252 | logger.store(kl=info["kl"]) 253 | 254 | if logger.get_stats("kl", with_min_and_max=True)[3] > args.target_kl: 255 | print("stop at:", str(i)) 256 | break 257 | 258 | if args.anneal_lr: 259 | ppo.lr_scheduler() 260 | 261 | ppo.eval() 262 | test_a = [] 263 | test_a_std = [] 264 | for i in range(args.test_epoch): 265 | test_obs = test_env.reset() 266 | test_obs = state_norm(test_obs, update=False) 267 | test_rew = 0 268 | 269 | while True: 270 | state_tensor = torch.tensor(test_obs, dtype=torch.float32, device=device).unsqueeze(0) 271 | a_tensor, std = ppo.actor(state_tensor) 272 | a_tensor = torch.squeeze(a_tensor, dim=0) 273 | a = a_tensor.detach().cpu().numpy() 274 | test_obs, r, done, _ = test_env.step(np.clip(a, -1, 1)) 275 | test_rew += r 276 | 277 | test_a.append(a) 278 | test_a_std.append(std.detach().cpu().numpy()) 279 | 280 | if done: 281 | logger.store(test_reward=test_rew) 282 | break 283 | test_obs = state_norm(test_obs, update=False) 284 | 285 | writer.add_scalar("test_reward", logger.get_stats("test_reward")[0], global_step=iter) 286 | writer.add_scalar("reward", logger.get_stats("reward")[0], global_step=iter) 287 | writer.add_histogram("action", np.array(replay.action), global_step=iter) 288 | writer.add_histogram("test_action", np.array(test_a), global_step=iter) 289 | writer.add_histogram("test_action_std", np.array(test_a_std), global_step=iter) 290 | 291 | if args.debug: 292 | writer.add_scalar("aloss", logger.get_stats("aloss")[0], global_step=iter) 293 | writer.add_scalar("vloss", logger.get_stats("vloss")[0], global_step=iter) 294 | writer.add_scalar("entropy", logger.get_stats("entropy")[0], global_step=iter) 295 | writer.add_scalar("kl", logger.get_stats("kl")[0], global_step=iter) 296 | 297 | logger.log_tabular('Epoch', iter) 298 | logger.log_tabular("reward", with_min_and_max=True) 299 | logger.log_tabular("test_reward", with_min_and_max=True) 300 | if args.debug: 301 | logger.log_tabular("aloss", with_min_and_max=True) 302 | logger.log_tabular("vloss", with_min_and_max=True) 303 | logger.log_tabular("entropy", with_min_and_max=True) 304 | logger.log_tabular("kl", with_min_and_max=True) 305 | logger.dump_tabular() 306 | 307 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 308 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 309 | if iter % args.log_every == 0: 310 | state = { 311 | "actor": ppo.actor.state_dict(), 312 | "critic": ppo.critic.state_dict(), 313 | 314 | } 315 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 316 | norm = {"state": state_norm, "reward": reward_norm} 317 | with open(os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pkl'), "wb") as f: 318 | pickle.dump(norm, f) 319 | 320 | 321 | -------------------------------------------------------------------------------- /algos/ppo/run_ppo_multi_torch.py: -------------------------------------------------------------------------------- 1 | try: 2 | import algos.gail.pytorch.core_torch as core 3 | except Exception: 4 | import core_torch as core 5 | 6 | import numpy as np 7 | import gym 8 | import argparse 9 | import scipy 10 | from scipy import signal 11 | 12 | import os 13 | from utils.logx import EpochLogger 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | from env.car_racing import CarRacing 17 | from env.vecenv import SubVectorEnv, Env, VectorEnv 18 | import pickle 19 | 20 | 21 | # Can be used to convert rewards into discounted returns: 22 | # ret[i] = sum of t = i to T of gamma^(t-i) * rew[t] 23 | def discount_path(path, gamma): 24 | ''' 25 | Given a "path" of items x_1, x_2, ... x_n, return the discounted 26 | path, i.e. 27 | X_1 = x_1 + h*x_2 + h^2 x_3 + h^3 x_4 28 | X_2 = x_2 + h*x_3 + h^2 x_4 + h^3 x_5 29 | etc. 30 | Can do (more efficiently?) w SciPy. Python here for readability 31 | Inputs: 32 | - path, list/tensor of floats 33 | - h, discount rate 34 | Outputs: 35 | - Discounted path, as above 36 | ''' 37 | curr = 0 38 | rets = [] 39 | for i in range(len(path)): 40 | curr = curr*gamma + path[-1-i] 41 | rets.append(curr) 42 | rets = np.stack(list(reversed(rets)), 0) 43 | return rets 44 | 45 | 46 | def get_path_indices(not_dones): 47 | """ 48 | Returns list of tuples of the form: 49 | (agent index, time index start, time index end + 1) 50 | For each path seen in the not_dones array of shape (# agents, # time steps) 51 | E.g. if we have an not_dones of composition: 52 | tensor([[1, 1, 0, 1, 1, 1, 1, 1, 1, 1], 53 | [1, 1, 0, 1, 1, 0, 1, 1, 0, 1]], dtype=torch.uint8) 54 | Then we would return: 55 | [(0, 0, 3), (0, 3, 10), (1, 0, 3), (1, 3, 5), (1, 5, 9), (1, 9, 10)] 56 | """ 57 | indices = [] 58 | num_timesteps = not_dones.shape[1] 59 | for actor in range(not_dones.shape[0]): 60 | last_index = 0 61 | for i in range(num_timesteps): 62 | if not_dones[actor, i] == 0.: 63 | indices.append((actor, last_index, i + 1)) 64 | last_index = i + 1 65 | if last_index != num_timesteps: 66 | indices.append((actor, last_index, num_timesteps)) 67 | return indices 68 | 69 | 70 | class ReplayBuffer: 71 | def __init__(self, env_num, size, state_dim, action_dim, gamma=0.99, lam=0.95): 72 | self.env_num = env_num 73 | self.size = size 74 | self.state_dim = state_dim 75 | self.action_dim = action_dim 76 | self.gamma = gamma 77 | self.lam = lam 78 | self.reset() 79 | 80 | def reset(self): 81 | self.state = np.zeros((self.env_num, self.size, self.state_dim), np.float32) 82 | self.action = np.zeros((self.env_num, self.size, self.action_dim), np.float32) 83 | self.mask = np.zeros((self.env_num, self.size), np.int32) 84 | self.v = np.zeros((self.env_num, self.size), np.float32) 85 | self.reward = np.zeros((self.env_num, self.size, ), np.float32) 86 | self.adv = np.zeros((self.env_num, self.size, ), np.float32) 87 | self.ptr, self.path_start = 0, 0 88 | 89 | def add(self, s, a, v, r, mask): 90 | if self.ptr < self.size: 91 | self.state[:, self.ptr, :] = s 92 | self.action[:, self.ptr, :] = a 93 | self.v[:, self.ptr] = v 94 | self.reward[:, self.ptr] = r 95 | self.mask[:, self.ptr] = mask 96 | self.ptr += 1 97 | 98 | def finish_path(self): 99 | """ 100 | Calculate GAE advantage, discounted returns, and 101 | true reward (average reward per trajectory) 102 | 103 | GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t) 104 | using formula from John Schulman's code: 105 | V(s_t+1) = {0 if s_t is terminal 106 | {v_s_{t+1} if s_t not terminal and t != T (last step) 107 | {v_s if s_t not terminal and t == T 108 | """ 109 | v_ = np.concatenate([self.v[:, 1:], self.v[:, -1:]], axis=1)*self.mask 110 | adv = self.reward + self.gamma*v_ -self.v 111 | 112 | indices = get_path_indices(self.mask) 113 | 114 | for (num, start, end) in indices: 115 | self.adv[num, start:end] = discount_path(adv[num, start:end], self.gamma*self.lam) 116 | self.reward[num, start:end] = discount_path(self.reward[num, start:end], self.gamma) 117 | 118 | 119 | def get(self): 120 | self.state = np.concatenate([state for state in self.state], axis=0) 121 | self.action = np.concatenate([action for action in self.action], axis=0) 122 | self.v = np.concatenate([v for v in self.v], axis=0) 123 | self.adv = np.concatenate([adv for adv in self.adv], axis=0) 124 | self.reward = np.concatenate([r for r in self.reward], axis=0) 125 | self.adv = (self.adv - np.mean(self.adv))/np.std(self.adv) 126 | 127 | def get_batch(self, batch=100, shuffle=True): 128 | if shuffle: 129 | indices = np.random.permutation(self.size) 130 | else: 131 | indices = np.arange(self.size) 132 | 133 | state = np.array(self.state) 134 | action = np.array(self.action) 135 | for idx in np.arange(0, self.size, batch): 136 | pos = indices[idx:(idx + batch)] 137 | yield (state[pos], action[pos], self.reward[pos], self.adv[pos], self.v[pos]) 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--iteration', default=int(1e3), type=int) 144 | parser.add_argument('--gamma', default=0.99) 145 | parser.add_argument('--lam', default=0.95) 146 | parser.add_argument('--a_update', default=10) 147 | parser.add_argument('--c_update', default=10) 148 | parser.add_argument('--lr_a', default=4e-4) 149 | parser.add_argument('--lr_c', default=1e-3) 150 | parser.add_argument('--log', type=str, default="logs") 151 | parser.add_argument('--steps', default=3000, type=int) 152 | parser.add_argument('--gpu', default=0) 153 | parser.add_argument('--env', default="Pendulum-v0") 154 | parser.add_argument('--env_num', default=12, type=int) 155 | parser.add_argument('--exp_name', default="ppo_Pendulum_test") 156 | parser.add_argument('--seed', default=0) 157 | parser.add_argument('--batch', default=50) 158 | parser.add_argument('--norm_state', default=True) 159 | parser.add_argument('--norm_rewards', default=True) 160 | parser.add_argument('--is_clip_v', default=True) 161 | parser.add_argument('--max_grad_norm', default=-1, type=float) 162 | parser.add_argument('--anneal_lr', default=False) 163 | parser.add_argument('--debug', default=False) 164 | parser.add_argument('--log_every', default=10) 165 | args = parser.parse_args() 166 | 167 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 168 | 169 | from utils.run_utils import setup_logger_kwargs 170 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 171 | logger = EpochLogger(**logger_kwargs) 172 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 173 | 174 | env = gym.make(args.env) 175 | if args.env_num > 1: 176 | env = [Env(args.env, norm_state=args.norm_state, norm_rewards=args.norm_rewards) 177 | for _ in range(args.env_num)] 178 | env = SubVectorEnv(env) 179 | # env = CarRacing() 180 | state_dim = env.observation_space.shape[0] 181 | act_dim = env.action_space.shape 182 | action_max = env.action_space.high[0] 183 | ppo = core.PPO(state_dim, act_dim, action_max, 0.2, device, lr_a=args.lr_a, 184 | lr_c=args.lr_c, max_grad_norm=args.max_grad_norm, 185 | anneal_lr=args.anneal_lr, train_steps=args.iteration) 186 | replay = ReplayBuffer(args.env_num, args.steps, state_dim, act_dim) 187 | 188 | for iter in range(args.iteration): 189 | replay.reset() 190 | rewards = [] 191 | obs = env.reset() 192 | rew = 0 193 | 194 | for step in range(args.steps): 195 | state_tensor = torch.tensor(obs, dtype=torch.float32, device=device) 196 | a_tensor = ppo.actor.select_action(state_tensor) 197 | a = a_tensor.detach().cpu().numpy() 198 | obs_, r, done, _ = env.step(a) 199 | mask = 1-done 200 | 201 | v_pred = ppo.getV(state_tensor) 202 | replay.add(obs, a, v_pred.detach().cpu().numpy(), r, mask) 203 | 204 | obs = obs_ 205 | replay.finish_path() 206 | epi, total_reward = env.statistics() 207 | 208 | ppo.update_a() 209 | replay.get() 210 | writer.add_scalar("reward", total_reward/epi, global_step=iter) 211 | writer.add_histogram("action", np.array(replay.action), global_step=iter) 212 | 213 | for i in range(args.a_update): 214 | for (s, a, r, adv, v) in replay.get_batch(batch=args.batch): 215 | s_tensor = torch.tensor(s, dtype=torch.float32, device=device) 216 | a_tensor = torch.tensor(a, dtype=torch.float32, device=device) 217 | adv_tensor = torch.tensor(adv, dtype=torch.float32, device=device) 218 | r_tensor = torch.tensor(r, dtype=torch.float32, device=device) 219 | v_tensor = torch.tensor(v, dtype=torch.float32, device=device) 220 | 221 | info = ppo.train(s_tensor, a_tensor, adv_tensor, r_tensor, v_tensor, is_clip_v=args.is_clip_v) 222 | 223 | if args.debug: 224 | logger.store(aloss=info["aloss"]) 225 | logger.store(vloss=info["vloss"]) 226 | logger.store(entropy=info["entropy"]) 227 | logger.store(kl=info["kl"]) 228 | 229 | if args.anneal_lr: 230 | ppo.lr_scheduler() 231 | if args.debug: 232 | writer.add_scalar("aloss", logger.get_stats("aloss")[0], global_step=iter) 233 | writer.add_scalar("vloss", logger.get_stats("vloss")[0], global_step=iter) 234 | writer.add_scalar("entropy", logger.get_stats("entropy")[0], global_step=iter) 235 | writer.add_scalar("kl", logger.get_stats("kl")[0], global_step=iter) 236 | 237 | logger.log_tabular('Epoch', iter) 238 | logger.log_tabular("reward", total_reward/epi) 239 | if args.debug: 240 | logger.log_tabular("aloss", with_min_and_max=True) 241 | logger.log_tabular("vloss", with_min_and_max=True) 242 | logger.log_tabular("entropy", with_min_and_max=True) 243 | logger.log_tabular("kl", with_min_and_max=True) 244 | logger.dump_tabular() 245 | 246 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 247 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 248 | if iter % args.log_every == 0: 249 | state = { 250 | "actor": ppo.actor.state_dict(), 251 | "critic": ppo.critic.state_dict() 252 | } 253 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 254 | 255 | norm = {"state": env.envs[0].state_norm, "reward": env.envs[0].reward_norm} 256 | with open(os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pkl'), "wb") as f: 257 | pickle.dump(norm, f) 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /algos/ppo/run_ppo_torch.py: -------------------------------------------------------------------------------- 1 | try: 2 | import algos.ppo.core_torch as core 3 | except Exception: 4 | import core_torch as core 5 | 6 | import numpy as np 7 | import gym 8 | import argparse 9 | import scipy 10 | from scipy import signal 11 | import pickle 12 | 13 | import os 14 | from utils.logx import EpochLogger 15 | import torch 16 | from torch.utils.tensorboard import SummaryWriter 17 | from utils.normalization import * 18 | import json 19 | from algos.ppo.utils import discount_path, get_path_indices 20 | 21 | 22 | class ReplayBuffer: 23 | def __init__(self, size, state_dim, act_dim, gamma=0.99, lam=0.95, is_gae=True): 24 | self.size = size 25 | self.state_dim = state_dim 26 | self.act_dim = act_dim 27 | self.gamma = gamma 28 | self.lam = lam 29 | self.is_gae = is_gae 30 | self.reset() 31 | 32 | def reset(self): 33 | self.state = np.zeros((self.size,) + self.state_dim, np.float32) 34 | if type(self.act_dim) == np.int64 or type(self.act_dim) == np.int: 35 | self.action = np.zeros((self.size, ), np.int32) 36 | else: 37 | self.action = np.zeros((self.size,) + self.act_dim, np.float32) 38 | self.v = np.zeros((self.size, ), np.float32) 39 | self.reward = np.zeros((self.size, ), np.float32) 40 | self.adv = np.zeros((self.size, ), np.float32) 41 | self.mask = np.zeros((self.size, ), np.float32) 42 | self.ptr, self.path_start = 0, 0 43 | 44 | def add(self, s, a, r, mask): 45 | if self.ptr < self.size: 46 | self.state[self.ptr] = s 47 | self.action[self.ptr] = a 48 | self.reward[self.ptr] = r 49 | self.mask[self.ptr] = mask 50 | self.ptr += 1 51 | 52 | def update_v(self, v, pos): 53 | self.v[pos] = v 54 | 55 | def finish_path(self, last_v=None): 56 | """ 57 | Calculate GAE advantage, discounted returns, and 58 | true reward (average reward per trajectory) 59 | 60 | GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t) 61 | using formula from John Schulman's code: 62 | V(s_t+1) = {0 if s_t is terminal 63 | {v_s_{t+1} if s_t not terminal and t != T (last step) 64 | {v_s if s_t not terminal and t == T 65 | """ 66 | if last_v is None: 67 | v_ = np.concatenate([self.v[1:], self.v[-1:]], axis=0) * self.mask 68 | else: 69 | v_ = np.concatenate([self.v[1:], [last_v]], axis=0) * self.mask 70 | adv = self.reward + self.gamma * v_ - self.v 71 | 72 | indices = get_path_indices(self.mask) 73 | 74 | for (start, end) in indices: 75 | self.adv[start:end] = discount_path(adv[start:end], self.gamma * self.lam) 76 | if not self.is_gae: 77 | self.reward[start:end] = discount_path(self.reward[start:end], self.gamma) 78 | if self.is_gae: 79 | self.reward = self.adv + self.v 80 | 81 | self.adv = (self.adv - np.mean(self.adv))/(np.std(self.adv) + 1e-8) 82 | 83 | def get_batch(self, batch=100, shuffle=True): 84 | if shuffle: 85 | indices = np.random.permutation(self.size) 86 | else: 87 | indices = np.arange(self.size) 88 | 89 | for idx in np.arange(0, self.size, batch): 90 | pos = indices[idx:(idx + batch)] 91 | yield (self.state[pos], self.action[pos], self.reward[pos], self.adv[pos], self.v[pos]) 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--iteration', default=int(1e3), type=int) 98 | parser.add_argument('--gamma', default=0.99, type=float) 99 | parser.add_argument('--lam', default=0.95, type=float) 100 | parser.add_argument('--c_en', default=0.01, type=float) 101 | parser.add_argument('--c_vf', default=0.5, type=float) 102 | parser.add_argument('--a_update', default=10, type=int) 103 | parser.add_argument('--lr', default=3e-4, type=float) 104 | parser.add_argument('--log', type=str, default="logs") 105 | parser.add_argument('--steps', default=3000, type=int) 106 | parser.add_argument('--gpu', default=0, type=int) 107 | parser.add_argument('--env', default="CartPole-v1") 108 | parser.add_argument('--env_num', default=4, type=int) 109 | parser.add_argument('--exp_name', default="ppo_cartpole") 110 | parser.add_argument('--seed', default=0, type=int) 111 | parser.add_argument('--batch', default=50, type=int) 112 | parser.add_argument('--norm_state', action="store_true") 113 | parser.add_argument('--norm_rewards', default=None, type=str) 114 | parser.add_argument('--is_clip_v', action="store_true") 115 | parser.add_argument('--last_v', action="store_true") 116 | parser.add_argument('--is_gae', action="store_true") 117 | parser.add_argument('--max_grad_norm', default=-1, type=float) 118 | parser.add_argument('--anneal_lr', action="store_true") 119 | parser.add_argument('--debug', action="store_false") 120 | parser.add_argument('--log_every', default=10, type=int) 121 | parser.add_argument('--target_kl', default=0.03, type=float) 122 | parser.add_argument('--test_epoch', default=10, type=int) 123 | args = parser.parse_args() 124 | 125 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 126 | 127 | from utils.run_utils import setup_logger_kwargs 128 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 129 | logger = EpochLogger(**logger_kwargs) 130 | with open(os.path.join(logger.output_dir, 'args.json'), 'w') as f: 131 | json.dump(vars(args), f, sort_keys=True, indent=4) 132 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 133 | 134 | env = gym.make(args.env) 135 | test_env = gym.make(args.env) 136 | torch.manual_seed(args.seed) 137 | np.random.seed(args.seed) 138 | env.seed(args.seed) 139 | test_env.seed(args.seed) 140 | 141 | state_dim = env.observation_space.shape 142 | if type(env.action_space) == gym.spaces.Discrete: 143 | act_dim = env.action_space.n 144 | action_max = 1 145 | else: 146 | act_dim = env.action_space.shape 147 | action_max = env.action_space.high[0] 148 | ppo = core.PPO(state_dim, act_dim, action_max, 0.2, device, lr_a=args.lr, max_grad_norm=args.max_grad_norm, 149 | anneal_lr=args.anneal_lr, train_steps=args.iteration, c_en=args.c_en, c_vf=args.c_vf) 150 | replay = ReplayBuffer(args.steps, state_dim, act_dim, is_gae=args.is_gae) 151 | 152 | state_norm = Identity() 153 | reward_norm = Identity() 154 | if args.norm_state: 155 | state_norm = AutoNormalization(state_norm, state_dim, clip=10.0) 156 | if args.norm_rewards == "rewards": 157 | reward_norm = AutoNormalization(reward_norm, (), clip=10.0) 158 | elif args.norm_rewards == "returns": 159 | reward_norm = RewardFilter(reward_norm, (), clip=10.0) 160 | 161 | state_norm.reset() 162 | reward_norm.reset() 163 | obs = env.reset() 164 | obs = state_norm(obs) 165 | rew = 0 166 | for iter in range(args.iteration): 167 | ppo.train() 168 | replay.reset() 169 | 170 | for step in range(args.steps): 171 | state_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) 172 | a_tensor = ppo.actor.select_action(state_tensor) 173 | a = a_tensor.detach().cpu().numpy() 174 | obs_, r, done, _ = env.step(a) 175 | rew += r 176 | r = reward_norm(r) 177 | mask = 1-done 178 | 179 | replay.add(obs, a, r, mask) 180 | 181 | obs = obs_ 182 | if done: 183 | logger.store(reward=rew) 184 | rew = 0 185 | obs = env.reset() 186 | state_norm.reset() 187 | reward_norm.reset() 188 | obs = state_norm(obs) 189 | 190 | state = replay.state 191 | for idx in np.arange(0, state.shape[0], args.batch): 192 | if idx + args.batch <= state.shape[0]: 193 | pos = np.arange(idx, idx + args.batch) 194 | else: 195 | pos = np.arange(idx, state.shape[0]) 196 | s = torch.tensor(state[pos], dtype=torch.float32).to(device) 197 | v = ppo.getV(s).detach().cpu().numpy() 198 | replay.update_v(v, pos) 199 | if args.last_v: 200 | s_tensor = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0) 201 | last_v = ppo.getV(s_tensor).detach().cpu().numpy() 202 | replay.finish_path(last_v=last_v) 203 | else: 204 | replay.finish_path() 205 | 206 | ppo.update_a() 207 | for i in range(args.a_update): 208 | for (s, a, r, adv, v) in replay.get_batch(batch=args.batch): 209 | s_tensor = torch.tensor(s, dtype=torch.float32, device=device) 210 | a_tensor = torch.tensor(a, dtype=torch.float32, device=device) 211 | adv_tensor = torch.tensor(adv, dtype=torch.float32, device=device) 212 | r_tensor = torch.tensor(r, dtype=torch.float32, device=device) 213 | v_tensor = torch.tensor(v, dtype=torch.float32, device=device) 214 | 215 | info = ppo.train_ac(s_tensor, a_tensor, adv_tensor, r_tensor, v_tensor, is_clip_v=args.is_clip_v) 216 | 217 | if args.debug: 218 | logger.store(aloss=info["aloss"]) 219 | logger.store(vloss=info["vloss"]) 220 | logger.store(entropy=info["entropy"]) 221 | logger.store(kl=info["kl"]) 222 | 223 | if logger.get_stats("kl", with_min_and_max=True)[3] > args.target_kl: 224 | print("stop at:", str(i)) 225 | break 226 | 227 | if args.anneal_lr: 228 | ppo.lr_scheduler() 229 | 230 | ppo.eval() 231 | test_a = [] 232 | test_a_std = [] 233 | for i in range(args.test_epoch): 234 | test_obs = test_env.reset() 235 | test_obs = state_norm(test_obs, update=False) 236 | test_rew = 0 237 | 238 | while True: 239 | state_tensor = torch.tensor(test_obs, dtype=torch.float32, device=device).unsqueeze(0) 240 | if type(env.action_space) == gym.spaces.Discrete: 241 | a_tensor = ppo.actor(state_tensor) 242 | a_tensor = torch.argmax(a_tensor, dim=1) 243 | else: 244 | a_tensor, std = ppo.actor(state_tensor) 245 | a_tensor = torch.squeeze(a_tensor, dim=0) 246 | a = a_tensor.detach().cpu().numpy() 247 | test_obs, r, done, _ = test_env.step(a) 248 | test_rew += r 249 | 250 | test_a.append(a) 251 | test_a_std.append(std.detach().cpu().numpy()) 252 | 253 | if done: 254 | logger.store(test_reward=test_rew) 255 | break 256 | test_obs = state_norm(test_obs, update=False) 257 | 258 | writer.add_scalar("test_reward", logger.get_stats("test_reward")[0], global_step=iter) 259 | writer.add_scalar("reward", logger.get_stats("reward")[0], global_step=iter) 260 | writer.add_histogram("action", np.array(replay.action), global_step=iter) 261 | writer.add_histogram("test_action", np.array(test_a), global_step=iter) 262 | writer.add_histogram("test_action_std", np.array(test_a_std), global_step=iter) 263 | 264 | if args.debug: 265 | writer.add_scalar("aloss", logger.get_stats("aloss")[0], global_step=iter) 266 | writer.add_scalar("vloss", logger.get_stats("vloss")[0], global_step=iter) 267 | writer.add_scalar("entropy", logger.get_stats("entropy")[0], global_step=iter) 268 | writer.add_scalar("kl", logger.get_stats("kl")[0], global_step=iter) 269 | 270 | logger.log_tabular('Epoch', iter) 271 | logger.log_tabular("reward", with_min_and_max=True) 272 | logger.log_tabular("test_reward", with_min_and_max=True) 273 | if args.debug: 274 | logger.log_tabular("aloss", with_min_and_max=True) 275 | logger.log_tabular("vloss", with_min_and_max=True) 276 | logger.log_tabular("entropy", with_min_and_max=True) 277 | logger.log_tabular("kl", with_min_and_max=True) 278 | logger.dump_tabular() 279 | 280 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 281 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 282 | if iter % args.log_every == 0: 283 | state = { 284 | "actor": ppo.actor.state_dict(), 285 | "critic": ppo.critic.state_dict(), 286 | 287 | } 288 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 289 | norm = {"state": state_norm, "reward": reward_norm} 290 | with open(os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pkl'), "wb") as f: 291 | pickle.dump(norm, f) 292 | 293 | 294 | -------------------------------------------------------------------------------- /algos/ppo/test_ppo_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import argparse 4 | import scipy 5 | from scipy import signal 6 | 7 | import os 8 | from utils.logx import EpochLogger 9 | import torch 10 | import dmc2gym 11 | from collections import deque 12 | from env.dmc_env import DMCFrameStack 13 | from utils.normalization import * 14 | from utils.video import VideoRecorder 15 | from user_config import DEFAULT_VIDEO_DIR 16 | import pickle 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--domain_name', default='cheetah') 22 | parser.add_argument('--task_name', default='run') 23 | parser.add_argument('--image_size', default=84, type=int) 24 | parser.add_argument('--action_repeat', default=1, type=int) 25 | parser.add_argument('--frame_stack', default=3, type=int) 26 | parser.add_argument('--encoder_type', default='pixel', type=str) 27 | 28 | parser.add_argument('--exp_name', default="ppo_test_cheetah_run_clipv_maxgrad_anneallr3e-4_normal_maxkl0.05_gae_norm-state_return_steps2048_batch256_notdone_lastv_4_entropy_update30") 29 | parser.add_argument('--seed', default=10, type=int) 30 | parser.add_argument('--norm_state', default=True, type=bool) 31 | parser.add_argument('--norm_rewards', default=True, type=bool) 32 | parser.add_argument('--check_num', default=900, type=int) 33 | parser.add_argument('--test_num', default=10, type=int) 34 | args = parser.parse_args() 35 | 36 | # env = gym.make("Hopper-v2") 37 | env = dmc2gym.make( 38 | domain_name=args.domain_name, 39 | task_name=args.task_name, 40 | seed=args.seed, 41 | visualize_reward=False, 42 | from_pixels=(args.encoder_type == 'pixel'), 43 | height=args.image_size, 44 | width=args.image_size, 45 | frame_skip=args.action_repeat 46 | ) 47 | if args.encoder_type == 'pixel': 48 | env = DMCFrameStack(env, k=args.frame_stack) 49 | import algos.ppo.core_cnn_torch as core 50 | else: 51 | import algos.ppo.core_torch as core 52 | 53 | state_dim = env.observation_space.shape 54 | act_dim = env.action_space.shape[0] 55 | action_max = env.action_space.high[0] 56 | device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu") 57 | 58 | from utils.run_utils import setup_logger_kwargs 59 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 60 | 61 | actor = core.Actor(state_dim, act_dim, action_max).to(device) 62 | checkpoint = torch.load(os.path.join(logger_kwargs["output_dir"], "checkpoints", str(args.check_num) + '.pth')) 63 | actor.load_state_dict(checkpoint["actor"]) 64 | 65 | state_norm = Identity() 66 | state_norm = ImageProcess(state_norm) 67 | reward_norm = Identity() 68 | file = os.path.join(logger_kwargs["output_dir"], "checkpoints", str(args.check_num) + '.pkl') 69 | with open(file, "rb") as f: 70 | norm = pickle.load(f) 71 | if args.norm_state: 72 | state_norm = norm["state"] 73 | if args.norm_rewards: 74 | reward_norm = norm["reward"] 75 | 76 | out_file = os.path.join(DEFAULT_VIDEO_DIR, args.exp_name) 77 | if not os.path.exists(DEFAULT_VIDEO_DIR): 78 | os.mkdir(DEFAULT_VIDEO_DIR) 79 | if not os.path.exists(out_file): 80 | os.mkdir(out_file) 81 | 82 | video = VideoRecorder(out_file) 83 | rew_list = [] 84 | for i in range(args.test_num): 85 | video.init() 86 | obs = env.reset() 87 | obs = state_norm(obs) 88 | rew = 0 89 | 90 | while True: 91 | # env.render() 92 | video.record(env) 93 | state_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) 94 | a, var = actor(state_tensor) 95 | logpi = actor.log_pi(state_tensor, a) 96 | a = torch.squeeze(a, 0).detach().cpu().numpy() 97 | obs, r, d, _ = env.step(a) 98 | 99 | rew += r 100 | obs = state_norm(obs) 101 | if d: 102 | rew_list.append(rew) 103 | print("reward", rew) 104 | video.save(str(i) + ".mp4") 105 | 106 | if (i+1) % 10 == 0: 107 | print("teset_", np.mean(rew_list)) 108 | rew_list = [] 109 | 110 | break 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /algos/ppo/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Can be used to convert rewards into discounted returns: 4 | # ret[i] = sum of t = i to T of gamma^(t-i) * rew[t] 5 | def discount_path(path, gamma): 6 | ''' 7 | Given a "path" of items x_1, x_2, ... x_n, return the discounted 8 | path, i.e. 9 | X_1 = x_1 + h*x_2 + h^2 x_3 + h^3 x_4 10 | X_2 = x_2 + h*x_3 + h^2 x_4 + h^3 x_5 11 | etc. 12 | Can do (more efficiently?) w SciPy. Python here for readability 13 | Inputs: 14 | - path, list/tensor of floats 15 | - h, discount rate 16 | Outputs: 17 | - Discounted path, as above 18 | ''' 19 | curr = 0 20 | rets = [] 21 | for i in range(len(path)): 22 | curr = curr*gamma + path[-1-i] 23 | rets.append(curr) 24 | rets = np.stack(list(reversed(rets)), 0) 25 | return rets 26 | 27 | 28 | def get_path_indices(not_dones): 29 | """ 30 | Returns list of tuples of the form: 31 | (agent index, time index start, time index end + 1) 32 | For each path seen in the not_dones array of shape (# agents, # time steps) 33 | E.g. if we have an not_dones of composition: 34 | tensor([[1, 1, 0, 1, 1, 1, 1, 1, 1, 1], 35 | [1, 1, 0, 1, 1, 0, 1, 1, 0, 1]], dtype=torch.uint8) 36 | Then we would return: 37 | [(0, 0, 3), (0, 3, 10), (1, 0, 3), (1, 3, 5), (1, 5, 9), (1, 9, 10)] 38 | """ 39 | indices = [] 40 | num_timesteps = not_dones.shape[0] 41 | last_index = 0 42 | for i in range(num_timesteps): 43 | if not_dones[i] == 0.: 44 | indices.append((last_index, i + 1)) 45 | last_index = i + 1 46 | if last_index != num_timesteps: 47 | indices.append((last_index, num_timesteps)) 48 | return indices 49 | -------------------------------------------------------------------------------- /algos/vae/core_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | import numpy as np 4 | 5 | def initialize_weights(mod, initialization_type, scale=1): 6 | ''' 7 | Weight initializer for the models. 8 | Inputs: A model, Returns: none, initializes the parameters 9 | ''' 10 | for p in mod.parameters(): 11 | if initialization_type == "normal": 12 | p.data.normal_(0.01) 13 | elif initialization_type == "xavier": 14 | if len(p.data.shape) >= 2: 15 | torch.nn.init.xavier_uniform_(p.data) 16 | else: 17 | p.data.zero_() 18 | elif initialization_type == "orthogonal": 19 | if len(p.data.shape) >= 2: 20 | torch.nn.init.orthogonal_(p.data, gain=scale) 21 | else: 22 | p.data.zero_() 23 | else: 24 | raise ValueError("Need a valid initialization key") 25 | 26 | class Critic(torch.nn.Module): 27 | def __init__(self, s_dim, emb_dim=50): 28 | super().__init__() 29 | 30 | self.fc = torch.nn.Sequential( 31 | torch.nn.Linear(emb_dim, 1), 32 | ) 33 | 34 | initialize_weights(self.fc, "orthogonal", scale=1) 35 | 36 | def forward(self, s_emb): 37 | v = self.fc(s_emb) 38 | return v 39 | 40 | 41 | class Actor(torch.nn.Module): 42 | def __init__(self, s_dim, a_dim, a_max=1, emb_dim=50): 43 | super().__init__() 44 | self.act_dim = a_dim 45 | self.a_max = a_max 46 | 47 | self.mu = torch.nn.Sequential( 48 | torch.nn.Linear(emb_dim, 64), 49 | torch.nn.ReLU(), 50 | torch.nn.Linear(64, a_dim), 51 | torch.nn.Tanh() 52 | ) 53 | self.var = torch.nn.Parameter(torch.zeros(1, a_dim)) 54 | 55 | initialize_weights(self.mu, "orthogonal", scale=0.01) 56 | 57 | def forward(self, emb): 58 | mu = self.mu(emb)*self.a_max 59 | 60 | return mu, torch.exp(self.var) 61 | 62 | def select_action(self, emb): 63 | with torch.no_grad(): 64 | mean, std = self(emb) 65 | 66 | normal = Normal(mean, std) 67 | action = normal.sample() # [1, a_dim] 68 | action = torch.squeeze(action, dim=0) # [a_dim] 69 | 70 | return action 71 | 72 | def log_pi(self, emb, a): 73 | mean, var = self(emb) 74 | std = torch.exp(var) 75 | 76 | normal = Normal(mean, std) 77 | logpi = normal.log_prob(a) 78 | logpi = torch.sum(logpi, dim=1) 79 | entropy = normal.entropy().sum(dim=1) 80 | 81 | return logpi, entropy # [None,] 82 | 83 | class ActorDisc(torch.nn.Module): 84 | def __init__(self, s_dim, a_num, emb_dim=512): 85 | super().__init__() 86 | self.act_num = a_num 87 | 88 | self.final = torch.nn.Sequential( 89 | torch.nn.Linear(emb_dim, a_num), 90 | torch.nn.Softmax() 91 | ) 92 | 93 | initialize_weights(self.emb, "orthogonal", scale=np.sqrt(2)) 94 | initialize_weights(self.final, "orthogonal", scale=0.01) 95 | 96 | def forward(self, s): 97 | emb = self.emb(s) 98 | x = self.final(emb) 99 | 100 | return x 101 | 102 | def select_action(self, s): 103 | with torch.no_grad(): 104 | x = self(s) 105 | normal = torch.distributions.Categorical(x) 106 | action = normal.sample() 107 | action = torch.squeeze(action, dim=0) 108 | 109 | return action 110 | 111 | def log_pi(self, s, a): 112 | x = self(s) 113 | 114 | normal = torch.distributions.Categorical(x) 115 | logpi = normal.log_prob(a) 116 | 117 | return logpi # [None,] 118 | 119 | class PPO(torch.nn.Module): 120 | def __init__(self, state_dim, act_dim, act_max, epsilon, device, lr_a=0.001, 121 | c_en=0.01, c_vf=0.5, max_grad_norm=-1, anneal_lr=False, train_steps=1000, 122 | emb_dim=50): 123 | super().__init__() 124 | if type(act_dim) == np.int64 or type(act_dim) == np.int: 125 | self.actor = ActorDisc(state_dim, act_dim, emb_dim=emb_dim).to(device) 126 | self.old_actor = ActorDisc(state_dim, act_dim, emb_dim=emb_dim).to(device) 127 | else: 128 | self.actor = Actor(state_dim, act_dim[0], act_max, emb_dim=emb_dim).to(device) 129 | self.old_actor = Actor(state_dim, act_dim[0], act_max, emb_dim=emb_dim).to(device) 130 | self.critic = Critic(state_dim, emb_dim=emb_dim).to(device) 131 | self.epsilon = epsilon 132 | self.c_en = c_en 133 | self.c_vf = c_vf 134 | 135 | self.max_grad_norm = max_grad_norm 136 | self.anneal_lr = anneal_lr 137 | 138 | self.opti = torch.optim.Adam(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr_a, eps=1e-5) 139 | 140 | if anneal_lr: 141 | lam = lambda f: 1 - f / train_steps 142 | self.opti_scheduler = torch.optim.lr_scheduler.LambdaLR(self.opti, lr_lambda=lam) 143 | 144 | def train_ac(self, s, a, adv, vs, oldv, is_clip_v=True): 145 | self.opti.zero_grad() 146 | logpi, entropy = self.actor.log_pi(s, a) 147 | old_logpi, _ = self.old_actor.log_pi(s, a) 148 | 149 | ratio = torch.exp(logpi - old_logpi) 150 | surr = ratio * adv 151 | clip_adv = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * adv 152 | aloss = -torch.mean(torch.min(surr, clip_adv)) 153 | # loss_entropy = torch.mean(torch.exp(logpi) * logpi) 154 | loss_entropy = entropy.mean() 155 | kl = torch.mean(torch.exp(logpi) * (logpi - old_logpi)) 156 | 157 | v = self.critic(s) 158 | v = torch.squeeze(v, 1) 159 | 160 | if not is_clip_v: 161 | v_loss = ((v - vs) ** 2).mean() 162 | else: 163 | clip_v = oldv + torch.clamp(v - oldv, -self.epsilon, self.epsilon) 164 | v_loss = torch.max(((v - vs) ** 2).mean(), ((clip_v - vs) ** 2).mean()) 165 | 166 | loss = aloss - loss_entropy*self.c_en + v_loss*self.c_vf 167 | loss.backward() 168 | if self.max_grad_norm != -1: 169 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 170 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 171 | self.opti.step() 172 | 173 | info = dict(vloss=v_loss.item(), aloss=aloss.item(), entropy=loss_entropy.item(), kl=kl.item(), loss=loss.item()) 174 | return info 175 | 176 | def lr_scheduler(self): 177 | if self.anneal_lr: 178 | self.opti_scheduler.step() 179 | 180 | def getV(self, emb): 181 | with torch.no_grad(): 182 | v = self.critic(emb) # [1,1] 183 | 184 | return torch.squeeze(v) # 185 | 186 | def update_a(self): 187 | self.old_actor.load_state_dict(self.actor.state_dict()) -------------------------------------------------------------------------------- /algos/vae/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from algos.vae.encoder import OUT_DIM 4 | 5 | class PixelDecoder(nn.Module): 6 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): 7 | super().__init__() 8 | 9 | self.num_layers = num_layers 10 | self.num_filters = num_filters 11 | self.out_dim = OUT_DIM[num_layers] 12 | 13 | self.fc = nn.Linear( 14 | feature_dim, num_filters * self.out_dim * self.out_dim 15 | ) 16 | 17 | self.deconvs = nn.ModuleList() 18 | 19 | for i in range(self.num_layers - 1): 20 | self.deconvs.append( 21 | nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1) 22 | ) 23 | self.deconvs.append( 24 | nn.ConvTranspose2d( 25 | num_filters, obs_shape[0], 3, stride=2, output_padding=1 26 | ) 27 | ) 28 | 29 | self.outputs = dict() 30 | 31 | def forward(self, h): 32 | h = torch.relu(self.fc(h)) 33 | self.outputs['fc'] = h 34 | 35 | deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim) 36 | self.outputs['deconv1'] = deconv 37 | 38 | for i in range(0, self.num_layers - 1): 39 | deconv = torch.relu(self.deconvs[i](deconv)) 40 | self.outputs['deconv%s' % (i + 1)] = deconv 41 | 42 | obs = self.deconvs[-1](deconv) 43 | self.outputs['obs'] = obs 44 | 45 | return obs 46 | -------------------------------------------------------------------------------- /algos/vae/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | OUT_DIM = {2: 39, 4: 35, 6: 31} 5 | 6 | 7 | class PixelEncoder(nn.Module): 8 | """Convolutional encoder of pixels observations.""" 9 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): 10 | super().__init__() 11 | 12 | assert len(obs_shape) == 3 13 | 14 | self.feature_dim = feature_dim 15 | self.num_layers = num_layers 16 | 17 | self.convs = nn.ModuleList( 18 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] 19 | ) 20 | for i in range(num_layers - 1): 21 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 22 | 23 | out_dim = OUT_DIM[num_layers] 24 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 25 | self.ln = nn.LayerNorm(self.feature_dim) 26 | 27 | self.outputs = dict() 28 | 29 | def reparameterize(self, mu, logstd): 30 | std = torch.exp(logstd) 31 | eps = torch.randn_like(std) 32 | return mu + eps * std 33 | 34 | def forward_conv(self, obs): 35 | obs = obs / 255. 36 | self.outputs['obs'] = obs 37 | 38 | conv = torch.relu(self.convs[0](obs)) 39 | self.outputs['conv1'] = conv 40 | 41 | for i in range(1, self.num_layers): 42 | conv = torch.relu(self.convs[i](conv)) 43 | self.outputs['conv%s' % (i + 1)] = conv 44 | 45 | h = conv.view(conv.size(0), -1) 46 | return h 47 | 48 | def forward(self, obs, detach=False): 49 | h = self.forward_conv(obs) 50 | 51 | if detach: 52 | h = h.detach() 53 | 54 | h_fc = self.fc(h) 55 | self.outputs['fc'] = h_fc 56 | 57 | h_norm = self.ln(h_fc) 58 | self.outputs['ln'] = h_norm 59 | 60 | out = torch.tanh(h_norm) 61 | self.outputs['tanh'] = out 62 | 63 | return out 64 | 65 | def copy_conv_weights_from(self, source): 66 | """Tie convolutional layers""" 67 | # only tie conv layers 68 | for i in range(self.num_layers): 69 | tie_weights(src=source.convs[i], trg=self.convs[i]) 70 | -------------------------------------------------------------------------------- /algos/vae/run_ppo_vae_dmc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import argparse 4 | import scipy 5 | from scipy import signal 6 | import pickle 7 | from collections import deque 8 | 9 | import os 10 | from utils.logx import EpochLogger 11 | import torch 12 | from torch.utils.tensorboard import SummaryWriter 13 | import dmc2gym 14 | import env.atari_lib as atari 15 | from env.dmc_env import DMCFrameStack 16 | from utils.normalization import * 17 | from algos.ppo.utils import discount_path, get_path_indices 18 | import json 19 | 20 | from algos.vae.encoder import PixelEncoder 21 | from algos.vae.decoder import PixelDecoder 22 | 23 | class ReplayBuffer: 24 | def __init__(self, size, state_dim, act_dim, gamma=0.99, lam=0.95, is_gae=True): 25 | self.size = size 26 | self.state_dim = state_dim 27 | self.act_dim = act_dim 28 | self.gamma = gamma 29 | self.lam = lam 30 | self.is_gae = is_gae 31 | self.reset() 32 | 33 | def reset(self): 34 | self.state = np.zeros((self.size, self.state_dim), np.float32) 35 | if type(self.act_dim) == np.int64 or type(self.act_dim) == np.int: 36 | self.action = np.zeros((self.size, ), np.int32) 37 | else: 38 | self.action = np.zeros((self.size,) + self.act_dim, np.float32) 39 | self.v = np.zeros((self.size, ), np.float32) 40 | self.reward = np.zeros((self.size, ), np.float32) 41 | self.adv = np.zeros((self.size, ), np.float32) 42 | self.mask = np.zeros((self.size, ), np.float32) 43 | self.ptr, self.path_start = 0, 0 44 | 45 | def add(self, s, a, r, mask): 46 | if self.ptr < self.size: 47 | self.state[self.ptr] = s 48 | self.action[self.ptr] = a 49 | self.reward[self.ptr] = r 50 | self.mask[self.ptr] = mask 51 | self.ptr += 1 52 | 53 | def update_v(self, v, pos): 54 | self.v[pos] = v 55 | 56 | def finish_path(self): 57 | """ 58 | Calculate GAE advantage, discounted returns, and 59 | true reward (average reward per trajectory) 60 | 61 | GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t) 62 | using formula from John Schulman's code: 63 | V(s_t+1) = {0 if s_t is terminal 64 | {v_s_{t+1} if s_t not terminal and t != T (last step) 65 | {v_s if s_t not terminal and t == T 66 | """ 67 | v_ = np.concatenate([self.v[1:], self.v[-1:]], axis=0) * self.mask 68 | adv = self.reward + self.gamma * v_ - self.v 69 | 70 | indices = get_path_indices(self.mask) 71 | 72 | for (start, end) in indices: 73 | self.adv[start:end] = discount_path(adv[start:end], self.gamma * self.lam) 74 | if not self.is_gae: 75 | self.reward[start:end] = discount_path(self.reward[start:end], self.gamma) 76 | if self.is_gae: 77 | self.reward = self.adv + self.v 78 | 79 | self.adv = (self.adv - np.mean(self.adv))/(np.std(self.adv) + 1e-8) 80 | 81 | def get_batch(self, batch=100, shuffle=True): 82 | if shuffle: 83 | indices = np.random.permutation(self.size) 84 | else: 85 | indices = np.arange(self.size) 86 | 87 | for idx in np.arange(0, self.size, batch): 88 | pos = indices[idx:(idx + batch)] 89 | yield (self.state[pos], self.action[pos], self.reward[pos], self.adv[pos], self.v[pos]) 90 | 91 | 92 | class ImageEncodeProcess: 93 | def __init__(self, pre_filter): 94 | self.pre_filter = pre_filter 95 | def __call__(self, x, update=True): 96 | x = self.pre_filter(x) 97 | x = np.array(x).astype(np.float32) 98 | x = torch.tensor(x, device=device).unsqueeze(0) 99 | x = encoder(x).detach() 100 | return x 101 | 102 | def reset(self): 103 | self.pre_filter.reset() 104 | 105 | if __name__ == '__main__': 106 | 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--domain_name', default='cheetah') 109 | parser.add_argument('--task_name', default='run') 110 | parser.add_argument('--image_size', default=84, type=int) 111 | parser.add_argument('--action_repeat', default=1, type=int) 112 | parser.add_argument('--frame_stack', default=3, type=int) 113 | parser.add_argument('--encoder_type', default='pixel', type=str) 114 | 115 | parser.add_argument('--iteration', default=int(1e3), type=int) 116 | parser.add_argument('--gamma', default=0.99, type=float) 117 | parser.add_argument('--lam', default=0.95, type=float) 118 | parser.add_argument('--a_update', default=10, type=int) 119 | parser.add_argument('--lr_a', default=2.5e-4, type=float) 120 | parser.add_argument('--c_en', default=0.01, type=float) 121 | parser.add_argument('--c_vf', default=0.5, type=float) 122 | parser.add_argument('--log', type=str, default="logs") 123 | parser.add_argument('--steps', default=3000, type=int) 124 | parser.add_argument('--gpu', default=0, type=int) 125 | parser.add_argument('--env_num', default=4, type=int) 126 | parser.add_argument('--exp_name', default="ppo_vae_cheetah_run_test") 127 | parser.add_argument('--seed', default=10, type=int) 128 | parser.add_argument('--batch', default=50, type=int) 129 | parser.add_argument('--norm_state', action="store_true") 130 | parser.add_argument('--norm_rewards', default=False) 131 | parser.add_argument('--is_clip_v', action="store_true") 132 | parser.add_argument('--max_grad_norm', default=-1, type=float) 133 | parser.add_argument('--anneal_lr', action="store_true") 134 | parser.add_argument('--debug', action="store_false") 135 | parser.add_argument('--log_every', default=10, type=int) 136 | parser.add_argument('--network', default="cnn") 137 | parser.add_argument('--feature_dim', default=50, type=int) 138 | parser.add_argument('--target_kl', default=0.03, type=float) 139 | parser.add_argument('--encoder_dir', default="vae_2") 140 | parser.add_argument('--encoder_check', default=300, type=int) 141 | parser.add_argument('--test_epoch', default=10, type=int) 142 | args = parser.parse_args() 143 | 144 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 145 | 146 | from utils.run_utils import setup_logger_kwargs 147 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 148 | logger = EpochLogger(**logger_kwargs) 149 | with open(os.path.join(logger.output_dir, 'args.json'), 'w') as f: 150 | json.dump(vars(args), f, sort_keys=True, indent=4) 151 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 152 | 153 | env = dmc2gym.make( 154 | domain_name=args.domain_name, 155 | task_name=args.task_name, 156 | seed=args.seed, 157 | visualize_reward=False, 158 | from_pixels=(args.encoder_type == 'pixel'), 159 | height=args.image_size, 160 | width=args.image_size, 161 | frame_skip=args.action_repeat 162 | ) 163 | test_env = dmc2gym.make( 164 | domain_name=args.domain_name, 165 | task_name=args.task_name, 166 | seed=args.seed, 167 | visualize_reward=False, 168 | from_pixels=(args.encoder_type == 'pixel'), 169 | height=args.image_size, 170 | width=args.image_size, 171 | frame_skip=args.action_repeat 172 | ) 173 | if args.encoder_type == 'pixel': 174 | env = DMCFrameStack(env, k=args.frame_stack) 175 | test_env = DMCFrameStack(test_env, k=args.frame_stack) 176 | torch.manual_seed(args.seed) 177 | np.random.seed(args.seed) 178 | env.seed(args.seed) 179 | test_env.seed(args.seed) 180 | 181 | state_dim = env.observation_space.shape 182 | act_dim = env.action_space.shape 183 | action_max = env.action_space.high[0] 184 | if args.network == "cnn": 185 | import algos.vae.core_vae as core 186 | ppo = core.PPO(state_dim, act_dim, action_max, 0.2, device, lr_a=args.lr_a, 187 | max_grad_norm=args.max_grad_norm, 188 | anneal_lr=args.anneal_lr, train_steps=args.iteration, emb_dim=args.feature_dim, 189 | c_en=args.c_en, c_vf=args.c_vf) 190 | 191 | replay = ReplayBuffer(args.steps, args.feature_dim, act_dim) 192 | encoder = PixelEncoder(state_dim, args.feature_dim, num_layers=4).to(device) 193 | encoder_kwargs = setup_logger_kwargs(args.encoder_dir, args.seed) 194 | encoder_file = os.path.join(encoder_kwargs["output_dir"], "checkpoints", str(args.encoder_check) + ".pth") 195 | check = torch.load(encoder_file) 196 | encoder.load_state_dict(check["encoder"]) 197 | 198 | state_norm = Identity() 199 | state_norm = ImageEncodeProcess(state_norm) 200 | reward_norm = Identity() 201 | if args.norm_state: 202 | state_norm = AutoNormalization(state_norm, state_dim, clip=5.0) 203 | if args.norm_rewards == "rewards": 204 | reward_norm = AutoNormalization(reward_norm, (), clip=5.0) 205 | elif args.norm_rewards == "returns": 206 | reward_norm = RewardFilter(reward_norm, (), clip=5.0) 207 | 208 | state_norm.reset() 209 | reward_norm.reset() 210 | obs = env.reset() 211 | obs = state_norm(obs) 212 | rew = 0 213 | for iter in range(args.iteration): 214 | ppo.train() 215 | replay.reset() 216 | 217 | for step in range(args.steps): 218 | a_tensor = ppo.actor.select_action(obs) 219 | a = a_tensor.detach().cpu().numpy() 220 | a = np.clip(a, -1, 1) 221 | obs_, r, done, _ = env.step(a) 222 | obs_ = state_norm(obs_) 223 | rew += r 224 | r = reward_norm(r) 225 | 226 | mask = 1-done 227 | replay.add(obs.cpu().numpy()[0], a, r, mask) 228 | 229 | obs = obs_ 230 | if done: 231 | logger.store(reward=rew) 232 | rew = 0 233 | state_norm.reset() 234 | reward_norm.reset() 235 | obs = env.reset() 236 | obs = state_norm(obs) 237 | 238 | state = replay.state 239 | for idx in np.arange(0, state.shape[0], args.batch): 240 | if idx + args.batch <= state.shape[0]: 241 | pos = np.arange(idx, idx + args.batch) 242 | else: 243 | pos = np.arange(idx, state.shape[0]) 244 | s = torch.tensor(state[pos], dtype=torch.float32).to(device) 245 | v = ppo.getV(s).detach().cpu().numpy() 246 | replay.update_v(v, pos) 247 | if args.last_v: 248 | s_tensor = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0) 249 | last_v = ppo.getV(s_tensor).detach().cpu().numpy() 250 | replay.finish_path(last_v=last_v) 251 | else: 252 | replay.finish_path() 253 | ppo.update_a() 254 | 255 | for i in range(args.a_update): 256 | for (s, a, r, adv, v) in replay.get_batch(batch=args.batch): 257 | s_tensor = torch.tensor(s, dtype=torch.float32, device=device) 258 | a_tensor = torch.tensor(a, dtype=torch.float32, device=device) 259 | adv_tensor = torch.tensor(adv, dtype=torch.float32, device=device) 260 | r_tensor = torch.tensor(r, dtype=torch.float32, device=device) 261 | v_tensor = torch.tensor(v, dtype=torch.float32, device=device) 262 | 263 | info = ppo.train_ac(s_tensor, a_tensor, adv_tensor, r_tensor, v_tensor, is_clip_v=args.is_clip_v) 264 | 265 | if args.debug: 266 | logger.store(aloss=info["aloss"]) 267 | logger.store(vloss=info["vloss"]) 268 | logger.store(entropy=info["entropy"]) 269 | logger.store(kl=info["kl"]) 270 | 271 | if logger.get_stats("kl", with_min_and_max=True)[3] > args.target_kl: 272 | print("stop at:", str(i)) 273 | break 274 | 275 | if args.anneal_lr: 276 | ppo.lr_scheduler() 277 | 278 | ppo.eval() 279 | for i in range(args.test_epoch): 280 | test_obs = test_env.reset() 281 | test_obs = state_norm(test_obs, update=False) 282 | test_rew = 0 283 | 284 | while True: 285 | a_tensor, var = ppo.actor(test_obs) 286 | a_tensor = torch.squeeze(a_tensor, dim=0) 287 | a = a_tensor.detach().cpu().numpy() 288 | test_obs, r, done, _ = test_env.step(np.clip(a, -1, 1)) 289 | test_rew += r 290 | 291 | if done: 292 | logger.store(test_reward=test_rew) 293 | break 294 | test_obs = state_norm(test_obs, update=False) 295 | 296 | writer.add_scalar("test_reward", logger.get_stats("test_reward")[0], global_step=iter) 297 | writer.add_scalar("reward", logger.get_stats("reward")[0], global_step=iter) 298 | writer.add_histogram("action", np.array(replay.action), global_step=iter) 299 | if args.debug: 300 | writer.add_scalar("aloss", logger.get_stats("aloss")[0], global_step=iter) 301 | writer.add_scalar("vloss", logger.get_stats("vloss")[0], global_step=iter) 302 | writer.add_scalar("entropy", logger.get_stats("entropy")[0], global_step=iter) 303 | writer.add_scalar("kl", logger.get_stats("kl")[0], global_step=iter) 304 | 305 | logger.log_tabular('Epoch', iter) 306 | logger.log_tabular("reward", with_min_and_max=True) 307 | logger.log_tabular("test_reward", with_min_and_max=True) 308 | if args.debug: 309 | logger.log_tabular("aloss", with_min_and_max=True) 310 | logger.log_tabular("vloss", with_min_and_max=True) 311 | logger.log_tabular("entropy", with_min_and_max=True) 312 | logger.log_tabular("kl", with_min_and_max=True) 313 | logger.dump_tabular() 314 | 315 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 316 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 317 | if iter % args.log_every == 0: 318 | state = { 319 | "actor": ppo.actor.state_dict(), 320 | "critic": ppo.critic.state_dict(), 321 | 322 | } 323 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 324 | norm = {"state": state_norm, "reward": reward_norm} 325 | with open(os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pkl'), "wb") as f: 326 | pickle.dump(norm, f) 327 | 328 | 329 | -------------------------------------------------------------------------------- /algos/vae/run_vae.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from algos.vae.encoder import PixelEncoder 3 | from algos.vae.decoder import PixelDecoder 4 | from torch.utils.data import Dataset 5 | from env.dmc_env import DMCFrameStack 6 | import torch.nn.functional as F 7 | import torch 8 | import argparse 9 | import dmc2gym 10 | import os 11 | from utils.logx import EpochLogger 12 | from torch.utils.tensorboard import SummaryWriter 13 | import numpy as np 14 | import json 15 | 16 | class ExpertDataset(Dataset): 17 | def __init__(self, data, device): 18 | self.data = data 19 | self.device = device 20 | 21 | def __getitem__(self, item): 22 | expert = {} 23 | 24 | expert["obs"] = torch.tensor(self.data["obs"][item], dtype=torch.float32) 25 | expert["action"] = torch.tensor(self.data["action"][item], dtype=torch.float32) 26 | 27 | return expert 28 | 29 | def __len__(self): 30 | return len(self.data["obs"]) 31 | 32 | def preprocess_obs(obs, bits=5): 33 | """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" 34 | bins = 2**bits 35 | assert obs.dtype == torch.float32 36 | if bits < 8: 37 | obs = torch.floor(obs / 2**(8 - bits)) 38 | obs = obs / bins 39 | obs = obs + torch.rand_like(obs) / bins 40 | obs = obs - 0.5 41 | return obs 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--domain_name', default='cheetah') 47 | parser.add_argument('--task_name', default='run') 48 | parser.add_argument('--image_size', default=84, type=int) 49 | parser.add_argument('--action_repeat', default=1, type=int) 50 | parser.add_argument('--frame_stack', default=3, type=int) 51 | parser.add_argument('--encoder_type', default='pixel', type=str) 52 | 53 | parser.add_argument('--exp_name', default="ppo_cheetah_run_clipv_maxgrad_anneallr2.5e-3_stack3_normal_state01_maxkl0.03_gae") 54 | parser.add_argument('--expert_num', default=10, type=int) 55 | parser.add_argument('--seed', default=10, type=int) 56 | parser.add_argument('--gpu', default=0) 57 | parser.add_argument('--batch', default=50, type=int) 58 | parser.add_argument('--encoder_lr', default=1e-3, type=float) 59 | parser.add_argument('--decoder_lr', default=1e-3, type=float) 60 | parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float) 61 | parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float) 62 | parser.add_argument('--epoch', default=1000, type=int) 63 | parser.add_argument('--out_dir', default="vae_test") 64 | parser.add_argument('--log_every', default=10, type=int) 65 | parser.add_argument('--feature_dim', default=50, type=int) 66 | parser.add_argument('--num_workers', default=8, type=int) 67 | args = parser.parse_args() 68 | 69 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 70 | 71 | env = dmc2gym.make( 72 | domain_name=args.domain_name, 73 | task_name=args.task_name, 74 | seed=args.seed, 75 | visualize_reward=False, 76 | from_pixels=(args.encoder_type == 'pixel'), 77 | height=args.image_size, 78 | width=args.image_size, 79 | frame_skip=args.action_repeat 80 | ) 81 | if args.encoder_type == 'pixel': 82 | env = DMCFrameStack(env, k=args.frame_stack) 83 | torch.manual_seed(args.seed) 84 | np.random.seed(args.seed) 85 | env.seed(args.seed) 86 | state_dim = env.observation_space.shape 87 | 88 | from utils.run_utils import setup_logger_kwargs 89 | logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) 90 | 91 | expert_data_file = os.path.join(logger_kwargs["output_dir"], "experts") 92 | with open(os.path.join(expert_data_file, 93 | args.domain_name + "_" + args.task_name + "_epoch" + str(args.expert_num) + ".pkl"), "rb") as f: 94 | expert_data = pickle.load(f) 95 | 96 | out_kwargs = setup_logger_kwargs(args.out_dir, args.seed) 97 | logger = EpochLogger(**out_kwargs) 98 | writer = SummaryWriter(os.path.join(logger.output_dir, "logs")) 99 | with open(os.path.join(logger.output_dir, 'args.json'), 'w') as f: 100 | json.dump(vars(args), f, sort_keys=True, indent=4) 101 | if not os.path.exists(os.path.join(logger.output_dir, "checkpoints")): 102 | os.makedirs(os.path.join(logger.output_dir, "checkpoints")) 103 | 104 | encoder = PixelEncoder(state_dim, args.feature_dim, num_layers=4).to(device) 105 | decoder = PixelDecoder(state_dim, args.feature_dim, num_layers=4).to(device) 106 | encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.encoder_lr) 107 | decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.decoder_lr, weight_decay=args.decoder_weight_lambda) 108 | 109 | expert_dataset = ExpertDataset(expert_data, device=device) 110 | expert_loader = torch.utils.data.DataLoader(expert_dataset, 111 | batch_size=args.batch, shuffle=True, num_workers=args.num_workers) 112 | 113 | for iter in range(args.epoch): 114 | for expert in expert_loader: 115 | obs = expert["obs"].to(device) 116 | h = encoder(obs) 117 | rec_obs = decoder(h) 118 | 119 | target_obs = obs.clone() 120 | target_obs = preprocess_obs(obs) 121 | 122 | rec_loss = F.mse_loss(target_obs, rec_obs) 123 | # add L2 penalty on latent representation 124 | # see https://arxiv.org/pdf/1903.12436.pdf 125 | latent_loss = (0.5 * h.pow(2).sum(1)).mean() 126 | loss = rec_loss + args.decoder_latent_lambda * latent_loss 127 | 128 | encoder_optimizer.zero_grad() 129 | decoder_optimizer.zero_grad() 130 | loss.backward() 131 | 132 | encoder_optimizer.step() 133 | decoder_optimizer.step() 134 | 135 | logger.store(loss=loss) 136 | 137 | writer.add_scalar("loss", logger.get_stats("loss")[0], global_step=iter) 138 | logger.log_tabular('Epoch', iter) 139 | logger.log_tabular("loss", with_min_and_max=True) 140 | logger.dump_tabular() 141 | 142 | if iter % args.log_every == 0: 143 | state = { 144 | "encoder": encoder.state_dict(), 145 | "decoder": decoder.state_dict(), 146 | } 147 | 148 | torch.save(state, os.path.join(logger.output_dir, "checkpoints", str(iter) + '.pth')) 149 | 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /data/dqn/dqn_s0/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "env_name": "SeaquestNoFrameskip-v4", 4 | "epsilon_decay_period": 250000, 5 | "epsilon_eval": 0.01, 6 | "epsilon_train": 0.1, 7 | "evaluation_step": 125000, 8 | "exp_name": "dqn", 9 | "gamma": 0.99, 10 | "iteration": 200, 11 | "logger_kwargs": { 12 | "exp_name": "dqn", 13 | "output_dir": "/home/zhr/workspace/Value-based-rl/data/dqn/dqn_s0" 14 | }, 15 | "max_ep_len": 6750.0, 16 | "replay_size": 1000000.0, 17 | "self": { 18 | "<__main__.Dqn object at 0x7f31db6c4ef0>": { 19 | "logger": { 20 | "": { 21 | "epoch_dict": {}, 22 | "exp_name": "dqn", 23 | "first_row": true, 24 | "log_current_row": {}, 25 | "log_headers": [], 26 | "output_dir": "/home/zhr/workspace/Value-based-rl/data/dqn/dqn_s0", 27 | "output_file": { 28 | "<_io.TextIOWrapper name='/home/zhr/workspace/Value-based-rl/data/dqn/dqn_s0/progress.txt' mode='a+' encoding='UTF-8'>": { 29 | "mode": "a+" 30 | } 31 | } 32 | } 33 | } 34 | } 35 | }, 36 | "target_update_period": 8000, 37 | "train_step": 62500.0, 38 | "update_period": 4, 39 | "warmup_steps": 20000 40 | } -------------------------------------------------------------------------------- /data/dqn/dqn_s0/progress.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn/dqn_s0/progress.txt -------------------------------------------------------------------------------- /data/dqn_cartpole/dqn_cartpole_s0/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "env_name": "CartPole-v1", 4 | "epsilon_decay_period": 100, 5 | "epsilon_eval": 0.01, 6 | "epsilon_train": 0.1, 7 | "evaluation_step": 1000, 8 | "exp_name": "dqn_cartpole", 9 | "gamma": 0.99, 10 | "iteration": 200, 11 | "logger_kwargs": { 12 | "exp_name": "dqn_cartpole", 13 | "output_dir": "/home/zhr/workspace/pycharm/Value-based-rl/data/dqn_cartpole/dqn_cartpole_s0" 14 | }, 15 | "max_ep_len": 200, 16 | "replay_size": 1000000.0, 17 | "self": { 18 | "<__main__.Dqn object at 0x7f02ec92db70>": { 19 | "logger": { 20 | "": { 21 | "epoch_dict": {}, 22 | "exp_name": "dqn_cartpole", 23 | "first_row": true, 24 | "log_current_row": {}, 25 | "log_headers": [], 26 | "output_dir": "/home/zhr/workspace/pycharm/Value-based-rl/data/dqn_cartpole/dqn_cartpole_s0", 27 | "output_file": { 28 | "<_io.TextIOWrapper name='/home/zhr/workspace/pycharm/Value-based-rl/data/dqn_cartpole/dqn_cartpole_s0/progress.txt' mode='w+' encoding='UTF-8'>": { 29 | "mode": "w+" 30 | } 31 | } 32 | } 33 | } 34 | } 35 | }, 36 | "target_update_period": 50, 37 | "train_step": 200, 38 | "update_period": 10, 39 | "warmup_steps": 0 40 | } -------------------------------------------------------------------------------- /data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566310371.zhr-AERO-15WV8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566310371.zhr-AERO-15WV8 -------------------------------------------------------------------------------- /data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566310768.zhr-AERO-15WV8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566310768.zhr-AERO-15WV8 -------------------------------------------------------------------------------- /data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566313578.zhr-AERO-15WV8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn_cartpole/dqn_cartpole_s0/logs/events.out.tfevents.1566313578.zhr-AERO-15WV8 -------------------------------------------------------------------------------- /data/dqn_cartpole/dqn_cartpole_s0/progress.txt: -------------------------------------------------------------------------------- 1 | iter Averagereward Stdreward Maxreward Minreward Averagestep Stdstep Maxstep Minstep 2 | 1 10.578947 2.7205997 18.0 8.0 10.578947 2.7205997 18.0 8.0 3 | 2 9.454545 0.5821022 10.0 8.0 9.454545 0.5821022 10.0 8.0 4 | 3 9.666667 1.356934 14.0 8.0 9.666667 1.356934 14.0 8.0 5 | 4 10.1 1.6999999 14.0 8.0 10.1 1.6999999 14.0 8.0 6 | 5 11.666667 2.728451 20.0 9.0 11.666667 2.728451 20.0 9.0 7 | 6 14.428572 3.5196242 21.0 9.0 14.428572 3.5196242 21.0 9.0 8 | 7 43.8 30.986448 94.0 10.0 43.8 30.986448 94.0 10.0 9 | 8 23.0 10.34945 43.0 9.0 23.0 10.34945 43.0 9.0 10 | 9 25.5 9.110434 39.0 11.0 25.5 9.110434 39.0 11.0 11 | 10 20.9 8.490583 36.0 11.0 20.9 8.490583 36.0 11.0 12 | 11 23.777779 14.950123 59.0 9.0 23.777779 14.950123 59.0 9.0 13 | 12 19.636364 10.377185 46.0 9.0 19.636364 10.377185 46.0 9.0 14 | 13 20.6 11.262327 44.0 10.0 20.6 11.262327 44.0 10.0 15 | 14 25.25 15.204851 61.0 12.0 25.25 15.204851 61.0 12.0 16 | 15 20.9 6.0074954 31.0 11.0 20.9 6.0074954 31.0 11.0 17 | 16 28.0 18.96781 59.0 9.0 28.0 18.96781 59.0 9.0 18 | 17 55.5 36.017357 108.0 10.0 55.5 36.017357 108.0 10.0 19 | 18 43.6 18.884914 67.0 23.0 43.6 18.884914 67.0 23.0 20 | 19 64.25 34.766182 123.0 37.0 64.25 34.766182 123.0 37.0 21 | 20 88.333336 28.581268 123.0 53.0 88.333336 28.581268 123.0 53.0 22 | 21 67.333336 25.629843 92.0 32.0 67.333336 25.629843 92.0 32.0 23 | 22 34.833332 24.217188 69.0 9.0 34.833332 24.217188 69.0 9.0 24 | 23 80.25 56.299976 173.0 30.0 80.25 56.299976 173.0 30.0 25 | 24 70.0 12.675436 85.0 54.0 70.0 12.675436 85.0 54.0 26 | 25 157.5 13.5 171.0 144.0 157.5 13.5 171.0 144.0 27 | 26 85.333336 64.61338 167.0 9.0 85.333336 64.61338 167.0 9.0 28 | 27 118.5 30.5 149.0 88.0 118.5 30.5 149.0 88.0 29 | 28 110.5 25.5 136.0 85.0 110.5 25.5 136.0 85.0 30 | 29 158.0 38.0 196.0 120.0 158.0 38.0 196.0 120.0 31 | 30 114.5 8.5 123.0 106.0 114.5 8.5 123.0 106.0 32 | 31 62.25 48.272015 135.0 12.0 62.25 48.272015 135.0 12.0 33 | 32 176.0 6.0 182.0 170.0 176.0 6.0 182.0 170.0 34 | 33 135.0 20.0 155.0 115.0 135.0 20.0 155.0 115.0 35 | 34 134.0 4.0 138.0 130.0 134.0 4.0 138.0 130.0 36 | 35 124.5 1.5 126.0 123.0 124.5 1.5 126.0 123.0 37 | 36 116.0 18.0 134.0 98.0 116.0 18.0 134.0 98.0 38 | 37 83.666664 6.5996633 91.0 75.0 83.666664 6.5996633 91.0 75.0 39 | 38 159.5 40.5 200.0 119.0 159.5 40.5 200.0 119.0 40 | 39 138.0 5.0 143.0 133.0 138.0 5.0 143.0 133.0 41 | 40 137.5 7.5 145.0 130.0 137.5 7.5 145.0 130.0 42 | 41 155.5 11.5 167.0 144.0 155.5 11.5 167.0 144.0 43 | 42 130.0 7.0 137.0 123.0 130.0 7.0 137.0 123.0 44 | 43 159.5 23.5 183.0 136.0 159.5 23.5 183.0 136.0 45 | 44 167.0 33.0 200.0 134.0 167.0 33.0 200.0 134.0 46 | 45 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 47 | 46 173.5 26.5 200.0 147.0 173.5 26.5 200.0 147.0 48 | 47 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 49 | 48 196.0 4.0 200.0 192.0 196.0 4.0 200.0 192.0 50 | 49 151.5 13.5 165.0 138.0 151.5 13.5 165.0 138.0 51 | 50 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 52 | 51 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 53 | 52 146.0 30.0 176.0 116.0 146.0 30.0 176.0 116.0 54 | 53 179.5 20.5 200.0 159.0 179.5 20.5 200.0 159.0 55 | 54 199.5 0.5 200.0 199.0 199.5 0.5 200.0 199.0 56 | 55 188.5 6.5 195.0 182.0 188.5 6.5 195.0 182.0 57 | 56 184.5 1.5 186.0 183.0 184.5 1.5 186.0 183.0 58 | 57 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 59 | 58 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 60 | 59 181.0 19.0 200.0 162.0 181.0 19.0 200.0 162.0 61 | 60 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 62 | 61 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 63 | 62 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 64 | 63 194.0 6.0 200.0 188.0 194.0 6.0 200.0 188.0 65 | 64 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 66 | 65 154.5 30.5 185.0 124.0 154.5 30.5 185.0 124.0 67 | 66 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 68 | 67 177.5 2.5 180.0 175.0 177.5 2.5 180.0 175.0 69 | 68 200.0 0.0 200.0 200.0 200.0 0.0 200.0 200.0 70 | -------------------------------------------------------------------------------- /data/dqn_seaquest/dqn_seaquest_s0/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "env_name": "SeaquestNoFrameskip-v4", 4 | "epsilon_decay_period": 62500.0, 5 | "epsilon_eval": 0.01, 6 | "epsilon_train": 0.1, 7 | "evaluation_step": 31250.0, 8 | "exp_name": "dqn_seaquest", 9 | "gamma": 0.99, 10 | "iteration": 200, 11 | "logger_kwargs": { 12 | "exp_name": "dqn_seaquest", 13 | "output_dir": "/home/zhr/workspace/Value-based-rl/data/dqn_seaquest/dqn_seaquest_s0" 14 | }, 15 | "max_ep_len": 6750.0, 16 | "replay_size": 1000000.0, 17 | "self": { 18 | "<__main__.Dqn object at 0x7f681f51d978>": { 19 | "logger": { 20 | "": { 21 | "epoch_dict": {}, 22 | "exp_name": "dqn_seaquest", 23 | "first_row": true, 24 | "log_current_row": {}, 25 | "log_headers": [], 26 | "output_dir": "/home/zhr/workspace/Value-based-rl/data/dqn_seaquest/dqn_seaquest_s0", 27 | "output_file": { 28 | "<_io.TextIOWrapper name='/home/zhr/workspace/Value-based-rl/data/dqn_seaquest/dqn_seaquest_s0/progress.txt' mode='w+' encoding='UTF-8'>": { 29 | "mode": "w+" 30 | } 31 | } 32 | } 33 | } 34 | } 35 | }, 36 | "target_update_period": 2000.0, 37 | "train_step": 62500.0, 38 | "update_period": 4, 39 | "warmup_steps": 5000.0 40 | } -------------------------------------------------------------------------------- /data/dqn_seaquest/dqn_seaquest_s0/logs/events.out.tfevents.1565228049.zhr-System-Product-Name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn_seaquest/dqn_seaquest_s0/logs/events.out.tfevents.1565228049.zhr-System-Product-Name -------------------------------------------------------------------------------- /data/dqn_seaquest/dqn_seaquest_s0/progress.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/data/dqn_seaquest/dqn_seaquest_s0/progress.txt -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/env/__init__.py -------------------------------------------------------------------------------- /env/atari_lib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.environ.setdefault('PATH', '') 4 | from collections import deque 5 | import gym 6 | from gym import spaces 7 | import cv2 8 | cv2.ocl.setUseOpenCL(False) 9 | from env.wrappers import TimeLimit 10 | 11 | 12 | class NoopResetEnv(gym.Wrapper): 13 | def __init__(self, env, noop_max=30): 14 | """Sample initial states by taking random number of no-ops on reset. 15 | No-op is assumed to be action 0. 16 | """ 17 | gym.Wrapper.__init__(self, env) 18 | self.noop_max = noop_max 19 | self.override_num_noops = None 20 | self.noop_action = 0 21 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 22 | 23 | def reset(self, **kwargs): 24 | """ Do no-op action for a number of steps in [1, noop_max].""" 25 | self.env.reset(**kwargs) 26 | if self.override_num_noops is not None: 27 | noops = self.override_num_noops 28 | else: 29 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 30 | assert noops > 0 31 | obs = None 32 | for _ in range(noops): 33 | obs, _, done, _ = self.env.step(self.noop_action) 34 | if done: 35 | obs = self.env.reset(**kwargs) 36 | return obs 37 | 38 | def step(self, ac): 39 | return self.env.step(ac) 40 | 41 | class FireResetEnv(gym.Wrapper): 42 | def __init__(self, env): 43 | """Take action on reset for environments that are fixed until firing.""" 44 | gym.Wrapper.__init__(self, env) 45 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 46 | assert len(env.unwrapped.get_action_meanings()) >= 3 47 | 48 | def reset(self, **kwargs): 49 | self.env.reset(**kwargs) 50 | obs, _, done, _ = self.env.step(1) 51 | if done: 52 | self.env.reset(**kwargs) 53 | obs, _, done, _ = self.env.step(2) 54 | if done: 55 | self.env.reset(**kwargs) 56 | return obs 57 | 58 | def step(self, ac): 59 | return self.env.step(ac) 60 | 61 | class EpisodicLifeEnv(gym.Wrapper): 62 | def __init__(self, env): 63 | """Make end-of-life == end-of-episode, but only reset on true game over. 64 | Done by DeepMind for the DQN and co. since it helps value estimation. 65 | """ 66 | gym.Wrapper.__init__(self, env) 67 | self.lives = 0 68 | self.was_real_done = True 69 | 70 | def step(self, action): 71 | obs, reward, done, info = self.env.step(action) 72 | self.was_real_done = done 73 | # check current lives, make loss of life terminal, 74 | # then update lives to handle bonus lives 75 | lives = self.env.unwrapped.ale.lives() 76 | if lives < self.lives and lives > 0: 77 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 78 | # so it's important to keep lives > 0, so that we only reset once 79 | # the environment advertises done. 80 | done = True 81 | self.lives = lives 82 | return obs, reward, done, info 83 | 84 | def reset(self, **kwargs): 85 | """Reset only when lives are exhausted. 86 | This way all states are still reachable even though lives are episodic, 87 | and the learner need not know about any of this behind-the-scenes. 88 | """ 89 | if self.was_real_done: 90 | obs = self.env.reset(**kwargs) 91 | else: 92 | # no-op step to advance from terminal/lost life state 93 | obs, _, _, _ = self.env.step(0) 94 | self.lives = self.env.unwrapped.ale.lives() 95 | return obs 96 | 97 | class MaxAndSkipEnv(gym.Wrapper): 98 | def __init__(self, env, skip=4): 99 | """Return only every `skip`-th frame""" 100 | gym.Wrapper.__init__(self, env) 101 | # most recent raw observations (for max pooling across time steps) 102 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 103 | self._skip = skip 104 | 105 | def step(self, action): 106 | """Repeat action, sum reward, and max over last observations.""" 107 | total_reward = 0.0 108 | done = None 109 | for i in range(self._skip): 110 | obs, reward, done, info = self.env.step(action) 111 | if i == self._skip - 2: self._obs_buffer[0] = obs 112 | if i == self._skip - 1: self._obs_buffer[1] = obs 113 | total_reward += reward 114 | if done: 115 | break 116 | # Note that the observation on the done=True frame 117 | # doesn't matter 118 | max_frame = self._obs_buffer.max(axis=0) 119 | 120 | return max_frame, total_reward, done, info 121 | 122 | def reset(self, **kwargs): 123 | return self.env.reset(**kwargs) 124 | 125 | class ClipRewardEnv(gym.RewardWrapper): 126 | def __init__(self, env): 127 | gym.RewardWrapper.__init__(self, env) 128 | 129 | def reward(self, reward): 130 | """Bin reward to {+1, 0, -1} by its sign.""" 131 | return np.sign(reward) 132 | 133 | 134 | class WarpFrame(gym.ObservationWrapper): 135 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): 136 | """ 137 | Warp frames to 84x84 as done in the Nature paper and later work. 138 | 139 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which 140 | observation should be warped. 141 | """ 142 | super().__init__(env) 143 | self._width = width 144 | self._height = height 145 | self._grayscale = grayscale 146 | self._key = dict_space_key 147 | if self._grayscale: 148 | num_colors = 1 149 | else: 150 | num_colors = 3 151 | 152 | new_space = gym.spaces.Box( 153 | low=0, 154 | high=255, 155 | shape=(self._height, self._width, num_colors), 156 | dtype=np.uint8, 157 | ) 158 | if self._key is None: 159 | original_space = self.observation_space 160 | self.observation_space = new_space 161 | else: 162 | original_space = self.observation_space.spaces[self._key] 163 | self.observation_space.spaces[self._key] = new_space 164 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 165 | 166 | def observation(self, obs): 167 | if self._key is None: 168 | frame = obs 169 | else: 170 | frame = obs[self._key] 171 | 172 | if self._grayscale: 173 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 174 | frame = cv2.resize( 175 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA 176 | ) 177 | if self._grayscale: 178 | frame = np.expand_dims(frame, -1) 179 | 180 | if self._key is None: 181 | obs = frame 182 | else: 183 | obs = obs.copy() 184 | obs[self._key] = frame 185 | return obs 186 | 187 | 188 | class FrameStack(gym.Wrapper): 189 | def __init__(self, env, k): 190 | """Stack k last frames. 191 | 192 | Returns lazy array, which is much more memory efficient. 193 | 194 | See Also 195 | -------- 196 | baselines.common.atari_wrappers.LazyFrames 197 | """ 198 | gym.Wrapper.__init__(self, env) 199 | self.k = k 200 | self.frames = deque([], maxlen=k) 201 | shp = env.observation_space.shape 202 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) 203 | 204 | def reset(self): 205 | ob = self.env.reset() 206 | for _ in range(self.k): 207 | self.frames.append(ob) 208 | return self._get_ob() 209 | 210 | def step(self, action): 211 | ob, reward, done, info = self.env.step(action) 212 | self.frames.append(ob) 213 | return self._get_ob(), reward, done, info 214 | 215 | def _get_ob(self): 216 | assert len(self.frames) == self.k 217 | return LazyFrames(list(self.frames)) 218 | 219 | class ScaledFloatFrame(gym.ObservationWrapper): 220 | def __init__(self, env): 221 | gym.ObservationWrapper.__init__(self, env) 222 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 223 | 224 | def observation(self, observation): 225 | # careful! This undoes the memory optimization, use 226 | # with smaller replay buffers only. 227 | return np.array(observation).astype(np.float32) / 255.0 228 | 229 | class LazyFrames(object): 230 | def __init__(self, frames): 231 | """This object ensures that common frames between the observations are only stored once. 232 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 233 | buffers. 234 | 235 | This object should only be converted to numpy array before being passed to the model. 236 | 237 | You'd not believe how complex the previous solution was.""" 238 | self._frames = frames 239 | self._out = None 240 | 241 | def _force(self): 242 | if self._out is None: 243 | self._out = np.concatenate(self._frames, axis=-1) 244 | self._frames = None 245 | return self._out 246 | 247 | def __array__(self, dtype=None): 248 | out = self._force() 249 | if dtype is not None: 250 | out = out.astype(dtype) 251 | return out 252 | 253 | def __len__(self): 254 | return len(self._force()) 255 | 256 | def __getitem__(self, i): 257 | return self._force()[..., i] 258 | 259 | def make_atari(env_id, max_episode_steps=None): 260 | env = gym.make(env_id) 261 | assert 'NoFrameskip' in env.spec.id 262 | env = NoopResetEnv(env, noop_max=30) 263 | env = MaxAndSkipEnv(env, skip=4) 264 | if max_episode_steps is not None: 265 | env = TimeLimit(env, max_episode_steps=max_episode_steps) 266 | return env 267 | 268 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 269 | """Configure environment for DeepMind-style Atari. 270 | """ 271 | if episode_life: 272 | env = EpisodicLifeEnv(env) 273 | if 'FIRE' in env.unwrapped.get_action_meanings(): 274 | env = FireResetEnv(env) 275 | env = WarpFrame(env) 276 | if scale: 277 | env = ScaledFloatFrame(env) 278 | if clip_rewards: 279 | env = ClipRewardEnv(env) 280 | if frame_stack: 281 | env = FrameStack(env, 4) 282 | return env 283 | 284 | if __name__ == '__main__': 285 | env = make_atari("BreakoutNoFrameskip-v4") 286 | env = wrap_deepmind(env) 287 | # s = env.reset() 288 | obs = env.observation_space 289 | act = env.action_space 290 | # 291 | # s_ = s 292 | # env = gym.make("SeaquestNoFrameskip-v4") 293 | s = env.reset() 294 | a = env.action_space.sample() 295 | env.step(a) 296 | while 1: 297 | env.render() 298 | -------------------------------------------------------------------------------- /env/atari_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from gym.spaces.box import Box 4 | 5 | class AtariPreprocessing(object): 6 | """A class implementing image preprocessing for Atari 2600 agents. 7 | 8 | Specifically, this provides the following subset from the JAIR paper 9 | (Bellemare et al., 2013) and Nature DQN paper (Mnih et al., 2015): 10 | 11 | * Frame skipping (defaults to 4). 12 | * Terminal signal when a life is lost (off by default). 13 | * Grayscale and max-pooling of the last two frames. 14 | * Downsample the screen to a square image (defaults to 84x84). 15 | 16 | More generally, this class follows the preprocessing guidelines set down in 17 | Machado et al. (2018), "Revisiting the Arcade Learning Environment: 18 | Evaluation Protocols and Open Problems for General Agents". 19 | """ 20 | 21 | def __init__(self, environment, frame_skip=4, terminal_on_life_loss=False, 22 | screen_size=84): 23 | """Constructor for an Atari 2600 preprocessor. 24 | 25 | Args: 26 | environment: Gym environment whose observations are preprocessed. 27 | frame_skip: int, the frequency at which the agent experiences the game. 28 | terminal_on_life_loss: bool, If True, the step() method returns 29 | is_terminal=True whenever a life is lost. See Mnih et al. 2015. 30 | screen_size: int, size of a resized Atari 2600 frame. 31 | 32 | Raises: 33 | ValueError: if frame_skip or screen_size are not strictly positive. 34 | """ 35 | if frame_skip <= 0: 36 | raise ValueError('Frame skip should be strictly positive, got {}'. 37 | format(frame_skip)) 38 | if screen_size <= 0: 39 | raise ValueError('Target screen size should be strictly positive, got {}'. 40 | format(screen_size)) 41 | 42 | self.environment = environment 43 | self.terminal_on_life_loss = terminal_on_life_loss 44 | self.frame_skip = frame_skip 45 | self.screen_size = screen_size 46 | 47 | obs_dims = self.environment.observation_space 48 | # Stores temporary observations used for pooling over two successive 49 | # frames. 50 | self.screen_buffer = [ 51 | np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), 52 | np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) 53 | ] 54 | 55 | self.game_over = False 56 | self.lives = 0 # Will need to be set by reset(). 57 | 58 | @property 59 | def observation_space(self): 60 | # Return the observation space adjusted to match the shape of the processed 61 | # observations. 62 | return Box(low=0, high=255, shape=(self.screen_size, self.screen_size, 1), 63 | dtype=np.uint8) 64 | 65 | @property 66 | def action_space(self): 67 | return self.environment.action_space 68 | 69 | @property 70 | def reward_range(self): 71 | return self.environment.reward_range 72 | 73 | @property 74 | def metadata(self): 75 | return self.environment.metadata 76 | 77 | def close(self): 78 | return self.environment.close() 79 | 80 | def reset(self): 81 | """Resets the environment. 82 | 83 | Returns: 84 | observation: numpy array, the initial observation emitted by the 85 | environment. 86 | """ 87 | self.environment.reset() 88 | self.lives = self.environment.ale.lives() 89 | self._fetch_grayscale_observation(self.screen_buffer[0]) 90 | self.screen_buffer[1].fill(0) 91 | return self._pool_and_resize() 92 | 93 | def render(self, mode): 94 | """Renders the current screen, before preprocessing. 95 | 96 | This calls the Gym API's render() method. 97 | 98 | Args: 99 | mode: Mode argument for the environment's render() method. 100 | Valid values (str) are: 101 | 'rgb_array': returns the raw ALE image. 102 | 'human': renders to display via the Gym renderer. 103 | 104 | Returns: 105 | if mode='rgb_array': numpy array, the most recent screen. 106 | if mode='human': bool, whether the rendering was successful. 107 | """ 108 | return self.environment.render(mode) 109 | 110 | def step(self, action): 111 | """Applies the given action in the environment. 112 | 113 | Remarks: 114 | 115 | * If a terminal state (from life loss or episode end) is reached, this may 116 | execute fewer than self.frame_skip steps in the environment. 117 | * Furthermore, in this case the returned observation may not contain valid 118 | image data and should be ignored. 119 | 120 | Args: 121 | action: The action to be executed. 122 | 123 | Returns: 124 | observation: numpy array, the observation following the action. 125 | reward: float, the reward following the action. 126 | is_terminal: bool, whether the environment has reached a terminal state. 127 | This is true when a life is lost and terminal_on_life_loss, or when the 128 | episode is over. 129 | info: Gym API's info data structure. 130 | """ 131 | accumulated_reward = 0. 132 | 133 | for time_step in range(self.frame_skip): 134 | # We bypass the Gym observation altogether and directly fetch the 135 | # grayscale image from the ALE. This is a little faster. 136 | _, reward, game_over, info = self.environment.step(action) 137 | accumulated_reward += reward 138 | 139 | if self.terminal_on_life_loss: 140 | new_lives = self.environment.ale.lives() 141 | is_terminal = game_over or new_lives < self.lives 142 | self.lives = new_lives 143 | else: 144 | is_terminal = game_over 145 | 146 | if is_terminal: 147 | break 148 | # We max-pool over the last two frames, in grayscale. 149 | elif time_step >= self.frame_skip - 2: 150 | t = time_step - (self.frame_skip - 2) 151 | self._fetch_grayscale_observation(self.screen_buffer[t]) 152 | 153 | # Pool the last two observations. 154 | observation = self._pool_and_resize() 155 | 156 | self.game_over = game_over 157 | return observation, accumulated_reward, is_terminal, info 158 | 159 | def _fetch_grayscale_observation(self, output): 160 | """Returns the current observation in grayscale. 161 | 162 | The returned observation is stored in 'output'. 163 | 164 | Args: 165 | output: numpy array, screen buffer to hold the returned observation. 166 | 167 | Returns: 168 | observation: numpy array, the current observation in grayscale. 169 | """ 170 | self.environment.ale.getScreenGrayscale(output) 171 | return output 172 | 173 | def _pool_and_resize(self): 174 | """Transforms two frames into a Nature DQN observation. 175 | 176 | For efficiency, the transformation is done in-place in self.screen_buffer. 177 | 178 | Returns: 179 | transformed_screen: numpy array, pooled, resized screen. 180 | """ 181 | # Pool if there are enough screens to do so. 182 | if self.frame_skip > 1: 183 | np.maximum(self.screen_buffer[0], self.screen_buffer[1], 184 | out=self.screen_buffer[0]) 185 | 186 | transformed_image = cv2.resize(self.screen_buffer[0], 187 | (self.screen_size, self.screen_size), 188 | interpolation=cv2.INTER_AREA) 189 | int_image = np.asarray(transformed_image, dtype=np.uint8) 190 | return np.expand_dims(int_image, axis=2) 191 | -------------------------------------------------------------------------------- /env/dmc_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from collections import deque 4 | 5 | class DMCFrameStack(gym.Wrapper): 6 | def __init__(self, env, k): 7 | gym.Wrapper.__init__(self, env) 8 | self._k = k 9 | self._frames = deque([], maxlen=k) 10 | shp = env.observation_space.shape 11 | self.observation_space = gym.spaces.Box( 12 | low=0, 13 | high=255, 14 | shape=((shp[0] * k,) + shp[1:]), 15 | dtype=env.observation_space.dtype 16 | ) 17 | self._max_episode_steps = env._max_episode_steps 18 | 19 | def reset(self): 20 | obs = self.env.reset() 21 | for _ in range(self._k): 22 | self._frames.append(obs) 23 | return self._get_obs() 24 | 25 | def step(self, action): 26 | obs, reward, done, info = self.env.step(action) 27 | self._frames.append(obs) 28 | return self._get_obs(), reward, done, info 29 | 30 | def _get_obs(self): 31 | assert len(self._frames) == self._k 32 | return np.concatenate(list(self._frames), axis=0) -------------------------------------------------------------------------------- /env/gym_example.py: -------------------------------------------------------------------------------- 1 | from gym import envs 2 | env_names = [spec.id for spec in envs.registry.all()] 3 | for name in sorted(env_names): 4 | print(name) -------------------------------------------------------------------------------- /env/vecenv.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from multiprocessing import Process, Pipe 4 | 5 | import cloudpickle 6 | 7 | 8 | class CloudpickleWrapper(object): 9 | 10 | def __init__(self, data): 11 | self.data = data 12 | 13 | def __getstate__(self): 14 | return cloudpickle.dumps(self.data) 15 | 16 | def __setstate__(self, data): 17 | self.data = cloudpickle.loads(data) 18 | 19 | class RunningStat: # for class AutoNormalization 20 | def __init__(self, shape): 21 | self._n = 0 22 | self._M = np.zeros(shape) 23 | self._S = np.zeros(shape) 24 | 25 | def push(self, x): 26 | x = np.asarray(x) 27 | # assert x.shape == self._M.shape 28 | self._n += 1 29 | if self._n == 1: 30 | self._M[...] = x 31 | else: 32 | pre_memo = self._M.copy() 33 | self._M[...] = pre_memo + (x - pre_memo) / self._n 34 | self._S[...] = self._S + (x - pre_memo) * (x - self._M) 35 | 36 | @property 37 | def n(self): 38 | return self._n 39 | 40 | @property 41 | def mean(self): 42 | return self._M 43 | 44 | @property 45 | def var(self): 46 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 47 | 48 | @property 49 | def std(self): 50 | return np.sqrt(self.var) 51 | 52 | @property 53 | def shape(self): 54 | return self._M.shape 55 | 56 | 57 | class Identity: 58 | def __call__(self, x): 59 | return x 60 | 61 | 62 | class RewardFilter: 63 | def __init__(self, pre_filter, shape, center=True, scale=True, clip=10.0, gamma=0.99): 64 | self.pre_filter = pre_filter 65 | self.center = center 66 | self.scale = scale 67 | self.clip = clip 68 | self.gamma = gamma 69 | 70 | self.rs = RunningStat(shape) 71 | self.ret = np.zeros(shape) 72 | 73 | def __call__(self, x): 74 | x = self.pre_filter(x) 75 | self.ret = self.ret*self.gamma + x 76 | self.rs.push(self.ret) 77 | x = self.ret/self.rs.std 78 | if self.clip: 79 | x = np.clip(x, -self.clip, self.clip) 80 | return x 81 | 82 | 83 | class AutoNormalization: 84 | def __init__(self, pre_filter, shape, demean=True, destd=True, clip=10.0): 85 | self.demean = demean 86 | self.destd = destd 87 | self.clip = clip 88 | self.pre_filter = pre_filter 89 | 90 | self.rs = RunningStat(shape) 91 | 92 | def __call__(self, x, update=True): 93 | x = self.pre_filter(x) 94 | if update: 95 | self.rs.push(x) 96 | if self.demean: 97 | x = x - self.rs.mean 98 | if self.destd: 99 | x = x / (self.rs.std + 1e-8) 100 | if self.clip: 101 | x = np.clip(x, -self.clip, self.clip) 102 | return x 103 | 104 | @staticmethod 105 | def output_shape(input_space): 106 | return input_space.shape 107 | 108 | 109 | 110 | class Env(gym.Wrapper): 111 | def __init__(self, task, norm_state=False, norm_rewards=False): 112 | env = gym.make(task) 113 | super().__init__(env) 114 | self.total_reward = 0 115 | self.episode = 0 116 | self.total_step = 0 117 | 118 | self.state_norm = Identity() 119 | self.reward_norm = Identity() 120 | if norm_state: 121 | self.state_norm = AutoNormalization(self.state_norm, shape=env.observation_space.shape, clip=5) 122 | if norm_rewards: 123 | self.reward_norm = AutoNormalization(self.reward_norm, shape=(), clip=5) 124 | 125 | def reset(self, **kwargs): 126 | self.total_reward = 0 127 | self.episode = 0 128 | self.total_step = 0 129 | 130 | s = self.env.reset() 131 | return self.state_norm(s) 132 | 133 | def step(self, action): 134 | s_, r, d, info = self.env.step(action) 135 | 136 | self.total_reward += r 137 | self.total_step += 1 138 | if d: 139 | self.episode += 1 140 | s_ = self.env.reset() 141 | 142 | return self.state_norm(s_), self.reward_norm(r), d, info 143 | 144 | def statistics(self): 145 | return self.episode, self.total_reward 146 | 147 | 148 | class VectorEnv: 149 | def __init__(self, env_fn): 150 | self.envs = env_fn 151 | self.env_num = len(env_fn) 152 | self.observation_space = self.envs[0].observation_space 153 | self.action_space = self.envs[0].action_space 154 | 155 | def reset(self): 156 | self._obs = np.stack([e.reset() for e in self.envs], axis=0) 157 | return self._obs 158 | 159 | def step(self, action): 160 | result = [e.step(a) for e, a in zip(self.envs, action)] 161 | self._obs, self._rew, self._done, self._info = zip(*result) 162 | self._obs = np.stack(self._obs) 163 | self._rew = np.stack(self._rew) 164 | self._done = np.stack(self._done) 165 | self._info = np.stack(self._info) 166 | return self._obs, self._rew, self._done, self._info 167 | 168 | def statistics(self): 169 | result = [env.statistics() for env in self.envs] 170 | epi, total_reward = zip(*result) 171 | return np.sum(epi, axis=0), np.sum(total_reward, axis=0) 172 | 173 | 174 | def worker(parent, p, env_fn_wrapper): 175 | parent.close() 176 | # env = env_fn_wrapper.data() 177 | env = env_fn_wrapper 178 | try: 179 | while True: 180 | cmd, data = p.recv() 181 | if cmd == 'step': 182 | p.send(env.step(data)) 183 | elif cmd == 'reset': 184 | p.send(env.reset()) 185 | elif cmd == 'close': 186 | p.send(env.close()) 187 | p.close() 188 | break 189 | elif cmd == "statistics": 190 | p.send(env.statistics()) 191 | elif cmd == 'render': 192 | p.send(env.render(**data) if hasattr(env, 'render') else None) 193 | elif cmd == 'seed': 194 | p.send(env.seed(data) if hasattr(env, 'seed') else None) 195 | elif cmd == 'getattr': 196 | p.send(getattr(env, data) if hasattr(env, data) else None) 197 | else: 198 | p.close() 199 | raise NotImplementedError 200 | except KeyboardInterrupt: 201 | p.close() 202 | 203 | 204 | class SubVectorEnv: 205 | def __init__(self, env_fn): 206 | self.envs = env_fn 207 | self.env_num = len(env_fn) 208 | self.observation_space = self.envs[0].observation_space 209 | self.action_space = self.envs[0].action_space 210 | 211 | self.parent_remote, self.child_remote = \ 212 | zip(*[Pipe() for _ in range(self.env_num)]) 213 | 214 | self.processes = [ 215 | Process(target=worker, args=( 216 | parent, child, env_fn), daemon=True) 217 | for (parent, child, env_fn) in zip( 218 | self.parent_remote, self.child_remote, self.envs) 219 | ] 220 | 221 | for p in self.processes: 222 | p.start() 223 | for c in self.child_remote: 224 | c.close() 225 | 226 | def reset(self): 227 | for p in self.parent_remote: 228 | p.send(['reset', None]) 229 | self._obs = np.stack([p.recv() for p in self.parent_remote]) 230 | return self._obs 231 | 232 | def step(self, action): 233 | for p, a in zip(self.parent_remote, action): 234 | p.send(['step', a]) 235 | result = [p.recv() for p in self.parent_remote] 236 | self._obs, self._rew, self._done, self._info = zip(*result) 237 | self._obs = np.stack(self._obs) 238 | self._rew = np.stack(self._rew) 239 | self._done = np.stack(self._done) 240 | self._info = np.stack(self._info) 241 | return self._obs, self._rew, self._done, self._info 242 | 243 | def statistics(self): 244 | for p in self.parent_remote: 245 | p.send(['statistics', None]) 246 | result = [p.recv() for p in self.parent_remote] 247 | 248 | epi, total_reward = zip(*result) 249 | return np.sum(epi), np.sum(total_reward) 250 | -------------------------------------------------------------------------------- /env/wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | class TimeLimit(gym.Wrapper): 4 | def __init__(self, env, max_episode_steps=None): 5 | super(TimeLimit, self).__init__(env) 6 | self._max_episode_steps = max_episode_steps 7 | self._elapsed_steps = 0 8 | 9 | def step(self, ac): 10 | observation, reward, done, info = self.env.step(ac) 11 | self._elapsed_steps += 1 12 | if self._elapsed_steps >= self._max_episode_steps: 13 | done = True 14 | info['TimeLimit.truncated'] = True 15 | return observation, reward, done, info 16 | 17 | def reset(self, **kwargs): 18 | self._elapsed_steps = 0 19 | return self.env.reset(**kwargs) 20 | 21 | class ClipActionsWrapper(gym.Wrapper): 22 | def step(self, action): 23 | import numpy as np 24 | action = np.nan_to_num(action) 25 | action = np.clip(action, self.action_space.low, self.action_space.high) 26 | return self.env.step(action) 27 | 28 | def reset(self, **kwargs): 29 | return self.env.reset(**kwargs) 30 | -------------------------------------------------------------------------------- /user_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | # Where experiment outputs are saved by default: 5 | DEFAULT_DATA_DIR = osp.join(osp.abspath(osp.dirname(__file__)),'data') 6 | DEFAULT_IMG_DIR = osp.join(osp.abspath(osp.dirname(__file__)),'imgs') 7 | DEFAULT_VIDEO_DIR = osp.join(osp.abspath(osp.dirname(__file__)),'videos') 8 | 9 | # Whether to automatically insert a date and time stamp into the names of 10 | # save directories: 11 | FORCE_DATESTAMP = False 12 | 13 | # Whether GridSearch provides automatically-generated default shorthands: 14 | DEFAULT_SHORTHAND = True 15 | 16 | # Tells the GridSearch how many seconds to pause for before launching 17 | # experiments. 18 | WAIT_BEFORE_LAUNCH = 5 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidieufo/RL-Implementation/c6f1a466d7897ac896debc1e7f3b2601cb245544/utils/__init__.py -------------------------------------------------------------------------------- /utils/logx.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Some simple logging functionality, inspired by rllab's logging. 4 | 5 | Logs to a tab-separated-values file (path/to/output_directory/progress.txt) 6 | 7 | """ 8 | import json 9 | import joblib 10 | import shutil 11 | import numpy as np 12 | import tensorflow as tf 13 | import os.path as osp, time, atexit, os 14 | from utils.mpi_tools import proc_id, mpi_statistics_scalar 15 | from utils.serialization_utils import convert_json 16 | 17 | color2num = dict( 18 | gray=30, 19 | red=31, 20 | green=32, 21 | yellow=33, 22 | blue=34, 23 | magenta=35, 24 | cyan=36, 25 | white=37, 26 | crimson=38 27 | ) 28 | 29 | def colorize(string, color, bold=False, highlight=False): 30 | """ 31 | Colorize a string. 32 | 33 | This function was originally written by John Schulman. 34 | """ 35 | attr = [] 36 | num = color2num[color] 37 | if highlight: num += 10 38 | attr.append(str(num)) 39 | if bold: attr.append('1') 40 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 41 | 42 | def restore_tf_graph(sess, fpath): 43 | """ 44 | Loads graphs saved by Logger. 45 | 46 | Will output a dictionary whose keys and values are from the 'inputs' 47 | and 'outputs' dict you specified with logger.setup_tf_saver(). 48 | 49 | Args: 50 | sess: A Tensorflow session. 51 | fpath: Filepath to save directory. 52 | 53 | Returns: 54 | A dictionary mapping from keys to tensors in the computation graph 55 | loaded from ``fpath``. 56 | """ 57 | tf.saved_model.loader.load( 58 | sess, 59 | [tf.saved_model.tag_constants.SERVING], 60 | fpath 61 | ) 62 | model_info = joblib.load(osp.join(fpath, 'model_info.pkl')) 63 | graph = tf.get_default_graph() 64 | model = dict() 65 | model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['inputs'].items()}) 66 | model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['outputs'].items()}) 67 | return model 68 | 69 | class Logger: 70 | """ 71 | A general-purpose logger. 72 | 73 | Makes it easy to save diagnostics, hyperparameter configurations, the 74 | state of a training run, and the trained model. 75 | """ 76 | 77 | def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None, mode='w'): 78 | """ 79 | Initialize a Logger. 80 | 81 | Args: 82 | output_dir (string): A directory for saving results to. If 83 | ``None``, defaults to a temp directory of the form 84 | ``/tmp/experiments/somerandomnumber``. 85 | 86 | output_fname (string): Name for the tab-separated-value file 87 | containing metrics logged throughout a training run. 88 | Defaults to ``progress.txt``. 89 | 90 | exp_name (string): Experiment name. If you run multiple training 91 | runs and give them all the same ``exp_name``, the plotter 92 | will know to group them. (Use case: if you run the same 93 | hyperparameter configuration with multiple random seeds, you 94 | should give them all the same ``exp_name``.) 95 | """ 96 | if proc_id()==0: 97 | self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time()) 98 | if osp.exists(self.output_dir): 99 | print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir) 100 | else: 101 | os.makedirs(self.output_dir) 102 | self.output_file = open(osp.join(self.output_dir, output_fname), mode) 103 | atexit.register(self.output_file.close) 104 | print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True)) 105 | else: 106 | self.output_dir = None 107 | self.output_file = None 108 | self.first_row=True 109 | self.log_headers = [] 110 | self.log_current_row = {} 111 | self.exp_name = exp_name 112 | 113 | def log(self, msg, color='green'): 114 | """Print a colorized message to stdout.""" 115 | if proc_id()==0: 116 | print(colorize(msg, color, bold=True)) 117 | 118 | def log_tabular(self, key, val): 119 | """ 120 | Log a value of some diagnostic. 121 | 122 | Call this only once for each diagnostic quantity, each iteration. 123 | After using ``log_tabular`` to store values for each diagnostic, 124 | make sure to call ``dump_tabular`` to write them out to file and 125 | stdout (otherwise they will not get saved anywhere). 126 | """ 127 | if self.first_row: 128 | self.log_headers.append(key) 129 | else: 130 | assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key 131 | assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key 132 | self.log_current_row[key] = val 133 | 134 | def save_config(self, config): 135 | """ 136 | Log an experiment configuration. 137 | 138 | Call this once at the top of your experiment, passing in all important 139 | config vars as a dict. This will serialize the config to JSON, while 140 | handling anything which can't be serialized in a graceful way (writing 141 | as informative a string as possible). 142 | 143 | Example use: 144 | 145 | .. code-block:: python 146 | 147 | logger = EpochLogger(**logger_kwargs) 148 | logger.save_config(locals()) 149 | """ 150 | config_json = convert_json(config) 151 | if self.exp_name is not None: 152 | config_json['exp_name'] = self.exp_name 153 | if proc_id()==0: 154 | output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True) 155 | print(colorize('Saving config:\n', color='cyan', bold=True)) 156 | print(output) 157 | with open(osp.join(self.output_dir, "config.json"), 'w') as out: 158 | out.write(output) 159 | 160 | def save_state(self, state_dict, itr=None): 161 | """ 162 | Saves the state of an experiment. 163 | 164 | To be clear: this is about saving *state*, not logging diagnostics. 165 | All diagnostic logging is separate from this function. This function 166 | will save whatever is in ``state_dict``---usually just a copy of the 167 | environment---and the most recent parameters for the model you 168 | previously set up saving for with ``setup_tf_saver``. 169 | 170 | Call with any frequency you prefer. If you only want to maintain a 171 | single state and overwrite it at each call with the most recent 172 | version, leave ``itr=None``. If you want to keep all of the states you 173 | save, provide unique (increasing) values for 'itr'. 174 | 175 | Args: 176 | state_dict (dict): Dictionary containing essential elements to 177 | describe the current state of training. 178 | 179 | itr: An int, or None. Current iteration of training. 180 | """ 181 | if proc_id()==0: 182 | fname = 'vars.pkl' if itr is None else 'vars%d.pkl'%itr 183 | try: 184 | joblib.dump(state_dict, osp.join(self.output_dir, fname)) 185 | except: 186 | self.log('Warning: could not pickle state_dict.', color='red') 187 | if hasattr(self, 'tf_saver_elements'): 188 | self._tf_simple_save(itr) 189 | 190 | def setup_tf_saver(self, sess, inputs, outputs): 191 | """ 192 | Set up easy model saving for tensorflow. 193 | 194 | Call once, after defining your computation graph but before training. 195 | 196 | Args: 197 | sess: The Tensorflow session in which you train your computation 198 | graph. 199 | 200 | inputs (dict): A dictionary that maps from keys of your choice 201 | to the tensorflow placeholders that serve as inputs to the 202 | computation graph. Make sure that *all* of the placeholders 203 | needed for your outputs are included! 204 | 205 | outputs (dict): A dictionary that maps from keys of your choice 206 | to the outputs from your computation graph. 207 | """ 208 | self.tf_saver_elements = dict(session=sess, inputs=inputs, outputs=outputs) 209 | self.tf_saver_info = {'inputs': {k:v.name for k,v in inputs.items()}, 210 | 'outputs': {k:v.name for k,v in outputs.items()}} 211 | 212 | ###########zhr 213 | def save_model(self, model, step=None): 214 | fpath = "saver" 215 | if self.output_dir is None: 216 | self.log("no output_file", color="red") 217 | return 218 | 219 | fpath = osp.join(self.output_dir, fpath) 220 | 221 | if not osp.exists(fpath): 222 | os.makedirs(fpath) 223 | 224 | if step != None: 225 | fpath = osp.join(fpath, "step_"+str(step)) 226 | else: 227 | fpath = osp.join(fpath, "step") 228 | 229 | model.save(fpath) 230 | #########zhr 231 | 232 | 233 | def _tf_simple_save(self, itr=None): 234 | """ 235 | Uses simple_save to save a trained model, plus info to make it easy 236 | to associated tensors to variables after restore. 237 | """ 238 | if proc_id()==0: 239 | assert hasattr(self, 'tf_saver_elements'), \ 240 | "First have to setup saving with self.setup_tf_saver" 241 | fpath = 'simple_save' + ('%d'%itr if itr is not None else '') 242 | fpath = osp.join(self.output_dir, fpath) 243 | if osp.exists(fpath): 244 | # simple_save refuses to be useful if fpath already exists, 245 | # so just delete fpath if it's there. 246 | shutil.rmtree(fpath) 247 | tf.saved_model.simple_save(export_dir=fpath, **self.tf_saver_elements) 248 | joblib.dump(self.tf_saver_info, osp.join(fpath, 'model_info.pkl')) 249 | 250 | def dump_tabular(self): 251 | """ 252 | Write all of the diagnostics from the current iteration. 253 | 254 | Writes both to stdout, and to the output file. 255 | """ 256 | if proc_id()==0: 257 | vals = [] 258 | key_lens = [len(key) for key in self.log_headers] 259 | max_key_len = max(15,max(key_lens)) 260 | keystr = '%'+'%d'%max_key_len 261 | fmt = "| " + keystr + "s | %15s |" 262 | n_slashes = 22 + max_key_len 263 | print("-"*n_slashes) 264 | for key in self.log_headers: 265 | val = self.log_current_row.get(key, "") 266 | valstr = "%8.3g"%val if hasattr(val, "__float__") else val 267 | print(fmt%(key, valstr)) 268 | vals.append(val) 269 | print("-"*n_slashes) 270 | if self.output_file is not None: 271 | if self.first_row: 272 | self.output_file.write("\t".join(self.log_headers)+"\n") 273 | self.output_file.write("\t".join(map(str,vals))+"\n") 274 | self.output_file.flush() 275 | self.log_current_row.clear() 276 | self.first_row=False 277 | 278 | class EpochLogger(Logger): 279 | """ 280 | A variant of Logger tailored for tracking average values over epochs. 281 | 282 | Typical use case: there is some quantity which is calculated many times 283 | throughout an epoch, and at the end of the epoch, you would like to 284 | report the average / std / min / max value of that quantity. 285 | 286 | With an EpochLogger, each time the quantity is calculated, you would 287 | use 288 | 289 | .. code-block:: python 290 | 291 | epoch_logger.store(NameOfQuantity=quantity_value) 292 | 293 | to load it into the EpochLogger's state. Then at the end of the epoch, you 294 | would use 295 | 296 | .. code-block:: python 297 | 298 | epoch_logger.log_tabular(NameOfQuantity, **options) 299 | 300 | to record the desired values. 301 | """ 302 | 303 | def __init__(self, *args, **kwargs): 304 | super().__init__(*args, **kwargs) 305 | self.epoch_dict = dict() 306 | 307 | def store(self, **kwargs): 308 | """ 309 | Save something into the epoch_logger's current state. 310 | 311 | Provide an arbitrary number of keyword arguments with numerical 312 | values. 313 | """ 314 | for k,v in kwargs.items(): 315 | if not(k in self.epoch_dict.keys()): 316 | self.epoch_dict[k] = [] 317 | self.epoch_dict[k].append(v) 318 | 319 | def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False): 320 | """ 321 | Log a value or possibly the mean/std/min/max values of a diagnostic. 322 | 323 | Args: 324 | key (string): The name of the diagnostic. If you are logging a 325 | diagnostic whose state has previously been saved with 326 | ``store``, the key here has to match the key you used there. 327 | 328 | val: A value for the diagnostic. If you have previously saved 329 | values for this key via ``store``, do *not* provide a ``val`` 330 | here. 331 | 332 | with_min_and_max (bool): If true, log min and max values of the 333 | diagnostic over the epoch. 334 | 335 | average_only (bool): If true, do not log the standard deviation 336 | of the diagnostic over the epoch. 337 | """ 338 | if val is not None: 339 | super().log_tabular(key,val) 340 | else: 341 | v = self.epoch_dict[key] 342 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 343 | stats = mpi_statistics_scalar(vals, with_min_and_max=with_min_and_max) 344 | super().log_tabular(key if average_only else 'Average' + key, stats[0]) 345 | if not(average_only): 346 | super().log_tabular('Std'+key, stats[1]) 347 | if with_min_and_max: 348 | super().log_tabular('Max'+key, stats[3]) 349 | super().log_tabular('Min'+key, stats[2]) 350 | self.epoch_dict[key] = [] 351 | 352 | def get_stats(self, key, with_min_and_max=False): 353 | """ 354 | Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. 355 | """ 356 | v = self.epoch_dict[key] 357 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 358 | return mpi_statistics_scalar(vals, with_min_and_max) -------------------------------------------------------------------------------- /utils/mpi_tools.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import os, subprocess, sys 3 | import numpy as np 4 | 5 | 6 | def mpi_fork(n, bind_to_core=False): 7 | """ 8 | Re-launches the current script with workers linked by MPI. 9 | 10 | Also, terminates the original process that launched it. 11 | 12 | Taken almost without modification from the Baselines function of the 13 | `same name`_. 14 | 15 | .. _`same name`: https://github.com/openai/baselines/blob/master/baselines/common/mpi_fork.py 16 | 17 | Args: 18 | n (int): Number of process to split into. 19 | 20 | bind_to_core (bool): Bind each MPI process to a core. 21 | """ 22 | if n<=1: 23 | return 24 | if os.getenv("IN_MPI") is None: 25 | env = os.environ.copy() 26 | env.update( 27 | MKL_NUM_THREADS="1", 28 | OMP_NUM_THREADS="1", 29 | IN_MPI="1" 30 | ) 31 | args = ["mpirun", "-np", str(n)] 32 | if bind_to_core: 33 | args += ["-bind-to", "core"] 34 | args += [sys.executable] + sys.argv 35 | subprocess.check_call(args, env=env) 36 | sys.exit() 37 | 38 | 39 | def msg(m, string=''): 40 | print(('Message from %d: %s \t '%(MPI.COMM_WORLD.Get_rank(), string))+str(m)) 41 | 42 | def proc_id(): 43 | """Get rank of calling process.""" 44 | return MPI.COMM_WORLD.Get_rank() 45 | 46 | def allreduce(*args, **kwargs): 47 | return MPI.COMM_WORLD.Allreduce(*args, **kwargs) 48 | 49 | def num_procs(): 50 | """Count active MPI processes.""" 51 | return MPI.COMM_WORLD.Get_size() 52 | 53 | def broadcast(x, root=0): 54 | MPI.COMM_WORLD.Bcast(x, root=root) 55 | 56 | def mpi_op(x, op): 57 | x, scalar = ([x], True) if np.isscalar(x) else (x, False) 58 | x = np.asarray(x, dtype=np.float32) 59 | buff = np.zeros_like(x, dtype=np.float32) 60 | allreduce(x, buff, op=op) 61 | return buff[0] if scalar else buff 62 | 63 | def mpi_sum(x): 64 | return mpi_op(x, MPI.SUM) 65 | 66 | def mpi_avg(x): 67 | """Average a scalar or vector over MPI processes.""" 68 | return mpi_sum(x) / num_procs() 69 | 70 | def mpi_statistics_scalar(x, with_min_and_max=False): 71 | """ 72 | Get mean/std and optional min/max of scalar x across MPI processes. 73 | 74 | Args: 75 | x: An array containing samples of the scalar to produce statistics 76 | for. 77 | 78 | with_min_and_max (bool): If true, return min and max of x in 79 | addition to mean and std. 80 | """ 81 | x = np.array(x, dtype=np.float32) 82 | global_sum, global_n = mpi_sum([np.sum(x), len(x)]) 83 | mean = global_sum / global_n 84 | 85 | global_sum_sq = mpi_sum(np.sum((x - mean)**2)) 86 | std = np.sqrt(global_sum_sq / global_n) # compute global std 87 | 88 | if with_min_and_max: 89 | global_min = mpi_op(np.min(x) if len(x) > 0 else np.inf, op=MPI.MIN) 90 | global_max = mpi_op(np.max(x) if len(x) > 0 else -np.inf, op=MPI.MAX) 91 | return mean, std, global_min, global_max 92 | return mean, std -------------------------------------------------------------------------------- /utils/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class RunningStat: # for class AutoNormalization 4 | def __init__(self, shape): 5 | self._n = 0 6 | self._M = np.zeros(shape) 7 | self._S = np.zeros(shape) 8 | 9 | def push(self, x): 10 | x = np.asarray(x) 11 | # assert x.shape == self._M.shape 12 | self._n += 1 13 | if self._n == 1: 14 | self._M[...] = x 15 | else: 16 | pre_memo = self._M.copy() 17 | self._M[...] = pre_memo + (x - pre_memo) / self._n 18 | self._S[...] = self._S + (x - pre_memo) * (x - self._M) 19 | 20 | @property 21 | def n(self): 22 | return self._n 23 | 24 | @property 25 | def mean(self): 26 | return self._M 27 | 28 | @property 29 | def var(self): 30 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 31 | 32 | @property 33 | def std(self): 34 | return np.sqrt(self.var) 35 | 36 | @property 37 | def shape(self): 38 | return self._M.shape 39 | 40 | 41 | class Identity: 42 | def __call__(self, x, update=True): 43 | return x 44 | 45 | def reset(self): 46 | pass 47 | 48 | class ImageProcess: 49 | def __init__(self, pre_filter): 50 | self.pre_filter = pre_filter 51 | def __call__(self, x, update=True): 52 | x = self.pre_filter(x) 53 | x = np.array(x).astype(np.float32) / 255.0 54 | return x 55 | 56 | def reset(self): 57 | self.pre_filter.reset() 58 | 59 | class RewardFilter: 60 | def __init__(self, pre_filter, shape, center=True, scale=True, clip=10.0, gamma=0.99): 61 | self.pre_filter = pre_filter 62 | self.center = center 63 | self.scale = scale 64 | self.clip = clip 65 | self.gamma = gamma 66 | 67 | self.rs = RunningStat(shape) 68 | self.ret = np.zeros(shape) 69 | 70 | def __call__(self, x, update=True): 71 | x = self.pre_filter(x) 72 | self.ret = self.ret*self.gamma + x 73 | if update: 74 | self.rs.push(self.ret) 75 | x = x/(self.rs.std + 1e-8) 76 | if self.clip: 77 | x = np.clip(x, -self.clip, self.clip) 78 | return x 79 | 80 | def reset(self): 81 | self.pre_filter.reset() 82 | self.ret = np.zeros(self.ret.shape) 83 | 84 | 85 | 86 | class AutoNormalization: 87 | def __init__(self, pre_filter, shape, demean=True, destd=True, clip=10.0): 88 | self.demean = demean 89 | self.destd = destd 90 | self.clip = clip 91 | self.pre_filter = pre_filter 92 | 93 | self.rs = RunningStat(shape) 94 | 95 | def __call__(self, x, update=True): 96 | x = self.pre_filter(x) 97 | if update: 98 | self.rs.push(x) 99 | if self.demean: 100 | x = x - self.rs.mean 101 | if self.destd: 102 | x = x / (self.rs.std + 1e-8) 103 | if self.clip: 104 | x = np.clip(x, -self.clip, self.clip) 105 | return x 106 | 107 | def reset(self): 108 | self.pre_filter.reset() 109 | 110 | @staticmethod 111 | def output_shape(input_space): 112 | return input_space.shape 113 | 114 | 115 | class RunningMeanStd(object): 116 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 117 | def __init__(self, epsilon=1e-4, shape=()): 118 | self.mean = np.zeros(shape, 'float64') 119 | self.var = np.ones(shape, 'float64') 120 | self.count = epsilon 121 | 122 | def update(self, x): 123 | batch_mean = np.mean(x, axis=0) 124 | batch_var = np.var(x, axis=0) 125 | batch_count = x.shape[0] 126 | self.update_from_moments(batch_mean, batch_var, batch_count) 127 | 128 | def update_from_moments(self, batch_mean, batch_var, batch_count): 129 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 130 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 131 | 132 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 133 | delta = batch_mean - mean 134 | tot_count = count + batch_count 135 | 136 | new_mean = mean + delta * batch_count / tot_count 137 | m_a = var * count 138 | m_b = batch_var * batch_count 139 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 140 | new_var = M2 / tot_count 141 | new_count = tot_count 142 | 143 | return new_mean, new_var, new_count 144 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | from user_config import DEFAULT_IMG_DIR, DEFAULT_DATA_DIR 7 | import argparse 8 | import glob 9 | 10 | def smooth(data, sm=1): 11 | if sm > 1: 12 | y = np.ones(sm)*1.0/sm 13 | data = np.convolve(y, data, "same") 14 | 15 | 16 | return data 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--plot_name', default=None) 21 | parser.add_argument('--seed', default='10', type=int) 22 | parser.add_argument('--output_name', default=None, type=str) 23 | args = parser.parse_args() 24 | 25 | plt.style.use('fivethirtyeight') 26 | 27 | if args.plot_name is None: 28 | file_list = glob.glob(os.path.join(DEFAULT_DATA_DIR, "ppo_mask*")) 29 | for file in file_list: 30 | for file_seed in glob.glob(os.path.join(file, "*")): 31 | data_file = os.path.join(file_seed, "progress.txt") 32 | pd_data = pd.read_table(data_file) 33 | mean_name = "Averagetest_reward" 34 | std_name = "Stdtest_reward" 35 | if mean_name not in pd_data.columns or std_name not in pd_data.columns: 36 | continue 37 | mean = pd_data[mean_name] 38 | std = pd_data[std_name] 39 | x = pd_data["Epoch"] 40 | mean = smooth(mean, sm=3) 41 | 42 | plt.plot(x, mean, c="deepskyblue", linewidth=1) 43 | plt.fill_between(x, mean+std, mean, color="lightskyblue") 44 | plt.fill_between(x, mean-std, mean, color="lightskyblue") 45 | 46 | output_name = file_seed.split(os.sep)[-1] 47 | plt.title(output_name) 48 | plt.xlabel("epoch") 49 | plt.ylabel("return") 50 | # plt.legend(loc = 'lower right', # 默认在左上角即 upper left 可以通过loc进行修改 51 | # fancybox = True, # 边框 52 | # framealpha = 0.5, # 透明度 53 | # shadow = True, # 阴影 54 | # borderpad = 1) # 边框宽度 55 | 56 | if not os.path.exists(DEFAULT_IMG_DIR): 57 | os.mkdir(DEFAULT_IMG_DIR) 58 | out_file = os.path.join(DEFAULT_IMG_DIR, output_name + ".png") 59 | plt.savefig(out_file) 60 | plt.clf() 61 | 62 | else: 63 | 64 | from utils.run_utils import setup_logger_kwargs 65 | logger_kwargs = setup_logger_kwargs(args.plot_name, args.seed) 66 | data_file = os.path.join(logger_kwargs["output_dir"], "progress.txt") 67 | 68 | pd_data = pd.read_table(data_file) 69 | mean_name = "Averagetest_reward" 70 | std_name = "Stdtest_reward" 71 | mean = pd_data[mean_name] 72 | std = pd_data[std_name] 73 | x = pd_data["Epoch"] 74 | 75 | plt.plot(x, mean) 76 | plt.fill_between(x, mean+std, mean-std) 77 | 78 | output_name = args.output_name 79 | if args.output_name is None: 80 | output_name = args.plot_name + "_s" + str(args.seed) 81 | if not os.path.exists(DEFAULT_IMG_DIR): 82 | os.mkdir(DEFAULT_IMG_DIR) 83 | out_file = os.path.join(DEFAULT_IMG_DIR, output_name + ".png") 84 | plt.savefig(out_file) 85 | -------------------------------------------------------------------------------- /utils/serialization_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def convert_json(obj): 4 | """ Convert obj to a version which can be serialized with JSON. """ 5 | if is_json_serializable(obj): 6 | return obj 7 | else: 8 | if isinstance(obj, dict): 9 | return {convert_json(k): convert_json(v) 10 | for k,v in obj.items()} 11 | 12 | elif isinstance(obj, tuple): 13 | return (convert_json(x) for x in obj) 14 | 15 | elif isinstance(obj, list): 16 | return [convert_json(x) for x in obj] 17 | 18 | elif hasattr(obj,'__name__') and not('lambda' in obj.__name__): 19 | return convert_json(obj.__name__) 20 | 21 | elif hasattr(obj,'__dict__') and obj.__dict__: 22 | obj_dict = {convert_json(k): convert_json(v) 23 | for k,v in obj.__dict__.items()} 24 | return {str(obj): obj_dict} 25 | 26 | return str(obj) 27 | 28 | def is_json_serializable(v): 29 | try: 30 | json.dumps(v) 31 | return True 32 | except: 33 | return False -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder(object): 7 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 8 | self.dir_name = dir_name 9 | self.height = height 10 | self.width = width 11 | self.camera_id = camera_id 12 | self.fps = fps 13 | self.frames = [] 14 | 15 | def init(self, enabled=True): 16 | self.frames = [] 17 | self.enabled = self.dir_name is not None and enabled 18 | 19 | def record(self, env): 20 | if self.enabled: 21 | frame = env.render( 22 | mode='rgb_array', 23 | height=self.height, 24 | width=self.width, 25 | camera_id=self.camera_id 26 | ) 27 | self.frames.append(frame) 28 | 29 | def save(self, file_name): 30 | if self.enabled: 31 | path = os.path.join(self.dir_name, file_name) 32 | imageio.mimsave(path, self.frames, fps=self.fps) 33 | --------------------------------------------------------------------------------