├── .gitignore ├── Acrobot-v1 └── training_q_learning.ipynb ├── Ant-v3 └── training_td3-SNN.ipynb ├── CartPole-v0 └── training_q_learning.ipynb ├── HalfCheetah-v3 └── training_td3-SNN.ipynb ├── Hopper-v3 └── training_td3-SNN.ipynb ├── Pendulum-v0 └── training_td3-SNN.ipynb ├── README.md ├── requirements.txt ├── rstdp_domain_adaptation ├── pretrained_models │ ├── acrobot │ │ ├── checkpoint_DSQN_0.pt │ │ ├── checkpoint_DSQN_1.pt │ │ └── checkpoint_DSQN_2.pt │ └── cartpole │ │ ├── checkpoint_DSQN_0.pt │ │ ├── checkpoint_DSQN_1.pt │ │ └── checkpoint_DSQN_2.pt ├── rstdp_acrobot_adaptation.ipynb ├── rstdp_cartpole_adaptation.ipynb └── utils.py ├── seeds ├── evaluation_seeds.npy ├── rstdp_training_seeds.npy └── training_seeds.npy ├── src ├── dqn_agent.py ├── dsnn.py ├── model.py ├── rstdp.py └── td3_agent.py └── swing-up-cartpole ├── checkpoint_TD3_actor_1500.pt ├── evaluate_snn.py ├── memory_buffer.py ├── model.py ├── smoothed_rewards_500.npy ├── snn_results ├── checkpoint_TD3_actor_0100.pt ├── checkpoint_TD3_actor_0200.pt ├── checkpoint_TD3_actor_0300.pt ├── checkpoint_TD3_actor_0400.pt ├── checkpoint_TD3_actor_0500.pt ├── checkpoint_TD3_actor_0600.pt ├── checkpoint_TD3_actor_0700.pt ├── checkpoint_TD3_actor_0800.pt ├── checkpoint_TD3_actor_0900.pt ├── checkpoint_TD3_actor_1000.pt ├── checkpoint_TD3_actor_1100.pt ├── checkpoint_TD3_actor_1200.pt ├── checkpoint_TD3_actor_1300.pt ├── checkpoint_TD3_actor_1400.pt ├── checkpoint_TD3_actor_1500.pt ├── checkpoint_TD3_actor_1600.pt ├── checkpoint_TD3_actor_1700.pt ├── checkpoint_TD3_actor_1800.pt ├── checkpoint_TD3_actor_1900.pt ├── checkpoint_TD3_actor_2000.pt ├── checkpoint_TD3_actor_2100.pt ├── checkpoint_TD3_actor_2200.pt ├── checkpoint_TD3_actor_2300.pt ├── checkpoint_TD3_actor_2400.pt ├── checkpoint_TD3_actor_2500.pt ├── checkpoint_TD3_actor_2600.pt ├── checkpoint_TD3_actor_2700.pt ├── checkpoint_TD3_actor_2800.pt ├── checkpoint_TD3_actor_2900.pt ├── checkpoint_TD3_actor_3000.pt ├── checkpoint_TD3_actor_3100.pt ├── checkpoint_TD3_actor_3200.pt ├── checkpoint_TD3_actor_3300.pt ├── checkpoint_TD3_actor_3400.pt ├── checkpoint_TD3_actor_3500.pt ├── checkpoint_TD3_actor_3600.pt ├── checkpoint_TD3_actor_3700.pt ├── checkpoint_TD3_actor_3800.pt ├── checkpoint_TD3_actor_3900.pt ├── checkpoint_TD3_actor_4000.pt ├── checkpoint_TD3_actor_4100.pt ├── checkpoint_TD3_actor_4200.pt ├── checkpoint_TD3_actor_4300.pt ├── checkpoint_TD3_actor_4400.pt ├── checkpoint_TD3_actor_4500.pt ├── checkpoint_TD3_actor_4600.pt ├── checkpoint_TD3_actor_4700.pt ├── checkpoint_TD3_actor_4800.pt ├── checkpoint_TD3_actor_4900.pt ├── checkpoint_TD3_actor_5000.pt └── td3_snn_swing_up_cartpole.png ├── td3_ann_main.py └── td3_snn.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea/ 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | -------------------------------------------------------------------------------- /Acrobot-v1/training_q_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DQN vs. DSQN for the CartPole Environment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import gym\n", 18 | "import site\n", 19 | "import torch\n", 20 | "import random\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import torch.optim as optim\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "site.addsitedir('../src/')\n", 27 | "\n", 28 | "from datetime import date\n", 29 | "from model import QNetwork, DSNN\n", 30 | "from dqn_agent import Agent, ReplayBuffer\n", 31 | "from matplotlib.gridspec import GridSpec\n", 32 | "\n", 33 | "%matplotlib inline" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Environment specific parameters\n", 43 | "env_name = 'Acrobot-v1'\n", 44 | "n_runs = 5\n", 45 | "n_evaluations = 100\n", 46 | "max_steps = 500\n", 47 | "num_episodes = 500\n", 48 | "\n", 49 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# Create Results Directory\n", 59 | "dirs = os.listdir('.')\n", 60 | "if not any('result' in d for d in dirs):\n", 61 | " result_id = 1\n", 62 | "else:\n", 63 | " results = [d for d in dirs if 'result' in d]\n", 64 | " result_id = len(results) + 1\n", 65 | "\n", 66 | "# Get today's date and add it to the results directory\n", 67 | "d = date.today()\n", 68 | "result_dir = 'dqn_result_' + str(result_id) + '_{}'.format(\n", 69 | " str(d.year) + str(d.month) + str(d.day))\n", 70 | "os.mkdir(result_dir)\n", 71 | "print('Created Directory {} to store the results in'.format(result_dir))" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Hyperparameters\n", 81 | "batch_size = 128\n", 82 | "discount_factor = 0.999\n", 83 | "eps_start = 1.0\n", 84 | "eps_end = 0.05\n", 85 | "eps_decay = 0.999\n", 86 | "update_every = 4\n", 87 | "target_update_frequency = 100\n", 88 | "learning_rate = 0.001\n", 89 | "replay_memory_size = 4*10**4\n", 90 | "tau = 1e-3" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# SNN Hyperparameters\n", 100 | "simulation_time = 3\n", 101 | "alpha = 0.8\n", 102 | "beta = 0.8\n", 103 | "threshold = 1.0\n", 104 | "weight_scale = 1\n", 105 | "architecture = [6, 256, 256, 3]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "seeds = np.load('../seeds/training_seeds.npy')" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## DQN Training" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "scrolled": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "smoothed_scores_dqn_all = []\n", 133 | "dqn_completion_after = []\n", 134 | "\n", 135 | "for i_run in range(n_runs):\n", 136 | " print(\"Run # {}\".format(i_run))\n", 137 | " seed = int(seeds[i_run])\n", 138 | " \n", 139 | " torch.manual_seed(seed)\n", 140 | " random.seed(seed)\n", 141 | "\n", 142 | " policy_net = QNetwork(architecture, seed).to(device)\n", 143 | " target_net = QNetwork(architecture, seed).to(device)\n", 144 | " target_net.load_state_dict(policy_net.state_dict())\n", 145 | "\n", 146 | " optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)\n", 147 | " agent = Agent(env_name, policy_net, target_net, architecture, batch_size,\n", 148 | " replay_memory_size, discount_factor, eps_start, eps_end, eps_decay,\n", 149 | " update_every, target_update_frequency, optimizer, learning_rate,\n", 150 | " num_episodes, max_steps, i_run, result_dir, seed, tau)\n", 151 | " \n", 152 | " smoothed_scores, scores, best_average, best_average_after = agent.train_agent()\n", 153 | "\n", 154 | " np.save(result_dir + '/scores_{}'.format(i_run), scores)\n", 155 | " np.save(result_dir + '/smoothed_scores_DQN_{}'.format(i_run), smoothed_scores)\n", 156 | "\n", 157 | " # save smoothed scores in list to plot later\n", 158 | " dqn_completion_after.append(best_average_after)\n", 159 | " smoothed_scores_dqn_all.append(smoothed_scores)\n", 160 | " print(\"\")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "scrolled": true 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "# Plot scores of individual runs\n", 172 | "for i in range(len(smoothed_scores_dqn_all)):\n", 173 | " fig = plt.figure()\n", 174 | " plt.plot(smoothed_scores_dqn_all[i])\n", 175 | " plt.ylim(-550, 0)\n", 176 | " plt.grid(True)\n", 177 | " plt.savefig(result_dir + '/training_dqn_{}.png'.format(i), dpi=1000)\n", 178 | " plt.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# Plot results (mean)\n", 188 | "best_smoothed_scores_dqn = [smoothed_scores_dqn_all[0],\n", 189 | " smoothed_scores_dqn_all[1],\n", 190 | " smoothed_scores_dqn_all[2],\n", 191 | " smoothed_scores_dqn_all[3],\n", 192 | " smoothed_scores_dqn_all[4]]\n", 193 | "mean_smoothed_scores_dqn = np.mean(best_smoothed_scores_dqn, axis=0)\n", 194 | "std_smoothed_scores = np.std(best_smoothed_scores_dqn, axis=0)\n", 195 | "\n", 196 | "avg_dqn_completion_after = np.mean([dqn_completion_after[0],\n", 197 | " dqn_completion_after[1],\n", 198 | " dqn_completion_after[2],\n", 199 | " dqn_completion_after[3],\n", 200 | " dqn_completion_after[4]])\n", 201 | "\n", 202 | "fig = plt.figure()\n", 203 | "plt.plot(range(len(best_smoothed_scores_dqn[0])), mean_smoothed_scores_dqn)\n", 204 | "plt.fill_between(range(len(best_smoothed_scores_dqn[0])),\n", 205 | " np.nanpercentile(best_smoothed_scores_dqn, 2, axis=0),\n", 206 | " np.nanpercentile(best_smoothed_scores_dqn, 97, axis=0), alpha=0.25)\n", 207 | "plt.ylim(-550, 0)\n", 208 | "plt.grid(True)\n", 209 | "plt.savefig(result_dir + '/DQN_training.png', dpi=300)\n", 210 | "plt.show()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## DSQN Training" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": { 224 | "scrolled": false 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "smoothed_scores_dsqn_all = []\n", 229 | "dsqn_completion_after = []\n", 230 | "\n", 231 | "for i_run in range(n_runs):\n", 232 | " print(\"Run # {}\".format(i_run))\n", 233 | " seed = int(seeds[i_run])\n", 234 | "\n", 235 | " torch.manual_seed(seed)\n", 236 | " random.seed(seed)\n", 237 | "\n", 238 | " policy_net = DSNN(architecture, seed, alpha, beta, weight_scale, batch_size, threshold,\n", 239 | " simulation_time, learning_rate)\n", 240 | " target_net = DSNN(architecture, seed, alpha, beta, weight_scale, batch_size, threshold,\n", 241 | " simulation_time, learning_rate)\n", 242 | " target_net.load_state_dict(policy_net.state_dict())\n", 243 | " optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)\n", 244 | "\n", 245 | " agent = Agent(env_name, policy_net, target_net, architecture, batch_size,\n", 246 | " replay_memory_size, discount_factor, eps_start, eps_end, eps_decay,\n", 247 | " update_every, target_update_frequency, optimizer, learning_rate,\n", 248 | " num_episodes, max_steps, i_run, result_dir, seed, tau, spiking=True)\n", 249 | "\n", 250 | " smoothed_scores, scores, best_average, best_average_after = agent.train_agent()\n", 251 | "\n", 252 | " np.save(result_dir + '/scores_{}'.format(i_run), scores)\n", 253 | " np.save(result_dir + '/smoothed_scores_DSQN_{}'.format(i_run), smoothed_scores)\n", 254 | "\n", 255 | " # save smoothed scores in list to plot later\n", 256 | " smoothed_scores_dsqn_all.append(smoothed_scores)\n", 257 | " dsqn_completion_after.append(best_average_after)\n", 258 | " print(\"\")" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "best_smoothed_scores_dsqn = [smoothed_scores_dsqn_all[0],\n", 268 | " smoothed_scores_dsqn_all[1],\n", 269 | " smoothed_scores_dsqn_all[2],\n", 270 | " smoothed_scores_dsqn_all[3],\n", 271 | " smoothed_scores_dsqn_all[4]]\n", 272 | "mean_smoothed_scores_dsqn = np.mean(best_smoothed_scores_dsqn, axis=0)\n", 273 | "\n", 274 | "avg_dsqn_completion_after = np.mean([dsqn_completion_after[0],\n", 275 | " dsqn_completion_after[1],\n", 276 | " dsqn_completion_after[2],\n", 277 | " dsqn_completion_after[3],\n", 278 | " dsqn_completion_after[4]])\n", 279 | "\n", 280 | "fig = plt.figure()\n", 281 | "plt.plot(range(len(best_smoothed_scores_dsqn[0])), mean_smoothed_scores_dsqn)\n", 282 | "plt.fill_between(range(len(best_smoothed_scores_dsqn[0])),\n", 283 | " np.nanpercentile(best_smoothed_scores_dsqn, 2, axis=0),\n", 284 | " np.nanpercentile(best_smoothed_scores_dsqn, 97, axis=0), alpha=0.25)\n", 285 | "\n", 286 | "plt.vlines(avg_dsqn_completion_after, 0, 250, 'C0')\n", 287 | "\n", 288 | "\n", 289 | "plt.ylim(-550, 0)\n", 290 | "plt.grid(True)\n", 291 | "plt.savefig(result_dir + '/DSQN_training.png', dpi=1000)\n", 292 | "plt.title('Acrobot-v1 DSQN')\n", 293 | "plt.show()" 294 | ] 295 | } 296 | ], 297 | "metadata": { 298 | "kernelspec": { 299 | "display_name": "Python 3 (ipykernel)", 300 | "language": "python", 301 | "name": "python3" 302 | }, 303 | "language_info": { 304 | "codemirror_mode": { 305 | "name": "ipython", 306 | "version": 3 307 | }, 308 | "file_extension": ".py", 309 | "mimetype": "text/x-python", 310 | "name": "python", 311 | "nbconvert_exporter": "python", 312 | "pygments_lexer": "ipython3", 313 | "version": "3.8.12" 314 | } 315 | }, 316 | "nbformat": 4, 317 | "nbformat_minor": 2 318 | } 319 | -------------------------------------------------------------------------------- /Ant-v3/training_td3-SNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Twin Delayed Deep Deterministic Policy Gradients (TD3)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import site\n", 18 | "import torch\n", 19 | "import random\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import gymnasium as gym\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "%matplotlib inline\n", 25 | "\n", 26 | "site.addsitedir('../src/')\n", 27 | "\n", 28 | "from datetime import date\n", 29 | "from td3_agent import Agent\n", 30 | "from model import TD3CriticNetwork, TD3ActorDSNN" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# Create Results Directory\n", 40 | "dirs = os.listdir('.')\n", 41 | "if not any('result' in d for d in dirs):\n", 42 | " result_id = 1\n", 43 | "else:\n", 44 | " results = [d for d in dirs if 'result' in d]\n", 45 | " result_id = len(results) + 1\n", 46 | "\n", 47 | "# Get today's date and add it to the results directory\n", 48 | "d = date.today()\n", 49 | "result_dir = 'td3_result_' + str(result_id) + '_{}'.format(\n", 50 | " str(d.year) + str(d.month) + str(d.day))\n", 51 | "os.mkdir(result_dir)\n", 52 | "print('Created Directory {} to store the results in'.format(result_dir))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "n_runs = 10\n", 62 | "n_timesteps = 1e6\n", 63 | "batch_size = 100\n", 64 | "\n", 65 | "seeds = np.load('../seeds/training_seeds.npy')\n", 66 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "actor_learning_rate = 0.001\n", 76 | "critic_learning_rate = 0.001\n", 77 | "tau = 0.005\n", 78 | "layer1_size = 400\n", 79 | "layer2_size = 300\n", 80 | "noise = 0.1\n", 81 | "warmup = 1000\n", 82 | "update_actor_interval = 2\n", 83 | "update_target_interval = 1\n", 84 | "buffer_size = int(1e6)\n", 85 | "pop_size = 10\n", 86 | "pop_coding = False\n", 87 | "two_neuron = True\n", 88 | "mutually_exclusive = False" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "alpha = 0.5\n", 98 | "beta = 0.5\n", 99 | "weight_scale = 1\n", 100 | "threshold = 2.5\n", 101 | "sim_time = 5" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "scrolled": true 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "smoothed_scores_all = []\n", 113 | "#torch.autograd.set_detect_anomaly(True)\n", 114 | "\n", 115 | "for i in range(n_runs):\n", 116 | " print(\"Run # {}\".format(i))\n", 117 | "\n", 118 | " seed = int(seeds[i])\n", 119 | " \n", 120 | " env = gym.make('Ant-v3')\n", 121 | " \n", 122 | " if two_neuron:\n", 123 | " input_dims = (env.observation_space.shape[0]*2,)\n", 124 | " elif pop_coding:\n", 125 | " input_dims = (env.observation_space.shape[0]*pop_size,)\n", 126 | " else:\n", 127 | " input_dims = env.observation_space.shape\n", 128 | " n_actions = env.action_space.shape[0]\n", 129 | "\n", 130 | " actor_architecture = [input_dims[0], layer1_size, layer2_size, n_actions]\n", 131 | " \n", 132 | " torch.manual_seed(seed)\n", 133 | " np.random.seed(seed)\n", 134 | " random.seed(seed)\n", 135 | "\n", 136 | " actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 137 | " threshold, sim_time, actor_learning_rate, name='actor_{}'.format(i), device=device)\n", 138 | " target_actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 139 | " threshold, sim_time, actor_learning_rate, name='target_actor_{}'.format(i), device=device)\n", 140 | "\n", 141 | " critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 142 | " layer2_size, n_actions=n_actions, name='critic_1_{}'.format(i))\n", 143 | " critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 144 | " layer2_size, n_actions=n_actions, name='critic_2_{}'.format(i))\n", 145 | " target_critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 146 | " layer2_size, n_actions=n_actions, name='target_critic_1_{}'.format(i))\n", 147 | " target_critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 148 | " layer2_size, n_actions=n_actions, name='target_critic_2_{}'.format(i))\n", 149 | "\n", 150 | " agent = Agent(actor, critic_1, critic_2, target_actor, target_critic_1, target_critic_2,\n", 151 | " input_dims, tau, env, n_timesteps, result_dir, n_actions=n_actions, seed=seed,\n", 152 | " noise=noise, update_actor_interval=update_actor_interval, warmup=warmup,\n", 153 | " update_target_interval=update_target_interval, two_neuron=two_neuron,\n", 154 | " buffer_size=buffer_size, spiking=True, normalize=True)\n", 155 | "\n", 156 | " smoothed_scores, reward_history, best_average, best_average_after = agent.train_agent()\n", 157 | " smoothed_scores_all.append(smoothed_scores)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "pycharm": { 165 | "name": "#%%\n" 166 | } 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "final_smoothed_scores = [smoothed_scores_all[i] for i in range(n_runs)]\n", 171 | "mean_smoothed_scores_dqn = np.mean(final_smoothed_scores, axis=0)\n", 172 | "std_smoothed_scores = np.std(final_smoothed_scores, axis=0)\n", 173 | "\n", 174 | "fig = plt.figure()\n", 175 | "plt.plot(range(len(final_smoothed_scores[0])), mean_smoothed_scores_dqn)\n", 176 | "plt.fill_between(range(len(final_smoothed_scores[0])),\n", 177 | " np.nanpercentile(final_smoothed_scores, 2, axis=0),\n", 178 | " np.nanpercentile(final_smoothed_scores, 97, axis=0), alpha=0.25)\n", 179 | "plt.grid(True)\n", 180 | "plt.savefig(result_dir + '/td3_training_snn.png', dpi=300)\n", 181 | "plt.show()" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3 (ipykernel)", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.10.12" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 2 206 | } 207 | -------------------------------------------------------------------------------- /CartPole-v0/training_q_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DQN vs. DSQN for the CartPole Environment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import site\n", 18 | "import torch\n", 19 | "import random\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import torch.optim as optim\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "\n", 25 | "site.addsitedir('../src/')\n", 26 | "\n", 27 | "from datetime import date\n", 28 | "from model import QNetwork, DSNN\n", 29 | "from dqn_agent import Agent, ReplayBuffer\n", 30 | "from matplotlib.gridspec import GridSpec\n", 31 | "\n", 32 | "%matplotlib inline" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Environment specific parameters\n", 42 | "env_name = 'CartPole-v0'\n", 43 | "n_runs = 5\n", 44 | "n_evaluations = 100\n", 45 | "max_steps = 200\n", 46 | "num_episodes = 500\n", 47 | "\n", 48 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Create Results Directory\n", 58 | "dirs = os.listdir('.')\n", 59 | "if not any('result' in d for d in dirs):\n", 60 | " result_id = 1\n", 61 | "else:\n", 62 | " results = [d for d in dirs if 'result' in d]\n", 63 | " result_id = len(results) + 1\n", 64 | "\n", 65 | "# Get today's date and add it to the results directory\n", 66 | "d = date.today()\n", 67 | "result_dir = 'dqn_result_' + str(result_id) + '_{}'.format(\n", 68 | " str(d.year) + str(d.month) + str(d.day))\n", 69 | "os.mkdir(result_dir)\n", 70 | "print('Created Directory {} to store the results in'.format(result_dir))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Hyperparameters\n", 80 | "batch_size = 128\n", 81 | "discount_factor = 0.999\n", 82 | "eps_start = 1.0\n", 83 | "eps_end = 0.05\n", 84 | "eps_decay = 0.999\n", 85 | "update_every = 4\n", 86 | "target_update_frequency = 100\n", 87 | "learning_rate = 0.001\n", 88 | "replay_memory_size = 4*10**4\n", 89 | "tau = 1e-3" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# SNN Hyperparameters\n", 99 | "simulation_time = 3\n", 100 | "alpha = 0.5\n", 101 | "beta = 0.5\n", 102 | "threshold = 0.2\n", 103 | "weight_scale = 1\n", 104 | "architecture = [4, 64, 64, 2]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "seeds = np.load('../seeds/training_seeds.npy')" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## DQN Training" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "scrolled": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "smoothed_scores_dqn_all = []\n", 132 | "dqn_completion_after = []\n", 133 | "\n", 134 | "for i_run in range(n_runs):\n", 135 | " print(\"Run # {}\".format(i_run))\n", 136 | " seed = int(seeds[i_run])\n", 137 | " \n", 138 | " torch.manual_seed(seed)\n", 139 | " random.seed(seed)\n", 140 | "\n", 141 | " policy_net = QNetwork(architecture, seed).to(device)\n", 142 | " target_net = QNetwork(architecture, seed).to(device)\n", 143 | " target_net.load_state_dict(policy_net.state_dict())\n", 144 | "\n", 145 | " optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)\n", 146 | " agent = Agent(env_name, policy_net, target_net, architecture, batch_size,\n", 147 | " replay_memory_size, discount_factor, eps_start, eps_end, eps_decay,\n", 148 | " update_every, target_update_frequency, optimizer, learning_rate,\n", 149 | " num_episodes, max_steps, i_run, result_dir, seed, tau)\n", 150 | " \n", 151 | " smoothed_scores, scores, best_average, best_average_after = agent.train_agent()\n", 152 | "\n", 153 | " np.save(result_dir + '/scores_{}'.format(i_run), scores)\n", 154 | " np.save(result_dir + '/smoothed_scores_DQN_{}'.format(i_run), smoothed_scores)\n", 155 | "\n", 156 | " # save smoothed scores in list to plot later\n", 157 | " dqn_completion_after.append(best_average_after)\n", 158 | " smoothed_scores_dqn_all.append(smoothed_scores)\n", 159 | " print(\"\")" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "scrolled": true 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "# Plot scores of individual runs\n", 171 | "for i in range(len(smoothed_scores_dqn_all)):\n", 172 | " fig = plt.figure()\n", 173 | " plt.plot(smoothed_scores_dqn_all[i])\n", 174 | " plt.ylim(0, 250)\n", 175 | " plt.grid(True)\n", 176 | " plt.savefig(result_dir + '/training_dqn_{}.png'.format(i), dpi=1000)\n", 177 | " plt.show()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "# Plot results (mean)\n", 187 | "best_smoothed_scores_dqn = [smoothed_scores_dqn_all[0],\n", 188 | " smoothed_scores_dqn_all[1],\n", 189 | " smoothed_scores_dqn_all[2],\n", 190 | " smoothed_scores_dqn_all[3],\n", 191 | " smoothed_scores_dqn_all[4]]\n", 192 | "mean_smoothed_scores_dqn = np.mean(best_smoothed_scores_dqn, axis=0)\n", 193 | "std_smoothed_scores = np.std(best_smoothed_scores_dqn, axis=0)\n", 194 | "\n", 195 | "avg_dqn_completion_after = np.mean([dqn_completion_after[0],\n", 196 | " dqn_completion_after[1],\n", 197 | " dqn_completion_after[2],\n", 198 | " dqn_completion_after[3],\n", 199 | " dqn_completion_after[4]])\n", 200 | "\n", 201 | "fig = plt.figure()\n", 202 | "plt.plot(range(len(best_smoothed_scores_dqn[0])), mean_smoothed_scores_dqn)\n", 203 | "plt.fill_between(range(len(best_smoothed_scores_dqn[0])),\n", 204 | " np.nanpercentile(best_smoothed_scores_dqn, 2, axis=0),\n", 205 | " np.nanpercentile(best_smoothed_scores_dqn, 97, axis=0), alpha=0.25)\n", 206 | "plt.ylim(0, 250)\n", 207 | "plt.grid(True)\n", 208 | "plt.savefig(result_dir + '/DQN_training.png', dpi=300)\n", 209 | "plt.show()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "## DSQN Training" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "scrolled": false 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "smoothed_scores_dsqn_all = []\n", 228 | "dsqn_completion_after = []\n", 229 | "\n", 230 | "for i_run in range(n_runs):\n", 231 | " print(\"Run # {}\".format(i_run))\n", 232 | " seed = int(seeds[i_run])\n", 233 | "\n", 234 | " torch.manual_seed(seed)\n", 235 | " random.seed(seed)\n", 236 | "\n", 237 | " policy_net = DSNN(architecture, seed, alpha, beta, weight_scale, batch_size, threshold,\n", 238 | " simulation_time, learning_rate)\n", 239 | " target_net = DSNN(architecture, seed, alpha, beta, weight_scale, batch_size, threshold,\n", 240 | " simulation_time, learning_rate)\n", 241 | " target_net.load_state_dict(policy_net.state_dict())\n", 242 | " optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)\n", 243 | "\n", 244 | " agent = Agent(env_name, policy_net, target_net, architecture, batch_size,\n", 245 | " replay_memory_size, discount_factor, eps_start, eps_end, eps_decay,\n", 246 | " update_every, target_update_frequency, optimizer, learning_rate,\n", 247 | " num_episodes, max_steps, i_run, result_dir, seed, tau, spiking=True)\n", 248 | "\n", 249 | " smoothed_scores, scores, best_average, best_average_after = agent.train_agent()\n", 250 | "\n", 251 | " np.save(result_dir + '/scores_{}'.format(i_run), scores)\n", 252 | " np.save(result_dir + '/smoothed_scores_DSQN_{}'.format(i_run), smoothed_scores)\n", 253 | "\n", 254 | " # save smoothed scores in list to plot later\n", 255 | " smoothed_scores_dsqn_all.append(smoothed_scores)\n", 256 | " dsqn_completion_after.append(best_average_after)\n", 257 | " print(\"\")" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "best_smoothed_scores_dsqn = [smoothed_scores_dsqn_all[0],\n", 267 | " smoothed_scores_dsqn_all[1],\n", 268 | " smoothed_scores_dsqn_all[2],\n", 269 | " smoothed_scores_dsqn_all[3],\n", 270 | " smoothed_scores_dsqn_all[4]]\n", 271 | "mean_smoothed_scores_dsqn = np.mean(best_smoothed_scores_dsqn, axis=0)\n", 272 | "\n", 273 | "avg_dsqn_completion_after = np.mean([dsqn_completion_after[0],\n", 274 | " dsqn_completion_after[1],\n", 275 | " dsqn_completion_after[2],\n", 276 | " dsqn_completion_after[3],\n", 277 | " dsqn_completion_after[4]])\n", 278 | "\n", 279 | "fig = plt.figure()\n", 280 | "plt.plot(range(len(best_smoothed_scores_dsqn[0])), mean_smoothed_scores_dsqn)\n", 281 | "plt.fill_between(range(len(best_smoothed_scores_dsqn[0])),\n", 282 | " np.nanpercentile(best_smoothed_scores_dsqn, 2, axis=0),\n", 283 | " np.nanpercentile(best_smoothed_scores_dsqn, 97, axis=0), alpha=0.25)\n", 284 | "\n", 285 | "plt.vlines(avg_dsqn_completion_after, 0, 250, 'C0')\n", 286 | "\n", 287 | "\n", 288 | "plt.ylim(0, 250)\n", 289 | "plt.grid(True)\n", 290 | "plt.savefig(result_dir + '/DSQN_training.png', dpi=1000)\n", 291 | "plt.title('CartPole-v0 DSQN')\n", 292 | "plt.show()" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3 (ipykernel)", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.10.12" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 2 317 | } 318 | -------------------------------------------------------------------------------- /HalfCheetah-v3/training_td3-SNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Twin Delayed Deep Deterministic Policy Gradients (TD3)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import site\n", 18 | "import torch\n", 19 | "import random\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import gymnasium as gym\n", 23 | "import torch.optim as optim\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "%matplotlib inline\n", 26 | "\n", 27 | "site.addsitedir('../src/')\n", 28 | "\n", 29 | "from datetime import date\n", 30 | "from td3_agent import Agent\n", 31 | "from collections import deque\n", 32 | "from model import TD3CriticNetwork, TD3ActorDSNN" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Create Results Directory\n", 42 | "dirs = os.listdir('.')\n", 43 | "if not any('result' in d for d in dirs):\n", 44 | " result_id = 1\n", 45 | "else:\n", 46 | " results = [d for d in dirs if 'result' in d]\n", 47 | " result_id = len(results) + 1\n", 48 | "\n", 49 | "# Get today's date and add it to the results directory\n", 50 | "d = date.today()\n", 51 | "result_dir = 'td3_result_' + str(result_id) + '_{}'.format(\n", 52 | " str(d.year) + str(d.month) + str(d.day))\n", 53 | "os.mkdir(result_dir)\n", 54 | "print('Created Directory {} to store the results in'.format(result_dir))" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "n_runs = 10\n", 64 | "n_timesteps = 1e6\n", 65 | "batch_size = 128\n", 66 | "\n", 67 | "seeds = np.load('../seeds/training_seeds.npy')\n", 68 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "actor_learning_rate = 0.001\n", 78 | "critic_learning_rate = 0.001\n", 79 | "tau = 0.005\n", 80 | "layer1_size = 400\n", 81 | "layer2_size = 300\n", 82 | "noise = 0.1\n", 83 | "warmup = 1000\n", 84 | "update_actor_interval = 2\n", 85 | "update_target_interval = 2\n", 86 | "buffer_size = int(2e5)\n", 87 | "pop_size = 10\n", 88 | "pop_coding = False\n", 89 | "two_neuron = True\n", 90 | "mutually_exclusive = False" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "alpha = 0.5\n", 100 | "beta = 0.5\n", 101 | "weight_scale = 1\n", 102 | "threshold = 0.8\n", 103 | "sim_time = 5" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "scrolled": false 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "smoothed_scores_all = []\n", 115 | "#torch.autograd.set_detect_anomaly(True)\n", 116 | "\n", 117 | "for i in range(n_runs):\n", 118 | " print(\"Run # {}\".format(i))\n", 119 | "\n", 120 | " seed = int(seeds[i])\n", 121 | " \n", 122 | " env = gym.make('HalfCheetah-v3')\n", 123 | " \n", 124 | " if two_neuron:\n", 125 | " input_dims = (env.observation_space.shape[0]*2,)\n", 126 | " elif pop_coding:\n", 127 | " input_dims = (env.observation_space.shape[0]*pop_size,)\n", 128 | " else:\n", 129 | " input_dims = env.observation_space.shape\n", 130 | " n_actions = env.action_space.shape[0]\n", 131 | "\n", 132 | " actor_architecture = [input_dims[0], layer1_size, layer2_size, n_actions]\n", 133 | " \n", 134 | " torch.manual_seed(seed)\n", 135 | " np.random.seed(seed)\n", 136 | " random.seed(seed)\n", 137 | "\n", 138 | " actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 139 | " threshold, sim_time, actor_learning_rate, name='actor_{}'.format(i), device=device)\n", 140 | " target_actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 141 | " threshold, sim_time, actor_learning_rate, name='target_actor_{}'.format(i), device=device)\n", 142 | "\n", 143 | " critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 144 | " layer2_size, n_actions=n_actions, name='critic_1_{}'.format(i))\n", 145 | " critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 146 | " layer2_size, n_actions=n_actions, name='critic_2_{}'.format(i))\n", 147 | " target_critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 148 | " layer2_size, n_actions=n_actions, name='target_critic_1_{}'.format(i))\n", 149 | " target_critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 150 | " layer2_size, n_actions=n_actions, name='target_critic_2_{}'.format(i))\n", 151 | "\n", 152 | " agent = Agent(actor, critic_1, critic_2, target_actor, target_critic_1, target_critic_2,\n", 153 | " input_dims, tau, env, n_timesteps, result_dir, n_actions=n_actions, seed=seed,\n", 154 | " noise=noise, update_actor_interval=update_actor_interval, warmup=warmup,\n", 155 | " update_target_interval=update_target_interval, two_neuron=two_neuron,\n", 156 | " buffer_size=buffer_size, spiking=True, normalize=True)\n", 157 | " \n", 158 | " smoothed_scores, reward_history, best_average, best_average_after = agent.train_agent()\n", 159 | " print(agent.max_obs)\n", 160 | " smoothed_scores_all.append(smoothed_scores)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "final_smoothed_scores = [smoothed_scores_all[i] for i in range(n_runs)]\n", 170 | "mean_smoothed_scores_dqn = np.mean(final_smoothed_scores, axis=0)\n", 171 | "std_smoothed_scores = np.std(final_smoothed_scores, axis=0)\n", 172 | "\n", 173 | "fig = plt.figure()\n", 174 | "plt.plot(range(len(final_smoothed_scores[0])), mean_smoothed_scores_dqn)\n", 175 | "plt.fill_between(range(len(final_smoothed_scores[0])),\n", 176 | " np.nanpercentile(final_smoothed_scores, 2, axis=0),\n", 177 | " np.nanpercentile(final_smoothed_scores, 97, axis=0), alpha=0.25)\n", 178 | "plt.grid(True)\n", 179 | "plt.savefig(result_dir + '/td3_training_snn.png', dpi=300)\n", 180 | "plt.show()" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python 3 (ipykernel)", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.12" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /Hopper-v3/training_td3-SNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Twin Delayed Deep Deterministic Policy Gradients (TD3)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import site\n", 18 | "import torch\n", 19 | "import random\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import gymnasium as gym\n", 23 | "import torch.optim as optim\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "%matplotlib inline\n", 26 | "\n", 27 | "site.addsitedir('../src/')\n", 28 | "\n", 29 | "from datetime import date\n", 30 | "from td3_agent import Agent\n", 31 | "from collections import deque\n", 32 | "from model import TD3CriticNetwork, TD3ActorDSNN" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Create Results Directory\n", 42 | "dirs = os.listdir('.')\n", 43 | "if not any('result' in d for d in dirs):\n", 44 | " result_id = 1\n", 45 | "else:\n", 46 | " results = [d for d in dirs if 'result' in d]\n", 47 | " result_id = len(results) + 1\n", 48 | "\n", 49 | "# Get today's date and add it to the results directory\n", 50 | "d = date.today()\n", 51 | "result_dir = 'td3_result_' + str(result_id) + '_{}'.format(\n", 52 | " str(d.year) + str(d.month) + str(d.day))\n", 53 | "os.mkdir(result_dir)\n", 54 | "print('Created Directory {} to store the results in'.format(result_dir))" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "n_runs = 10\n", 64 | "n_timesteps = 1e6\n", 65 | "batch_size = 100\n", 66 | "\n", 67 | "seeds = np.load('../seeds/training_seeds.npy')\n", 68 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "actor_learning_rate = 0.001\n", 78 | "critic_learning_rate = 0.001\n", 79 | "tau = 0.005\n", 80 | "layer1_size = 400\n", 81 | "layer2_size = 300\n", 82 | "noise = 0.1\n", 83 | "warmup = 1000\n", 84 | "update_actor_interval = 2\n", 85 | "update_target_interval = 1\n", 86 | "buffer_size = int(1e6)\n", 87 | "pop_size = 10\n", 88 | "pop_coding = False\n", 89 | "two_neuron = True\n", 90 | "mutually_exclusive = False" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "alpha = 0.5\n", 100 | "beta = 0.5\n", 101 | "weight_scale = 1.0\n", 102 | "threshold = 2.0\n", 103 | "sim_time = 5" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "scrolled": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "smoothed_scores_all = []\n", 115 | "#torch.autograd.set_detect_anomaly(True)\n", 116 | "\n", 117 | "for i in range(n_runs):\n", 118 | " print(\"Run # {}\".format(i))\n", 119 | "\n", 120 | " seed = int(seeds[i])\n", 121 | " \n", 122 | " env = gym.make('Hopper-v3')\n", 123 | " \n", 124 | " if two_neuron:\n", 125 | " input_dims = (env.observation_space.shape[0]*2,)\n", 126 | " elif pop_coding:\n", 127 | " input_dims = (env.observation_space.shape[0]*pop_size,)\n", 128 | " else:\n", 129 | " input_dims = env.observation_space.shape\n", 130 | " n_actions = env.action_space.shape[0]\n", 131 | "\n", 132 | " actor_architecture = [input_dims[0], layer1_size, layer2_size, n_actions]\n", 133 | " \n", 134 | " torch.manual_seed(seed)\n", 135 | " np.random.seed(seed)\n", 136 | " random.seed(seed)\n", 137 | "\n", 138 | " actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 139 | " threshold, sim_time, actor_learning_rate, name='actor_{}'.format(i), device=device)\n", 140 | " target_actor = TD3ActorDSNN(actor_architecture, seed, alpha, beta, weight_scale, 1,\n", 141 | " threshold, sim_time, actor_learning_rate, name='target_actor_{}'.format(i),\n", 142 | " device=device)\n", 143 | "\n", 144 | " critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 145 | " layer2_size, n_actions=n_actions, name='critic_1_{}'.format(i))\n", 146 | " critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 147 | " layer2_size, n_actions=n_actions, name='critic_2_{}'.format(i))\n", 148 | " target_critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 149 | " layer2_size, n_actions=n_actions, name='target_critic_1_{}'.format(i))\n", 150 | " target_critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 151 | " layer2_size, n_actions=n_actions, name='target_critic_2_{}'.format(i))\n", 152 | "\n", 153 | " agent = Agent(actor, critic_1, critic_2, target_actor, target_critic_1, target_critic_2,\n", 154 | " input_dims, tau, env, n_timesteps, result_dir, n_actions=n_actions, seed=seed,\n", 155 | " noise=noise, update_actor_interval=update_actor_interval, warmup=warmup,\n", 156 | " update_target_interval=update_target_interval, two_neuron=two_neuron,\n", 157 | " buffer_size=buffer_size, spiking=True, normalize=True)\n", 158 | " \n", 159 | " smoothed_scores, reward_history, best_average, best_average_after = agent.train_agent()\n", 160 | " smoothed_scores_all.append(smoothed_scores)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "final_smoothed_scores = [smoothed_scores_all[i] for i in range(n_runs)]\n", 170 | "mean_smoothed_scores = np.mean(final_smoothed_scores, axis=0)\n", 171 | "std_smoothed_scores = np.std(final_smoothed_scores, axis=0)\n", 172 | "\n", 173 | "fig = plt.figure()\n", 174 | "plt.plot(range(len(final_smoothed_scores[0])), mean_smoothed_scores)\n", 175 | "plt.fill_between(range(len(final_smoothed_scores[0])),\n", 176 | " np.nanpercentile(final_smoothed_scores, 2, axis=0),\n", 177 | " np.nanpercentile(final_smoothed_scores, 97, axis=0), alpha=0.25)\n", 178 | "\n", 179 | "plt.grid(True)\n", 180 | "plt.savefig(result_dir + '/td3_training_pop_10.png', dpi=300)\n", 181 | "plt.show()" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3 (ipykernel)", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.10.12" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 2 206 | } 207 | -------------------------------------------------------------------------------- /Pendulum-v0/training_td3-SNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Twin Delayed Deep Deterministic Policy Gradients (TD3)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import site\n", 18 | "import torch\n", 19 | "import random\n", 20 | "import itertools\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import gymnasium as gym\n", 24 | "import torch.optim as optim\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "%matplotlib inline\n", 27 | "\n", 28 | "site.addsitedir('../src/')\n", 29 | "\n", 30 | "from datetime import date\n", 31 | "from td3_agent import Agent\n", 32 | "from collections import deque\n", 33 | "from model import TD3CriticNetwork, TD3ActorDSNN" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Create Results Directory\n", 43 | "dirs = os.listdir('.')\n", 44 | "if not any('result' in d for d in dirs):\n", 45 | " result_id = 1\n", 46 | "else:\n", 47 | " results = [d for d in dirs if 'result' in d]\n", 48 | " result_id = len(results) + 1\n", 49 | "\n", 50 | "# Get today's date and add it to the results directory\n", 51 | "d = date.today()\n", 52 | "result_dir = 'td3_result_' + str(result_id) + '_{}'.format(\n", 53 | " str(d.year) + str(d.month) + str(d.day))\n", 54 | "os.mkdir(result_dir)\n", 55 | "print('Created Directory {} to store the results in'.format(result_dir))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "n_runs = 10\n", 65 | "n_episodes = 1e6\n", 66 | "batch_size = 128\n", 67 | "\n", 68 | "seeds = np.load('../seeds/training_seeds.npy')\n", 69 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "actor_learning_rate = 0.001\n", 79 | "critic_learning_rate = 0.001\n", 80 | "tau = 0.005\n", 81 | "layer1_size = 256\n", 82 | "layer2_size = 256\n", 83 | "noise = 0.1\n", 84 | "warmup = 1000\n", 85 | "update_actor_interval = 2\n", 86 | "update_target_interval = 2\n", 87 | "buffer_size = int(2e5)\n", 88 | "pop_size = 10\n", 89 | "two_neuron = True\n", 90 | "pop_coding = False\n", 91 | "mutually_exclusive = False\n", 92 | "obs_range = [(-1,1), (-1,1), (-8,8)] " 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "actor_alpha = 0.5\n", 102 | "actor_beta = 0.5\n", 103 | "weight_scale = 1\n", 104 | "actor_threshold = 1\n", 105 | "actor_sim_time = 5" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "scrolled": true 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "smoothed_scores_all = []\n", 117 | "\n", 118 | "for i in range(n_runs):\n", 119 | " print(\"Run # {}\".format(i))\n", 120 | "\n", 121 | " seed = int(seeds[i])\n", 122 | " \n", 123 | " env = gym.make('Pendulum-v1')\n", 124 | " \n", 125 | " if two_neuron:\n", 126 | " input_dims = (env.observation_space.shape[0]*2,)\n", 127 | " elif pop_coding:\n", 128 | " input_dims = (env.observation_space.shape[0]*pop_size,)\n", 129 | " else:\n", 130 | " input_dims = env.observation_space.shape\n", 131 | " n_actions = env.action_space.shape[0]\n", 132 | "\n", 133 | " actor_architecture = [input_dims[0], layer1_size, layer2_size, n_actions]\n", 134 | " critic_architecture = [input_dims[0] + n_actions, layer1_size, layer2_size, n_actions]\n", 135 | "\n", 136 | " torch.manual_seed(seed)\n", 137 | " np.random.seed(seed)\n", 138 | " random.seed(seed)\n", 139 | "\n", 140 | " actor = TD3ActorDSNN(actor_architecture, seed, actor_alpha, actor_beta, weight_scale, 1,\n", 141 | " actor_threshold, actor_sim_time, actor_learning_rate, name='actor_{}'.format(i),\n", 142 | " device=device)\n", 143 | " target_actor = TD3ActorDSNN(actor_architecture, seed, actor_alpha, actor_beta, weight_scale, 1,\n", 144 | " actor_threshold, actor_sim_time, actor_learning_rate,\n", 145 | " name='target_actor_{}'.format(i), device=device)\n", 146 | " critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 147 | " layer2_size, n_actions=n_actions, name='critic_1_{}'.format(i))\n", 148 | " critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 149 | " layer2_size, n_actions=n_actions, name='critic_2_{}'.format(i))\n", 150 | " target_critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 151 | " layer2_size, n_actions=n_actions, name='target_critic_1_{}'.format(i))\n", 152 | " target_critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size,\n", 153 | " layer2_size, n_actions=n_actions, name='target_critic_2_{}'.format(i))\n", 154 | "\n", 155 | " agent = Agent(actor, critic_1, critic_2, target_actor, target_critic_1, target_critic_2,\n", 156 | " input_dims, tau, env, n_episodes, result_dir, n_actions=n_actions, seed=seed,\n", 157 | " noise=noise, update_actor_interval=update_actor_interval, warmup=warmup,\n", 158 | " update_target_interval=update_target_interval, two_neuron=two_neuron,\n", 159 | " buffer_size=buffer_size, spiking=True, spiking_critic=False, normalize=True)\n", 160 | "\n", 161 | " smoothed_scores, reward_history, best_average, best_average_after = agent.train_agent()\n", 162 | " smoothed_scores_all.append(smoothed_scores)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "final_smoothed_scores = [smoothed_scores_all[i] for i in range(n_runs)]\n", 172 | "mean_smoothed_scores = np.mean(final_smoothed_scores, axis=0)\n", 173 | "std_smoothed_scores = np.std(final_smoothed_scores, axis=0)\n", 174 | "\n", 175 | "fig = plt.figure()\n", 176 | "plt.plot(range(len(final_smoothed_scores[0])), mean_smoothed_scores)\n", 177 | "plt.fill_between(range(len(final_smoothed_scores[0])),\n", 178 | " np.nanpercentile(final_smoothed_scores, 5, axis=0),\n", 179 | " np.nanpercentile(final_smoothed_scores, 95, axis=0), alpha=0.25)\n", 180 | "\n", 181 | "plt.ylim(-1600, 0)\n", 182 | "plt.grid(True)\n", 183 | "plt.legend()\n", 184 | "plt.savefig(result_dir + '/td3_training_snn_pendulum.png', dpi=300, bbox_inches='tight')\n", 185 | "plt.show()" 186 | ] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "Python 3 (ipykernel)", 192 | "language": "python", 193 | "name": "python3" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.10.12" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 2 210 | } 211 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Spiking Reinforcement Learning 2 | Implementations of Deep Reinforcement Learning (DRL) algorithms with 3 | Spiking Neural Netowrks (SNNs) in PyTorch. SNNs are based on the 4 | SpyTorch implementations, with custom encoding and decoding mechanisms. 5 | 6 | ## Dependency installation 7 | Tested on Ubuntu 20.04 and Python 3.8.12. Creating a virtual environment is recommended. 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | For MuJoCo based environments (Ant-v3, HalfCheetah-v3, and Hopper-v3), install MuJoCo as described [here](https://github.com/openai/mujoco-py#install-mujoco). 13 | 14 | ## Citation 15 | 16 | If you use our code, please consider citing our research: 17 | 18 | ```bibtex 19 | @ARTICLE{10.3389/fnbot.2022.1075647, 20 | AUTHOR={Akl, Mahmoud and Ergene, Deniz and Walter, Florian and Knoll, Alois}, 21 | TITLE={Toward robust and scalable deep spiking reinforcement learning}, 22 | JOURNAL={Frontiers in Neurorobotics}, 23 | VOLUME={16}, 24 | YEAR={2023}, 25 | URL={https://www.frontiersin.org/articles/10.3389/fnbot.2022.1075647}, 26 | DOI={10.3389/fnbot.2022.1075647}, 27 | ISSN={1662-5218}, 28 | } 29 | ``` 30 | 31 | ```bibtex 32 | @inproceedings{10.1145/3546790.3546804, 33 | author = {Akl, Mahmoud and Sandamirskaya, Yulia and Ergene, Deniz and Walter, Florian and Knoll, Alois}, 34 | title = {Fine-Tuning Deep Reinforcement Learning Policies with r-STDP for Domain Adaptation}, 35 | year = {2022}, 36 | isbn = {9781450397896}, 37 | publisher = {Association for Computing Machinery}, 38 | address = {New York, NY, USA}, 39 | url = {https://doi.org/10.1145/3546790.3546804}, 40 | doi = {10.1145/3546790.3546804}, 41 | booktitle = {Proceedings of the International Conference on Neuromorphic Systems 2022}, 42 | articleno = {14}, 43 | numpages = {8}, 44 | keywords = {neural networks, spiking neural networks, reinforcement learning}, 45 | location = {Knoxville, TN, USA}, 46 | series = {ICONS '22} 47 | } 48 | ``` 49 | 50 | ```bibtex 51 | @inproceedings{10.1145/3477145.3477159, 52 | author = {Akl, Mahmoud and Sandamirskaya, Yulia and Walter, Florian and Knoll, Alois}, 53 | title = {Porting Deep Spiking Q-Networks to Neuromorphic Chip Loihi}, 54 | year = {2021}, 55 | isbn = {9781450386913}, 56 | publisher = {Association for Computing Machinery}, 57 | address = {New York, NY, USA}, 58 | url = {https://doi.org/10.1145/3477145.3477159}, 59 | doi = {10.1145/3477145.3477159}, 60 | booktitle = {International Conference on Neuromorphic Systems 2021}, 61 | articleno = {13}, 62 | numpages = {7}, 63 | keywords = {neuromorphic hardware, reinforcement learning, Spiking neural networks}, 64 | location = {Knoxville, TN, USA}, 65 | series = {ICONS 2021} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | notebook==6.5.2 2 | gym==0.15.7 3 | numpy==1.23.5 4 | torch==1.13.0 5 | matplotlib==3.5.1 6 | mujoco-py<2.2,>=2.1 7 | patchelf==0.17.0.0 8 | -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_0.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_1.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/acrobot/checkpoint_DSQN_2.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_0.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_1.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/rstdp_domain_adaptation/pretrained_models/cartpole/checkpoint_DSQN_2.pt -------------------------------------------------------------------------------- /rstdp_domain_adaptation/rstdp_acrobot_adaptation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fef2e3ee", 6 | "metadata": {}, 7 | "source": [ 8 | "# RSTDP domain adaptation of pre-trained agents for modified Acrobot environments" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "084483bc", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "from datetime import date\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import torch\n", 23 | "import torch.optim as optim\n", 24 | "import random\n", 25 | "import gym\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from copy import deepcopy\n", 28 | "\n", 29 | "from utils import evaluate_policy, rstdp_train_acrobot\n", 30 | "\n", 31 | "import site\n", 32 | "site.addsitedir('../src/')\n", 33 | "\n", 34 | "from dsnn import RSTDPNet\n", 35 | "\n", 36 | "%matplotlib inline" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "ab71a941", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 47 | "dtype = torch.float" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "7deedd78", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Environment specific parameters\n", 58 | "env_name = 'Acrobot-v1'\n", 59 | "min_reward = -500\n", 60 | "max_steps = 500\n", 61 | "\n", 62 | "n_evaluations = 100\n", 63 | "rstdp_episodes = 250" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "ac8cbe74", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Create environments\n", 74 | "original_env = gym.make(env_name)\n", 75 | "\n", 76 | "modified_env = gym.make(env_name)\n", 77 | "modified_env.unwrapped.LINK_LENGTH_1 *= 1.5\n", 78 | "modified_env.unwrapped.LINK_LENGTH_2 *= 1.5" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "94627454", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# SNN Hyperparameters\n", 89 | "simulation_time = 3\n", 90 | "alpha = 0.8\n", 91 | "beta = 0.8\n", 92 | "threshold = 1.0\n", 93 | "weight_scale = 1\n", 94 | "architecture = [12, 256, 256, 3]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "955f4b15", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# RSTDP Hyperparameters\n", 105 | "tau = 30\n", 106 | "tau_e = 10\n", 107 | "C = 0.05\n", 108 | "# A+/- are calculated from the pre-trained network weights" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "541f666b", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "evaluation_seeds = np.load('../seeds/evaluation_seeds.npy')\n", 119 | "rstdp_seeds = np.load('../seeds/rstdp_training_seeds.npy')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "10da94f2", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# Load pre-trained model weights\n", 130 | "weights_0 = torch.load('pretrained_models/acrobot/checkpoint_DSQN_0.pt', map_location=torch.device(device))\n", 131 | "weights_1 = torch.load('pretrained_models/acrobot/checkpoint_DSQN_1.pt', map_location=torch.device(device))\n", 132 | "weights_2 = torch.load('pretrained_models/acrobot/checkpoint_DSQN_2.pt', map_location=torch.device(device))\n", 133 | "weights = [weights_0, weights_1, weights_2]" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "cafb7b6a", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# Helper for printing\n", 144 | "eraser = '\\b \\b'" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "185ee721", 150 | "metadata": {}, 151 | "source": [ 152 | "### Evaluate pre-trained models on original environment" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "28f2b5b7", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "original_eval_rewards = []\n", 163 | "\n", 164 | "for i, w in enumerate(weights):\n", 165 | " print('Run {:02d} ...'.format(i), end='')\n", 166 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w, \n", 167 | " device=device, dtype=dtype)\n", 168 | " rewards = evaluate_policy(policy_net, original_env, n_evaluations, evaluation_seeds)\n", 169 | " original_eval_rewards.append(rewards)\n", 170 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "c0c40289", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "plt.figure(figsize=(16, 4))\n", 181 | "\n", 182 | "for i, oer in enumerate(original_eval_rewards):\n", 183 | " plt.plot(oer, label='Run {:02d}'.format(i))\n", 184 | "\n", 185 | "plt.legend()\n", 186 | "plt.grid()\n", 187 | "plt.show()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "c811342f", 193 | "metadata": {}, 194 | "source": [ 195 | "### Evaluate pre-trained models on modified environment" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "c835c189", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "modified_env_eval_rewards = []\n", 206 | "\n", 207 | "for i, w in enumerate(weights):\n", 208 | " print('Run {:02d} ...'.format(i), end='')\n", 209 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w,\n", 210 | " device=device, dtype=dtype)\n", 211 | " rewards = evaluate_policy(policy_net, modified_env, n_evaluations, evaluation_seeds)\n", 212 | " modified_env_eval_rewards.append(rewards)\n", 213 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "90edc465", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "plt.figure(figsize=(16, 4))\n", 224 | "\n", 225 | "for i, meer in enumerate(modified_env_eval_rewards):\n", 226 | " plt.plot(meer, label='Run {:02d}'.format(i))\n", 227 | "\n", 228 | "plt.legend()\n", 229 | "plt.grid()\n", 230 | "plt.show()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "fbd4678c", 236 | "metadata": {}, 237 | "source": [ 238 | "### RSTDP Adaptation" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "id": "5f9bfabb", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "rstdp_adaptation_rewards = []\n", 249 | "adapted_weights_collection = []\n", 250 | "\n", 251 | "for i, w in enumerate(weights):\n", 252 | " w_plus = deepcopy(w[0][1])\n", 253 | " w_minus = deepcopy(w[0][1])\n", 254 | " w_plus[w_plus < 0] = 0\n", 255 | " w_minus[w_minus > 0] = 0\n", 256 | " A_plus = torch.mean(w_plus)\n", 257 | " A_minus = torch.abs(torch.mean(w_minus))\n", 258 | " \n", 259 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w, \n", 260 | " tau, tau_e, A_plus, A_minus, C, \n", 261 | " device=device, dtype=dtype)\n", 262 | " \n", 263 | " adapted_weights, rewards = rstdp_train_acrobot(policy_net, modified_env, min_reward, rstdp_episodes, \n", 264 | " n_evaluations, max_steps, rstdp_seeds, evaluation_seeds)\n", 265 | " \n", 266 | " rstdp_adaptation_rewards.append(rewards)\n", 267 | " adapted_weights_collection.append(adapted_weights)\n", 268 | " \n", 269 | "adapted_weights_collection = [(list(aw.values()), []) for aw in adapted_weights_collection]" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "585101eb", 275 | "metadata": {}, 276 | "source": [ 277 | "### Evaluate adapted models on modified environment" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "6640b7d6", 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "adapted_eval_rewards = []\n", 288 | "\n", 289 | "for i, w in enumerate(adapted_weights_collection):\n", 290 | " print('Run {:02d} ...'.format(i), end='')\n", 291 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w,\n", 292 | " device=device, dtype=dtype)\n", 293 | " rewards = evaluate_policy(policy_net, modified_env, n_evaluations, evaluation_seeds)\n", 294 | " adapted_eval_rewards.append(rewards)\n", 295 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "id": "2231eb92", 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "plt.figure(figsize=(16, 4))\n", 306 | "\n", 307 | "for i, aer in enumerate(adapted_eval_rewards):\n", 308 | " plt.plot(aer, label='Run {:02d}'.format(i))\n", 309 | "\n", 310 | "plt.legend()\n", 311 | "plt.grid()\n", 312 | "plt.show()" 313 | ] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3 (ipykernel)", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.8.12" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 5 337 | } 338 | -------------------------------------------------------------------------------- /rstdp_domain_adaptation/rstdp_cartpole_adaptation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fef2e3ee", 6 | "metadata": {}, 7 | "source": [ 8 | "# RSTDP domain adaptation of pre-trained agents for modified CartPole environments" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "084483bc", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "from datetime import date\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import torch\n", 23 | "import torch.optim as optim\n", 24 | "import random\n", 25 | "import gym\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from copy import deepcopy\n", 28 | "\n", 29 | "from utils import evaluate_policy, rstdp_train_cartpole\n", 30 | "\n", 31 | "import site\n", 32 | "site.addsitedir('../src/')\n", 33 | "\n", 34 | "from dsnn import RSTDPNet\n", 35 | "\n", 36 | "%matplotlib inline" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "ab71a941", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 47 | "dtype = torch.float" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "7deedd78", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Environment specific parameters\n", 58 | "env_name = 'CartPole-v0'\n", 59 | "max_reward = 200\n", 60 | "max_steps = 200\n", 61 | "\n", 62 | "n_evaluations = 100\n", 63 | "rstdp_episodes = 250" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "ac8cbe74", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Create environments\n", 74 | "original_env = gym.make(env_name)\n", 75 | "\n", 76 | "modified_env = gym.make(env_name)\n", 77 | "modified_env.unwrapped.length *= 1.5" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "94627454", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# SNN Hyperparameters\n", 88 | "simulation_time = 8\n", 89 | "alpha = 0.8\n", 90 | "beta = 0.8\n", 91 | "threshold = 0.5\n", 92 | "weight_scale = 1\n", 93 | "architecture = [8, 64, 64, 2]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "955f4b15", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# RSTDP Hyperparameters\n", 104 | "tau = 5\n", 105 | "tau_e = 10\n", 106 | "C = 0.01\n", 107 | "# A+/- are calculated from the pre-trained network weights" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "541f666b", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "evaluation_seeds = np.load('../seeds/evaluation_seeds.npy')\n", 118 | "rstdp_seeds = np.load('../seeds/rstdp_training_seeds.npy')" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "10da94f2", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Load pre-trained model weights\n", 129 | "weights_0 = torch.load('pretrained_models/cartpole/checkpoint_DSQN_0.pt', map_location=torch.device(device))\n", 130 | "weights_1 = torch.load('pretrained_models/cartpole/checkpoint_DSQN_1.pt', map_location=torch.device(device))\n", 131 | "weights_2 = torch.load('pretrained_models/cartpole/checkpoint_DSQN_2.pt', map_location=torch.device(device))\n", 132 | "weights = [weights_0, weights_1, weights_2]" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "cafb7b6a", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# Helper for printing\n", 143 | "eraser = '\\b \\b'" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "id": "185ee721", 149 | "metadata": {}, 150 | "source": [ 151 | "### Evaluate pre-trained models on original environment" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "28f2b5b7", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "original_eval_rewards = []\n", 162 | "\n", 163 | "for i, w in enumerate(weights):\n", 164 | " print('Run {:02d} ...'.format(i), end='')\n", 165 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w, \n", 166 | " device=device, dtype=dtype)\n", 167 | " rewards = evaluate_policy(policy_net, original_env, n_evaluations, evaluation_seeds)\n", 168 | " original_eval_rewards.append(rewards)\n", 169 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "c0c40289", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "plt.figure(figsize=(16, 4))\n", 180 | "\n", 181 | "for i, oer in enumerate(original_eval_rewards):\n", 182 | " plt.plot(oer, label='Run {:02d}'.format(i))\n", 183 | "\n", 184 | "plt.legend()\n", 185 | "plt.grid()\n", 186 | "plt.show()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "c811342f", 192 | "metadata": {}, 193 | "source": [ 194 | "### Evaluate pre-trained models on modified environment" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "c835c189", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "modified_env_eval_rewards = []\n", 205 | "\n", 206 | "for i, w in enumerate(weights):\n", 207 | " print('Run {:02d} ...'.format(i), end='')\n", 208 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w,\n", 209 | " device=device, dtype=dtype)\n", 210 | " rewards = evaluate_policy(policy_net, modified_env, n_evaluations, evaluation_seeds)\n", 211 | " modified_env_eval_rewards.append(rewards)\n", 212 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "90edc465", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "plt.figure(figsize=(16, 4))\n", 223 | "\n", 224 | "for i, meer in enumerate(modified_env_eval_rewards):\n", 225 | " plt.plot(meer, label='Run {:02d}'.format(i))\n", 226 | "\n", 227 | "plt.legend()\n", 228 | "plt.grid()\n", 229 | "plt.show()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "fbd4678c", 235 | "metadata": {}, 236 | "source": [ 237 | "### RSTDP Adaptation" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "5f9bfabb", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "rstdp_adaptation_rewards = []\n", 248 | "adapted_weights_collection = []\n", 249 | "\n", 250 | "for i, w in enumerate(weights):\n", 251 | " w_plus = deepcopy(w[0][1])\n", 252 | " w_minus = deepcopy(w[0][1])\n", 253 | " w_plus[w_plus < 0] = 0\n", 254 | " w_minus[w_minus > 0] = 0\n", 255 | " A_plus = torch.mean(w_plus)\n", 256 | " A_minus = torch.abs(torch.mean(w_minus))\n", 257 | " \n", 258 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w, \n", 259 | " tau, tau_e, A_plus, A_minus, C, \n", 260 | " device=device, dtype=dtype)\n", 261 | " \n", 262 | " adapted_weights, rewards = rstdp_train_cartpole(policy_net, modified_env, max_reward, rstdp_episodes, \n", 263 | " n_evaluations, max_steps, rstdp_seeds, evaluation_seeds)\n", 264 | " \n", 265 | " rstdp_adaptation_rewards.append(rewards)\n", 266 | " adapted_weights_collection.append(adapted_weights)\n", 267 | " \n", 268 | "adapted_weights_collection = [(list(aw.values()), []) for aw in adapted_weights_collection]" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "id": "585101eb", 274 | "metadata": {}, 275 | "source": [ 276 | "### Evaluate adapted models on modified environment" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "id": "6640b7d6", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "adapted_eval_rewards = []\n", 287 | "\n", 288 | "for i, w in enumerate(adapted_weights_collection):\n", 289 | " print('Run {:02d} ...'.format(i), end='')\n", 290 | " policy_net = RSTDPNet(alpha, beta, threshold, architecture, simulation_time, w,\n", 291 | " device=device, dtype=dtype)\n", 292 | " rewards = evaluate_policy(policy_net, modified_env, n_evaluations, evaluation_seeds)\n", 293 | " adapted_eval_rewards.append(rewards)\n", 294 | " print(eraser*3 + '-> Avg reward: {:7.2f}'.format(np.mean(rewards)))" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "id": "2231eb92", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "plt.figure(figsize=(16, 4))\n", 305 | "\n", 306 | "for i, aer in enumerate(adapted_eval_rewards):\n", 307 | " plt.plot(aer, label='Run {:02d}'.format(i))\n", 308 | "\n", 309 | "plt.legend()\n", 310 | "plt.grid()\n", 311 | "plt.show()" 312 | ] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 3 (ipykernel)", 318 | "language": "python", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.8.12" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 5 336 | } 337 | -------------------------------------------------------------------------------- /rstdp_domain_adaptation/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | import site 6 | site.addsitedir('../src/') 7 | 8 | default_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 9 | 10 | 11 | def evaluate_policy(policy_net, env, n_evaluations, seeds): 12 | """ Evaluate a policy on a cartpole environment. 13 | 14 | Arguments: 15 | - policy_net: policy to evaluate. Should have four inputs and 2 outputs. 16 | - env: Environment object to use for evaluation. 17 | - n_evaluations: Number of evaluation runs. 18 | - seeds: Seeds for the environment. Should have at least 'n_evaluations' entries. 19 | 20 | Returns: 21 | List of rewards. 22 | """ 23 | eval_rewards = [] 24 | for i in range(n_evaluations): 25 | env.seed(int(seeds[i])) 26 | 27 | state = env.reset() 28 | reward = 0 29 | done = False 30 | 31 | while not done: 32 | inputs = torch.from_numpy(state).float() 33 | mem_result, _ = policy_net(inputs, rstdp_state=(None, None, None)) 34 | action = torch.argmax(mem_result) 35 | state, r, done, _ = env.step(action.item()) 36 | reward += r 37 | eval_rewards.append(reward) 38 | 39 | return eval_rewards 40 | 41 | 42 | def rstdp_train_cartpole(policy_net, env, max_reward, num_episodes, n_evaluations, max_steps, rstdp_seeds, evaluation_seeds): 43 | """ Train policy on cartpole environment with RSTDP. 44 | 45 | Arguments: 46 | - policy_net: Policy to train. 47 | - env: Environment object to train on. 48 | - max_reward: Maximum achievable reward in 'env'. Used for weight update calculation. 49 | - num_episodes: Number of training episodes. 50 | - n_evaluations: Number of evaluations for stopping/saving criterion. 51 | - max_steps: Maximum number of steps in 'env'. 52 | - rstdp_seeds: Environment seeds for rstdp-training. 53 | - evaluations_seeds: Environment seeds for evaluation. 54 | 55 | Returns: 56 | - Trained weights. 57 | - Rewards achieved during training. 58 | """ 59 | env._max_episode_steps = max_steps 60 | 61 | best_reward = -np.inf 62 | best_episode = -1 63 | best_weights = None 64 | 65 | rewards = [] 66 | 67 | for i_episode in range(num_episodes): 68 | env.seed(int(rstdp_seeds[i_episode])) 69 | 70 | e_trace = None 71 | 72 | state = env.reset() 73 | total_reward = 0 74 | for t in range(max_steps): 75 | inputs = torch.from_numpy(state).float() 76 | final_layer_values, rstdp_out = policy_net.forward(inputs, (e_trace, None, None)) 77 | final_layer_values = final_layer_values.cpu().data.numpy() 78 | e_trace = rstdp_out[0][0] 79 | 80 | action = np.argmax(final_layer_values) 81 | state, reward, done, _ = env.step(action) 82 | 83 | total_reward += reward 84 | if done: 85 | break 86 | 87 | rewards.append(total_reward) 88 | 89 | # RSTDP update 90 | delta_rstdp = -e_trace * (1.0 - total_reward / max_reward) 91 | policy_net.l2.weights += delta_rstdp[0] 92 | 93 | eval_rewards = evaluate_policy(policy_net, env, n_evaluations, evaluation_seeds) 94 | avg_eval_reward = np.mean(eval_rewards) 95 | 96 | print("Episode: {:4d} -- Reward: {:7.2f} -- Best reward: {:7.2f} in episode {:4d}"\ 97 | .format(i_episode, avg_eval_reward, best_reward, best_episode), end='\r') 98 | 99 | if avg_eval_reward > best_reward: 100 | best_reward = avg_eval_reward 101 | best_episode = i_episode 102 | best_weights = deepcopy(policy_net.state_dict()) 103 | 104 | if best_reward >= max_reward: 105 | break 106 | 107 | print('\nBest individual stored after episode {:d} with reward {:6.2f}'.format(best_episode, best_reward)) 108 | print() 109 | return best_weights, rewards 110 | 111 | 112 | def rstdp_train_acrobot(policy_net, env, min_reward, num_episodes, n_evaluations, max_steps, rstdp_seeds, evaluation_seeds): 113 | """ Train policy on acrobot environment with RSTDP. 114 | 115 | Arguments: 116 | - policy_net: Policy to train. 117 | - env: Environment object to train on. 118 | - max_reward: Maximum achievable reward in 'env'. Used for weight update calculation. 119 | - num_episodes: Number of training episodes. 120 | - n_evaluations: Number of evaluations for stopping/saving criterion. 121 | - max_steps: Maximum number of steps in 'env'. 122 | - rstdp_seeds: Environment seeds for rstdp-training. 123 | - evaluations_seeds: Environment seeds for evaluation. 124 | 125 | Returns: 126 | - Trained weights. 127 | - Rewards achieved during training. 128 | """ 129 | env._max_episode_steps = max_steps 130 | 131 | best_reward = -np.inf 132 | best_episode = -1 133 | best_weights = None 134 | 135 | rewards = [] 136 | 137 | for i_episode in range(num_episodes): 138 | env.seed(int(rstdp_seeds[i_episode])) 139 | 140 | e_trace = None 141 | 142 | state = env.reset() 143 | total_reward = 0 144 | for t in range(max_steps): 145 | inputs = torch.from_numpy(state).float() 146 | final_layer_values, rstdp_out = policy_net.forward(inputs, (e_trace, None, None)) 147 | final_layer_values = final_layer_values.cpu().data.numpy() 148 | e_trace = rstdp_out[0][0] 149 | 150 | action = np.argmax(final_layer_values) 151 | state, reward, done, _ = env.step(action) 152 | 153 | total_reward += reward 154 | if done: 155 | break 156 | 157 | rewards.append(total_reward) 158 | 159 | # RSTDP update 160 | delta_rstdp = e_trace * (total_reward/min_reward) 161 | policy_net.l2.weights += delta_rstdp[0] 162 | 163 | eval_rewards = evaluate_policy(policy_net, env, n_evaluations, evaluation_seeds) 164 | avg_eval_reward = np.mean(eval_rewards) 165 | 166 | print("Episode: {:4d} -- Reward: {:7.2f} -- Best reward: {:7.2f} in episode {:4d}"\ 167 | .format(i_episode, avg_eval_reward, best_reward, best_episode), end='\r') 168 | 169 | if avg_eval_reward > best_reward: 170 | best_reward = avg_eval_reward 171 | best_episode = i_episode 172 | best_weights = deepcopy(policy_net.state_dict()) 173 | 174 | if best_reward >= -100: 175 | break 176 | 177 | print('\nBest individual stored after episode {:d} with reward {:6.2f}'.format(best_episode, best_reward)) 178 | print() 179 | return best_weights, rewards -------------------------------------------------------------------------------- /seeds/evaluation_seeds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/seeds/evaluation_seeds.npy -------------------------------------------------------------------------------- /seeds/rstdp_training_seeds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/seeds/rstdp_training_seeds.npy -------------------------------------------------------------------------------- /seeds/training_seeds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/seeds/training_seeds.npy -------------------------------------------------------------------------------- /src/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import random 4 | 5 | import numpy as np 6 | import gymnasium as gym 7 | import torch.nn.functional as F 8 | 9 | from model import DSNN 10 | from collections import namedtuple, deque 11 | 12 | 13 | sys.path.append('../') 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class ReplayBuffer: 18 | 19 | def __init__(self, buffer_size, batch_size, seed): 20 | self.memory = deque(maxlen=buffer_size) 21 | self.batch_size = batch_size 22 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", 23 | "next_state", "done"]) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | 27 | def add(self, state, action, reward, next_state, done): 28 | e = self.experience(state, action, reward, next_state, done) 29 | self.memory.append(e) 30 | 31 | def sample(self): 32 | experiences = random.sample(self.memory, k=self.batch_size) 33 | 34 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).\ 35 | float().to(device) 36 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).\ 37 | long().to(device) 38 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).\ 39 | float().to(device) 40 | next_states = torch.from_numpy( 41 | np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 42 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]). 43 | astype(np.uint8)).float().to(device) 44 | 45 | return (states, actions, rewards, next_states, dones) 46 | 47 | def __len__(self): 48 | return len(self.memory) 49 | 50 | 51 | class Agent: 52 | def __init__(self, env, policy_net, target_net, architecture, batch_size, memory_size, gamma, 53 | eps_start, eps_end, eps_decay, update_every, target_update_frequency, optimizer, 54 | learning_rate, num_episodes, max_steps, i_run, result_dir, seed, tau, 55 | spiking=False, two_neuron=False): 56 | 57 | self.env = gym.make(env) 58 | 59 | random.seed(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | 63 | self.policy_net = policy_net 64 | self.target_net = target_net 65 | 66 | self.architecture = architecture 67 | self.batch_size = batch_size 68 | self.memory_size = memory_size 69 | self.gamma = gamma 70 | self.eps_start = eps_start 71 | self.eps_end = eps_end 72 | self.eps_decay = eps_decay 73 | self.update_every = update_every 74 | self.target_update_frequency = target_update_frequency 75 | self.optimizer = optimizer 76 | self.learning_rate = learning_rate 77 | self.num_episodes = num_episodes 78 | self.max_steps = max_steps 79 | self.i_run = i_run 80 | self.result_dir = result_dir 81 | self.tau = tau 82 | self.spiking = spiking 83 | self.random = random 84 | self.two_neuron = two_neuron 85 | self.seed = seed 86 | 87 | # Initialize Replay Memory 88 | self.memory = ReplayBuffer(self.memory_size, self.batch_size, seed) 89 | 90 | # Initialize time step 91 | self.t_step = 0 92 | self.t_step_total = 0 93 | 94 | def select_action(self, state, eps=0.): 95 | state = torch.from_numpy(state) 96 | state = state.unsqueeze(0).to(device) 97 | if random.random() > eps: 98 | with torch.no_grad(): 99 | if self.spiking: 100 | final_layer_values = self.policy_net.forward(state.float())[0].\ 101 | cpu().data.numpy() 102 | return np.argmax(final_layer_values) 103 | else: 104 | return np.argmax(self.policy_net.forward(state.float())[0].cpu().data.numpy()) 105 | else: 106 | return random.choice(np.arange(self.architecture[-1])) 107 | 108 | def step(self, state, action, reward, next_state, done): 109 | self.memory.add(state, action, reward, next_state, done) 110 | 111 | # Learn every UPDATE_EVERY time steps. 112 | self.t_step = (self.t_step + 1) % self.update_every 113 | if self.t_step == 0: 114 | # If enough samples are available in memory, get random subset and learn 115 | if len(self.memory) > self.batch_size: 116 | experiences = self.memory.sample() 117 | self.optimize_model(experiences) 118 | 119 | def optimize_model(self, experiences): 120 | states, actions, rewards, next_states, dones = experiences 121 | 122 | # Get max predicted Q values (for next states) from target model 123 | if self.spiking: 124 | Q_targets_next = self.target_net.forward(next_states)[0].detach().max(1)[0].unsqueeze(1) 125 | else: 126 | Q_targets_next = self.target_net(next_states).detach().max(1)[0].unsqueeze(1) 127 | 128 | # Compute Q targets for current states 129 | Q_targets = rewards + (self.gamma * Q_targets_next*(1 - dones)) 130 | 131 | # Get expected Q values from local model 132 | if self.spiking: 133 | Q_expected = self.policy_net.forward(states)[0].gather(1, actions) 134 | else: 135 | Q_expected = self.policy_net.forward(states).gather(1, actions) 136 | 137 | # Compute loss 138 | loss = F.mse_loss(Q_expected, Q_targets) 139 | # Minimize the loss 140 | self.optimizer.zero_grad() 141 | loss.backward(retain_graph=True) 142 | 143 | self.optimizer.step() 144 | if self.t_step_total % self.target_update_frequency == 0: 145 | self.soft_update() 146 | 147 | def soft_update(self): 148 | self.target_net.load_state_dict(self.policy_net.state_dict()) 149 | 150 | def transform_state(self, state): 151 | state_ = [] 152 | for i in state: 153 | if i > 0: 154 | state_.append(i) 155 | state_.append(0) 156 | else: 157 | state_.append(0) 158 | state_.append(abs(i)) 159 | return np.array(state_) 160 | 161 | def train_agent(self): 162 | best_average = -np.inf 163 | best_average_after = np.inf 164 | scores = [] 165 | smoothed_scores = [] 166 | scores_window = deque(maxlen=100) 167 | eps = self.eps_start 168 | 169 | for i_episode in range(1, self.num_episodes + 1): 170 | state, _ = self.env.reset(seed=self.seed) 171 | if self.two_neuron: 172 | state = self.transform_state(state) 173 | score = 0 174 | done = False 175 | while not done: 176 | self.t_step_total += 1 177 | action = self.select_action(state, eps) 178 | next_state, reward, done1, done2, _ = self.env.step(action) 179 | done = done1 or done2 180 | if self.two_neuron: 181 | next_state = self.transform_state(next_state) 182 | self.step(state, action, reward, next_state, done) 183 | state = next_state 184 | score += reward 185 | eps = max(self.eps_end, self.eps_decay * eps) 186 | if done: 187 | break 188 | scores_window.append(score) 189 | scores.append(score) 190 | smoothed_scores.append(np.mean(scores_window)) 191 | 192 | if smoothed_scores[-1] > best_average: 193 | best_average = smoothed_scores[-1] 194 | best_average_after = i_episode 195 | if self.spiking: 196 | torch.save(self.policy_net.state_dict(), 197 | self.result_dir + '/checkpoint_DSQN_{}.pt'.format(self.i_run)) 198 | else: 199 | torch.save(self.policy_net.state_dict(), 200 | self.result_dir + '/checkpoint_DQN_{}.pt'.format(self.i_run)) 201 | 202 | print("Episode {}\tAverage Score: {:.2f}\t Epsilon: {:.2f}". 203 | format(i_episode, np.mean(scores_window), eps), end='\r') 204 | 205 | if i_episode % 100 == 0: 206 | print("\rEpisode {}\tAverage Score: {:.2f}". 207 | format(i_episode, np.mean(scores_window))) 208 | 209 | print('Best 100 episode average: ', best_average, ' reached at episode ', 210 | best_average_after, '. Model saved in folder best.') 211 | return smoothed_scores, scores, best_average, best_average_after 212 | 213 | 214 | def evaluate_agent(policy_net, env, num_episodes, max_steps, gym_seeds, epsilon=0): 215 | """ 216 | 217 | """ 218 | rewards = [] 219 | 220 | for i_episode in range(num_episodes): 221 | env.seed(int(gym_seeds[i_episode])) 222 | env._max_episode_steps = max_steps 223 | state = env.reset() 224 | state = torch.from_numpy(state).float().unsqueeze(0).to(device) 225 | total_reward = 0 226 | for t in range(max_steps): 227 | if random.random() >= epsilon: 228 | final_layer_values = policy_net.forward(state.float())[0].cpu().data.numpy() 229 | action = np.argmax(final_layer_values) 230 | else: 231 | action = random.randint(0, env.action_space.n - 1) 232 | 233 | observation, reward, done, _ = env.step(action) 234 | state = torch.from_numpy(observation).float().unsqueeze(0).to(device) 235 | total_reward += reward 236 | if done: 237 | break 238 | rewards.append(total_reward) 239 | print("Episode: {}".format(i_episode), end='\r') 240 | 241 | return rewards 242 | -------------------------------------------------------------------------------- /src/dsnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from functools import partial 6 | 7 | from model import SurrGradSpike 8 | from rstdp import RSTDP 9 | 10 | default_device = ('cuda:0' if torch.cuda.is_available() else 'cpu') 11 | 12 | class DSNN(nn.Module): 13 | def __init__(self, 14 | state_size, 15 | operation, 16 | weights_size, 17 | alpha, 18 | beta, 19 | is_spiking=True, 20 | threshold=.1, 21 | dtype=torch.float, 22 | #device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | device=default_device 24 | ): 25 | super(DSNN, self).__init__() 26 | 27 | self.dtype = dtype 28 | self.device = device 29 | 30 | self.state_size = state_size 31 | self.operation = operation 32 | 33 | self.weights = torch.zeros(weights_size, dtype=self.dtype, device=self.device, requires_grad=True) 34 | nn.init.normal_(self.weights, mean=0., std=.1) 35 | self.weights = nn.Parameter(self.weights) 36 | 37 | self.alpha = alpha 38 | self.beta = beta 39 | self.threshold = threshold 40 | 41 | self.spike_fn = SurrGradSpike.apply 42 | self.is_spiking = is_spiking 43 | 44 | 45 | def _get_initial_state(self, template, amount_state_vars): 46 | assert template.shape[1:] == self.state_size, \ 47 | "Specified 'state_size' {} does not match size of tensor resulting from operation {}." \ 48 | .format(self.state_size, template.shape[1:]) 49 | 50 | return [torch.zeros_like(template, dtype=self.dtype, device=self.device) for _ in range(amount_state_vars)] 51 | 52 | 53 | def forward(self, inputs, state=(None, None)): 54 | mem, syn = state 55 | spk_out, mem_out, syn_out = [], [], [] 56 | 57 | for t in range(inputs.shape[1]): 58 | h = self.operation(inputs[:, t, ...], self.weights) 59 | 60 | if mem is None or syn is None: 61 | mem, syn = self._get_initial_state(template=h, amount_state_vars=2) 62 | 63 | new_syn = self.alpha * syn + h 64 | new_mem = (self.beta * mem + new_syn) 65 | 66 | if self.is_spiking: 67 | mthr = new_mem - self.threshold 68 | out = self.spike_fn(mthr) 69 | c = (mthr > 0.) 70 | new_mem[c] = 0. 71 | else: 72 | out = torch.zeros_like(mem) 73 | 74 | spk_out.append(out) 75 | mem_out.append(new_mem) 76 | syn_out.append(new_syn) 77 | 78 | mem = new_mem 79 | syn = new_syn 80 | 81 | spk_out = torch.stack(spk_out, dim=1) 82 | mem_out = torch.stack(mem_out, dim=1) 83 | syn_out = torch.stack(syn_out, dim=1) 84 | 85 | return spk_out, (mem, syn), (mem_out, syn_out) 86 | 87 | 88 | class DSRNN(DSNN): 89 | def __init__(self, 90 | state_size, 91 | operation, 92 | weights_size, 93 | alpha, 94 | beta, 95 | is_spiking=True, 96 | threshold=.1, 97 | dtype=torch.float, 98 | device=torch.device('cpu') 99 | ): 100 | super(DSRNN, self).__init__( 101 | state_size, operation, weights_size, alpha, beta, is_spiking, threshold, dtype, device 102 | ) 103 | 104 | num_units = np.prod(self.state_size) 105 | 106 | self.recurrent_weights = torch.zeros((num_units, num_units), dtype=self.dtype, device=self.device, requires_grad=True) 107 | nn.init.normal_(self.recurrent_weights, mean=0., std=.1) 108 | self.recurrent_weights = nn.Parameter(self.recurrent_weights) 109 | 110 | 111 | def forward(self, inputs, state=(None, None, None)): 112 | mem, syn, spk = state 113 | spk_rec, mem_rec, syn_rec = [], [], [] 114 | 115 | for t in range(inputs.shape[1]): 116 | h = self.operation(inputs[:, t, ...], self.weights) 117 | 118 | if mem is None or syn is None or spk is None: 119 | mem, syn, spk = self._get_initial_state(template=h, amount_state_vars=3) 120 | 121 | # Calculate recurrent connections: 122 | # Convert spk tensor to size (batch_size, num_units) 123 | linear_spk = spk.view(spk.shape[0], -1) 124 | # Calculate recurrent inputs 125 | rec_h = torch.einsum('ab,bc->ac', linear_spk, self.recurrent_weights) 126 | # Reshape recurrent inputs back to size of spks (batch_size, state_size) 127 | rec_h = rec_h.view(spk.shape) 128 | 129 | new_syn = self.alpha * syn + h + rec_h 130 | new_mem = (self.beta * mem + new_syn) 131 | 132 | if self.is_spiking: 133 | mthr = new_mem - self.threshold 134 | out = self.spike_fn(mthr) 135 | c = (mthr > 0.) 136 | new_mem[c] = 0. 137 | else: 138 | out = torch.zeros_like(mem) 139 | 140 | spk_rec.append(out) 141 | mem_rec.append(new_mem) 142 | syn_rec.append(new_syn) 143 | 144 | mem = new_mem 145 | syn = new_syn 146 | spk = out 147 | 148 | spk_rec = torch.stack(spk_rec, dim=1) 149 | mem_rec = torch.stack(mem_rec, dim=1) 150 | syn_rec = torch.stack(syn_rec, dim=1) 151 | 152 | return spk_rec, (mem, syn, spk), (mem_rec, syn_rec) 153 | 154 | 155 | def transform_state(state): 156 | state_ = [] 157 | for i in state: 158 | if i > 0: 159 | state_.append(i) 160 | state_.append(0) 161 | else: 162 | state_.append(0) 163 | state_.append(abs(i)) 164 | return torch.tensor(state_) 165 | 166 | 167 | class RSTDPNet(nn.Module): 168 | def __init__(self, alpha, beta, threshold, architecture, simulation_time, weights, tau=16, tau_e=800, 169 | A_plus=0.00001, A_minus=0.00001, C=0.01, device=default_device, dtype=torch.float): 170 | super(RSTDPNet, self).__init__() 171 | 172 | self.simulation_time = simulation_time 173 | self.device = device 174 | self.dtype = dtype 175 | 176 | self.l1 = DSNN( 177 | state_size=(architecture[1],), 178 | operation=partial(torch.einsum, 'ab,bc->ac'), 179 | weights_size=(architecture[0], architecture[1]), 180 | alpha=alpha, 181 | beta=beta, 182 | threshold=threshold, 183 | device=self.device, 184 | dtype=self.dtype 185 | ) 186 | # Set 'requires_grad' to 'False', since value assignment is not supported otherwise 187 | self.l1.weights = nn.Parameter(weights[0][0].clone(), requires_grad=False) 188 | self.l2 = DSNN( 189 | state_size=(architecture[2],), 190 | operation=partial(torch.einsum, 'ab,bc->ac'), 191 | weights_size=(architecture[1], architecture[2]), 192 | alpha=alpha, 193 | beta=beta, 194 | threshold=threshold, 195 | device=self.device, 196 | dtype=self.dtype 197 | ) 198 | # As above, set 'requires_grad' to 'False', since value assignment is not supported otherwise 199 | self.l2.weights = nn.Parameter(weights[0][1].clone(), requires_grad=False) 200 | self.l3 = DSNN( 201 | state_size=(architecture[3],), 202 | operation=partial(torch.einsum, 'ab,bc->ac'), 203 | weights_size=(architecture[2], architecture[3]), 204 | alpha=alpha, 205 | beta=beta, 206 | threshold=threshold, 207 | is_spiking=False, 208 | device=self.device, 209 | dtype=self.dtype 210 | ) 211 | self.l3.weights = nn.Parameter(weights[0][2].clone(), requires_grad=False) 212 | 213 | self.rstdp=RSTDP(A_plus=A_plus, A_minus=A_minus, tau_plus=tau, tau_minus=tau, tau_e=tau_e, C=C, 214 | device=self.device, dtype=self.dtype) 215 | 216 | def forward(self, inputs, rstdp_state=(None, None, None)): 217 | # Two neuron encoding 218 | inputs = transform_state(inputs) 219 | # Expand analogue inputs for each timestep 220 | inputs = torch.tile(inputs, (1, self.simulation_time, 1)).to(self.device) 221 | 222 | # Calculate layers 223 | y_l1, _, _ = self.l1(inputs) 224 | y_l2, _, _ = self.l2(y_l1) 225 | _, mem_result, _ = self.l3(y_l2) 226 | 227 | # Calculate rstdp 228 | rstdp_out = self.rstdp(y_l1, y_l2, rstdp_state) 229 | 230 | return mem_result[0], rstdp_out 231 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class SurrGradSpike(torch.autograd.Function): 15 | """ 16 | Here we implement our spiking nonlinearity which also implements 17 | the surrogate gradient. By subclassing torch.autograd.Function, 18 | we will be able to use all of PyTorch's autograd functionality. 19 | Here we use the normalized negative part of a fast sigmoid 20 | as this was done in Zenke & Ganguli (2018). 21 | """ 22 | 23 | scale = 100.0 # controls steepness of surrogate gradient 24 | 25 | @staticmethod 26 | def forward(ctx, input): 27 | """ 28 | In the forward pass we compute a step function of the input Tensor 29 | and return it. ctx is a context object that we use to stash information which 30 | we need to later backpropagate our error signals. To achieve this we use the 31 | ctx.save_for_backward method. 32 | """ 33 | ctx.save_for_backward(input) 34 | out = torch.zeros_like(input) 35 | out[input > 0] = 1.0 36 | # Only for inhibitory spikes 37 | #out[input < 0] = -1.0 38 | return out 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | """ 43 | In the backward pass we receive a Tensor we need to compute the 44 | surrogate gradient of the loss with respect to the input. 45 | Here we use the normalized negative part of a fast sigmoid 46 | as this was done in Zenke & Ganguli (2018). 47 | """ 48 | input, = ctx.saved_tensors 49 | grad_input = grad_output.clone() 50 | grad = grad_input / (SurrGradSpike.scale * torch.abs(input) + 1.0) ** 2 51 | return grad 52 | 53 | 54 | class DSNN(nn.Module): 55 | def __init__(self, architecture, seed, alpha, beta, weight_scale, batch_size, threshold, 56 | simulation_time, learning_rate, reset_potential=0): 57 | """ 58 | 59 | """ 60 | self.architecture = architecture 61 | 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | 66 | self.alpha = alpha 67 | self.beta = beta 68 | self.weight_scale = weight_scale 69 | self.batch_size = batch_size 70 | self.threshold = threshold 71 | self.simulation_time = simulation_time 72 | self.reset_potential = reset_potential 73 | 74 | self.spike_fn = SurrGradSpike.apply 75 | 76 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 77 | 78 | # Initialize the network weights 79 | self.weights = [] 80 | for i in range(len(architecture) - 1): 81 | self.weights.append(torch.empty((self.architecture[i], self.architecture[i + 1]), 82 | device=device, dtype=torch.float, requires_grad=True)) 83 | torch.nn.init.normal_(self.weights[i], mean=0.0, 84 | std=self.weight_scale/np.sqrt(self.architecture[i])) 85 | 86 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 87 | 88 | def forward(self, inputs): 89 | syn = [] 90 | mem = [] 91 | spk_count = [] 92 | 93 | for l in range(0, len(self.weights)): 94 | syn.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 95 | dtype=torch.float)) 96 | mem.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 97 | dtype=torch.float)) 98 | 99 | # Here we define two lists which we use to record the membrane potentials and output spikes 100 | mem_rec = [] 101 | spk_rec = [] 102 | 103 | # Here we loop over time 104 | for t in range(self.simulation_time): 105 | # append the new timestep to mem_rec and spk_rec 106 | mem_rec.append([]) 107 | spk_rec.append([]) 108 | 109 | if t == 0: 110 | for l in range(len(self.weights)): 111 | mem_rec[-1].append(mem[l]) 112 | spk_rec[-1].append(mem[l]) 113 | continue 114 | 115 | # We take the input as it is, multiply is by the weights, and we inject the outcome 116 | # as current in the neurons of the first hidden layer 117 | input = inputs.detach().clone() 118 | 119 | # loop over layers 120 | for l in range(len(self.weights)): 121 | if l == 0: 122 | h = torch.einsum("ab,bc->ac", [input, self.weights[0]]) 123 | new_syn = 0 * syn[l] + h 124 | else: 125 | h = torch.einsum("ab,bc->ac", [spk_rec[-1][l - 1], self.weights[l]]) 126 | new_syn = self.alpha * syn[l] + h 127 | 128 | new_mem = self.beta*mem[l] + new_syn 129 | 130 | # calculate the spikes for all layers but the last layer 131 | if l < (len(self.weights) - 1): 132 | mthr = new_mem 133 | mthr = mthr - self.threshold 134 | out = self.spike_fn(mthr) 135 | c = (mthr > 0) 136 | new_mem[c] = self.reset_potential 137 | spk_rec[-1].append(out) 138 | 139 | mem[l] = new_mem 140 | syn[l] = new_syn 141 | 142 | mem_rec[-1].append(mem[l]) 143 | 144 | # return the final recorded membrane potential in the output layer, all membrane potentials, 145 | # and spikes 146 | return mem_rec[-1][-1], mem_rec, spk_rec 147 | 148 | def load_state_dict(self, layers): 149 | """Method to load weights and biases into the network""" 150 | weights = layers[0] 151 | for l in range(0,len(weights)): 152 | self.weights[l] = weights[l].detach().clone().requires_grad_(True) 153 | 154 | def state_dict(self): 155 | """Method to copy the layers of the SQN. Makes explicit copies, no references.""" 156 | weights_copy = [] 157 | bias_copy = [] 158 | for l in range(0, len(self.weights)): 159 | weights_copy.append(self.weights[l].detach().clone()) 160 | return weights_copy, bias_copy 161 | 162 | def parameters(self): 163 | parameters = [] 164 | for l in range(0, len(self.weights)): 165 | parameters.append(self.weights[l]) 166 | 167 | return parameters 168 | 169 | 170 | class QNetwork(nn.Module): 171 | """Actor (Policy) Model.""" 172 | 173 | def __init__(self, architecture, seed): 174 | """Initialize parameters and build model. 175 | Params 176 | ====== 177 | architecture: 178 | """ 179 | super(QNetwork, self).__init__() 180 | self.layers = nn.ModuleList() 181 | 182 | random.seed(seed) 183 | np.random.seed(seed) 184 | torch.manual_seed(seed) 185 | 186 | for i in range(len(architecture) - 1): 187 | self.layers.append(nn.Linear(architecture[i], architecture[i + 1])) 188 | 189 | def forward(self, x): 190 | """Build a network that maps state -> action values.""" 191 | for layer in self.layers[:-1]: 192 | x = F.relu(layer(x)) 193 | 194 | # no ReLu activation in the output layer 195 | return self.layers[-1](x) 196 | 197 | 198 | class TD3ActorDSNN(DSNN): 199 | def __init__(self, architecture, seed, alpha, beta, weight_scale, batch_size, threshold, 200 | simulation_time, learning_rate, name, device, reset_potential=0, 201 | random_params=False, std=0.1): 202 | self.architecture = architecture 203 | 204 | random.seed(seed) 205 | np.random.seed(seed) 206 | torch.manual_seed(seed) 207 | 208 | self.alpha = alpha 209 | self.beta = beta 210 | self.weight_scale = weight_scale 211 | self.batch_size = batch_size 212 | self.threshold = threshold 213 | self.simulation_time = simulation_time 214 | self.reset_potential = reset_potential 215 | self.random_params = random_params 216 | self.std = std 217 | 218 | self.name = name 219 | 220 | self.spike_fn = SurrGradSpike.apply 221 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 222 | 223 | # Initialize the network weights 224 | self.weights = [] 225 | for i in range(len(architecture) - 1): 226 | self.weights.append(torch.empty((self.architecture[i], self.architecture[i + 1]), 227 | device=device, dtype=torch.float, requires_grad=True)) 228 | torch.nn.init.normal_(self.weights[i], mean=0.0, 229 | std=self.weight_scale / np.sqrt(self.architecture[i])) 230 | 231 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 232 | 233 | if random_params: 234 | self.randomize_parameters() 235 | 236 | def randomize_parameters(self): 237 | # Define new random values for alpha, beta, and threhsold 238 | self.alphas = [] 239 | self.betas = [] 240 | self.thresholds = [] 241 | self.reset_potentials = [] 242 | 243 | for i in range(len(self.architecture) - 1): 244 | self.alphas.append(torch.normal(torch.ones(self.architecture[i + 1])*self.alpha, 245 | self.alpha*self.std). 246 | unsqueeze(0).to(self.device)) 247 | self.alphas[i][self.alphas[i] > 1] = 1 248 | self.alphas[i][self.alphas[i] < 0] = 0 249 | self.betas.append(torch.normal(torch.ones(self.architecture[i + 1])*self.beta, 250 | self.beta*self.std). 251 | unsqueeze(0).to(self.device)) 252 | self.betas[i][self.betas[i] > 1] = 1 253 | self.betas[i][self.betas[i] < 0] = 0 254 | 255 | self.thresholds.append( 256 | torch.normal(torch.ones(self.architecture[i + 1]) * self.threshold, 257 | self.threshold*self.std).to(self.device)) 258 | self.thresholds[i][self.thresholds[i] < 0] = 0 259 | 260 | self.reset_potentials.append( 261 | torch.normal(torch.ones(self.architecture[i + 1])*self.reset_potential, 262 | self.std).to(self.device)) 263 | 264 | def forward(self, inputs): 265 | syn = [] 266 | mem = [] 267 | 268 | for l in range(0, len(self.weights)): 269 | syn.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 270 | dtype=torch.float)) 271 | mem.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 272 | dtype=torch.float)) 273 | 274 | # Here we define two lists which we use to record the membrane potentials and output spikes 275 | mem_rec = [] 276 | spk_rec = [] 277 | 278 | # Here we loop over time 279 | for t in range(self.simulation_time): 280 | # append the new timestep to mem_rec and spk_rec 281 | mem_rec.append([]) 282 | spk_rec.append([]) 283 | 284 | if t == 0: 285 | for l in range(len(self.weights)): 286 | mem_rec[-1].append(mem[l]) 287 | spk_rec[-1].append(mem[l]) 288 | continue 289 | 290 | # We take the input as it is, multiply is by the weights, and we inject the outcome 291 | # as current in the neurons of the first hidden layer 292 | input = inputs.detach().clone() 293 | 294 | # loop over layers 295 | for l in range(len(self.weights)): 296 | if l == 0: 297 | h = torch.matmul(input, self.weights[0]) 298 | new_syn = 0 * syn[l] + h 299 | elif l == len(self.weights) - 1: 300 | h = torch.matmul(spk_rec[-1][l - 1], self.weights[l]) 301 | new_syn = 0*syn[l] + h 302 | else: 303 | h = torch.matmul(spk_rec[-1][l - 1], self.weights[l]) 304 | if self.random_params: 305 | new_syn = torch.add(torch.mul(self.alphas[l], syn[l]), h) 306 | else: 307 | new_syn = self.alpha * syn[l] + h 308 | 309 | if l == len(self.weights) - 1: 310 | new_mem = 1*mem[l] + new_syn 311 | else: 312 | if self.random_params: 313 | new_mem = torch.add(torch.mul(self.betas[l], mem[l]), new_syn) 314 | else: 315 | new_mem = self.beta*mem[l] + new_syn 316 | 317 | # calculate the spikes for all layers but the last layer (decoding='potential') 318 | if l < (len(self.weights) - 1): 319 | mthr = new_mem 320 | if self.random_params: 321 | mthr = torch.sub(mthr, self.thresholds[l]) 322 | reset = self.reset_potential 323 | else: 324 | mthr = mthr - self.threshold 325 | reset = self.reset_potential 326 | out = self.spike_fn(mthr) 327 | c = (mthr > 0) 328 | new_mem[c] = self.reset_potential 329 | 330 | spk_rec[-1].append(out) 331 | 332 | mem[l] = new_mem 333 | syn[l] = new_syn 334 | 335 | mem_rec[-1].append(mem[l]) 336 | 337 | # return the final recorded membrane potential (len(mem_rec)-1) in the output layer (-1) 338 | return torch.tanh(mem_rec[-1][-1]), mem_rec, spk_rec 339 | 340 | def save_checkpoint(self, result_dir, episode_num): 341 | #print('... saving checkpoint ...') 342 | torch.save(self.state_dict(), result_dir + '/checkpoint_TD3_{}_{}.pt'.format(self.name, 343 | episode_num)) 344 | 345 | 346 | class TD3CriticNetwork(nn.Module): 347 | def __init__(self, learning_rate, input_dims, fc1_dims, fc2_dims, n_actions, name, 348 | checkpoint_dir='tmp/td3'): 349 | super(TD3CriticNetwork, self).__init__() 350 | self.input_dims = input_dims 351 | self.fc1_dims = fc1_dims 352 | self.fc2_dims = fc2_dims 353 | self.n_actions = n_actions 354 | self.name = name 355 | self.checkpoint_dir = checkpoint_dir 356 | self.checkpoint_file = os.path.join(self.checkpoint_dir, name + '_td3') 357 | 358 | self.fc1 = nn.Linear(self.input_dims[0] + n_actions, self.fc1_dims) 359 | self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) 360 | self.q1 = nn.Linear(self.fc2_dims, 1) 361 | 362 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 363 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 364 | 365 | self.to(self.device) 366 | 367 | def forward(self, state, action): 368 | q1_action_value = self.fc1(torch.cat([state, action], dim=1)) 369 | q1_action_value = F.relu(q1_action_value) 370 | q1_action_value = self.fc2(q1_action_value) 371 | q1_action_value = F.relu(q1_action_value) 372 | 373 | q1 = self.q1(q1_action_value) 374 | 375 | return q1 376 | 377 | def save_checkpoint(self, result_dir): 378 | torch.save(self.state_dict(), result_dir + '/checkpoint_TD3_{}.pt'.format(self.name)) 379 | 380 | def load_checkpoint(self): 381 | print('... loading checkpoint ...') 382 | self.load_state_dict(torch.load(self.checkpoint_file)) 383 | -------------------------------------------------------------------------------- /src/rstdp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class RSTDP(object): 5 | def __init__(self, 6 | A_plus=1., 7 | A_minus=1., 8 | tau_plus=10., 9 | tau_minus=10., 10 | tau_e=10., 11 | C=1., 12 | time_step=1., 13 | device='cpu', 14 | dtype=torch.float 15 | ): 16 | self._A_plus = A_plus 17 | self._A_minus = A_minus 18 | self._C = C 19 | 20 | self._exp_dec_plus = np.exp(-time_step/tau_plus) 21 | self._exp_dec_minus = np.exp(-time_step/tau_minus) 22 | self._e_trace_dec = tau_e / time_step 23 | 24 | self.device = device 25 | self.dtype = dtype 26 | 27 | def __call__(self, pre, post, state=(None, None, None)): 28 | # Get old state 29 | e, k_plus, k_minus = state 30 | 31 | # Set zero states, if input state is None (this can be used for resetting) 32 | if e is None: 33 | e = torch.zeros(pre.shape[0], pre.shape[-1], post.shape[-1], device=self.device, dtype=self.dtype) 34 | if k_plus is None: 35 | k_plus = torch.zeros(pre.shape[0], pre.shape[-1], post.shape[-1], device=self.device, dtype=self.dtype) 36 | if k_minus is None: 37 | k_minus = torch.zeros(pre.shape[0], pre.shape[-1], post.shape[-1], device=self.device, dtype=self.dtype) 38 | 39 | # Create lists for recording 40 | e_rec = [] 41 | k_plus_rec = [] 42 | k_minus_rec = [] 43 | # Loop over each timestep in spikes 44 | for t in range(pre.shape[1]): 45 | # Calculate change in eligibility trace (e_dot) 46 | e_dot = -(e/self._e_trace_dec) 47 | e_dot += self._A_plus * k_plus * post[:, t, None, :] * self._C 48 | e_dot += -self._A_minus * k_minus * pre[:, t, ..., None] * self._C 49 | 50 | # Set new eligibility trace (e) from old e and change (e_dot) 51 | new_e = e + e_dot 52 | 53 | # Calculate new k-values 54 | new_k_plus = k_plus * self._exp_dec_plus + pre[:, t, ..., None] 55 | new_k_minus = k_minus * self._exp_dec_minus + post[:, t, None, :] 56 | 57 | # Append new values to recordings 58 | e_rec.append(new_e) 59 | k_plus_rec.append(new_k_plus) 60 | k_minus_rec.append(new_k_minus) 61 | # Set new values to variables 62 | e = new_e 63 | k_plus = new_k_plus 64 | k_minus = new_k_minus 65 | 66 | # Stack recordings into single array, along simulation_time dimension 67 | e_rec = torch.stack(e_rec, dim=1) 68 | k_plus_rec = torch.stack(k_plus_rec, dim=1) 69 | k_minus_rec = torch.stack(k_minus_rec, dim=1) 70 | 71 | # Return state and recordings 72 | return (e, k_plus, k_minus), (e_rec, k_plus_rec, k_minus_rec) 73 | -------------------------------------------------------------------------------- /src/td3_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class ReplayBuffer(): 12 | def __init__(self, max_size, input_shape, n_actions, store_spikes=False, simtime=10): 13 | self.store_spikes = store_spikes 14 | 15 | self.mem_size = max_size 16 | self.mem_counter = 0 17 | self.state_memory = np.zeros((self.mem_size, *input_shape)) 18 | self.new_state_memory = np.zeros((self.mem_size, *input_shape)) 19 | self.action_memory = np.zeros((self.mem_size, n_actions)) 20 | self.reward_memory = np.zeros(self.mem_size) 21 | self.terminal_memory = np.zeros(self.mem_size, dtype=bool) 22 | 23 | if self.store_spikes: 24 | self.state_spikes_memory = np.zeros((self.mem_size, input_shape[0], simtime)) 25 | self.new_state_spikes_memory = np.zeros((self.mem_size, input_shape[0], simtime)) 26 | 27 | def store_transition(self, state, action, reward, state_, done, state_spikes=None, 28 | state_spikes_=None): 29 | index = self.mem_counter % self.mem_size 30 | self.state_memory[index] = state 31 | self.action_memory[index] = action 32 | self.reward_memory[index] = reward 33 | self.new_state_memory[index] = state_ 34 | self.terminal_memory[index] = done 35 | 36 | if self.store_spikes: 37 | self.state_spikes_memory[index] = state_spikes 38 | self.new_state_spikes_memory[index] = state_spikes_ 39 | 40 | self.mem_counter += 1 41 | 42 | def sample_buffer(self, batch_size): 43 | max_mem = min(self.mem_counter, self.mem_size) 44 | 45 | batch = np.random.choice(max_mem, batch_size) 46 | 47 | states = self.state_memory[batch] 48 | states_ = self.new_state_memory[batch] 49 | actions = self.action_memory[batch] 50 | rewards = self.reward_memory[batch] 51 | dones = self.terminal_memory[batch] 52 | 53 | if self.store_spikes: 54 | states_spikes = self.state_spikes_memory[batch] 55 | states_spikes_ = self.new_state_spikes_memory[batch] 56 | return states, states_spikes, actions, rewards, states_, states_spikes_, dones 57 | 58 | return states, actions, rewards, states_, dones 59 | 60 | 61 | class Agent(): 62 | def __init__(self, actor, critic_1, critic_2, target_actor, target_critic_1, target_critic_2, 63 | input_dims, tau, env, n_timesteps, result_dir, gamma=0.99, update_actor_interval=2, 64 | update_target_interval=2, warmup=1000, learning_starts=1000, n_actions=2, 65 | buffer_size=1000000, batch_size=100, noise=0.1, seed=0, pop_coding=False, 66 | mutually_exclusive=False, pop_size=2, obs_range=[(-1,1)], spiking=False, 67 | two_neuron=False, normalize=False, spiking_critic=False, simtime=20, 68 | encoding='current', store_spikes=False): 69 | """ 70 | :param alpha: actor network learning rate 71 | :param beta: critic network learning rate 72 | """ 73 | torch.manual_seed(seed) 74 | np.random.seed(seed) 75 | random.seed(seed) 76 | 77 | self.input_dims = input_dims 78 | self.tau = tau 79 | self.env = env 80 | self.n_timesteps = n_timesteps 81 | self.result_dir = result_dir 82 | self.max_action = env.action_space.high 83 | self.min_action = env.action_space.low 84 | self.max_obs = env.observation_space.high 85 | 86 | for i in range(len(self.max_obs)): 87 | if self.max_obs[i] == np.inf: 88 | self.max_obs[i] = 1 89 | self.gamma = gamma 90 | self.store_spikes = store_spikes 91 | self.memory = ReplayBuffer(buffer_size, input_dims, n_actions, store_spikes=store_spikes, 92 | simtime=simtime) 93 | self.batch_size = batch_size 94 | self.episode_counter = 0 95 | self.learn_step_counter = 0 96 | self.policy_learn_step_counter = 0 97 | self.time_step = 0 98 | self.warmup = warmup 99 | self.learning_starts = learning_starts 100 | self.n_actions = n_actions 101 | self.update_actor_interval = update_actor_interval 102 | self.update_target_interval = update_target_interval 103 | self.spiking = spiking 104 | self.spiking_critic = spiking_critic 105 | self.two_neuron = two_neuron 106 | self.normalize = normalize 107 | self.simtime = simtime 108 | self.encoding = encoding 109 | 110 | self.actor = actor 111 | self.critic_1 = critic_1 112 | self.critic_2 = critic_2 113 | self.target_actor = target_actor 114 | self.target_critic_1 = target_critic_1 115 | self.target_critic_2 = target_critic_2 116 | 117 | self.chosen_actions = [] 118 | 119 | self.noise = noise 120 | self.update_network_parameters(tau=1) 121 | 122 | self.actor_output = [] 123 | self.critic_1_output = [] 124 | self.critic_2_output = [] 125 | 126 | self.pop_coding = pop_coding 127 | self.mutually_exclusive = mutually_exclusive 128 | self.pop_size = pop_size 129 | self.obs_range = obs_range 130 | self.seed = seed 131 | 132 | if self.pop_coding: 133 | self.pop_disp = [(i[1] - i[0])/(pop_size + 1) for i in obs_range] 134 | self.pop_means = [] 135 | for i in range(int(input_dims[0]/pop_size)): 136 | self.pop_means.append([]) 137 | start = obs_range[i][0] 138 | for j in range(pop_size): 139 | self.pop_means[-1].append(start + self.pop_disp[i]) 140 | start += self.pop_disp[i] 141 | 142 | def choose_action(self, observation): 143 | if self.time_step < self.warmup: 144 | mu = torch.tensor(np.random.normal(scale=self.noise, size=self.n_actions), 145 | device=device) 146 | else: 147 | if self.normalize: 148 | observation = self.normalize_state(observation) 149 | state = observation.clone().to(device) 150 | else: 151 | state = torch.tensor(observation, dtype=torch.float).clone().to(device) 152 | if self.spiking: 153 | if self.encoding == 'poisson': 154 | state = self.generate_poisson_input(state.to('cpu')) 155 | state = state.unsqueeze(0).to(device) 156 | mu = self.actor.forward(state)[0].squeeze(0).to(device) 157 | else: 158 | mu = self.actor.forward(state).to(device) 159 | 160 | mu_prime = mu + torch.tensor(np.random.normal(scale=self.noise), dtype=torch.float, 161 | device=device).to(device) 162 | 163 | mu_prime = torch.clamp(mu_prime*self.max_action[0], self.min_action[0], self.max_action[0]) 164 | self.time_step += 1 165 | 166 | return mu_prime.cpu().detach().numpy() 167 | 168 | def learn(self): 169 | if self.memory.mem_counter < self.batch_size or self.time_step < self.learning_starts: 170 | return 171 | #if self.encoding == 'poisson': 172 | # state, state_spikes, action, reward, state_, state_spikes_, done =\ 173 | # self.memory.sample_buffer(self.batch_size) 174 | #state_spikes = torch.tensor(state_spikes, dtype=torch.float).to(device) 175 | #state_spikes_ = torch.tensor(state_spikes_, dtype=torch.float).to(device) 176 | #else: 177 | state, action, reward, state_, done = self.memory.sample_buffer(self.batch_size) 178 | 179 | state = torch.tensor(state, dtype=torch.float).to(device) 180 | action = torch.tensor(action, dtype=torch.float).to(device) 181 | reward = torch.tensor(reward, dtype=torch.float).to(device) 182 | state_ = torch.tensor(state_, dtype=torch.float).to(device) 183 | done = torch.tensor(done).to(device) 184 | 185 | if self.normalize: 186 | state = self.normalize_state(state.to('cpu')).float().to(device) 187 | state_ = self.normalize_state(state_.to('cpu')).float().to(device) 188 | 189 | if self.spiking: 190 | if self.encoding == 'poisson': 191 | state_spikes_ = self.generate_poisson_input(state_.to('cpu')).to(device) 192 | target_actions = self.target_actor.forward(state_spikes_)[0].squeeze(0).to(device) 193 | else: 194 | target_actions = self.target_actor.forward(state_)[0].squeeze(0).to(device) 195 | else: 196 | target_actions = self.target_actor.forward(state_) 197 | 198 | if self.spiking_critic: 199 | q1 = self.critic_1.forward(state, action)[0] 200 | q2 = self.critic_2.forward(state, action)[0] 201 | else: 202 | q1 = self.critic_1.forward(state, action) 203 | q2 = self.critic_2.forward(state, action) 204 | target_actions = target_actions + torch.clamp(torch.tensor(np.random.normal(scale=0.2)), 205 | -0.5, 0.5) 206 | target_actions = torch.clamp(target_actions, self.min_action[0], self.max_action[0]) 207 | 208 | if self.spiking_critic: 209 | q1_ = self.target_critic_1.forward(state_, target_actions)[0] 210 | q2_ = self.target_critic_2.forward(state_, target_actions)[0] 211 | else: 212 | q1_ = self.target_critic_1.forward(state_, target_actions) 213 | q2_ = self.target_critic_2.forward(state_, target_actions) 214 | 215 | q1_[done] = 0.0 216 | q2_[done] = 0.0 217 | 218 | critic_value_ = torch.min(q1_.view(-1), q2_.view(-1)).detach() 219 | target = reward + self.gamma*critic_value_ 220 | target = target.view(self.batch_size, 1) 221 | 222 | self.critic_1.optimizer.zero_grad() 223 | self.critic_2.optimizer.zero_grad() 224 | 225 | q1_loss = F.mse_loss(target, q1) 226 | q2_loss = F.mse_loss(target, q2) 227 | critic_loss = q1_loss + q2_loss 228 | critic_loss.backward(retain_graph=True) 229 | self.critic_1.optimizer.step() 230 | self.critic_2.optimizer.step() 231 | 232 | self.learn_step_counter += 1 233 | 234 | if self.learn_step_counter % self.update_actor_interval == 0: 235 | self.policy_learn_step_counter += 1 236 | self.actor.optimizer.zero_grad() 237 | if self.spiking_critic: 238 | actor_q1_loss = \ 239 | self.critic_1.forward(state, self.actor.forward(state)[0].squeeze(0))[0] 240 | elif self.spiking: 241 | if self.encoding == 'poisson': 242 | state_spikes = self.generate_poisson_input(state.to('cpu')).to(device) 243 | actor_q1_loss = \ 244 | self.critic_1.forward(state, self.actor.forward(state_spikes)[0].squeeze(0)) 245 | else: 246 | actor_q1_loss = \ 247 | self.critic_1.forward(state, self.actor.forward(state)[0].squeeze(0)) 248 | else: 249 | actor_q1_loss = self.critic_1.forward(state, self.actor.forward(state)) 250 | actor_loss = -torch.mean(actor_q1_loss) 251 | actor_loss.backward(retain_graph=True) 252 | self.actor.optimizer.step() 253 | 254 | if self.learn_step_counter % self.update_target_interval == 0: 255 | self.update_network_parameters() 256 | 257 | def update_network_parameters(self, tau=None): 258 | if tau is None: 259 | tau = self.tau 260 | 261 | # update actor params 262 | if self.spiking: 263 | actor_state_dict = [d.clone() for d in self.actor.state_dict()[0]] 264 | target_actor_state_dict = [d.clone() for d in self.target_actor.state_dict()[0]] 265 | new_state_dict = [[tau*a + (1 - tau)*ta for a, ta in zip(actor_state_dict, 266 | target_actor_state_dict)]] 267 | self.target_actor.load_state_dict(new_state_dict) 268 | else: 269 | actor_params = self.actor.named_parameters() 270 | target_actor_params = self.target_actor.named_parameters() 271 | actor = dict(actor_params) 272 | target_actor = dict(target_actor_params) 273 | 274 | for name in actor: 275 | actor[name] = tau*actor[name].clone() + (1 - tau)*target_actor[name].clone() 276 | 277 | self.target_actor.load_state_dict(actor) 278 | 279 | # update critic params 280 | if self.spiking_critic: 281 | critic_1_state_dict = [d.clone() for d in self.critic_1.state_dict()[0]] 282 | target_critic_1_state_dict = [d.clone() for d in self.target_critic_1.state_dict()[0]] 283 | new_state_dict = [[tau * a + (1 - tau) * ta for a, ta in 284 | zip(critic_1_state_dict, target_critic_1_state_dict)]] 285 | self.target_critic_1.load_state_dict(new_state_dict) 286 | 287 | critic_2_state_dict = [d.clone() for d in self.critic_2.state_dict()[0]] 288 | target_critic_2_state_dict = [d.clone() for d in self.target_critic_2.state_dict()[0]] 289 | new_state_dict = [[tau * a + (1 - tau) * ta for a, ta in 290 | zip(critic_2_state_dict, target_critic_2_state_dict)]] 291 | self.target_critic_2.load_state_dict(new_state_dict) 292 | else: 293 | critic_1_params = self.critic_1.named_parameters() 294 | critic_2_params = self.critic_2.named_parameters() 295 | target_critic_1_params = self.target_critic_1.named_parameters() 296 | target_critic_2_params = self.target_critic_2.named_parameters() 297 | critic_1 = dict(critic_1_params) 298 | critic_2 = dict(critic_2_params) 299 | target_critic_1 = dict(target_critic_1_params) 300 | target_critic_2 = dict(target_critic_2_params) 301 | 302 | for name in critic_1: 303 | critic_1[name] = tau * critic_1[name].clone() + (1 - tau) * target_critic_1[ 304 | name].clone() 305 | 306 | for name in critic_2: 307 | critic_2[name] = tau * critic_2[name].clone() + (1 - tau) * target_critic_2[ 308 | name].clone() 309 | 310 | self.target_critic_1.load_state_dict(critic_1) 311 | self.target_critic_2.load_state_dict(critic_2) 312 | 313 | def get_active_neurons(self, state): 314 | intervals = [] 315 | neuron_idxs = np.zeros_like(state) 316 | for i in self.obs_range: 317 | intervals.append((i[1] - i[0]) / self.pop_size) 318 | 319 | for i in range(len(state)): 320 | threshold = self.obs_range[i][0] 321 | for k in range(self.pop_size): 322 | neuron_idx = k 323 | neuron_idxs[i] = int(neuron_idx) 324 | if state[i] < threshold + intervals[i]: 325 | break 326 | else: 327 | threshold += intervals[i] 328 | neuron_idxs = neuron_idxs.reshape(state.shape) 329 | return neuron_idxs 330 | 331 | def get_mutually_exclusive_pop_input(self, state): 332 | neuron_idxs = self.get_active_neurons(state) 333 | pop_observation = np.zeros((state.shape[0] * self.pop_size)) 334 | idx = 0 335 | for i in range(len(state)): 336 | pop_idx = int(idx + neuron_idxs[i]) 337 | pop_observation[pop_idx] = abs(state[i]) 338 | idx += self.pop_size 339 | return pop_observation 340 | 341 | def get_population_input(self, state): 342 | pop_input = [] 343 | for i in range(len(state)): 344 | for j in range(self.pop_size): 345 | a_es = np.exp(-0.5*((state[i] - self.pop_means[i][j])/self.pop_disp[i])**2) 346 | pop_input.append(a_es) 347 | 348 | return pop_input 349 | 350 | def generate_poisson_input(self, state): 351 | return torch.bernoulli( 352 | torch.tile(state[..., None], [1] * len(state.shape) + [self.simtime]) 353 | ) 354 | 355 | def normalize_state(self, state): 356 | if self.two_neuron: 357 | two_neuron_max_obs = np.array([val for val in self.max_obs for _ in (0, 1)]) 358 | return torch.tensor(state/two_neuron_max_obs, dtype=torch.float) 359 | 360 | return torch.tensor(state/self.max_obs, dtype=torch.float) 361 | 362 | def transform_state(self, state): 363 | state_ = [] 364 | for i in state: 365 | if i > 0: 366 | state_.append(i) 367 | state_.append(0) 368 | else: 369 | state_.append(0) 370 | state_.append(abs(i)) 371 | return np.array(state_) 372 | 373 | def update_max_obs(self, state): 374 | self.max_obs = np.maximum(self.max_obs, np.abs(state)) 375 | 376 | def train_agent(self): 377 | best_average = -np.inf 378 | best_average_after = np.inf 379 | reward_history = [] 380 | smoothed_scores = [] 381 | 382 | while self.learn_step_counter < self.n_timesteps + 1: 383 | self.episode_counter += 1 384 | observation, info = self.env.reset(seed=self.seed) 385 | self.update_max_obs(observation) 386 | if self.two_neuron: 387 | observation = self.transform_state(observation) 388 | done = False 389 | score = 0 390 | while not done: 391 | action = self.choose_action(observation) 392 | observation_, reward, done1, done2, info = self.env.step(action) 393 | done = done1 or done2 394 | self.update_max_obs(observation_) 395 | if self.two_neuron: 396 | observation_ = self.transform_state(observation_) 397 | if self.spiking: 398 | observation_ = observation_.reshape(self.input_dims) 399 | self.memory.store_transition(observation, action, reward, observation_, done) 400 | score += reward 401 | observation = observation_ 402 | self.learn() 403 | 404 | reward_history.append(score) 405 | avg_score = np.mean(reward_history[-100:]) 406 | smoothed_scores.append(avg_score) 407 | 408 | if avg_score > best_average: 409 | best_average = avg_score 410 | best_average_after = self.episode_counter 411 | #self.save_models(self.result_dir) 412 | 413 | print('Episode: ', self.episode_counter, 'training steps: ', 414 | self.learn_step_counter, 'score: %.1f' % score, 415 | 'Average Score: %.1f' % avg_score, end='\r') 416 | 417 | if self.episode_counter % 100 == 0: 418 | print("\rEpisode: ", self.episode_counter, 'training steps: ', self.learn_step_counter, 419 | "Average Score: %.2f" % avg_score) 420 | self.save_models(self.result_dir, self.episode_counter) 421 | 422 | print('Best 100 episode average: ', best_average, ' reached at episode ', 423 | best_average_after, '. Model saved in folder best.') 424 | return smoothed_scores, reward_history, best_average, best_average_after 425 | 426 | def save_models(self, result_dir, episode_num): 427 | self.actor.save_checkpoint(result_dir, episode_num) 428 | self.target_actor.save_checkpoint(result_dir, episode_num) 429 | self.critic_1.save_checkpoint(result_dir) 430 | self.critic_2.save_checkpoint(result_dir) 431 | self.target_critic_1.save_checkpoint(result_dir) 432 | self.target_critic_2.save_checkpoint(result_dir) 433 | 434 | def load_models(self): 435 | self.actor.load_checkpoint() 436 | self.target_actor.load_checkpoint() 437 | self.critic_1.load_checkpoint() 438 | self.critic_2.load_checkpoint() 439 | self.target_critic_1.load_checkpoint() 440 | self.target_critic_2.load_checkpoint() 441 | -------------------------------------------------------------------------------- /swing-up-cartpole/checkpoint_TD3_actor_1500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/checkpoint_TD3_actor_1500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/evaluate_snn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import csv 5 | from datetime import datetime 6 | from importlib import import_module 7 | from tqdm import trange 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from typing import Any 13 | import gymnasium as gym 14 | import numpy as np 15 | import tensorflow as tf 16 | from numpy.random import SeedSequence 17 | from yaml import dump 18 | from memory_buffer import ReplayBuffer 19 | from model import TD3CriticNetwork, TD3ActorDSNN 20 | 21 | from Control_Toolkit.Controllers import template_controller 22 | from Control_Toolkit.Cost_Functions.cost_function_wrapper import CostFunctionWrapper 23 | from Control_Toolkit.others.environment import EnvironmentBatched 24 | from Environments import ENV_REGISTRY, register_envs 25 | from SI_Toolkit.computation_library import TensorFlowLibrary 26 | from Utilities.csv_helpers import save_to_csv 27 | from Utilities.generate_plots import generate_experiment_plots 28 | from Utilities.utils import ConfigManager, CurrentRunMemory, OutputPath, SeedMemory, get_logger, nested_assignment_to_ordereddict 29 | 30 | 31 | sys.path.append(os.path.join(os.path.abspath("."), "CartPoleSimulation")) # Keep allowing absolute imports within CartPoleSimulation subgit 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | register_envs() # Gym API: Register custom environments 34 | logger = get_logger(__name__) 35 | 36 | # td3 parameters 37 | actor_learning_rate = 0.001 38 | critic_learning_rate = 0.001 39 | tau = 0.005 40 | layer1_size = 400 41 | layer2_size = 300 42 | noise = 0.1 43 | gamma = 0.99 44 | warmup = 1000 45 | batch_size = 100 46 | learning_starts = 1000 47 | update_actor_interval = 2 48 | update_target_interval = 2 49 | buffer_size = int(2e5) 50 | normalize = False 51 | episode_counter = 0 52 | learn_step_counter = 0 53 | policy_learn_step_counter = 0 54 | time_step = 0 55 | 56 | # snn parameters 57 | alpha = 0.5 58 | beta = 0.5 59 | weight_scale = 1 60 | threshold = 2.5 61 | sim_time = 5 62 | two_neuron_encoding = True 63 | spiking = True 64 | 65 | 66 | def normalize_state(state, max_obs): 67 | if two_neuron_encoding: 68 | two_neuron_max_obs = np.array([val for val in max_obs for _ in (0, 1)]) 69 | return torch.tensor(state / two_neuron_max_obs, dtype=torch.float) 70 | 71 | def choose_action(actor, observation, max_action, min_action, n_actions, max_obs): 72 | global time_step 73 | if normalize: 74 | observation = normalize_state(observation, max_obs) 75 | state = observation.clone().to(device) 76 | else: 77 | state = torch.tensor(observation, dtype=torch.float).clone().to(device) 78 | if spiking: 79 | state = state.unsqueeze(0).to(device) 80 | mu = actor.forward(state)[0].squeeze(0).to(device) 81 | else: 82 | mu = actor.forward(state).to(device) 83 | 84 | mu_prime = mu + torch.tensor(np.random.normal(scale=noise), dtype=torch.float, 85 | device=device).to(device) 86 | 87 | mu_prime = torch.clamp(mu_prime * max_action[0], min_action[0], max_action[0]) 88 | time_step += 1 89 | 90 | action = mu_prime.cpu().detach().numpy() 91 | action = action.astype(np.float32) 92 | return action 93 | 94 | 95 | def transform_state(state): 96 | state_ = [] 97 | for i in state: 98 | if i > 0: 99 | state_.append(i) 100 | state_.append(0) 101 | else: 102 | state_.append(0) 103 | state_.append(abs(i)) 104 | return np.array(state_) 105 | 106 | 107 | def one_neuron_encoding(state): 108 | state_ = [] 109 | for i in range(len(state)): 110 | if i == 0: 111 | continue 112 | else: 113 | if i % 2 == 0: 114 | state_.append(state[i]) 115 | else: 116 | state_.append(-state[i]) 117 | return np.array(state_) 118 | 119 | 120 | def update_max_obs(state, max_obs): 121 | max_obs = np.maximum(max_obs, np.abs(state)) 122 | return max_obs 123 | 124 | 125 | def run_data_generator(controller_name: str, environment_name: str, config_manager: ConfigManager, 126 | record_path=None): 127 | global time_step, episode_counter 128 | models = [file for file in os.listdir('snn_results/') if file.endswith(".pt")] 129 | models.sort() 130 | 131 | # Generate seeds and set timestamp 132 | timestamp = datetime.now() 133 | seed_entropy = config_manager("config")["seed_entropy"] 134 | if seed_entropy is None: 135 | seed_entropy = int(timestamp.timestamp()) 136 | logger.info("No seed entropy specified. Setting to posix timestamp.") 137 | 138 | num_experiments = config_manager("config")["num_experiments"] 139 | seed_sequences = SeedSequence(entropy=seed_entropy).spawn(num_experiments) 140 | timestamp_str = timestamp.strftime("%Y%m%d-%H%M%S") 141 | 142 | controller_short_name = controller_name.replace("controller_", "").replace("_", "-") 143 | optimizer_short_name = config_manager("config_controllers")[controller_short_name]["optimizer"] 144 | optimizer_name = "optimizer_" + optimizer_short_name.replace("-", "_") 145 | CurrentRunMemory.current_optimizer_name = optimizer_name 146 | all_metrics = dict( 147 | total_rewards = [], 148 | timeout = [], 149 | terminated = [], 150 | truncated = [], 151 | ) 152 | 153 | best_average = -np.inf 154 | best_average_after = np.inf 155 | smoothed_scores = [] 156 | 157 | # Generate new seeds for environment and controller 158 | seeds = seed_sequences[0].generate_state(3) 159 | SeedMemory.set_seeds(seeds) 160 | 161 | config_controller = dict(config_manager("config_controllers")[controller_short_name]) 162 | config_optimizer = dict(config_manager("config_optimizers")[optimizer_short_name]) 163 | config_optimizer.update({"seed": int(seeds[1])}) 164 | config_environment = dict(config_manager("config_environments")[environment_name]) 165 | config_environment.update({"seed": int(seeds[0])}) 166 | all_rewards = [] 167 | 168 | ##### ----------------------------------------------- ##### 169 | ##### ----------------- ENVIRONMENT ----------------- ##### 170 | ##### --- Instantiate environment and call reset ---- ##### 171 | if config_manager("config")["render_for_humans"]: 172 | render_mode = "human" 173 | elif config_manager("config")["save_plots_to_file"]: 174 | render_mode = "rgb_array" 175 | else: 176 | render_mode = None 177 | 178 | import matplotlib 179 | 180 | matplotlib.use("Agg") 181 | 182 | env: EnvironmentBatched = gym.make(environment_name, **config_environment, 183 | computation_lib=TensorFlowLibrary, 184 | render_mode=render_mode) 185 | CurrentRunMemory.current_environment = env 186 | obs, obs_info = env.reset(seed=config_environment["seed"]) 187 | assert len( 188 | env.action_space.shape) == 1, f"Action space needs to be a flat vector, is Box with shape {env.action_space.shape}" 189 | 190 | # td3 variables 191 | max_action = env.action_space.high 192 | min_action = env.action_space.low 193 | max_obs = env.observation_space.high 194 | if two_neuron_encoding: 195 | input_dims = (env.observation_space.shape[0]*2) 196 | else: 197 | input_dims = env.observation_space.shape 198 | n_actions = env.action_space.shape[0] 199 | 200 | memory = ReplayBuffer(buffer_size, input_dims, n_actions) 201 | 202 | for i in range(len(max_obs)): 203 | if max_obs[i] == np.inf: 204 | max_obs[i] = 1 205 | 206 | actor_architecture = [input_dims, layer1_size, layer2_size, n_actions] 207 | 208 | #update_network_parameters(actor, target_actor, critic_1, target_critic_1, critic_2, 209 | # target_critic_2, tau=1) 210 | 211 | model_rewards = [] 212 | 213 | for j in range(len(models)): 214 | # networks 215 | actor = TD3ActorDSNN(actor_architecture, 0, alpha, beta, weight_scale, 1, threshold, 216 | sim_time, actor_learning_rate, name='actor', device=device) 217 | m = models[0] 218 | print(models[0]) 219 | w = torch.load('snn_results/{}'.format(models[0]), map_location=torch.device(device)) 220 | actor.load_state_dict(w) 221 | reward_history = [] 222 | model_rewards.append([]) 223 | # Loop through episodes 224 | for i in trange(num_experiments): 225 | 226 | # Generate new seeds for environment and controller 227 | seeds = seed_sequences[i].generate_state(3) 228 | SeedMemory.set_seeds(seeds) 229 | 230 | config_controller = dict(config_manager("config_controllers")[controller_short_name]) 231 | config_optimizer = dict(config_manager("config_optimizers")[optimizer_short_name]) 232 | config_optimizer.update({"seed": int(seeds[1])}) 233 | config_environment = dict(config_manager("config_environments")[environment_name]) 234 | config_environment.update({"seed": int(seeds[0])}) 235 | all_rewards = [] 236 | 237 | episode_counter += 1 238 | print(episode_counter) 239 | obs, obs_info = env.reset(seed=config_environment["seed"]) 240 | max_obs = update_max_obs(obs, max_obs) 241 | 242 | if two_neuron_encoding: 243 | obs = transform_state(obs) 244 | 245 | score = 0 246 | 247 | ##### ---------------------------------------------- ##### 248 | ##### ----------------- CONTROLLER ----------------- ##### 249 | controller_module = import_module(f"Control_Toolkit.Controllers.{controller_name}") 250 | controller: template_controller = getattr(controller_module, controller_name)( 251 | dt=env.dt, 252 | environment_name=ENV_REGISTRY[environment_name].split(":")[-1], 253 | control_limits=(env.action_space.low, env.action_space.high), 254 | initial_environment_attributes=env.environment_attributes) 255 | controller.configure(optimizer_name=optimizer_short_name, predictor_specification=config_controller["predictor_specification"]) 256 | 257 | ##### ----------------------------------------------------- ##### 258 | ##### ----------------- MAIN CONTROL LOOP ----------------- ##### 259 | frames = [] 260 | start_time = time.time() 261 | num_iterations = config_manager("config")["num_iterations"] 262 | for step in range(num_iterations): 263 | #action = controller.step(obs, updated_attributes=env.environment_attributes) 264 | action = choose_action(actor, obs, max_action, min_action, n_actions, max_obs) 265 | new_obs, reward, terminated, truncated, info = env.step(action) 266 | if two_neuron_encoding: 267 | new_obs_one_neuron = new_obs 268 | new_obs = transform_state(new_obs) 269 | if spiking: 270 | new_obs = new_obs.reshape(input_dims) 271 | c_fun: CostFunctionWrapper = getattr(controller, "cost_function", None) 272 | if c_fun is not None: 273 | assert isinstance(c_fun, CostFunctionWrapper) 274 | # Compute reward from the cost function that the controller optimized 275 | reward = -float(c_fun.get_stage_cost( 276 | tf.convert_to_tensor(new_obs_one_neuron[np.newaxis, np.newaxis, ...]), # Add batch / MPC horizon dimensions 277 | tf.convert_to_tensor(action[np.newaxis, np.newaxis, ...]), 278 | None 279 | )) 280 | all_rewards.append(reward) 281 | if config_controller.get("controller_logging", False): 282 | controller.logs["realized_cost_logged"].append(np.array([-reward]).copy()) 283 | env.set_logs(controller.logs) 284 | if config_manager("config")["render_for_humans"]: 285 | env.render() 286 | elif config_manager("config")["save_plots_to_file"]: 287 | frames.append(env.render()) 288 | 289 | done = terminated or truncated 290 | memory.store_transition(obs, action, reward, new_obs, done) 291 | score += reward 292 | 293 | time.sleep(1e-6) 294 | 295 | obs = new_obs 296 | 297 | # Print compute time statistics 298 | end_time = time.time() 299 | control_freq = num_iterations / (end_time - start_time) 300 | 301 | reward_history.append(score) 302 | model_rewards.append(score) 303 | avg_score = np.mean(reward_history[-100:]) 304 | smoothed_scores.append(avg_score) 305 | 306 | #if avg_score > best_average: 307 | # best_average = avg_score 308 | # best_average_after = episode_counter 309 | 310 | print('Episode: ', episode_counter, 'training steps: ', learn_step_counter, 311 | 'score: %.1f' % score, 'Average Score: %.1f' % avg_score) 312 | 313 | print("Model: {}, Average Score: {}".format(j, np.mean(model_rewards[j]))) 314 | 315 | # Close the env 316 | env.close() 317 | print('Best 100 episode average: ', best_average, ' reached at episode ', best_average_after, 318 | '.') 319 | return model_rewards 320 | 321 | def prepare_and_run(): 322 | import ruamel.yaml 323 | 324 | # Create a config manager which looks for '.yml' files within the list of folders specified. 325 | # Rationale: We want GUILD AI to be able to update values in configs that we include in this list. 326 | # We might intentionally want to exclude the path to a folder which does contain configs but should not be overwritten by GUILD. 327 | config_manager = ConfigManager(".", "Control_Toolkit_ASF", "SI_Toolkit_ASF", "Environments") 328 | 329 | # Scan for any custom parameters that should overwrite the toolkits' config files: 330 | submodule_configs = ConfigManager("Control_Toolkit_ASF", "SI_Toolkit_ASF", "Environments").loaders 331 | for base_name, loader in submodule_configs.items(): 332 | if base_name in config_manager("config").get("custom_config_overwrites", {}): 333 | data: ruamel.yaml.comments.CommentedMap = loader.load() 334 | update_dict = config_manager("config")["custom_config_overwrites"][base_name] 335 | nested_assignment_to_ordereddict(data, update_dict) 336 | loader.overwrite_config(data) 337 | 338 | # Retrieve required parameters from config: 339 | CurrentRunMemory.current_controller_name = config_manager("config")["controller_name"] 340 | CurrentRunMemory.current_environment_name = config_manager("config")["environment_name"] 341 | 342 | smoothed_scores = run_data_generator(controller_name=CurrentRunMemory.current_controller_name, 343 | environment_name=CurrentRunMemory.current_environment_name, 344 | config_manager=config_manager) 345 | return smoothed_scores 346 | 347 | if __name__ == "__main__": 348 | smoothed_scores = prepare_and_run() -------------------------------------------------------------------------------- /swing-up-cartpole/memory_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ReplayBuffer(): 5 | def __init__(self, max_size, input_shape, n_actions, store_spikes=False, simtime=10): 6 | self.store_spikes = store_spikes 7 | 8 | self.mem_size = max_size 9 | self.mem_counter = 0 10 | self.state_memory = np.zeros((self.mem_size, input_shape)) 11 | self.new_state_memory = np.zeros((self.mem_size, input_shape)) 12 | self.action_memory = np.zeros((self.mem_size, n_actions)) 13 | self.reward_memory = np.zeros(self.mem_size) 14 | self.terminal_memory = np.zeros(self.mem_size, dtype=bool) 15 | 16 | if self.store_spikes: 17 | self.state_spikes_memory = np.zeros((self.mem_size, input_shape[0], simtime)) 18 | self.new_state_spikes_memory = np.zeros((self.mem_size, input_shape[0], simtime)) 19 | 20 | def store_transition(self, state, action, reward, state_, done, state_spikes=None, 21 | state_spikes_=None): 22 | index = self.mem_counter % self.mem_size 23 | self.state_memory[index] = state 24 | self.action_memory[index] = action 25 | self.reward_memory[index] = reward 26 | self.new_state_memory[index] = state_ 27 | self.terminal_memory[index] = done 28 | 29 | if self.store_spikes: 30 | self.state_spikes_memory[index] = state_spikes 31 | self.new_state_spikes_memory[index] = state_spikes_ 32 | 33 | self.mem_counter += 1 34 | 35 | def sample_buffer(self, batch_size): 36 | max_mem = min(self.mem_counter, self.mem_size) 37 | 38 | batch = np.random.choice(max_mem, batch_size) 39 | 40 | states = self.state_memory[batch] 41 | states_ = self.new_state_memory[batch] 42 | actions = self.action_memory[batch] 43 | rewards = self.reward_memory[batch] 44 | dones = self.terminal_memory[batch] 45 | 46 | if self.store_spikes: 47 | states_spikes = self.state_spikes_memory[batch] 48 | states_spikes_ = self.new_state_spikes_memory[batch] 49 | return states, states_spikes, actions, rewards, states_, states_spikes_, dones 50 | 51 | return states, actions, rewards, states_, dones 52 | -------------------------------------------------------------------------------- /swing-up-cartpole/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class SurrGradSpike(torch.autograd.Function): 15 | """ 16 | Here we implement our spiking nonlinearity which also implements 17 | the surrogate gradient. By subclassing torch.autograd.Function, 18 | we will be able to use all of PyTorch's autograd functionality. 19 | Here we use the normalized negative part of a fast sigmoid 20 | as this was done in Zenke & Ganguli (2018). 21 | """ 22 | 23 | scale = 100.0 # controls steepness of surrogate gradient 24 | 25 | @staticmethod 26 | def forward(ctx, input): 27 | """ 28 | In the forward pass we compute a step function of the input Tensor 29 | and return it. ctx is a context object that we use to stash information which 30 | we need to later backpropagate our error signals. To achieve this we use the 31 | ctx.save_for_backward method. 32 | """ 33 | ctx.save_for_backward(input) 34 | out = torch.zeros_like(input) 35 | out[input > 0] = 1.0 36 | # Only for inhibitory spikes 37 | #out[input < 0] = -1.0 38 | return out 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | """ 43 | In the backward pass we receive a Tensor we need to compute the 44 | surrogate gradient of the loss with respect to the input. 45 | Here we use the normalized negative part of a fast sigmoid 46 | as this was done in Zenke & Ganguli (2018). 47 | """ 48 | input, = ctx.saved_tensors 49 | grad_input = grad_output.clone() 50 | grad = grad_input / (SurrGradSpike.scale * torch.abs(input) + 1.0) ** 2 51 | return grad 52 | 53 | 54 | class DSNN(nn.Module): 55 | def __init__(self, architecture, seed, alpha, beta, weight_scale, batch_size, threshold, 56 | simulation_time, learning_rate, reset_potential=0): 57 | """ 58 | 59 | """ 60 | self.architecture = architecture 61 | 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | 66 | self.alpha = alpha 67 | self.beta = beta 68 | self.weight_scale = weight_scale 69 | self.batch_size = batch_size 70 | self.threshold = threshold 71 | self.simulation_time = simulation_time 72 | self.reset_potential = reset_potential 73 | 74 | self.spike_fn = SurrGradSpike.apply 75 | 76 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 77 | 78 | # Initialize the network weights 79 | self.weights = [] 80 | for i in range(len(architecture) - 1): 81 | self.weights.append(torch.empty((self.architecture[i], self.architecture[i + 1]), 82 | device=device, dtype=torch.float, requires_grad=True)) 83 | torch.nn.init.normal_(self.weights[i], mean=0.0, 84 | std=self.weight_scale/np.sqrt(self.architecture[i])) 85 | 86 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 87 | 88 | def forward(self, inputs): 89 | syn = [] 90 | mem = [] 91 | spk_count = [] 92 | 93 | for l in range(0, len(self.weights)): 94 | syn.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 95 | dtype=torch.float)) 96 | mem.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 97 | dtype=torch.float)) 98 | 99 | # Here we define two lists which we use to record the membrane potentials and output spikes 100 | mem_rec = [] 101 | spk_rec = [] 102 | 103 | # Here we loop over time 104 | for t in range(self.simulation_time): 105 | # append the new timestep to mem_rec and spk_rec 106 | mem_rec.append([]) 107 | spk_rec.append([]) 108 | 109 | if t == 0: 110 | for l in range(len(self.weights)): 111 | mem_rec[-1].append(mem[l]) 112 | spk_rec[-1].append(mem[l]) 113 | continue 114 | 115 | # We take the input as it is, multiply is by the weights, and we inject the outcome 116 | # as current in the neurons of the first hidden layer 117 | input = inputs.detach().clone() 118 | 119 | # loop over layers 120 | for l in range(len(self.weights)): 121 | if l == 0: 122 | h = torch.einsum("ab,bc->ac", [input, self.weights[0]]) 123 | new_syn = 0 * syn[l] + h 124 | else: 125 | h = torch.einsum("ab,bc->ac", [spk_rec[-1][l - 1], self.weights[l]]) 126 | new_syn = self.alpha * syn[l] + h 127 | 128 | new_mem = self.beta*mem[l] + new_syn 129 | 130 | # calculate the spikes for all layers but the last layer 131 | if l < (len(self.weights) - 1): 132 | mthr = new_mem 133 | mthr = mthr - self.threshold 134 | out = self.spike_fn(mthr) 135 | c = (mthr > 0) 136 | new_mem[c] = self.reset_potential 137 | spk_rec[-1].append(out) 138 | 139 | mem[l] = new_mem 140 | syn[l] = new_syn 141 | 142 | mem_rec[-1].append(mem[l]) 143 | 144 | # return the final recorded membrane potential in the output layer, all membrane potentials, 145 | # and spikes 146 | return mem_rec[-1][-1], mem_rec, spk_rec 147 | 148 | def load_state_dict(self, layers): 149 | """Method to load weights and biases into the network""" 150 | weights = layers[0] 151 | for l in range(0,len(weights)): 152 | self.weights[l] = weights[l].detach().clone().requires_grad_(True) 153 | 154 | def state_dict(self): 155 | """Method to copy the layers of the SQN. Makes explicit copies, no references.""" 156 | weights_copy = [] 157 | bias_copy = [] 158 | for l in range(0, len(self.weights)): 159 | weights_copy.append(self.weights[l].detach().clone()) 160 | return weights_copy, bias_copy 161 | 162 | def parameters(self): 163 | parameters = [] 164 | for l in range(0, len(self.weights)): 165 | parameters.append(self.weights[l]) 166 | 167 | return parameters 168 | 169 | 170 | class QNetwork(nn.Module): 171 | """Actor (Policy) Model.""" 172 | 173 | def __init__(self, architecture, seed): 174 | """Initialize parameters and build model. 175 | Params 176 | ====== 177 | architecture: 178 | """ 179 | super(QNetwork, self).__init__() 180 | self.layers = nn.ModuleList() 181 | 182 | random.seed(seed) 183 | np.random.seed(seed) 184 | torch.manual_seed(seed) 185 | 186 | for i in range(len(architecture) - 1): 187 | self.layers.append(nn.Linear(architecture[i], architecture[i + 1])) 188 | 189 | def forward(self, x): 190 | """Build a network that maps state -> action values.""" 191 | for layer in self.layers[:-1]: 192 | x = F.relu(layer(x)) 193 | 194 | # no ReLu activation in the output layer 195 | return self.layers[-1](x) 196 | 197 | 198 | class TD3ActorDSNN(DSNN): 199 | def __init__(self, architecture, seed, alpha, beta, weight_scale, batch_size, threshold, 200 | simulation_time, learning_rate, name, device, reset_potential=0, 201 | random_params=False, std=0.1): 202 | self.architecture = architecture 203 | 204 | random.seed(seed) 205 | np.random.seed(seed) 206 | torch.manual_seed(seed) 207 | 208 | self.alpha = alpha 209 | self.beta = beta 210 | self.weight_scale = weight_scale 211 | self.batch_size = batch_size 212 | self.threshold = threshold 213 | self.simulation_time = simulation_time 214 | self.reset_potential = reset_potential 215 | self.random_params = random_params 216 | self.std = std 217 | 218 | self.name = name 219 | 220 | self.spike_fn = SurrGradSpike.apply 221 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 222 | 223 | # Initialize the network weights 224 | self.weights = [] 225 | for i in range(len(architecture) - 1): 226 | self.weights.append(torch.empty((self.architecture[i], self.architecture[i + 1]), 227 | device=device, dtype=torch.float, requires_grad=True)) 228 | torch.nn.init.normal_(self.weights[i], mean=0.0, 229 | std=self.weight_scale / np.sqrt(self.architecture[i])) 230 | 231 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 232 | 233 | if random_params: 234 | self.randomize_parameters() 235 | 236 | def randomize_parameters(self): 237 | # Define new random values for alpha, beta, and threhsold 238 | self.alphas = [] 239 | self.betas = [] 240 | self.thresholds = [] 241 | self.reset_potentials = [] 242 | 243 | for i in range(len(self.architecture) - 1): 244 | self.alphas.append(torch.normal(torch.ones(self.architecture[i + 1])*self.alpha, 245 | self.alpha*self.std). 246 | unsqueeze(0).to(self.device)) 247 | self.alphas[i][self.alphas[i] > 1] = 1 248 | self.alphas[i][self.alphas[i] < 0] = 0 249 | self.betas.append(torch.normal(torch.ones(self.architecture[i + 1])*self.beta, 250 | self.beta*self.std). 251 | unsqueeze(0).to(self.device)) 252 | self.betas[i][self.betas[i] > 1] = 1 253 | self.betas[i][self.betas[i] < 0] = 0 254 | 255 | self.thresholds.append( 256 | torch.normal(torch.ones(self.architecture[i + 1]) * self.threshold, 257 | self.threshold*self.std).to(self.device)) 258 | self.thresholds[i][self.thresholds[i] < 0] = 0 259 | 260 | self.reset_potentials.append( 261 | torch.normal(torch.ones(self.architecture[i + 1])*self.reset_potential, 262 | self.std).to(self.device)) 263 | 264 | def forward(self, inputs): 265 | syn = [] 266 | mem = [] 267 | 268 | for l in range(0, len(self.weights)): 269 | syn.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 270 | dtype=torch.float)) 271 | mem.append(torch.zeros((self.batch_size, self.weights[l].shape[1]), device=device, 272 | dtype=torch.float)) 273 | 274 | # Here we define two lists which we use to record the membrane potentials and output spikes 275 | mem_rec = [] 276 | spk_rec = [] 277 | 278 | # Here we loop over time 279 | for t in range(self.simulation_time): 280 | # append the new timestep to mem_rec and spk_rec 281 | mem_rec.append([]) 282 | spk_rec.append([]) 283 | 284 | if t == 0: 285 | for l in range(len(self.weights)): 286 | mem_rec[-1].append(mem[l]) 287 | spk_rec[-1].append(mem[l]) 288 | continue 289 | 290 | # We take the input as it is, multiply is by the weights, and we inject the outcome 291 | # as current in the neurons of the first hidden layer 292 | input = inputs.detach().clone() 293 | 294 | # loop over layers 295 | for l in range(len(self.weights)): 296 | if l == 0: 297 | h = torch.matmul(input, self.weights[0]) 298 | new_syn = 0 * syn[l] + h 299 | elif l == len(self.weights) - 1: 300 | h = torch.matmul(spk_rec[-1][l - 1], self.weights[l]) 301 | new_syn = 0*syn[l] + h 302 | else: 303 | h = torch.matmul(spk_rec[-1][l - 1], self.weights[l]) 304 | if self.random_params: 305 | new_syn = torch.add(torch.mul(self.alphas[l], syn[l]), h) 306 | else: 307 | new_syn = self.alpha * syn[l] + h 308 | 309 | if l == len(self.weights) - 1: 310 | new_mem = 1*mem[l] + new_syn 311 | else: 312 | if self.random_params: 313 | new_mem = torch.add(torch.mul(self.betas[l], mem[l]), new_syn) 314 | else: 315 | new_mem = self.beta*mem[l] + new_syn 316 | 317 | # calculate the spikes for all layers but the last layer (decoding='potential') 318 | if l < (len(self.weights) - 1): 319 | mthr = new_mem 320 | if self.random_params: 321 | mthr = torch.sub(mthr, self.thresholds[l]) 322 | reset = self.reset_potential 323 | else: 324 | mthr = mthr - self.threshold 325 | reset = self.reset_potential 326 | out = self.spike_fn(mthr) 327 | c = (mthr > 0) 328 | new_mem[c] = self.reset_potential 329 | 330 | spk_rec[-1].append(out) 331 | 332 | mem[l] = new_mem 333 | syn[l] = new_syn 334 | 335 | mem_rec[-1].append(mem[l]) 336 | 337 | # return the final recorded membrane potential (len(mem_rec)-1) in the output layer (-1) 338 | return torch.tanh(mem_rec[-1][-1]), mem_rec, spk_rec 339 | 340 | def save_checkpoint(self, result_dir, episode_num): 341 | #print('... saving checkpoint ...') 342 | torch.save(self.state_dict(), result_dir + '/checkpoint_TD3_{}_{}.pt'.format(self.name, 343 | episode_num)) 344 | 345 | class TD3ActorNetwork(nn.Module): 346 | def __init__(self, learning_rate, input_dims, fc1_dims, fc2_dims, n_actions, name, 347 | checkpoint_dir='tmp/td3'): 348 | super(TD3ActorNetwork, self).__init__() 349 | self.input_dims = input_dims 350 | self.fc1_dims = fc1_dims 351 | self.fc2_dims = fc2_dims 352 | self.n_actions = n_actions 353 | self.name = name 354 | self.checkpoint_dir = checkpoint_dir 355 | self.checkpoint_file = os.path.join(self.checkpoint_dir, name + '_td3') 356 | 357 | self.fc1 = nn.Linear(self.input_dims, fc1_dims, bias=False) 358 | self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims, bias=False) 359 | self.mu = nn.Linear(self.fc2_dims, self.n_actions, bias=False) 360 | 361 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 362 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 363 | 364 | self.to(self.device) 365 | 366 | def forward(self, state): 367 | prob = self.fc1(state) 368 | prob = F.relu(prob) 369 | prob = self.fc2(prob) 370 | prob = F.relu(prob) 371 | 372 | mu = torch.tanh(self.mu(prob)) 373 | 374 | return mu 375 | 376 | def save_checkpoint(self, result_dir): 377 | #print('... saving checkpoint ...') 378 | torch.save(self.state_dict(), result_dir + '/checkpoint_TD3_{}.pt'.format(self.name)) 379 | 380 | def load_checkpoint(self): 381 | print('... loading checkpoint ...') 382 | self.load_state_dict(torch.load(self.checkpoint_file)) 383 | 384 | 385 | 386 | class TD3CriticNetwork(nn.Module): 387 | def __init__(self, learning_rate, input_dims, fc1_dims, fc2_dims, n_actions, name, 388 | checkpoint_dir='tmp/td3'): 389 | super(TD3CriticNetwork, self).__init__() 390 | self.input_dims = input_dims 391 | self.fc1_dims = fc1_dims 392 | self.fc2_dims = fc2_dims 393 | self.n_actions = n_actions 394 | self.name = name 395 | self.checkpoint_dir = checkpoint_dir 396 | self.checkpoint_file = os.path.join(self.checkpoint_dir, name + '_td3') 397 | 398 | self.fc1 = nn.Linear(self.input_dims + n_actions, self.fc1_dims) 399 | self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) 400 | self.q1 = nn.Linear(self.fc2_dims, 1) 401 | 402 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 403 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 404 | 405 | self.to(self.device) 406 | 407 | def forward(self, state, action): 408 | q1_action_value = self.fc1(torch.cat([state, action], dim=1)) 409 | q1_action_value = F.relu(q1_action_value) 410 | q1_action_value = self.fc2(q1_action_value) 411 | q1_action_value = F.relu(q1_action_value) 412 | 413 | q1 = self.q1(q1_action_value) 414 | 415 | return q1 416 | 417 | def save_checkpoint(self, result_dir): 418 | torch.save(self.state_dict(), result_dir + '/checkpoint_TD3_{}.pt'.format(self.name)) 419 | 420 | def load_checkpoint(self): 421 | print('... loading checkpoint ...') 422 | self.load_state_dict(torch.load(self.checkpoint_file)) 423 | -------------------------------------------------------------------------------- /swing-up-cartpole/smoothed_rewards_500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/smoothed_rewards_500.npy -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0100.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0200.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0300.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0400.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0600.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0700.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0700.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0800.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0800.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_0900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_0900.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1000.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1100.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1200.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1300.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1400.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1600.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1700.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1700.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1800.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1800.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_1900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_1900.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2000.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2100.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2200.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2300.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2400.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2600.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2700.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2700.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2800.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2800.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_2900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_2900.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3000.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3100.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3200.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3300.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3400.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3600.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3700.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3700.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3800.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3800.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_3900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_3900.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4000.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4100.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4200.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4300.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4400.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4500.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4600.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4700.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4700.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4800.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4800.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_4900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_4900.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/checkpoint_TD3_actor_5000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/checkpoint_TD3_actor_5000.pt -------------------------------------------------------------------------------- /swing-up-cartpole/snn_results/td3_snn_swing_up_cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoudakl/dsrl/db7774448f4fb31e45eb7e153354b67b3a532508/swing-up-cartpole/snn_results/td3_snn_swing_up_cartpole.png -------------------------------------------------------------------------------- /swing-up-cartpole/td3_snn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import csv 5 | from datetime import datetime 6 | from importlib import import_module 7 | from tqdm import trange 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from collections import namedtuple 12 | 13 | from typing import Any 14 | import gymnasium as gym 15 | import numpy as np 16 | import tensorflow as tf 17 | from numpy.random import SeedSequence 18 | from yaml import dump 19 | from memory_buffer import ReplayBuffer 20 | from model import TD3CriticNetwork, TD3ActorDSNN 21 | 22 | from Control_Toolkit.Controllers import template_controller 23 | from Control_Toolkit.Cost_Functions.cost_function_wrapper import CostFunctionWrapper 24 | from Control_Toolkit.others.environment import EnvironmentBatched 25 | from Environments import ENV_REGISTRY, register_envs 26 | from SI_Toolkit.computation_library import TensorFlowLibrary 27 | from Utilities.csv_helpers import save_to_csv 28 | from Utilities.generate_plots import generate_experiment_plots 29 | from Utilities.utils import ConfigManager, CurrentRunMemory, OutputPath, SeedMemory, get_logger, nested_assignment_to_ordereddict 30 | 31 | 32 | sys.path.append(os.path.join(os.path.abspath("."), "CartPoleSimulation")) # Keep allowing absolute imports within CartPoleSimulation subgit 33 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | register_envs() # Gym API: Register custom environments 35 | logger = get_logger(__name__) 36 | 37 | # td3 parameters 38 | actor_learning_rate = 0.001 39 | critic_learning_rate = 0.001 40 | tau = 0.005 41 | layer1_size = 400 42 | layer2_size = 300 43 | noise = 0.1 44 | gamma = 0.99 45 | warmup = 1000 46 | batch_size = 100 47 | learning_starts = 1000 48 | update_actor_interval = 2 49 | update_target_interval = 2 50 | buffer_size = int(2e5) 51 | normalize = False 52 | episode_counter = 0 53 | learn_step_counter = 0 54 | policy_learn_step_counter = 0 55 | time_step = 0 56 | 57 | # snn parameters 58 | alpha = 0.5 59 | beta = 0.5 60 | weight_scale = 1 61 | threshold = 110 62 | sim_time = 5 63 | two_neuron_encoding = True 64 | spiking = True 65 | quantization = True 66 | 67 | QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point']) 68 | 69 | 70 | def normalize_state(state, max_obs): 71 | if two_neuron_encoding: 72 | two_neuron_max_obs = np.array([val for val in max_obs for _ in (0, 1)]) 73 | return torch.tensor(state / two_neuron_max_obs, dtype=torch.float) 74 | 75 | 76 | def save_models(actor, target_actor, critic_1, target_critic_1, critic_2, target_critic_2, 77 | result_dir, episode_num, q_actor=None): 78 | actor.save_checkpoint(result_dir, episode_num) 79 | target_actor.save_checkpoint(result_dir, episode_num) 80 | critic_1.save_checkpoint(result_dir) 81 | critic_2.save_checkpoint(result_dir) 82 | target_critic_1.save_checkpoint(result_dir) 83 | target_critic_2.save_checkpoint(result_dir) 84 | if q_actor is not None: 85 | q_actor.save_checkpoint(result_dir, episode_num) 86 | 87 | 88 | def quantize_tensor(x, min_val, max_val, qmin=-127, qmax=127): 89 | scale = (max_val - min_val)/(qmax - qmin) 90 | 91 | zero_point = 0 92 | q_x = zero_point + (x/scale) 93 | q_x.clamp(qmin, qmax).round_() 94 | q_x = q_x.round().int() 95 | return QTensor(tensor=q_x, scale=scale, zero_point=zero_point) 96 | 97 | 98 | def quantize_weights(weights): 99 | combined_weights = torch.cat([torch.flatten(w) for w in weights]) 100 | min_val = torch.min(combined_weights) 101 | max_val = torch.max(combined_weights) 102 | quantized_weights = [] 103 | for w in weights: 104 | w = quantize_tensor(w, min_val, max_val, qmin=-127, qmax=127).tensor 105 | quantized_weights.append(w) 106 | 107 | return quantized_weights 108 | 109 | 110 | def choose_action(actor, observation, max_action, min_action, n_actions, max_obs, q_actor=None): 111 | global time_step 112 | if time_step < warmup: 113 | mu = torch.tensor(np.random.normal(scale=noise, size=n_actions), 114 | device=device) 115 | else: 116 | if normalize: 117 | observation = normalize_state(observation, max_obs) 118 | state = observation.clone().to(device) 119 | else: 120 | state = torch.tensor(observation, dtype=torch.float).clone().to(device) 121 | if spiking: 122 | state = state.unsqueeze(0).to(device) 123 | if q_actor is not None: 124 | weights = actor.weights 125 | q_weights = [q_w.float() for q_w in quantize_weights(weights)] 126 | q_actor.weights = q_weights 127 | mu = q_actor.forward(state)[0].squeeze(0).to(device) 128 | else: 129 | mu = actor.forward(state)[0].squeeze(0).to(device) 130 | else: 131 | mu = actor.forward(state).to(device) 132 | 133 | mu_prime = mu + torch.tensor(np.random.normal(scale=noise), dtype=torch.float, 134 | device=device).to(device) 135 | 136 | mu_prime = torch.clamp(mu_prime * max_action[0], min_action[0], max_action[0]) 137 | time_step += 1 138 | 139 | action = mu_prime.cpu().detach().numpy() 140 | action = action.astype(np.float32) 141 | return action 142 | 143 | 144 | def learn(actor, target_actor, critic_1, target_critic_1, critic_2, target_critic_2, memory, 145 | time_step, max_action, min_action, max_obs): 146 | global learn_step_counter, policy_learn_step_counter 147 | 148 | if memory.mem_counter < batch_size or time_step < learning_starts: 149 | return 150 | 151 | state, action, reward, state_, done = memory.sample_buffer(batch_size) 152 | #print(action) 153 | 154 | state = torch.tensor(state, dtype=torch.float).to(device) 155 | action = torch.tensor(action, dtype=torch.float).to(device) 156 | reward = torch.tensor(reward, dtype=torch.float).to(device) 157 | state_ = torch.tensor(state_, dtype=torch.float).to(device) 158 | done = torch.tensor(done).to(device) 159 | 160 | if normalize: 161 | state = normalize_state(state.to('cpu'), max_obs).float().to(device) 162 | state_ = normalize_state(state_.to('cpu'), max_obs).float().to(device) 163 | 164 | if spiking: 165 | target_actions = target_actor.forward(state_)[0].squeeze(0).to(device) 166 | else: 167 | target_actions = target_actor.forward(state_) 168 | 169 | q1 = critic_1.forward(state, action) 170 | q2 = critic_2.forward(state, action) 171 | target_actions = target_actions + torch.clamp(torch.tensor(np.random.normal(scale=0.2)), -0.5, 172 | 0.5) 173 | target_actions = torch.clamp(target_actions, min_action[0], max_action[0]) 174 | 175 | q1_ = target_critic_1.forward(state_, target_actions) 176 | q2_ = target_critic_2.forward(state_, target_actions) 177 | 178 | q1_[done] = 0.0 179 | q2_[done] = 0.0 180 | 181 | critic_value_ = torch.min(q1_.view(-1), q2_.view(-1)).detach() 182 | target = reward + gamma*critic_value_ 183 | target = target.view(batch_size, 1) 184 | 185 | critic_1.optimizer.zero_grad() 186 | critic_2.optimizer.zero_grad() 187 | 188 | q1_loss = F.mse_loss(target, q1) 189 | q2_loss = F.mse_loss(target, q2) 190 | critic_loss = q1_loss + q2_loss 191 | critic_loss.backward(retain_graph=True) 192 | critic_1.optimizer.step() 193 | critic_2.optimizer.step() 194 | 195 | learn_step_counter += 1 196 | 197 | if learn_step_counter % update_actor_interval == 0: 198 | policy_learn_step_counter += 1 199 | actor.optimizer.zero_grad() 200 | if spiking: 201 | actor_q1_loss = critic_1.forward(state, actor.forward(state)[0].squeeze(0)) 202 | else: 203 | actor_q1_loss = critic_1.forward(state, actor.forward(state)) 204 | actor_loss = -torch.mean(actor_q1_loss) 205 | actor_loss.backward(retain_graph=True) 206 | actor.optimizer.step() 207 | 208 | if learn_step_counter % update_target_interval == 0: 209 | update_network_parameters(actor, target_actor, critic_1, target_critic_1, critic_2, 210 | target_critic_2, tau=tau) 211 | 212 | 213 | def update_network_parameters(actor, target_actor, critic_1, target_critic_1, critic_2, 214 | target_critic_2, tau=None): 215 | 216 | # update actor params 217 | if spiking: 218 | actor_state_dict = [d.clone() for d in actor.state_dict()[0]] 219 | target_actor_state_dict = [d.clone() for d in target_actor.state_dict()[0]] 220 | new_state_dict = [[tau*a + (1 - tau)*ta for a, ta in zip(actor_state_dict, 221 | target_actor_state_dict)]] 222 | target_actor.load_state_dict(new_state_dict) 223 | else: 224 | actor_params =actor.named_parameters() 225 | target_actor_params = target_actor.named_parameters() 226 | actor = dict(actor_params) 227 | target_actor = dict(target_actor_params) 228 | 229 | for name in actor: 230 | actor[name] = tau*actor[name].clone() + (1 - tau)*target_actor[name].clone() 231 | 232 | target_actor.load_state_dict(actor) 233 | 234 | # update critic params 235 | critic_1_params = critic_1.named_parameters() 236 | critic_2_params = critic_2.named_parameters() 237 | target_critic_1_params = target_critic_1.named_parameters() 238 | target_critic_2_params = target_critic_2.named_parameters() 239 | critic_1 = dict(critic_1_params) 240 | critic_2 = dict(critic_2_params) 241 | target_critic_1_dict = dict(target_critic_1_params) 242 | target_critic_2_dict = dict(target_critic_2_params) 243 | 244 | for name in critic_1: 245 | critic_1[name] = tau * critic_1[name].clone() + (1 - tau) * target_critic_1_dict[name].clone() 246 | 247 | for name in critic_2: 248 | critic_2[name] = tau * critic_2[name].clone() + (1 - tau) * target_critic_2_dict[name].clone() 249 | 250 | target_critic_1.load_state_dict(critic_1) 251 | target_critic_2.load_state_dict(critic_2) 252 | 253 | 254 | def transform_state(state): 255 | state_ = [] 256 | for i in state: 257 | if i > 0: 258 | state_.append(i) 259 | state_.append(0) 260 | else: 261 | state_.append(0) 262 | state_.append(abs(i)) 263 | return np.array(state_) 264 | 265 | def one_neuron_encoding(state): 266 | state_ = [] 267 | for i in range(len(state)): 268 | if i == 0: 269 | continue 270 | else: 271 | if i % 2 == 0: 272 | state_.append(state[i]) 273 | else: 274 | state_.append(-state[i]) 275 | return np.array(state_) 276 | 277 | 278 | def update_max_obs(state, max_obs): 279 | max_obs = np.maximum(max_obs, np.abs(state)) 280 | return max_obs 281 | 282 | 283 | def run_data_generator(controller_name: str, environment_name: str, config_manager: ConfigManager, 284 | record_path=None): 285 | global time_step, episode_counter 286 | # Generate seeds and set timestamp 287 | timestamp = datetime.now() 288 | seed_entropy = config_manager("config")["seed_entropy"] 289 | if seed_entropy is None: 290 | seed_entropy = int(timestamp.timestamp()) 291 | logger.info("No seed entropy specified. Setting to posix timestamp.") 292 | 293 | num_experiments = config_manager("config")["num_experiments"] 294 | seed_sequences = SeedSequence(entropy=seed_entropy).spawn(num_experiments) 295 | timestamp_str = timestamp.strftime("%Y%m%d-%H%M%S") 296 | 297 | controller_short_name = controller_name.replace("controller_", "").replace("_", "-") 298 | optimizer_short_name = config_manager("config_controllers")[controller_short_name]["optimizer"] 299 | optimizer_name = "optimizer_" + optimizer_short_name.replace("-", "_") 300 | CurrentRunMemory.current_optimizer_name = optimizer_name 301 | all_metrics = dict( 302 | total_rewards = [], 303 | timeout = [], 304 | terminated = [], 305 | truncated = [], 306 | ) 307 | 308 | best_average = -np.inf 309 | best_average_after = np.inf 310 | reward_history = [] 311 | smoothed_scores = [] 312 | 313 | # Generate new seeds for environment and controller 314 | seeds = seed_sequences[0].generate_state(3) 315 | SeedMemory.set_seeds(seeds) 316 | 317 | config_controller = dict(config_manager("config_controllers")[controller_short_name]) 318 | config_optimizer = dict(config_manager("config_optimizers")[optimizer_short_name]) 319 | config_optimizer.update({"seed": int(seeds[1])}) 320 | config_environment = dict(config_manager("config_environments")[environment_name]) 321 | config_environment.update({"seed": int(seeds[0])}) 322 | all_rewards = [] 323 | 324 | ##### ----------------------------------------------- ##### 325 | ##### ----------------- ENVIRONMENT ----------------- ##### 326 | ##### --- Instantiate environment and call reset ---- ##### 327 | if config_manager("config")["render_for_humans"]: 328 | render_mode = "human" 329 | elif config_manager("config")["save_plots_to_file"]: 330 | render_mode = "rgb_array" 331 | else: 332 | render_mode = None 333 | 334 | import matplotlib 335 | 336 | matplotlib.use("Agg") 337 | 338 | env: EnvironmentBatched = gym.make(environment_name, **config_environment, 339 | computation_lib=TensorFlowLibrary, 340 | render_mode=render_mode) 341 | CurrentRunMemory.current_environment = env 342 | obs, obs_info = env.reset(seed=config_environment["seed"]) 343 | assert len( 344 | env.action_space.shape) == 1, f"Action space needs to be a flat vector, is Box with shape {env.action_space.shape}" 345 | 346 | # td3 variables 347 | max_action = env.action_space.high 348 | min_action = env.action_space.low 349 | max_obs = env.observation_space.high 350 | if two_neuron_encoding: 351 | input_dims = (env.observation_space.shape[0]*2) 352 | else: 353 | input_dims = env.observation_space.shape 354 | n_actions = env.action_space.shape[0] 355 | 356 | memory = ReplayBuffer(buffer_size, input_dims, n_actions) 357 | 358 | for i in range(len(max_obs)): 359 | if max_obs[i] == np.inf: 360 | max_obs[i] = 1 361 | 362 | actor_architecture = [input_dims, layer1_size, layer2_size, n_actions] 363 | 364 | # networks 365 | actor = TD3ActorDSNN(actor_architecture, 0, alpha, beta, weight_scale, 1, threshold, 366 | sim_time, actor_learning_rate, name='actor', device=device) 367 | target_actor = TD3ActorDSNN(actor_architecture, 0, alpha, beta, weight_scale, 1, threshold, 368 | sim_time, actor_learning_rate, name='target_actor', device=device) 369 | 370 | if quantization: 371 | q_actor = TD3ActorDSNN(actor_architecture, 0, alpha, beta, weight_scale, 1, threshold, 372 | sim_time, actor_learning_rate, name='quantized_actor', device=device) 373 | else: 374 | q_actor = None 375 | 376 | critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size, layer2_size, 377 | n_actions=n_actions, name='critic_1') 378 | target_critic_1 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size, layer2_size, 379 | n_actions=n_actions, name='target_critic_1') 380 | 381 | critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size, layer2_size, 382 | n_actions=n_actions, name='critic_2') 383 | target_critic_2 = TD3CriticNetwork(critic_learning_rate, input_dims, layer1_size, layer2_size, 384 | n_actions=n_actions, name='target_critic_2') 385 | 386 | #update_network_parameters(actor, target_actor, critic_1, target_critic_1, critic_2, 387 | # target_critic_2, tau=1) 388 | 389 | # Loop through episodes 390 | for i in trange(num_experiments): 391 | 392 | # Generate new seeds for environment and controller 393 | seeds = seed_sequences[i].generate_state(3) 394 | SeedMemory.set_seeds(seeds) 395 | 396 | config_controller = dict(config_manager("config_controllers")[controller_short_name]) 397 | config_optimizer = dict(config_manager("config_optimizers")[optimizer_short_name]) 398 | config_optimizer.update({"seed": int(seeds[1])}) 399 | config_environment = dict(config_manager("config_environments")[environment_name]) 400 | config_environment.update({"seed": int(seeds[0])}) 401 | all_rewards = [] 402 | 403 | episode_counter += 1 404 | print(episode_counter) 405 | obs, obs_info = env.reset(seed=config_environment["seed"]) 406 | max_obs = update_max_obs(obs, max_obs) 407 | 408 | if two_neuron_encoding: 409 | obs = transform_state(obs) 410 | 411 | score = 0 412 | 413 | ##### ---------------------------------------------- ##### 414 | ##### ----------------- CONTROLLER ----------------- ##### 415 | controller_module = import_module(f"Control_Toolkit.Controllers.{controller_name}") 416 | controller: template_controller = getattr(controller_module, controller_name)( 417 | dt=env.dt, 418 | environment_name=ENV_REGISTRY[environment_name].split(":")[-1], 419 | control_limits=(env.action_space.low, env.action_space.high), 420 | initial_environment_attributes=env.environment_attributes) 421 | controller.configure(optimizer_name=optimizer_short_name, predictor_specification=config_controller["predictor_specification"]) 422 | 423 | ##### ----------------------------------------------------- ##### 424 | ##### ----------------- MAIN CONTROL LOOP ----------------- ##### 425 | frames = [] 426 | start_time = time.time() 427 | num_iterations = config_manager("config")["num_iterations"] 428 | for step in range(num_iterations): 429 | #action = controller.step(obs, updated_attributes=env.environment_attributes) 430 | action = choose_action(actor, obs, max_action, min_action, n_actions, max_obs, 431 | q_actor=q_actor) 432 | new_obs, reward, terminated, truncated, info = env.step(action) 433 | max_obs = update_max_obs(new_obs, max_obs) 434 | if two_neuron_encoding: 435 | new_obs_one_neuron = new_obs 436 | new_obs = transform_state(new_obs) 437 | if spiking: 438 | new_obs = new_obs.reshape(input_dims) 439 | c_fun: CostFunctionWrapper = getattr(controller, "cost_function", None) 440 | if c_fun is not None: 441 | assert isinstance(c_fun, CostFunctionWrapper) 442 | # Compute reward from the cost function that the controller optimized 443 | reward = -float(c_fun.get_stage_cost( 444 | tf.convert_to_tensor(new_obs_one_neuron[np.newaxis, np.newaxis, ...]), # Add batch / MPC horizon dimensions 445 | tf.convert_to_tensor(action[np.newaxis, np.newaxis, ...]), 446 | None 447 | )) 448 | all_rewards.append(reward) 449 | if config_controller.get("controller_logging", False): 450 | controller.logs["realized_cost_logged"].append(np.array([-reward]).copy()) 451 | env.set_logs(controller.logs) 452 | if config_manager("config")["render_for_humans"]: 453 | env.render() 454 | elif config_manager("config")["save_plots_to_file"]: 455 | frames.append(env.render()) 456 | 457 | done = terminated or truncated 458 | memory.store_transition(obs, action, reward, new_obs, done) 459 | score += reward 460 | 461 | time.sleep(1e-6) 462 | 463 | obs = new_obs 464 | 465 | learn(actor, target_actor, critic_1, target_critic_1, critic_2, target_critic_2, memory, 466 | time_step, max_action, min_action, max_obs) 467 | 468 | # Print compute time statistics 469 | end_time = time.time() 470 | control_freq = num_iterations / (end_time - start_time) 471 | 472 | reward_history.append(score) 473 | avg_score = np.mean(reward_history[-100:]) 474 | smoothed_scores.append(avg_score) 475 | 476 | if avg_score > best_average: 477 | best_average = avg_score 478 | best_average_after = episode_counter 479 | 480 | print('Episode: ', episode_counter, 'training steps: ', learn_step_counter, 481 | 'score: %.1f' % score, 'Average Score: %.1f' % avg_score) 482 | 483 | if episode_counter % 100 == 0: 484 | print("\rEpisode: ", episode_counter, 'training steps: ', learn_step_counter, 485 | "Average Score: %.2f" % avg_score) 486 | save_models(actor, target_actor, critic_1, target_critic_1, critic_2, target_critic_2, 487 | 'snn_results/', episode_counter, q_actor=q_actor) 488 | 489 | # Close the env 490 | env.close() 491 | print('Best 100 episode average: ', best_average, ' reached at episode ', best_average_after, 492 | '.') 493 | return smoothed_scores 494 | 495 | def prepare_and_run(): 496 | import ruamel.yaml 497 | 498 | # Create a config manager which looks for '.yml' files within the list of folders specified. 499 | # Rationale: We want GUILD AI to be able to update values in configs that we include in this list. 500 | # We might intentionally want to exclude the path to a folder which does contain configs but should not be overwritten by GUILD. 501 | config_manager = ConfigManager(".", "Control_Toolkit_ASF", "SI_Toolkit_ASF", "Environments") 502 | 503 | # Scan for any custom parameters that should overwrite the toolkits' config files: 504 | submodule_configs = ConfigManager("Control_Toolkit_ASF", "SI_Toolkit_ASF", "Environments").loaders 505 | for base_name, loader in submodule_configs.items(): 506 | if base_name in config_manager("config").get("custom_config_overwrites", {}): 507 | data: ruamel.yaml.comments.CommentedMap = loader.load() 508 | update_dict = config_manager("config")["custom_config_overwrites"][base_name] 509 | nested_assignment_to_ordereddict(data, update_dict) 510 | loader.overwrite_config(data) 511 | 512 | # Retrieve required parameters from config: 513 | CurrentRunMemory.current_controller_name = config_manager("config")["controller_name"] 514 | CurrentRunMemory.current_environment_name = config_manager("config")["environment_name"] 515 | 516 | smoothed_scores = run_data_generator(controller_name=CurrentRunMemory.current_controller_name, 517 | environment_name=CurrentRunMemory.current_environment_name, 518 | config_manager=config_manager) 519 | return smoothed_scores 520 | 521 | if __name__ == "__main__": 522 | smoothed_scores = prepare_and_run() 523 | --------------------------------------------------------------------------------