├── README.md ├── core ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── agent.cpython-36.pyc │ ├── ddpg.cpython-36.pyc │ ├── ddpg.cpython-37.pyc │ ├── mod_neuro_evo.cpython-36.pyc │ ├── mod_neuro_evo.cpython-37.pyc │ ├── mod_utils.cpython-36.pyc │ ├── mod_utils.cpython-37.pyc │ ├── operator_runner.cpython-36.pyc │ ├── operator_runner.cpython-37.pyc │ ├── replay_memory.cpython-36.pyc │ ├── replay_memory.cpython-37.pyc │ └── utils.cpython-36.pyc ├── agent.py ├── ddpg.py ├── mod_neuro_evo.py ├── mod_utils.py ├── operator_runner.py ├── replay_memory.py └── utils.py ├── parameters.py ├── run.sh └── run_re2.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # ICLR 2023: ERL-Re$^2$: Efficient Evolutionary Reinforcement Learning with Shared State Representation and Individual Policy Representation 3 | 4 | Official code for the paper "ERL-Re$^2$: Efficient Evolutionary Reinforcement Learning with Shared State Representation and Individual Policy Representation" (). **:trophy: ERL-Re$^2$ achieves current SOTA in the ERL field (OpenAI MUJOCO Tasks)**. 5 | 6 | **ERL-Re$^2$** is a novel framework to integrate EA and RL. The cornerstone of ERL-Re$^2$ is two-scale representation: all EA and RL policies share the same nonlinear state representation while maintaining individual linear policy representations. The state representation conveys expressive common features of the environment learned by all the agents collectively; the linear policy representation provides a favorable space for efficient policy optimization, where novel behavior-level crossover and mutation operations can be performed. Moreover, the linear policy representation allows convenient generalization of policy fitness with the help of Policy-extended Value Function Approximator (PeVFA), further improving the sample efficiency. This repository is based on . 7 | 8 | 9 | ### NEWS :fire::fire::fire: Two ERL works are accepted by ICML 2024: 10 | 11 | :triangular_flag_on_post: [EvoRainbow: Combining Improvements in Evolutionary Reinforcement Learning for Policy Search](https://openreview.net/forum?id=75Hes6Zse4) 12 | 13 | :triangular_flag_on_post: [Value-Evolutionary-Based Reinforcement Learning](https://openreview.net/forum?id=h9LcbJ3l9r) 14 | 15 | **The codes for both works have already been released. 16 | We invite everyone to follow our work!** 17 | 18 | ### ![Awesome](https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg) You can explore more information on state-of-the-art ERL works by visiting [https://github.com/yeshenpy/Awesome-Evolutionary-Reinforcement-Learning](https://github.com/yeshenpy/Awesome-Evolutionary-Reinforcement-Learning). 19 | 20 | 21 | 22 | 23 | # Installation 24 | Known dependencies: MUJOCO 200, 25 | Python (3.6.13), gym (0.10.5), torch (1.1.0), fastrand, wandb, mujoco_py=2.0.2.13 26 | 27 | ## Hyperparameter 28 | - `-env`: define environment in MUJOCO 29 | - `-OFF_TYPE`: the type of PeVFA, default 1 30 | - `-pr`: The policy representation size, default 64 31 | - `-pop_size`: population size, default 5 32 | - `-prob_reset_and_sup`: Probability of resetting parameters and super mutations, default 0.05 33 | - `-time_steps`: The step length of H-step return 34 | - `-theta`: Probability of using MC estimates 35 | - `-frac`: Mutation ratio 36 | - `-gamma`: Gamma for RL 37 | - `-TD3_noise`: Noise for TD3 38 | - `-EA`: Whether to use EA 39 | - `-RL`: Whether to use RL 40 | - `-K`: The number of individuals selected to optimize the shared representation from the population 41 | - `-EA_actor_alpha`: The coefficient used to balance the weights of PeVFA loss 42 | - `-actor_alpha`: The coefficient used to balance the weights of RL loss 43 | - `-tau`: The coefficient for soft updates 44 | - `-seed`: Seed, default from 1 to 5 45 | - `-logdir`: Log Location 46 | 47 | ## Code structure 48 | 49 | - `./parameters.py`: Hyperparameters setting for ERL-Re$^2$ 50 | 51 | - `./run_re2.py`: Code to run ERL-Re$^2$ 52 | 53 | - `./core/agent.py`: Algorithm flow 54 | 55 | - `./core/ddpg.py`: The core code of ERL-Re$^2$ (DDPG and TD3 version) 56 | 57 | - `./core/mod_utils`: Some Functions for ERL-Re$^2$ 58 | 59 | - `./core/replay_memory`: replay buffer for ERL-Re$^2$ 60 | 61 | - `./core/utils`: Some Functions for ERL-Re$^2$ 62 | 63 | - `./run.sh`: command-line file 64 | 65 | ## How to run 66 | 67 | We implement and provide ERL-Re$^2$ based on TD3. 68 | Run the `run.sh` file directly, Hyperparameter settings can be found in the paper. 69 | 70 | 71 | ## Publication 72 | 73 | If you find this repository useful or use our code, please cite our paper: 74 | 75 | @inproceedings{ 76 | li2023erlre, 77 | title={{ERL}-Re\${\textasciicircum}2\$: Efficient Evolutionary Reinforcement Learning with Shared State Representation and Individual Policy Representation }, 78 | author={Pengyi Li and Hongyao Tang and Jianye HAO and YAN ZHENG and Xian Fu and Zhaopeng Meng}, 79 | booktitle={International Conference on Learning Representations}, 80 | year={2023} 81 | } 82 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__init__.py -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/agent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/agent.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/ddpg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/ddpg.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/ddpg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/ddpg.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/mod_neuro_evo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/mod_neuro_evo.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/mod_neuro_evo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/mod_neuro_evo.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/mod_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/mod_utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/mod_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/mod_utils.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/operator_runner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/operator_runner.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/operator_runner.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/operator_runner.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/replay_memory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/replay_memory.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/replay_memory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/replay_memory.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshenpy/ERL-Re2/1ca10cf8909396084429fb73734c40d74859ffa5/core/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core import mod_neuro_evo as utils_ne 3 | from core import mod_utils as utils 4 | from core import replay_memory 5 | from core import ddpg as ddpg 6 | from scipy.spatial import distance 7 | from core import replay_memory 8 | from parameters import Parameters 9 | import torch 10 | from core import utils 11 | import scipy.signal 12 | import torch.nn as nn 13 | import math 14 | import random 15 | 16 | def discount(x, gamma): 17 | """ Calculate discounted forward sum of a sequence at each point """ 18 | return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1] 19 | 20 | class Agent: 21 | def __init__(self, args: Parameters, env): 22 | self.args = args; self.env = env 23 | 24 | # Init population 25 | self.pop = [] 26 | self.buffers = [] 27 | self.all_actors = [] 28 | for _ in range(args.pop_size): 29 | #self.pop.append(ddpg.GeneticAgent(args)) 30 | genetic = ddpg.GeneticAgent(args) 31 | self.pop.append(genetic) 32 | self.all_actors.append(genetic.actor) 33 | 34 | # Init RL Agent 35 | 36 | 37 | self.rl_agent = ddpg.TD3(args) 38 | self.replay_buffer = utils.ReplayBuffer() 39 | 40 | self.all_actors.append(self.rl_agent.actor) 41 | 42 | self.ounoise = ddpg.OUNoise(args.action_dim) 43 | self.evolver = utils_ne.SSNE(self.args, self.rl_agent.critic, self.evaluate,self.rl_agent.state_embedding, self.args.prob_reset_and_sup, self.args.frac) 44 | 45 | # Population novelty 46 | self.ns_r = 1.0 47 | self.ns_delta = 0.1 48 | self.best_train_reward = 0.0 49 | self.time_since_improv = 0 50 | self.step = 1 51 | self.use_real = 0 52 | self.total_use = 0 53 | # Trackers 54 | self.num_games = 0; self.num_frames = 0; self.iterations = 0; self.gen_frames = None 55 | self.rl_agent_frames = 0 56 | 57 | self.old_fitness = None 58 | self.evo_times = 0 59 | 60 | 61 | def evaluate(self, agent: ddpg.GeneticAgent or ddpg.TD3, state_embedding_net, is_render=False, is_action_noise=False, 62 | store_transition=True, net_index=None, is_random =False, rl_agent_collect_data = False, use_n_step_return = False,PeVFA=None): 63 | total_reward = 0.0 64 | total_error = 0.0 65 | policy_params = torch.nn.utils.parameters_to_vector(list(agent.actor.parameters())).data.cpu().numpy().reshape([-1]) 66 | state = self.env.reset() 67 | done = False 68 | 69 | state_list = [] 70 | reward_list = [] 71 | 72 | action_list = [] 73 | policy_params_list =[] 74 | n_step_discount_reward = 0.0 75 | episode_timesteps = 0 76 | all_state = [] 77 | all_action = [] 78 | 79 | while not done: 80 | if store_transition: 81 | self.num_frames += 1; self.gen_frames += 1 82 | if rl_agent_collect_data: 83 | self.rl_agent_frames +=1 84 | if self.args.render and is_render: self.env.render() 85 | 86 | if is_random: 87 | action = self.env.action_space.sample() 88 | else : 89 | action = agent.actor.select_action(np.array(state),state_embedding_net) 90 | if is_action_noise: 91 | 92 | action = (action + np.random.normal(0, 0.1, size=self.args.action_dim)).clip(-1.0, 1.0) 93 | all_state.append(np.array(state)) 94 | all_action.append(np.array(action)) 95 | # Simulate one step in environment 96 | next_state, reward, done, info = self.env.step(action.flatten()) 97 | done_bool = 0 if episode_timesteps + 1 == 1000 else float(done) 98 | total_reward += reward 99 | n_step_discount_reward += math.pow(self.args.gamma,episode_timesteps)*reward 100 | state_list.append(state) 101 | reward_list.append(reward) 102 | policy_params_list.append(policy_params) 103 | action_list.append(action.flatten()) 104 | 105 | transition = (state, action, next_state, reward, done_bool) 106 | if store_transition: 107 | next_action = agent.actor.select_action(np.array(next_state), state_embedding_net) 108 | self.replay_buffer.add((state, next_state, action, reward, done_bool, next_action ,policy_params)) 109 | #self.replay_buffer.add(*transition) 110 | agent.buffer.add(*transition) 111 | episode_timesteps += 1 112 | state = next_state 113 | 114 | if use_n_step_return: 115 | if self.args.time_steps <= episode_timesteps: 116 | next_action = agent.actor.select_action(np.array(next_state), state_embedding_net) 117 | param = nn.utils.parameters_to_vector(list(agent.actor.parameters())).data.cpu().numpy() 118 | param = torch.FloatTensor(param).to(self.args.device) 119 | param = param.repeat(1, 1) 120 | 121 | # print("1") 122 | next_state = torch.FloatTensor(np.array([next_state])).to(self.args.device) 123 | next_action = torch.FloatTensor(np.array([next_action])).to(self.args.device) 124 | 125 | input = torch.cat([next_state, next_action], -1) 126 | # print("2") 127 | next_Q1, next_Q2 = PeVFA.forward(input, param) 128 | # print("3") 129 | next_state_Q = torch.min(next_Q1, next_Q2).cpu().data.numpy().flatten() 130 | n_step_discount_reward += math.pow(self.args.gamma,episode_timesteps) *next_state_Q[0] 131 | break 132 | if store_transition: self.num_games += 1 133 | 134 | return {'n_step_discount_reward':n_step_discount_reward,'reward': total_reward, 'td_error': total_error, "state_list": state_list, "reward_list":reward_list, "policy_prams_list":policy_params_list, "action_list":action_list} 135 | 136 | 137 | def rl_to_evo(self, rl_agent: ddpg.TD3, evo_net: ddpg.GeneticAgent): 138 | for target_param, param in zip(evo_net.actor.parameters(), rl_agent.actor.parameters()): 139 | target_param.data.copy_(param.data) 140 | evo_net.buffer.reset() 141 | evo_net.buffer.add_content_of(rl_agent.buffer) 142 | 143 | def evo_to_rl(self, rl_net, evo_net): 144 | for target_param, param in zip(rl_net.parameters(), evo_net.parameters()): 145 | target_param.data.copy_(param.data) 146 | 147 | def get_pop_novelty(self): 148 | epochs = self.args.ns_epochs 149 | novelties = np.zeros(len(self.pop)) 150 | for _ in range(epochs): 151 | transitions = self.replay_buffer.sample(self.args.batch_size) 152 | batch = replay_memory.Transition(*zip(*transitions)) 153 | 154 | for i, net in enumerate(self.pop): 155 | novelties[i] += (net.get_novelty(batch)) 156 | return novelties / epochs 157 | 158 | def train_ddpg(self, evo_times,all_fitness, state_list_list,reward_list_list, policy_params_list_list,action_list_list): 159 | bcs_loss, pgs_loss,c_q,t_q = [], [],[],[] 160 | 161 | 162 | if len(self.replay_buffer.storage) >= 5000:#self.args.batch_size * 5: 163 | 164 | 165 | before_rewards = np.zeros(len(self.pop)) 166 | 167 | ddpg.hard_update(self.rl_agent.old_state_embedding, self.rl_agent.state_embedding) 168 | for gen in self.pop: 169 | ddpg.hard_update(gen.old_actor, gen.actor) 170 | 171 | discount_reward_list_list =[] 172 | for reward_list in reward_list_list: 173 | discount_reward_list = discount(reward_list,0.99) 174 | discount_reward_list_list.append(discount_reward_list) 175 | state_list_list = np.concatenate(np.array(state_list_list)) 176 | #print("state_list_list ",state_list_list.shape) 177 | discount_reward_list_list = np.concatenate(np.array(discount_reward_list_list)) 178 | #print("discount_reward_list_list ",discount_reward_list_list.shape) 179 | policy_params_list_list = np.concatenate(np.array(policy_params_list_list)) 180 | #print("policy_params_list_list ", policy_params_list_list.shape) 181 | action_list_list = np.concatenate(np.array(action_list_list)) 182 | pgl, delta,pre_loss,pv_loss,keep_c_loss= self.rl_agent.train(evo_times,all_fitness, self.pop , state_list_list, policy_params_list_list, discount_reward_list_list,action_list_list, self.replay_buffer ,int(self.gen_frames * self.args.frac_frames_train), self.args.batch_size, discount=self.args.gamma, tau=self.args.tau,policy_noise=self.args.TD3_noise,train_OFN_use_multi_actor=self.args.random_choose,all_actor=self.all_actors) 183 | after_rewards = np.zeros(len(self.pop)) 184 | else: 185 | before_rewards = np.zeros(len(self.pop)) 186 | after_rewards = np.zeros(len(self.pop)) 187 | delta = 0.0 188 | pgl = 0.0 189 | pre_loss = 0.0 190 | keep_c_loss = [0.0] 191 | pv_loss = 0.0 192 | add_rewards = np.mean(after_rewards - before_rewards) 193 | return {'pv_loss':pv_loss,'bcs_loss': delta, 'pgs_loss': pgl,"current_q":0.0, "target_q":0.0, "pre_loss":pre_loss}, keep_c_loss, add_rewards 194 | 195 | def train(self): 196 | self.gen_frames = 0 197 | self.iterations += 1 198 | 199 | # ========================== EVOLUTION ========================== 200 | # Evaluate genomes/individuals 201 | real_rewards = np.zeros(len(self.pop)) 202 | fake_rewards = np.zeros(len(self.pop)) 203 | MC_n_steps_rewards = np.zeros(len(self.pop)) 204 | state_list_list = [] 205 | 206 | reward_list_list = [] 207 | policy_parms_list_list =[] 208 | action_list_list =[] 209 | 210 | 211 | if self.args.EA and self.rl_agent_frames>=self.args.init_steps: 212 | self.evo_times +=1 213 | random_num_num = random.random() 214 | if random_num_num< self.args.theta: 215 | for i, net in enumerate(self.pop): 216 | for _ in range(self.args.num_evals): 217 | episode = self.evaluate(net, self.rl_agent.state_embedding, is_render=False, is_action_noise=False,net_index=i) 218 | real_rewards[i] += episode['reward'] 219 | real_rewards /= self.args.num_evals 220 | all_fitness = real_rewards 221 | else : 222 | for i, net in enumerate(self.pop): 223 | episode = self.evaluate(net, self.rl_agent.state_embedding, is_render=False, is_action_noise=False,net_index=i,use_n_step_return = True,PeVFA=self.rl_agent.PVN) 224 | fake_rewards[i] += episode['n_step_discount_reward'] 225 | MC_n_steps_rewards[i] +=episode['reward'] 226 | all_fitness = fake_rewards 227 | 228 | else : 229 | all_fitness = np.zeros(len(self.pop)) 230 | 231 | self.total_use +=1.0 232 | # all_fitness = 0.8 * rankdata(rewards) + 0.2 * rankdata(errors) 233 | 234 | keep_c_loss = [0.0 / 1000] 235 | min_fintess = 0.0 236 | best_old_fitness = 0.0 237 | temp_reward =0.0 238 | 239 | # Validation test for NeuroEvolution champion 240 | best_train_fitness = np.max(all_fitness) 241 | champion = self.pop[np.argmax(all_fitness)] 242 | 243 | test_score = 0 244 | 245 | if self.args.EA and self.rl_agent_frames>=self.args.init_steps: 246 | for eval in range(10): 247 | episode = self.evaluate(champion, self.rl_agent.state_embedding, is_render=True, is_action_noise=False, store_transition=False) 248 | test_score += episode['reward'] 249 | test_score /= 10.0 250 | 251 | # NeuroEvolution's probabilistic selection and recombination step 252 | if self.args.EA: 253 | elite_index = self.evolver.epoch(self.pop, all_fitness) 254 | else : 255 | elite_index = 0 256 | # ========================== DDPG or TD3 =========================== 257 | # Collect experience for training 258 | 259 | if self.args.RL: 260 | is_random = (self.rl_agent_frames < self.args.init_steps) 261 | episode = self.evaluate(self.rl_agent, self.rl_agent.state_embedding, is_action_noise=True, is_random=is_random,rl_agent_collect_data=True) 262 | 263 | state_list_list.append(episode['state_list']) 264 | reward_list_list.append(episode['reward_list']) 265 | policy_parms_list_list.append(episode['policy_prams_list']) 266 | action_list_list.append(episode['action_list']) 267 | 268 | if self.rl_agent_frames>=self.args.init_steps: 269 | losses, _, add_rewards = self.train_ddpg(self.evo_times,all_fitness, state_list_list,reward_list_list,policy_parms_list_list,action_list_list) 270 | else : 271 | losses = {'bcs_loss': 0.0, 'pgs_loss': 0.0 ,"current_q":0.0, "target_q":0.0, "pv_loss":0.0, "pre_loss":0.0} 272 | add_rewards = np.zeros(len(self.pop)) 273 | else : 274 | losses = {'bcs_loss': 0.0, 'pgs_loss': 0.0 ,"current_q":0.0, "target_q":0.0,"pv_loss":0.0, "pre_loss":0.0} 275 | 276 | add_rewards = np.zeros(len(self.pop)) 277 | 278 | L1_before_after = np.zeros(len(self.pop)) 279 | 280 | # Validation test for RL agent 281 | testr = 0 282 | 283 | if self.args.RL: 284 | for eval in range(10): 285 | ddpg_stats = self.evaluate(self.rl_agent, self.rl_agent.state_embedding,store_transition=False, is_action_noise=False) 286 | testr += ddpg_stats['reward'] 287 | testr /= 10.0 288 | 289 | #Sync RL Agent to NE every few steps 290 | if self.args.EA and self.args.RL and self.rl_agent_frames>=self.args.init_steps: 291 | if self.iterations % self.args.rl_to_ea_synch_period == 0: 292 | # Replace any index different from the new elite 293 | replace_index = np.argmin(all_fitness) 294 | if replace_index == elite_index: 295 | replace_index = (replace_index + 1) % len(self.pop) 296 | 297 | self.rl_to_evo(self.rl_agent, self.pop[replace_index]) 298 | self.evolver.rl_policy = replace_index 299 | print('Sync from RL --> Nevo') 300 | 301 | 302 | self.old_fitness = all_fitness 303 | # -------------------------- Collect statistics -------------------------- 304 | 305 | 306 | return { 307 | 'min_fintess':min_fintess, 308 | 'best_old_fitness':best_old_fitness, 309 | 'new_fitness':temp_reward, 310 | 'best_train_fitness': best_train_fitness, 311 | 'test_score': test_score, 312 | 'elite_index': elite_index, 313 | 'ddpg_reward': testr, 314 | 'pvn_loss':losses['pv_loss'], 315 | 'pg_loss': np.mean(losses['pgs_loss']), 316 | 'bc_loss': np.mean(losses['bcs_loss']), 317 | 'current_q': np.mean(losses['current_q']), 318 | 'target_q':np.mean(losses['target_q']), 319 | 'pre_loss': np.mean(losses['pre_loss']), 320 | 'pop_novelty': np.mean(0), 321 | 'before_rewards':all_fitness, 322 | 'add_rewards':add_rewards, 323 | 'l1_before_after':L1_before_after, 324 | 'keep_c_loss':np.mean(keep_c_loss) 325 | } 326 | 327 | 328 | class Archive: 329 | """A record of past behaviour characterisations (BC) in the population""" 330 | 331 | def __init__(self, args): 332 | self.args = args 333 | # Past behaviours 334 | self.bcs = [] 335 | 336 | def add_bc(self, bc): 337 | if len(self.bcs) + 1 > self.args.archive_size: 338 | self.bcs = self.bcs[1:] 339 | self.bcs.append(bc) 340 | 341 | def get_novelty(self, this_bc): 342 | if self.size() == 0: 343 | return np.array(this_bc).T @ np.array(this_bc) 344 | distances = np.ravel(distance.cdist(np.expand_dims(this_bc, axis=0), np.array(self.bcs), metric='sqeuclidean')) 345 | distances = np.sort(distances) 346 | return distances[:self.args.ns_k].mean() 347 | 348 | def size(self): 349 | return len(self.bcs) -------------------------------------------------------------------------------- /core/ddpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.nn import functional as F 5 | from parameters import Parameters 6 | from core import replay_memory 7 | from core.mod_utils import is_lnorm_key 8 | import numpy as np 9 | 10 | def soft_update(target, source, tau): 11 | for target_param, param in zip(target.parameters(), source.parameters()): 12 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 13 | 14 | 15 | def hard_update(target, source): 16 | for target_param, param in zip(target.parameters(), source.parameters()): 17 | target_param.data.copy_(param.data) 18 | 19 | 20 | class GeneticAgent: 21 | def __init__(self, args: Parameters): 22 | 23 | self.args = args 24 | 25 | self.actor = Actor(args) 26 | self.old_actor = Actor(args) 27 | self.temp_actor = Actor(args) 28 | self.actor_optim = Adam(self.actor.parameters(), lr=1e-4) 29 | 30 | self.buffer = replay_memory.ReplayMemory(self.args.individual_bs, args.device) 31 | self.loss = nn.MSELoss() 32 | 33 | def keep_consistency(self, z_old, z_new): 34 | target_action = self.old_actor.select_action_from_z(z_old).detach() 35 | current_action = self.actor.select_action_from_z(z_new) 36 | delta = (current_action - target_action).abs() 37 | dt = torch.mean(delta ** 2) 38 | self.actor_optim.zero_grad() 39 | dt.backward() 40 | self.actor_optim.step() 41 | return dt.data.cpu().numpy() 42 | 43 | def keep_consistency_with_other_agent(self, z_old, z_new, other_actor): 44 | target_action = other_actor.select_action_from_z(z_old).detach() 45 | current_action = self.actor.select_action_from_z(z_new) 46 | delta = (current_action - target_action).abs() 47 | dt = torch.mean(delta ** 2) 48 | self.actor_optim.zero_grad() 49 | dt.backward() 50 | self.actor_optim.step() 51 | return dt.data.cpu().numpy() 52 | 53 | def update_parameters(self, batch, p1, p2, critic): 54 | state_batch, _, _, _, _ = batch 55 | 56 | p1_action = p1(state_batch) 57 | p2_action = p2(state_batch) 58 | p1_q = critic.Q1(state_batch, p1_action).flatten() 59 | p2_q = critic.Q1(state_batch, p2_action).flatten() 60 | 61 | eps = 0.0 62 | action_batch = torch.cat((p1_action[p1_q - p2_q > eps], p2_action[p2_q - p1_q >= eps])).detach() 63 | state_batch = torch.cat((state_batch[p1_q - p2_q > eps], state_batch[p2_q - p1_q >= eps])) 64 | actor_action = self.actor(state_batch) 65 | 66 | # Actor Update 67 | self.actor_optim.zero_grad() 68 | sq = (actor_action - action_batch)**2 69 | policy_loss = torch.sum(sq) + torch.mean(actor_action**2) 70 | policy_mse = torch.mean(sq) 71 | policy_loss.backward() 72 | self.actor_optim.step() 73 | 74 | return policy_mse.item() 75 | 76 | class shared_state_embedding(nn.Module): 77 | def __init__(self, args): 78 | super(shared_state_embedding, self).__init__() 79 | self.args = args 80 | l1 = 400 81 | l2 = args.ls 82 | l3 = l2 83 | 84 | # Construct Hidden Layer 1 85 | self.w_l1 = nn.Linear(args.state_dim, l1) 86 | if self.args.use_ln: self.lnorm1 = LayerNorm(l1) 87 | 88 | # Hidden Layer 2 89 | self.w_l2 = nn.Linear(l1, l2) 90 | if self.args.use_ln: self.lnorm2 = LayerNorm(l2) 91 | # Init 92 | self.to(self.args.device) 93 | 94 | def forward(self, state): 95 | # Hidden Layer 1 96 | out = self.w_l1(state) 97 | if self.args.use_ln: out = self.lnorm1(out) 98 | out = out.tanh() 99 | 100 | # Hidden Layer 2 101 | out = self.w_l2(out) 102 | if self.args.use_ln: out = self.lnorm2(out) 103 | out = out.tanh() 104 | 105 | return out 106 | 107 | 108 | class Actor(nn.Module): 109 | def __init__(self, args, init=False): 110 | super(Actor, self).__init__() 111 | self.args = args 112 | l1 = args.ls; l2 = args.ls; l3 = l2 113 | # Out 114 | self.w_out = nn.Linear(l3, args.action_dim) 115 | # Init 116 | if init: 117 | self.w_out.weight.data.mul_(0.1) 118 | self.w_out.bias.data.mul_(0.1) 119 | 120 | self.to(self.args.device) 121 | 122 | def forward(self, input, state_embedding): 123 | s_z = state_embedding.forward(input) 124 | action = self.w_out(s_z).tanh() 125 | return action 126 | 127 | def select_action_from_z(self,s_z): 128 | 129 | action = self.w_out(s_z).tanh() 130 | return action 131 | 132 | def select_action(self, state, state_embedding): 133 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.args.device) 134 | return self.forward(state, state_embedding).cpu().data.numpy().flatten() 135 | 136 | def get_novelty(self, batch): 137 | state_batch, action_batch, _, _, _ = batch 138 | novelty = torch.mean(torch.sum((action_batch - self.forward(state_batch))**2, dim=-1)) 139 | return novelty.item() 140 | 141 | # function to return current pytorch gradient in same order as genome's flattened parameter vector 142 | def extract_grad(self): 143 | tot_size = self.count_parameters() 144 | pvec = torch.zeros(tot_size, dtype=torch.float32).to(self.args.device) 145 | count = 0 146 | for name, param in self.named_parameters(): 147 | if is_lnorm_key(name) or len(param.shape) != 2: 148 | continue 149 | sz = param.numel() 150 | pvec[count:count + sz] = param.grad.view(-1) 151 | count += sz 152 | return pvec.detach().clone() 153 | 154 | # function to grab current flattened neural network weights 155 | def extract_parameters(self): 156 | tot_size = self.count_parameters() 157 | pvec = torch.zeros(tot_size, dtype=torch.float32).to(self.args.device) 158 | count = 0 159 | for name, param in self.named_parameters(): 160 | if is_lnorm_key(name) or len(param.shape) != 2: 161 | continue 162 | sz = param.numel() 163 | pvec[count:count + sz] = param.view(-1) 164 | count += sz 165 | return pvec.detach().clone() 166 | 167 | # function to inject a flat vector of ANN parameters into the model's current neural network weights 168 | def inject_parameters(self, pvec): 169 | count = 0 170 | for name, param in self.named_parameters(): 171 | if is_lnorm_key(name) or len(param.shape) != 2: 172 | continue 173 | sz = param.numel() 174 | raw = pvec[count:count + sz] 175 | reshaped = raw.view(param.size()) 176 | param.data.copy_(reshaped.data) 177 | count += sz 178 | 179 | # count how many parameters are in the model 180 | def count_parameters(self): 181 | count = 0 182 | for name, param in self.named_parameters(): 183 | if is_lnorm_key(name) or len(param.shape) != 2: 184 | continue 185 | count += param.numel() 186 | return count 187 | 188 | 189 | class Critic(nn.Module): 190 | 191 | def __init__(self, args): 192 | super(Critic, self).__init__() 193 | self.args = args 194 | 195 | l1 = 400; 196 | l2 = 300; 197 | l3 = l2 198 | 199 | # Construct input interface (Hidden Layer 1) 200 | self.w_l1 = nn.Linear(args.state_dim+args.action_dim, l1) 201 | # Hidden Layer 2 202 | 203 | self.w_l2 = nn.Linear(l1, l2) 204 | if self.args.use_ln: 205 | self.lnorm1 = LayerNorm(l1) 206 | self.lnorm2 = LayerNorm(l2) 207 | 208 | # Out 209 | self.w_out = nn.Linear(l3, 1) 210 | self.w_out.weight.data.mul_(0.1) 211 | self.w_out.bias.data.mul_(0.1) 212 | 213 | self.w_l3 = nn.Linear(args.state_dim+args.action_dim, l1) 214 | # Hidden Layer 2 215 | self.w_l4 = nn.Linear(l1, l2) 216 | if self.args.use_ln: 217 | self.lnorm3 = LayerNorm(l1) 218 | self.lnorm4 = LayerNorm(l2) 219 | 220 | # Out 221 | self.w_out_2 = nn.Linear(l3, 1) 222 | self.w_out_2.weight.data.mul_(0.1) 223 | self.w_out_2.bias.data.mul_(0.1) 224 | 225 | self.to(self.args.device) 226 | 227 | def forward(self, input, action): 228 | 229 | # Hidden Layer 1 (Input Interface) 230 | concat_input = torch.cat([input,action],-1) 231 | 232 | out = self.w_l1(concat_input) 233 | if self.args.use_ln:out = self.lnorm1(out) 234 | 235 | out = F.leaky_relu(out) 236 | # Hidden Layer 2 237 | out = self.w_l2(out) 238 | if self.args.use_ln: out = self.lnorm2(out) 239 | out = F.leaky_relu(out) 240 | # Output interface 241 | out_1 = self.w_out(out) 242 | 243 | out_2 = self.w_l3(concat_input) 244 | if self.args.use_ln: out_2 = self.lnorm3(out_2) 245 | out_2 = F.leaky_relu(out_2) 246 | 247 | # Hidden Layer 2 248 | out_2 = self.w_l4(out_2) 249 | if self.args.use_ln: out_2 = self.lnorm4(out_2) 250 | out_2 = F.leaky_relu(out_2) 251 | 252 | # Output interface 253 | out_2 = self.w_out_2(out_2) 254 | 255 | return out_1, out_2 256 | 257 | def Q1(self, input, action): 258 | 259 | concat_input = torch.cat([input, action], -1) 260 | 261 | out = self.w_l1(concat_input) 262 | if self.args.use_ln:out = self.lnorm1(out) 263 | 264 | out = F.leaky_relu(out) 265 | # Hidden Layer 2 266 | out = self.w_l2(out) 267 | if self.args.use_ln: out = self.lnorm2(out) 268 | out = F.leaky_relu(out) 269 | # Output interface 270 | out_1 = self.w_out(out) 271 | return out_1 272 | 273 | 274 | 275 | 276 | class Policy_Value_Network(nn.Module): 277 | 278 | def __init__(self, args): 279 | super(Policy_Value_Network, self).__init__() 280 | self.args = args 281 | 282 | self.policy_size = self.args.ls * self.args.action_dim + self.args.action_dim 283 | 284 | l1 = 400; l2 = 300; l3 = l2 285 | self.l1 = l1 286 | # Construct input interface (Hidden Layer 1) 287 | 288 | if self.args.use_ln: 289 | self.lnorm1 = LayerNorm(l1) 290 | self.lnorm2 = LayerNorm(l2) 291 | self.lnorm3 = LayerNorm(l1) 292 | self.lnorm4 = LayerNorm(l2) 293 | self.policy_w_l1 = nn.Linear(self.args.ls + 1, self.args.pr) 294 | self.policy_w_l2 = nn.Linear(self.args.pr, self.args.pr) 295 | self.policy_w_l3 = nn.Linear(self.args.pr, self.args.pr) 296 | 297 | if self.args.OFF_TYPE == 1 : 298 | input_dim = self.args.state_dim + self.args.action_dim 299 | else: 300 | input_dim = self.args.ls 301 | 302 | self.w_l1 = nn.Linear(input_dim + self.args.pr, l1) 303 | # Hidden Layer 2 304 | 305 | self.w_l2 = nn.Linear(l1, l2) 306 | 307 | 308 | # Out 309 | self.w_out = nn.Linear(l3, 1) 310 | self.w_out.weight.data.mul_(0.1) 311 | self.w_out.bias.data.mul_(0.1) 312 | 313 | self.policy_w_l4 = nn.Linear(self.args.ls + 1, self.args.pr) 314 | self.policy_w_l5 = nn.Linear(self.args.pr, self.args.pr) 315 | self.policy_w_l6 = nn.Linear(self.args.pr, self.args.pr) 316 | 317 | self.w_l3 = nn.Linear(input_dim + self.args.pr, l1) 318 | # Hidden Layer 2 319 | 320 | self.w_l4 = nn.Linear(l1, l2) 321 | 322 | # Out 323 | self.w_out_2 = nn.Linear(l3, 1) 324 | self.w_out_2.weight.data.mul_(0.1) 325 | self.w_out_2.bias.data.mul_(0.1) 326 | 327 | self.to(self.args.device) 328 | 329 | def forward(self, input,param): 330 | reshape_param = param.reshape([-1,self.args.ls + 1]) 331 | 332 | out_p = F.leaky_relu(self.policy_w_l1(reshape_param)) 333 | out_p = F.leaky_relu(self.policy_w_l2(out_p)) 334 | out_p = self.policy_w_l3(out_p) 335 | out_p = out_p.reshape([-1,self.args.action_dim,self.args.pr]) 336 | out_p = torch.mean(out_p,dim=1) 337 | 338 | # Hidden Layer 1 (Input Interface) 339 | concat_input = torch.cat((input,out_p), 1) 340 | 341 | # Hidden Layer 2 342 | out = self.w_l1(concat_input) 343 | if self.args.use_ln: out = self.lnorm1(out) 344 | out = F.leaky_relu(out) 345 | out = self.w_l2(out) 346 | if self.args.use_ln: out = self.lnorm2(out) 347 | out = F.leaky_relu(out) 348 | 349 | # Output interface 350 | out_1 = self.w_out(out) 351 | 352 | out_p = F.leaky_relu(self.policy_w_l4(reshape_param)) 353 | out_p = F.leaky_relu(self.policy_w_l5(out_p)) 354 | out_p = self.policy_w_l6(out_p) 355 | out_p = out_p.reshape([-1, self.args.action_dim, self.args.pr]) 356 | out_p = torch.mean(out_p, dim=1) 357 | 358 | # Hidden Layer 1 (Input Interface) 359 | concat_input = torch.cat((input, out_p), 1) 360 | 361 | # Hidden Layer 2 362 | out = self.w_l3(concat_input) 363 | if self.args.use_ln: out = self.lnorm3(out) 364 | out = F.leaky_relu(out) 365 | 366 | out = self.w_l4(out) 367 | if self.args.use_ln: out = self.lnorm4(out) 368 | out = F.leaky_relu(out) 369 | 370 | # Output interface 371 | out_2 = self.w_out_2(out) 372 | 373 | 374 | return out_1, out_2 375 | 376 | def Q1(self, input, param): 377 | reshape_param = param.reshape([-1, self.args.ls + 1]) 378 | 379 | out_p = F.leaky_relu(self.policy_w_l1(reshape_param)) 380 | out_p = F.leaky_relu(self.policy_w_l2(out_p)) 381 | out_p = self.policy_w_l3(out_p) 382 | out_p = out_p.reshape([-1, self.args.action_dim, self.args.pr]) 383 | out_p = torch.mean(out_p, dim=1) 384 | 385 | # Hidden Layer 1 (Input Interface) 386 | 387 | # out_state = F.elu(self.w_state_l1(input)) 388 | # out_action = F.elu(self.w_action_l1(action)) 389 | concat_input = torch.cat((input, out_p), 1) 390 | 391 | # Hidden Layer 2 392 | out = self.w_l1(concat_input) 393 | if self.args.use_ln: out = self.lnorm1(out) 394 | out = F.leaky_relu(out) 395 | out = self.w_l2(out) 396 | if self.args.use_ln: out = self.lnorm2(out) 397 | out = F.leaky_relu(out) 398 | 399 | # Output interface 400 | out_1 = self.w_out(out) 401 | return out_1 402 | 403 | import random 404 | 405 | def caculate_prob(score): 406 | 407 | X = (score - np.min(score))/(np.max(score)-np.min(score) + 1e-8) 408 | max_X = np.max(X) 409 | 410 | exp_x = np.exp(X-max_X) 411 | sum_exp_x = np.sum(exp_x) 412 | prob = exp_x/sum_exp_x 413 | return prob 414 | 415 | class TD3(object): 416 | def __init__(self, args): 417 | self.args = args 418 | self.max_action = 1.0 419 | self.device = args.device 420 | self.actor = Actor(args, init=True) 421 | self.actor_target = Actor(args, init=True) 422 | self.actor_target.load_state_dict(self.actor.state_dict()) 423 | 424 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3) 425 | 426 | self.critic = Critic(args).to(self.device) 427 | self.critic_target = Critic(args).to(self.device) 428 | self.critic_target.load_state_dict(self.critic.state_dict()) 429 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=1e-3) 430 | 431 | self.buffer = replay_memory.ReplayMemory(args.individual_bs, args.device) 432 | 433 | 434 | self.PVN = Policy_Value_Network(args).to(self.device) 435 | self.PVN_Target = Policy_Value_Network(args).to(self.device) 436 | self.PVN_Target.load_state_dict(self.PVN.state_dict()) 437 | self.PVN_optimizer = torch.optim.Adam([{'params': self.PVN.parameters()}],lr=1e-3) 438 | 439 | self.state_embedding = shared_state_embedding(args) 440 | self.state_embedding_target = shared_state_embedding(args) 441 | self.state_embedding_target.load_state_dict(self.state_embedding.state_dict()) 442 | 443 | 444 | self.old_state_embedding = shared_state_embedding(args) 445 | self.state_embedding_optimizer = torch.optim.Adam(self.state_embedding.parameters(), lr=1e-3) 446 | 447 | def select_action(self, state): 448 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 449 | return self.actor(state).cpu().data.numpy().flatten() 450 | 451 | def train(self,evo_times,all_fitness, all_gen , on_policy_states, on_policy_params, on_policy_discount_rewards,on_policy_actions,replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.005, policy_noise=0.2, 452 | noise_clip=0.5, policy_freq=2, train_OFN_use_multi_actor= False,all_actor = None): 453 | actor_loss_list =[] 454 | critic_loss_list =[] 455 | pre_loss_list = [] 456 | pv_loss_list = [0.0] 457 | keep_c_loss = [0.0] 458 | 459 | for it in range(iterations): 460 | 461 | x, y, u, r, d, _ ,_= replay_buffer.sample(batch_size) 462 | state = torch.FloatTensor(x).to(self.device) 463 | action = torch.FloatTensor(u).to(self.device) 464 | next_state = torch.FloatTensor(y).to(self.device) 465 | done = torch.FloatTensor(1 - d).to(self.device) 466 | reward = torch.FloatTensor(r).to(self.device) 467 | 468 | if self.args.EA: 469 | if self.args.use_all: 470 | use_actors = all_actor 471 | else : 472 | index = random.sample(list(range(self.args.pop_size+1)), 1)[0] 473 | use_actors = [all_actor[index]] 474 | 475 | # off policy update 476 | pv_loss = 0.0 477 | for actor in use_actors: 478 | param = nn.utils.parameters_to_vector(list(actor.parameters())).data.cpu().numpy() 479 | param = torch.FloatTensor(param).to(self.device) 480 | param = param.repeat(len(state), 1) 481 | 482 | with torch.no_grad(): 483 | if self.args.OFF_TYPE == 1: 484 | input = torch.cat([next_state,actor.forward(next_state,self.state_embedding)],-1) 485 | else : 486 | input = self.state_embedding.forward(next_state) 487 | next_Q1, next_Q2 = self.PVN_Target.forward(input ,param) 488 | next_target_Q = torch.min(next_Q1,next_Q2) 489 | target_Q = reward + (done * discount * next_target_Q).detach() 490 | 491 | if self.args.OFF_TYPE == 1: 492 | input = torch.cat([state,action], -1) 493 | else: 494 | input = self.state_embedding.forward(state) 495 | 496 | current_Q1, current_Q2 = self.PVN.forward(input, param) 497 | pv_loss += F.mse_loss(current_Q1, target_Q)+ F.mse_loss(current_Q2, target_Q) 498 | 499 | self.PVN_optimizer.zero_grad() 500 | pv_loss.backward() 501 | nn.utils.clip_grad_norm_(self.PVN.parameters(), 10) 502 | self.PVN_optimizer.step() 503 | pv_loss_list.append(pv_loss.cpu().data.numpy().flatten()) 504 | else : 505 | pv_loss_list.append(0.0) 506 | 507 | # Select action according to policy and add clipped noise 508 | noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(self.device) 509 | noise = noise.clamp(-noise_clip, noise_clip) 510 | 511 | next_action = (self.actor_target.forward(next_state,self.state_embedding_target)+noise).clamp(-self.max_action, self.max_action) 512 | 513 | # Compute the target Q value 514 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 515 | target_Q = torch.min(target_Q1, target_Q2) 516 | target_Q = reward + (done * discount * target_Q).detach() 517 | 518 | # Get current Q estimates 519 | current_Q1, current_Q2 = self.critic(state, action) 520 | 521 | # Compute critic loss 522 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 523 | 524 | # Optimize the critic 525 | self.critic_optimizer.zero_grad() 526 | critic_loss.backward() 527 | nn.utils.clip_grad_norm_(self.critic.parameters(), 10) 528 | self.critic_optimizer.step() 529 | critic_loss_list.append(critic_loss.cpu().data.numpy().flatten()) 530 | 531 | # Delayed policy updates 532 | if it % policy_freq == 0: 533 | 534 | # Compute actor loss 535 | s_z= self.state_embedding.forward(state) 536 | actor_loss = -self.critic.Q1(state, self.actor.select_action_from_z(s_z)).mean() 537 | # Optimize the actor 538 | self.actor_optimizer.zero_grad() 539 | actor_loss.backward(retain_graph=True) 540 | nn.utils.clip_grad_norm_(self.actor.parameters(), 10) 541 | self.actor_optimizer.step() 542 | 543 | if self.args.EA: 544 | index = random.sample(list(range(self.args.pop_size+1)), self.args.K) 545 | new_actor_loss = 0.0 546 | 547 | if evo_times > 0 : 548 | for ind in index : 549 | actor = all_actor[ind] 550 | param = nn.utils.parameters_to_vector(list(actor.parameters())).data.cpu().numpy() 551 | param = torch.FloatTensor(param).to(self.device) 552 | param = param.repeat(len(state), 1) 553 | if self.args.OFF_TYPE == 1: 554 | input = torch.cat([state,actor.forward(state,self.state_embedding)], -1) 555 | else: 556 | input = self.state_embedding.forward(state) 557 | 558 | new_actor_loss += -self.PVN.Q1(input,param).mean() 559 | 560 | 561 | total_loss = self.args.actor_alpha * actor_loss + self.args.EA_actor_alpha* new_actor_loss 562 | else : 563 | total_loss = self.args.actor_alpha * actor_loss 564 | 565 | self.state_embedding_optimizer.zero_grad() 566 | total_loss.backward() 567 | nn.utils.clip_grad_norm_(self.state_embedding.parameters(), 10) 568 | self.state_embedding_optimizer.step() 569 | # Update the frozen target models 570 | 571 | for param, target_param in zip(self.state_embedding.parameters(), self.state_embedding_target.parameters()): 572 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 573 | 574 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 575 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 576 | 577 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 578 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 579 | 580 | for param, target_param in zip(self.PVN.parameters(), self.PVN_Target.parameters()): 581 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 582 | 583 | actor_loss_list.append(actor_loss.cpu().data.numpy().flatten()) 584 | pre_loss_list.append(0.0) 585 | 586 | return np.mean(actor_loss_list) , np.mean(critic_loss_list), np.mean(pre_loss_list),np.mean(pv_loss_list), np.mean(keep_c_loss) 587 | 588 | 589 | 590 | def fanin_init(size, fanin=None): 591 | v = 0.008 592 | return torch.Tensor(size).uniform_(-v, v) 593 | 594 | def actfn_none(inp): return inp 595 | 596 | class LayerNorm(nn.Module): 597 | 598 | def __init__(self, features, eps=1e-6): 599 | super().__init__() 600 | self.gamma = nn.Parameter(torch.ones(features)) 601 | self.beta = nn.Parameter(torch.zeros(features)) 602 | self.eps = eps 603 | 604 | def forward(self, x): 605 | mean = x.mean(-1, keepdim=True) 606 | std = x.std(-1, keepdim=True) 607 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 608 | 609 | class OUNoise: 610 | 611 | def __init__(self, action_dimension, scale=0.3, mu=0, theta=0.15, sigma=0.2): 612 | self.action_dimension = action_dimension 613 | self.scale = scale 614 | self.mu = mu 615 | self.theta = theta 616 | self.sigma = sigma 617 | self.state = np.ones(self.action_dimension) * self.mu 618 | self.reset() 619 | 620 | def reset(self): 621 | self.state = np.ones(self.action_dimension) * self.mu 622 | 623 | def noise(self): 624 | x = self.state 625 | dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x)) 626 | self.state = x + dx 627 | return self.state * self.scale 628 | -------------------------------------------------------------------------------- /core/mod_neuro_evo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from core.ddpg import GeneticAgent, hard_update 4 | from typing import List 5 | from core import replay_memory 6 | import fastrand, math 7 | import torch 8 | import torch.distributions as dist 9 | from core.mod_utils import is_lnorm_key 10 | from parameters import Parameters 11 | import os 12 | 13 | 14 | class SSNE: 15 | def __init__(self, args: Parameters, critic, evaluate, state_embedding, prob_reset_and_sup, frac): 16 | self.state_embedding = state_embedding 17 | self.current_gen = 0 18 | self.args = args; 19 | self.critic = critic 20 | self.prob_reset_and_sup = prob_reset_and_sup 21 | 22 | self.frac = frac 23 | self.population_size = self.args.pop_size 24 | self.num_elitists = int(self.args.elite_fraction * args.pop_size) 25 | self.evaluate = evaluate 26 | self.stats = PopulationStats(self.args) 27 | if self.num_elitists < 1: self.num_elitists = 1 28 | 29 | self.rl_policy = None 30 | self.selection_stats = {'elite': 0, 'selected': 0, 'discarded':0, 'total':0.0000001} 31 | 32 | def selection_tournament(self, index_rank, num_offsprings, tournament_size): 33 | total_choices = len(index_rank) 34 | offsprings = [] 35 | for i in range(num_offsprings): 36 | winner = np.min(np.random.randint(total_choices, size=tournament_size)) 37 | offsprings.append(index_rank[winner]) 38 | 39 | offsprings = list(set(offsprings)) # Find unique offsprings 40 | if len(offsprings) % 2 != 0: # Number of offsprings should be even 41 | offsprings.append(offsprings[fastrand.pcg32bounded(len(offsprings))]) 42 | return offsprings 43 | 44 | def list_argsort(self, seq): 45 | return sorted(range(len(seq)), key=seq.__getitem__) 46 | 47 | def regularize_weight(self, weight, mag): 48 | if weight > mag: weight = mag 49 | if weight < -mag: weight = -mag 50 | return weight 51 | 52 | def crossover_inplace(self, gene1: GeneticAgent, gene2: GeneticAgent): 53 | # Evaluate the parents 54 | trials = 5 55 | if self.args.opstat and self.stats.should_log(): 56 | test_score_p1 = 0 57 | for eval in range(trials): 58 | episode = self.evaluate(gene1, is_render=False, is_action_noise=False, store_transition=False) 59 | test_score_p1 += episode['reward'] 60 | test_score_p1 /= trials 61 | 62 | test_score_p2 = 0 63 | for eval in range(trials): 64 | episode = self.evaluate(gene2, is_render=False, is_action_noise=False, store_transition=False) 65 | test_score_p2 += episode['reward'] 66 | test_score_p2 /= trials 67 | 68 | b_1 = None 69 | b_2 = None 70 | for param1, param2 in zip(gene1.actor.parameters(), gene2.actor.parameters()): 71 | # References to the variable tensors 72 | W1 = param1.data 73 | W2 = param2.data 74 | if len(W1.shape) == 1: 75 | b_1 = W1 76 | b_2 = W2 77 | 78 | 79 | 80 | for param1, param2 in zip(gene1.actor.parameters(), gene2.actor.parameters()): 81 | # References to the variable tensors 82 | W1 = param1.data 83 | W2 = param2.data 84 | 85 | if len(W1.shape) == 2: #Weights no bias 86 | num_variables = W1.shape[0] 87 | # Crossover opertation [Indexed by row] 88 | num_cross_overs = fastrand.pcg32bounded(num_variables * 2) # Lower bounded on full swaps 89 | for i in range(num_cross_overs): 90 | receiver_choice = random.random() # Choose which gene to receive the perturbation 91 | if receiver_choice < 0.5: 92 | ind_cr = fastrand.pcg32bounded(W1.shape[0]) # 93 | W1[ind_cr,:] = W2[ind_cr,:] 94 | b_1[ind_cr] = b_2[ind_cr] 95 | else: 96 | ind_cr = fastrand.pcg32bounded(W1.shape[0]) # 97 | W2[ind_cr,:] = W1[ind_cr,:] 98 | b_2[ind_cr] = b_1[ind_cr] 99 | 100 | # Evaluate the children 101 | if self.args.opstat and self.stats.should_log(): 102 | test_score_c1 = 0 103 | for eval in range(trials): 104 | episode = self.evaluate(gene1, is_render=False, is_action_noise=False, store_transition=False) 105 | test_score_c1 += episode['reward'] 106 | test_score_c1 /= trials 107 | 108 | test_score_c2 = 0 109 | for eval in range(trials): 110 | episode = self.evaluate(gene1, is_render=False, is_action_noise=False, store_transition=False) 111 | test_score_c2 += episode['reward'] 112 | test_score_c2 /= trials 113 | 114 | if self.args.verbose_crossover: 115 | print("==================== Classic Crossover ======================") 116 | print("Parent 1", test_score_p1) 117 | print("Parent 2", test_score_p2) 118 | print("Child 1", test_score_c1) 119 | print("Child 2", test_score_c2) 120 | 121 | self.stats.add({ 122 | 'cros_parent1_fit': test_score_p1, 123 | 'cros_parent2_fit': test_score_p2, 124 | 'cros_child_fit': np.mean([test_score_c1, test_score_c2]), 125 | 'cros_child1_fit': test_score_c1, 126 | 'cros_child2_fit': test_score_c2, 127 | }) 128 | 129 | def distilation_crossover(self, gene1: GeneticAgent, gene2: GeneticAgent): 130 | new_agent = GeneticAgent(self.args) 131 | new_agent.buffer.add_latest_from(gene1.buffer, self.args.individual_bs // 2) 132 | new_agent.buffer.add_latest_from(gene2.buffer, self.args.individual_bs // 2) 133 | new_agent.buffer.shuffle() 134 | 135 | hard_update(new_agent.actor, gene2.actor) 136 | batch_size = min(128, len(new_agent.buffer)) 137 | iters = len(new_agent.buffer) // batch_size 138 | losses = [] 139 | for epoch in range(12): 140 | for i in range(iters): 141 | batch = new_agent.buffer.sample(batch_size) 142 | losses.append(new_agent.update_parameters(batch, gene1.actor, gene2.actor, self.critic)) 143 | 144 | if self.args.opstat and self.stats.should_log(): 145 | 146 | test_score_p1 = 0 147 | trials = 5 148 | for eval in range(trials): 149 | episode = self.evaluate(gene1, is_render=False, is_action_noise=False, store_transition=False) 150 | test_score_p1 += episode['reward'] 151 | test_score_p1 /= trials 152 | 153 | test_score_p2 = 0 154 | for eval in range(trials): 155 | episode = self.evaluate(gene2, is_render=False, is_action_noise=False, store_transition=False) 156 | test_score_p2 += episode['reward'] 157 | test_score_p2 /= trials 158 | 159 | test_score_c = 0 160 | for eval in range(trials): 161 | episode = self.evaluate(new_agent, is_render=False, is_action_noise=False, store_transition=False) 162 | test_score_c += episode['reward'] 163 | test_score_c /= trials 164 | 165 | if self.args.verbose_crossover: 166 | print("==================== Distillation Crossover ======================") 167 | print("MSE Loss:", np.mean(losses[-40:])) 168 | print("Parent 1", test_score_p1) 169 | print("Parent 2", test_score_p2) 170 | print("Crossover performance: ", test_score_c) 171 | 172 | self.stats.add({ 173 | 'cros_parent1_fit': test_score_p1, 174 | 'cros_parent2_fit': test_score_p2, 175 | 'cros_child_fit': test_score_c, 176 | }) 177 | 178 | return new_agent 179 | 180 | def mutate_inplace(self, gene: GeneticAgent): 181 | trials = 5 182 | if self.stats.should_log(): 183 | test_score_p = 0 184 | for eval in range(trials): 185 | episode = self.evaluate(gene, is_render=False, is_action_noise=False, store_transition=False) 186 | test_score_p += episode['reward'] 187 | test_score_p /= trials 188 | 189 | mut_strength = 0.1 190 | num_mutation_frac = 0.1 191 | super_mut_strength = 10 192 | super_mut_prob = self.prob_reset_and_sup 193 | reset_prob = super_mut_prob + self.prob_reset_and_sup 194 | 195 | num_params = len(list(gene.actor.parameters())) 196 | ssne_probabilities = np.random.uniform(0, 1, num_params) * 2 197 | model_params = gene.actor.state_dict() 198 | 199 | for i, key in enumerate(model_params): #Mutate each param 200 | 201 | if is_lnorm_key(key): 202 | continue 203 | 204 | # References to the variable keys 205 | W = model_params[key] 206 | if len(W.shape) == 2: #Weights, no bias 207 | 208 | ssne_prob = ssne_probabilities[i] 209 | 210 | if random.random() < ssne_prob: 211 | num_variables = W.shape[0] 212 | # Crossover opertation [Indexed by row] 213 | for index in range(num_variables): 214 | #ind_dim1 = fastrand.pcg32bounded(W.shape[0]) 215 | #ind_dim2 = fastrand.pcg32bounded(W.shape[-1]) 216 | 217 | 218 | random_num_num = random.random() 219 | if random_num_num < 1.0 : 220 | #print(W) 221 | index_list = random.sample(range(W.shape[1]), int(W.shape[1]*self.frac) ) 222 | random_num = random.random() 223 | if random_num < super_mut_prob: # Super Mutation probability 224 | for ind in index_list: 225 | W[index, ind] += random.gauss(0, super_mut_strength * W[index, ind]) 226 | elif random_num < reset_prob: # Reset probability 227 | for ind in index_list: 228 | W[index, ind] = random.gauss(0, 1) 229 | else: # mutation even normal 230 | for ind in index_list: 231 | W[index, ind] += random.gauss(0, mut_strength *W[index, ind]) 232 | 233 | # Regularization hard limit 234 | W[index, :] = np.clip(W[index, :], a_min=-1000000, a_max=1000000) 235 | 236 | if self.stats.should_log(): 237 | test_score_c = 0 238 | for eval in range(trials): 239 | episode = self.evaluate(gene, is_render=False, is_action_noise=False, store_transition=False) 240 | test_score_c += episode['reward'] 241 | test_score_c /= trials 242 | 243 | self.stats.add({ 244 | 'mut_parent_fit': test_score_p, 245 | 'mut_child_fit': test_score_c, 246 | }) 247 | 248 | if self.args.verbose_crossover: 249 | print("==================== Mutation ======================") 250 | print("Fitness before: ", test_score_p) 251 | print("Fitness after: ", test_score_c) 252 | 253 | def proximal_mutate(self, gene: GeneticAgent, mag): 254 | # Based on code from https://github.com/uber-research/safemutations 255 | trials = 5 256 | if self.stats.should_log(): 257 | test_score_p = 0 258 | for eval in range(trials): 259 | episode = self.evaluate(gene, is_render=False, is_action_noise=False, store_transition=False) 260 | test_score_p += episode['reward'] 261 | test_score_p /= trials 262 | 263 | model = gene.actor 264 | 265 | batch = gene.buffer.sample(min(self.args.mutation_batch_size, len(gene.buffer))) 266 | state, _, _, _, _ = batch 267 | output = model(state, self.state_embedding) 268 | 269 | params = model.extract_parameters() 270 | tot_size = model.count_parameters() 271 | num_outputs = output.size()[1] 272 | 273 | if self.args.mutation_noise: 274 | mag_dist = dist.Normal(self.args.mutation_mag, 0.02) 275 | mag = mag_dist.sample() 276 | 277 | # initial perturbation 278 | normal = dist.Normal(torch.zeros_like(params), torch.ones_like(params) * mag) 279 | delta = normal.sample() 280 | # uniform = delta.clone().detach().data.uniform_(0, 1) 281 | # delta[uniform > 0.1] = 0.0 282 | 283 | # we want to calculate a jacobian of derivatives of each output's sensitivity to each parameter 284 | jacobian = torch.zeros(num_outputs, tot_size).to(self.args.device) 285 | grad_output = torch.zeros(output.size()).to(self.args.device) 286 | 287 | # do a backward pass for each output 288 | for i in range(num_outputs): 289 | model.zero_grad() 290 | grad_output.zero_() 291 | grad_output[:, i] = 1.0 292 | 293 | output.backward(grad_output, retain_graph=True) 294 | jacobian[i] = model.extract_grad() 295 | 296 | # summed gradients sensitivity 297 | scaling = torch.sqrt((jacobian**2).sum(0)) 298 | scaling[scaling == 0] = 1.0 299 | scaling[scaling < 0.01] = 0.01 300 | delta /= scaling 301 | new_params = params + delta 302 | 303 | model.inject_parameters(new_params) 304 | 305 | if self.stats.should_log(): 306 | test_score_c = 0 307 | for eval in range(trials): 308 | episode = self.evaluate(gene, is_render=False, is_action_noise=False, store_transition=False) 309 | test_score_c += episode['reward'] 310 | test_score_c /= trials 311 | 312 | self.stats.add({ 313 | 'mut_parent_fit': test_score_p, 314 | 'mut_child_fit': test_score_c, 315 | }) 316 | 317 | if self.args.verbose_crossover: 318 | print("==================== Mutation ======================") 319 | print("Fitness before: ", test_score_p) 320 | print("Fitness after: ", test_score_c) 321 | print("Mean mutation change:", torch.mean(torch.abs(new_params - params)).item()) 322 | 323 | def clone(self, master: GeneticAgent, replacee: GeneticAgent): # Replace the replacee individual with master 324 | for target_param, source_param in zip(replacee.actor.parameters(), master.actor.parameters()): 325 | target_param.data.copy_(source_param.data) 326 | replacee.buffer.reset() 327 | replacee.buffer.add_content_of(master.buffer) 328 | 329 | def reset_genome(self, gene: GeneticAgent): 330 | for param in (gene.actor.parameters()): 331 | param.data.copy_(param.data) 332 | 333 | @staticmethod 334 | def sort_groups_by_fitness(genomes, fitness): 335 | groups = [] 336 | for i, first in enumerate(genomes): 337 | for second in genomes[i+1:]: 338 | if fitness[first] < fitness[second]: 339 | groups.append((second, first, fitness[first] + fitness[second])) 340 | else: 341 | groups.append((first, second, fitness[first] + fitness[second])) 342 | return sorted(groups, key=lambda group: group[2], reverse=True) 343 | 344 | @staticmethod 345 | def get_distance(gene1: GeneticAgent, gene2: GeneticAgent): 346 | batch_size = min(256, min(len(gene1.buffer), len(gene2.buffer))) 347 | batch_gene1 = gene1.buffer.sample_from_latest(batch_size, 1000) 348 | batch_gene2 = gene2.buffer.sample_from_latest(batch_size, 1000) 349 | 350 | return gene1.actor.get_novelty(batch_gene2) + gene2.actor.get_novelty(batch_gene1) 351 | 352 | @staticmethod 353 | def sort_groups_by_distance(genomes, pop): 354 | groups = [] 355 | for i, first in enumerate(genomes): 356 | for second in genomes[i+1:]: 357 | groups.append((second, first, SSNE.get_distance(pop[first], pop[second]))) 358 | return sorted(groups, key=lambda group: group[2], reverse=True) 359 | 360 | def epoch(self, pop: List[GeneticAgent], fitness_evals): 361 | # Entire epoch is handled with indices; Index rank nets by fitness evaluation (0 is the best after reversing) 362 | index_rank = np.argsort(fitness_evals)[::-1] 363 | elitist_index = index_rank[:self.num_elitists] # Elitist indexes safeguard 364 | 365 | # Selection step 366 | offsprings = self.selection_tournament(index_rank, num_offsprings=len(index_rank) - self.num_elitists, 367 | tournament_size=3) 368 | 369 | # Figure out unselected candidates 370 | unselects = []; new_elitists = [] 371 | for i in range(self.population_size): 372 | if i not in offsprings and i not in elitist_index: 373 | unselects.append(i) 374 | random.shuffle(unselects) 375 | 376 | # COMPUTE RL_SELECTION RATE 377 | if self.rl_policy is not None: # RL Transfer happened 378 | self.selection_stats['total'] += 1.0 379 | 380 | if self.rl_policy in elitist_index: self.selection_stats['elite'] += 1.0 381 | elif self.rl_policy in offsprings: self.selection_stats['selected'] += 1.0 382 | elif self.rl_policy in unselects: self.selection_stats['discarded'] += 1.0 383 | self.rl_policy = None 384 | 385 | # Elitism step, assigning elite candidates to some unselects 386 | for i in elitist_index: 387 | try: replacee = unselects.pop(0) 388 | except: replacee = offsprings.pop(0) 389 | new_elitists.append(replacee) 390 | self.clone(master=pop[i], replacee=pop[replacee]) 391 | 392 | # Crossover between elite and offsprings for the unselected genes with 100 percent probability 393 | if self.args.distil: 394 | if self.args.distil_type == 'fitness': 395 | sorted_groups = SSNE.sort_groups_by_fitness(new_elitists + offsprings, fitness_evals) 396 | elif self.args.distil_type == 'dist': 397 | sorted_groups = SSNE.sort_groups_by_distance(new_elitists + offsprings, pop) 398 | else: 399 | raise NotImplementedError('Unknown distilation type') 400 | for i, unselected in enumerate(unselects): 401 | first, second, _ = sorted_groups[i % len(sorted_groups)] 402 | if fitness_evals[first] < fitness_evals[second]: 403 | first, second = second, first 404 | self.clone(self.distilation_crossover(pop[first], pop[second]), pop[unselected]) 405 | else: 406 | if len(unselects) % 2 != 0: # Number of unselects left should be even 407 | unselects.append(unselects[fastrand.pcg32bounded(len(unselects))]) 408 | for i, j in zip(unselects[0::2], unselects[1::2]): 409 | off_i = random.choice(new_elitists) 410 | off_j = random.choice(offsprings) 411 | self.clone(master=pop[off_i], replacee=pop[i]) 412 | self.clone(master=pop[off_j], replacee=pop[j]) 413 | self.crossover_inplace(pop[i], pop[j]) 414 | 415 | # Crossover for selected offsprings 416 | for i in offsprings: 417 | if random.random() < self.args.crossover_prob: 418 | others = offsprings.copy() 419 | others.remove(i) 420 | off_j = random.choice(others) 421 | self.clone(self.distilation_crossover(pop[i], pop[off_j]), pop[i]) 422 | 423 | # Mutate all genes in the population except the new elitists 424 | for i in range(self.population_size): 425 | if i not in new_elitists: # Spare the new elitists 426 | if random.random() < self.args.mutation_prob: 427 | if self.args.proximal_mut: 428 | self.proximal_mutate(pop[i], mag=self.args.mutation_mag) 429 | else: 430 | self.mutate_inplace(pop[i]) 431 | 432 | if self.stats.should_log(): 433 | self.stats.log() 434 | self.stats.reset() 435 | return new_elitists[0] 436 | 437 | 438 | def unsqueeze(array, axis=1): 439 | if axis == 0: return np.reshape(array, (1, len(array))) 440 | elif axis == 1: return np.reshape(array, (len(array), 1)) 441 | 442 | 443 | class PopulationStats: 444 | def __init__(self, args: Parameters, file='population.csv'): 445 | self.data = {} 446 | self.args = args 447 | self.save_path = os.path.join(args.save_foldername, file) 448 | self.generation = 0 449 | 450 | if not os.path.exists(args.save_foldername): 451 | os.makedirs(args.save_foldername) 452 | 453 | def add(self, res): 454 | for k, v in res.items(): 455 | if k not in self.data: 456 | self.data[k] = [] 457 | self.data[k].append(v) 458 | 459 | def log(self): 460 | with open(self.save_path, 'a+') as f: 461 | if self.generation == 0: 462 | f.write('generation,') 463 | for i, k in enumerate(self.data): 464 | if i > 0: 465 | f.write(',') 466 | f.write(k) 467 | f.write('\n') 468 | 469 | f.write(str(self.generation)) 470 | f.write(',') 471 | for i, k in enumerate(self.data): 472 | if i > 0: 473 | f.write(',') 474 | f.write(str(np.mean(self.data[k]))) 475 | f.write('\n') 476 | 477 | def should_log(self): 478 | return self.generation % self.args.opstat_freq == 0 and self.args.opstat 479 | 480 | def reset(self): 481 | for k in self.data: 482 | self.data[k] = [] 483 | self.generation += 1 484 | 485 | 486 | -------------------------------------------------------------------------------- /core/mod_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.autograd import Variable 3 | import random, pickle 4 | import numpy as np 5 | import torch 6 | import os 7 | import gym 8 | 9 | 10 | class Tracker: 11 | def __init__(self, parameters, vars_string, project_string): 12 | self.vars_string = vars_string; self.project_string = project_string 13 | self.foldername = parameters.save_foldername 14 | self.all_tracker = [[[],0.0,[]] for _ in vars_string] # [Id of var tracked][fitnesses, avg_fitness, csv_fitnesses] 15 | self.counter = 0 16 | self.conv_size = 10 17 | if not os.path.exists(self.foldername): 18 | os.makedirs(self.foldername) 19 | 20 | def update(self, updates, generation): 21 | self.counter += 1 22 | for update, var in zip(updates, self.all_tracker): 23 | if update == None: continue 24 | var[0].append(update) 25 | 26 | # Constrain size of convolution 27 | for var in self.all_tracker: 28 | if len(var[0]) > self.conv_size: var[0].pop(0) 29 | 30 | # Update new average 31 | for var in self.all_tracker: 32 | if len(var[0]) == 0: continue 33 | var[1] = sum(var[0])/float(len(var[0])) 34 | 35 | if self.counter % 4 == 0: # Save to csv file 36 | for i, var in enumerate(self.all_tracker): 37 | if len(var[0]) == 0: continue 38 | var[2].append(np.array([generation, var[1]])) 39 | filename = os.path.join(self.foldername, self.vars_string[i] + self.project_string) 40 | try: 41 | np.savetxt(filename, np.array(var[2]), fmt='%.3f', delimiter=',') 42 | except: 43 | # Common error showing up in the cluster for unknown reasons 44 | print('Failed to save progress') 45 | 46 | 47 | class Memory: # stored as ( s, a, r, s_ ) in SumTree 48 | e = 0.01 49 | a = 0.6 50 | 51 | def __init__(self, capacity): 52 | self.tree = SumTree(capacity) 53 | 54 | def _getPriority(self, error): 55 | return (error + self.e) ** self.a 56 | 57 | def add(self, error, sample): 58 | p = self._getPriority(error) 59 | self.tree.add(p, sample) 60 | 61 | def sample(self, n): 62 | batch = [] 63 | segment = self.tree.total() / n 64 | 65 | for i in range(n): 66 | a = segment * i 67 | b = segment * (i + 1) 68 | 69 | s = random.uniform(a, b) 70 | (idx, p, data) = self.tree.get(s) 71 | batch.append( (idx, data) ) 72 | 73 | return batch 74 | 75 | def update(self, idx, error): 76 | p = self._getPriority(error) 77 | self.tree.update(idx, p) 78 | 79 | 80 | class SumTree: 81 | write = 0 82 | 83 | def __init__(self, capacity): 84 | self.capacity = capacity 85 | self.tree = np.zeros( 2*capacity - 1 ) 86 | self.data = np.zeros( capacity, dtype=object ) 87 | 88 | def _propagate(self, idx, change): 89 | parent = (idx - 1) // 2 90 | 91 | self.tree[parent] += change 92 | 93 | if parent != 0: 94 | self._propagate(parent, change) 95 | 96 | def _retrieve(self, idx, s): 97 | left = 2 * idx + 1 98 | right = left + 1 99 | 100 | if left >= len(self.tree): 101 | return idx 102 | 103 | if s <= self.tree[left]: 104 | return self._retrieve(left, s) 105 | else: 106 | return self._retrieve(right, s-self.tree[left]) 107 | 108 | def total(self): 109 | return self.tree[0] 110 | 111 | def add(self, p, data): 112 | idx = self.write + self.capacity - 1 113 | 114 | self.data[self.write] = data 115 | self.update(idx, p) 116 | 117 | self.write += 1 118 | if self.write >= self.capacity: 119 | self.write = 0 120 | 121 | def update(self, idx, p): 122 | change = p - self.tree[idx] 123 | 124 | self.tree[idx] = p 125 | self._propagate(idx, change) 126 | 127 | def get(self, s): 128 | idx = self._retrieve(0, s) 129 | dataIdx = idx - self.capacity + 1 130 | 131 | return (idx, self.tree[idx], self.data[dataIdx]) 132 | 133 | 134 | class NormalizedActions(gym.ActionWrapper): 135 | 136 | def action(self, action): 137 | action = (action + 1) / 2 # [-1, 1] => [0, 1] 138 | action *= (self.action_space.high - self.action_space.low) 139 | action += self.action_space.low 140 | return action 141 | 142 | def _reverse_action(self, action): 143 | action -= self.action_space.low 144 | action /= (self.action_space.high - self.action_space.low) 145 | action = action * 2 - 1 146 | return action 147 | 148 | 149 | def fanin_init(size, fanin=None): 150 | fanin = fanin or size[0] 151 | #v = 1. / np.sqrt(fanin) 152 | v = 0.008 153 | return torch.Tensor(size).uniform_(-v, v) 154 | 155 | def to_numpy(var): 156 | return var.data.numpy() 157 | 158 | def to_tensor(ndarray, volatile=False, requires_grad=False): 159 | return Variable(torch.from_numpy(ndarray).float(), volatile=volatile, requires_grad=requires_grad) 160 | 161 | def pickle_obj(filename, object): 162 | handle = open(filename, "wb") 163 | pickle.dump(object, handle) 164 | 165 | def unpickle_obj(filename): 166 | with open(filename, 'rb') as f: 167 | return pickle.load(f) 168 | 169 | def odict_to_numpy(odict): 170 | l = list(odict.values()) 171 | state = l[0] 172 | for i in range(1, len(l)): 173 | if isinstance(l[i], np.ndarray): 174 | state = np.concatenate((state, l[i])) 175 | else: #Floats 176 | state = np.concatenate((state, np.array([l[i]]))) 177 | return state 178 | 179 | def min_max_normalize(x): 180 | min_x = np.min(x) 181 | max_x = np.max(x) 182 | return (x - min_x) / (max_x - min_x) 183 | 184 | def is_lnorm_key(key): 185 | return key.startswith('lnorm') 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /core/operator_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from core import ddpg 6 | from core import mod_neuro_evo 7 | 8 | 9 | class OperatorRunner: 10 | def __init__(self, args, env): 11 | self.env = env 12 | self.args = args 13 | 14 | def load_genetic_agent(self, source, model): 15 | actor_path = os.path.join(source, 'evo_net_actor_{}.pkl'.format(model)) 16 | buffer_path = os.path.join(source, 'champion_buffer_{}.pkl'.format(model)) 17 | 18 | agent = ddpg.GeneticAgent(self.args) 19 | agent.actor.load_state_dict(torch.load(actor_path)) 20 | with open(buffer_path, 'rb') as file: 21 | agent.buffer = pickle.load(file) 22 | 23 | return agent 24 | 25 | def evaluate(self, agent, trials=10): 26 | results = [] 27 | states = [] 28 | for trial in range(trials): 29 | total_reward = 0 30 | 31 | state = self.env.reset() 32 | if trial < 3: 33 | states.append(state) 34 | done = False 35 | while not done: 36 | action = agent.actor.select_action(np.array(state)) 37 | 38 | # Simulate one step in environment 39 | next_state, reward, done, info = self.env.step(action.flatten()) 40 | total_reward += reward 41 | state = next_state 42 | if trial < 3: 43 | states.append(state) 44 | 45 | results.append(total_reward) 46 | return np.mean(results), np.array(states) 47 | 48 | def test_crossover(self): 49 | source_dir = 'exp/cheetah_sm0.1_distil_save_20/models/' 50 | models = [1400, 1600, 1800, 2200] 51 | 52 | parent1 = [] 53 | parent2 = [] 54 | normal_cro = [] 55 | distil_cro = [] 56 | p1s, p2s, ncs, dcs = [], [], [], [] 57 | for i, model1 in enumerate(models): 58 | for j, model2 in enumerate(models): 59 | if j > i: 60 | print("========== Crossover between {} and {} ==============".format(model1, model2)) 61 | critic = ddpg.Critic(self.args) 62 | critic_path = os.path.join(source_dir, 'evo_net_critic_{}.pkl'.format(model2)) 63 | critic.load_state_dict(torch.load(critic_path)) 64 | 65 | agent1 = self.load_genetic_agent(source_dir, model1) 66 | agent2 = self.load_genetic_agent(source_dir, model2) 67 | 68 | p1_reward, p1_states = self.evaluate(agent1) 69 | p2_reward, p2_states = self.evaluate(agent2) 70 | parent1.append(p1_reward) 71 | parent2.append(p2_reward) 72 | p1s.append(p1_states) 73 | p2s.append(p2_states) 74 | 75 | ssne = mod_neuro_evo.SSNE(self.args, critic, None) 76 | child1 = ddpg.GeneticAgent(self.args) 77 | child2 = ddpg.GeneticAgent(self.args) 78 | ssne.clone(agent1, child1) 79 | ssne.clone(agent2, child2) 80 | 81 | ssne.crossover_inplace(child1, child2) 82 | 83 | c1_reward, c1_states = self.evaluate(child1) 84 | normal_cro.append(c1_reward) 85 | ncs.append(c1_states) 86 | 87 | child = ssne.distilation_crossover(agent1, agent2) 88 | c_reward, c_states = self.evaluate(child) 89 | distil_cro.append(c_reward) 90 | dcs.append(c_states) 91 | 92 | print(parent1[-1]) 93 | print(parent2[-1]) 94 | print(normal_cro[-1]) 95 | print(distil_cro[-1]) 96 | print() 97 | 98 | save_file = 'visualise/crossover' 99 | np.savez(save_file, p1=parent1, p2=parent2, nc=normal_cro, dc=distil_cro, p1s=p1s, p2s=p2s, ncs=ncs, dcs=dcs) 100 | 101 | def test_mutation(self): 102 | models = [800, 1400, 1600, 1800, 2200] 103 | source_dir = 'exp/cheetah_sm0.1_distil_save_20/models/' 104 | 105 | pr, nmr, smr = [], [], [] 106 | ps, nms, sms = [], [], [] 107 | ssne = mod_neuro_evo.SSNE(self.args, None, None) 108 | for i, model in enumerate(models): 109 | print("========== Mutation for {} ==============".format(model)) 110 | agent = self.load_genetic_agent(source_dir, model) 111 | p_reward, p_states = self.evaluate(agent) 112 | pr.append(p_reward) 113 | ps.append(p_states) 114 | 115 | nchild = ddpg.GeneticAgent(self.args) 116 | ssne.clone(agent, nchild) 117 | ssne.mutate_inplace(nchild) 118 | 119 | nm_reward, nm_states = self.evaluate(nchild) 120 | nmr.append(nm_reward) 121 | nms.append(nm_states) 122 | 123 | dchild = ddpg.GeneticAgent(self.args) 124 | ssne.clone(agent, dchild) 125 | ssne.proximal_mutate(dchild, 0.05) 126 | sm_reward, sm_states = self.evaluate(dchild) 127 | smr.append(sm_reward) 128 | sms.append(sm_states) 129 | 130 | print("Parent", pr[-1]) 131 | print("Normal", nmr[-1]) 132 | print("Safe", smr[-1]) 133 | 134 | # Ablation for safe mutation 135 | ablation_mag = [0.0, 0.01, 0.05, 0.1, 0.2] 136 | agent = self.load_genetic_agent(source_dir, 2200) 137 | ablr = [] 138 | abls = [] 139 | for mag in ablation_mag: 140 | dchild = ddpg.GeneticAgent(self.args) 141 | ssne.clone(agent, dchild) 142 | ssne.proximal_mutate(dchild, mag) 143 | 144 | sm_reward, sm_states = self.evaluate(dchild) 145 | ablr.append(sm_reward) 146 | abls.append(sm_states) 147 | 148 | save_file = 'visualise/mutation' 149 | np.savez(save_file, pr=pr, nmr=nmr, smr=smr, ps=ps, nms=nms, sms=sms, ablr=ablr, abls=abls, 150 | abl_mag=ablation_mag) 151 | 152 | def run(self): 153 | self.test_crossover() 154 | self.test_mutation() 155 | -------------------------------------------------------------------------------- /core/replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | from collections import namedtuple 5 | from core import mod_utils as utils 6 | 7 | # Taken and adapted from 8 | # https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb 9 | 10 | Transition = namedtuple( 11 | 'Transition', ('state', 'action', 'next_state', 'reward', 'done')) 12 | 13 | 14 | class ReplayMemory(object): 15 | 16 | def __init__(self, capacity, device): 17 | self.device = device 18 | self.capacity = capacity 19 | self.memory = [] 20 | self.position = 0 21 | 22 | def add(self, *args): 23 | """Saves a transition.""" 24 | if len(self.memory) < self.capacity: 25 | self.memory.append(None) 26 | 27 | reshaped_args = [] 28 | for arg in args: 29 | reshaped_args.append(np.reshape(arg, (1, -1))) 30 | 31 | self.memory[self.position] = Transition(*reshaped_args) 32 | self.position = (self.position + 1) % self.capacity 33 | 34 | def add_content_of(self, other): 35 | """ 36 | Adds the content of another replay buffer to this replay buffer 37 | :param other: another replay buffer 38 | """ 39 | latest_trans = other.get_latest(self.capacity) 40 | for transition in latest_trans: 41 | self.add(*transition) 42 | 43 | def get_latest(self, latest): 44 | """ 45 | Returns the latest element from the other buffer with the most recent ones at the end of the returned list 46 | :param other: another replay buffer 47 | :param latest: the number of latest elements to return 48 | :return: a list with the latest elements 49 | """ 50 | if self.capacity < latest: 51 | latest_trans = self.memory[self.position:].copy() + self.memory[:self.position].copy() 52 | elif len(self.memory) < self.capacity: 53 | latest_trans = self.memory[-latest:].copy() 54 | elif self.position >= latest: 55 | latest_trans = self.memory[:self.position][-latest:].copy() 56 | else: 57 | latest_trans = self.memory[-latest+self.position:].copy() + self.memory[:self.position].copy() 58 | return latest_trans 59 | 60 | def add_latest_from(self, other, latest): 61 | """ 62 | Adds the latest samples from the other buffer to this buffer 63 | :param other: another replay buffer 64 | :param latest: the number of elements to add 65 | """ 66 | latest_trans = other.get_latest(latest) 67 | for transition in latest_trans: 68 | self.add(*transition) 69 | 70 | def shuffle(self): 71 | random.shuffle(self.memory) 72 | 73 | def sample(self, batch_size): 74 | transitions = random.sample(self.memory, batch_size) 75 | batch = Transition(*zip(*transitions)) 76 | 77 | state = torch.FloatTensor(np.concatenate(batch.state)).to(self.device) 78 | action = torch.FloatTensor(np.concatenate(batch.action)).to(self.device) 79 | next_state = torch.FloatTensor(np.concatenate(batch.next_state)).to(self.device) 80 | reward = torch.FloatTensor(np.concatenate(batch.reward)).to(self.device) 81 | done = torch.FloatTensor(np.concatenate(batch.done)).to(self.device) 82 | return state, action, next_state, reward, done 83 | 84 | def sample_from_latest(self, batch_size, latest): 85 | latest_trans = self.get_latest(latest) 86 | transitions = random.sample(latest_trans, batch_size) 87 | batch = Transition(*zip(*transitions)) 88 | 89 | state = torch.FloatTensor(np.concatenate(batch.state)).to(self.device) 90 | action = torch.FloatTensor(np.concatenate(batch.action)).to(self.device) 91 | next_state = torch.FloatTensor(np.concatenate(batch.next_state)).to(self.device) 92 | reward = torch.FloatTensor(np.concatenate(batch.reward)).to(self.device) 93 | done = torch.FloatTensor(np.concatenate(batch.done)).to(self.device) 94 | return state, action, next_state, reward, done 95 | 96 | def __len__(self): 97 | return len(self.memory) 98 | 99 | def reset(self): 100 | self.memory = [] 101 | self.position = 0 102 | 103 | 104 | class PrioritizedReplayMemory(object): 105 | def __init__(self, capacity, device, alpha=0.6, beta_start=0.4, beta_frames=100000): 106 | self.prob_alpha = alpha 107 | self.capacity = capacity 108 | self.buffer = [] 109 | self.pos = 0 110 | self.priorities = np.zeros((capacity,), dtype=np.float32) 111 | self.frame = 1 112 | self.beta_start = beta_start 113 | self.beta_frames = beta_frames 114 | self.device = device 115 | 116 | def beta_by_frame(self, frame_idx): 117 | return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames) 118 | 119 | def push(self, transition: Transition): 120 | max_prio = self.priorities.max() if self.buffer else 1.0 ** self.prob_alpha 121 | 122 | if len(self.buffer) < self.capacity: 123 | self.buffer.append(transition) 124 | else: 125 | self.buffer[self.pos] = transition 126 | 127 | self.priorities[self.pos] = max_prio 128 | 129 | self.pos = (self.pos + 1) % self.capacity 130 | 131 | def sample(self, batch_size): 132 | if len(self.buffer) == self.capacity: 133 | prios = self.priorities 134 | else: 135 | prios = self.priorities[:self.pos] 136 | 137 | total = len(self.buffer) 138 | 139 | probs = prios / prios.sum() 140 | 141 | indices = np.random.choice(total, batch_size, p=probs) 142 | samples = [self.buffer[idx] for idx in indices] 143 | 144 | beta = self.beta_by_frame(self.frame) 145 | self.frame += 1 146 | 147 | # min of ALL probs, not just sampled probs 148 | prob_min = probs.min() 149 | max_weight = (prob_min * total) ** (-beta) 150 | 151 | weights = (total * probs[indices]) ** (-beta) 152 | weights /= max_weight 153 | weights = torch.tensor(weights, device=self.device, dtype=torch.float) 154 | 155 | return samples, indices, weights 156 | 157 | def update_priorities(self, batch_indices, batch_priorities): 158 | for idx, prio in zip(batch_indices, batch_priorities): 159 | self.priorities[idx] = (prio + 1e-5) ** self.prob_alpha 160 | 161 | def __len__(self): 162 | return len(self.buffer) 163 | 164 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | # from mpi_tools import mpi_statistics_scalar 4 | 5 | 6 | # Code based on: 7 | # https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 8 | 9 | # Expects tuples of (state, next_state, action, reward, done) 10 | class ReplayBuffer(object): 11 | def __init__(self, max_size=1e6): 12 | self.storage = [] 13 | self.max_size = max_size 14 | self.ptr = 0 15 | 16 | def add(self, data): 17 | if len(self.storage) == self.max_size: 18 | self.storage[int(self.ptr)] = data 19 | self.ptr = (self.ptr + 1) % self.max_size 20 | else: 21 | self.storage.append(data) 22 | 23 | def sample(self, batch_size): 24 | ind = np.random.randint(0, len(self.storage), size=batch_size) 25 | x, y, u, r, d,nu, parameters =[], [],[], [], [], [], [] 26 | # print(len(self.storage), " ", ind) 27 | 28 | # print(self.storage[ind[0]]) 29 | 30 | for i in ind: 31 | X, Y, U, R, D, NU, P = self.storage[i] 32 | x.append(np.array(X, copy=False)) 33 | y.append(np.array(Y, copy=False)) 34 | u.append(np.array(U, copy=False)) 35 | r.append(np.array(R, copy=False)) 36 | d.append(np.array(D, copy=False)) 37 | nu.append(np.array(NU, copy=False)) 38 | parameters.append(np.array(P, copy=False)) 39 | return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1),np.array(parameters),np.array(nu) 40 | 41 | def combined_shape(length, shape=None): 42 | if shape is None: 43 | return (length,) 44 | return (length, shape) if np.isscalar(shape) else (length, *shape) 45 | 46 | 47 | class ReplayBufferPPO(object): 48 | """ 49 | original from: https://github.com/bluecontra/tsallis_actor_critic_mujoco/blob/master/spinup/algos/ppo/ppo.py 50 | A buffer for storing trajectories experienced by a PPO agent interacting 51 | with the environment, and using Generalized Advantage Estimation (GAE-Lambda) 52 | for calculating the advantages of state-action pairs. 53 | """ 54 | 55 | def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95): 56 | self.obs_dim = obs_dim 57 | self.act_dim = act_dim 58 | self.size = size 59 | self.gamma, self.lam = gamma, lam 60 | self.ptr = 0 61 | self.path_start_idx, self.max_size = 0, size 62 | 63 | self.reset() 64 | 65 | def reset(self): 66 | self.obs_buf = np.zeros([self.size, self.obs_dim], dtype=np.float32) 67 | self.act_buf = np.zeros([self.size, self.act_dim], dtype=np.float32) 68 | self.adv_buf = np.zeros(self.size, dtype=np.float32) 69 | self.rew_buf = np.zeros(self.size, dtype=np.float32) 70 | self.ret_buf = np.zeros(self.size, dtype=np.float32) 71 | self.val_buf = np.zeros(self.size, dtype=np.float32) 72 | self.logp_buf = np.zeros(self.size, dtype=np.float32) 73 | 74 | def add(self, obs, act, rew, val, logp): 75 | """ 76 | Append one timestep of agent-environment interaction to the buffer. 77 | """ 78 | assert self.ptr < self.max_size # buffer has to have room so you can store 79 | self.obs_buf[self.ptr] = obs 80 | self.act_buf[self.ptr] = act 81 | self.rew_buf[self.ptr] = rew 82 | self.val_buf[self.ptr] = val 83 | self.logp_buf[self.ptr] = logp 84 | self.ptr += 1 85 | 86 | def finish_path(self, last_val=0): 87 | path_slice = slice(self.path_start_idx, self.ptr) 88 | rews = np.append(self.rew_buf[path_slice], last_val) 89 | vals = np.append(self.val_buf[path_slice], last_val) 90 | 91 | # the next two lines implement GAE-Lambda advantage calculation 92 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] 93 | self.adv_buf[path_slice] = discount(deltas, self.gamma * self.lam) 94 | 95 | # the next line computes rewards-to-go, to be targets for the value function 96 | self.ret_buf[path_slice] = discount(rews, self.gamma)[:-1] 97 | 98 | self.path_start_idx = self.ptr 99 | 100 | def get(self): 101 | assert self.ptr == self.max_size # buffer has to be full before you can get 102 | self.ptr, self.path_start_idx = 0, 0 103 | # the next two lines implement the advantage normalization trick 104 | # adv_mean, adv_std = mpi_statistics_scalar(self.adv_buf) 105 | adv_mean = np.mean(self.adv_buf) 106 | adv_std = np.std(self.adv_buf) 107 | self.adv_buf = (self.adv_buf - adv_mean) / adv_std 108 | return [self.obs_buf, self.act_buf, self.adv_buf, 109 | self.ret_buf, self.logp_buf] 110 | 111 | 112 | class ReplayBuffer_MC(object): 113 | def __init__(self, max_size=1e6): 114 | self.storage = [] 115 | self.max_size = max_size 116 | self.ptr = 0 117 | 118 | def add(self, data): 119 | if len(self.storage) == self.max_size: 120 | self.storage[int(self.ptr)] = data 121 | self.ptr = (self.ptr + 1) % self.max_size 122 | else: 123 | self.storage.append(data) 124 | 125 | def sample(self, batch_size): 126 | ind = np.random.randint(0, len(self.storage), size=batch_size) 127 | x, u, r = [], [], [] 128 | 129 | for i in ind: 130 | X, U, R = self.storage[i] 131 | x.append(np.array(X, copy=False)) 132 | u.append(np.array(U, copy=False)) 133 | r.append(np.array(R, copy=False)) 134 | 135 | return np.array(x), np.array(u), np.array(r).reshape(-1, 1) 136 | 137 | 138 | class ReplayBuffer_VDFP(object): 139 | def __init__(self, max_size=1e5): 140 | self.storage = [] 141 | self.max_size = int(max_size) 142 | self.ptr = 0 143 | 144 | def add(self, data): 145 | if len(self.storage) == self.max_size: 146 | self.storage[self.ptr] = data 147 | self.ptr = (self.ptr + 1) % self.max_size 148 | else: 149 | self.storage.append(data) 150 | 151 | def sample(self, batch_size): 152 | ind = np.random.randint(0, len(self.storage), size=batch_size) 153 | s, a, u, x = [], [], [], [] 154 | 155 | for i in ind: 156 | S, A, U, X = self.storage[i] 157 | s.append(np.array(S, copy=False)) 158 | a.append(np.array(A, copy=False)) 159 | u.append(np.array(U, copy=False)) 160 | x.append(np.array(X, copy=False)) 161 | 162 | return np.array(s), np.array(a), np.array(u).reshape(-1, 1), np.array(x) 163 | 164 | def sample_traj(self, batch_size, offset=0): 165 | ind = np.random.randint(0, len(self.storage) - int(offset), size=batch_size) 166 | if len(self.storage) == self.max_size: 167 | ind = (self.ptr + self.max_size - ind) % self.max_size 168 | else: 169 | ind = len(self.storage) - ind - 1 170 | # ind = (self.ptr - ind + len(self.storage)) % len(self.storage) 171 | s, a, x = [], [], [] 172 | 173 | for i in ind: 174 | S, A, _, X = self.storage[i] 175 | s.append(np.array(S, copy=False)) 176 | a.append(np.array(A, copy=False)) 177 | x.append(np.array(X, copy=False)) 178 | 179 | return np.array(s), np.array(a), np.array(x) 180 | 181 | def sample_traj_return(self, batch_size): 182 | ind = np.random.randint(0, len(self.storage), size=batch_size) 183 | u, x = [], [] 184 | 185 | for i in ind: 186 | _, _, U, X = self.storage[i] 187 | u.append(np.array(U, copy=False)) 188 | x.append(np.array(X, copy=False)) 189 | 190 | return np.array(u).reshape(-1, 1), np.array(x) 191 | 192 | 193 | def store_experience(replay_buffer, trajectory, s_dim, a_dim, 194 | sequence_length, min_sequence_length=0, is_padding=False, gamma=0.99, 195 | ): 196 | s_traj, a_traj, r_traj = trajectory 197 | 198 | # for the convenience of manipulation 199 | arr_s_traj = np.array(s_traj) 200 | arr_a_traj = np.array(a_traj) 201 | arr_r_traj = np.array(r_traj) 202 | 203 | zero_pads = np.zeros(shape=[sequence_length, s_dim + a_dim]) 204 | 205 | # for i in range(len(s_traj) - self.sequence_length): 206 | for i in range(len(s_traj) - min_sequence_length): 207 | tmp_s = arr_s_traj[i] 208 | tmp_a = arr_a_traj[i] 209 | tmp_soff = arr_s_traj[i:i + sequence_length] 210 | tmp_aoff = arr_a_traj[i:i + sequence_length] 211 | tmp_saoff = np.concatenate([tmp_soff, tmp_aoff], axis=1) 212 | 213 | tmp_saoff_padded = np.concatenate([tmp_saoff, zero_pads], axis=0) 214 | tmp_saoff_padded_clip = tmp_saoff_padded[:sequence_length, :] 215 | 216 | tmp_roff = arr_r_traj[i:i + sequence_length] 217 | tmp_u = np.matmul(tmp_roff, np.power(gamma, [j for j in range(len(tmp_roff))])) 218 | 219 | replay_buffer.add((tmp_s, tmp_a, tmp_u, tmp_saoff_padded_clip)) 220 | 221 | 222 | def discount(x, gamma): 223 | """ Calculate discounted forward sum of a sequence at each point """ 224 | """ 225 | magic from rllab for computing discounted cumulative sums of vectors. 226 | input: 227 | vector x, 228 | [x0, 229 | x1, 230 | x2] 231 | output: 232 | [x0 + discount * x1 + discount^2 * x2, 233 | x1 + discount * x2, 234 | x2] 235 | """ 236 | # return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] 237 | return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1] 238 | 239 | 240 | class Scaler(object): 241 | """ Generate scale and offset based on running mean and stddev along axis=0 242 | offset = running mean 243 | scale = 1 / (stddev + 0.1) / 3 (i.e. 3x stddev = +/- 1.0) 244 | """ 245 | 246 | def __init__(self, obs_dim): 247 | """ 248 | Args: 249 | obs_dim: dimension of axis=1 250 | """ 251 | self.vars = np.zeros(obs_dim) 252 | self.means = np.zeros(obs_dim) 253 | self.m = 0 254 | self.n = 0 255 | self.first_pass = True 256 | 257 | def update(self, x): 258 | """ Update running mean and variance (this is an exact method) 259 | Args: 260 | x: NumPy array, shape = (N, obs_dim) 261 | see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled- 262 | variance-of-two-groups-given-known-group-variances-mean 263 | """ 264 | if self.first_pass: 265 | self.means = np.mean(x, axis=0) 266 | self.vars = np.var(x, axis=0) 267 | self.m = x.shape[0] 268 | self.first_pass = False 269 | else: 270 | n = x.shape[0] 271 | new_data_var = np.var(x, axis=0) 272 | new_data_mean = np.mean(x, axis=0) 273 | new_data_mean_sq = np.square(new_data_mean) 274 | new_means = ((self.means * self.m) + (new_data_mean * n)) / (self.m + n) 275 | self.vars = (((self.m * (self.vars + np.square(self.means))) + 276 | (n * (new_data_var + new_data_mean_sq))) / (self.m + n) - 277 | np.square(new_means)) 278 | self.vars = np.maximum(0.0, self.vars) # occasionally goes negative, clip 279 | self.means = new_means 280 | self.m += n 281 | 282 | def get(self): 283 | """ returns 2-tuple: (scale, offset) """ 284 | return 1/(np.sqrt(self.vars) + 0.1)/3, self.means 285 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | import torch 3 | import os 4 | import wandb 5 | os.environ["WANDB_API_KEY"] = "" 6 | os.environ["WANDB_MODE"] = "" 7 | class Parameters: 8 | def __init__(self, cla, init=True): 9 | if not init: 10 | return 11 | cla = cla.parse_args() 12 | 13 | # Set the device to run on CUDA or CPU 14 | if not cla.disable_cuda and torch.cuda.is_available(): 15 | self.device = torch.device('cuda') 16 | else: 17 | self.device = torch.device('cpu') 18 | 19 | # Render episodes 20 | self.render = cla.render 21 | self.env_name = cla.env 22 | self.save_periodic = cla.save_periodic 23 | 24 | # Number of Frames to Run 25 | self.num_frames = 1000000 26 | 27 | # Synchronization 28 | if cla.env == 'HalfCheetah-v2' or cla.env == 'Hopper-v2' or cla.env == 'Ant-v2' or cla.env == 'Walker2d-v2' or cla.env == "Humanoid-v2": 29 | self.rl_to_ea_synch_period = 1 30 | else: 31 | self.rl_to_ea_synch_period = 10 32 | 33 | # Overwrite sync from command line if value is passed 34 | if cla.sync_period is not None: 35 | self.rl_to_ea_synch_period = cla.sync_period 36 | 37 | # Novelty Search 38 | self.ns = cla.novelty 39 | self.ns_epochs = 10 40 | 41 | # Model save frequency if save is active 42 | self.next_save = cla.next_save 43 | 44 | # DDPG params 45 | self.use_ln = True 46 | self.gamma = cla.gamma 47 | self.tau = cla.tau 48 | self.seed = cla.seed 49 | self.batch_size = 128 50 | self.frac_frames_train = 1.0 51 | self.use_done_mask = True 52 | self.buffer_size = 1000000 53 | self.ls = 300 54 | 55 | # Prioritised Experience Replay 56 | self.per = cla.per 57 | self.replace_old = True 58 | self.alpha = 0.7 59 | self.beta_zero = 0.5 60 | self.learn_start = (1 + self.buffer_size / self.batch_size) * 2 61 | self.total_steps = self.num_frames 62 | 63 | # ========================================== NeuroEvolution Params ============================================= 64 | 65 | # Num of trials 66 | self.num_evals = 1 67 | if cla.num_evals is not None: 68 | self.num_evals = cla.num_evals 69 | 70 | # Elitism Rate 71 | self.elite_fraction = 0.2 72 | # Number of actors in the population 73 | self.pop_size = cla.pop_size 74 | 75 | # Mutation and crossover 76 | self.crossover_prob = 0.0 77 | self.mutation_prob = 0.9 78 | self.mutation_mag = cla.mut_mag 79 | self.mutation_noise = cla.mut_noise 80 | self.mutation_batch_size = 256 81 | self.proximal_mut = cla.proximal_mut 82 | self.distil = cla.distil 83 | self.distil_type = cla.distil_type 84 | self.verbose_mut = cla.verbose_mut 85 | self.verbose_crossover = cla.verbose_crossover 86 | 87 | # Genetic memory size 88 | self.individual_bs = 8000 89 | self.intention = cla.intention 90 | # Variation operator statistics 91 | self.opstat = cla.opstat 92 | self.opstat_freq = cla.opstat_freq 93 | self.test_operators = cla.test_operators 94 | 95 | # Save Results 96 | self.state_dim = None # To be initialised externally 97 | self.action_dim = None # To be initialised externally 98 | self.random_choose = cla.random_choose 99 | self.EA = cla.EA 100 | self.RL = cla.RL 101 | self.K = cla.K 102 | self.state_alpha = cla.state_alpha 103 | self.detach_z = cla.detach_z 104 | self.actor_alpha = cla.actor_alpha 105 | self.TD3_noise = cla.TD3_noise 106 | self.pr = cla.pr 107 | self.use_all = cla.use_all 108 | self.OFF_TYPE = cla.OFF_TYPE 109 | self.prob_reset_and_sup = cla.prob_reset_and_sup 110 | self.frac = cla.frac 111 | self.EA_actor_alpha = cla.EA_actor_alpha 112 | self.theta = cla.theta 113 | self.time_steps = cla.time_steps 114 | self.init_steps = 10000 115 | self.name = "Steps_"+str(self.time_steps)+"_theta_"+str(self.theta)+ "_eval_"+str(self.num_evals)+"_rs_prob_"+ str(self.prob_reset_and_sup)+"_frac_p_"+str(self.frac)+"_our_M_"+str(self.OFF_TYPE)+"_" + str(self.elite_fraction) +"_"+ str(self.rl_to_ea_synch_period) +"_"+ str(self.pop_size) + "_"+str(self.EA_actor_alpha) + "_"+str(self.pr)+"_noise_"+str(self.TD3_noise)+"_Pavn_detach_"+str(self.detach_z)+"_"+str(self.actor_alpha)+ "_actorloss_MI_sa_s_"+ str(self.state_alpha) + "_random_K_"+ str(self.K)+ "_"+str(cla.env) + "_"+ str(self.tau) 116 | 117 | self.wandb = wandb.init(project="TSR",name=self.name) 118 | 119 | self.wandb.config.rl_to_ea_synch_period = self.rl_to_ea_synch_period 120 | self.wandb.config.env = cla.env 121 | self.wandb.config.tau = self.tau 122 | 123 | self.wandb.config.num_evals = self.num_evals 124 | self.wandb.config.elite_fraction = self.elite_fraction 125 | self.wandb.config.crossover_prob = self.crossover_prob 126 | self.wandb.config.mutation_prob = self.mutation_prob 127 | self.wandb.config.mutation_batch_size = self.mutation_batch_size 128 | self.wandb.config.distil = self.distil 129 | self.wandb.config.proximal_mut = self.proximal_mut 130 | 131 | self.save_foldername = cla.logdir + "/"+self.name 132 | if not os.path.exists(self.save_foldername): 133 | os.makedirs(self.save_foldername) 134 | 135 | def write_params(self, stdout=True): 136 | # Dump all the hyper-parameters in a file. 137 | params = pprint.pformat(vars(self), indent=4) 138 | if stdout: 139 | print(params) 140 | 141 | with open(os.path.join(self.save_foldername, 'info.txt'), 'a') as f: 142 | f.write(params) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | nohup python run_re2.py -env="Ant-v2" -disable_cuda -OFF_TYPE=1 -pr=64 -pop_size=5 -prob_reset_and_sup=0.05 -time_steps=200 -theta=0.5 -frac=0.7 -gamma=0.99 -TD3_noise=0.2 -EA -RL -K=1 -state_alpha=0.0 -actor_alpha=1.0 -EA_actor_alpha=1.0 -tau=0.005 -seed=1 -logdir="./logs" > ./logs/xxx.log 2>&1 & 2 | -------------------------------------------------------------------------------- /run_re2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import random 4 | 5 | import gym, torch 6 | import argparse 7 | import pickle 8 | from core.operator_runner import OperatorRunner 9 | from parameters import Parameters 10 | 11 | 12 | import os 13 | 14 | cpu_num = 1 15 | os.environ ['OMP_NUM_THREADS'] = str(cpu_num) 16 | os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num) 17 | os.environ ['MKL_NUM_THREADS'] = str(cpu_num) 18 | os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num) 19 | os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num) 20 | torch.set_num_threads(cpu_num) 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('-env', help='Environment Choices: (Swimmer-v2) (HalfCheetah-v2) (Hopper-v2) ' + 25 | '(Walker2d-v2) (Ant-v2)', required=True, type=str) 26 | parser.add_argument('-seed', help='Random seed to be used', type=int, default=7) 27 | parser.add_argument('-pr', help='pr', type=int, default=128) 28 | parser.add_argument('-pop_size', help='pop_size', type=int, default=10) 29 | 30 | parser.add_argument('-disable_cuda', help='Disables CUDA', action='store_true') 31 | parser.add_argument('-render', help='Render gym episodes', action='store_true') 32 | parser.add_argument('-sync_period', help="How often to sync to population", type=int) 33 | parser.add_argument('-novelty', help='Use novelty exploration', action='store_true') 34 | parser.add_argument('-proximal_mut', help='Use safe mutation', action='store_true') 35 | parser.add_argument('-distil', help='Use distilation crossover', action='store_true') 36 | parser.add_argument('-distil_type', help='Use distilation crossover. Choices: (fitness) (distance)', 37 | type=str, default='fitness') 38 | parser.add_argument('-EA', help='Use ea', action='store_true') 39 | parser.add_argument('-RL', help='Use rl', action='store_true') 40 | parser.add_argument('-detach_z', help='detach_z', action='store_true') 41 | parser.add_argument('-random_choose', help='Use random_choose', action='store_true') 42 | 43 | parser.add_argument('-per', help='Use Prioritised Experience Replay', action='store_true') 44 | parser.add_argument('-use_all', help='Use all', action='store_true') 45 | 46 | parser.add_argument('-intention', help='intention', action='store_true') 47 | 48 | parser.add_argument('-mut_mag', help='The magnitude of the mutation', type=float, default=0.05) 49 | parser.add_argument('-tau', help='tau', type=float, default=0.005) 50 | 51 | parser.add_argument('-prob_reset_and_sup', help='prob_reset_and_sup', type=float, default=0.05) 52 | parser.add_argument('-frac', help='frac', type=float, default=0.1) 53 | 54 | 55 | parser.add_argument('-TD3_noise', help='tau', type=float, default=0.2) 56 | parser.add_argument('-mut_noise', help='Use a random mutation magnitude', action='store_true') 57 | parser.add_argument('-verbose_mut', help='Make mutations verbose', action='store_true') 58 | parser.add_argument('-verbose_crossover', help='Make crossovers verbose', action='store_true') 59 | parser.add_argument('-logdir', help='Folder where to save results', type=str, required=True) 60 | parser.add_argument('-opstat', help='Store statistics for the variation operators', action='store_true') 61 | parser.add_argument('-opstat_freq', help='Frequency (in generations) to store operator statistics', type=int, default=1) 62 | parser.add_argument('-save_periodic', help='Save actor, critic and memory periodically', action='store_true') 63 | parser.add_argument('-next_save', help='Generation save frequency for save_periodic', type=int, default=200) 64 | parser.add_argument('-K', help='K', type=int, default=5) 65 | parser.add_argument('-OFF_TYPE', help='OFF_TYPE', type=int, default=1) 66 | parser.add_argument('-num_evals', help='num_evals', type=int, default=1) 67 | 68 | parser.add_argument('-version', help='version', type=int, default=1) 69 | parser.add_argument('-time_steps', help='time_steps', type=int, default=1) 70 | 71 | 72 | parser.add_argument('-test_operators', help='Runs the operator runner to test the operators', action='store_true') 73 | parser.add_argument('-EA_actor_alpha', help='EA_actor_alpha', type=float, default=1.0) 74 | parser.add_argument('-state_alpha', help='state_alpha', type=float, default=1.0) 75 | parser.add_argument('-actor_alpha', help='actor_alpha', type=float, default=1.0) 76 | parser.add_argument('-theta', help='theta', type=float, default=0.5) 77 | 78 | parser.add_argument('-gamma', help='gamma', type=float, default=0.99) 79 | 80 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 81 | parameters = Parameters(parser) # Inject the cla arguments in the parameters object 82 | 83 | # Create Env 84 | #env = utils.NormalizedActions(gym.make(parameters.env_name)) 85 | env = gym.make(parameters.env_name) 86 | print("env.action_space.low",env.action_space.low, "env.action_space.high",env.action_space.high) 87 | parameters.action_dim = env.action_space.shape[0] 88 | parameters.state_dim = env.observation_space.shape[0] 89 | 90 | # Write the parameters to a the info file and print them 91 | parameters.write_params(stdout=True) 92 | 93 | # Seed 94 | os.environ['PYTHONHASHSEED']= str(parameters.seed) 95 | env.seed(parameters.seed) 96 | torch.manual_seed(parameters.seed) 97 | np.random.seed(parameters.seed) 98 | random.seed(parameters.seed) 99 | 100 | from core import mod_utils as utils, agent 101 | tracker = utils.Tracker(parameters, ['erl'], '_score.csv') # Initiate tracker 102 | frame_tracker = utils.Tracker(parameters, ['frame_erl'], '_score.csv') # Initiate tracker 103 | time_tracker = utils.Tracker(parameters, ['time_erl'], '_score.csv') 104 | ddpg_tracker = utils.Tracker(parameters, ['ddpg'], '_score.csv') 105 | selection_tracker = utils.Tracker(parameters, ['elite', 'selected', 'discarded'], '_selection.csv') 106 | #env.action_space.seed(parameters.seed) 107 | 108 | 109 | if __name__ == "__main__": 110 | 111 | # Tests the variation operators after that is saved first with -save_periodic 112 | if parameters.test_operators: 113 | operator_runner = OperatorRunner(parameters, env) 114 | operator_runner.run() 115 | exit() 116 | 117 | # Create Agent 118 | agent = agent.Agent(parameters, env) 119 | print('Running', parameters.env_name, ' State_dim:', parameters.state_dim, ' Action_dim:', parameters.action_dim) 120 | 121 | next_save = parameters.next_save; time_start = time.time() 122 | while agent.num_frames <= parameters.num_frames: 123 | stats = agent.train() 124 | best_train_fitness = stats['best_train_fitness'] 125 | erl_score = stats['test_score'] 126 | elite_index = stats['elite_index'] 127 | ddpg_reward = stats['ddpg_reward'] 128 | policy_gradient_loss = stats['pg_loss'] 129 | behaviour_cloning_loss = stats['bc_loss'] 130 | population_novelty = stats['pop_novelty'] 131 | current_q = stats['current_q'] 132 | target_q = stats['target_q'] 133 | pre_loss = stats['pre_loss'] 134 | before_rewards = stats['before_rewards'] 135 | add_rewards = stats['add_rewards'] 136 | l1_before_after = stats['l1_before_after'] 137 | keep_c_loss = stats['keep_c_loss'] 138 | pvn_loss = stats['pvn_loss'] 139 | min_fintess = stats['min_fintess'] 140 | best_old_fitness = stats['best_old_fitness'] 141 | new_fitness = stats['new_fitness'] 142 | 143 | print('#Games:', agent.num_games, '#Frames:', agent.num_frames, 144 | ' Train_Max:', '%.2f'%best_train_fitness if best_train_fitness is not None else None, 145 | ' Test_Score:','%.2f'%erl_score if erl_score is not None else None, 146 | ' Avg:','%.2f'%tracker.all_tracker[0][1], 147 | ' ENV: '+ parameters.env_name, 148 | ' DDPG Reward:', '%.2f'%ddpg_reward, 149 | ' PG Loss:', '%.4f' % policy_gradient_loss) 150 | 151 | elite = agent.evolver.selection_stats['elite']/agent.evolver.selection_stats['total'] 152 | selected = agent.evolver.selection_stats['selected'] / agent.evolver.selection_stats['total'] 153 | discarded = agent.evolver.selection_stats['discarded'] / agent.evolver.selection_stats['total'] 154 | 155 | print() 156 | 157 | min_fintess = stats['min_fintess'] 158 | best_old_fitness = stats['best_old_fitness'] 159 | new_fitness = stats['new_fitness'] 160 | best_reward = np.max([ddpg_reward,erl_score]) 161 | 162 | parameters.wandb.log( 163 | {'best_reward': best_reward, 'add_rewards': add_rewards, 164 | 'pvn_loss': pvn_loss, 'keep_c_loss': keep_c_loss, 'l1_before_after': l1_before_after, 165 | 'pre_loss': pre_loss, 'num_frames': agent.num_frames, 'num_games': agent.num_games, 166 | 'erl_score': erl_score, 'ddpg_reward': ddpg_reward, 'elite': elite, 'selected': selected, 'discarded': discarded, 167 | 'policy_gradient_loss': policy_gradient_loss, 'population_novelty': population_novelty, 168 | 'best_train_fitness': best_train_fitness, 'behaviour_cloning_loss': behaviour_cloning_loss}) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | --------------------------------------------------------------------------------