├── results.JPG ├── LICENSE ├── utils.py ├── README.md ├── sparse_utils.py ├── TD3.py ├── main.py ├── StaticSparseTD3.py └── DS_TD3.py /results.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GhadaSokar/Dynamic-Sparse-Training-for-Deep-Reinforcement-Learning/HEAD/results.JPG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ghada Sokar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | 17 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def add(self, state, action, next_state, reward, done): 21 | self.state[self.ptr] = state 22 | self.action[self.ptr] = action 23 | self.next_state[self.ptr] = next_state 24 | self.reward[self.ptr] = reward 25 | self.not_done[self.ptr] = 1. - done 26 | 27 | self.ptr = (self.ptr + 1) % self.max_size 28 | self.size = min(self.size + 1, self.max_size) 29 | 30 | 31 | def sample(self, batch_size): 32 | ind = np.random.randint(0, self.size, size=batch_size) 33 | 34 | return ( 35 | torch.FloatTensor(self.state[ind]).to(self.device), 36 | torch.FloatTensor(self.action[ind]).to(self.device), 37 | torch.FloatTensor(self.next_state[ind]).to(self.device), 38 | torch.FloatTensor(self.reward[ind]).to(self.device), 39 | torch.FloatTensor(self.not_done[ind]).to(self.device) 40 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Sparse Training for Deep Reinforcement Learning 2 | 3 | This is the Pytorch implementation for the IJCAI2022 [Dynamic Sparse Training for Deep Reinforcement Learning](https://arxiv.org/pdf/2106.04217.pdf) paper. 4 | 5 | # Abstract 6 | Deep reinforcement learning (DRL) agents are trained through trial-and-error interactions with the environment. This leads to a long training time for dense neural networks to achieve good performance. Hence, prohibitive computation and memory resources are consumed. Recently, learning efficient DRL agents has received increasing attention. Yet, current methods focus on accelerating inference time. In this paper, we introduce for the first time a dynamic sparse training approach for deep reinforcement learning to accelerate the training process. The proposed approach trains a sparse neural network from scratch and dynamically adapts its topology to the changing data distribution during training. Experiments on continuous control tasks show that our dynamic sparse agents achieve higher performance than the equivalent dense methods, reduce the parameter count and floating-point operations (FLOPs) by 50%, and have a faster learning speed that enables reaching the performance of dense agents with 40 − 50% reduction in the training steps 7 | 8 | # Requirements 9 | * Python 3.8 10 | * PyTorch 1.5 11 | * [Mujoco-py](https://github.com/openai/mujoco-py) 12 | * [OpenAI gym](https://github.com/openai/gym) 13 | 14 | # Usage 15 | 16 | For DS-TD3: Dynamic Sparse training of TD3 algorithm 17 | ``` 18 | python main.py --env HalfCheetah-v3 --policy DS-TD3 19 | ``` 20 | 21 | For Static-TD3 22 | ``` 23 | python main.py --env HalfCheetah-v3 --policy StaticSparseTD3 24 | ``` 25 | 26 | For TD3 27 | ``` 28 | python main.py --env HalfCheetah-v3 --policy TD3 29 | ``` 30 | 31 | # Results 32 | ![](results.JPG) 33 | # Reference 34 | 35 | If you use this code, please cite our paper: 36 | ``` 37 | @inproceedings{sokar2022dynamic, 38 | title={Dynamic Sparse Training for Deep Reinforcement Learning}, 39 | author={Sokar, Ghada and Mocanu, Elena and Mocanu, Decebal Constantin and Pechenizkiy, Mykola and Stone, Peter}, 40 | booktitle={International Joint Conference on Artificial Intelligence}, 41 | year={2022} 42 | } 43 | ``` 44 | 45 | # Acknowledgments 46 | We start from the official code of the TD3 method from the following repository 47 | 48 | [TD3](https://github.com/sfujim/TD3) 49 | -------------------------------------------------------------------------------- /sparse_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def initializeEpsilonWeightsMask(text,epsilon, noRows, noCols): 4 | # generate an epsilon based Erdos Renyi sparse weights mask 5 | mask_weights = np.random.rand(noRows, noCols) 6 | prob = 1 - (epsilon * (noRows + noCols)) / (noRows * noCols) 7 | mask_weights = np.random.rand(noRows, noCols) 8 | mask_weights[mask_weights < prob] = 0 9 | mask_weights[mask_weights >= prob] = 1 10 | noParameters = np.sum(mask_weights) 11 | sparsity = 1-noParameters/(noRows * noCols) 12 | print("Epsilon Sparse Initialization ",text,": Epsilon ",epsilon,"; Sparsity ",sparsity,"; NoParameters ",noParameters,"; NoRows ",noRows,"; NoCols ",noCols,"; NoDenseParam ",noRows*noCols) 13 | print (" OutDegreeBottomNeurons %.2f ± %.2f; InDegreeTopNeurons %.2f ± %.2f" % (np.mean(mask_weights.sum(axis=1)),np.std(mask_weights.sum(axis=1)),np.mean(mask_weights.sum(axis=0)),np.std(mask_weights.sum(axis=0)))) 14 | return [noParameters, mask_weights.transpose()] 15 | 16 | def initializeSparsityLevelWeightMask(text,sparsityLevel,noRows, noCols): 17 | # generate an Erdos Renyi sparse weights mask 18 | prob=sparsityLevel 19 | mask_weights = np.random.rand(noRows, noCols) 20 | mask_weights[mask_weights < prob] = 0 21 | mask_weights[mask_weights >= prob] = 1 22 | noParameters = np.sum(mask_weights) 23 | sparsity = 1-noParameters/(noRows * noCols) 24 | epsilon = int((prob*(noRows * noCols)-1)/(noRows + noCols)) 25 | print("Sparsity Level Initialization ",text,": Computed Epsilon ",epsilon,"; Sparsity ",sparsity,"; NoParameters ",noParameters,"; NoRows ",noRows,"; NoCols ",noCols,"; NoDenseParam ",noRows*noCols) 26 | print (" OutDegreeBottomNeurons %.2f ± %.2f; InDegreeTopNeurons %.2f ± %.2f" % (np.mean(mask_weights.sum(axis=1)),np.std(mask_weights.sum(axis=1)),np.mean(mask_weights.sum(axis=0)),np.std(mask_weights.sum(axis=0)))) 27 | return [noParameters, mask_weights.transpose()] 28 | 29 | def find_first_pos(array, value): 30 | idx = (np.abs(array - value)).argmin() 31 | return idx 32 | 33 | def find_last_pos(array, value): 34 | idx = (np.abs(array - value))[::-1].argmin() 35 | return array.shape[0] - idx 36 | 37 | def changeConnectivitySET(weights, noWeights, initMask, zeta, lastTopologyChange, iteration): 38 | # change Connectivity 39 | # remove zeta largest negative and smallest positive weights 40 | weights = weights * initMask 41 | values = np.sort(weights.ravel()) 42 | firstZeroPos = find_first_pos(values, 0) 43 | lastZeroPos = find_last_pos(values, 0) 44 | largestNegative = values[int((1 - zeta) * firstZeroPos)] 45 | smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos + zeta * (values.shape[0] - lastZeroPos)))] 46 | rewiredWeights = weights.copy(); 47 | rewiredWeights[rewiredWeights > smallestPositive] = 1; 48 | rewiredWeights[rewiredWeights < largestNegative] = 1; 49 | rewiredWeights[rewiredWeights != 1] = 0; 50 | 51 | # add random weights 52 | nrAdd = 0 53 | if (lastTopologyChange==False): 54 | noRewires = noWeights - np.sum(rewiredWeights) 55 | while (nrAdd < noRewires): 56 | i = np.random.randint(0, rewiredWeights.shape[0]) 57 | j = np.random.randint(0, rewiredWeights.shape[1]) 58 | if (rewiredWeights[i, j] == 0): 59 | rewiredWeights[i, j] = 1 60 | nrAdd += 1 61 | 62 | ascStats=[iteration, nrAdd, noWeights, np.count_nonzero(rewiredWeights)] 63 | 64 | return [rewiredWeights,ascStats] 65 | 66 | def changeConnectivityXReLU(weights, noWeights, initMask, lastTopologyChange, iteration): 67 | weights = weights * initMask 68 | 69 | weightspos = weights.copy() 70 | weightspos[weightspos < 0] = 0 71 | strengthpos = np.sum(weightspos, axis=0) 72 | 73 | weightsneg = weights.copy() 74 | weightsneg[weightsneg > 0] = 0 75 | strengthneg = np.sum(weightsneg, axis=0) 76 | 77 | for j in range(strengthpos.shape[0]): 78 | if (strengthpos[j] + strengthneg[j] < 0): 79 | difference = strengthpos[j] 80 | iis = np.nonzero(weightsneg[:, j])[0] 81 | ww = weightsneg[iis, j] 82 | iisort = np.argsort(ww) 83 | for i in iisort: 84 | if (difference > 0): 85 | difference += weightsneg[iis[i], j] 86 | else: 87 | weights[iis[i], j] = 0 88 | rewiredWeights = weights.copy(); 89 | rewiredWeights[rewiredWeights != 0] = 1; 90 | 91 | nrAdd = 0 92 | if (lastTopologyChange == False): 93 | noRewires = noWeights - np.sum(rewiredWeights) 94 | while (nrAdd < noRewires): 95 | i = np.random.randint(0, rewiredWeights.shape[0]) 96 | j = np.random.randint(0, rewiredWeights.shape[1]) 97 | if (rewiredWeights[i, j] == 0): 98 | rewiredWeights[i, j] = 1 99 | nrAdd += 1 100 | 101 | ascStats = [iteration, nrAdd, noWeights, np.count_nonzero(rewiredWeights)] 102 | return [rewiredWeights, ascStats] 103 | 104 | -------------------------------------------------------------------------------- /TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) 12 | # Paper: https://arxiv.org/abs/1802.09477 13 | # This is the offical code from the authors 14 | 15 | class Actor(nn.Module): 16 | def __init__(self, state_dim, action_dim, max_action): 17 | super(Actor, self).__init__() 18 | 19 | self.l1 = nn.Linear(state_dim, 256) 20 | self.l2 = nn.Linear(256, 256) 21 | self.l3 = nn.Linear(256, action_dim) 22 | 23 | self.max_action = max_action 24 | 25 | 26 | def forward(self, state): 27 | a = F.relu(self.l1(state)) 28 | a = F.relu(self.l2(a)) 29 | return self.max_action * torch.tanh(self.l3(a)) 30 | 31 | 32 | class Critic(nn.Module): 33 | def __init__(self, state_dim, action_dim): 34 | super(Critic, self).__init__() 35 | 36 | # Q1 architecture 37 | self.l1 = nn.Linear(state_dim + action_dim, 256) 38 | self.l2 = nn.Linear(256, 256) 39 | self.l3 = nn.Linear(256, 1) 40 | 41 | # Q2 architecture 42 | self.l4 = nn.Linear(state_dim + action_dim, 256) 43 | self.l5 = nn.Linear(256, 256) 44 | self.l6 = nn.Linear(256, 1) 45 | 46 | 47 | def forward(self, state, action): 48 | sa = torch.cat([state, action], 1) 49 | 50 | q1 = F.relu(self.l1(sa)) 51 | q1 = F.relu(self.l2(q1)) 52 | q1 = self.l3(q1) 53 | 54 | q2 = F.relu(self.l4(sa)) 55 | q2 = F.relu(self.l5(q2)) 56 | q2 = self.l6(q2) 57 | return q1, q2 58 | 59 | 60 | def Q1(self, state, action): 61 | sa = torch.cat([state, action], 1) 62 | 63 | q1 = F.relu(self.l1(sa)) 64 | q1 = F.relu(self.l2(q1)) 65 | q1 = self.l3(q1) 66 | return q1 67 | 68 | 69 | class TD3(object): 70 | def __init__( 71 | self, 72 | state_dim, 73 | action_dim, 74 | max_action, 75 | discount=0.99, 76 | tau=0.005, 77 | noHidNeurons=256, 78 | policy_noise=0.2, 79 | noise_clip=0.5, 80 | policy_freq=2 81 | 82 | ): 83 | 84 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 85 | self.actor_target = copy.deepcopy(self.actor) 86 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3, weight_decay=0.0002) 87 | 88 | self.critic = Critic(state_dim, action_dim).to(device) 89 | self.critic_target = copy.deepcopy(self.critic) 90 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3, weight_decay=0.0002) 91 | 92 | self.max_action = max_action 93 | self.discount = discount 94 | self.tau = tau 95 | self.policy_noise = policy_noise 96 | self.noise_clip = noise_clip 97 | self.policy_freq = policy_freq 98 | 99 | self.total_it = 0 100 | 101 | 102 | def select_action(self, state): 103 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 104 | return self.actor(state).cpu().data.numpy().flatten() 105 | 106 | 107 | def train(self, replay_buffer, batch_size=100): 108 | self.total_it += 1 109 | 110 | # Sample replay buffer 111 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 112 | 113 | with torch.no_grad(): 114 | # Select action according to policy and add clipped noise 115 | noise = ( 116 | torch.randn_like(action) * self.policy_noise 117 | ).clamp(-self.noise_clip, self.noise_clip) 118 | 119 | next_action = ( 120 | self.actor_target(next_state) + noise 121 | ).clamp(-self.max_action, self.max_action) 122 | 123 | # Compute the target Q value 124 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 125 | target_Q = torch.min(target_Q1, target_Q2) 126 | target_Q = reward + not_done * self.discount * target_Q 127 | 128 | # Get current Q estimates 129 | current_Q1, current_Q2 = self.critic(state, action) 130 | 131 | # Compute critic loss 132 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 133 | 134 | # Optimize the critic 135 | self.critic_optimizer.zero_grad() 136 | critic_loss.backward() 137 | self.critic_optimizer.step() 138 | 139 | # Delayed policy updates 140 | if self.total_it % self.policy_freq == 0: 141 | 142 | # Compute actor losse 143 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 144 | 145 | # Optimize the actor 146 | self.actor_optimizer.zero_grad() 147 | actor_loss.backward() 148 | self.actor_optimizer.step() 149 | 150 | # Update the frozen target models 151 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 152 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 153 | 154 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 155 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 156 | 157 | 158 | def save(self, filename): 159 | torch.save(self.critic.state_dict(), filename + "_critic") 160 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 161 | 162 | torch.save(self.actor.state_dict(), filename + "_actor") 163 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 164 | 165 | 166 | def load(self, filename): 167 | self.critic.load_state_dict(torch.load(filename + "_critic")) 168 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 169 | self.critic_target = copy.deepcopy(self.critic) 170 | 171 | self.actor.load_state_dict(torch.load(filename + "_actor")) 172 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 173 | self.actor_target = copy.deepcopy(self.actor) 174 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | 7 | import utils 8 | import TD3 9 | import StaticSparseTD3 10 | import DS_TD3 11 | import datetime 12 | 13 | printComments=True 14 | 15 | # Runs policy for X episodes and returns average reward 16 | # A fixed seed is used for the eval environment 17 | def eval_policy(policy, env_name, seed, eval_episodes=10): 18 | eval_env = gym.make(env_name) 19 | eval_env.seed(seed + 100) 20 | 21 | avg_reward = 0. 22 | for _ in range(eval_episodes): 23 | state, done = eval_env.reset(), False 24 | while not done: 25 | action = policy.select_action(np.array(state)) 26 | state, reward, done, _ = eval_env.step(action) 27 | avg_reward += reward 28 | 29 | avg_reward /= eval_episodes 30 | 31 | if (printComments): 32 | print("---------------------------------------") 33 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 34 | print("---------------------------------------") 35 | 36 | return avg_reward 37 | 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--policy", default="DS-TD3") # Policy name (TD3, DS-TD3 or StaticSparseTD3) 43 | parser.add_argument("--env", default="HalfCheetah-v3") # OpenAI gym environment name 44 | parser.add_argument("--seed", default=2, type=int) # Sets Gym, PyTorch and Numpy seeds 45 | parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used 46 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 47 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 48 | parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise 49 | parser.add_argument("--batch_size", default=100, type=int) # Batch size for both actor and critic 50 | parser.add_argument("--discount", default=0.99) # Discount factor 51 | parser.add_argument("--tau", default=0.005) # Target network update rate 52 | parser.add_argument("--policy_noise", default=0.2) # Noise added to target policy during critic update 53 | parser.add_argument("--noise_clip", default=0.5) # Range to clip target policy noise 54 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 55 | parser.add_argument("--save_model", default=True) # Save model and optimizer parameters 56 | parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name 57 | 58 | parser.add_argument("--save_model_period", default=1e5, type=int) # Save model and optimizer parameters after the set number of iterations 59 | parser.add_argument("--ann_noHidNeurons", default=256, type=int) # 60 | parser.add_argument("--ann_epsilonHid1", default=7, type=int) # lambda 1 61 | parser.add_argument("--ann_epsilonHid2", default=64, type=int) # lambda 2 62 | parser.add_argument("--ann_setZeta", default=0.05) # 63 | parser.add_argument("--ann_ascTopologyChangePeriod", default=1e3, type=int) # 64 | parser.add_argument("--ann_earlyStopTopologyChange", default=5e4, type=int) # 65 | 66 | args = parser.parse_args() 67 | 68 | file_name = f"{args.policy}_{args.env}_{args.seed}_{args.ann_epsilonHid1}_{args.ann_epsilonHid2}" 69 | print("---------------------------------------") 70 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 71 | print("---------------------------------------") 72 | 73 | if not os.path.exists("./results"): 74 | os.makedirs("./results") 75 | 76 | if args.save_model and not os.path.exists("./models"): 77 | os.makedirs("./models") 78 | 79 | env = gym.make(args.env) 80 | 81 | # Set seeds 82 | env.seed(args.seed) 83 | torch.manual_seed(args.seed) 84 | np.random.seed(args.seed) 85 | 86 | state_dim = env.observation_space.shape[0] 87 | action_dim = env.action_space.shape[0] 88 | max_action = float(env.action_space.high[0]) 89 | 90 | kwargs = { 91 | "state_dim": state_dim, 92 | "action_dim": action_dim, 93 | "max_action": max_action, 94 | "discount": args.discount, 95 | "tau": args.tau, 96 | "noHidNeurons":args.ann_noHidNeurons, 97 | } 98 | 99 | # Target policy smoothing is scaled wrt the action scale 100 | kwargs["policy_noise"] = args.policy_noise * max_action 101 | kwargs["noise_clip"] = args.noise_clip * max_action 102 | kwargs["policy_freq"] = args.policy_freq 103 | 104 | # Initialize policy 105 | if args.policy == "TD3": 106 | policy = TD3.TD3(**kwargs) 107 | elif args.policy == "StaticSparseTD3": 108 | kwargs["epsilonHid1"]=args.ann_epsilonHid1 109 | kwargs["epsilonHid2"]=args.ann_epsilonHid2 110 | policy = StaticSparseTD3.StaticSparseTD3(**kwargs) 111 | elif args.policy == "DS-TD3": 112 | kwargs["epsilonHid1"]=args.ann_epsilonHid1 113 | kwargs["epsilonHid2"]=args.ann_epsilonHid2 114 | kwargs["setZeta"]=args.ann_setZeta 115 | kwargs["ascTopologyChangePeriod"]=args.ann_ascTopologyChangePeriod 116 | kwargs["earlyStopTopologyChangeIteration"] = args.max_timesteps-args.start_timesteps-args.ann_earlyStopTopologyChange 117 | policy = DS_TD3.SETSparseTD3(**kwargs) 118 | 119 | if args.load_model != "": 120 | policy_file = file_name if args.load_model == "default" else args.load_model 121 | policy.load(f"./models/{policy_file}") 122 | 123 | if args.save_model: 124 | policy.save(f"./models/{file_name}_iter_0") 125 | 126 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 127 | 128 | # Evaluate untrained policy 129 | evaluations = [eval_policy(policy, args.env, args.seed)] 130 | 131 | state, done = env.reset(), False 132 | episode_reward = 0 133 | episode_timesteps = 0 134 | episode_num = 0 135 | t1 = datetime.datetime.now() 136 | tin= datetime.datetime.now() 137 | for t in range(int(args.max_timesteps)): 138 | 139 | episode_timesteps += 1 140 | 141 | # Select action randomly or according to policy 142 | if t < args.start_timesteps: 143 | action = env.action_space.sample() 144 | else: 145 | action = ( 146 | policy.select_action(np.array(state)) 147 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 148 | ).clip(-max_action, max_action) 149 | 150 | # Perform action 151 | next_state, reward, done, _ = env.step(action) 152 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 153 | 154 | # Store data in replay buffer 155 | replay_buffer.add(state, action, next_state, reward, done_bool) 156 | 157 | state = next_state 158 | episode_reward += reward 159 | 160 | # Train agent after collecting sufficient data 161 | if t >= args.start_timesteps: 162 | policy.train(replay_buffer, args.batch_size) 163 | 164 | if done: 165 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 166 | if (printComments): 167 | print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f} Time: {datetime.datetime.now()-t1}") 168 | t1 = datetime.datetime.now() 169 | # Reset environment 170 | state, done = env.reset(), False 171 | episode_reward = 0 172 | episode_timesteps = 0 173 | episode_num += 1 174 | 175 | # Evaluate episode 176 | if (t + 1) % args.eval_freq == 0: 177 | evaluations.append(eval_policy(policy, args.env, args.seed)) 178 | np.save(f"./results/{file_name}", evaluations) 179 | 180 | if args.save_model: 181 | if (t+1) % args.save_model_period == 0: 182 | policy.save(f"./models/{file_name}_iter_{t+1}") 183 | 184 | print("Total running time",datetime.datetime.now()-tin) 185 | -------------------------------------------------------------------------------- /StaticSparseTD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import sparse_utils as sp 7 | 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) 12 | # Paper: https://arxiv.org/abs/1802.09477 13 | 14 | 15 | class Actor(nn.Module): 16 | def __init__(self, state_dim, action_dim, max_action,noHidNeurons,epsilonHid1,epsilonHid2): 17 | super(Actor, self).__init__() 18 | 19 | self.l1 = nn.Linear(state_dim, noHidNeurons) 20 | [self.noPar1, self.mask1] = sp.initializeEpsilonWeightsMask("actor first layer", epsilonHid1, state_dim, noHidNeurons) 21 | self.torchMask1=torch.from_numpy(self.mask1).float().to(device) 22 | self.l1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 23 | 24 | self.l2 = nn.Linear(noHidNeurons, noHidNeurons) 25 | [self.noPar2, self.mask2] = sp.initializeEpsilonWeightsMask("actor second layer", epsilonHid2, noHidNeurons, noHidNeurons) 26 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 27 | self.l2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 28 | 29 | self.l3 = nn.Linear(noHidNeurons, action_dim) 30 | 31 | self.max_action = max_action 32 | 33 | 34 | def forward(self, state): 35 | a = F.relu(self.l1(state)) 36 | a = F.relu(self.l2(a)) 37 | return self.max_action * torch.tanh(self.l3(a)) 38 | 39 | 40 | class Critic(nn.Module): 41 | def __init__(self, state_dim, action_dim,noHidNeurons,epsilonHid1,epsilonHid2): 42 | super(Critic, self).__init__() 43 | 44 | # Q1 architecture 45 | self.l1 = nn.Linear(state_dim + action_dim, noHidNeurons) 46 | [self.noPar1, self.mask1] = sp.initializeEpsilonWeightsMask("critic Q1 first layer", epsilonHid1, state_dim + action_dim, noHidNeurons) 47 | self.torchMask1=torch.from_numpy(self.mask1).float().to(device) 48 | self.l1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 49 | 50 | self.l2 = nn.Linear(noHidNeurons, noHidNeurons) 51 | [self.noPar2, self.mask2] = sp.initializeEpsilonWeightsMask("critic Q1 second layer", epsilonHid2, noHidNeurons, noHidNeurons) 52 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 53 | self.l2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 54 | 55 | self.l3 = nn.Linear(noHidNeurons, 1) 56 | 57 | # Q2 architecture 58 | self.l4 = nn.Linear(state_dim + action_dim, noHidNeurons) 59 | [self.noPar4, self.mask4] = sp.initializeEpsilonWeightsMask("critic Q2 first layer", epsilonHid1, state_dim + action_dim, noHidNeurons) 60 | self.torchMask4 = torch.from_numpy(self.mask4).float().to(device) 61 | self.l4.weight.data.mul_(torch.from_numpy(self.mask4).float()) 62 | 63 | self.l5 = nn.Linear(noHidNeurons, noHidNeurons) 64 | [self.noPar5, self.mask5] = sp.initializeEpsilonWeightsMask("critic Q2 second layer", epsilonHid2, noHidNeurons, noHidNeurons) 65 | self.torchMask5 = torch.from_numpy(self.mask5).float().to(device) 66 | self.l5.weight.data.mul_(torch.from_numpy(self.mask5).float()) 67 | 68 | self.l6 = nn.Linear(noHidNeurons, 1) 69 | 70 | 71 | def forward(self, state, action): 72 | sa = torch.cat([state, action], 1) 73 | 74 | q1 = F.relu(self.l1(sa)) 75 | q1 = F.relu(self.l2(q1)) 76 | q1 = self.l3(q1) 77 | 78 | q2 = F.relu(self.l4(sa)) 79 | q2 = F.relu(self.l5(q2)) 80 | q2 = self.l6(q2) 81 | return q1, q2 82 | 83 | 84 | def Q1(self, state, action): 85 | sa = torch.cat([state, action], 1) 86 | 87 | q1 = F.relu(self.l1(sa)) 88 | q1 = F.relu(self.l2(q1)) 89 | q1 = self.l3(q1) 90 | return q1 91 | 92 | 93 | class StaticSparseTD3(object): 94 | def __init__( 95 | self, 96 | state_dim, 97 | action_dim, 98 | max_action, 99 | discount=0.99, 100 | tau=0.005, 101 | policy_noise=0.2, 102 | noise_clip=0.5, 103 | policy_freq=2, 104 | noHidNeurons=256, 105 | epsilonHid1=20, 106 | epsilonHid2=20, 107 | ): 108 | 109 | self.actor = Actor(state_dim, action_dim, max_action,noHidNeurons,epsilonHid1,epsilonHid2).to(device) 110 | self.actor_target = copy.deepcopy(self.actor) 111 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3, weight_decay=0.0002) 112 | 113 | self.critic = Critic(state_dim, action_dim,noHidNeurons,epsilonHid1,epsilonHid2).to(device) 114 | self.critic_target = copy.deepcopy(self.critic) 115 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3, weight_decay=0.0002) 116 | 117 | self.max_action = max_action 118 | self.discount = discount 119 | self.tau = tau 120 | self.policy_noise = policy_noise 121 | self.noise_clip = noise_clip 122 | self.policy_freq = policy_freq 123 | 124 | self.total_it = 0 125 | 126 | 127 | def select_action(self, state): 128 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 129 | return self.actor(state).cpu().data.numpy().flatten() 130 | 131 | 132 | def train(self, replay_buffer, batch_size=100): 133 | self.total_it += 1 134 | 135 | # Sample replay buffer 136 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 137 | 138 | with torch.no_grad(): 139 | # Select action according to policy and add clipped noise 140 | noise = ( 141 | torch.randn_like(action) * self.policy_noise 142 | ).clamp(-self.noise_clip, self.noise_clip) 143 | 144 | next_action = ( 145 | self.actor_target(next_state) + noise 146 | ).clamp(-self.max_action, self.max_action) 147 | 148 | # Compute the target Q value 149 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 150 | target_Q = torch.min(target_Q1, target_Q2) 151 | target_Q = reward + not_done * self.discount * target_Q 152 | 153 | # Get current Q estimates 154 | current_Q1, current_Q2 = self.critic(state, action) 155 | 156 | # Compute critic loss 157 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 158 | 159 | # Optimize the critic 160 | self.critic_optimizer.zero_grad() 161 | critic_loss.backward() 162 | self.critic_optimizer.step() 163 | 164 | # Maintain the same sparse connectivity for critic 165 | self.critic.l1.weight.data.mul_(self.critic.torchMask1) 166 | self.critic.l2.weight.data.mul_(self.critic.torchMask2) 167 | self.critic.l4.weight.data.mul_(self.critic.torchMask4) 168 | self.critic.l5.weight.data.mul_(self.critic.torchMask5) 169 | 170 | # Delayed policy updates 171 | if self.total_it % self.policy_freq == 0: 172 | 173 | # Compute actor losse 174 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 175 | 176 | # Optimize the actor 177 | self.actor_optimizer.zero_grad() 178 | actor_loss.backward() 179 | self.actor_optimizer.step() 180 | 181 | # Maintain the same sparse connectivity for actor 182 | self.actor.l1.weight.data.mul_(self.actor.torchMask1) 183 | self.actor.l2.weight.data.mul_(self.actor.torchMask2) 184 | 185 | # Update the frozen target models 186 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 187 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 188 | 189 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 190 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 191 | 192 | def print_sparsity(self): 193 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 194 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 195 | if(len(target_param.shape)>1): 196 | critic_current_sparsity = ((target_param==0).sum().cpu().data.numpy()*1.0/(target_param.shape[0]*target_param.shape[1])) 197 | print("target critic sparsity", critic_current_sparsity) 198 | 199 | critic_current_sparsity = ((param==0).sum().cpu().data.numpy()*1.0/(param.shape[0]*param.shape[1])) 200 | print("critic sparsity", critic_current_sparsity) 201 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 202 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 203 | if(len(target_param.shape)>1): 204 | critic_current_sparsity = ((target_param==0).sum().cpu().data.numpy()*1.0/(target_param.shape[0]*target_param.shape[1])) 205 | print("target actor sparsity", critic_current_sparsity) 206 | 207 | critic_current_sparsity = ((param==0).sum().cpu().data.numpy()*1.0/(param.shape[0]*param.shape[1])) 208 | print("actor sparsity", critic_current_sparsity) 209 | 210 | 211 | def save(self, filename): 212 | torch.save(self.critic.state_dict(), filename + "_critic") 213 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 214 | 215 | torch.save(self.actor.state_dict(), filename + "_actor") 216 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 217 | 218 | 219 | def load(self, filename): 220 | self.critic.load_state_dict(torch.load(filename + "_critic")) 221 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 222 | self.critic_target = copy.deepcopy(self.critic) 223 | 224 | self.actor.load_state_dict(torch.load(filename + "_actor")) 225 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 226 | self.actor_target = copy.deepcopy(self.actor) 227 | -------------------------------------------------------------------------------- /DS_TD3.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import sparse_utils as sp 7 | 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action,noHidNeurons,epsilonHid1,epsilonHid2): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, noHidNeurons) 16 | [self.noPar1, self.mask1] = sp.initializeEpsilonWeightsMask("actor first layer", epsilonHid1, state_dim, noHidNeurons) 17 | self.torchMask1=torch.from_numpy(self.mask1).float().to(device) 18 | self.l1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 19 | 20 | self.l2 = nn.Linear(noHidNeurons, noHidNeurons) 21 | [self.noPar2, self.mask2] = sp.initializeEpsilonWeightsMask("actor second layer", epsilonHid2, noHidNeurons, noHidNeurons) 22 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 23 | self.l2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 24 | 25 | self.l3 = nn.Linear(noHidNeurons, action_dim) 26 | 27 | self.max_action = max_action 28 | 29 | 30 | def forward(self, state): 31 | a = F.relu(self.l1(state)) 32 | a = F.relu(self.l2(a)) 33 | return self.max_action * torch.tanh(self.l3(a)) 34 | 35 | 36 | class Critic(nn.Module): 37 | def __init__(self, state_dim, action_dim,noHidNeurons,epsilonHid1,epsilonHid2): 38 | super(Critic, self).__init__() 39 | 40 | # Q1 architecture 41 | self.l1 = nn.Linear(state_dim + action_dim, noHidNeurons) 42 | [self.noPar1, self.mask1] = sp.initializeEpsilonWeightsMask("critic Q1 first layer", epsilonHid1, state_dim + action_dim, noHidNeurons) 43 | self.torchMask1=torch.from_numpy(self.mask1).float().to(device) 44 | self.l1.weight.data.mul_(torch.from_numpy(self.mask1).float()) 45 | 46 | self.l2 = nn.Linear(noHidNeurons, noHidNeurons) 47 | [self.noPar2, self.mask2] = sp.initializeEpsilonWeightsMask("critic Q1 second layer", epsilonHid2, noHidNeurons, noHidNeurons) 48 | self.torchMask2 = torch.from_numpy(self.mask2).float().to(device) 49 | self.l2.weight.data.mul_(torch.from_numpy(self.mask2).float()) 50 | 51 | self.l3 = nn.Linear(noHidNeurons, 1) 52 | 53 | # Q2 architecture 54 | self.l4 = nn.Linear(state_dim + action_dim, noHidNeurons) 55 | [self.noPar4, self.mask4] = sp.initializeEpsilonWeightsMask("critic Q2 first layer", epsilonHid1, state_dim + action_dim, noHidNeurons) 56 | self.torchMask4 = torch.from_numpy(self.mask4).float().to(device) 57 | self.l4.weight.data.mul_(torch.from_numpy(self.mask4).float()) 58 | 59 | self.l5 = nn.Linear(noHidNeurons, noHidNeurons) 60 | [self.noPar5, self.mask5] = sp.initializeEpsilonWeightsMask("critic Q2 second layer", epsilonHid2, noHidNeurons, noHidNeurons) 61 | self.torchMask5 = torch.from_numpy(self.mask5).float().to(device) 62 | self.l5.weight.data.mul_(torch.from_numpy(self.mask5).float()) 63 | 64 | self.l6 = nn.Linear(noHidNeurons, 1) 65 | 66 | 67 | def forward(self, state, action): 68 | sa = torch.cat([state, action], 1) 69 | 70 | q1 = F.relu(self.l1(sa)) 71 | q1 = F.relu(self.l2(q1)) 72 | q1 = self.l3(q1) 73 | 74 | q2 = F.relu(self.l4(sa)) 75 | q2 = F.relu(self.l5(q2)) 76 | q2 = self.l6(q2) 77 | return q1, q2 78 | 79 | 80 | def Q1(self, state, action): 81 | sa = torch.cat([state, action], 1) 82 | 83 | q1 = F.relu(self.l1(sa)) 84 | q1 = F.relu(self.l2(q1)) 85 | q1 = self.l3(q1) 86 | return q1 87 | 88 | 89 | class SETSparseTD3(object): 90 | def __init__( 91 | self, 92 | state_dim, 93 | action_dim, 94 | max_action, 95 | discount=0.99, 96 | tau=0.005, 97 | policy_noise=0.2, 98 | noise_clip=0.5, 99 | policy_freq=2, 100 | noHidNeurons=256, 101 | epsilonHid1=20, 102 | epsilonHid2=20, 103 | setZeta=0.05, 104 | ascTopologyChangePeriod=1000, 105 | earlyStopTopologyChangeIteration=1e8 #kind of never 106 | ): 107 | 108 | self.actor = Actor(state_dim, action_dim, max_action,noHidNeurons,epsilonHid1,epsilonHid2).to(device) 109 | self.actor_target = copy.deepcopy(self.actor) 110 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3, weight_decay=0.0002) 111 | 112 | self.critic = Critic(state_dim, action_dim,noHidNeurons,epsilonHid1,epsilonHid2).to(device) 113 | self.critic_target = copy.deepcopy(self.critic) 114 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3, weight_decay=0.0002) 115 | 116 | self.max_action = max_action 117 | self.discount = discount 118 | self.tau = tau 119 | self.policy_noise = policy_noise 120 | self.noise_clip = noise_clip 121 | self.policy_freq = policy_freq 122 | 123 | self.setZeta = setZeta 124 | self.ascTopologyChangePeriod = ascTopologyChangePeriod 125 | self.earlyStopTopologyChangeIteration = earlyStopTopologyChangeIteration 126 | self.lastTopologyChangeCritic = False 127 | self.lastTopologyChangeActor = False 128 | 129 | self.ascStatsActor=[] 130 | self.ascStatsCritic = [] 131 | 132 | self.total_it = 0 133 | 134 | 135 | def select_action(self, state): 136 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 137 | return self.actor(state).cpu().data.numpy().flatten() 138 | 139 | 140 | def train(self, replay_buffer, batch_size=100): 141 | self.total_it += 1 142 | 143 | # Sample replay buffer 144 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 145 | 146 | with torch.no_grad(): 147 | # Select action according to policy and add clipped noise 148 | noise = ( 149 | torch.randn_like(action) * self.policy_noise 150 | ).clamp(-self.noise_clip, self.noise_clip) 151 | 152 | next_action = ( 153 | self.actor_target(next_state) + noise 154 | ).clamp(-self.max_action, self.max_action) 155 | 156 | # Compute the target Q value 157 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 158 | target_Q = torch.min(target_Q1, target_Q2) 159 | target_Q = reward + not_done * self.discount * target_Q 160 | 161 | # Get current Q estimates 162 | current_Q1, current_Q2 = self.critic(state, action) 163 | 164 | # Compute critic loss 165 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 166 | 167 | # Optimize the critic 168 | self.critic_optimizer.zero_grad() 169 | critic_loss.backward() 170 | self.critic_optimizer.step() 171 | 172 | # Adapt the sparse connectivity 173 | if (self.lastTopologyChangeCritic==False): 174 | if (self.total_it % self.ascTopologyChangePeriod == 2): 175 | if (self.total_it>self.earlyStopTopologyChangeIteration): 176 | self.lastTopologyChangeCritic = True 177 | [self.critic.mask1, ascStats1] = sp.changeConnectivitySET(self.critic.l1.weight.data.cpu().numpy(), self.critic.noPar1,self.critic.mask1,self.setZeta,self.lastTopologyChangeCritic,self.total_it) 178 | self.critic.torchMask1 = torch.from_numpy(self.critic.mask1).float().to(device) 179 | [self.critic.mask2, ascStats2] = sp.changeConnectivitySET(self.critic.l2.weight.data.cpu().numpy(), self.critic.noPar2,self.critic.mask2,self.setZeta,self.lastTopologyChangeCritic,self.total_it) 180 | self.critic.torchMask2 = torch.from_numpy(self.critic.mask2).float().to(device) 181 | [self.critic.mask4, ascStats4] = sp.changeConnectivitySET(self.critic.l4.weight.data.cpu().numpy(), self.critic.noPar4,self.critic.mask4,self.setZeta,self.lastTopologyChangeCritic,self.total_it) 182 | self.critic.torchMask4 = torch.from_numpy(self.critic.mask4).float().to(device) 183 | [self.critic.mask5, ascStats5] = sp.changeConnectivitySET(self.critic.l5.weight.data.cpu().numpy(), self.critic.noPar5,self.critic.mask5,self.setZeta,self.lastTopologyChangeCritic,self.total_it) 184 | self.critic.torchMask5 = torch.from_numpy(self.critic.mask5).float().to(device) 185 | self.ascStatsCritic.append([ascStats1,ascStats2,ascStats4,ascStats5]) 186 | 187 | # Maintain the same sparse connectivity for critic 188 | self.critic.l1.weight.data.mul_(self.critic.torchMask1) 189 | self.critic.l2.weight.data.mul_(self.critic.torchMask2) 190 | self.critic.l4.weight.data.mul_(self.critic.torchMask4) 191 | self.critic.l5.weight.data.mul_(self.critic.torchMask5) 192 | 193 | # Delayed policy updates 194 | if self.total_it % self.policy_freq == 0: 195 | 196 | # Compute actor losse 197 | actor_loss = -self.critic.Q1(state, self.actor(state)).mean() 198 | 199 | # Optimize the actor 200 | self.actor_optimizer.zero_grad() 201 | actor_loss.backward() 202 | self.actor_optimizer.step() 203 | 204 | if (self.lastTopologyChangeActor == False): 205 | if (self.total_it % self.ascTopologyChangePeriod == 2): 206 | if (self.total_it > self.earlyStopTopologyChangeIteration): 207 | self.lastTopologyChangeActor = True 208 | [self.actor.mask1, ascStats1] = sp.changeConnectivitySET(self.actor.l1.weight.data.cpu().numpy(), self.actor.noPar1, self.actor.mask1, self.setZeta, self.lastTopologyChangeActor, self.total_it) 209 | self.actor.torchMask1 = torch.from_numpy(self.actor.mask1).float().to(device) 210 | [self.actor.mask2, ascStats2] = sp.changeConnectivitySET(self.actor.l2.weight.data.cpu().numpy(), self.actor.noPar2, self.actor.mask2, self.setZeta, self.lastTopologyChangeActor, self.total_it) 211 | self.actor.torchMask2 = torch.from_numpy(self.actor.mask2).float().to(device) 212 | self.ascStatsActor.append([ascStats1, ascStats2]) 213 | 214 | # Maintain the same sparse connectivity for actor 215 | self.actor.l1.weight.data.mul_(self.actor.torchMask1) 216 | self.actor.l2.weight.data.mul_(self.actor.torchMask2) 217 | 218 | # Update the frozen target models 219 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 220 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 221 | if len(param.shape)>1: 222 | self.update_target_networks(param, target_param, device) 223 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 224 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 225 | if len(param.shape)>1: 226 | self.update_target_networks(param, target_param, device) 227 | 228 | # Maintain sparsity in target networks 229 | def update_target_networks(self, param, target_param, device): 230 | current_density = (param!=0).sum() 231 | target_density = (target_param!=0).sum() #torch.count_nonzero(target_param.data) 232 | difference = target_density - current_density 233 | # constrain the sparsity by removing the extra elements (smallest values) 234 | if(difference>0): 235 | count_rmv = difference 236 | tmp = copy.deepcopy(abs(target_param.data)) 237 | tmp[tmp==0]= 10000000 238 | unraveled = self.unravel_index(torch.argsort(tmp.view(1,-1)[0]), tmp.shape) 239 | rmv_indicies = torch.stack(unraveled, dim=1) 240 | rmv_values_smaller_than = tmp[rmv_indicies[count_rmv][0],rmv_indicies[count_rmv][1]] 241 | target_param.data[tmp1): 254 | critic_current_sparsity = ((target_param==0).sum().cpu().data.numpy()*1.0/(target_param.shape[0]*target_param.shape[1])) 255 | print("target critic sparsity", critic_current_sparsity) 256 | 257 | critic_current_sparsity = ((param==0).sum().cpu().data.numpy()*1.0/(param.shape[0]*param.shape[1])) 258 | print("critic sparsity", critic_current_sparsity) 259 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 260 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 261 | if(len(target_param.shape)>1): 262 | critic_current_sparsity = ((target_param==0).sum().cpu().data.numpy()*1.0/(target_param.shape[0]*target_param.shape[1])) 263 | print("target actor sparsity", critic_current_sparsity) 264 | 265 | critic_current_sparsity = ((param==0).sum().cpu().data.numpy()*1.0/(param.shape[0]*param.shape[1])) 266 | print("actor sparsity", critic_current_sparsity) 267 | 268 | def saveAscStats(self, filename): 269 | np.savez(filename+"_ASC_stats.npz",ascStatsActor=self.ascStatsActor, ascStatsCritic=self.ascStatsCritic) 270 | 271 | def save(self, filename): 272 | torch.save(self.critic.state_dict(), filename + "_critic") 273 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 274 | 275 | torch.save(self.actor.state_dict(), filename + "_actor") 276 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 277 | 278 | 279 | def load(self, filename): 280 | self.critic.load_state_dict(torch.load(filename + "_critic")) 281 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 282 | self.critic_target = copy.deepcopy(self.critic) 283 | 284 | self.actor.load_state_dict(torch.load(filename + "_actor")) 285 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 286 | self.actor_target = copy.deepcopy(self.actor) 287 | --------------------------------------------------------------------------------