├── .gitignore ├── README.md ├── env.yaml ├── img ├── ReDo_algorithm_pseudocode.png ├── redo_episodic_returns.png ├── redo_tau_0_0_dormant_fraction.png └── redo_tau_0_1_dormant_fraction.png ├── redo.sh ├── redo_dqn.py └── src ├── agent.py ├── benchmark.py ├── buffer.py ├── config.py ├── evaluate.py ├── redo.py ├── utils.py └── wrappers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Experiment tracking 2 | wandb/ 3 | log/ 4 | runs/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *.pyc 10 | 11 | # package 12 | eggs/ 13 | develop-eggs/ 14 | *.egg-info/ 15 | *.egg 16 | 17 | # editor 18 | .idea/ 19 | .vscode/ 20 | .ipynb_checkpoints/ 21 | 22 | # large files 23 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Recycling dormant neurons 2 | 3 | Pytorch reimplementation of [ReDo](https://arxiv.org/abs/2302.12902) (The Dormant Neuron Phenomenon in Deep Reinforcement Learning). 4 | The paper establishes the _dormant neuron phenomenon_, where over the course of training a network with nonstationary targets, a significant portion of the neurons in a deep network become _dormant_, i.e. their activations become minimal to compared the other neurons in the layer. 5 | This phenomenon is particularly prevalent in value-based deep reinforcement learning algorithms, such as DQN and its variants. As a solution, the authors propose to periodically check for dormant neurons and reinitialize them. 6 | 7 | ## Dormant neurons 8 | 9 | The score $s_i^{\ell}$ of a neuron $i$ in layer $l$ is defined as the absolute value of its activation $`\mathbb{E}_{x \in D} |h_i^{\ell}(x)|`$ divided by the normalized average of absolute activations within the layer $`\frac{1}{H^{\ell}} \sum_{k \in h} \mathbb{E}_{x \in D}|h_k^{\ell}(x)|`$: 10 | 11 | $$`s_i^{\ell}=\frac{\mathbb{E}_{x \in D}|h_i^{\ell}(x)|}{\frac{1}{H^{\ell}} \sum_{k \in h} \mathbb{E}_{x \in D}|h_k^{\ell}(x)|}`$$ 12 | 13 | A neuron is defined as $\tau$-dormant when $s_i^{\ell} \leq \tau$. 14 | 15 | ## ReDo 16 | 17 | 18 | 19 | Every $F$-th time step: 20 | 21 | 1. Check whether a neuron $i$ is $\tau$-dormant. 22 | 2. If a neuron $i$ is $\tau$-dormant: 23 | **Re-initialize input weights and bias** of $i$. 24 | _Set_ the **outgoing weights** of $i$ to $0~.$ 25 | 26 | ## Results 27 | 28 | These results were generated using 3 seeds on DemonAttack-v4. Note I was not using typical hyperparameters for DQN, but instead chose a hyperparameter set to exaggerate the dormant neuron phenomenon. 29 | In particular: 30 | 31 | - Updates are done every environment step instead of every 4 steps. 32 | - Target network updates every 2000 steps instead of every 8000. 33 | - Fewer random samples before learing starts. 34 | - $\tau=0.1$ instead of $\tau=0.025$. 35 | 36 | #### Episodic Return 37 | 38 | 39 | 40 | #### Dormant count $\tau=0.0$ 41 | 42 | 43 | 44 | #### Dormant count $\tau=0.1$ 45 | 46 | 47 | 48 | 49 | I've skipped running 10M or 100M experiments because these are very expensive in terms of compute. 50 | 51 | ## Implementation progress 52 | 53 | Update 1: 54 | Fixed and simplified the for-loop in the redo resets. 55 | 56 | Udpate 2: 57 | The reset-check in the main function was on the wrong level and the re-initializations are now properly done in-place and work. 58 | 59 | Update 3: 60 | Adam moment step-count reset is crucial for performance. Else the Adam updates will immediately create dead neurons again. 61 | Preliminary results now look promising. 62 | 63 | Update 4: 64 | Fixed the outgoing weight resets where the mask was generated wrongly and not applied to the outgoing weights. See [this issue](https://github.com/timoklein/redo/issues/3). Thanks @SaminYeasar! 65 | 66 | ## Citations 67 | 68 | Paper: 69 | 70 | ```bibtex 71 | @inproceedings{sokar2023dormant, 72 | title={The dormant neuron phenomenon in deep reinforcement learning}, 73 | author={Sokar, Ghada and Agarwal, Rishabh and Castro, Pablo Samuel and Evci, Utku}, 74 | booktitle={International Conference on Machine Learning}, 75 | pages={32145--32168}, 76 | year={2023}, 77 | organization={PMLR} 78 | } 79 | ``` 80 | 81 | Training code is based on [cleanRL](https://github.com/vwxyzjn/cleanrl): 82 | 83 | ```bibtex 84 | @article{huang2022cleanrl, 85 | author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo}, 86 | title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms}, 87 | journal = {Journal of Machine Learning Research}, 88 | year = {2022}, 89 | volume = {23}, 90 | number = {274}, 91 | pages = {1--18}, 92 | url = {http://jmlr.org/papers/v23/21-1342.html} 93 | } 94 | ``` 95 | 96 | Replay buffer and wrappers are from [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3): 97 | 98 | ```bibtex 99 | @misc{raffin2019stable, 100 | title={Stable baselines3}, 101 | author={Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, 102 | year={2019} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - appdirs=1.4.4=pyh9f0ad1d_0 10 | - blas=2.116=mkl 11 | - blas-devel=3.9.0=16_linux64_mkl 12 | - brotli-python=1.0.9=py310hd8f1fbe_9 13 | - bzip2=1.0.8=h7f98852_4 14 | - ca-certificates=2023.5.7=hbcca054_0 15 | - certifi=2023.5.7=pyhd8ed1ab_0 16 | - charset-normalizer=3.2.0=pyhd8ed1ab_0 17 | - click=8.1.4=unix_pyh707e725_0 18 | - cuda-cudart=11.7.99=0 19 | - cuda-cupti=11.7.101=0 20 | - cuda-libraries=11.7.1=0 21 | - cuda-nvrtc=11.7.99=0 22 | - cuda-nvtx=11.7.91=0 23 | - cuda-runtime=11.7.1=0 24 | - docker-pycreds=0.4.0=py_0 25 | - ffmpeg=4.3=hf484d3e_0 26 | - filelock=3.12.2=pyhd8ed1ab_0 27 | - freetype=2.12.1=hca18f0e_1 28 | - gitdb=4.0.10=pyhd8ed1ab_0 29 | - gitpython=3.1.31=pyhd8ed1ab_0 30 | - gmp=6.2.1=h58526e2_0 31 | - gmpy2=2.1.2=py310h3ec546c_1 32 | - gnutls=3.6.13=h85f3911_1 33 | - icu=72.1=hcb278e6_0 34 | - idna=3.4=pyhd8ed1ab_0 35 | - jinja2=3.1.2=pyhd8ed1ab_1 36 | - jpeg=9e=h0b41bf4_3 37 | - lame=3.100=h166bdaf_1003 38 | - lcms2=2.15=hfd0df8a_0 39 | - ld_impl_linux-64=2.40=h41732ed_0 40 | - lerc=4.0.0=h27087fc_0 41 | - libabseil=20230125.3=cxx17_h59595ed_0 42 | - libblas=3.9.0=16_linux64_mkl 43 | - libcblas=3.9.0=16_linux64_mkl 44 | - libcublas=11.10.3.66=0 45 | - libcufft=10.7.2.124=h4fbf590_0 46 | - libcufile=1.7.0.149=0 47 | - libcurand=10.3.3.53=0 48 | - libcusolver=11.4.0.1=0 49 | - libcusparse=11.7.4.91=0 50 | - libdeflate=1.17=h0b41bf4_0 51 | - libffi=3.4.2=h7f98852_5 52 | - libgcc-ng=13.1.0=he5830b7_0 53 | - libgfortran-ng=13.1.0=h69a702a_0 54 | - libgfortran5=13.1.0=h15d22d2_0 55 | - libgomp=13.1.0=he5830b7_0 56 | - libhwloc=2.9.1=nocuda_h7313eea_6 57 | - libiconv=1.17=h166bdaf_0 58 | - liblapack=3.9.0=16_linux64_mkl 59 | - liblapacke=3.9.0=16_linux64_mkl 60 | - libnpp=11.7.4.75=0 61 | - libnsl=2.0.0=h7f98852_0 62 | - libnvjpeg=11.8.0.2=0 63 | - libpng=1.6.39=h753d276_0 64 | - libprotobuf=4.23.3=hd1fb520_0 65 | - libsqlite=3.42.0=h2797004_0 66 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 67 | - libtiff=4.5.0=h6adf6a1_2 68 | - libuuid=2.38.1=h0b41bf4_0 69 | - libwebp-base=1.3.1=hd590300_0 70 | - libxcb=1.13=h7f98852_1004 71 | - libxml2=2.11.4=h0d562d8_0 72 | - libzlib=1.2.13=hd590300_5 73 | - llvm-openmp=16.0.6=h4dfa4b3_0 74 | - markupsafe=2.1.3=py310h2372a71_0 75 | - mkl=2022.1.0=h84fe81f_915 76 | - mkl-devel=2022.1.0=ha770c72_916 77 | - mkl-include=2022.1.0=h84fe81f_915 78 | - mpc=1.3.1=hfe3b2da_0 79 | - mpfr=4.2.0=hb012696_0 80 | - mpmath=1.3.0=pyhd8ed1ab_0 81 | - ncurses=6.4=hcb278e6_0 82 | - nettle=3.6=he412f7d_0 83 | - networkx=3.1=pyhd8ed1ab_0 84 | - numpy=1.25.1=py310ha4c1d20_0 85 | - openh264=2.1.1=h780b84a_0 86 | - openjpeg=2.5.0=hfec8fc6_2 87 | - openssl=3.1.1=hd590300_1 88 | - pathtools=0.1.2=py_1 89 | - pillow=9.4.0=py310h023d228_1 90 | - pip=23.1.2=pyhd8ed1ab_0 91 | - protobuf=4.23.3=py310hb875b13_0 92 | - psutil=5.9.5=py310h1fa729e_0 93 | - pthread-stubs=0.4=h36c2ea0_1001 94 | - pysocks=1.7.1=pyha2e5f31_6 95 | - python=3.10.12=hd12c33a_0_cpython 96 | - python_abi=3.10=3_cp310 97 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0 98 | - pytorch-cuda=11.7=h778d358_5 99 | - pytorch-mutex=1.0=cuda 100 | - pyyaml=6.0=py310h5764c6d_5 101 | - readline=8.2=h8228510_1 102 | - requests=2.31.0=pyhd8ed1ab_0 103 | - sentry-sdk=1.21.1=pyhd8ed1ab_0 104 | - setproctitle=1.3.2=py310h5764c6d_1 105 | - setuptools=68.0.0=pyhd8ed1ab_0 106 | - six=1.16.0=pyh6c4a22f_0 107 | - smmap=3.0.5=pyh44b312d_0 108 | - sympy=1.12=pypyh9d50eac_103 109 | - tbb=2021.9.0=hf52228f_0 110 | - tk=8.6.12=h27826a3_0 111 | - torchaudio=2.0.2=py310_cu117 112 | - torchtriton=2.0.0=py310 113 | - torchvision=0.15.2=py310_cu117 114 | - typing_extensions=4.7.1=pyha770c72_0 115 | - tzdata=2023c=h71feb2d_0 116 | - urllib3=2.0.3=pyhd8ed1ab_1 117 | - wandb=0.15.5=pyhd8ed1ab_0 118 | - wheel=0.40.0=pyhd8ed1ab_0 119 | - xorg-libxau=1.0.11=hd590300_0 120 | - xorg-libxdmcp=1.1.3=h7f98852_0 121 | - xz=5.2.6=h166bdaf_0 122 | - yaml=0.2.5=h7f98852_2 123 | - zlib=1.2.13=hd590300_5 124 | - zstd=1.5.2=hfc55251_7 125 | - pip: 126 | - ale-py==0.8.1 127 | - autorom==0.4.2 128 | - autorom-accept-rom-license==0.6.1 129 | - cloudpickle==2.2.1 130 | - contourpy==1.1.0 131 | - cycler==0.11.0 132 | - decorator==4.4.2 133 | - farama-notifications==0.0.4 134 | - fonttools==4.40.0 135 | - gymnasium==0.28.1 136 | - imageio==2.31.1 137 | - imageio-ffmpeg==0.4.8 138 | - importlib-resources==6.0.0 139 | - jax-jumpy==1.0.0 140 | - kiwisolver==1.4.4 141 | - lz4==4.3.2 142 | - matplotlib==3.7.2 143 | - moviepy==1.0.3 144 | - mypy-extensions==1.0.0 145 | - opencv-python==4.8.0.74 146 | - packaging==23.1 147 | - proglog==0.1.10 148 | - pyparsing==3.0.9 149 | - pyrallis==0.3.1 150 | - python-dateutil==2.8.2 151 | - shimmy==0.2.1 152 | - tqdm==4.65.0 153 | - typing-inspect==0.9.0 154 | prefix: /export/home/timok34dm/mambaforge/envs/env -------------------------------------------------------------------------------- /img/ReDo_algorithm_pseudocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/ReDo_algorithm_pseudocode.png -------------------------------------------------------------------------------- /img/redo_episodic_returns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_episodic_returns.png -------------------------------------------------------------------------------- /img/redo_tau_0_0_dormant_fraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_tau_0_0_dormant_fraction.png -------------------------------------------------------------------------------- /img/redo_tau_0_1_dormant_fraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_tau_0_1_dormant_fraction.png -------------------------------------------------------------------------------- /redo.sh: -------------------------------------------------------------------------------- 1 | # Simplified benchmarking script from cleanRL: https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/dqn.sh 2 | 3 | OMP_NUM_THREADS=1 python -m src.benchmark \ 4 | --env-ids DemonAttack-v4 \ 5 | --command "python redo_dqn.py --track --enable_redo" \ 6 | --num-seeds 3 \ 7 | --workers 3 -------------------------------------------------------------------------------- /redo_dqn.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy 2 | import os 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import gymnasium as gym 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import tyro 13 | import wandb 14 | 15 | from src.agent import QNetwork, linear_schedule 16 | from src.buffer import ReplayBuffer 17 | from src.config import Config 18 | from src.redo import run_redo 19 | from src.utils import lecun_normal_initializer, make_env, set_cuda_configuration 20 | 21 | 22 | def dqn_loss( 23 | q_network: QNetwork, 24 | target_network: QNetwork, 25 | obs: torch.Tensor, 26 | next_obs: torch.Tensor, 27 | actions: torch.Tensor, 28 | rewards: torch.Tensor, 29 | dones: torch.Tensor, 30 | gamma: float, 31 | ) -> tuple[torch.Tensor, torch.Tensor]: 32 | """Compute the double DQN loss.""" 33 | with torch.no_grad(): 34 | # Get value estimates from the target network 35 | target_vals = target_network.forward(next_obs) 36 | # Select actions through the policy network 37 | policy_actions = q_network(next_obs).argmax(dim=1) 38 | target_max = target_vals[range(len(target_vals)), policy_actions] 39 | # Calculate Q-target 40 | td_target = rewards.flatten() + gamma * target_max * (1 - dones.flatten()) 41 | 42 | old_val = q_network(obs).gather(1, actions).squeeze() 43 | return F.mse_loss(td_target, old_val), old_val 44 | 45 | 46 | def main(cfg: Config) -> None: 47 | """Main training method for ReDO DQN.""" 48 | run_name = f"{cfg.env_id}__{cfg.exp_name}__{cfg.seed}__{int(time.time())}" 49 | 50 | wandb.init( 51 | project=cfg.wandb_project_name, 52 | entity=cfg.wandb_entity, 53 | config=vars(cfg), 54 | name=run_name, 55 | monitor_gym=True, 56 | save_code=True, 57 | mode="online" if cfg.track else "disabled", 58 | ) 59 | 60 | if cfg.save_model: 61 | evaluation_episode = 0 62 | wandb.define_metric("evaluation_episode") 63 | wandb.define_metric("eval/episodic_return", step_metric="evaluation_episode") 64 | 65 | # To get deterministic pytorch to work 66 | if cfg.torch_deterministic: 67 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 68 | torch.use_deterministic_algorithms(True) 69 | 70 | # TRY NOT TO MODIFY: seeding 71 | random.seed(cfg.seed) 72 | np.random.seed(cfg.seed) 73 | torch.manual_seed(cfg.seed) 74 | torch.set_float32_matmul_precision("high") 75 | 76 | device = set_cuda_configuration(cfg.gpu) 77 | 78 | # env setup 79 | envs = gym.vector.SyncVectorEnv( 80 | [make_env(cfg.env_id, cfg.seed + i, i, cfg.capture_video, run_name) for i in range(cfg.num_envs)] 81 | ) 82 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 83 | 84 | q_network = QNetwork(envs).to(device) 85 | if cfg.use_lecun_init: 86 | # Use the same initialization scheme as jax/flax 87 | q_network.apply(lecun_normal_initializer) 88 | optimizer = optim.Adam(q_network.parameters(), lr=cfg.learning_rate, eps=cfg.adam_eps) 89 | target_network = QNetwork(envs).to(device) 90 | target_network.load_state_dict(q_network.state_dict()) 91 | 92 | rb = ReplayBuffer( 93 | cfg.buffer_size, 94 | envs.single_observation_space, 95 | envs.single_action_space, 96 | device, 97 | optimize_memory_usage=True, 98 | handle_timeout_termination=False, 99 | ) 100 | start_time = time.time() 101 | 102 | # TRY NOT TO MODIFY: start the game 103 | obs, _ = envs.reset(seed=cfg.seed) 104 | for global_step in range(cfg.total_timesteps): 105 | # ALGO LOGIC: put action logic here 106 | epsilon = linear_schedule(cfg.start_e, cfg.end_e, cfg.exploration_fraction * cfg.total_timesteps, global_step) 107 | if random.random() < epsilon: 108 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 109 | else: 110 | q_values = q_network(torch.Tensor(obs).to(device)) 111 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 112 | 113 | # TRY NOT TO MODIFY: execute the game and log data. 114 | next_obs, rewards, terminated, truncated, infos = envs.step(actions) 115 | 116 | # TRY NOT TO MODIFY: record rewards for plotting purposes 117 | if "final_info" in infos: 118 | for info in infos["final_info"]: 119 | # Skip the envs that are not done 120 | if "episode" not in info: 121 | continue 122 | epi_return = info["episode"]["r"].item() 123 | print(f"global_step={global_step}, episodic_return={epi_return}") 124 | wandb.log( 125 | { 126 | "charts/episodic_return": epi_return, 127 | "charts/episodic_length": info["episode"]["l"].item(), 128 | "charts/epsilon": epsilon, 129 | }, 130 | step=global_step, 131 | ) 132 | 133 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 134 | real_next_obs = next_obs.copy() 135 | for idx, d in enumerate(truncated): 136 | if d: 137 | real_next_obs[idx] = infos["final_observation"][idx] 138 | rb.add(obs, real_next_obs, actions, rewards, terminated, infos) 139 | 140 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 141 | obs = next_obs 142 | 143 | # ALGO LOGIC: training. 144 | if global_step > cfg.learning_starts: 145 | # Flag for logging 146 | done_update = False 147 | if done_update := global_step % cfg.train_frequency == 0: 148 | data = rb.sample(cfg.batch_size) 149 | loss, old_val = dqn_loss( 150 | q_network=q_network, 151 | target_network=target_network, 152 | obs=data.observations, 153 | next_obs=data.next_observations, 154 | actions=data.actions, 155 | rewards=data.rewards, 156 | dones=data.dones, 157 | gamma=cfg.gamma, 158 | ) 159 | # optimize the model 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | logs = { 165 | "losses/td_loss": loss, 166 | "losses/q_values": old_val.mean().item(), 167 | "charts/SPS": int(global_step / (time.time() - start_time)), 168 | } 169 | 170 | if global_step % cfg.redo_check_interval == 0: 171 | redo_samples = rb.sample(cfg.redo_bs) 172 | redo_out = run_redo( 173 | redo_samples.observations, 174 | model=q_network, 175 | optimizer=optimizer, 176 | tau=cfg.redo_tau, 177 | re_initialize=cfg.enable_redo, 178 | use_lecun_init=cfg.use_lecun_init, 179 | ) 180 | 181 | q_network = redo_out["model"] 182 | optimizer = redo_out["optimizer"] 183 | 184 | logs |= { 185 | f"regularization/dormant_t={cfg.redo_tau}_fraction": redo_out["dormant_fraction"], 186 | f"regularization/dormant_t={cfg.redo_tau}_count": redo_out["dormant_count"], 187 | "regularization/dormant_t=0.0_fraction": redo_out["zero_fraction"], 188 | "regularization/dormant_t=0.0_count": redo_out["zero_count"], 189 | } 190 | 191 | if global_step % 100 == 0 and done_update: 192 | print("SPS:", int(global_step / (time.time() - start_time))) 193 | wandb.log( 194 | logs, 195 | step=global_step, 196 | ) 197 | 198 | # update target network 199 | if global_step % cfg.target_network_frequency == 0: 200 | for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): 201 | target_network_param.data.copy_( 202 | cfg.tau * q_network_param.data + (1.0 - cfg.tau) * target_network_param.data 203 | ) 204 | 205 | if cfg.save_model: 206 | model_path = Path(f"runs/{run_name}/{cfg.exp_name}") 207 | model_path.mkdir(parents=True, exist_ok=True) 208 | torch.save(q_network.state_dict(), model_path / ".cleanrl_model") 209 | print(f"model saved to {model_path}") 210 | from src.evaluate import evaluate 211 | 212 | episodic_returns = evaluate( 213 | model_path=model_path, 214 | make_env=make_env, 215 | env_id=cfg.env_id, 216 | eval_episodes=10, 217 | run_name=f"{run_name}-eval", 218 | Model=QNetwork, 219 | device=device, 220 | epsilon=0.05, 221 | capture_video=False, 222 | ) 223 | for episodic_return in episodic_returns: 224 | wandb.log({"evaluation_episode": evaluation_episode, "eval/episodic_return": episodic_return}) 225 | evaluation_episode += 1 226 | 227 | envs.close() 228 | wandb.finish() 229 | 230 | 231 | if __name__ == "__main__": 232 | cfg = tyro.cli(Config) 233 | main(cfg) 234 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class QNetwork(nn.Module): 7 | """Basic nature DQN agent.""" 8 | 9 | def __init__(self, env): 10 | super().__init__() 11 | self.conv1 = nn.Conv2d(4, 32, 8, stride=4) 12 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 13 | self.conv3 = nn.Conv2d(64, 64, 3, stride=1) 14 | self.fc1 = nn.Linear(3136, 512) 15 | self.q = nn.Linear(512, int(env.single_action_space.n)) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x / 255.0)) 19 | x = F.relu(self.conv2(x)) 20 | x = F.relu(self.conv3(x)) 21 | x = torch.flatten(x, start_dim=1) 22 | x = F.relu(self.fc1(x)) 23 | x = self.q(x) 24 | return x 25 | 26 | 27 | def linear_schedule(start_e: float, end_e: float, duration: float, t: int): 28 | slope = (end_e - start_e) / duration 29 | return max(slope * t + start_e, end_e) 30 | -------------------------------------------------------------------------------- /src/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simplified version of cleanRL's benchmarking script. 3 | The original version can be found here: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/benchmark.py 4 | """ 5 | 6 | import shlex 7 | import subprocess 8 | from dataclasses import dataclass 9 | 10 | import tyro 11 | 12 | 13 | @dataclass 14 | class BenchmarkConfig: 15 | env_ids: tuple[str] = ("CartPole-v1", "Acrobot-v1", "MountainCar-v0") 16 | command: str = "python redo_dqn.py" 17 | start_seed: int = 1 18 | num_seeds: int = 3 19 | workers: int = 3 20 | 21 | 22 | def run_experiment(command: str): 23 | command_list = shlex.split(command) 24 | print(f"running {command}") 25 | fd = subprocess.Popen(command_list) 26 | return_code = fd.wait() 27 | assert return_code == 0 28 | 29 | 30 | if __name__ == "__main__": 31 | args = tyro.cli(BenchmarkConfig) 32 | 33 | commands = [] 34 | for seed in range(0, args.num_seeds): 35 | for env_id in args.env_ids: 36 | commands += [" ".join([args.command, "--env_id", env_id, "--seed", str(args.start_seed + seed)])] 37 | 38 | print("======= commands to run:") 39 | for command in commands: 40 | print(command) 41 | 42 | if args.workers > 0: 43 | from concurrent.futures import ThreadPoolExecutor 44 | 45 | executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-") 46 | for command in commands: 47 | executor.submit(run_experiment, command) 48 | executor.shutdown(wait=True) 49 | else: 50 | print("not running the experiments because --workers is set to 0; just printing the commands to run") 51 | -------------------------------------------------------------------------------- /src/buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a simplified version of the stable-baselines3 replay buffer taken from 3 | https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py 4 | 5 | I've removed unneeded functionality and put all dependencies into a single file. 6 | """ 7 | 8 | import warnings 9 | from abc import ABC, abstractmethod 10 | from typing import Any, Dict, NamedTuple, Union 11 | 12 | import numpy as np 13 | import psutil 14 | import torch 15 | from gymnasium import spaces 16 | 17 | 18 | class ReplayBufferSamples(NamedTuple): 19 | """Container for replay buffer samples.""" 20 | 21 | observations: torch.Tensor 22 | actions: torch.Tensor 23 | next_observations: torch.Tensor 24 | dones: torch.Tensor 25 | rewards: torch.Tensor 26 | 27 | 28 | def get_obs_shape(observation_space: spaces.Space) -> Union[tuple[int, ...], Dict[str, tuple[int, ...]]]: 29 | """ 30 | Get the shape of the observation (useful for the buffers). 31 | :param observation_space: 32 | :return: 33 | """ 34 | if isinstance(observation_space, spaces.Box): 35 | return observation_space.shape 36 | elif isinstance(observation_space, spaces.Discrete): 37 | # Observation is an int 38 | return (1,) 39 | elif isinstance(observation_space, spaces.MultiDiscrete): 40 | # Number of discrete features 41 | return (int(len(observation_space.nvec)),) 42 | elif isinstance(observation_space, spaces.MultiBinary): 43 | # Number of binary features 44 | if type(observation_space.n) in [tuple, list, np.ndarray]: 45 | return tuple(observation_space.n) 46 | else: 47 | return (int(observation_space.n),) 48 | elif isinstance(observation_space, spaces.Dict): 49 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] 50 | else: 51 | raise NotImplementedError(f"{observation_space} observation space is not supported") 52 | 53 | 54 | def get_action_dim(action_space: spaces.Space) -> int: 55 | """ 56 | Get the dimension of the action space. 57 | :param action_space: 58 | :return: 59 | """ 60 | if isinstance(action_space, spaces.Box): 61 | return int(np.prod(action_space.shape)) 62 | elif isinstance(action_space, spaces.Discrete): 63 | # Action is an int 64 | return 1 65 | elif isinstance(action_space, spaces.MultiDiscrete): 66 | # Number of discrete actions 67 | return int(len(action_space.nvec)) 68 | elif isinstance(action_space, spaces.MultiBinary): 69 | # Number of binary actions 70 | return int(action_space.n) 71 | else: 72 | raise NotImplementedError(f"{action_space} action space is not supported") 73 | 74 | 75 | def get_device(device: Union[torch.device, str] = "auto") -> torch.device: 76 | """ 77 | Retrieve PyTorch device. 78 | It checks that the requested device is available first. 79 | For now, it supports only cpu and cuda. 80 | By default, it tries to use the gpu. 81 | :param device: One for 'auto', 'cuda', 'cpu' 82 | :return: Supported Pytorch device 83 | """ 84 | # Cuda by default 85 | if device == "auto": 86 | device = "cuda" 87 | # Force conversion to torch.device 88 | device = torch.device(device) 89 | 90 | # Cuda not available 91 | if device.type == torch.device("cuda").type and not torch.cuda.is_available(): 92 | return torch.device("cpu") 93 | 94 | return device 95 | 96 | 97 | class BaseBuffer(ABC): 98 | """ 99 | Base class that represent a buffer (rollout or replay) 100 | :param buffer_size: Max number of element in the buffer 101 | :param observation_space: Observation space 102 | :param action_space: Action space 103 | :param device: PyTorch device 104 | to which the values will be converted 105 | :param n_envs: Number of parallel environments 106 | """ 107 | 108 | def __init__( 109 | self, 110 | buffer_size: int, 111 | observation_space: spaces.Space, 112 | action_space: spaces.Space, 113 | device: Union[torch.device, str] = "auto", 114 | n_envs: int = 1, 115 | ): 116 | super().__init__() 117 | self.buffer_size = buffer_size 118 | self.observation_space = observation_space 119 | self.action_space = action_space 120 | self.obs_shape = get_obs_shape(observation_space) 121 | 122 | self.action_dim = get_action_dim(action_space) 123 | self.pos = 0 124 | self.full = False 125 | self.device = get_device(device) 126 | self.n_envs = n_envs 127 | 128 | @staticmethod 129 | def swap_and_flatten(arr: np.ndarray) -> np.ndarray: 130 | """ 131 | Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) 132 | to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) 133 | to [n_steps * n_envs, ...] (which maintain the order) 134 | :param arr: 135 | :return: 136 | """ 137 | shape = arr.shape 138 | if len(shape) < 3: 139 | shape = (*shape, 1) 140 | return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) 141 | 142 | def size(self) -> int: 143 | """ 144 | :return: The current size of the buffer 145 | """ 146 | if self.full: 147 | return self.buffer_size 148 | return self.pos 149 | 150 | def add(self, *args, **kwargs) -> None: 151 | """ 152 | Add elements to the buffer. 153 | """ 154 | raise NotImplementedError() 155 | 156 | def extend(self, *args, **kwargs) -> None: 157 | """ 158 | Add a new batch of transitions to the buffer 159 | """ 160 | # Do a for loop along the batch axis 161 | for data in zip(*args): 162 | self.add(*data) 163 | 164 | def reset(self) -> None: 165 | """ 166 | Reset the buffer. 167 | """ 168 | self.pos = 0 169 | self.full = False 170 | 171 | def sample(self, batch_size: int): 172 | """ 173 | :param batch_size: Number of element to sample 174 | :param env: associated gym VecEnv 175 | to normalize the observations/rewards when sampling 176 | :return: 177 | """ 178 | upper_bound = self.buffer_size if self.full else self.pos 179 | batch_inds = np.random.randint(0, upper_bound, size=batch_size) 180 | return self._get_samples(batch_inds) 181 | 182 | @abstractmethod 183 | def _get_samples(self, batch_inds: np.ndarray) -> ReplayBufferSamples: 184 | """ 185 | :param batch_inds: 186 | :param env: 187 | :return: 188 | """ 189 | raise NotImplementedError() 190 | 191 | def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: 192 | """ 193 | Convert a numpy array to a PyTorch tensor. 194 | Note: it copies the data by default 195 | :param array: 196 | :param copy: Whether to copy or not the data (may be useful to avoid changing things 197 | by reference). This argument is inoperative if the device is not the CPU. 198 | :return: 199 | """ 200 | if copy: 201 | return torch.tensor(array, device=self.device) 202 | return torch.as_tensor(array, device=self.device) 203 | 204 | 205 | class ReplayBuffer(BaseBuffer): 206 | """ 207 | Replay buffer used in off-policy algorithms like SAC/TD3. 208 | :param buffer_size: Max number of element in the buffer 209 | :param observation_space: Observation space 210 | :param action_space: Action space 211 | :param device: PyTorch device 212 | :param n_envs: Number of parallel environments 213 | :param optimize_memory_usage: Enable a memory efficient variant 214 | of the replay buffer which reduces by almost a factor two the memory used, 215 | at a cost of more complexity. 216 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 217 | and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 218 | Cannot be used in combination with handle_timeout_termination. 219 | :param handle_timeout_termination: Handle timeout termination (due to timelimit) 220 | separately and treat the task as infinite horizon task. 221 | https://github.com/DLR-RM/stable-baselines3/issues/284 222 | """ 223 | 224 | def __init__( 225 | self, 226 | buffer_size: int, 227 | observation_space: spaces.Space, 228 | action_space: spaces.Space, 229 | device: Union[torch.device, str] = "auto", 230 | n_envs: int = 1, 231 | optimize_memory_usage: bool = False, 232 | handle_timeout_termination: bool = True, 233 | ): 234 | super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) 235 | 236 | # Adjust buffer size 237 | self.buffer_size = max(buffer_size // n_envs, 1) 238 | 239 | # Check that the replay buffer can fit into the memory 240 | if psutil is not None: 241 | mem_available = psutil.virtual_memory().available 242 | 243 | # there is a bug if both optimize_memory_usage and handle_timeout_termination are true 244 | # see https://github.com/DLR-RM/stable-baselines3/issues/934 245 | if optimize_memory_usage and handle_timeout_termination: 246 | raise ValueError( 247 | "ReplayBuffer does not support optimize_memory_usage = True " 248 | "and handle_timeout_termination = True simultaneously." 249 | ) 250 | self.optimize_memory_usage = optimize_memory_usage 251 | 252 | self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) 253 | 254 | if optimize_memory_usage: 255 | # `observations` contains also the next observation 256 | self.next_observations = None 257 | else: 258 | self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) 259 | 260 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) 261 | 262 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 263 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 264 | # Handle timeouts termination properly if needed 265 | # see https://github.com/DLR-RM/stable-baselines3/issues/284 266 | self.handle_timeout_termination = handle_timeout_termination 267 | self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 268 | 269 | if psutil is not None: 270 | total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes 271 | 272 | if self.next_observations is not None: 273 | total_memory_usage += self.next_observations.nbytes 274 | 275 | if total_memory_usage > mem_available: 276 | # Convert to GB 277 | total_memory_usage /= 1e9 278 | mem_available /= 1e9 279 | warnings.warn( 280 | "This system does not have apparently enough memory to store the complete " 281 | f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" 282 | ) 283 | 284 | def add( 285 | self, 286 | obs: np.ndarray, 287 | next_obs: np.ndarray, 288 | action: np.ndarray, 289 | reward: np.ndarray, 290 | done: np.ndarray, 291 | infos: list[Dict[str, Any]], 292 | ) -> None: 293 | # Reshape needed when using multiple envs with discrete observations 294 | # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) 295 | if isinstance(self.observation_space, spaces.Discrete): 296 | obs = obs.reshape((self.n_envs, *self.obs_shape)) 297 | next_obs = next_obs.reshape((self.n_envs, *self.obs_shape)) 298 | 299 | # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 300 | action = action.reshape((self.n_envs, self.action_dim)) 301 | 302 | # Copy to avoid modification by reference 303 | self.observations[self.pos] = np.array(obs).copy() 304 | 305 | if self.optimize_memory_usage: 306 | self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy() 307 | else: 308 | self.next_observations[self.pos] = np.array(next_obs).copy() 309 | 310 | self.actions[self.pos] = np.array(action).copy() 311 | self.rewards[self.pos] = np.array(reward).copy() 312 | self.dones[self.pos] = np.array(done).copy() 313 | 314 | if self.handle_timeout_termination: 315 | self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) 316 | 317 | self.pos += 1 318 | if self.pos == self.buffer_size: 319 | self.full = True 320 | self.pos = 0 321 | 322 | def sample(self, batch_size: int) -> ReplayBufferSamples: 323 | """ 324 | Sample elements from the replay buffer. 325 | Custom sampling when using memory efficient variant, 326 | as we should not sample the element with index `self.pos` 327 | See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 328 | :param batch_size: Number of element to sample 329 | :return: 330 | """ 331 | if not self.optimize_memory_usage: 332 | return super().sample(batch_size=batch_size) 333 | # Do not sample the element with index `self.pos` as the transitions is invalid 334 | # (we use only one array to store `obs` and `next_obs`) 335 | if self.full: 336 | batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size 337 | else: 338 | batch_inds = np.random.randint(0, self.pos, size=batch_size) 339 | return self._get_samples(batch_inds) 340 | 341 | def _get_samples(self, batch_inds: np.ndarray) -> ReplayBufferSamples: 342 | # Sample randomly the env idx 343 | env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) 344 | 345 | if self.optimize_memory_usage: 346 | next_obs = self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :] 347 | else: 348 | next_obs = self.next_observations[batch_inds, env_indices, :] 349 | 350 | data = ( 351 | self.observations[batch_inds, env_indices, :], 352 | self.actions[batch_inds, env_indices, :], 353 | next_obs, 354 | # Only use dones that are not due to timeouts 355 | # deactivated by default (timeouts is initialized as an array of False) 356 | (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), 357 | self.rewards[batch_inds, env_indices].reshape(-1, 1), 358 | ) 359 | return ReplayBufferSamples(*tuple(map(self.to_torch, data))) 360 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config: 6 | """Configuration for a ReDo DQN agent.""" 7 | 8 | # Experiment settings 9 | exp_name: str = "ReDo DQN" 10 | tags: tuple[str, ...] | str | None = None 11 | seed: int = 0 12 | torch_deterministic: bool = True 13 | gpu: int | None = 0 14 | track: bool = False 15 | wandb_project_name: str = "ReDo" 16 | wandb_entity: str | None = None 17 | capture_video: bool = False 18 | save_model: bool = False 19 | 20 | # Environment settings 21 | env_id: str = "DemonAttackNoFrameskip-v4" 22 | total_timesteps: int = 10_000_000 23 | num_envs: int = 1 24 | 25 | # DQN settings 26 | buffer_size: int = 1_000_000 27 | batch_size: int = 32 28 | learning_rate: float = 6.25 * 1e-5 # cleanRL default: 1e-4, theirs: 6.25 * 1e-5 29 | adam_eps: float = 1.5 * 1e-4 30 | use_lecun_init: bool = False # ReDO uses lecun_normal initializer, cleanRL uses the pytorch default (kaiming_uniform) 31 | gamma: float = 0.99 32 | tau: float = 1.0 33 | target_network_frequency: int = 8000 # cleanRL default: 8000, 4 freq -> 8000, 1 -> 2000 34 | start_e: float = 1.0 35 | end_e: float = 0.01 36 | exploration_fraction: float = 0.10 37 | learning_starts: int = 80_000 # cleanRL default: 80000, theirs 20000 38 | train_frequency: int = 4 # cleanRL default: 4, theirs 1 39 | 40 | # ReDo settings 41 | enable_redo: bool = False 42 | redo_tau: float = 0.025 # 0.025 for default, else 0.1 43 | redo_check_interval: int = 1000 44 | redo_bs: int = 64 45 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from cleanRL: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/evals/dqn_eval.py. 3 | """ 4 | 5 | import random 6 | from typing import Callable 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def evaluate( 14 | model_path: str, 15 | make_env: Callable, 16 | env_id: str, 17 | eval_episodes: int, 18 | run_name: str, 19 | Model: torch.nn.Module, 20 | device: torch.device = torch.device("cpu"), 21 | epsilon: float = 0.05, 22 | capture_video: bool = False, 23 | ): 24 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 25 | model = Model(envs).to(device) 26 | model.load_state_dict(torch.load(model_path, map_location=device)) 27 | model.eval() 28 | 29 | obs = envs.reset() 30 | episodic_returns = [] 31 | while len(episodic_returns) < eval_episodes: 32 | if random.random() < epsilon: 33 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 34 | else: 35 | q_values = model(torch.Tensor(obs).to(device)) 36 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 37 | next_obs, _, _, infos = envs.step(actions) 38 | for info in infos: 39 | if "episode" in info.keys(): 40 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 41 | episodic_returns += [info["episode"]["r"]] 42 | obs = next_obs 43 | 44 | return episodic_returns 45 | -------------------------------------------------------------------------------- /src/redo.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import optim 8 | 9 | from .agent import QNetwork 10 | from .buffer import ReplayBufferSamples 11 | 12 | 13 | @torch.no_grad() 14 | def _kaiming_uniform_reinit(layer: nn.Linear | nn.Conv2d, mask: torch.Tensor) -> None: 15 | """Partially re-initializes the bias of a layer according to the Kaiming uniform scheme.""" 16 | 17 | # This is adapted from https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ 18 | fan_in = nn.init._calculate_correct_fan(tensor=layer.weight, mode="fan_in") 19 | gain = nn.init.calculate_gain(nonlinearity="relu", param=math.sqrt(5)) 20 | std = gain / math.sqrt(fan_in) 21 | # Calculate uniform bounds from standard deviation 22 | bound = math.sqrt(3.0) * std 23 | layer.weight.data[mask, ...] = torch.empty_like(layer.weight.data[mask, ...]).uniform_(-bound, bound) 24 | 25 | if layer.bias is not None: 26 | # The original code resets the bias to 0.0 because it uses a different initialization scheme 27 | # layer.bias.data[mask] = 0.0 28 | if isinstance(layer, nn.Conv2d): 29 | if fan_in != 0: 30 | bound = 1 / math.sqrt(fan_in) 31 | layer.bias.data[mask, ...] = torch.empty_like(layer.bias.data[mask, ...]).uniform_(-bound, bound) 32 | else: 33 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 34 | layer.bias.data[mask, ...] = torch.empty_like(layer.bias.data[mask, ...]).uniform_(-bound, bound) 35 | 36 | 37 | @torch.no_grad() 38 | def _lecun_normal_reinit(layer: nn.Linear | nn.Conv2d, mask: torch.Tensor) -> None: 39 | """Partially re-initializes the bias of a layer according to the Lecun normal scheme.""" 40 | 41 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight) 42 | 43 | # This implementation follows the jax one 44 | # https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L260 45 | variance = 1.0 / fan_in 46 | stddev = math.sqrt(variance) / 0.87962566103423978 47 | layer.weight[mask] = nn.init._no_grad_trunc_normal_(layer.weight[mask], mean=0.0, std=1.0, a=-2.0, b=2.0) 48 | layer.weight[mask] *= stddev 49 | if layer.bias is not None: 50 | layer.bias.data[mask] = 0.0 51 | 52 | 53 | @torch.inference_mode() 54 | def _get_activation(name: str, activations: dict[str, torch.Tensor]): 55 | """Fetches and stores the activations of a network layer.""" 56 | 57 | def hook(layer: nn.Linear | nn.Conv2d, input: tuple[torch.Tensor], output: torch.Tensor) -> None: 58 | """ 59 | Get the activations of a layer with relu nonlinearity. 60 | ReLU has to be called explicitly here because the hook is attached to the conv/linear layer. 61 | """ 62 | activations[name] = F.relu(output) 63 | 64 | return hook 65 | 66 | 67 | @torch.inference_mode() 68 | def _get_redo_masks(activations: dict[str, torch.Tensor], tau: float) -> torch.Tensor: 69 | """ 70 | Computes the ReDo mask for a given set of activations. 71 | The returned mask has True where neurons are dormant and False where they are active. 72 | """ 73 | masks = [] 74 | 75 | # Last activation are the q-values, which are never reset 76 | for name, activation in list(activations.items())[:-1]: 77 | # Taking the mean here conforms to the expectation under D in the main paper's formula 78 | if activation.ndim == 4: 79 | # Conv layer 80 | score = activation.abs().mean(dim=(0, 2, 3)) 81 | else: 82 | # Linear layer 83 | score = activation.abs().mean(dim=0) 84 | 85 | # Divide by activation mean to make the threshold independent of the layer size 86 | # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314 87 | # https://github.com/google/dopamine/issues/209 88 | normalized_score = score / (score.mean() + 1e-9) 89 | 90 | layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool) 91 | if tau > 0.0: 92 | layer_mask[normalized_score <= tau] = 1 93 | else: 94 | layer_mask[torch.isclose(normalized_score, torch.zeros_like(normalized_score))] = 1 95 | masks.append(layer_mask) 96 | return masks 97 | 98 | 99 | @torch.no_grad() 100 | def _reset_dormant_neurons(model: QNetwork, redo_masks: torch.Tensor, use_lecun_init: bool) -> QNetwork: 101 | """Re-initializes the dormant neurons of a model.""" 102 | 103 | layers = [(name, layer) for name, layer in list(model.named_modules())[1:]] 104 | assert len(redo_masks) == len(layers) - 1, "Number of masks must match the number of layers" 105 | 106 | # Reset the ingoing weights 107 | # Here the mask size always matches the layer weight size 108 | for i in range(len(layers[:-1])): 109 | mask = redo_masks[i] 110 | layer = layers[i][1] 111 | next_layer = layers[i + 1][1] 112 | # Can be used to not reset outgoing weights in the Q-function 113 | next_layer_name = layers[i + 1][0] 114 | 115 | # Skip if there are no dead neurons 116 | if torch.all(~mask): 117 | # No dormant neurons in this layer 118 | continue 119 | 120 | # The initialization scheme is the same for conv2d and linear 121 | # 1. Reset the ingoing weights using the initialization distribution 122 | if use_lecun_init: 123 | _lecun_normal_reinit(layer, mask) 124 | else: 125 | _kaiming_uniform_reinit(layer, mask) 126 | 127 | # 2. Reset the outgoing weights to 0 128 | # NOTE: Don't reset the bias for the following layer or else you will create new dormant neurons 129 | # To not reset in the last layer: and not next_layer_name == 'q' 130 | if isinstance(layer, nn.Conv2d) and isinstance(next_layer, nn.Linear): 131 | # Special case: Transition from conv to linear layer 132 | # Reset the outgoing weights to 0 with a mask created from the conv filters 133 | num_repeatition = next_layer.weight.data.shape[1] // mask.shape[0] 134 | linear_mask = torch.repeat_interleave(mask, num_repeatition) 135 | next_layer.weight.data[:, linear_mask] = 0.0 136 | else: 137 | # Standard case: layer and next_layer are both conv or both linear 138 | # Reset the outgoing weights to 0 139 | next_layer.weight.data[:, mask, ...] = 0.0 140 | 141 | return model 142 | 143 | 144 | @torch.no_grad() 145 | def _reset_adam_moments(optimizer: optim.Adam, reset_masks: dict[str, torch.Tensor]) -> optim.Adam: 146 | """Resets the moments of the Adam optimizer for the dormant neurons.""" 147 | 148 | assert isinstance(optimizer, optim.Adam), "Moment resetting currently only supported for Adam optimizer" 149 | for i, mask in enumerate(reset_masks): 150 | # Reset the moments for the weights 151 | optimizer.state_dict()["state"][i * 2]["exp_avg"][mask, ...] = 0.0 152 | optimizer.state_dict()["state"][i * 2]["exp_avg_sq"][mask, ...] = 0.0 153 | # NOTE: Step count resets are key to the algorithm's performance 154 | # It's possible to just reset the step for moment that's being reset 155 | optimizer.state_dict()["state"][i * 2]["step"].zero_() 156 | 157 | # Reset the moments for the bias 158 | optimizer.state_dict()["state"][i * 2 + 1]["exp_avg"][mask] = 0.0 159 | optimizer.state_dict()["state"][i * 2 + 1]["exp_avg_sq"][mask] = 0.0 160 | optimizer.state_dict()["state"][i * 2 + 1]["step"].zero_() 161 | 162 | # Reset the moments for the output weights 163 | if ( 164 | len(optimizer.state_dict()["state"][i * 2]["exp_avg"].shape) == 4 165 | and len(optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape) == 2 166 | ): 167 | # Catch transition from conv to linear layer through moment shapes 168 | num_repeatition = optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape[1] // mask.shape[0] 169 | linear_mask = torch.repeat_interleave(mask, num_repeatition) 170 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"][:, linear_mask] = 0.0 171 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg_sq"][:, linear_mask] = 0.0 172 | optimizer.state_dict()["state"][i * 2 + 2]["step"].zero_() 173 | else: 174 | # Standard case: layer and next_layer are both conv or both linear 175 | # Reset the outgoing weights to 0 176 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"][:, mask, ...] = 0.0 177 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg_sq"][:, mask, ...] = 0.0 178 | optimizer.state_dict()["state"][i * 2 + 2]["step"].zero_() 179 | 180 | return optimizer 181 | 182 | 183 | @torch.inference_mode() 184 | def run_redo( 185 | obs: torch.Tensor, 186 | model: QNetwork, 187 | optimizer: optim.Adam, 188 | tau: float, 189 | re_initialize: bool, 190 | use_lecun_init: bool, 191 | ) -> tuple[nn.Module, optim.Adam, float, int]: 192 | """ 193 | Checks the number of dormant neurons for a given model. 194 | If re_initialize is True, then the dormant neurons are re-initialized according to the scheme in 195 | https://arxiv.org/abs/2302.12902 196 | 197 | Returns the number of dormant neurons. 198 | """ 199 | 200 | activations = {} 201 | activation_getter = partial(_get_activation, activations=activations) 202 | 203 | # Register hooks for all Conv2d and Linear layers to calculate activations 204 | handles = [] 205 | for name, module in model.named_modules(): 206 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 207 | handles.append(module.register_forward_hook(activation_getter(name))) 208 | 209 | # Calculate activations 210 | _ = model(obs) 211 | 212 | # Masks for tau=0 logging 213 | zero_masks = _get_redo_masks(activations, 0.0) 214 | total_neurons = sum([torch.numel(mask) for mask in zero_masks]) 215 | zero_count = sum([torch.sum(mask) for mask in zero_masks]) 216 | zero_fraction = (zero_count / total_neurons) * 100 217 | 218 | # Calculate the masks actually used for resetting 219 | masks = _get_redo_masks(activations, tau) 220 | dormant_count = sum([torch.sum(mask) for mask in masks]) 221 | dormant_fraction = (dormant_count / sum([torch.numel(mask) for mask in masks])) * 100 222 | 223 | # Re-initialize the dormant neurons and reset the Adam moments 224 | if re_initialize: 225 | print("Re-initializing dormant neurons") 226 | print(f"Total neurons: {total_neurons} | Dormant neurons: {dormant_count} | Dormant fraction: {dormant_fraction:.2f}%") 227 | model = _reset_dormant_neurons(model, masks, use_lecun_init) 228 | optimizer = _reset_adam_moments(optimizer, masks) 229 | 230 | # Remove the hooks again 231 | for handle in handles: 232 | handle.remove() 233 | 234 | return { 235 | "model": model, 236 | "optimizer": optimizer, 237 | "zero_fraction": zero_fraction, 238 | "zero_count": zero_count, 239 | "dormant_fraction": dormant_fraction, 240 | "dormant_count": dormant_count, 241 | } 242 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing 3 | 4 | import gymnasium as gym 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .wrappers import ( 9 | ClipRewardEnv, 10 | EpisodicLifeEnv, 11 | FireResetEnv, 12 | MaxAndSkipEnv, 13 | NoopResetEnv, 14 | ) 15 | 16 | 17 | def make_env(env_id, seed, idx, capture_video, run_name): 18 | """Helper function to create an environment with some standard wrappers. 19 | 20 | Taken from cleanRL's DQN Atari implementation: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py. 21 | """ 22 | 23 | def thunk(): 24 | if capture_video and idx == 0: 25 | env = gym.make(env_id, render_mode="rgb_array") 26 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 27 | else: 28 | env = gym.make(env_id) 29 | env = gym.wrappers.RecordEpisodeStatistics(env) 30 | env = NoopResetEnv(env, noop_max=30) 31 | env = MaxAndSkipEnv(env, skip=4) 32 | env = EpisodicLifeEnv(env) 33 | if "FIRE" in env.unwrapped.get_action_meanings(): 34 | env = FireResetEnv(env) 35 | env = ClipRewardEnv(env) 36 | env = gym.wrappers.ResizeObservation(env, (84, 84)) 37 | env = gym.wrappers.GrayScaleObservation(env) 38 | env = gym.wrappers.FrameStack(env, 4) 39 | env.action_space.seed(seed) 40 | 41 | return env 42 | 43 | return thunk 44 | 45 | 46 | def set_cuda_configuration(gpu: typing.Any) -> torch.device: 47 | """Set up the device for the desired GPU or all GPUs.""" 48 | 49 | if gpu is None or gpu == -1 or gpu is False: 50 | device = torch.device("cpu") 51 | elif isinstance(gpu, int): 52 | assert gpu <= torch.cuda.device_count(), "Invalid CUDA index specified." 53 | device = torch.device(f"cuda:{gpu}") 54 | else: 55 | device = torch.device("cuda") 56 | 57 | return device 58 | 59 | 60 | @torch.no_grad() 61 | def lecun_normal_initializer(layer: nn.Module) -> None: 62 | """ 63 | Initialization according to LeCun et al. (1998). 64 | See here https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.initializers.lecun_normal.html 65 | and here https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L460 . 66 | Initializes bias terms to 0. 67 | """ 68 | 69 | # Catch case where the whole network is passed 70 | if not isinstance(layer, nn.Linear | nn.Conv2d): 71 | return 72 | 73 | # For a conv layer, this is num_channels*kernel_height*kernel_width 74 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight) 75 | 76 | # This implementation follows the jax one 77 | # https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L260 78 | variance = 1.0 / fan_in 79 | stddev = math.sqrt(variance) / 0.87962566103423978 80 | torch.nn.init.trunc_normal_(layer.weight) 81 | layer.weight *= stddev 82 | if layer.bias is not None: 83 | torch.nn.init.zeros_(layer.bias) 84 | -------------------------------------------------------------------------------- /src/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | The wrappers are taken from stable-baselines3 with unnecessary ones removed. 3 | https://github.com/DLR-RM/stable-baselines3/blob/feat/gymnasium-support/stable_baselines3/common/type_aliases.py 4 | """ 5 | 6 | from typing import Any, SupportsFloat, Union 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | 11 | GymObs = Union[tuple, dict[str, Any], np.ndarray, int] 12 | GymStepReturn = tuple[GymObs, float, bool, dict] 13 | 14 | 15 | AtariResetReturn = tuple[np.ndarray, dict[str, Any]] 16 | AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]] 17 | 18 | 19 | class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 20 | """ 21 | Sticky action. 22 | Paper: https://arxiv.org/abs/1709.06009 23 | Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment 24 | :param env: Environment to wrap 25 | :param action_repeat_probability: Probability of repeating the last action 26 | """ 27 | 28 | def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: 29 | super().__init__(env) 30 | self.action_repeat_probability = action_repeat_probability 31 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] 32 | 33 | def reset(self, **kwargs) -> AtariResetReturn: 34 | self._sticky_action = 0 # NOOP 35 | return self.env.reset(**kwargs) 36 | 37 | def step(self, action: int) -> AtariStepReturn: 38 | if self.np_random.random() >= self.action_repeat_probability: 39 | self._sticky_action = action 40 | return self.env.step(self._sticky_action) 41 | 42 | 43 | class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 44 | """ 45 | Sample initial states by taking random number of no-ops on reset. 46 | No-op is assumed to be action 0. 47 | :param env: Environment to wrap 48 | :param noop_max: Maximum value of no-ops to run 49 | """ 50 | 51 | def __init__(self, env: gym.Env, noop_max: int = 30) -> None: 52 | super().__init__(env) 53 | self.noop_max = noop_max 54 | self.override_num_noops = None 55 | self.noop_action = 0 56 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] 57 | 58 | def reset(self, **kwargs) -> AtariResetReturn: 59 | self.env.reset(**kwargs) 60 | if self.override_num_noops is not None: 61 | noops = self.override_num_noops 62 | else: 63 | noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) 64 | assert noops > 0 65 | obs = np.zeros(0) 66 | info: dict = {} 67 | for _ in range(noops): 68 | obs, _, terminated, truncated, info = self.env.step(self.noop_action) 69 | if terminated or truncated: 70 | obs, info = self.env.reset(**kwargs) 71 | return obs, info 72 | 73 | 74 | class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 75 | """ 76 | Take action on reset for environments that are fixed until firing. 77 | :param env: Environment to wrap 78 | """ 79 | 80 | def __init__(self, env: gym.Env) -> None: 81 | super().__init__(env) 82 | assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] 83 | assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] 84 | 85 | def reset(self, **kwargs) -> AtariResetReturn: 86 | self.env.reset(**kwargs) 87 | obs, _, terminated, truncated, _ = self.env.step(1) 88 | if terminated or truncated: 89 | self.env.reset(**kwargs) 90 | obs, _, terminated, truncated, _ = self.env.step(2) 91 | if terminated or truncated: 92 | self.env.reset(**kwargs) 93 | return obs, {} 94 | 95 | 96 | class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 97 | """ 98 | Make end-of-life == end-of-episode, but only reset on true game over. 99 | Done by DeepMind for the DQN and co. since it helps value estimation. 100 | :param env: Environment to wrap 101 | """ 102 | 103 | def __init__(self, env: gym.Env) -> None: 104 | super().__init__(env) 105 | self.lives = 0 106 | self.was_real_done = True 107 | 108 | def step(self, action: int) -> AtariStepReturn: 109 | obs, reward, terminated, truncated, info = self.env.step(action) 110 | self.was_real_done = terminated or truncated 111 | # check current lives, make loss of life terminal, 112 | # then update lives to handle bonus lives 113 | lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] 114 | if 0 < lives < self.lives: 115 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 116 | # so its important to keep lives > 0, so that we only reset once 117 | # the environment advertises done. 118 | terminated = True 119 | self.lives = lives 120 | return obs, reward, terminated, truncated, info 121 | 122 | def reset(self, **kwargs) -> AtariResetReturn: 123 | """ 124 | Calls the Gym environment reset, only when lives are exhausted. 125 | This way all states are still reachable even though lives are episodic, 126 | and the learner need not know about any of this behind-the-scenes. 127 | :param kwargs: Extra keywords passed to env.reset() call 128 | :return: the first observation of the environment 129 | """ 130 | if self.was_real_done: 131 | obs, info = self.env.reset(**kwargs) 132 | else: 133 | # no-op step to advance from terminal/lost life state 134 | obs, _, terminated, truncated, info = self.env.step(0) 135 | 136 | # The no-op step can lead to a game over, so we need to check it again 137 | # to see if we should reset the environment and avoid the 138 | # monitor.py `RuntimeError: Tried to step environment that needs reset` 139 | if terminated or truncated: 140 | obs, info = self.env.reset(**kwargs) 141 | self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] 142 | return obs, info 143 | 144 | 145 | class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 146 | """ 147 | Return only every ``skip``-th frame (frameskipping) 148 | and return the max between the two last frames. 149 | :param env: Environment to wrap 150 | :param skip: Number of ``skip``-th frame 151 | The same action will be taken ``skip`` times. 152 | """ 153 | 154 | def __init__(self, env: gym.Env, skip: int = 4) -> None: 155 | super().__init__(env) 156 | # most recent raw observations (for max pooling across time steps) 157 | assert env.observation_space.dtype is not None, "No dtype specified for the observation space" 158 | assert env.observation_space.shape is not None, "No shape defined for the observation space" 159 | self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) 160 | self._skip = skip 161 | 162 | def step(self, action: int) -> AtariStepReturn: 163 | """ 164 | Step the environment with the given action 165 | Repeat action, sum reward, and max over last observations. 166 | :param action: the action 167 | :return: observation, reward, terminated, truncated, information 168 | """ 169 | total_reward = 0.0 170 | terminated = truncated = False 171 | for i in range(self._skip): 172 | obs, reward, terminated, truncated, info = self.env.step(action) 173 | done = terminated or truncated 174 | if i == self._skip - 2: 175 | self._obs_buffer[0] = obs 176 | if i == self._skip - 1: 177 | self._obs_buffer[1] = obs 178 | total_reward += float(reward) 179 | if done: 180 | break 181 | # Note that the observation on the done=True frame 182 | # doesn't matter 183 | max_frame = self._obs_buffer.max(axis=0) 184 | 185 | return max_frame, total_reward, terminated, truncated, info 186 | 187 | 188 | class ClipRewardEnv(gym.RewardWrapper): 189 | """ 190 | Clip the reward to {+1, 0, -1} by its sign. 191 | :param env: Environment to wrap 192 | """ 193 | 194 | def __init__(self, env: gym.Env) -> None: 195 | super().__init__(env) 196 | 197 | def reward(self, reward: SupportsFloat) -> float: 198 | """ 199 | Bin reward to {+1, 0, -1} by its sign. 200 | :param reward: 201 | :return: 202 | """ 203 | return np.sign(float(reward)) 204 | --------------------------------------------------------------------------------