├── LICENSE ├── README.md ├── algorithms ├── run_sac.py ├── run_td3.py ├── sac_based │ ├── anf_sac.py │ ├── sac.py │ └── ss_sac.py ├── td3_based │ ├── anf_td3.py │ ├── ss_td3.py │ └── td3.py └── test_algos ├── figures ├── ANF.png ├── learning_curves_halfcheetah_nf98.png ├── mujoco_gym_4envs_captions.png └── mujoco_visual_halfcheetah.png ├── main.py ├── tutorial.md ├── utils ├── activations.py ├── core.py ├── core_anf_sac.py ├── core_sac.py ├── load_feats_distr.py ├── mask_adam.py ├── noise_distributions │ └── real_feats_distr_HalfCheetah.npy ├── pretrained_models │ └── ANF-SAC_HalfCheetah-v3_relu_sparsity0.0_uniform_inlayspars0.8_hid-lay2_maskadam_fakefeats0.9_seed3101_best ├── replay_memory_sac.py ├── sparse_utils.py ├── target_network.py └── utils.py └── view_mujoco.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Bram Grooten 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Noise Filtering 2 | _with Dynamic Sparse Training in Deep Reinforcement Learning_ 3 | 4 | Paper: [arxiv.org/abs/2302.06548](https://arxiv.org/abs/2302.06548) accepted at [AAMAS'23](https://aamas2023.soton.ac.uk/). 5 | If you use this code, please cite: 6 | ``` 7 | @article{grooten2023automatic, 8 | title={{Automatic Noise Filtering with Dynamic Sparse Training in Deep Reinforcement Learning}}, 9 | author={Grooten, Bram and Sokar, Ghada and Dohare, Shibhansh and Mocanu, Elena and Taylor, Matthew E. and Pechenizkiy, Mykola and Mocanu, Decebal Constantin}, 10 | year={2023}, 11 | journal={The 22nd International Conference on Autonomous Agents and Multiagent Systems (AAMAS)}, 12 | note={URL: \url{https://arxiv.org/abs/2302.06548}} 13 | } 14 | ``` 15 | 16 | ![Image showing the overview of ANF](figures/ANF.png) 17 | 18 | # Abstract 19 | Tomorrow's robots will need to distinguish useful information from noise when performing different tasks. 20 | A household robot for instance may continuously receive a plethora of information about the home, 21 | but needs to focus on just a small subset to successfully execute its current chore. 22 | 23 | Filtering distracting inputs that contain irrelevant data 24 | has received little attention in the reinforcement learning literature. 25 | To start resolving this, we formulate a **problem setting** in reinforcement learning 26 | called the _extremely noisy environment_ (ENE) where up to 99% of the input features are pure noise. 27 | Agents need to detect which features actually provide task-relevant information 28 | about the state of the environment. 29 | 30 | Consequently, we propose a new **method** termed _Automatic Noise Filtering_ (ANF) 31 | which uses the principles of dynamic sparse training to focus the input layer's connectivity 32 | on task-relevant features. 33 | ANF outperforms standard SAC and TD3 by a large margin, while using up to 95% fewer weights. 34 | 35 | Furthermore, we devise a transfer learning setting for ENEs, 36 | by permuting all features of the environment after 1M timesteps 37 | to simulate the fact that other information sources can become task-relevant as the world evolves. 38 | Again ANF surpasses the baselines in final performance and sample complexity. 39 | 40 | 41 | 42 | # Install 43 | ### Requirements 44 | * Python 3.8 45 | * PyTorch 1.9 46 | * [MuJoCo-py](https://github.com/openai/mujoco-py) 47 | * [OpenAI gym](https://github.com/openai/gym) 48 | * Linux (using [WSL](https://learn.microsoft.com/en-us/windows/wsl/install) may work in Windows) 49 | 50 | ### Instructions 51 | First make a virtual environment: 52 | ```shell 53 | sudo apt install python3.8 python3.8-venv python3.8-dev 54 | python3.8 -m venv venv 55 | source venv/bin/activate 56 | ``` 57 | 58 | If you don't have MuJoCo 2.10 yet: 59 | ```shell 60 | cd ~ 61 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 62 | tar -xzf mujoco210-linux-x86_64.tar.gz 63 | mkdir .mujoco 64 | mv mujoco210 .mujoco/ 65 | rm mujoco210-linux-x86_64.tar.gz 66 | ``` 67 | 68 | Now you have MuJoCo. Proceed with: 69 | ```shell 70 | pip install mujoco_py==2.1.2.14 gym==0.21.0 torch==1.9.0 71 | pip install wandb --upgrade 72 | ``` 73 | 74 | 75 | Now try to import mujoco_py in a python console, 76 | and do what the error messages tell you. 77 | (Like adding lines to your `.bashrc` file.) 78 | ```python 79 | $ python 80 | >>> import mujoco_py 81 | ``` 82 | 83 | You may need to install the following packages to solve some errors: 84 | ```shell 85 | sudo apt install libosmesa6-dev libglew-dev patchelf 86 | ``` 87 | 88 | 89 | # Usage 90 | 91 | ### Train 92 | To train an ANF agent on the ENE with 90% noise features, run: 93 | ``` 94 | python main.py \ 95 | --policy ANF-SAC \ 96 | --env HalfCheetah-v3 \ 97 | --fake_features 0.9 \ 98 | --input_layer_sparsity 0.8 \ 99 | --wandb_mode disabled 100 | ``` 101 | 102 | Possible policies: `ANF-SAC`, `ANF-TD3`, `SAC`, `TD3`. 103 | 104 | Possible environments: `HalfCheetah-v3`, `Hopper-v3`, `Walker2d-v3`, `Humanoid-v3`. 105 | 106 | Show all available arguments: `python main.py --help` 107 | 108 | ### Test 109 | 110 | See the file `view_mujoco.py` to test a trained agent on a single episode and view its behavior. 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /algorithms/run_sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import wandb 3 | from utils.replay_memory_sac import ReplayMemory, fill_initial_replay_memory, refill_replay_buffer 4 | from utils import utils 5 | import datetime 6 | 7 | 8 | def run(args, file_name, device, main_start_time): 9 | env, eval_env, next_env_change, adjust_env_period, env_num = utils.initialize_environments(args) 10 | agent = utils.setup_sac_based_agent(args, env, device) 11 | memory = ReplayMemory(args.buffer_size, args.seed) 12 | avg_return0 = utils.eval_policy(agent, eval_env, args.seed, args.print_comments, args.eval_episodes) 13 | evaluations = [avg_return0] 14 | wandb.log({'eval_return': avg_return0}, step=0) 15 | 16 | num_connections0 = utils.count_weights(agent, args) 17 | connections = [num_connections0] 18 | wandb.log({'num_connections': num_connections0}, step=0) 19 | wandb.watch((agent.policy, agent.critic), log="all", log_freq=5000) 20 | if args.save_model: 21 | agent.save(f"./output/models/{file_name}_iter_0") 22 | 23 | fill_initial_replay_memory(memory, env, args) 24 | 25 | updates = 0 26 | state, done = env.reset(), False 27 | episode_reward, episode_steps, episode_num = 0, 0, 0 28 | max_eval_return = float('-inf') 29 | episode_start_time = datetime.datetime.now() 30 | loss_info_dict = {} 31 | 32 | # Training Loop 33 | print(f"\nNow the training starts") 34 | for t in range(int(args.max_timesteps)): 35 | action = agent.select_action(state) # Sample action from policy 36 | next_state, reward, done, _ = env.step(action) # Perform action 37 | episode_steps += 1 38 | episode_reward += reward 39 | 40 | # Ignore the "done" signal if it comes from hitting the time horizon. 41 | # see https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/algos/pytorch/sac/sac.py#L304 42 | not_done = 1 if episode_steps == env._max_episode_steps else float(not done) 43 | memory.push(state, action, reward, next_state, not_done) # Append transition to memory 44 | state = next_state 45 | 46 | # Number of updates per step in environment 47 | for i in range(args.updates_per_step): 48 | # Train the agent 49 | loss_info_dict = agent.update_parameters(memory, args.batch_size, updates) 50 | updates += 1 51 | # wandb.log(loss_info_dict, step=t+1) 52 | 53 | if done: 54 | if args.print_comments: 55 | print(f"Total Steps: {t+1} Episode Num: {episode_num} " 56 | f"Epi. Steps: {episode_steps} Reward: {round(episode_reward, 2)} " 57 | f"Epi. Time: {datetime.datetime.now() - episode_start_time} " 58 | f"Total Train Time: {datetime.datetime.now() - main_start_time}") 59 | if t > next_env_change: 60 | agent.set_new_permutation() 61 | next_env_change += adjust_env_period 62 | if args.empty_buffer_on_env_change: 63 | memory.empty_buffer() 64 | refill_replay_buffer(memory, env, agent, args) 65 | # Reset environment 66 | state, done = env.reset(), False 67 | episode_reward, episode_steps = 0, 0 68 | episode_num += 1 69 | episode_start_time = datetime.datetime.now() 70 | 71 | # Evaluate the policy 72 | if (t + 1) % args.eval_freq == 0: 73 | avg_return = utils.eval_policy(agent, eval_env, args.seed, args.print_comments, args.eval_episodes) 74 | wandb.log({'eval_return': avg_return}, step=t+1) 75 | num_connections = utils.count_weights(agent, args) 76 | # wandb.log({'num_connections': num_connections}, step=t+1) 77 | wandb.log({'actor_real_connections': num_connections[0]}, step=t+1) 78 | wandb.log({'actor_fake_connections': num_connections[1]}, step=t+1) 79 | if args.save_results: 80 | evaluations.append(avg_return) 81 | np.save(f"./output/results/{file_name}", evaluations) 82 | connections.append(num_connections) 83 | np.save(f"./output/connectivity/{file_name}", connections) 84 | if t > 0.8 * int(args.max_timesteps) and avg_return > max_eval_return: 85 | max_eval_return = avg_return 86 | if args.save_model: 87 | agent.save(f"./output/models/{file_name}_best") 88 | 89 | # Save current policy 90 | if args.save_model and (t + 1) % args.save_model_period == 0: 91 | agent.save(f"./output/models/{file_name}_iter_{t + 1}") 92 | 93 | # Tracking the sparsity 94 | if args.policy in ['ANF-SAC', 'Static-SAC'] and t % 7_100 == 0: 95 | wandb.log(agent.print_sparsity(), step=t) 96 | 97 | wandb.log({'max_eval_return': max_eval_return}) 98 | print(f"Maximum evaluation return value was {max_eval_return} (only measured after 80% of training steps onwards)") 99 | -------------------------------------------------------------------------------- /algorithms/run_td3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import wandb 4 | from utils import utils 5 | 6 | 7 | def run(args, file_name, device, main_start_time): 8 | env, eval_env, next_env_change, adjust_env_period, env_num = utils.initialize_environments(args) 9 | state_dim, action_dim = env.observation_space.shape[0], env.action_space.shape[0] 10 | max_action = float(env.action_space.high[0]) 11 | 12 | agent = utils.set_policy_kwargs(state_dim, action_dim, max_action, args, device) 13 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim, max_size=args.buffer_size) 14 | 15 | if args.load_model != "": # Loading previously trained model 16 | agent.load(f"./output/models/{args.load_model}") 17 | current_iter = agent.total_it 18 | num_eval_to_keep = int(current_iter / args.eval_freq) + 1 19 | evaluations = list(np.load(f"./output/results/{file_name}.npy"))[:num_eval_to_keep] 20 | # replay_buffer.load_buffer(f"./output/replay_buffers/{file_name}") 21 | else: # No model loaded, training from scratch 22 | if args.save_model: # Save untrained policy 23 | agent.save(f"./output/models/{file_name}_iter0") 24 | # Evaluate untrained policy 25 | avg_return0 = utils.eval_policy(agent, eval_env, args.seed, args.print_comments, args.eval_episodes) 26 | evaluations = [avg_return0] 27 | wandb.log({'eval_return': avg_return0}, step=0) 28 | 29 | num_connections0 = utils.count_weights(agent, args) 30 | connections = [num_connections0] 31 | wandb.log({'num_connections': num_connections0}, step=0) 32 | if args.print_comments: 33 | print(f"\nFirst running {args.start_timesteps} steps with random policy to fill ReplayBuffer") 34 | utils.fill_initial_replay_buffer(replay_buffer, env, args) 35 | 36 | state, done = env.reset(), False 37 | episode_reward, episode_steps, episode_num = 0, 0, 0 38 | max_eval_return = float('-inf') 39 | episode_start_time = datetime.datetime.now() 40 | wandb.watch((agent.actor, agent.critic), log="all", log_freq=5000) 41 | 42 | print(f"\nNow the training starts") 43 | for t in range(int(args.max_timesteps)): 44 | # Select action according to policy, then add some noise 45 | action = (agent.select_action(np.array(state)) 46 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 47 | ).clip(-max_action, max_action) 48 | # Perform action 49 | next_state, reward, done, _ = env.step(action) 50 | episode_steps += 1 51 | episode_reward += reward 52 | 53 | done_bool = float(done) if episode_steps < env._max_episode_steps else 0 54 | # Store data in replay buffer 55 | replay_buffer.add(state, action, next_state, reward, done_bool) 56 | state = next_state 57 | 58 | # Train the agent 59 | agent.train(replay_buffer, args.batch_size) 60 | 61 | if done: 62 | if args.print_comments: 63 | print(f"Total Steps: {t + 1} Episode Num: {episode_num + 1} Epi. Steps: {episode_steps} " 64 | f"Reward: {episode_reward:.3f} Epi. Time: {datetime.datetime.now() - episode_start_time} " 65 | f"Total train time: {datetime.datetime.now() - main_start_time}") 66 | if t > next_env_change: 67 | agent.set_new_permutation() 68 | next_env_change += adjust_env_period 69 | if args.empty_buffer_on_env_change: 70 | replay_buffer.empty_buffer() 71 | utils.refill_replay_buffer(replay_buffer, env, agent, args) 72 | # Reset environment 73 | state, done = env.reset(), False 74 | episode_reward, episode_steps = 0, 0 75 | episode_num += 1 76 | episode_start_time = datetime.datetime.now() 77 | 78 | # Evaluate the agent 79 | if (t + 1) % args.eval_freq == 0: 80 | avg_return = utils.eval_policy(agent, eval_env, args.seed, args.print_comments, args.eval_episodes) 81 | wandb.log({'eval_return': avg_return}, step=agent.total_it) 82 | num_connections = utils.count_weights(agent, args) 83 | # wandb.log({'num_connections': num_connections}, step=agent.total_it) 84 | wandb.log({'actor_real_connections': num_connections[0]}, step=agent.total_it) 85 | wandb.log({'actor_fake_connections': num_connections[1]}, step=agent.total_it) 86 | if args.save_results: 87 | evaluations.append(avg_return) 88 | np.save(f"./output/results/{file_name}", evaluations) 89 | connections.append(num_connections) 90 | np.save(f"./output/connectivity/{file_name}", connections) 91 | if t > 0.8 * int(args.max_timesteps) and avg_return > max_eval_return: 92 | max_eval_return = avg_return 93 | if args.save_model: 94 | agent.save(f"./output/models/{file_name}_best") 95 | 96 | # Save current policy 97 | if args.save_model and (t + 1) % args.save_model_period == 0: 98 | agent.save(f"./output/models/{file_name}_iter{agent.total_it}") 99 | # replay_buffer.save_buffer(f"./output/replay_buffers/{file_name}") 100 | 101 | # Tracking the sparsity 102 | if args.policy in ['ANF-TD3', 'Static-TD3'] and t % 7_100 == 0: 103 | wandb.log(agent.print_sparsity(), step=t) 104 | 105 | wandb.log({'max_eval_return': max_eval_return}) 106 | print(f"Maximum evaluation return value was {max_eval_return} (only measured after 80% of training steps onwards)") 107 | -------------------------------------------------------------------------------- /algorithms/sac_based/anf_sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam, SGD 6 | from utils.mask_adam import MaskAdam 7 | from utils.target_network import soft_update, hard_update 8 | from utils.core_anf_sac import GaussianPolicy, QNetwork, DeterministicPolicy 9 | import utils.sparse_utils as sp 10 | 11 | 12 | class ANF_SAC(object): 13 | def __init__(self, state_dim, action_space, args, device): 14 | self.device = device 15 | 16 | self.gamma = args.discount 17 | self.tau = args.tau 18 | self.alpha = args.temperature 19 | 20 | self.target_update_interval = args.target_update_interval 21 | self.automatic_entropy_tuning = args.automatic_entropy_tuning 22 | 23 | self.total_it = 0 24 | self.setZeta = args.ann_setZeta 25 | self.ascTopologyChangePeriod = args.ann_ascTopologyChangePeriod 26 | self.earlyStopTopologyChangeIteration = args.ann_earlyStopTopologyChange 27 | self.lastTopologyChangeCritic = False 28 | self.lastTopologyChangePolicy = False 29 | self.ascStatsPolicy = [] 30 | self.ascStatsCritic = [] 31 | self.ascStatsValue = [] 32 | 33 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - args.fake_features))) 34 | self.prev_permutations = [] 35 | 36 | self.critic = QNetwork(state_dim, action_space.shape[0], args, 37 | self.dim_state_with_fake, self.device).to(device=self.device) 38 | self.critic_target = QNetwork(state_dim, action_space.shape[0], args, 39 | self.dim_state_with_fake, self.device).to(self.device) 40 | hard_update(self.critic_target, self.critic) 41 | 42 | if args.sac_type == "Gaussian": 43 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 44 | if self.automatic_entropy_tuning is True: 45 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() 46 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) 47 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr) 48 | self.policy = GaussianPolicy(state_dim, action_space.shape[0], args, 49 | self.dim_state_with_fake, self.device, action_space).to(self.device) 50 | else: 51 | self.alpha = 0 52 | self.automatic_entropy_tuning = False 53 | self.policy = DeterministicPolicy(state_dim, action_space.shape[0], args, 54 | self.dim_state_with_fake, self.device, action_space).to(self.device) 55 | 56 | self.optimizer_name = args.optimizer 57 | if args.optimizer == 'adam': 58 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr, weight_decay=0.0002) 59 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr, weight_decay=0.0002) 60 | elif args.optimizer == 'sgd': 61 | self.policy_optim = SGD(self.policy.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 62 | self.critic_optim = SGD(self.critic.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 63 | elif args.optimizer == 'maskadam': 64 | self.policy_optim = MaskAdam(self.policy.parameters(), lr=args.lr, weight_decay=0.0002) 65 | self.critic_optim = MaskAdam(self.critic.parameters(), lr=args.lr, weight_decay=0.0002) 66 | else: 67 | raise ValueError(f'Unknown optimizer {args.optimizer} given') 68 | 69 | def select_action(self, state, evaluate=False): 70 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 71 | if evaluate is False: 72 | action, _, _ = self.policy.sample(state) 73 | else: 74 | _, _, action = self.policy.sample(state) 75 | return action.detach().cpu().numpy()[0] 76 | 77 | def update_parameters(self, memory, batch_size, updates): 78 | self.total_it += 1 79 | 80 | # Sample a batch from memory 81 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = memory.sample(batch_size=batch_size) 82 | 83 | state_batch = torch.FloatTensor(state_batch).to(self.device) 84 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 85 | action_batch = torch.FloatTensor(action_batch).to(self.device) 86 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) 87 | done_batch = torch.FloatTensor(done_batch).to(self.device).unsqueeze(1) 88 | 89 | with torch.no_grad(): 90 | next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) 91 | qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) 92 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 93 | next_q_value = reward_batch + done_batch * self.gamma * min_qf_next_target 94 | qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step 95 | qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 96 | qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 97 | qf_loss = qf1_loss + qf2_loss 98 | 99 | self.critic_optim.zero_grad() 100 | qf_loss.backward() 101 | if self.optimizer_name == 'maskadam': 102 | self.critic_optim.step(masks=self.get_mask_list_critic()) 103 | else: 104 | self.critic_optim.step() 105 | 106 | # Maintain the same sparse connectivity for critic 107 | self.apply_masks_critic() 108 | 109 | # Adapt the sparse connectivity 110 | if not self.lastTopologyChangeCritic and self.total_it % self.ascTopologyChangePeriod == 2: 111 | if self.total_it > self.earlyStopTopologyChangeIteration: 112 | self.lastTopologyChangeCritic = True 113 | 114 | self.update_topology_critic() 115 | self.apply_masks_critic() 116 | 117 | pi, log_pi, _ = self.policy.sample(state_batch) 118 | 119 | qf1_pi, qf2_pi = self.critic(state_batch, pi) 120 | min_qf_pi = torch.min(qf1_pi, qf2_pi) 121 | 122 | policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] 123 | 124 | self.policy_optim.zero_grad() 125 | policy_loss.backward() 126 | if self.optimizer_name == 'maskadam': 127 | self.policy_optim.step(masks=self.get_mask_list_actor()) 128 | else: 129 | self.policy_optim.step() 130 | 131 | # Maintain the same sparse connectivity for actor 132 | self.apply_masks_actor() 133 | 134 | if not self.lastTopologyChangePolicy and self.total_it % self.ascTopologyChangePeriod == 2: 135 | if self.total_it > self.earlyStopTopologyChangeIteration: 136 | self.lastTopologyChangePolicy = True 137 | 138 | self.update_topology_actor() 139 | self.apply_masks_actor() 140 | 141 | if self.automatic_entropy_tuning: 142 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 143 | 144 | self.alpha_optim.zero_grad() 145 | alpha_loss.backward() 146 | self.alpha_optim.step() 147 | 148 | self.alpha = self.log_alpha.exp() 149 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs 150 | else: 151 | alpha_loss = torch.tensor(0.).to(self.device) 152 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs 153 | 154 | if updates % self.target_update_interval == 0: 155 | soft_update(self.critic_target, self.critic, self.tau) 156 | 157 | loss_info = {'q1_loss': qf1_loss.item(), 158 | 'q2_loss': qf2_loss.item(), 159 | 'actor_loss': policy_loss.item(), 160 | 'alpha_loss': alpha_loss.item(), 161 | 'alpha_val': alpha_tlogs.item()} 162 | return loss_info 163 | 164 | def update_topology_critic(self): 165 | if not self.critic.dense_layers[0]: 166 | self.critic.mask1 = sp.adjust_connectivity_set(self.critic.linear1.weight.data.cpu().numpy(), 167 | self.critic.noPar1, self.setZeta, self.critic.mask1) 168 | self.critic.torchMask1 = torch.from_numpy(self.critic.mask1).float().to(self.device) 169 | self.critic.mask4 = sp.adjust_connectivity_set(self.critic.linear4.weight.data.cpu().numpy(), 170 | self.critic.noPar4, self.setZeta, self.critic.mask4) 171 | self.critic.torchMask4 = torch.from_numpy(self.critic.mask4).float().to(self.device) 172 | if not self.critic.dense_layers[1]: 173 | self.critic.mask2 = sp.adjust_connectivity_set(self.critic.linear2.weight.data.cpu().numpy(), 174 | self.critic.noPar2, self.setZeta, self.critic.mask2) 175 | self.critic.torchMask2 = torch.from_numpy(self.critic.mask2).float().to(self.device) 176 | self.critic.mask5 = sp.adjust_connectivity_set(self.critic.linear5.weight.data.cpu().numpy(), 177 | self.critic.noPar5, self.setZeta, self.critic.mask5) 178 | self.critic.torchMask5 = torch.from_numpy(self.critic.mask5).float().to(self.device) 179 | 180 | def apply_masks_critic(self): 181 | if not self.critic.dense_layers[0]: 182 | self.critic.linear1.weight.data.mul_(self.critic.torchMask1) 183 | self.critic.linear4.weight.data.mul_(self.critic.torchMask4) 184 | if not self.critic.dense_layers[1]: 185 | self.critic.linear2.weight.data.mul_(self.critic.torchMask2) 186 | self.critic.linear5.weight.data.mul_(self.critic.torchMask5) 187 | 188 | def update_topology_actor(self): 189 | if not self.policy.dense_layers[0]: 190 | self.policy.mask1 = sp.adjust_connectivity_set(self.policy.linear1.weight.data.cpu().numpy(), 191 | self.policy.noPar1, self.setZeta, self.policy.mask1) 192 | self.policy.torchMask1 = torch.from_numpy(self.policy.mask1).float().to(self.device) 193 | if not self.policy.dense_layers[1]: 194 | self.policy.mask2 = sp.adjust_connectivity_set(self.policy.linear2.weight.data.cpu().numpy(), 195 | self.policy.noPar2, self.setZeta, self.policy.mask2) 196 | self.policy.torchMask2 = torch.from_numpy(self.policy.mask2).float().to(self.device) 197 | 198 | def apply_masks_actor(self): 199 | if not self.policy.dense_layers[0]: 200 | self.policy.linear1.weight.data.mul_(self.policy.torchMask1) 201 | if not self.policy.dense_layers[1]: 202 | self.policy.linear2.weight.data.mul_(self.policy.torchMask2) 203 | 204 | def get_mask_list_critic(self): 205 | mask_list = [] 206 | if not self.critic.dense_layers[0]: 207 | mask_list.append(self.critic.torchMask1) 208 | else: 209 | mask_list.append(None) 210 | if not self.critic.dense_layers[1]: 211 | mask_list.append(self.critic.torchMask2) 212 | else: 213 | mask_list.append(None) 214 | mask_list.append(None) # output layer is dense 215 | if not self.critic.dense_layers[0]: 216 | mask_list.append(self.critic.torchMask4) 217 | else: 218 | mask_list.append(None) 219 | if not self.critic.dense_layers[1]: 220 | mask_list.append(self.critic.torchMask5) 221 | else: 222 | mask_list.append(None) 223 | mask_list.append(None) # output layer is dense 224 | return mask_list 225 | 226 | def get_mask_list_actor(self): 227 | mask_list = [] 228 | if not self.policy.dense_layers[0]: 229 | mask_list.append(self.policy.torchMask1) 230 | else: 231 | mask_list.append(None) 232 | if not self.policy.dense_layers[1]: 233 | mask_list.append(self.policy.torchMask2) 234 | else: 235 | mask_list.append(None) 236 | mask_list.append(None) # output layer is dense 237 | mask_list.append(None) # output layer of the actor has two heads (mean and log_std) 238 | # at least for sac_type Gaussian. If sac_type is Deterministic, then this is not the case, 239 | # but then an extra None in the list does not matter. 240 | return mask_list 241 | 242 | def print_sparsity(self): 243 | return sp.print_sparsities(self.critic.parameters(), self.critic_target.parameters(), self.policy.parameters()) 244 | 245 | def set_new_permutation(self): 246 | # sample a new permutation until it is not a duplicate 247 | duplicate = True 248 | while duplicate: 249 | permutation = torch.randperm(self.dim_state_with_fake) 250 | for p in self.prev_permutations: 251 | if torch.equal(p, permutation): 252 | break 253 | else: 254 | duplicate = False 255 | print(f'\nEnvironment change: new permutation of input features.') 256 | self.prev_permutations.append(permutation) 257 | self.policy.set_new_permutation(permutation) 258 | self.critic.set_new_permutation(permutation) 259 | self.critic_target.set_new_permutation(permutation) 260 | 261 | # Save model parameters 262 | def save(self, filename): 263 | checkpoint = { 264 | 'actor': self.policy.state_dict(), 265 | 'critic': self.critic.state_dict(), 266 | 'critic_target': self.critic_target.state_dict(), 267 | 'actor_optim': self.policy_optim.state_dict(), 268 | 'critic_optim': self.critic_optim.state_dict(), 269 | } 270 | torch.save(checkpoint, filename) 271 | print(f"Saved current model in: {filename}") 272 | 273 | # Load model parameters 274 | def load(self, filename, load_device=None): 275 | if load_device is None: 276 | load_device = self.device 277 | loaded_checkpoint = torch.load(filename, map_location=load_device) 278 | self.policy.load_state_dict(loaded_checkpoint["actor"]) 279 | self.policy_optim.load_state_dict(loaded_checkpoint["actor_optim"]) 280 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 281 | self.critic_target.load_state_dict(loaded_checkpoint["critic_target"]) 282 | self.critic_optim.load_state_dict(loaded_checkpoint["critic_optim"]) 283 | print(f"Loaded model from: {filename}") 284 | 285 | -------------------------------------------------------------------------------- /algorithms/sac_based/sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam, SGD 6 | from utils.target_network import soft_update, hard_update 7 | from utils.core_sac import GaussianPolicy, QNetwork, DeterministicPolicy 8 | 9 | 10 | class SAC(object): 11 | def __init__(self, state_dim, action_space, args, device): 12 | self.device = device 13 | 14 | self.gamma = args.discount 15 | self.tau = args.tau 16 | self.alpha = args.temperature 17 | 18 | self.target_update_interval = args.target_update_interval 19 | self.automatic_entropy_tuning = args.automatic_entropy_tuning 20 | 21 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - args.fake_features))) 22 | self.prev_permutations = [] 23 | 24 | self.critic = QNetwork(state_dim, action_space.shape[0], args, 25 | self.dim_state_with_fake, self.device).to(device=self.device) 26 | self.critic_target = QNetwork(state_dim, action_space.shape[0], args, 27 | self.dim_state_with_fake, self.device).to(self.device) 28 | hard_update(self.critic_target, self.critic) 29 | 30 | if args.sac_type == "Gaussian": 31 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 32 | if self.automatic_entropy_tuning: 33 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() 34 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) 35 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr) 36 | self.policy = GaussianPolicy(state_dim, action_space.shape[0], args, 37 | self.dim_state_with_fake, self.device, action_space).to(self.device) 38 | else: 39 | self.alpha = 0 40 | self.automatic_entropy_tuning = False 41 | self.policy = DeterministicPolicy(state_dim, action_space.shape[0], args, 42 | self.dim_state_with_fake, self.device, action_space).to(self.device) 43 | 44 | if args.optimizer in ['adam', 'maskadam']: # for all-dense networks: adam == maskadam 45 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr, weight_decay=0.0002) 46 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr, weight_decay=0.0002) 47 | elif args.optimizer == 'sgd': 48 | self.policy_optim = SGD(self.policy.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 49 | self.critic_optim = SGD(self.critic.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 50 | else: 51 | raise ValueError(f'Unknown optimizer {args.optimizer} given') 52 | 53 | def select_action(self, state, evaluate=False): 54 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 55 | if evaluate is False: 56 | action, _, _ = self.policy.sample(state) 57 | else: 58 | _, _, action = self.policy.sample(state) 59 | return action.detach().cpu().numpy()[0] 60 | 61 | def update_parameters(self, memory, batch_size, updates): 62 | # Sample a batch from memory 63 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = memory.sample(batch_size=batch_size) 64 | 65 | state_batch = torch.FloatTensor(state_batch).to(self.device) 66 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 67 | action_batch = torch.FloatTensor(action_batch).to(self.device) 68 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) 69 | done_batch = torch.FloatTensor(done_batch).to(self.device).unsqueeze(1) 70 | 71 | with torch.no_grad(): 72 | next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) 73 | qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) 74 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 75 | next_q_value = reward_batch + done_batch * self.gamma * min_qf_next_target 76 | qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step 77 | qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 78 | qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 79 | qf_loss = qf1_loss + qf2_loss 80 | 81 | self.critic_optim.zero_grad() 82 | qf_loss.backward() 83 | self.critic_optim.step() 84 | 85 | pi, log_pi, _ = self.policy.sample(state_batch) 86 | 87 | qf1_pi, qf2_pi = self.critic(state_batch, pi) 88 | min_qf_pi = torch.min(qf1_pi, qf2_pi) 89 | 90 | policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] 91 | 92 | self.policy_optim.zero_grad() 93 | policy_loss.backward() 94 | self.policy_optim.step() 95 | 96 | if self.automatic_entropy_tuning: 97 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 98 | 99 | self.alpha_optim.zero_grad() 100 | alpha_loss.backward() 101 | self.alpha_optim.step() 102 | 103 | self.alpha = self.log_alpha.exp() 104 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs 105 | else: 106 | alpha_loss = torch.tensor(0.).to(self.device) 107 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs 108 | 109 | if updates % self.target_update_interval == 0: 110 | soft_update(self.critic_target, self.critic, self.tau) 111 | 112 | loss_info = {'q1_loss': qf1_loss.item(), 113 | 'q2_loss': qf2_loss.item(), 114 | 'actor_loss': policy_loss.item(), 115 | 'alpha_loss': alpha_loss.item(), 116 | 'alpha_val': alpha_tlogs.item()} 117 | return loss_info 118 | 119 | def set_new_permutation(self): 120 | # sample a new permutation until it is not a duplicate 121 | duplicate = True 122 | while duplicate: 123 | permutation = torch.randperm(self.dim_state_with_fake) 124 | for p in self.prev_permutations: 125 | if torch.equal(p, permutation): 126 | break 127 | else: 128 | duplicate = False 129 | print(f'\nEnvironment change: new permutation of input features.') 130 | self.prev_permutations.append(permutation) 131 | self.policy.set_new_permutation(permutation) 132 | self.critic.set_new_permutation(permutation) 133 | self.critic_target.set_new_permutation(permutation) 134 | 135 | # Save model parameters 136 | def save(self, filename): 137 | checkpoint = { 138 | 'actor': self.policy.state_dict(), 139 | 'critic': self.critic.state_dict(), 140 | 'critic_target': self.critic_target.state_dict(), 141 | 'actor_optim': self.policy_optim.state_dict(), 142 | 'critic_optim': self.critic_optim.state_dict(), 143 | } 144 | torch.save(checkpoint, filename) 145 | print(f"Saved current model in: {filename}") 146 | 147 | # Load model parameters 148 | def load(self, filename, load_device=None): 149 | if load_device is None: 150 | load_device = self.device 151 | loaded_checkpoint = torch.load(filename, map_location=load_device) 152 | self.policy.load_state_dict(loaded_checkpoint["actor"]) 153 | self.policy_optim.load_state_dict(loaded_checkpoint["actor_optim"]) 154 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 155 | self.critic_target.load_state_dict(loaded_checkpoint["critic_target"]) 156 | self.critic_optim.load_state_dict(loaded_checkpoint["critic_optim"]) 157 | print(f"Loaded model from: {filename}") 158 | -------------------------------------------------------------------------------- /algorithms/sac_based/ss_sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam, SGD 6 | from utils.target_network import soft_update, hard_update 7 | from utils.core_anf_sac import GaussianPolicy, QNetwork, DeterministicPolicy 8 | import utils.sparse_utils as sp 9 | 10 | 11 | class Static_SAC(object): 12 | def __init__(self, state_dim, action_space, args, device): 13 | self.device = device 14 | 15 | self.gamma = args.discount 16 | self.tau = args.tau 17 | self.alpha = args.temperature 18 | 19 | self.target_update_interval = args.target_update_interval 20 | self.automatic_entropy_tuning = args.automatic_entropy_tuning 21 | 22 | self.total_it = 0 23 | self.setZeta = args.ann_setZeta 24 | self.ascTopologyChangePeriod = args.ann_ascTopologyChangePeriod 25 | self.lastTopologyChangeCritic = False 26 | self.lastTopologyChangePolicy = False 27 | self.ascStatsPolicy = [] 28 | self.ascStatsCritic = [] 29 | self.ascStatsValue = [] 30 | 31 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - args.fake_features))) 32 | self.prev_permutations = [] 33 | 34 | self.critic = QNetwork(state_dim, action_space.shape[0], args, 35 | self.dim_state_with_fake, self.device).to(device=self.device) 36 | self.critic_target = QNetwork(state_dim, action_space.shape[0], args, 37 | self.dim_state_with_fake, self.device).to(self.device) 38 | hard_update(self.critic_target, self.critic) 39 | 40 | if args.sac_type == "Gaussian": 41 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 42 | if self.automatic_entropy_tuning is True: 43 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() 44 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) 45 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr) 46 | self.policy = GaussianPolicy(state_dim, action_space.shape[0], args, 47 | self.dim_state_with_fake, self.device, action_space).to(self.device) 48 | else: 49 | self.alpha = 0 50 | self.automatic_entropy_tuning = False 51 | self.policy = DeterministicPolicy(state_dim, action_space.shape[0], args, 52 | self.dim_state_with_fake, self.device, action_space).to(self.device) 53 | 54 | if args.optimizer in ['adam', 'maskadam']: # for all-dense networks: adam == maskadam 55 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr, weight_decay=0.0002) 56 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr, weight_decay=0.0002) 57 | elif args.optimizer == 'sgd': 58 | self.policy_optim = SGD(self.policy.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 59 | self.critic_optim = SGD(self.critic.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0002) 60 | else: 61 | raise ValueError(f'Unknown optimizer {args.optimizer} given') 62 | 63 | def select_action(self, state, evaluate=False): 64 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 65 | if evaluate is False: 66 | action, _, _ = self.policy.sample(state) 67 | else: 68 | _, _, action = self.policy.sample(state) 69 | return action.detach().cpu().numpy()[0] 70 | 71 | def update_parameters(self, memory, batch_size, updates): 72 | self.total_it += 1 73 | 74 | # Sample a batch from memory 75 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = memory.sample(batch_size=batch_size) 76 | 77 | state_batch = torch.FloatTensor(state_batch).to(self.device) 78 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) 79 | action_batch = torch.FloatTensor(action_batch).to(self.device) 80 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) 81 | done_batch = torch.FloatTensor(done_batch).to(self.device).unsqueeze(1) 82 | 83 | with torch.no_grad(): 84 | next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) 85 | qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) 86 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 87 | next_q_value = reward_batch + done_batch * self.gamma * min_qf_next_target 88 | qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step 89 | qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 90 | qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 91 | qf_loss = qf1_loss + qf2_loss 92 | 93 | self.critic_optim.zero_grad() 94 | qf_loss.backward() 95 | self.critic_optim.step() 96 | 97 | # Maintain the same sparse connectivity for critic 98 | self.apply_masks_critic() 99 | 100 | pi, log_pi, _ = self.policy.sample(state_batch) 101 | 102 | qf1_pi, qf2_pi = self.critic(state_batch, pi) 103 | min_qf_pi = torch.min(qf1_pi, qf2_pi) 104 | 105 | policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] 106 | 107 | self.policy_optim.zero_grad() 108 | policy_loss.backward() 109 | self.policy_optim.step() 110 | 111 | # Maintain the same sparse connectivity for actor 112 | self.apply_masks_actor() 113 | 114 | if self.automatic_entropy_tuning: 115 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 116 | 117 | self.alpha_optim.zero_grad() 118 | alpha_loss.backward() 119 | self.alpha_optim.step() 120 | 121 | self.alpha = self.log_alpha.exp() 122 | alpha_tlogs = self.alpha.clone() # For logs 123 | else: 124 | alpha_loss = torch.tensor(0.).to(self.device) 125 | alpha_tlogs = torch.tensor(self.alpha) # For logs 126 | 127 | if updates % self.target_update_interval == 0: 128 | soft_update(self.critic_target, self.critic, self.tau) 129 | 130 | loss_info = {'q1_loss': qf1_loss.item(), 131 | 'q2_loss': qf2_loss.item(), 132 | 'actor_loss': policy_loss.item(), 133 | 'alpha_loss': alpha_loss.item(), 134 | 'alpha_val': alpha_tlogs.item()} 135 | return loss_info 136 | 137 | def apply_masks_critic(self): 138 | if not self.critic.dense_layers[0]: 139 | self.critic.linear1.weight.data.mul_(self.critic.torchMask1) 140 | self.critic.linear4.weight.data.mul_(self.critic.torchMask4) 141 | if not self.critic.dense_layers[1]: 142 | self.critic.linear2.weight.data.mul_(self.critic.torchMask2) 143 | self.critic.linear5.weight.data.mul_(self.critic.torchMask5) 144 | 145 | def apply_masks_actor(self): 146 | if not self.policy.dense_layers[0]: 147 | self.policy.linear1.weight.data.mul_(self.policy.torchMask1) 148 | if not self.policy.dense_layers[1]: 149 | self.policy.linear2.weight.data.mul_(self.policy.torchMask2) 150 | 151 | def print_sparsity(self): 152 | return sp.print_sparsities(self.critic.parameters(), self.critic_target.parameters(), self.policy.parameters()) 153 | 154 | def set_new_permutation(self): 155 | # sample a new permutation until it is not a duplicate 156 | duplicate = True 157 | while duplicate: 158 | permutation = torch.randperm(self.dim_state_with_fake) 159 | for p in self.prev_permutations: 160 | if torch.equal(p, permutation): 161 | break 162 | else: 163 | duplicate = False 164 | print(f'\nEnvironment change: new permutation of input features.') 165 | self.prev_permutations.append(permutation) 166 | self.policy.set_new_permutation(permutation) 167 | self.critic.set_new_permutation(permutation) 168 | self.critic_target.set_new_permutation(permutation) 169 | 170 | # Save model parameters 171 | def save(self, filename): 172 | checkpoint = { 173 | 'actor': self.policy.state_dict(), 174 | 'critic': self.critic.state_dict(), 175 | 'critic_target': self.critic_target.state_dict(), 176 | 'actor_optim': self.policy_optim.state_dict(), 177 | 'critic_optim': self.critic_optim.state_dict(), 178 | } 179 | torch.save(checkpoint, filename) 180 | print(f"Saved current model in: {filename}") 181 | 182 | # Load model parameters 183 | def load(self, filename, load_device=None): 184 | if load_device is None: 185 | load_device = self.device 186 | loaded_checkpoint = torch.load(filename, map_location=load_device) 187 | self.policy.load_state_dict(loaded_checkpoint["actor"]) 188 | self.policy_optim.load_state_dict(loaded_checkpoint["actor_optim"]) 189 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 190 | self.critic_target.load_state_dict(loaded_checkpoint["critic_target"]) 191 | self.critic_optim.load_state_dict(loaded_checkpoint["critic_optim"]) 192 | print(f"Loaded model from: {filename}") 193 | 194 | -------------------------------------------------------------------------------- /algorithms/td3_based/anf_td3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils import utils 7 | from utils import sparse_utils as sp 8 | from utils.mask_adam import MaskAdam 9 | from utils.core import SparseBaseAgent 10 | from utils.activations import setup_activation_funcs_list 11 | 12 | 13 | class Actor(nn.Module): 14 | def __init__(self, state_dim, action_dim, max_action, activation, act_func_args, global_sparsity, 15 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 16 | dim_state_with_fake, args, num_hid_layers=2, num_hid_neurons=256): 17 | super().__init__() 18 | assert num_hid_layers >= 1 19 | self.num_hid_layers = num_hid_layers 20 | all_connection_layers = num_hid_layers-1 + 2 21 | # -1 for neuron layers to connection layers, +2 for input and output layer 22 | self.device = device 23 | self.permutation = None # first env is without permutation 24 | self.num_fake_features = dim_state_with_fake - state_dim 25 | self.fake_noise_std = args.fake_noise_std 26 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 27 | 28 | sparsities = sp.compute_sparsity_per_layer( 29 | global_sparsity=global_sparsity, 30 | neuron_layers=[dim_state_with_fake] + [num_hid_neurons for _ in range(num_hid_layers)] + [action_dim], 31 | keep_dense=[input_layer_dense] + [False for _ in range(num_hid_layers-1)] + [output_layer_dense], 32 | method=sparsity_distribution_method, 33 | input_layer_sparsity=args.input_layer_sparsity) 34 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 35 | 36 | # First define the dense network 37 | self.input_layer = nn.Linear(dim_state_with_fake, num_hid_neurons) 38 | self.hid_layers = nn.ModuleList() 39 | for hid_connection_layer in range(num_hid_layers - 1): 40 | self.hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 41 | self.output_layer = nn.Linear(num_hid_neurons, action_dim) 42 | 43 | # Now make masks for the sparse layers 44 | self.num_parm_in_layer = [dim_state_with_fake * num_hid_layers] + \ 45 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 46 | [num_hid_neurons * action_dim] 47 | self.masks = [None for _ in range(all_connection_layers)] 48 | self.torch_masks = [None for _ in range(all_connection_layers)] 49 | for layer in range(all_connection_layers): 50 | if not self.dense_layers[layer]: 51 | if layer == 0: 52 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 53 | f"actor input layer", sparsities[layer], dim_state_with_fake, num_hid_neurons) 54 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 55 | self.input_layer.weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 56 | elif layer == all_connection_layers - 1: 57 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 58 | f"actor output layer", sparsities[layer], num_hid_neurons, action_dim) 59 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 60 | self.output_layer.weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 61 | else: 62 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 63 | f"actor hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 64 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 65 | self.hid_layers[layer - 1].weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 66 | # weights are put .to(device) later on, whole network at once 67 | 68 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 69 | self.output_activation = nn.Tanh() 70 | self.max_action = max_action 71 | 72 | def forward(self, state): 73 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 74 | self.fake_noise_std, self.fake_noise_generator) 75 | state = utils.permute_features(state, self.permutation) 76 | a = self.activation_funcs[0](self.input_layer(state)) 77 | for hid_layer in range(self.num_hid_layers - 1): 78 | a = self.activation_funcs[hid_layer + 1](self.hid_layers[hid_layer](a)) 79 | return self.max_action * self.output_activation(self.output_layer(a)) 80 | 81 | def set_new_permutation(self, permutation): 82 | self.permutation = permutation 83 | 84 | 85 | class Critic(nn.Module): 86 | def __init__(self, state_dim, action_dim, activation, act_func_args, global_sparsity, 87 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 88 | dim_state_with_fake, args, num_hid_layers=2, num_hid_neurons=256): 89 | super().__init__() 90 | assert num_hid_layers >= 1 91 | self.num_hid_layers = num_hid_layers 92 | all_connection_layers = num_hid_layers-1 + 2 93 | self.device = device 94 | self.permutation = None 95 | self.num_fake_features = dim_state_with_fake - state_dim 96 | self.fake_noise_std = args.fake_noise_std 97 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 98 | 99 | sparsities = sp.compute_sparsity_per_layer( 100 | global_sparsity=global_sparsity, 101 | neuron_layers=[dim_state_with_fake + action_dim] + [num_hid_neurons for _ in range(num_hid_layers)] + [1], 102 | keep_dense=[input_layer_dense] + [False for _ in range(num_hid_layers-1)] + [True], 103 | method=sparsity_distribution_method, 104 | input_layer_sparsity=args.input_layer_sparsity) 105 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 106 | 107 | # Q1 dense architecture 108 | self.q1_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 109 | self.q1_hid_layers = nn.ModuleList() 110 | for hid_connection_layer in range(num_hid_layers - 1): 111 | self.q1_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 112 | self.q1_output_layer = nn.Linear(num_hid_neurons, 1) 113 | 114 | # Q2 dense architecture 115 | self.q2_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 116 | self.q2_hid_layers = nn.ModuleList() 117 | for hid_connection_layer in range(num_hid_layers - 1): 118 | self.q2_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 119 | self.q2_output_layer = nn.Linear(num_hid_neurons, 1) 120 | 121 | # Setup masks for Q1 122 | self.q1_num_parm_in_layer = [(dim_state_with_fake+action_dim) * num_hid_layers] + \ 123 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 124 | [num_hid_neurons] 125 | self.q1_masks = [None for _ in range(all_connection_layers)] 126 | self.q1_torch_masks = [None for _ in range(all_connection_layers)] 127 | for layer in range(all_connection_layers): 128 | if not self.dense_layers[layer]: 129 | if layer == 0: 130 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 131 | f"critic Q1 input layer", sparsities[layer], dim_state_with_fake+action_dim, num_hid_neurons) 132 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 133 | self.q1_input_layer.weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 134 | elif layer == all_connection_layers - 1: 135 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 136 | f"critic Q1 output layer", sparsities[layer], num_hid_neurons, 1) 137 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 138 | self.q1_output_layer.weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 139 | else: 140 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 141 | f"critic Q1 hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 142 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 143 | self.q1_hid_layers[layer-1].weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 144 | 145 | # Setup masks for Q2 146 | self.q2_num_parm_in_layer = [(dim_state_with_fake+action_dim) * num_hid_layers] + \ 147 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 148 | [num_hid_neurons] 149 | self.q2_masks = [None for _ in range(all_connection_layers)] 150 | self.q2_torch_masks = [None for _ in range(all_connection_layers)] 151 | for layer in range(all_connection_layers): 152 | if not self.dense_layers[layer]: 153 | if layer == 0: 154 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 155 | f"critic Q2 input layer", sparsities[layer], dim_state_with_fake+action_dim, num_hid_neurons) 156 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 157 | self.q2_input_layer.weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 158 | elif layer == all_connection_layers - 1: 159 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 160 | f"critic Q2 output layer", sparsities[layer], num_hid_neurons, 1) 161 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 162 | self.q2_output_layer.weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 163 | else: 164 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 165 | f"critic Q2 hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 166 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 167 | self.q2_hid_layers[layer-1].weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 168 | 169 | # Activation functions 170 | self.q1_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 171 | self.q2_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 172 | 173 | def forward(self, state, action): 174 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 175 | self.fake_noise_std, self.fake_noise_generator) 176 | state = utils.permute_features(state, self.permutation) 177 | sa = torch.cat([state, action], 1) 178 | 179 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 180 | for hid_layer in range(self.num_hid_layers - 1): 181 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 182 | q1 = self.q1_output_layer(q1) 183 | 184 | q2 = self.q2_activation_funcs[0](self.q2_input_layer(sa)) 185 | for hid_layer in range(self.num_hid_layers - 1): 186 | q2 = self.q2_activation_funcs[hid_layer + 1](self.q2_hid_layers[hid_layer](q2)) 187 | q2 = self.q2_output_layer(q2) 188 | 189 | return q1, q2 190 | 191 | def Q1(self, state, action): 192 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 193 | self.fake_noise_std, self.fake_noise_generator) 194 | state = utils.permute_features(state, self.permutation) 195 | sa = torch.cat([state, action], 1) 196 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 197 | for hid_layer in range(self.num_hid_layers - 1): 198 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 199 | return self.q1_output_layer(q1) 200 | 201 | def set_new_permutation(self, permutation): 202 | self.permutation = permutation 203 | 204 | 205 | class ANF_TD3(SparseBaseAgent): 206 | def __init__( 207 | self, 208 | state_dim, 209 | action_dim, 210 | max_action, 211 | args, 212 | discount=0.99, 213 | tau=0.005, 214 | policy_noise=0.2, 215 | noise_clip=0.5, 216 | policy_freq=2, 217 | num_hid_layers=2, 218 | num_hid_neurons=256, 219 | activation='relu', 220 | act_func_args=(None, False), 221 | optimizer='adam', 222 | lr=0.001, 223 | global_sparsity=0.5, 224 | sparsity_distribution_method='ER', 225 | input_layer_dense=False, 226 | output_layer_dense=True, 227 | setZeta=0.05, 228 | init_new_weights_method='zero', 229 | ascTopologyChangePeriod=1000, 230 | earlyStopTopologyChangeIteration=1e9, # kind of never 231 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 232 | fake_features=0.0, 233 | ): 234 | super().__init__() 235 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - fake_features))) 236 | self.prev_permutations = [] 237 | 238 | self.actor = Actor(state_dim, action_dim, max_action, activation, act_func_args, global_sparsity, 239 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 240 | self.dim_state_with_fake, args, num_hid_layers, num_hid_neurons).to(device) 241 | self.actor_target = copy.deepcopy(self.actor) 242 | 243 | self.critic = Critic(state_dim, action_dim, activation, act_func_args, global_sparsity, 244 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 245 | self.dim_state_with_fake, args, num_hid_layers, num_hid_neurons).to(device) 246 | self.critic_target = copy.deepcopy(self.critic) 247 | 248 | self.optimizer_name = optimizer 249 | if optimizer == 'adam': 250 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr, weight_decay=0.0002) 251 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr, weight_decay=0.0002) 252 | elif optimizer == 'sgd': 253 | self.actor_optimizer = torch.optim.SGD(self.actor.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 254 | self.critic_optimizer = torch.optim.SGD(self.critic.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 255 | elif optimizer == 'maskadam': 256 | self.actor_optimizer = MaskAdam(self.actor.parameters(), lr=lr, weight_decay=0.0002) 257 | self.critic_optimizer = MaskAdam(self.critic.parameters(), lr=lr, weight_decay=0.0002) 258 | else: 259 | raise ValueError(f'Unknown optimizer {optimizer} given') 260 | 261 | self.device = device 262 | self.max_action = max_action 263 | self.discount = discount 264 | self.tau = tau 265 | self.policy_noise = policy_noise 266 | self.noise_clip = noise_clip 267 | self.policy_freq = policy_freq 268 | 269 | self.setZeta = setZeta 270 | self.init_new_weights_method = init_new_weights_method 271 | self.ascTopologyChangePeriod = ascTopologyChangePeriod 272 | self.earlyStopTopologyChangeIteration = earlyStopTopologyChangeIteration 273 | self.lastTopologyChangeCritic = False 274 | self.lastTopologyChangeActor = False 275 | self.ascStatsActor = [] 276 | self.ascStatsCritic = [] 277 | self.total_it = 0 278 | 279 | def select_action(self, state, evaluate=False): 280 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 281 | return self.actor(state).cpu().data.numpy().flatten() 282 | 283 | def train(self, replay_buffer, batch_size=100): 284 | self.total_it += 1 285 | # Sample replay buffer 286 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 287 | 288 | with torch.no_grad(): 289 | # Select action according to policy and add clipped noise 290 | noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) 291 | next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action) 292 | # Compute the target Q value 293 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 294 | target_Q = reward + not_done * self.discount * torch.min(target_Q1, target_Q2) 295 | 296 | # Get current Q estimates 297 | current_Q1, current_Q2 = self.critic(state, action) 298 | # Compute critic loss 299 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 300 | # Optimize the critic 301 | self.critic_optimizer.zero_grad() 302 | critic_loss.backward() 303 | if self.optimizer_name == 'maskadam': 304 | self.critic_optimizer.step(masks=self.critic.q1_torch_masks + self.critic.q2_torch_masks) 305 | else: 306 | self.critic_optimizer.step() 307 | # Maintain the same sparse connectivity for critic 308 | self.apply_masks_critic() 309 | 310 | # Adapt the sparse connectivity 311 | if not self.lastTopologyChangeCritic and self.total_it % self.ascTopologyChangePeriod == 2: 312 | if self.total_it > self.earlyStopTopologyChangeIteration: 313 | self.lastTopologyChangeCritic = True 314 | if self.init_new_weights_method != 'zero': 315 | q1_old_masks, q2_old_masks = copy.deepcopy(self.critic.q1_masks), copy.deepcopy(self.critic.q2_masks) 316 | 317 | self.update_topology_critic() 318 | 319 | if self.init_new_weights_method != 'zero': 320 | sp.critic_give_new_connections_init_values(self.critic, q1_old_masks, q2_old_masks, 321 | self.init_new_weights_method, self.device) 322 | self.apply_masks_critic() 323 | 324 | # Delayed policy updates 325 | if self.total_it % self.policy_freq == 0: 326 | # Compute actor loss 327 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 328 | # Optimize the actor 329 | self.actor_optimizer.zero_grad() 330 | actor_loss.backward() 331 | if self.optimizer_name == 'maskadam': 332 | self.actor_optimizer.step(masks=self.actor.torch_masks) 333 | else: 334 | self.actor_optimizer.step() 335 | # Maintain the same sparse connectivity for actor 336 | self.apply_masks_actor() 337 | 338 | # Adapt the sparse connectivity of the actor 339 | if not self.lastTopologyChangeActor and self.total_it % self.ascTopologyChangePeriod == 2: 340 | if self.total_it > self.earlyStopTopologyChangeIteration: 341 | self.lastTopologyChangeActor = True 342 | if self.init_new_weights_method != 'zero': 343 | old_masks = copy.deepcopy(self.actor.masks) 344 | 345 | self.update_topology_actor() 346 | 347 | if self.init_new_weights_method != 'zero': 348 | sp.actor_give_new_connections_init_values(self.actor, old_masks, 349 | self.init_new_weights_method, self.device) 350 | self.apply_masks_actor() 351 | 352 | # Update the frozen target models 353 | self.update_target_networks() 354 | 355 | def update_topology_critic(self): 356 | for layer in range(self.critic.num_hid_layers + 1): 357 | if not self.critic.dense_layers[layer]: 358 | if layer == 0: 359 | self.critic.q1_masks[layer] = sp.adjust_connectivity_set( 360 | self.critic.q1_input_layer.weight.data.cpu().numpy(), self.critic.q1_num_parm_in_layer[layer], 361 | self.setZeta, self.critic.q1_masks[layer]) 362 | self.critic.q2_masks[layer] = sp.adjust_connectivity_set( 363 | self.critic.q2_input_layer.weight.data.cpu().numpy(), self.critic.q2_num_parm_in_layer[layer], 364 | self.setZeta, self.critic.q2_masks[layer]) 365 | elif layer == self.critic.num_hid_layers: 366 | self.critic.q1_masks[layer] = sp.adjust_connectivity_set( 367 | self.critic.q1_output_layer.weight.data.cpu().numpy(), self.critic.q1_num_parm_in_layer[layer], 368 | self.setZeta, self.critic.q1_masks[layer]) 369 | self.critic.q2_masks[layer] = sp.adjust_connectivity_set( 370 | self.critic.q2_output_layer.weight.data.cpu().numpy(), self.critic.q2_num_parm_in_layer[layer], 371 | self.setZeta, self.critic.q2_masks[layer]) 372 | else: 373 | self.critic.q1_masks[layer] = sp.adjust_connectivity_set( 374 | self.critic.q1_hid_layers[layer - 1].weight.data.cpu().numpy(), 375 | self.critic.q1_num_parm_in_layer[layer], self.setZeta, self.critic.q1_masks[layer]) 376 | self.critic.q2_masks[layer] = sp.adjust_connectivity_set( 377 | self.critic.q2_hid_layers[layer - 1].weight.data.cpu().numpy(), 378 | self.critic.q2_num_parm_in_layer[layer], self.setZeta, self.critic.q2_masks[layer]) 379 | self.critic.q1_torch_masks[layer] = torch.from_numpy(self.critic.q1_masks[layer]).float().to( 380 | self.device) 381 | self.critic.q2_torch_masks[layer] = torch.from_numpy(self.critic.q2_masks[layer]).float().to( 382 | self.device) 383 | 384 | def update_topology_actor(self): 385 | for layer in range(self.actor.num_hid_layers + 1): 386 | if not self.actor.dense_layers[layer]: 387 | if layer == 0: 388 | self.actor.masks[layer] = sp.adjust_connectivity_set( 389 | self.actor.input_layer.weight.data.cpu().numpy(), 390 | self.actor.num_parm_in_layer[layer], self.setZeta, self.actor.masks[layer]) 391 | elif layer == self.actor.num_hid_layers: 392 | self.actor.masks[layer] = sp.adjust_connectivity_set( 393 | self.actor.output_layer.weight.data.cpu().numpy(), 394 | self.actor.num_parm_in_layer[layer], self.setZeta, self.actor.masks[layer]) 395 | else: 396 | self.actor.masks[layer] = sp.adjust_connectivity_set( 397 | self.actor.hid_layers[layer - 1].weight.data.cpu().numpy(), 398 | self.actor.num_parm_in_layer[layer], self.setZeta, self.actor.masks[layer]) 399 | self.actor.torch_masks[layer] = torch.from_numpy(self.actor.masks[layer]).float().to(self.device) 400 | 401 | def update_target_networks(self): 402 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 403 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 404 | if len(param.shape) > 1: 405 | self.maintain_sparsity_target_networks(param, target_param, self.device) 406 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 407 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 408 | if len(param.shape) > 1: 409 | self.maintain_sparsity_target_networks(param, target_param, self.device) 410 | 411 | def maintain_sparsity_target_networks(self, param, target_param, device): 412 | current_density = (param != 0).sum() 413 | target_density = (target_param != 0).sum() # torch.count_nonzero(target_param.data) 414 | difference = target_density - current_density 415 | # constrain the sparsity by removing the extra elements (smallest values) 416 | if difference > 0: 417 | count_rmv = difference 418 | tmp = copy.deepcopy(abs(target_param.data)) 419 | tmp[tmp == 0] = 10000000 420 | unraveled = self.unravel_index(torch.argsort(tmp.view(1, -1)[0]), tmp.shape) 421 | rmv_indicies = torch.stack(unraveled, dim=1) 422 | rmv_values_smaller_than = tmp[rmv_indicies[count_rmv][0], rmv_indicies[count_rmv][1]] 423 | target_param.data[tmp < rmv_values_smaller_than] = 0 424 | 425 | def unravel_index(self, index, shape): 426 | out = [] 427 | for dim in reversed(shape): 428 | out.append(index % dim) 429 | index = index // dim 430 | return tuple(reversed(out)) 431 | 432 | def apply_masks_critic(self): 433 | for layer in range(self.critic.num_hid_layers + 1): 434 | if not self.critic.dense_layers[layer]: 435 | if layer == 0: 436 | self.critic.q1_input_layer.weight.data.mul_(self.critic.q1_torch_masks[layer]) 437 | self.critic.q2_input_layer.weight.data.mul_(self.critic.q2_torch_masks[layer]) 438 | elif layer == self.critic.num_hid_layers: 439 | self.critic.q1_output_layer.weight.data.mul_(self.critic.q1_torch_masks[layer]) 440 | self.critic.q2_output_layer.weight.data.mul_(self.critic.q2_torch_masks[layer]) 441 | else: 442 | self.critic.q1_hid_layers[layer - 1].weight.data.mul_(self.critic.q1_torch_masks[layer]) 443 | self.critic.q2_hid_layers[layer - 1].weight.data.mul_(self.critic.q2_torch_masks[layer]) 444 | 445 | def apply_masks_actor(self): 446 | for layer in range(self.actor.num_hid_layers + 1): 447 | if not self.actor.dense_layers[layer]: 448 | if layer == 0: 449 | self.actor.input_layer.weight.data.mul_(self.actor.torch_masks[layer]) 450 | elif layer == self.actor.num_hid_layers: 451 | self.actor.output_layer.weight.data.mul_(self.actor.torch_masks[layer]) 452 | else: 453 | self.actor.hid_layers[layer - 1].weight.data.mul_(self.actor.torch_masks[layer]) 454 | 455 | def print_sparsity(self): 456 | return sp.print_sparsities(self.critic.parameters(), self.critic_target.parameters(), 457 | self.actor.parameters(), self.actor_target.parameters()) 458 | 459 | def saveAscStats(self, filename): 460 | np.savez(filename + "_ASC_stats.npz", ascStatsActor=self.ascStatsActor, ascStatsCritic=self.ascStatsCritic) 461 | 462 | def set_new_permutation(self): 463 | # sample a new permutation until it is not a duplicate 464 | duplicate = True 465 | while duplicate: 466 | permutation = torch.randperm(self.dim_state_with_fake) 467 | for p in self.prev_permutations: 468 | if torch.equal(p, permutation): 469 | break 470 | else: 471 | duplicate = False 472 | print(f'\nEnvironment change: new permutation of input features.') 473 | self.prev_permutations.append(permutation) 474 | self.actor.set_new_permutation(permutation) 475 | self.critic.set_new_permutation(permutation) 476 | self.actor_target.set_new_permutation(permutation) 477 | self.critic_target.set_new_permutation(permutation) 478 | -------------------------------------------------------------------------------- /algorithms/td3_based/ss_td3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils import utils 7 | from utils.core import SparseBaseAgent 8 | from utils.activations import setup_activation_funcs_list 9 | import utils.sparse_utils as sp 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, state_dim, action_dim, max_action, activation, act_func_args, global_sparsity, 14 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 15 | dim_state_with_fake, args, num_hid_layers=2, num_hid_neurons=256): 16 | super().__init__() 17 | assert num_hid_layers >= 1 18 | self.num_hid_layers = num_hid_layers 19 | all_connection_layers = num_hid_layers-1 + 2 20 | # -1 for neuron layers to connection layers, +2 for input and output layer 21 | self.device = device 22 | self.permutation = None 23 | self.num_fake_features = dim_state_with_fake - state_dim 24 | self.fake_noise_std = args.fake_noise_std 25 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 26 | 27 | sparsities = sp.compute_sparsity_per_layer( 28 | global_sparsity=global_sparsity, 29 | neuron_layers=[dim_state_with_fake] + [num_hid_neurons for _ in range(num_hid_layers)] + [action_dim], 30 | keep_dense=[input_layer_dense] + [False for _ in range(num_hid_layers - 1)] + [output_layer_dense], 31 | method=sparsity_distribution_method, 32 | input_layer_sparsity=args.input_layer_sparsity) 33 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 34 | 35 | # First define the dense network 36 | self.input_layer = nn.Linear(dim_state_with_fake, num_hid_neurons) 37 | self.hid_layers = nn.ModuleList() 38 | for hid_connection_layer in range(num_hid_layers - 1): 39 | self.hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 40 | self.output_layer = nn.Linear(num_hid_neurons, action_dim) 41 | 42 | # Now make masks for the sparse layers 43 | self.num_parm_in_layer = [dim_state_with_fake * num_hid_layers] + \ 44 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 45 | [num_hid_neurons * action_dim] 46 | self.masks = [None for _ in range(all_connection_layers)] 47 | self.torch_masks = [None for _ in range(all_connection_layers)] 48 | for layer in range(all_connection_layers): 49 | if not self.dense_layers[layer]: 50 | if layer == 0: 51 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 52 | f"actor input layer", sparsities[layer], dim_state_with_fake, num_hid_neurons) 53 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 54 | self.input_layer.weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 55 | elif layer == all_connection_layers - 1: 56 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 57 | f"actor output layer", sparsities[layer], num_hid_neurons, action_dim) 58 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 59 | self.output_layer.weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 60 | else: 61 | self.num_parm_in_layer[layer], self.masks[layer] = sp.initialize_mask( 62 | f"actor hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 63 | self.torch_masks[layer] = torch.from_numpy(self.masks[layer]).float().to(device) 64 | self.hid_layers[layer - 1].weight.data.mul_(torch.from_numpy(self.masks[layer]).float()) 65 | # weights are put .to(device) later on, whole network at once 66 | 67 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 68 | self.output_activation = nn.Tanh() 69 | self.max_action = max_action 70 | 71 | def forward(self, state): 72 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 73 | self.fake_noise_std, self.fake_noise_generator) 74 | state = utils.permute_features(state, self.permutation) 75 | a = self.activation_funcs[0](self.input_layer(state)) 76 | for hid_layer in range(self.num_hid_layers - 1): 77 | a = self.activation_funcs[hid_layer + 1](self.hid_layers[hid_layer](a)) 78 | return self.max_action * self.output_activation(self.output_layer(a)) 79 | 80 | def set_new_permutation(self, permutation): 81 | self.permutation = permutation 82 | 83 | 84 | class Critic(nn.Module): 85 | def __init__(self, state_dim, action_dim, activation, act_func_args, global_sparsity, 86 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 87 | dim_state_with_fake, args, num_hid_layers=2, num_hid_neurons=256): 88 | super().__init__() 89 | assert num_hid_layers >= 1 90 | self.num_hid_layers = num_hid_layers 91 | all_connection_layers = num_hid_layers-1 + 2 92 | self.device = device 93 | self.permutation = None 94 | self.num_fake_features = dim_state_with_fake - state_dim 95 | self.fake_noise_std = args.fake_noise_std 96 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 97 | 98 | sparsities = sp.compute_sparsity_per_layer( 99 | global_sparsity=global_sparsity, 100 | neuron_layers=[dim_state_with_fake + action_dim] + [num_hid_neurons for _ in range(num_hid_layers)] + [1], 101 | keep_dense=[input_layer_dense] + [False for _ in range(num_hid_layers-1)] + [True], 102 | method=sparsity_distribution_method, 103 | input_layer_sparsity=args.input_layer_sparsity) 104 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 105 | 106 | # Q1 dense architecture 107 | self.q1_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 108 | self.q1_hid_layers = nn.ModuleList() 109 | for hid_connection_layer in range(num_hid_layers - 1): 110 | self.q1_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 111 | self.q1_output_layer = nn.Linear(num_hid_neurons, 1) 112 | 113 | # Q2 dense architecture 114 | self.q2_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 115 | self.q2_hid_layers = nn.ModuleList() 116 | for hid_connection_layer in range(num_hid_layers - 1): 117 | self.q2_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 118 | self.q2_output_layer = nn.Linear(num_hid_neurons, 1) 119 | 120 | # Setup masks for Q1 121 | self.q1_num_parm_in_layer = [(dim_state_with_fake+action_dim) * num_hid_layers] + \ 122 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 123 | [num_hid_neurons] 124 | self.q1_masks = [None for _ in range(all_connection_layers)] 125 | self.q1_torch_masks = [None for _ in range(all_connection_layers)] 126 | for layer in range(all_connection_layers): 127 | if not self.dense_layers[layer]: 128 | if layer == 0: 129 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 130 | f"critic Q1 input layer", sparsities[layer], dim_state_with_fake+action_dim, num_hid_neurons) 131 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 132 | self.q1_input_layer.weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 133 | elif layer == all_connection_layers - 1: 134 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 135 | f"critic Q1 output layer", sparsities[layer], num_hid_neurons, 1) 136 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 137 | self.q1_output_layer.weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 138 | else: 139 | self.q1_num_parm_in_layer[layer], self.q1_masks[layer] = sp.initialize_mask( 140 | f"critic Q1 hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 141 | self.q1_torch_masks[layer] = torch.from_numpy(self.q1_masks[layer]).float().to(device) 142 | self.q1_hid_layers[layer-1].weight.data.mul_(torch.from_numpy(self.q1_masks[layer]).float()) 143 | 144 | # Setup masks for Q2 145 | self.q2_num_parm_in_layer = [(dim_state_with_fake+action_dim) * num_hid_layers] + \ 146 | [num_hid_neurons**2 for _ in range(num_hid_layers-1)] + \ 147 | [num_hid_neurons] 148 | self.q2_masks = [None for _ in range(all_connection_layers)] 149 | self.q2_torch_masks = [None for _ in range(all_connection_layers)] 150 | for layer in range(all_connection_layers): 151 | if not self.dense_layers[layer]: 152 | if layer == 0: 153 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 154 | f"critic Q2 input layer", sparsities[layer], dim_state_with_fake+action_dim, num_hid_neurons) 155 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 156 | self.q2_input_layer.weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 157 | elif layer == all_connection_layers - 1: 158 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 159 | f"critic Q2 output layer", sparsities[layer], num_hid_neurons, 1) 160 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 161 | self.q2_output_layer.weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 162 | else: 163 | self.q2_num_parm_in_layer[layer], self.q2_masks[layer] = sp.initialize_mask( 164 | f"critic Q2 hid layer {layer}", sparsities[layer], num_hid_neurons, num_hid_neurons) 165 | self.q2_torch_masks[layer] = torch.from_numpy(self.q2_masks[layer]).float().to(device) 166 | self.q2_hid_layers[layer-1].weight.data.mul_(torch.from_numpy(self.q2_masks[layer]).float()) 167 | 168 | # Activation functions 169 | self.q1_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 170 | self.q2_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 171 | 172 | def forward(self, state, action): 173 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 174 | self.fake_noise_std, self.fake_noise_generator) 175 | state = utils.permute_features(state, self.permutation) 176 | sa = torch.cat([state, action], 1) 177 | 178 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 179 | for hid_layer in range(self.num_hid_layers - 1): 180 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 181 | q1 = self.q1_output_layer(q1) 182 | 183 | q2 = self.q2_activation_funcs[0](self.q2_input_layer(sa)) 184 | for hid_layer in range(self.num_hid_layers - 1): 185 | q2 = self.q2_activation_funcs[hid_layer + 1](self.q2_hid_layers[hid_layer](q2)) 186 | q2 = self.q2_output_layer(q2) 187 | 188 | return q1, q2 189 | 190 | def Q1(self, state, action): 191 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 192 | self.fake_noise_std, self.fake_noise_generator) 193 | state = utils.permute_features(state, self.permutation) 194 | sa = torch.cat([state, action], 1) 195 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 196 | for hid_layer in range(self.num_hid_layers - 1): 197 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 198 | return self.q1_output_layer(q1) 199 | 200 | def set_new_permutation(self, permutation): 201 | self.permutation = permutation 202 | 203 | 204 | class StaticSparseTD3(SparseBaseAgent): 205 | def __init__( 206 | self, 207 | state_dim, 208 | action_dim, 209 | max_action, 210 | args, 211 | discount=0.99, 212 | tau=0.005, 213 | policy_noise=0.2, 214 | noise_clip=0.5, 215 | policy_freq=2, 216 | num_hid_layers=2, 217 | num_hid_neurons=256, 218 | activation='relu', 219 | act_func_args=(None, False), 220 | optimizer='adam', 221 | lr=0.001, 222 | global_sparsity=0.5, 223 | sparsity_distribution_method='ER', 224 | input_layer_dense=False, 225 | output_layer_dense=True, 226 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 227 | fake_features=0.0, 228 | ): 229 | super().__init__() 230 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - fake_features))) 231 | self.prev_permutations = [] 232 | 233 | self.actor = Actor(state_dim, action_dim, max_action, activation, act_func_args, global_sparsity, 234 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 235 | self.dim_state_with_fake, args, num_hid_layers, num_hid_neurons).to(device) 236 | self.actor_target = copy.deepcopy(self.actor) 237 | 238 | self.critic = Critic(state_dim, action_dim, activation, act_func_args, global_sparsity, 239 | sparsity_distribution_method, input_layer_dense, output_layer_dense, device, 240 | self.dim_state_with_fake, args, num_hid_layers, num_hid_neurons).to(device) 241 | self.critic_target = copy.deepcopy(self.critic) 242 | 243 | if optimizer in ['adam', 'maskadam']: # for static networks: adam == maskadam 244 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr, weight_decay=0.0002) 245 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr, weight_decay=0.0002) 246 | elif optimizer == 'sgd': 247 | self.actor_optimizer = torch.optim.SGD(self.actor.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 248 | self.critic_optimizer = torch.optim.SGD(self.critic.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 249 | else: 250 | raise ValueError(f'unknown optimizer {optimizer} given') 251 | 252 | self.device = device 253 | self.max_action = max_action 254 | self.discount = discount 255 | self.tau = tau 256 | self.policy_noise = policy_noise 257 | self.noise_clip = noise_clip 258 | self.policy_freq = policy_freq 259 | 260 | self.total_it = 0 261 | 262 | def select_action(self, state, evaluate=False): 263 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 264 | return self.actor(state).cpu().data.numpy().flatten() 265 | 266 | def train(self, replay_buffer, batch_size=100): 267 | self.total_it += 1 268 | 269 | # Sample replay buffer 270 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 271 | 272 | with torch.no_grad(): 273 | # Select action according to policy and add clipped noise 274 | noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) 275 | next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action) 276 | 277 | # Compute the target Q value 278 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 279 | target_Q = reward + not_done * self.discount * torch.min(target_Q1, target_Q2) 280 | 281 | # Get current Q estimates 282 | current_Q1, current_Q2 = self.critic(state, action) 283 | 284 | # Compute critic loss 285 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 286 | 287 | # Optimize the critic 288 | self.critic_optimizer.zero_grad() 289 | critic_loss.backward() 290 | self.critic_optimizer.step() 291 | 292 | # Maintain the same sparse connectivity for critic 293 | self.apply_masks_critic() 294 | 295 | # Delayed policy updates 296 | if self.total_it % self.policy_freq == 0: 297 | 298 | # Compute actor loss 299 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 300 | 301 | # Optimize the actor 302 | self.actor_optimizer.zero_grad() 303 | actor_loss.backward() 304 | self.actor_optimizer.step() 305 | 306 | # Maintain the same sparse connectivity for actor 307 | self.apply_masks_actor() 308 | 309 | # Update the frozen target models 310 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 311 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 312 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 313 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 314 | # no need to maintain sparsity in target networks, 315 | # as the masks are static and already applied to main networks 316 | 317 | def apply_masks_critic(self): 318 | for layer in range(self.critic.num_hid_layers + 1): 319 | if not self.critic.dense_layers[layer]: 320 | if layer == 0: 321 | self.critic.q1_input_layer.weight.data.mul_(self.critic.q1_torch_masks[layer]) 322 | self.critic.q2_input_layer.weight.data.mul_(self.critic.q2_torch_masks[layer]) 323 | elif layer == self.critic.num_hid_layers: 324 | self.critic.q1_output_layer.weight.data.mul_(self.critic.q1_torch_masks[layer]) 325 | self.critic.q2_output_layer.weight.data.mul_(self.critic.q2_torch_masks[layer]) 326 | else: 327 | self.critic.q1_hid_layers[layer - 1].weight.data.mul_(self.critic.q1_torch_masks[layer]) 328 | self.critic.q2_hid_layers[layer - 1].weight.data.mul_(self.critic.q2_torch_masks[layer]) 329 | 330 | def apply_masks_actor(self): 331 | for layer in range(self.actor.num_hid_layers + 1): 332 | if not self.actor.dense_layers[layer]: 333 | if layer == 0: 334 | self.actor.input_layer.weight.data.mul_(self.actor.torch_masks[layer]) 335 | elif layer == self.actor.num_hid_layers: 336 | self.actor.output_layer.weight.data.mul_(self.actor.torch_masks[layer]) 337 | else: 338 | self.actor.hid_layers[layer - 1].weight.data.mul_(self.actor.torch_masks[layer]) 339 | 340 | def print_sparsity(self): 341 | return sp.print_sparsities(self.critic.parameters(), self.critic_target.parameters(), 342 | self.actor.parameters(), self.actor_target.parameters()) 343 | 344 | def set_new_permutation(self): 345 | # sample a new permutation until it is not a duplicate 346 | duplicate = True 347 | while duplicate: 348 | permutation = torch.randperm(self.dim_state_with_fake) 349 | for p in self.prev_permutations: 350 | if torch.equal(p, permutation): 351 | break 352 | else: 353 | duplicate = False 354 | print(f'\nEnvironment change: new permutation of input features.') 355 | self.prev_permutations.append(permutation) 356 | self.actor.set_new_permutation(permutation) 357 | self.critic.set_new_permutation(permutation) 358 | self.actor_target.set_new_permutation(permutation) 359 | self.critic_target.set_new_permutation(permutation) 360 | 361 | -------------------------------------------------------------------------------- /algorithms/td3_based/td3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils import utils 7 | from utils.core import BaseAgent 8 | from utils.activations import setup_activation_funcs_list 9 | 10 | # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) 11 | # Paper: https://arxiv.org/abs/1802.09477 12 | 13 | 14 | class Actor(nn.Module): 15 | def __init__(self, state_dim, action_dim, max_action, activation, act_func_args, device, args, 16 | dim_state_with_fake, num_hid_layers=2, num_hid_neurons=256): 17 | super().__init__() 18 | assert num_hid_layers >= 1 19 | self.num_hid_layers = num_hid_layers 20 | self.device = device 21 | self.permutation = None # first env is without permutation 22 | self.num_fake_features = dim_state_with_fake - state_dim 23 | self.fake_noise_std = args.fake_noise_std 24 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 25 | 26 | self.input_layer = nn.Linear(dim_state_with_fake, num_hid_neurons) 27 | self.hid_layers = nn.ModuleList() # a simple python list does not work here, see: 28 | # https://discuss.pytorch.org/t/when-should-i-use-nn-modulelist-and-when-should-i-use-nn-sequential/5463 29 | for hid_connection_layer in range(num_hid_layers - 1): 30 | self.hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 31 | self.output_layer = nn.Linear(num_hid_neurons, action_dim) 32 | 33 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 34 | self.output_activation = nn.Tanh() 35 | self.max_action = max_action 36 | 37 | def forward(self, state): 38 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 39 | self.fake_noise_std, self.fake_noise_generator) 40 | state = utils.permute_features(state, self.permutation) 41 | a = self.activation_funcs[0](self.input_layer(state)) 42 | for hid_layer in range(self.num_hid_layers - 1): 43 | a = self.activation_funcs[hid_layer + 1](self.hid_layers[hid_layer](a)) 44 | return self.max_action * self.output_activation(self.output_layer(a)) 45 | 46 | def set_new_permutation(self, permutation): 47 | self.permutation = permutation 48 | 49 | 50 | class Critic(nn.Module): 51 | def __init__(self, state_dim, action_dim, activation, act_func_args, device, args, 52 | dim_state_with_fake, num_hid_layers=2, num_hid_neurons=256): 53 | super().__init__() 54 | assert num_hid_layers >= 1 55 | self.num_hid_layers = num_hid_layers 56 | self.device = device 57 | self.permutation = None 58 | self.num_fake_features = dim_state_with_fake - state_dim 59 | self.fake_noise_std = args.fake_noise_std 60 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 61 | 62 | # Q1 architecture 63 | self.q1_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 64 | self.q1_hid_layers = nn.ModuleList() 65 | for hid_connection_layer in range(num_hid_layers - 1): 66 | self.q1_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 67 | self.q1_output_layer = nn.Linear(num_hid_neurons, 1) 68 | 69 | # Q2 architecture 70 | self.q2_input_layer = nn.Linear(dim_state_with_fake + action_dim, num_hid_neurons) 71 | self.q2_hid_layers = nn.ModuleList() 72 | for hid_connection_layer in range(num_hid_layers - 1): 73 | self.q2_hid_layers.append(nn.Linear(num_hid_neurons, num_hid_neurons)) 74 | self.q2_output_layer = nn.Linear(num_hid_neurons, 1) 75 | 76 | # Activation functions 77 | self.q1_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 78 | self.q2_activation_funcs = setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons) 79 | 80 | def forward(self, state, action): 81 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 82 | self.fake_noise_std, self.fake_noise_generator) 83 | state = utils.permute_features(state, self.permutation) 84 | sa = torch.cat([state, action], 1) 85 | 86 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 87 | for hid_layer in range(self.num_hid_layers - 1): 88 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 89 | q1 = self.q1_output_layer(q1) 90 | 91 | q2 = self.q2_activation_funcs[0](self.q2_input_layer(sa)) 92 | for hid_layer in range(self.num_hid_layers - 1): 93 | q2 = self.q2_activation_funcs[hid_layer + 1](self.q2_hid_layers[hid_layer](q2)) 94 | q2 = self.q2_output_layer(q2) 95 | 96 | return q1, q2 97 | 98 | def Q1(self, state, action): 99 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 100 | self.fake_noise_std, self.fake_noise_generator) 101 | state = utils.permute_features(state, self.permutation) 102 | sa = torch.cat([state, action], 1) 103 | q1 = self.q1_activation_funcs[0](self.q1_input_layer(sa)) 104 | for hid_layer in range(self.num_hid_layers - 1): 105 | q1 = self.q1_activation_funcs[hid_layer + 1](self.q1_hid_layers[hid_layer](q1)) 106 | return self.q1_output_layer(q1) 107 | 108 | def set_new_permutation(self, permutation): 109 | self.permutation = permutation 110 | 111 | 112 | class TD3(BaseAgent): 113 | def __init__( 114 | self, 115 | state_dim, 116 | action_dim, 117 | max_action, 118 | args, 119 | discount=0.99, 120 | tau=0.005, 121 | num_hid_layers=2, 122 | num_hid_neurons=256, 123 | activation='relu', 124 | act_func_args=(None, False), 125 | optimizer='adam', 126 | lr=0.001, 127 | policy_noise=0.2, 128 | noise_clip=0.5, 129 | policy_freq=2, 130 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 131 | fake_features=0.0, 132 | ): 133 | super().__init__() 134 | self.dim_state_with_fake = int(np.ceil(state_dim / (1 - fake_features))) 135 | 136 | self.actor = Actor(state_dim, action_dim, max_action, activation, act_func_args, device, args, 137 | self.dim_state_with_fake, num_hid_layers, num_hid_neurons).to(device) 138 | self.actor_target = copy.deepcopy(self.actor) 139 | 140 | self.critic = Critic(state_dim, action_dim, activation, act_func_args, device, args, 141 | self.dim_state_with_fake, num_hid_layers, num_hid_neurons).to(device) 142 | self.critic_target = copy.deepcopy(self.critic) 143 | 144 | if optimizer in ['adam', 'maskadam']: # for all-dense networks: adam == maskadam 145 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr, weight_decay=0.0002) 146 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr, weight_decay=0.0002) 147 | elif optimizer == 'sgd': 148 | self.actor_optimizer = torch.optim.SGD(self.actor.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 149 | self.critic_optimizer = torch.optim.SGD(self.critic.parameters(), lr=lr, momentum=0.9, weight_decay=0.0002) 150 | else: 151 | raise ValueError(f'Unknown optimizer {optimizer} given') 152 | 153 | self.device = device 154 | self.max_action = max_action 155 | self.discount = discount 156 | self.tau = tau 157 | self.policy_noise = policy_noise 158 | self.noise_clip = noise_clip 159 | self.policy_freq = policy_freq 160 | self.prev_permutations = [] 161 | self.total_it = 0 162 | 163 | def select_action(self, state, evaluate=False): 164 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 165 | return self.actor(state).cpu().data.numpy().flatten() 166 | 167 | def train(self, replay_buffer, batch_size=100): 168 | self.total_it += 1 169 | 170 | # Sample replay buffer 171 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 172 | 173 | with torch.no_grad(): 174 | # Select action according to policy and add clipped noise 175 | noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) 176 | next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action) 177 | 178 | # Compute the target Q value 179 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 180 | target_Q = torch.min(target_Q1, target_Q2) 181 | target_Q = reward + not_done * self.discount * target_Q 182 | 183 | # Get current Q estimates 184 | current_Q1, current_Q2 = self.critic(state, action) 185 | 186 | # Compute critic loss 187 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 188 | 189 | # Optimize the critic 190 | self.critic_optimizer.zero_grad() 191 | critic_loss.backward() 192 | self.critic_optimizer.step() 193 | 194 | # Delayed policy updates 195 | if self.total_it % self.policy_freq == 0: 196 | 197 | # Compute actor loss 198 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 199 | 200 | # Optimize the actor 201 | self.actor_optimizer.zero_grad() 202 | actor_loss.backward() 203 | self.actor_optimizer.step() 204 | 205 | # Update the frozen target models 206 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 207 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 208 | 209 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 210 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 211 | 212 | def set_new_permutation(self): 213 | # sample a new permutation until it is not a duplicate 214 | duplicate = True 215 | while duplicate: 216 | permutation = torch.randperm(self.dim_state_with_fake) 217 | for p in self.prev_permutations: 218 | if torch.equal(p, permutation): 219 | break 220 | else: 221 | duplicate = False 222 | print(f'\nEnvironment change: new permutation of input features.') 223 | self.prev_permutations.append(permutation) 224 | self.actor.set_new_permutation(permutation) 225 | self.critic.set_new_permutation(permutation) 226 | self.actor_target.set_new_permutation(permutation) 227 | self.critic_target.set_new_permutation(permutation) 228 | 229 | 230 | if __name__ == '__main__': 231 | # to test a bit 232 | 233 | td3agent = TD3(state_dim=17, action_dim=6, max_action=1.0, activation='relu') 234 | stdict = td3agent.actor.state_dict() 235 | # print(stdict) 236 | 237 | for key in stdict: 238 | print(key) 239 | -------------------------------------------------------------------------------- /algorithms/test_algos: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### call with: bash test_algos 3 | # File to test if all algorithms are working correctly. 4 | # Runs a few training steps, 5 | # runs a few evaluation episodes, 6 | # runs a change in environment (permutation) 7 | 8 | cd .. 9 | source venv/bin/activate 10 | wandb disabled 11 | 12 | 13 | for policy in ANF-SAC ANF-TD3 SAC TD3 Static-SAC Static-TD3 14 | do 15 | python main.py --policy $policy \ 16 | --env HalfCheetah-v3 \ 17 | --seed 42 \ 18 | --eval_freq 2500 \ 19 | --start_timesteps 1000 \ 20 | --activation relu \ 21 | --optimizer maskadam \ 22 | --global_sparsity 0.0 \ 23 | --sparsity_distribution_method uniform \ 24 | --input_layer_sparsity 0.8 \ 25 | --not_save_model --not_save_results \ 26 | --fake_features 0.5 \ 27 | --fake_noise_std 1.0 \ 28 | --adjust_env_period 3000 \ 29 | --max_timesteps 6000 30 | printf "\nFinished policy: $policy \n\n" 31 | done 32 | printf "\n\n\nFinished all policies. \n\n" 33 | -------------------------------------------------------------------------------- /figures/ANF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/figures/ANF.png -------------------------------------------------------------------------------- /figures/learning_curves_halfcheetah_nf98.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/figures/learning_curves_halfcheetah_nf98.png -------------------------------------------------------------------------------- /figures/mujoco_gym_4envs_captions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/figures/mujoco_gym_4envs_captions.png -------------------------------------------------------------------------------- /figures/mujoco_visual_halfcheetah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/figures/mujoco_visual_halfcheetah.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import datetime 5 | import wandb 6 | from utils import utils 7 | from algorithms import run_sac, run_td3 8 | 9 | 10 | if __name__ == "__main__": 11 | main_start_time = datetime.datetime.now() 12 | parser = argparse.ArgumentParser() 13 | utils.add_arguments(parser) 14 | args = parser.parse_args() 15 | utils.print_all_args(args) 16 | utils.make_folders() 17 | file_name = utils.set_file_name(args) 18 | wandb.init(project="ANF", config=vars(args), name=file_name, entity="VScAIL", mode=args.wandb_mode) 19 | 20 | # Set seeds 21 | torch.manual_seed(args.seed) 22 | np.random.seed(args.seed) 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | print(f"Using device: {device}") 26 | 27 | if 'TD3' in args.policy: 28 | run_td3.run(args, file_name, device, main_start_time) 29 | elif 'SAC' in args.policy: 30 | run_sac.run(args, file_name, device, main_start_time) 31 | else: 32 | raise ValueError("Invalid policy name") 33 | 34 | total_runtime = datetime.datetime.now() - main_start_time 35 | print(f"\nTotal running time {total_runtime}\n\n") 36 | wandb.log({"total_runtime_hours": round(total_runtime.total_seconds() / 3600, 2)}) 37 | -------------------------------------------------------------------------------- /tutorial.md: -------------------------------------------------------------------------------- 1 | 2 | # Sparse Training in Deep RL - Tutorial 3 | 4 | This guide is designed for anyone interested in using sparse neural networks in deep reinforcement learning. 5 | It was part of the [IJCAI 2023 tutorial](https://ijcai-23.org/tutorials/) 6 | T27: _Sparse Training for Supervised, Unsupervised, Continual, and Deep Reinforcement Learning with Deep Neural Networks_. 7 | In the following we will play around with some of the settings in this repository. 8 | I hope this will give you a feeling for the types of problems we are researching in this field. 9 | 10 | Author: [Bram Grooten](https://www.bramgrooten.nl/). 11 | 12 | ## Install 13 | 14 | First `git clone` and install this repository, 15 | by following the instructions in 16 | the [README](https://github.com/bramgrooten/automatic-noise-filtering/blob/main/README.md). 17 | 18 | 19 | ## Visualizing a trained agent 20 | 21 | We have included a trained agent in the repository already, 22 | so let's view what this actually looks like! 23 | Go to this repository's main folder in your terminal 24 | (the directory with the `view_mujoco.py` file in it), 25 | make sure your venv is activated, and run: 26 | 27 | ``` 28 | python view_mujoco.py 29 | ``` 30 | 31 | A window should pop up showing you one episode of a running HalfCheetah! 32 | The camera doesn't track the agent by default, press `TAB` to do that. 33 | 34 | ![Image showing the MuJoCo visualization window](figures/mujoco_visual_halfcheetah.png) 35 | 36 | The agent is rewarded for moving forward as quickly as possible. See the exact details in the documentation of [MuJoCo Gym](https://gymnasium.farama.org/environments/mujoco/half_cheetah/). 37 | 38 | At the bottom of the `view_mujoco.py` file there are some instructions on downloading more of our trained agents, 39 | if you want to compare a few or look at some agents move in different environments. 40 | 41 | ![Image of the 4 environments used in the ANF paper](figures/mujoco_gym_4envs_captions.png) 42 | 43 | 44 | 45 | ## Training sparse neural networks 46 | 47 | In this repository you can train a sparse version of [SAC](https://arxiv.org/abs/1801.01290) or [TD3](https://arxiv.org/abs/1802.09477). 48 | Run the script `main.py` as follows: 49 | 50 | ``` 51 | python main.py --policy ANF-SAC --env HalfCheetah-v3 --global_sparsity 0.5 --wandb_mode disabled 52 | ``` 53 | 54 | to train a sparse SAC agent on the HalfCheetah-v3 environment with 50% sparsity. 55 | The original DS-TD3 and DS-SAC algorithms from [Sokar et al.](https://arxiv.org/abs/2106.04217) (which this repository is based on) 56 | apply **D**ynamic **S**parse Training to show that only using half of the weights 57 | can actually improve performance over dense models in many environments! 58 | 59 | A training run in reinforcement learning can take quite a while 60 | (i.e., between 4-24 hours depending on the settings and compute hardware you use). 61 | However, you should be able to see the rewards go up pretty quickly for HalfCheetah. 62 | Running the above command on my laptop with an NVIDIA Quadro P2000 GPU shows: 63 | 64 | ``` 65 | Total Steps: 1000 Episode Num: 0 Epi. Steps: 1000 Reward: -106.9 Epi. Time: 0:00:38.113582 Total Train Time: 0:01:04.060303 66 | Total Steps: 2000 Episode Num: 1 Epi. Steps: 1000 Reward: -229.72 Epi. Time: 0:00:36.587989 Total Train Time: 0:01:40.648814 67 | Total Steps: 3000 Episode Num: 2 Epi. Steps: 1000 Reward: -254.99 Epi. Time: 0:00:34.820381 Total Train Time: 0:02:15.469734 68 | Total Steps: 4000 Episode Num: 3 Epi. Steps: 1000 Reward: -123.78 Epi. Time: 0:00:36.723802 Total Train Time: 0:02:52.194053 69 | Total Steps: 5000 Episode Num: 4 Epi. Steps: 1000 Reward: -65.56 Epi. Time: 0:00:37.499780 Total Train Time: 0:03:29.694346 70 | --------------------------------------- 71 | Evaluation over 5 episode(s): -88.544 72 | --------------------------------------- 73 | Total Steps: 6000 Episode Num: 5 Epi. Steps: 1000 Reward: -98.79 Epi. Time: 0:00:54.771722 Total Train Time: 0:04:24.466859 74 | Total Steps: 7000 Episode Num: 6 Epi. Steps: 1000 Reward: -20.56 Epi. Time: 0:00:38.878653 Total Train Time: 0:05:03.346438 75 | Total Steps: 8000 Episode Num: 7 Epi. Steps: 1000 Reward: 261.05 Epi. Time: 0:00:40.444133 Total Train Time: 0:05:43.791106 76 | Total Steps: 9000 Episode Num: 8 Epi. Steps: 1000 Reward: 186.32 Epi. Time: 0:00:39.042627 Total Train Time: 0:06:22.834215 77 | Total Steps: 10000 Episode Num: 9 Epi. Steps: 1000 Reward: 1363.95 Epi. Time: 0:00:37.373520 Total Train Time: 0:07:00.208328 78 | --------------------------------------- 79 | Evaluation over 5 episode(s): 1635.536 80 | --------------------------------------- 81 | Total Steps: 11000 Episode Num: 10 Epi. Steps: 1000 Reward: 1653.07 Epi. Time: 0:00:48.873239 Total Train Time: 0:07:49.082014 82 | Total Steps: 12000 Episode Num: 11 Epi. Steps: 1000 Reward: 1855.95 Epi. Time: 0:00:34.300092 Total Train Time: 0:08:23.382577 83 | Total Steps: 13000 Episode Num: 12 Epi. Steps: 1000 Reward: 2008.34 Epi. Time: 0:00:37.275500 Total Train Time: 0:09:00.658566 84 | Total Steps: 14000 Episode Num: 13 Epi. Steps: 1000 Reward: 2141.13 Epi. Time: 0:00:39.403169 Total Train Time: 0:09:40.062282 85 | Total Steps: 15000 Episode Num: 14 Epi. Steps: 1000 Reward: 2531.5 Epi. Time: 0:00:41.110036 Total Train Time: 0:10:21.172832 86 | --------------------------------------- 87 | Evaluation over 5 episode(s): 3113.252 88 | --------------------------------------- 89 | ``` 90 | 91 | After 10 minutes of training we already have an average evaluation return (=sum of rewards) of >3000. 92 | There is often a lot of variance in RL unfortunately, so the results will probably differ for you. 93 | We need to run many seeds in RL to get a good indication of an algorithm's performance. 94 | 95 | Note that the evaluation returns will almost always be better than the returns during training, 96 | as the agents often use extra stochasticity during training to boost exploration. 97 | During evaluation this is turned off, so the expected best action is always selected. 98 | 99 | The graph below should give an indication of the expected returns for HalfCheetah-v3. 100 | Note that the graph is from Figure 5 of the [ANF paper](https://arxiv.org/abs/2302.06548), 101 | where we used environments with large amounts of noise features. 102 | Sparse neural networks can help to focus on the task-relevant features, 103 | but we'll cover that in the section on Automatic Noise Filtering further below. 104 | 105 | ![Learning curves plot on HalfCheetah](figures/learning_curves_halfcheetah_nf98.png) 106 | 107 | 108 | The number of environment steps to train on is set to 1 million by default, 109 | but you can change this with the argument `--max_timesteps`. 110 | See the file [utils.py](https://github.com/bramgrooten/automatic-noise-filtering/blob/main/utils/utils.py) 111 | for all the arguments you can play with, or run `python main.py --help` in the terminal. 112 | 113 | To see the code showing how this actually works under the hood, check out the files: 114 | - `main.py`: where the training call is started 115 | - `algorithms/sac_based/anf_sac.py`: to see when we update and apply the sparsity masks 116 | - `utils/sparse_utils.py`: to see how we initialize and update the sparsity masks 117 | 118 | 119 | Go ahead and try out different values for the `--global_sparsity` and see how that affects the performance. 120 | There is an interesting paper called "[The State of Sparse Training in Deep RL](https://arxiv.org/abs/2206.10369)" 121 | that investigated how far sparsity levels can be pushed in Deep RL, 122 | going all the way to 99%. 123 | In RL, often the performance stays competitive with fully dense networks until a sparsity level of about 90%. 124 | 125 | 126 | 127 | ### Sparsity Distribution 128 | 129 | An example of a setting that you can play around with is the 130 | distribution of sparsity levels over the different layers of your neural network. 131 | In the simplest setting of Dynamic Sparse Training (DST) we usually have 132 | a fixed global sparsity level throughout training. 133 | Also, the sparsity distribution over the layers is usually kept fixed. 134 | What is always adapted in DST is which weights are currently activated. 135 | 136 | In this repository for RL we often use small MLP networks with 3 layers. 137 | The hidden layer has 256 neurons on each side. The input layer size depends on the environment, 138 | and the output layer has either the same number of neurons as the number of actions 139 | (for the Actor network) or just 1 neuron (for the Critic networks). 140 | For example, the Actor network on HalfCheetah consists of: 17-256-256-6. 141 | We often keep the crucial output layer dense, which especially makes sense for the Critic. 142 | 143 | A typical sparsity distribution for a 90% sparse NN in our work would be: 144 | 145 | | Layer 1 | Layer 2 | Layer 3 | 146 | |---------|---------|---------| 147 | | 90.3% | 90.3% | 0% | 148 | 149 | 150 | Notice that to get a 90% sparsity level over the whole network while keeping the output layer dense, 151 | we need to increase the sparsity level of the sparse layers by just a bit. 152 | (Since the output layer is much smaller, i.e., has much fewer weights.) 153 | 154 | The table above uses the default `--sparsity_distribution_method`: `uniform`. 155 | This means it uniformly distributes the sparsity levels over all layers that are not kept dense, 156 | such that the requested `global_sparsity` level is reached. 157 | 158 | Another popular sparsity distribution method is `ER`, which is developed by Mocanu et al. 159 | in the [SET paper](https://www.nature.com/articles/s41467-018-04316-3). 160 | The name comes from the famous Erdős-Rényi random graph. 161 | This method gives larger layers a higher sparsity level, as they may be able to handle more pruning. 162 | (Similar to how richer people pay a higher percentage of income tax.) 163 | 164 | Start a training run with: 165 | ``` 166 | python main.py --sparsity_distribution_method ER --global_sparsity 0.9 --policy ANF-SAC --env HalfCheetah-v3 --wandb_mode disabled 167 | ``` 168 | and check how the sparsity distribution differs from the table above. 169 | 170 | In this repository there's also a method included which for now is just called "new". 171 | It's a method that I came up with to define a distribution that would sit 172 | somewhere in between ER and uniform. 173 | I have plotted the differences in this [graph](https://www.desmos.com/calculator/yuvwypolsm). 174 | See all the details in the function `compute_sparsity_per_layer` of `sparse_utils.py`. 175 | 176 | 177 | ### Automatic Noise Filtering 178 | 179 | In the [ANF paper](https://arxiv.org/abs/2302.06548), 180 | we discovered that sparsity can be very beneficial in filtering noise. 181 | In environments with large amounts of noise features 182 | (where just a small subset of all state features are task-relevant), 183 | Dynamic Sparse Training learns to focus the network's connections on the important features. 184 | 185 | We simulate environments with lots of irrelevant features 186 | by adding many fake / noise features to the existing MuJoCo environments. 187 | The fraction of noise features you want in your environment can be set with `--fake_features`. 188 | 189 | The main difference in ANF is that we only sparsify the input layer. 190 | Our sparsity distribution that we use in the paper is thus as follows: 191 | 192 | | Layer 1 | Layer 2 | Layer 3 | 193 | |---------|---------|---------| 194 | | 80% | 0% | 0% | 195 | 196 | This can be set by using the argument `--input_layer_sparsity`. (Even if you then set `--global_sparsity` to 0, which is default, 197 | the input layer will still be sparse and the new global sparsity level will be printed for you.) 198 | 199 | Let's try out an environment with 90% noise features by running: 200 | 201 | ``` 202 | python main.py --input_layer_sparsity 0.8 --fake_features 0.9 --policy ANF-SAC --env HalfCheetah-v3 --wandb_mode disabled 203 | ``` 204 | and compare the performance with 205 | ``` 206 | python main.py --input_layer_sparsity 0.8 --fake_features 0.9 --policy SAC --env HalfCheetah-v3 --wandb_mode disabled 207 | ``` 208 | 209 | Note that the state space of the environment is now 10X larger, so the runtime will increase. 210 | As shown in the graph with learning curves above, 211 | ANF is able to outperform standard dense networks by a wide margin on environments with lots of noise. 212 | 213 | 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /utils/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def srelu_func(x: torch.Tensor, 6 | threshold_right, 7 | slope_right, 8 | threshold_left, 9 | slope_left 10 | ) -> torch.Tensor: 11 | far_positives = (x > threshold_right) 12 | far_negatives = (x < threshold_left) 13 | # middle_ones = torch.logical_not(torch.logical_or(far_negatives, far_positives)) 14 | # not needed, as middle ones keep the same value in output 15 | output = x.clone() 16 | output[far_positives] = threshold_right + slope_right * (x[far_positives] - threshold_right) 17 | output[far_negatives] = threshold_left + slope_left * (x[far_negatives] - threshold_left) 18 | return output 19 | 20 | 21 | def srelu_func_per_neuron(x: torch.Tensor, 22 | threshold_right, 23 | slope_right, 24 | threshold_left, 25 | slope_left 26 | ) -> torch.Tensor: 27 | """ The four params here are of type nn.Parameter 28 | with same shape as input x in last dimension (shape[-1]) """ 29 | far_positives = (x > threshold_right) 30 | far_negatives = (x < threshold_left) 31 | # middle_ones = torch.logical_not(torch.logical_or(far_negatives, far_positives)) 32 | # not needed, as middle ones keep the same value in output 33 | output = x.clone() 34 | # expanding parameters threshold_right (and others) to have same (batch) size 35 | output[far_positives] = threshold_right.expand_as(x)[far_positives] + slope_right.expand_as(x)[far_positives] * ( 36 | x[far_positives] - threshold_right.expand_as(x)[far_positives]) 37 | output[far_negatives] = threshold_left.expand_as(x)[far_negatives] + slope_left.expand_as(x)[far_negatives] * ( 38 | x[far_negatives] - threshold_left.expand_as(x)[far_negatives]) 39 | return output 40 | 41 | 42 | def lex_func(x: torch.Tensor, 43 | multiplier_right, 44 | exponent_right, 45 | multiplier_left, 46 | exponent_left, 47 | ) -> torch.Tensor: 48 | positives = (x >= 0) 49 | negatives = (x < 0) 50 | output = torch.zeros_like(x) 51 | output[positives] = multiplier_right * (x[positives] ** exponent_right) 52 | output[negatives] = -multiplier_left * ((-x[negatives]) ** exponent_left) 53 | return output 54 | 55 | 56 | def lex_func_per_neuron(x: torch.Tensor, 57 | multiplier_right, 58 | exponent_right, 59 | multiplier_left, 60 | exponent_left, 61 | ) -> torch.Tensor: 62 | """ The four params here are of type nn.Parameter 63 | with same shape as input x in last dimension (shape[-1]) """ 64 | positives = (x >= 0) 65 | negatives = (x < 0) 66 | output = torch.zeros_like(x) 67 | output[positives] = multiplier_right[positives] * (x[positives] ** exponent_right[positives]) 68 | output[negatives] = -multiplier_left[negatives] * ((-x[negatives]) ** exponent_left[negatives]) 69 | return output 70 | 71 | 72 | class SymSqrt(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | positives = (x >= 0) 78 | negatives = (x < 0) 79 | output = torch.zeros_like(x) 80 | output[positives] = torch.sqrt(x[positives]) 81 | output[negatives] = -torch.sqrt(-x[negatives]) 82 | return output 83 | 84 | 85 | class SymSqrt1(nn.Module): 86 | def __init__(self): 87 | super().__init__() 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | far_positives = (x > 1) 91 | far_negatives = (x < -1) 92 | # middle_ones = torch.logical_not(torch.logical_or(far_negatives, far_positives)) 93 | # not needed, as middle ones keep the same value in output 94 | output = x.clone() 95 | output[far_positives] = torch.sqrt(x[far_positives]) 96 | output[far_negatives] = -torch.sqrt(-x[far_negatives]) 97 | return output 98 | 99 | 100 | class NonLEx(nn.Module): 101 | """ Non-Learnable Exponents (fixed params version of LEx) """ 102 | 103 | def __init__(self, 104 | multiplier_right: float = 1., 105 | exponent_right: float = 1., 106 | multiplier_left: float = 1., 107 | exponent_left: float = 0.5 108 | ): 109 | super().__init__() 110 | self.mr = multiplier_right 111 | self.er = exponent_right 112 | self.ml = multiplier_left 113 | self.el = exponent_left 114 | 115 | def forward(self, x: torch.Tensor) -> torch.Tensor: 116 | return lex_func(x, self.mr, self.er, self.ml, self.el) 117 | 118 | 119 | class LEx(nn.Module): 120 | """ Learnable Exponents. 121 | num_neurons (int): 1 means per-layer shared LEx, >1 means per-neuron LEx. 122 | Although it takes an int as input, there are only two values legitimate: 123 | 1, or the number of channels(neurons) at input. Default: 1""" 124 | 125 | def __init__(self, 126 | multiplier_right: float = 1., 127 | exponent_right: float = 1., 128 | multiplier_left: float = 1., 129 | exponent_left: float = 0.5, 130 | num_neurons: int = 1, 131 | random_init=False # this is a (optional) feature, to add later maybe 132 | ): 133 | super().__init__() 134 | self.num_neurons = num_neurons 135 | param_shape = (num_neurons,) 136 | self.mr = nn.Parameter(torch.full(param_shape, float(multiplier_right))) 137 | self.er = nn.Parameter(torch.full(param_shape, float(exponent_right))) 138 | self.ml = nn.Parameter(torch.full(param_shape, float(multiplier_left))) 139 | self.el = nn.Parameter(torch.full(param_shape, float(exponent_left))) 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | if self.num_neurons == 1: 143 | return lex_func(x, self.mr, self.er, self.ml, self.el) 144 | else: 145 | return lex_func_per_neuron(x, self.mr, self.er, self.ml, self.el) 146 | 147 | 148 | class AlternatedLeftReLU(nn.Module): 149 | """ ALL-ReLU activation function 150 | The alternating behavior of this function needs to be implemented by yourself at a higher level of abstraction 151 | (by giving each layer -alpha or alpha as the slope_left) 152 | Function that is then implemented is: 153 | -alpha * x if x < 0 and layer_index % 2 == 0 154 | f(x) = { alpha * x if x < 0 and layer_index % 2 == 1 155 | x if x > 0 156 | The input layer (layer_index = 1) and output layer (layer_index = L) are excluded. 157 | See the paper: Truly Sparse Neural Networks at Scale, by Curci, Mocanu & Pechenizkiy: 158 | https://arxiv.org/abs/2102.01732 159 | """ 160 | def __init__(self, slope_left: float): 161 | super().__init__() 162 | self.slope_left = slope_left 163 | 164 | def forward(self, x: torch.Tensor) -> torch.Tensor: 165 | negatives = (x < 0) 166 | output = x.clone() 167 | output[negatives] = self.slope_left * x[negatives] 168 | return output 169 | 170 | 171 | class FixedSReLU(nn.Module): 172 | def __init__(self, 173 | threshold_right: float = 0.4, 174 | slope_right: float = 0.2, 175 | threshold_left: float = -0.4, 176 | slope_left: float = 0.2 177 | ): 178 | super().__init__() 179 | self.tr = threshold_right 180 | self.ar = slope_right 181 | self.tl = threshold_left 182 | self.al = slope_left 183 | 184 | def forward(self, x: torch.Tensor) -> torch.Tensor: 185 | return srelu_func(x, self.tr, self.ar, self.tl, self.al) 186 | 187 | 188 | class SReLU(nn.Module): 189 | """ Activation function SReLU. 190 | num_neurons (int): 1 means per-layer shared SReLU, >1 means per-neuron SReLU. 191 | Although it takes an int as input, there are only two values legitimate: 192 | 1, or the number of channels(neurons) at input. Default: 1 193 | """ 194 | 195 | def __init__(self, 196 | threshold_right: float = 0.4, 197 | slope_right: float = 0.2, 198 | threshold_left: float = -0.4, 199 | slope_left: float = 0.2, 200 | num_neurons: int = 1, 201 | random_init=False # this is a new (optional) feature, by Bram 202 | ): 203 | super().__init__() 204 | self.num_neurons = num_neurons 205 | param_shape = (num_neurons,) 206 | self.tr = nn.Parameter(torch.full(param_shape, float(threshold_right))) 207 | self.ar = nn.Parameter(torch.full(param_shape, float(slope_right))) 208 | self.tl = nn.Parameter(torch.full(param_shape, float(threshold_left))) 209 | self.al = nn.Parameter(torch.full(param_shape, float(slope_left))) 210 | 211 | def forward(self, x: torch.Tensor) -> torch.Tensor: 212 | if self.num_neurons == 1: 213 | return srelu_func(x, self.tr, self.ar, self.tl, self.al) 214 | else: 215 | return srelu_func_per_neuron(x, self.tr, self.ar, self.tl, self.al) 216 | 217 | 218 | def setup_act_func_args(act_func_args, activation, num_hid_neurons): 219 | """ Sets up the arguments for the activation function. 220 | :param act_func_args: list of floats 221 | :param activation: str, name of activation function 222 | :param num_hid_neurons: int, for example 256 223 | :return: act_args: a list of floats (parameters for the activation func) 224 | per_neuron: a dict with the number of neurons to share params with (1 or num_hid_neurons) 225 | """ 226 | act_args, per_neuron = [], {} 227 | act_funcs_with_per_neuron_option = ['srelu', 'lex'] 228 | act_funcs_with_params = ['srelu', 'fixedsrelu', 'elu', 'leakyrelu', 'nonlex', 'lex'] 229 | if activation in act_funcs_with_per_neuron_option and act_func_args[1]: 230 | per_neuron = {"num_neurons": num_hid_neurons} 231 | if act_func_args[0] is not None and activation in act_funcs_with_params: 232 | for arg in act_func_args[0]: 233 | act_args.append(arg) 234 | return act_args, per_neuron 235 | 236 | 237 | act_funcs = {'relu': nn.ReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'elu': nn.ELU, 'leakyrelu': nn.LeakyReLU, 238 | 'symsqrt': SymSqrt, 'symsqrt1': SymSqrt1, 'nonlex': NonLEx, 'lex': LEx, 239 | 'fixedsrelu': FixedSReLU, 'srelu': SReLU, 'allrelu': AlternatedLeftReLU, 240 | 'swish': nn.SiLU, 'selu': nn.SELU, 241 | } 242 | 243 | 244 | def setup_activation_funcs_list(activation, act_func_args, num_hid_layers, num_hid_neurons): 245 | activation_funcs = nn.ModuleList() 246 | if activation == 'allrelu': 247 | for hid_layer in range(num_hid_layers): 248 | if hid_layer % 2 == 0: 249 | activation_funcs.append(act_funcs['allrelu'](-act_func_args[0][0])) 250 | else: 251 | activation_funcs.append(act_funcs['allrelu'](act_func_args[0][0])) 252 | else: 253 | act_args, per_neuron = setup_act_func_args(act_func_args, activation, num_hid_neurons) 254 | for hid_layer in range(num_hid_layers): 255 | activation_funcs.append(act_funcs[activation](*act_args, **per_neuron)) 256 | return activation_funcs 257 | 258 | 259 | if __name__ == '__main__': 260 | # to test whether gradients are computed correctly 261 | 262 | inpt = [[4., -0.4, 1.5, -100], 263 | [3., -1, 1.2, -30], 264 | [2, 5.3, 0.4, 0.01]] 265 | 266 | # act_func = SymSqrt() 267 | # act_func = SymSqrt1() 268 | # act_func = FixedSReLU(threshold_right=0.4, slope_right=0.2, threshold_left=-0.4, slope_left=0.2) 269 | # act_func = SReLU(threshold_right=0.4, slope_right=0.2, threshold_left=-0.4, slope_left=0.2, num_neurons=len(inpt[0])) 270 | # act_func = NonLEx(multiplier_right=1, exponent_right=1, multiplier_left=1, exponent_left=0.5) 271 | # act_func = LEx(multiplier_right=1, exponent_right=1, multiplier_left=1, exponent_left=0.5, num_neurons=len(inpt[0])) 272 | act_func = AlternatedLeftReLU(slope_left=-0.6) 273 | 274 | x = torch.tensor(inpt, requires_grad=True) 275 | out = act_func.forward(x) 276 | loss = torch.sum(out) 277 | loss.backward() 278 | 279 | print(x.grad) 280 | 281 | # for SReLU 282 | # print(act_func.tr.grad) # 1 - ar 283 | # print(act_func.ar.grad) # x - tr 284 | # print(act_func.tl.grad) # 1 - al 285 | # print(act_func.al.grad) # x - tl 286 | 287 | # for LEx 288 | # print(act_func.mr.grad) 289 | # print(act_func.er.grad) 290 | # print(act_func.ml.grad) 291 | # print(act_func.el.grad) 292 | -------------------------------------------------------------------------------- /utils/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | from utils.activations import act_funcs, setup_act_func_args 5 | 6 | 7 | class SparseBaseAgent: 8 | """ Sparse Base Agent class to inherit """ 9 | def __init__(self): 10 | self.actor = None 11 | self.actor_target = None 12 | self.actor_optimizer = None 13 | self.critic = None 14 | self.critic_target = None 15 | self.critic_optimizer = None 16 | self.device = None 17 | self.total_it = None 18 | self.prev_permutations = None 19 | 20 | def save(self, filename): 21 | checkpoint = { 22 | "iteration": self.total_it, 23 | "actor": self.actor.state_dict(), 24 | "actor_target": self.actor_target.state_dict(), 25 | "actor_optimizer": self.actor_optimizer.state_dict(), 26 | "actor_masks": self.actor.masks, 27 | "actor_torch_masks": self.actor.torch_masks, 28 | "critic": self.critic.state_dict(), 29 | "critic_target": self.critic_target.state_dict(), 30 | "critic_optimizer": self.critic_optimizer.state_dict(), 31 | "critic_q1_masks": self.critic.q1_masks, 32 | "critic_q2_masks": self.critic.q2_masks, 33 | "critic_q1_torch_masks": self.critic.q1_torch_masks, 34 | "critic_q2_torch_masks": self.critic.q2_torch_masks, 35 | "prev_permutations": self.prev_permutations, 36 | } 37 | # torch.save(checkpoint, f"{filename}_iter_{self.total_it}") 38 | torch.save(checkpoint, filename) 39 | print(f"Saved current model in: {filename}") 40 | 41 | def load(self, filename, load_device=None): 42 | if load_device is None: 43 | load_device = self.device 44 | loaded_checkpoint = torch.load(filename, map_location=load_device) 45 | self.total_it = loaded_checkpoint["iteration"] 46 | self.actor.load_state_dict(loaded_checkpoint["actor"]) 47 | self.actor_target.load_state_dict(loaded_checkpoint["actor_target"]) 48 | self.actor_optimizer.load_state_dict(loaded_checkpoint["actor_optimizer"]) 49 | self.actor.masks = loaded_checkpoint["actor_masks"] 50 | self.actor.torch_masks = loaded_checkpoint["actor_torch_masks"] 51 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 52 | self.critic_target.load_state_dict(loaded_checkpoint["critic_target"]) 53 | self.critic_optimizer.load_state_dict(loaded_checkpoint["critic_optimizer"]) 54 | self.critic.q1_masks = loaded_checkpoint["critic_q1_masks"] 55 | self.critic.q2_masks = loaded_checkpoint["critic_q2_masks"] 56 | self.critic.q1_torch_masks = loaded_checkpoint["critic_q1_torch_masks"] 57 | self.critic.q2_torch_masks = loaded_checkpoint["critic_q2_torch_masks"] 58 | self.prev_permutations = loaded_checkpoint.get("prev_permutations") 59 | print(f"Loaded model from: {filename}") 60 | 61 | 62 | class BaseAgent: 63 | """ Base Agent class to inherit """ 64 | def __init__(self): 65 | self.actor = None 66 | self.actor_target = None 67 | self.actor_optimizer = None 68 | self.critic = None 69 | self.critic_target = None 70 | self.critic_optimizer = None 71 | self.device = None 72 | self.total_it = None 73 | self.prev_permutations = None 74 | 75 | def save(self, filename): 76 | checkpoint = { 77 | "iteration": self.total_it, 78 | "actor": self.actor.state_dict(), 79 | "actor_target": self.actor_target.state_dict(), 80 | "actor_optimizer": self.actor_optimizer.state_dict(), 81 | "critic": self.critic.state_dict(), 82 | "critic_target": self.critic_target.state_dict(), 83 | "critic_optimizer": self.critic_optimizer.state_dict(), 84 | "prev_permutations": self.prev_permutations, 85 | } 86 | # torch.save(checkpoint, f"{filename}_iter_{self.total_it}") 87 | torch.save(checkpoint, filename) 88 | print(f"Saved current model in: {filename}") 89 | 90 | def load(self, filename, load_device=None): 91 | if load_device is None: 92 | load_device = self.device 93 | loaded_checkpoint = torch.load(filename, map_location=load_device) 94 | self.total_it = loaded_checkpoint["iteration"] 95 | self.actor.load_state_dict(loaded_checkpoint["actor"]) 96 | self.actor_target.load_state_dict(loaded_checkpoint["actor_target"]) 97 | self.actor_optimizer.load_state_dict(loaded_checkpoint["actor_optimizer"]) 98 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 99 | self.critic_target.load_state_dict(loaded_checkpoint["critic_target"]) 100 | self.critic_optimizer.load_state_dict(loaded_checkpoint["critic_optimizer"]) 101 | self.prev_permutations = loaded_checkpoint.get("prev_permutations") 102 | print(f"Loaded model from: {filename}") 103 | 104 | 105 | class Agent: 106 | """ Old version of the base agent. Didn't save the target networks. """ 107 | def __init__(self): 108 | self.actor = None 109 | self.critic = None 110 | self.actor_target = None 111 | self.critic_target = None 112 | self.actor_optimizer = None 113 | self.critic_optimizer = None 114 | self.device = None 115 | self.total_it = None 116 | 117 | def save(self, filename): 118 | checkpoint = { 119 | "iteration": self.total_it, 120 | "critic": self.critic.state_dict(), 121 | "critic_optimizer": self.critic_optimizer.state_dict(), 122 | "actor": self.actor.state_dict(), 123 | "actor_optimizer": self.actor_optimizer.state_dict(), 124 | } 125 | # torch.save(checkpoint, f"{filename}_iter_{self.total_it}") 126 | torch.save(checkpoint, filename) 127 | print(f"Saved current model in: {filename}") 128 | 129 | def load(self, filename, load_device=None): 130 | if load_device is None: 131 | load_device = self.device 132 | loaded_checkpoint = torch.load(filename, map_location=load_device) 133 | self.total_it = loaded_checkpoint["iteration"] 134 | self.critic.load_state_dict(loaded_checkpoint["critic"]) 135 | self.critic_optimizer.load_state_dict(loaded_checkpoint["critic_optimizer"]) 136 | self.critic_target = copy.deepcopy(self.critic) 137 | self.actor.load_state_dict(loaded_checkpoint["actor"]) 138 | self.actor_optimizer.load_state_dict(loaded_checkpoint["actor_optimizer"]) 139 | self.actor_target = copy.deepcopy(self.actor) 140 | print(f"Loaded model from: {filename}") 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /utils/core_anf_sac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | import utils.sparse_utils as sp 6 | from utils import utils 7 | from utils.activations import setup_activation_funcs_list 8 | 9 | LOG_SIG_MAX = 2 10 | LOG_SIG_MIN = -20 11 | epsilon = 1e-6 12 | 13 | 14 | # Initialize Policy weights 15 | def weights_init_(m): 16 | if isinstance(m, nn.Linear): 17 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 18 | torch.nn.init.constant_(m.bias, 0) 19 | 20 | 21 | class QNetwork(nn.Module): 22 | def __init__(self, state_dim, action_dim, args, dim_state_with_fake, device): 23 | super(QNetwork, self).__init__() 24 | hidden_dim = args.num_hid_neurons 25 | self.device = device 26 | self.num_fake_features = dim_state_with_fake - state_dim 27 | self.permutation = None 28 | self.fake_noise_std = args.fake_noise_std 29 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 30 | 31 | sparsities = sp.compute_sparsity_per_layer( 32 | global_sparsity=args.global_sparsity, 33 | neuron_layers=[dim_state_with_fake + action_dim, hidden_dim, hidden_dim, 1], 34 | keep_dense=[(args.input_layer_sparsity == 0), False, True], # sparse output layer not implemented yet 35 | method=args.sparsity_distribution_method, 36 | input_layer_sparsity=args.input_layer_sparsity) 37 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 38 | 39 | # Q1 architecture 40 | self.linear1 = nn.Linear(dim_state_with_fake + action_dim, hidden_dim) 41 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 42 | self.linear3 = nn.Linear(hidden_dim, 1) 43 | 44 | # Q2 architecture 45 | self.linear4 = nn.Linear(dim_state_with_fake + action_dim, hidden_dim) 46 | self.linear5 = nn.Linear(hidden_dim, hidden_dim) 47 | self.linear6 = nn.Linear(hidden_dim, 1) 48 | 49 | self.apply(weights_init_) 50 | 51 | activation = args.activation 52 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 53 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 54 | 55 | if not self.dense_layers[0]: 56 | self.noPar1, self.mask1 = sp.initialize_mask( 57 | 'critic Q1 first layer', sparsities[0], dim_state_with_fake + action_dim, hidden_dim) 58 | self.torchMask1 = torch.from_numpy(self.mask1).float().to(device) 59 | self.linear1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 60 | 61 | if not self.dense_layers[1]: 62 | self.noPar2, self.mask2 = sp.initialize_mask( 63 | 'critic Q1 second layer', sparsities[1], hidden_dim, hidden_dim) 64 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 65 | self.linear2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 66 | 67 | if not self.dense_layers[0]: 68 | self.noPar4, self.mask4 = sp.initialize_mask( 69 | 'critic Q2 first layer', sparsities[0], dim_state_with_fake + action_dim, hidden_dim) 70 | self.torchMask4 = torch.from_numpy(self.mask4).float().to(device) 71 | self.linear4.weight.data.mul_(torch.from_numpy(self.mask4).float()) 72 | 73 | if not self.dense_layers[1]: 74 | self.noPar5, self.mask5 = sp.initialize_mask( 75 | 'critic Q2 second layer', sparsities[1], hidden_dim, hidden_dim) 76 | self.torchMask5 = torch.from_numpy(self.mask5).float().to(device) 77 | self.linear5.weight.data.mul_(torch.from_numpy(self.mask5).float()) 78 | 79 | def forward(self, state, action): 80 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 81 | self.fake_noise_std, self.fake_noise_generator) 82 | state = utils.permute_features(state, self.permutation) 83 | xu = torch.cat([state, action], 1) 84 | 85 | x1 = self.activation_funcs[0](self.linear1(xu)) 86 | x1 = self.activation_funcs[1](self.linear2(x1)) 87 | x1 = self.linear3(x1) 88 | 89 | x2 = self.activation_funcs[0](self.linear4(xu)) 90 | x2 = self.activation_funcs[1](self.linear5(x2)) 91 | x2 = self.linear6(x2) 92 | 93 | return x1, x2 94 | 95 | def set_new_permutation(self, permutation): 96 | self.permutation = permutation 97 | 98 | 99 | class GaussianPolicy(nn.Module): 100 | def __init__(self, state_dim, action_dim, args, dim_state_with_fake, device, action_space=None): 101 | super(GaussianPolicy, self).__init__() 102 | hidden_dim = args.num_hid_neurons 103 | self.device = device 104 | self.num_fake_features = dim_state_with_fake - state_dim 105 | self.permutation = None 106 | self.fake_noise_std = args.fake_noise_std 107 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 108 | 109 | sparsities = sp.compute_sparsity_per_layer( 110 | global_sparsity=args.global_sparsity, 111 | neuron_layers=[dim_state_with_fake, hidden_dim, hidden_dim, 2 * action_dim], # *2 for two heads: mean and log_std 112 | keep_dense=[(args.input_layer_sparsity == 0), False, True], # sparse output layer not implemented yet 113 | method=args.sparsity_distribution_method, 114 | input_layer_sparsity=args.input_layer_sparsity) 115 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 116 | 117 | self.linear1 = nn.Linear(dim_state_with_fake, hidden_dim) 118 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 119 | self.mean_linear = nn.Linear(hidden_dim, action_dim) 120 | self.log_std_linear = nn.Linear(hidden_dim, action_dim) 121 | 122 | self.apply(weights_init_) 123 | 124 | activation = args.activation 125 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 126 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 127 | 128 | if not self.dense_layers[0]: 129 | self.noPar1, self.mask1 = sp.initialize_mask( 130 | 'Gaussian actor input layer', sparsities[0], dim_state_with_fake, hidden_dim) 131 | self.torchMask1 = torch.from_numpy(self.mask1).float().to(device) 132 | self.linear1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 133 | 134 | if not self.dense_layers[1]: 135 | self.noPar2, self.mask2 = sp.initialize_mask( 136 | 'Gaussian actor hidden layer', sparsities[1], hidden_dim, hidden_dim) 137 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 138 | self.linear2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 139 | 140 | # action rescaling 141 | if action_space is None: 142 | self.action_scale = torch.tensor(1.) 143 | self.action_bias = torch.tensor(0.) 144 | else: 145 | self.action_scale = torch.FloatTensor( 146 | (action_space.high - action_space.low) / 2.) 147 | self.action_bias = torch.FloatTensor( 148 | (action_space.high + action_space.low) / 2.) 149 | 150 | def forward(self, state): 151 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 152 | self.fake_noise_std, self.fake_noise_generator) 153 | state = utils.permute_features(state, self.permutation) 154 | x = self.activation_funcs[0](self.linear1(state)) 155 | x = self.activation_funcs[1](self.linear2(x)) 156 | mean = self.mean_linear(x) 157 | log_std = self.log_std_linear(x) 158 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 159 | return mean, log_std 160 | 161 | def sample(self, state): 162 | mean, log_std = self.forward(state) 163 | std = log_std.exp() 164 | normal = Normal(mean, std) 165 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 166 | y_t = torch.tanh(x_t) 167 | action = y_t * self.action_scale + self.action_bias 168 | log_prob = normal.log_prob(x_t) 169 | # Enforcing Action Bound 170 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 171 | log_prob = log_prob.sum(1, keepdim=True) 172 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 173 | return action, log_prob, mean 174 | 175 | def to(self, device): 176 | self.action_scale = self.action_scale.to(device) 177 | self.action_bias = self.action_bias.to(device) 178 | return super(GaussianPolicy, self).to(device) 179 | 180 | def set_new_permutation(self, permutation): 181 | self.permutation = permutation 182 | 183 | 184 | class DeterministicPolicy(nn.Module): 185 | def __init__(self, state_dim, action_dim, args, dim_state_with_fake, device, action_space=None): 186 | super(DeterministicPolicy, self).__init__() 187 | hidden_dim = args.num_hid_neurons 188 | self.device = device 189 | self.num_fake_features = dim_state_with_fake - state_dim 190 | self.permutation = None 191 | self.fake_noise_std = args.fake_noise_std 192 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 193 | 194 | sparsities = sp.compute_sparsity_per_layer( 195 | global_sparsity=args.global_sparsity, 196 | neuron_layers=[dim_state_with_fake, hidden_dim, hidden_dim, action_dim], 197 | keep_dense=[(args.input_layer_sparsity == 0), False, True], # sparse output layer not implemented yet 198 | method=args.sparsity_distribution_method, 199 | input_layer_sparsity=args.input_layer_sparsity) 200 | self.dense_layers = [True if sparsity == 0 else False for sparsity in sparsities] 201 | 202 | self.linear1 = nn.Linear(dim_state_with_fake, hidden_dim) 203 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 204 | self.mean = nn.Linear(hidden_dim, action_dim) 205 | self.noise = torch.Tensor(action_dim) 206 | 207 | self.apply(weights_init_) 208 | 209 | activation = args.activation 210 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 211 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 212 | 213 | if not self.dense_layers[0]: 214 | self.noPar1, self.mask1 = sp.initialize_mask( 215 | 'Gaussian actor input layer', sparsities[0], dim_state_with_fake, hidden_dim) 216 | self.torchMask1 = torch.from_numpy(self.mask1).float().to(device) 217 | self.linear1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 218 | 219 | if not self.dense_layers[1]: 220 | self.noPar2, self.mask2 = sp.initialize_mask( 221 | 'Gaussian actor hidden layer', sparsities[1], hidden_dim, hidden_dim) 222 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 223 | self.linear2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 224 | 225 | # action rescaling 226 | if action_space is None: 227 | self.action_scale = 1. 228 | self.action_bias = 0. 229 | else: 230 | self.action_scale = torch.FloatTensor( 231 | (action_space.high - action_space.low) / 2.) 232 | self.action_bias = torch.FloatTensor( 233 | (action_space.high + action_space.low) / 2.) 234 | 235 | def forward(self, state): 236 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 237 | self.fake_noise_std, self.fake_noise_generator) 238 | state = utils.permute_features(state, self.permutation) 239 | x = self.activation_funcs[0](self.linear1(state)) 240 | x = self.activation_funcs[1](self.linear2(x)) 241 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 242 | return mean 243 | 244 | def sample(self, state): 245 | mean = self.forward(state) 246 | noise = self.noise.normal_(0., std=0.1) 247 | noise = noise.clamp(-0.25, 0.25) 248 | action = mean + noise 249 | return action, torch.tensor(0.), mean 250 | 251 | def to(self, device): 252 | self.action_scale = self.action_scale.to(device) 253 | self.action_bias = self.action_bias.to(device) 254 | self.noise = self.noise.to(device) 255 | return super(DeterministicPolicy, self).to(device) 256 | 257 | def set_new_permutation(self, permutation): 258 | self.permutation = permutation 259 | 260 | -------------------------------------------------------------------------------- /utils/core_sac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | from utils import utils 6 | from utils.activations import setup_activation_funcs_list 7 | # was called model.py in the original code 8 | 9 | LOG_SIG_MAX = 2 10 | LOG_SIG_MIN = -20 11 | epsilon = 1e-6 12 | 13 | 14 | # Initialize Policy weights 15 | def weights_init_(m): 16 | if isinstance(m, nn.Linear): 17 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 18 | torch.nn.init.constant_(m.bias, 0) 19 | 20 | 21 | # class ValueNetwork(nn.Module): 22 | # def __init__(self, state_dim, hidden_dim): 23 | # super(ValueNetwork, self).__init__() 24 | # 25 | # self.linear1 = nn.Linear(state_dim, hidden_dim) 26 | # self.linear2 = nn.Linear(hidden_dim, hidden_dim) 27 | # self.linear3 = nn.Linear(hidden_dim, 1) 28 | # 29 | # self.apply(weights_init_) 30 | # 31 | # def forward(self, state): 32 | # x = F.relu(self.linear1(state)) 33 | # x = F.relu(self.linear2(x)) 34 | # x = self.linear3(x) 35 | # return x 36 | 37 | 38 | class QNetwork(nn.Module): 39 | def __init__(self, state_dim, num_actions, args, dim_state_with_fake, device): 40 | super(QNetwork, self).__init__() 41 | hidden_dim = args.num_hid_neurons 42 | self.device = device 43 | self.num_fake_features = dim_state_with_fake - state_dim 44 | self.permutation = None 45 | self.fake_noise_std = args.fake_noise_std 46 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 47 | 48 | # Q1 architecture 49 | self.linear1 = nn.Linear(dim_state_with_fake + num_actions, hidden_dim) 50 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 51 | self.linear3 = nn.Linear(hidden_dim, 1) 52 | 53 | # Q2 architecture 54 | self.linear4 = nn.Linear(dim_state_with_fake + num_actions, hidden_dim) 55 | self.linear5 = nn.Linear(hidden_dim, hidden_dim) 56 | self.linear6 = nn.Linear(hidden_dim, 1) 57 | 58 | self.apply(weights_init_) 59 | 60 | activation = args.activation 61 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 62 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 63 | 64 | def forward(self, state, action): 65 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 66 | self.fake_noise_std, self.fake_noise_generator) 67 | state = utils.permute_features(state, self.permutation) 68 | xu = torch.cat([state, action], 1) 69 | 70 | x1 = self.activation_funcs[0](self.linear1(xu)) 71 | x1 = self.activation_funcs[1](self.linear2(x1)) 72 | x1 = self.linear3(x1) 73 | 74 | x2 = self.activation_funcs[0](self.linear4(xu)) 75 | x2 = self.activation_funcs[1](self.linear5(x2)) 76 | x2 = self.linear6(x2) 77 | 78 | return x1, x2 79 | 80 | def set_new_permutation(self, permutation): 81 | self.permutation = permutation 82 | 83 | 84 | class GaussianPolicy(nn.Module): 85 | def __init__(self, state_dim, num_actions, args, dim_state_with_fake, device, action_space=None): 86 | super(GaussianPolicy, self).__init__() 87 | hidden_dim = args.num_hid_neurons 88 | self.device = device 89 | self.num_fake_features = dim_state_with_fake - state_dim 90 | self.permutation = None 91 | self.fake_noise_std = args.fake_noise_std 92 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 93 | 94 | self.linear1 = nn.Linear(dim_state_with_fake, hidden_dim) 95 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 96 | 97 | self.mean_linear = nn.Linear(hidden_dim, num_actions) 98 | self.log_std_linear = nn.Linear(hidden_dim, num_actions) 99 | 100 | self.apply(weights_init_) 101 | 102 | activation = args.activation 103 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 104 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 105 | 106 | # action rescaling 107 | if action_space is None: 108 | self.action_scale = torch.tensor(1.) 109 | self.action_bias = torch.tensor(0.) 110 | else: 111 | self.action_scale = torch.FloatTensor( 112 | (action_space.high - action_space.low) / 2.) 113 | self.action_bias = torch.FloatTensor( 114 | (action_space.high + action_space.low) / 2.) 115 | 116 | def forward(self, state): 117 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 118 | self.fake_noise_std, self.fake_noise_generator) 119 | state = utils.permute_features(state, self.permutation) 120 | x = self.activation_funcs[0](self.linear1(state)) 121 | x = self.activation_funcs[1](self.linear2(x)) 122 | mean = self.mean_linear(x) 123 | log_std = self.log_std_linear(x) 124 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 125 | return mean, log_std 126 | 127 | def sample(self, state): 128 | mean, log_std = self.forward(state) 129 | std = log_std.exp() 130 | normal = Normal(mean, std) 131 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 132 | y_t = torch.tanh(x_t) 133 | action = y_t * self.action_scale + self.action_bias 134 | log_prob = normal.log_prob(x_t) 135 | # Enforcing Action Bound 136 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 137 | log_prob = log_prob.sum(1, keepdim=True) 138 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 139 | return action, log_prob, mean 140 | 141 | def to(self, device): 142 | self.action_scale = self.action_scale.to(device) 143 | self.action_bias = self.action_bias.to(device) 144 | return super(GaussianPolicy, self).to(device) 145 | 146 | def set_new_permutation(self, permutation): 147 | self.permutation = permutation 148 | 149 | 150 | class DeterministicPolicy(nn.Module): 151 | def __init__(self, state_dim, num_actions, args, dim_state_with_fake, device, action_space=None): 152 | super(DeterministicPolicy, self).__init__() 153 | hidden_dim = args.num_hid_neurons 154 | self.device = device 155 | self.num_fake_features = dim_state_with_fake - state_dim 156 | self.permutation = None 157 | self.fake_noise_std = args.fake_noise_std 158 | self.fake_noise_generator = utils.setup_noise_generator(args.load_noise_distribution) 159 | 160 | self.linear1 = nn.Linear(dim_state_with_fake, hidden_dim) 161 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 162 | 163 | self.mean = nn.Linear(hidden_dim, num_actions) 164 | self.noise = torch.Tensor(num_actions) 165 | 166 | self.apply(weights_init_) 167 | 168 | activation = args.activation 169 | act_func_args = (args.act_func_args, args.act_func_per_neuron) 170 | self.activation_funcs = setup_activation_funcs_list(activation, act_func_args, args.num_hid_layers, args.num_hid_neurons) 171 | 172 | # action rescaling 173 | if action_space is None: 174 | self.action_scale = 1. 175 | self.action_bias = 0. 176 | else: 177 | self.action_scale = torch.FloatTensor( 178 | (action_space.high - action_space.low) / 2.) 179 | self.action_bias = torch.FloatTensor( 180 | (action_space.high + action_space.low) / 2.) 181 | 182 | def forward(self, state): 183 | state = utils.add_fake_features(state, self.num_fake_features, self.device, 184 | self.fake_noise_std, self.fake_noise_generator) 185 | state = utils.permute_features(state, self.permutation) 186 | x = self.activation_funcs[0](self.linear1(state)) 187 | x = self.activation_funcs[1](self.linear2(x)) 188 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 189 | return mean 190 | 191 | def sample(self, state): 192 | mean = self.forward(state) 193 | noise = self.noise.normal_(0., std=0.1) 194 | noise = noise.clamp(-0.25, 0.25) 195 | action = mean + noise 196 | return action, torch.tensor(0.), mean 197 | 198 | def to(self, device): 199 | self.action_scale = self.action_scale.to(device) 200 | self.action_bias = self.action_bias.to(device) 201 | self.noise = self.noise.to(device) 202 | return super(DeterministicPolicy, self).to(device) 203 | 204 | def set_new_permutation(self, permutation): 205 | self.permutation = permutation 206 | 207 | -------------------------------------------------------------------------------- /utils/load_feats_distr.py: -------------------------------------------------------------------------------- 1 | """ 2 | To load and sample from the distribution of real features 3 | """ 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class RealFeatureDistribution: 9 | """ A class to sample from the distribution of real features """ 10 | def __init__(self, path): 11 | self.path = path 12 | self.probs, self.bin_edges = self.load_distribution(path) 13 | self.num_dims = len(self.probs) 14 | 15 | def load_distribution(self, path): 16 | """ Load the distribution of real features from a file """ 17 | distr_dict = np.load(path, allow_pickle=True) 18 | probs = distr_dict.item().get('probs') 19 | bin_edges = distr_dict.item().get('bin_edges') 20 | return probs, bin_edges 21 | 22 | def sample_one_feat_dim(self, probs, bin_edges, batch_size=1): 23 | """ Sample one number from the distribution of a specific feature dimension 24 | :arg probs: the probabilities of the bins, for the feature dimension that you want to sample from 25 | :arg bin_edges: the bin edges of the bins, for the feature dimension that you want to sample from 26 | """ 27 | bin_idx = np.random.choice(len(probs), size=batch_size, p=probs) 28 | edge_left = bin_edges[bin_idx] 29 | edge_right = bin_edges[bin_idx+1] 30 | return np.random.uniform(edge_left, edge_right) 31 | 32 | def sample(self, num_feats, batch_size=1): 33 | """ Sample n features from the distribution 34 | start from feat_dim 0, go towards the last real feat_dim |state_space|-1, 35 | then keep cycling through the real feat_dims until you have n samples 36 | :return an 2D array (torch tensor) of sampled features, size (batch_size, num_feats) 37 | """ 38 | generated_feats = np.empty((batch_size, num_feats)) 39 | for i in range(num_feats): 40 | feat_dim = i % self.num_dims 41 | feat = self.sample_one_feat_dim(self.probs[feat_dim], self.bin_edges[feat_dim], batch_size) 42 | generated_feats[:, i] = feat.transpose() 43 | return torch.from_numpy(generated_feats).to(torch.float32) 44 | 45 | 46 | if __name__ == '__main__': 47 | feats_distr_folder = "../experiments/plots_fake_feats/real_feats_distributions/" 48 | feats_distr_file = "real_feats_distr_HalfCheetah.npy" 49 | feats_sampler = RealFeatureDistribution(path=f'{feats_distr_folder}{feats_distr_file}') 50 | sampled_feats = feats_sampler.sample(5, 3) 51 | print(sampled_feats) 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /utils/mask_adam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | from torch import Tensor 4 | from typing import List 5 | 6 | 7 | class MaskAdam(Optimizer): 8 | r"""Implements the MaskAdam optimizer. 9 | Especially useful for Dynamic Sparse Training. 10 | The difference with regular Adam is that the gradients and its first and 11 | second (raw) moments are masked for non-existing connections. 12 | Args: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 1e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 22 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 23 | (default: False) 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, amsgrad=False): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if not 0.0 <= eps: 35 | raise ValueError("Invalid epsilon value: {}".format(eps)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | if not 0.0 <= weight_decay: 41 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, amsgrad=amsgrad) 44 | super().__init__(params, defaults) 45 | 46 | def __setstate__(self, state): 47 | super().__setstate__(state) 48 | for group in self.param_groups: 49 | group.setdefault('amsgrad', False) 50 | 51 | @torch.no_grad() 52 | def step(self, masks, closure=None): 53 | """Performs a single optimization step. 54 | Args: 55 | masks (list): List of masks (torch tensors) for each layer in the network. 56 | Should be of length equal to the number of connection-layers in the network, 57 | and should have element None if a layer is dense. 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | with torch.enable_grad(): 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | params_with_grad = [] 68 | grads = [] 69 | exp_avgs = [] 70 | exp_avg_sqs = [] 71 | max_exp_avg_sqs = [] 72 | state_steps = [] 73 | 74 | params = group['params'] 75 | beta1, beta2 = group['betas'] 76 | amsgrad = group['amsgrad'] 77 | 78 | for p in params: 79 | if p.grad is not None: 80 | params_with_grad.append(p) 81 | if p.grad.is_sparse: 82 | raise RuntimeError('MaskAdam does not support sparse gradients, ' 83 | 'which is something distinct from sparse connectivity. ' 84 | 'Please consider SparseAdam instead if you have sparse gradients.') 85 | grads.append(p.grad) 86 | 87 | state = self.state[p] 88 | # Lazy state initialization 89 | if len(state) == 0: 90 | state['step'] = torch.tensor(0.) 91 | # Exponential moving average of gradient values 92 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 93 | # Exponential moving average of squared gradient values 94 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 95 | if amsgrad: 96 | # Maintains max of all exp. moving avg. of sq. grad. values 97 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 98 | 99 | state['step'].add_(1) 100 | 101 | # record the updates 102 | state_steps.append(state['step']) 103 | exp_avgs.append(state['exp_avg']) 104 | exp_avg_sqs.append(state['exp_avg_sq']) 105 | if amsgrad: 106 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 107 | 108 | self.mask_adam(params_with_grad, 109 | grads, 110 | exp_avgs, 111 | exp_avg_sqs, 112 | max_exp_avg_sqs, 113 | state_steps, 114 | masks, 115 | amsgrad=amsgrad, 116 | beta1=beta1, 117 | beta2=beta2, 118 | lr=group['lr'], 119 | weight_decay=group['weight_decay'], 120 | eps=group['eps'], 121 | ) 122 | return loss 123 | 124 | def mask_adam(self, 125 | params: List[Tensor], 126 | grads: List[Tensor], 127 | exp_avgs: List[Tensor], 128 | exp_avg_sqs: List[Tensor], 129 | max_exp_avg_sqs: List[Tensor], 130 | state_steps: List[Tensor], 131 | masks: List[Tensor or None], 132 | *, 133 | amsgrad: bool, 134 | beta1: float, 135 | beta2: float, 136 | lr: float, 137 | weight_decay: float, 138 | eps: float, 139 | ): 140 | """ 141 | Function that performs the MaskAdam optimizer computation. 142 | It uses masks to simulate sparsity. Some layers may not have masks. 143 | 144 | ASSUMES: all connection-layers (weight matrices) have gradients 145 | (they are in params_with_grad in step function above, and so are given in param argument here) 146 | """ 147 | layer_idx = 0 148 | for param_idx, param in enumerate(params): 149 | grad = grads[param_idx] 150 | exp_avg = exp_avgs[param_idx] 151 | exp_avg_sq = exp_avg_sqs[param_idx] 152 | step = state_steps[param_idx] 153 | 154 | bias_correction1 = 1 - torch.pow(beta1, step) 155 | bias_correction2 = 1 - torch.pow(beta2, step) 156 | 157 | if len(param.size()) == 2: 158 | mask = masks[layer_idx] 159 | layer_idx += 1 160 | if mask is not None: 161 | # this is a sparse layer, everything is masked (put to zero) for non-existing connections 162 | grad[mask == 0] = 0 163 | exp_avg[mask == 0] = 0 164 | exp_avg_sq[mask == 0] = 0 165 | if amsgrad: 166 | max_exp_avg_sqs[param_idx][mask == 0] = 0 167 | 168 | if weight_decay != 0: 169 | grad = grad.add(param, alpha=weight_decay) 170 | 171 | # Decay the first and second moment running average coefficient 172 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 173 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 174 | if amsgrad: 175 | # Maintains the maximum of all 2nd moment running avg. till now 176 | torch.max(max_exp_avg_sqs[param_idx], exp_avg_sq, out=max_exp_avg_sqs[param_idx]) 177 | # Use the max. for normalizing running avg. of gradient 178 | denom = (max_exp_avg_sqs[param_idx] / bias_correction2).sqrt().add_(eps) 179 | else: 180 | denom = (exp_avg_sq / bias_correction2).sqrt().add_(eps) 181 | 182 | step_size = lr / bias_correction1 183 | param -= step_size * exp_avg / denom 184 | -------------------------------------------------------------------------------- /utils/noise_distributions/real_feats_distr_HalfCheetah.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/utils/noise_distributions/real_feats_distr_HalfCheetah.npy -------------------------------------------------------------------------------- /utils/pretrained_models/ANF-SAC_HalfCheetah-v3_relu_sparsity0.0_uniform_inlayspars0.8_hid-lay2_maskadam_fakefeats0.9_seed3101_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bramgrooten/automatic-noise-filtering/6d3bca27affb4036e6e0bd60ef5b9a4ccdfc6daa/utils/pretrained_models/ANF-SAC_HalfCheetah-v3_relu_sparsity0.0_uniform_inlayspars0.8_hid-lay2_maskadam_fakefeats0.9_seed3101_best -------------------------------------------------------------------------------- /utils/replay_memory_sac.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import datetime 4 | 5 | 6 | class ReplayMemory: 7 | def __init__(self, capacity, seed): 8 | random.seed(seed) 9 | self.capacity = int(capacity) 10 | self.buffer = [] 11 | self.position = 0 12 | 13 | def push(self, state, action, reward, next_state, done): 14 | if len(self.buffer) < self.capacity: 15 | self.buffer.append(None) 16 | self.buffer[self.position] = (state, action, reward, next_state, done) 17 | self.position = int((self.position + 1) % self.capacity) 18 | 19 | def sample(self, batch_size): 20 | batch = random.sample(self.buffer, batch_size) 21 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 22 | return state, action, reward, next_state, done 23 | 24 | def __len__(self): 25 | return len(self.buffer) 26 | 27 | def empty_buffer(self): 28 | self.buffer = [] 29 | self.position = 0 30 | 31 | 32 | def fill_initial_replay_memory(replay_buffer, env, args): 33 | print("Filling initial replay memory...") 34 | state, done = env.reset(), False 35 | episode_reward, episode_timesteps, episode_num = 0, 0, 0 36 | episode_start_time = datetime.datetime.now() 37 | for t in range(int(args.start_timesteps)): 38 | episode_timesteps += 1 39 | action = env.action_space.sample() 40 | # Perform action 41 | next_state, reward, done, _ = env.step(action) 42 | 43 | # Ignore the "done" signal if it comes from hitting the time horizon. 44 | # see https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/algos/pytorch/sac/sac.py#L304 45 | not_done = 1 if episode_timesteps == env._max_episode_steps else float(not done) 46 | replay_buffer.push(state, action, reward, next_state, not_done) # Append transition to memory 47 | 48 | state = next_state 49 | episode_reward += reward 50 | if done: 51 | if args.print_comments: 52 | print(f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} " 53 | f"Reward: {episode_reward:.3f} Time: {datetime.datetime.now() - episode_start_time}") 54 | # Reset environment 55 | state, done = env.reset(), False 56 | episode_reward = 0 57 | episode_timesteps = 0 58 | episode_num += 1 59 | episode_start_time = datetime.datetime.now() 60 | 61 | 62 | def refill_replay_buffer(replay_buffer, env, policy, args): 63 | print("Refilling replay buffer...") 64 | state, done = env.reset(), False 65 | episode_reward, episode_timesteps, episode_num = 0, 0, 0 66 | episode_start_time = datetime.datetime.now() 67 | 68 | for t in range(int(args.refill_timesteps)): 69 | episode_timesteps += 1 70 | if args.refill_mode == 'random': 71 | action = env.action_space.sample() 72 | elif args.refill_mode == 'current': 73 | action = policy.select_action(state) 74 | else: 75 | raise ValueError('Invalid refill_mode. Options: random, current.') 76 | # Perform action 77 | next_state, reward, done, _ = env.step(action) 78 | 79 | not_done = 1 if episode_timesteps == env._max_episode_steps else float(not done) 80 | replay_buffer.push(state, action, reward, next_state, not_done) # Append transition to memory 81 | 82 | state = next_state 83 | episode_reward += reward 84 | if done: 85 | if args.print_comments: 86 | print(f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} " 87 | f"Reward: {episode_reward:.3f} Time: {datetime.datetime.now() - episode_start_time}") 88 | episode_start_time = datetime.datetime.now() 89 | # Reset environment 90 | state, done = env.reset(), False 91 | episode_reward = 0 92 | episode_timesteps = 0 93 | episode_num += 1 94 | 95 | -------------------------------------------------------------------------------- /utils/sparse_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def initialize_mask(layer_name, layer_sparsity_level, num_in_neurons, num_out_neurons, print_info=True): 6 | """ New version of initializeSparsityLevelWeightMask 7 | that produces the exact sparsity level, instead of something closeby. """ 8 | total_connections = num_in_neurons * num_out_neurons 9 | num_connections = int((1 - layer_sparsity_level) * total_connections) 10 | mask = np.zeros((num_in_neurons, num_out_neurons)) 11 | mask = grow_set(mask, num_connections) 12 | # if print_info: 13 | # print(f"Sparsity Level Initialization {layer_name}: sparsity {1 - np.sum(mask)/total_connections:.6f}; " 14 | # f"num active connections {num_connections}; num_in_neurons {num_in_neurons}; " 15 | # f"num_out_neurons {num_out_neurons}; num_dense_connections {total_connections}") 16 | # print(f" OutDegreeBottomNeurons {np.mean(mask.sum(axis=1)):.2f} ± {np.std(mask.sum(axis=1)):.2f};" 17 | # f" InDegreeTopNeurons {np.mean(mask.sum(axis=0)):.2f} ± {np.std(mask.sum(axis=0)):.2f}") 18 | return num_connections, mask.transpose() 19 | 20 | 21 | def adjust_connectivity_set(weights, num_weights, zeta, mask): 22 | """ New version of changeConnectivitySET that uses a slightly different method: 23 | remove the zeta fraction of weights that are nearest to zero (so zeta smallest absolute values, 24 | instead of zeta largest negative and zeta smallest positive). 25 | """ 26 | # mask = prune_set(weights, zeta) 27 | mask = prune_set_only_active(weights, zeta, mask, num_weights) 28 | mask = grow_set(mask, num_weights) 29 | return mask 30 | 31 | 32 | def prune_set(weights, zeta): 33 | """ Remove zeta fraction of weights that are nearest to zero. """ 34 | abs_weights = np.abs(weights) 35 | abs_vector = np.sort(abs_weights.ravel()) 36 | first_nonzero = np.nonzero(abs_vector)[0][0] 37 | threshold = abs_vector[int(first_nonzero + (len(abs_vector) - first_nonzero) * zeta)] 38 | # wrong: there could also be active connections that still have a weight of 0. 39 | # Hmm, this should be fixed by grow_set though (more connections than zeta will grow back in that case) 40 | new_mask = np.zeros_like(weights) 41 | new_mask[abs_weights > threshold] = 1 42 | return new_mask 43 | 44 | 45 | def prune_set_only_active(weights, zeta, mask, num_weights): 46 | """ Remove zeta fraction of weights that are nearest to zero. 47 | Prune threshold now also based on existing connections that have value 0. """ 48 | active = np.where(mask == 1) 49 | abs_weights = np.abs(weights) 50 | 51 | abs_vector = np.sort(abs_weights[active].ravel()) 52 | assert len(abs_vector) == num_weights 53 | threshold = abs_vector[int(num_weights * zeta)] 54 | # assumes threshold is unique, otherwise more connections than zeta might be pruned. They will grow back. 55 | # prune_set assumes this as well. 56 | 57 | new_mask = np.zeros_like(weights) 58 | new_mask[abs_weights > threshold] = 1 59 | return new_mask 60 | 61 | 62 | def grow_set(mask, num_weights): 63 | """ Grow new connections according to SET: Choose randomly from the available options. """ 64 | num_to_grow = num_weights - np.sum(mask) 65 | idx_zeros_i, idx_zeros_j = np.where(mask == 0) 66 | new_conn_idx = np.random.choice(idx_zeros_i.shape[0], int(num_to_grow), replace=False) 67 | mask[idx_zeros_i[new_conn_idx], idx_zeros_j[new_conn_idx]] = 1 68 | return mask 69 | 70 | 71 | def critic_give_new_connections_init_values(critic, q1_old_masks, q2_old_masks, init_new_weights_method, device): 72 | for layer in range(critic.num_hid_layers + 1): 73 | if layer == 0: 74 | layer_give_new_connections_init_values( 75 | critic.q1_input_layer.weight, critic.q1_masks[0], q1_old_masks[0], init_new_weights_method, device) 76 | layer_give_new_connections_init_values( 77 | critic.q2_input_layer.weight, critic.q2_masks[0], q2_old_masks[0], init_new_weights_method, device) 78 | elif layer == critic.num_hid_layers: 79 | layer_give_new_connections_init_values( 80 | critic.q1_output_layer.weight, critic.q1_masks[layer], q1_old_masks[layer], init_new_weights_method, device) 81 | layer_give_new_connections_init_values( 82 | critic.q2_output_layer.weight, critic.q2_masks[layer], q2_old_masks[layer], init_new_weights_method, device) 83 | else: 84 | layer_give_new_connections_init_values( 85 | critic.q1_hid_layers[layer-1].weight, critic.q1_masks[layer], q1_old_masks[layer], init_new_weights_method, device) 86 | layer_give_new_connections_init_values( 87 | critic.q2_hid_layers[layer-1].weight, critic.q2_masks[layer], q2_old_masks[layer], init_new_weights_method, device) 88 | 89 | 90 | def actor_give_new_connections_init_values(actor, old_masks, init_new_weights_method, device): 91 | for layer in range(actor.num_hid_layers + 1): 92 | if layer == 0: 93 | layer_give_new_connections_init_values( 94 | actor.input_layer.weight, actor.masks[0], old_masks[0], init_new_weights_method, device) 95 | elif layer == actor.num_hid_layers: 96 | layer_give_new_connections_init_values( 97 | actor.output_layer.weight, actor.masks[layer], old_masks[layer], init_new_weights_method, device) 98 | else: 99 | layer_give_new_connections_init_values( 100 | actor.hid_layers[layer-1].weight, actor.masks[layer], old_masks[layer], init_new_weights_method, device) 101 | 102 | 103 | def layer_give_new_connections_init_values(layer_weights, new_mask, old_mask, init_new_weights_method, device): 104 | if init_new_weights_method == "unif": 105 | reinit_values_unif(layer_weights, new_mask, old_mask, device) 106 | elif init_new_weights_method == "xavier": 107 | reinit_values_xavier(layer_weights, new_mask, old_mask, device) 108 | else: 109 | raise ValueError("Unknown init_new_weights_method") 110 | 111 | 112 | def reinit_values_unif(layer_weights, new_mask, old_mask, device): 113 | # the default initialization values of PyTorch 114 | # see https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear 115 | if old_mask is not None: 116 | weights = layer_weights.data.cpu().numpy() 117 | diff = new_mask - old_mask # new connections will have value 1 in diff 118 | num_in_neurons = layer_weights.data.shape[1] # weight matrix is transposed 119 | bound = 1 / np.sqrt(num_in_neurons) 120 | weights[diff == 1] = np.random.uniform(-bound, bound) # only new connections will get new values 121 | layer_weights.data = torch.from_numpy(weights).float().to(device) 122 | 123 | 124 | def reinit_values_xavier(layer_weights, new_mask, old_mask, device): 125 | # also called Glorot initialization 126 | # see https://keras.io/api/layers/initializers/#glorotuniform-class 127 | # and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_ 128 | if old_mask is not None: 129 | weights = layer_weights.data.cpu().numpy() 130 | diff = new_mask - old_mask # new connections will have value 1 in diff 131 | num_in_neurons = layer_weights.data.shape[1] # weight matrix is transposed 132 | num_out_neurons = layer_weights.data.shape[0] 133 | bound = np.sqrt(12 / (num_in_neurons + num_out_neurons)) # 12 = 2 * 6 (see 'gain' variable in links above) 134 | weights[diff == 1] = np.random.uniform(-bound, bound) # only new connections will get new values 135 | layer_weights.data = torch.from_numpy(weights).float().to(device) 136 | 137 | 138 | def print_sparsities(critic_params, critic_target_params, actor_params, actor_target_params=None): 139 | sparsities_dict = {} 140 | sparsities_dict['critic_sparsity'], layers_sp = print_sparsity_one_network(critic_params, 'critic') 141 | for lay_idx, lay_sp in enumerate(layers_sp): 142 | sparsities_dict[f'critic_sparsity_layer{lay_idx}'] = lay_sp 143 | 144 | sparsities_dict['critic_target_sparsity'], layers_sp = print_sparsity_one_network(critic_target_params, 'critic_target') 145 | for lay_idx, lay_sp in enumerate(layers_sp): 146 | sparsities_dict[f'critic_target_sparsity_layer{lay_idx}'] = lay_sp 147 | 148 | sparsities_dict['actor_sparsity'], layers_sp = print_sparsity_one_network(actor_params, 'actor') 149 | for lay_idx, lay_sp in enumerate(layers_sp): 150 | sparsities_dict[f'actor_sparsity_layer{lay_idx}'] = lay_sp 151 | 152 | if actor_target_params is not None: 153 | sparsities_dict['actor_target_sparsity'], layers_sp = print_sparsity_one_network(actor_target_params, 'actor_target') 154 | for lay_idx, lay_sp in enumerate(layers_sp): 155 | sparsities_dict[f'actor_target_sparsity_layer{lay_idx}'] = lay_sp 156 | return sparsities_dict 157 | 158 | 159 | def print_sparsity_one_network(params, network_name): 160 | # layer = 0 161 | total_pruned_connections = 0 162 | total_connections_possible = 0 163 | layer_sparsities = [] 164 | for param in params: 165 | if len(param.shape) > 1: 166 | num_pruned_connections = np.sum(param.data.cpu().numpy() == 0) 167 | total_layer_connections = param.shape[0] * param.shape[1] 168 | this_layer_current_sparsity = num_pruned_connections / total_layer_connections 169 | # print(f"{network_name} sparsity layer {layer}: {this_layer_current_sparsity}") 170 | # layer += 1 171 | layer_sparsities.append(round(this_layer_current_sparsity, 5)) 172 | total_pruned_connections += num_pruned_connections 173 | total_connections_possible += total_layer_connections 174 | global_sparsity = total_pruned_connections / total_connections_possible 175 | # print(f" {round(global_sparsity, 5)} {network_name} global sparsity. Layer sparsities: {layer_sparsities}") 176 | return global_sparsity, layer_sparsities 177 | 178 | 179 | def compute_sparsity_per_layer(global_sparsity, neuron_layers, keep_dense, closeness=0.2, method='ER', input_layer_sparsity=-1.): 180 | """ 181 | Function to compute the sparsity levels of individual layers, based on a given global sparsity level. 182 | Instead of a uniform sparsity (for example 80% in every layer), 183 | this function gives bigger layers a larger sparsity level. 184 | Assumes an MLP architecture. 185 | :param global_sparsity: float, number between 0 and 1, the desired global sparsity of the whole network 186 | :param neuron_layers: list of ints, number of neurons in each neuron-layer 187 | :param keep_dense: list of bools, must be of length neuron_layers-1. 188 | Put a True if this connection-layer should be dense, and a False if not. 189 | :param closeness: float, the exponent in the computation, between 0 and 1. 190 | Only used in method 'new'. Determines how close the sparsity levels of layers should be. 191 | Value closer to 0 gives more uniform (0 = same sparsity level in each sparse layer) 192 | Value closer to 1 gives less uniform (1 = most differences in sparsity levels) 193 | :param method: str, 'new' or 'ER' or 'uniform'. ER is from Mocanu et al. 2018: 194 | https://www.nature.com/articles/s41467-018-04316-3 Uniform gives each layer same sparsity level 195 | (which needs to be higher than the desired global sparsity if you want some dense layers as well.) 196 | :param input_layer_sparsity: float, number between 0 and 1, the desired sparsity of the input layer. Default is -1, 197 | meaning that the input layer sparsity is computed based on the global sparsity. 198 | :return: list of floats (between 0 and 1) giving the sparsity of each connection-layer (length: neuron_layers-1) 199 | 200 | Example: 201 | compute_sparsity_per_layer(global_sparsity=0.8, neuron_layers=[17, 256, 256, 6], keep_dense=[False, False, True]) 202 | output: [0.40, 0.85, 0.0] (output is normally not rounded, just for brevity here) 203 | """ 204 | assert len(neuron_layers) - 1 == len(keep_dense) 205 | 206 | total_connections_possible = 0 207 | connections_possible_per_layer = [] 208 | for n_layer_idx in range(len(neuron_layers)-1): 209 | layer_connections = neuron_layers[n_layer_idx] * neuron_layers[n_layer_idx + 1] 210 | total_connections_possible += layer_connections 211 | connections_possible_per_layer.append(layer_connections) 212 | 213 | global_density = 1 - global_sparsity 214 | total_connections_needed = round(global_density * total_connections_possible) 215 | 216 | keep_input_layer_fixed = False 217 | if input_layer_sparsity == 0: 218 | keep_dense[0] = True 219 | elif 0 < input_layer_sparsity < 1: 220 | keep_input_layer_fixed = True 221 | elif input_layer_sparsity != -1: 222 | raise ValueError("input_layer_sparsity must be in the interval [0,1) " 223 | "or equal to -1 to let it be computed based on global sparsity") 224 | 225 | total_conn_needed_sparse_lays = total_connections_needed 226 | keep_probs = [] # the probability of keeping a connection, for each layer (i.e. the layer density) 227 | for c_layer_idx, c_layer_dense in enumerate(keep_dense): 228 | if c_layer_dense: 229 | keep_probs.append(1) 230 | total_conn_needed_sparse_lays -= connections_possible_per_layer[c_layer_idx] 231 | else: 232 | if c_layer_idx == 0 and input_layer_sparsity > 0: 233 | prob = 1 - input_layer_sparsity 234 | total_conn_needed_sparse_lays -= round(prob * connections_possible_per_layer[c_layer_idx]) 235 | else: 236 | if method == 'new': 237 | prob = 2 / ((neuron_layers[c_layer_idx] * neuron_layers[c_layer_idx + 1]) ** closeness) 238 | elif method == 'ER': 239 | prob = (neuron_layers[c_layer_idx] + neuron_layers[c_layer_idx + 1]) \ 240 | / (neuron_layers[c_layer_idx] * neuron_layers[c_layer_idx + 1]) 241 | elif method == 'uniform': 242 | prob = global_density 243 | else: 244 | raise ValueError('Unknown method name for computing layer sparsities.') 245 | keep_probs.append(prob) 246 | 247 | # Counting the number of connections that we have in the sparse layers (that don't stay on a fixed sparsity level) 248 | total_conn_current_sparse_lays = 0 249 | for c_layer_idx, c_layer_dense in enumerate(keep_dense): 250 | if not c_layer_dense and not (keep_input_layer_fixed and c_layer_idx == 0): 251 | num_connections = round(keep_probs[c_layer_idx] * connections_possible_per_layer[c_layer_idx]) 252 | total_conn_current_sparse_lays += num_connections 253 | 254 | if total_conn_current_sparse_lays == 0: 255 | return handle_impossible_config_input_layer(input_layer_sparsity, keep_probs, connections_possible_per_layer, 256 | total_connections_possible) 257 | 258 | adjustment_factor = total_conn_needed_sparse_lays / total_conn_current_sparse_lays 259 | for c_layer_idx, c_layer_dense in enumerate(keep_dense): 260 | if not c_layer_dense and not (keep_input_layer_fixed and c_layer_idx == 0): 261 | keep_probs[c_layer_idx] *= adjustment_factor 262 | # print(f"adjustment_factor (epsilon) is {adjustment_factor}") 263 | 264 | # Check if all probabilities are valid 265 | do_again = False 266 | for c_layer_idx, keep_prob in enumerate(keep_probs): 267 | prob_for_one_connection = 1 / connections_possible_per_layer[c_layer_idx] 268 | minimum_connections = max(neuron_layers[c_layer_idx], neuron_layers[c_layer_idx+1]) 269 | prob_minimum_connections = minimum_connections / connections_possible_per_layer[c_layer_idx] 270 | if keep_prob > 1: 271 | keep_dense[c_layer_idx] = True 272 | do_again = True 273 | elif keep_prob < prob_for_one_connection: 274 | # example input: global_sparsity=0.9, neuron_layers=[10, 256, 256], keep_dense=[False, True] 275 | # can happen if dense layers & high sparsity is desired 276 | raise ValueError(f"This sparsity configuration is impossible, empty layers are required (layer {c_layer_idx}).") 277 | elif keep_prob < prob_minimum_connections: 278 | # warn if prob so low that this c_layer would have fewer connections than num_neurons on either side 279 | print(f"\nWARNING: extremely sparse layer, some neurons will have no connections in layer: {c_layer_idx}.\n") 280 | if do_again: 281 | return compute_sparsity_per_layer(global_sparsity, neuron_layers, keep_dense, closeness, method, input_layer_sparsity) 282 | 283 | return collect_output_sparsity_per_layer(keep_probs, connections_possible_per_layer, total_connections_possible) 284 | 285 | 286 | def collect_output_sparsity_per_layer(keep_probs, connections_possible_per_layer, total_connections_possible): 287 | sparsity_per_layer = [] 288 | total_connections = 0 289 | for c_layer_idx, keep_prob in enumerate(keep_probs): 290 | sparsity_per_layer.append(float(1 - keep_prob)) 291 | total_connections += round(keep_prob * connections_possible_per_layer[c_layer_idx]) 292 | 293 | print(f"\nconnections: {total_connections}, out of: {total_connections_possible}, " 294 | f"global sparsity: {round(1 - total_connections/total_connections_possible, 6)}") 295 | print(f"sparsity per layer: {sparsity_per_layer}") 296 | return sparsity_per_layer 297 | 298 | 299 | def handle_impossible_config_input_layer(input_layer_sparsity, keep_probs, connections_possible_per_layer, 300 | total_connections_possible): 301 | assert 0 < input_layer_sparsity < 1, "other situation not handled yet" 302 | conns = total_connections_possible - round(input_layer_sparsity * connections_possible_per_layer[0]) 303 | print(f"\nUsing higher global sparsity than requested." 304 | f"\nYour chosen sparsity configuration is impossible. Minimum sparsity level with " 305 | f"input_layer_sparsity {input_layer_sparsity} is {1 - conns / total_connections_possible}." 306 | f"\nConfiguring that now :)") 307 | new_keep_probs = [1] * len(keep_probs) 308 | new_keep_probs[0] = 1 - input_layer_sparsity 309 | return collect_output_sparsity_per_layer(new_keep_probs, connections_possible_per_layer, total_connections_possible) 310 | 311 | 312 | 313 | if __name__ == '__main__': 314 | # to test some of the functions 315 | 316 | ### to get an overview for each environment 317 | env_dims = { 318 | 'HalfCheetah-v3': (17, 6), 319 | # 'Walker2d-v3': (17, 6), 320 | # 'Hopper-v3': (11, 3), 321 | 'Humanoid-v3': (376, 17), 322 | # 'Ant-v3': (111, 8), 323 | # 'Swimmer-v3': (8, 2), 324 | # 'SlipperyAnt': (29, 8) 325 | } 326 | 327 | # glob_sparsities = [0, .5, .8, .9, .95, 0.96, 0.97] # , .98] 328 | glob_sparsities = [0] 329 | fake_feats_multiplier = 10 330 | 331 | for env_name, env_dims in env_dims.items(): 332 | print(f"\n{env_name}") 333 | for glob_sparsity in glob_sparsities: 334 | for method in ['uniform']: # 'ER', 335 | sparsities = compute_sparsity_per_layer(global_sparsity=glob_sparsity, 336 | # neuron_layers=[env_dims[0]*fake_feats_multiplier, 256, 256, 2 * env_dims[1]], # SAC has 2 output heads in actor 337 | # neuron_layers=[env_dims[0]*fake_feats_multiplier, 256, 256, env_dims[1]], # TD3 has 1 output head in actor. 338 | neuron_layers=[env_dims[0]*fake_feats_multiplier + env_dims[1], 256, 256, 1], # for all critic networks 339 | keep_dense=[False, False, True], 340 | method=method, 341 | input_layer_sparsity=0.8) 342 | print(f"desired global: {glob_sparsity}, sparsity per layer {method[:2]}: {sparsities}") 343 | 344 | 345 | 346 | 347 | -------------------------------------------------------------------------------- /utils/target_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | 5 | # Used in SAC algos 6 | def soft_update(target, source, tau): 7 | for target_param, param in zip(target.parameters(), source.parameters()): 8 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 9 | if len(param.shape) > 1: 10 | update_target_networks(param, target_param) 11 | 12 | 13 | def hard_update(target, source): 14 | for target_param, param in zip(target.parameters(), source.parameters()): 15 | target_param.data.copy_(param.data) 16 | 17 | 18 | def update_target_networks(param, target_param): 19 | current_density = (param != 0).sum() 20 | target_density = (target_param != 0).sum() # torch.count_nonzero(target_param.data) 21 | difference = target_density - current_density 22 | # constrain the sparsity by removing the extra elements (smallest values) 23 | if difference > 0: 24 | count_rmv = difference 25 | tmp = copy.deepcopy(abs(target_param.data)) 26 | tmp[tmp == 0] = 10000000 27 | # rmv_indicies = torch.dstack(unravel_index(torch.argsort(tmp.ravel()), tmp.shape)) 28 | unraveled = unravel_index(torch.argsort(tmp.view(1, -1)[0]), tmp.shape) 29 | rmv_indicies = torch.stack(unraveled, dim=1) 30 | rmv_values_smaller_than = tmp[rmv_indicies[count_rmv][0], rmv_indicies[count_rmv][1]] 31 | target_param.data[tmp < rmv_values_smaller_than] = 0 32 | 33 | 34 | def unravel_index(index, shape): 35 | out = [] 36 | for dim in reversed(shape): 37 | out.append(index % dim) 38 | index = index // dim 39 | return tuple(reversed(out)) 40 | 41 | -------------------------------------------------------------------------------- /view_mujoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | import numpy as np 4 | import argparse 5 | import time 6 | from utils import utils 7 | 8 | 9 | def show_policy(config): 10 | policy = load_policy(config) 11 | env = gym.make(config['env']) 12 | state, done = env.reset(), False 13 | while not done: 14 | env.render() 15 | action = policy.select_action(np.array(state), evaluate=True) 16 | state, reward, done, _ = env.step(action) 17 | 18 | 19 | def show_random_actions(config): 20 | env = gym.make(config['env']) 21 | env.reset() 22 | done, step_num = False, 0 23 | while not done: 24 | env.render() 25 | action = env.action_space.sample() 26 | obs, rew, done, _ = env.step(action) 27 | 28 | time.sleep(0.01) 29 | print(f"step {step_num}") 30 | step_num += 1 31 | env.close() 32 | 33 | 34 | def format_optimizer_arg(config): 35 | if config['policy'] in ['TD3', 'SAC', 'Static-TD3', 'Static-SAC']: 36 | optimizer = '' # for dense/static it's all the same (adam=maskadam), we use default (Adam) 37 | elif config['policy'] in ['ANF-TD3', 'ANF-SAC']: 38 | optimizer = 'maskadam_' 39 | else: 40 | raise ValueError('unknown policy') 41 | return optimizer 42 | 43 | 44 | def format_sparsity(exp): 45 | if exp['policy'] in ['TD3', 'SAC']: 46 | return '' 47 | else: 48 | return f'sparsity0.0_uniform_inlayspars0.8_' 49 | 50 | 51 | def format_env(config): 52 | if config['noise_fraction'] == 0.95: 53 | return f'{config["env"]}-adjust1000000' 54 | else: 55 | return config['env'] 56 | 57 | 58 | def load_policy(config): 59 | env = gym.make(config['env']) 60 | state_dim = env.observation_space.shape[0] 61 | action_dim = env.action_space.shape[0] 62 | max_action = float(env.action_space.high[0]) 63 | 64 | parser = argparse.ArgumentParser() 65 | utils.add_arguments(parser) 66 | seed = 3101 67 | args = parser.parse_args(["--policy", config['policy'], 68 | "--env", config['env'], 69 | "--fake_features", str(config.get('noise_fraction', 0)), 70 | "--fake_noise_std", str(config.get('noise_amplitude', 1)), 71 | "--global_sparsity", str(0), 72 | "--sparsity_distribution_method", 'uniform', 73 | "--input_layer_sparsity", str(0.8), 74 | "--seed", str(int(seed)), 75 | ]) 76 | utils.print_all_args(args) 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | if 'TD3' in config['policy']: 79 | policy = utils.set_policy_kwargs(state_dim, action_dim, max_action, args, device) 80 | elif 'SAC' in config['policy']: 81 | policy = utils.setup_sac_based_agent(args, env, device) 82 | else: 83 | raise ValueError('Unknown policy name') 84 | 85 | sparsity_info = format_sparsity(config) 86 | optim = format_optimizer_arg(config) 87 | noisefeats = f'fakefeats{config["noise_fraction"]}_' if config["noise_fraction"] != 0 else '' 88 | noiseamp = config.get('noise_amplitude', 1) 89 | noise_ampl = f'noise-std{noiseamp}.0_' if noiseamp != 1 else '' 90 | folder = './utils/pretrained_models/' 91 | file_name = f"{config['policy']}_{format_env(config)}_relu_" \ 92 | f"{sparsity_info}hid-lay2_{optim}{noisefeats}{noise_ampl}seed{seed}_best" 93 | file_path = f'{folder}{file_name}' 94 | policy.load(file_path) 95 | return policy 96 | 97 | 98 | if __name__ == '__main__': 99 | # PRESS >>>TAB<<< TO SWITCH CAMERA TO TRACK THE AGENT :) 100 | # script to see an agent in action! 101 | # run this script from the terminal with: 102 | # python view_mujoco.py 103 | 104 | # possible environments: HalfCheetah-v3, Hopper-v3, Walker2d-v3, Humanoid-v3 105 | # possible policies: ANF-SAC, ANF-TD3, SAC, TD3 106 | # possible noise_fractions: 0, 0.8, 0.9, 0.95, 0.98, 0.99 107 | # possible noise_amplitudes: 1 (for all) or 2, 4, 8, 16 (for noise_fraction=0.9, env=HalfCheetah-v3) 108 | config = { 109 | 'env': 'HalfCheetah-v3', 110 | 'policy': 'ANF-SAC', 111 | 'noise_fraction': 0.9, 112 | 'noise_amplitude': 1, 113 | } 114 | show_policy(config) 115 | 116 | # if you want to run this file with other config settings, 117 | # first download the pretrained models from: 118 | # https://www.dropbox.com/s/qr1l7bscnnd8non/pretrained_models.zip?dl=0 119 | # extract the zip, and put the model files in the folder './utils/pretrained_models/' 120 | --------------------------------------------------------------------------------