├── .gitignore
├── README.md
├── env.yaml
├── img
├── ReDo_algorithm_pseudocode.png
├── redo_episodic_returns.png
├── redo_tau_0_0_dormant_fraction.png
└── redo_tau_0_1_dormant_fraction.png
├── redo.sh
├── redo_dqn.py
└── src
├── agent.py
├── benchmark.py
├── buffer.py
├── config.py
├── evaluate.py
├── redo.py
├── utils.py
└── wrappers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Experiment tracking
2 | wandb/
3 | log/
4 | runs/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *.pyc
10 |
11 | # package
12 | eggs/
13 | develop-eggs/
14 | *.egg-info/
15 | *.egg
16 |
17 | # editor
18 | .idea/
19 | .vscode/
20 | .ipynb_checkpoints/
21 |
22 | # large files
23 | data/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Recycling dormant neurons
2 |
3 | Pytorch reimplementation of [ReDo](https://arxiv.org/abs/2302.12902) (The Dormant Neuron Phenomenon in Deep Reinforcement Learning).
4 | The paper establishes the _dormant neuron phenomenon_, where over the course of training a network with nonstationary targets, a significant portion of the neurons in a deep network become _dormant_, i.e. their activations become minimal to compared the other neurons in the layer.
5 | This phenomenon is particularly prevalent in value-based deep reinforcement learning algorithms, such as DQN and its variants. As a solution, the authors propose to periodically check for dormant neurons and reinitialize them.
6 |
7 | ## Dormant neurons
8 |
9 | The score $s_i^{\ell}$ of a neuron $i$ in layer $l$ is defined as the absolute value of its activation $`\mathbb{E}_{x \in D} |h_i^{\ell}(x)|`$ divided by the normalized average of absolute activations within the layer $`\frac{1}{H^{\ell}} \sum_{k \in h} \mathbb{E}_{x \in D}|h_k^{\ell}(x)|`$:
10 |
11 | $$`s_i^{\ell}=\frac{\mathbb{E}_{x \in D}|h_i^{\ell}(x)|}{\frac{1}{H^{\ell}} \sum_{k \in h} \mathbb{E}_{x \in D}|h_k^{\ell}(x)|}`$$
12 |
13 | A neuron is defined as $\tau$-dormant when $s_i^{\ell} \leq \tau$.
14 |
15 | ## ReDo
16 |
17 |
18 |
19 | Every $F$-th time step:
20 |
21 | 1. Check whether a neuron $i$ is $\tau$-dormant.
22 | 2. If a neuron $i$ is $\tau$-dormant:
23 | **Re-initialize input weights and bias** of $i$.
24 | _Set_ the **outgoing weights** of $i$ to $0~.$
25 |
26 | ## Results
27 |
28 | These results were generated using 3 seeds on DemonAttack-v4. Note I was not using typical hyperparameters for DQN, but instead chose a hyperparameter set to exaggerate the dormant neuron phenomenon.
29 | In particular:
30 |
31 | - Updates are done every environment step instead of every 4 steps.
32 | - Target network updates every 2000 steps instead of every 8000.
33 | - Fewer random samples before learing starts.
34 | - $\tau=0.1$ instead of $\tau=0.025$.
35 |
36 | #### Episodic Return
37 |
38 |
39 |
40 | #### Dormant count $\tau=0.0$
41 |
42 |
43 |
44 | #### Dormant count $\tau=0.1$
45 |
46 |
47 |
48 |
49 | I've skipped running 10M or 100M experiments because these are very expensive in terms of compute.
50 |
51 | ## Implementation progress
52 |
53 | Update 1:
54 | Fixed and simplified the for-loop in the redo resets.
55 |
56 | Udpate 2:
57 | The reset-check in the main function was on the wrong level and the re-initializations are now properly done in-place and work.
58 |
59 | Update 3:
60 | Adam moment step-count reset is crucial for performance. Else the Adam updates will immediately create dead neurons again.
61 | Preliminary results now look promising.
62 |
63 | Update 4:
64 | Fixed the outgoing weight resets where the mask was generated wrongly and not applied to the outgoing weights. See [this issue](https://github.com/timoklein/redo/issues/3). Thanks @SaminYeasar!
65 |
66 | ## Citations
67 |
68 | Paper:
69 |
70 | ```bibtex
71 | @inproceedings{sokar2023dormant,
72 | title={The dormant neuron phenomenon in deep reinforcement learning},
73 | author={Sokar, Ghada and Agarwal, Rishabh and Castro, Pablo Samuel and Evci, Utku},
74 | booktitle={International Conference on Machine Learning},
75 | pages={32145--32168},
76 | year={2023},
77 | organization={PMLR}
78 | }
79 | ```
80 |
81 | Training code is based on [cleanRL](https://github.com/vwxyzjn/cleanrl):
82 |
83 | ```bibtex
84 | @article{huang2022cleanrl,
85 | author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo},
86 | title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms},
87 | journal = {Journal of Machine Learning Research},
88 | year = {2022},
89 | volume = {23},
90 | number = {274},
91 | pages = {1--18},
92 | url = {http://jmlr.org/papers/v23/21-1342.html}
93 | }
94 | ```
95 |
96 | Replay buffer and wrappers are from [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3):
97 |
98 | ```bibtex
99 | @misc{raffin2019stable,
100 | title={Stable baselines3},
101 | author={Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
102 | year={2019}
103 | }
104 | ```
105 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: env
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=2_kmp_llvm
9 | - appdirs=1.4.4=pyh9f0ad1d_0
10 | - blas=2.116=mkl
11 | - blas-devel=3.9.0=16_linux64_mkl
12 | - brotli-python=1.0.9=py310hd8f1fbe_9
13 | - bzip2=1.0.8=h7f98852_4
14 | - ca-certificates=2023.5.7=hbcca054_0
15 | - certifi=2023.5.7=pyhd8ed1ab_0
16 | - charset-normalizer=3.2.0=pyhd8ed1ab_0
17 | - click=8.1.4=unix_pyh707e725_0
18 | - cuda-cudart=11.7.99=0
19 | - cuda-cupti=11.7.101=0
20 | - cuda-libraries=11.7.1=0
21 | - cuda-nvrtc=11.7.99=0
22 | - cuda-nvtx=11.7.91=0
23 | - cuda-runtime=11.7.1=0
24 | - docker-pycreds=0.4.0=py_0
25 | - ffmpeg=4.3=hf484d3e_0
26 | - filelock=3.12.2=pyhd8ed1ab_0
27 | - freetype=2.12.1=hca18f0e_1
28 | - gitdb=4.0.10=pyhd8ed1ab_0
29 | - gitpython=3.1.31=pyhd8ed1ab_0
30 | - gmp=6.2.1=h58526e2_0
31 | - gmpy2=2.1.2=py310h3ec546c_1
32 | - gnutls=3.6.13=h85f3911_1
33 | - icu=72.1=hcb278e6_0
34 | - idna=3.4=pyhd8ed1ab_0
35 | - jinja2=3.1.2=pyhd8ed1ab_1
36 | - jpeg=9e=h0b41bf4_3
37 | - lame=3.100=h166bdaf_1003
38 | - lcms2=2.15=hfd0df8a_0
39 | - ld_impl_linux-64=2.40=h41732ed_0
40 | - lerc=4.0.0=h27087fc_0
41 | - libabseil=20230125.3=cxx17_h59595ed_0
42 | - libblas=3.9.0=16_linux64_mkl
43 | - libcblas=3.9.0=16_linux64_mkl
44 | - libcublas=11.10.3.66=0
45 | - libcufft=10.7.2.124=h4fbf590_0
46 | - libcufile=1.7.0.149=0
47 | - libcurand=10.3.3.53=0
48 | - libcusolver=11.4.0.1=0
49 | - libcusparse=11.7.4.91=0
50 | - libdeflate=1.17=h0b41bf4_0
51 | - libffi=3.4.2=h7f98852_5
52 | - libgcc-ng=13.1.0=he5830b7_0
53 | - libgfortran-ng=13.1.0=h69a702a_0
54 | - libgfortran5=13.1.0=h15d22d2_0
55 | - libgomp=13.1.0=he5830b7_0
56 | - libhwloc=2.9.1=nocuda_h7313eea_6
57 | - libiconv=1.17=h166bdaf_0
58 | - liblapack=3.9.0=16_linux64_mkl
59 | - liblapacke=3.9.0=16_linux64_mkl
60 | - libnpp=11.7.4.75=0
61 | - libnsl=2.0.0=h7f98852_0
62 | - libnvjpeg=11.8.0.2=0
63 | - libpng=1.6.39=h753d276_0
64 | - libprotobuf=4.23.3=hd1fb520_0
65 | - libsqlite=3.42.0=h2797004_0
66 | - libstdcxx-ng=13.1.0=hfd8a6a1_0
67 | - libtiff=4.5.0=h6adf6a1_2
68 | - libuuid=2.38.1=h0b41bf4_0
69 | - libwebp-base=1.3.1=hd590300_0
70 | - libxcb=1.13=h7f98852_1004
71 | - libxml2=2.11.4=h0d562d8_0
72 | - libzlib=1.2.13=hd590300_5
73 | - llvm-openmp=16.0.6=h4dfa4b3_0
74 | - markupsafe=2.1.3=py310h2372a71_0
75 | - mkl=2022.1.0=h84fe81f_915
76 | - mkl-devel=2022.1.0=ha770c72_916
77 | - mkl-include=2022.1.0=h84fe81f_915
78 | - mpc=1.3.1=hfe3b2da_0
79 | - mpfr=4.2.0=hb012696_0
80 | - mpmath=1.3.0=pyhd8ed1ab_0
81 | - ncurses=6.4=hcb278e6_0
82 | - nettle=3.6=he412f7d_0
83 | - networkx=3.1=pyhd8ed1ab_0
84 | - numpy=1.25.1=py310ha4c1d20_0
85 | - openh264=2.1.1=h780b84a_0
86 | - openjpeg=2.5.0=hfec8fc6_2
87 | - openssl=3.1.1=hd590300_1
88 | - pathtools=0.1.2=py_1
89 | - pillow=9.4.0=py310h023d228_1
90 | - pip=23.1.2=pyhd8ed1ab_0
91 | - protobuf=4.23.3=py310hb875b13_0
92 | - psutil=5.9.5=py310h1fa729e_0
93 | - pthread-stubs=0.4=h36c2ea0_1001
94 | - pysocks=1.7.1=pyha2e5f31_6
95 | - python=3.10.12=hd12c33a_0_cpython
96 | - python_abi=3.10=3_cp310
97 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0
98 | - pytorch-cuda=11.7=h778d358_5
99 | - pytorch-mutex=1.0=cuda
100 | - pyyaml=6.0=py310h5764c6d_5
101 | - readline=8.2=h8228510_1
102 | - requests=2.31.0=pyhd8ed1ab_0
103 | - sentry-sdk=1.21.1=pyhd8ed1ab_0
104 | - setproctitle=1.3.2=py310h5764c6d_1
105 | - setuptools=68.0.0=pyhd8ed1ab_0
106 | - six=1.16.0=pyh6c4a22f_0
107 | - smmap=3.0.5=pyh44b312d_0
108 | - sympy=1.12=pypyh9d50eac_103
109 | - tbb=2021.9.0=hf52228f_0
110 | - tk=8.6.12=h27826a3_0
111 | - torchaudio=2.0.2=py310_cu117
112 | - torchtriton=2.0.0=py310
113 | - torchvision=0.15.2=py310_cu117
114 | - typing_extensions=4.7.1=pyha770c72_0
115 | - tzdata=2023c=h71feb2d_0
116 | - urllib3=2.0.3=pyhd8ed1ab_1
117 | - wandb=0.15.5=pyhd8ed1ab_0
118 | - wheel=0.40.0=pyhd8ed1ab_0
119 | - xorg-libxau=1.0.11=hd590300_0
120 | - xorg-libxdmcp=1.1.3=h7f98852_0
121 | - xz=5.2.6=h166bdaf_0
122 | - yaml=0.2.5=h7f98852_2
123 | - zlib=1.2.13=hd590300_5
124 | - zstd=1.5.2=hfc55251_7
125 | - pip:
126 | - ale-py==0.8.1
127 | - autorom==0.4.2
128 | - autorom-accept-rom-license==0.6.1
129 | - cloudpickle==2.2.1
130 | - contourpy==1.1.0
131 | - cycler==0.11.0
132 | - decorator==4.4.2
133 | - farama-notifications==0.0.4
134 | - fonttools==4.40.0
135 | - gymnasium==0.28.1
136 | - imageio==2.31.1
137 | - imageio-ffmpeg==0.4.8
138 | - importlib-resources==6.0.0
139 | - jax-jumpy==1.0.0
140 | - kiwisolver==1.4.4
141 | - lz4==4.3.2
142 | - matplotlib==3.7.2
143 | - moviepy==1.0.3
144 | - mypy-extensions==1.0.0
145 | - opencv-python==4.8.0.74
146 | - packaging==23.1
147 | - proglog==0.1.10
148 | - pyparsing==3.0.9
149 | - pyrallis==0.3.1
150 | - python-dateutil==2.8.2
151 | - shimmy==0.2.1
152 | - tqdm==4.65.0
153 | - typing-inspect==0.9.0
154 | prefix: /export/home/timok34dm/mambaforge/envs/env
--------------------------------------------------------------------------------
/img/ReDo_algorithm_pseudocode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/ReDo_algorithm_pseudocode.png
--------------------------------------------------------------------------------
/img/redo_episodic_returns.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_episodic_returns.png
--------------------------------------------------------------------------------
/img/redo_tau_0_0_dormant_fraction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_tau_0_0_dormant_fraction.png
--------------------------------------------------------------------------------
/img/redo_tau_0_1_dormant_fraction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/timoklein/redo/3c0e2a9661ad85ca80523fd205da5eed8eb7b859/img/redo_tau_0_1_dormant_fraction.png
--------------------------------------------------------------------------------
/redo.sh:
--------------------------------------------------------------------------------
1 | # Simplified benchmarking script from cleanRL: https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/dqn.sh
2 |
3 | OMP_NUM_THREADS=1 python -m src.benchmark \
4 | --env-ids DemonAttack-v4 \
5 | --command "python redo_dqn.py --track --enable_redo" \
6 | --num-seeds 3 \
7 | --workers 3
--------------------------------------------------------------------------------
/redo_dqn.py:
--------------------------------------------------------------------------------
1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy
2 | import os
3 | import random
4 | import time
5 | from pathlib import Path
6 |
7 | import gymnasium as gym
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | import tyro
13 | import wandb
14 |
15 | from src.agent import QNetwork, linear_schedule
16 | from src.buffer import ReplayBuffer
17 | from src.config import Config
18 | from src.redo import run_redo
19 | from src.utils import lecun_normal_initializer, make_env, set_cuda_configuration
20 |
21 |
22 | def dqn_loss(
23 | q_network: QNetwork,
24 | target_network: QNetwork,
25 | obs: torch.Tensor,
26 | next_obs: torch.Tensor,
27 | actions: torch.Tensor,
28 | rewards: torch.Tensor,
29 | dones: torch.Tensor,
30 | gamma: float,
31 | ) -> tuple[torch.Tensor, torch.Tensor]:
32 | """Compute the double DQN loss."""
33 | with torch.no_grad():
34 | # Get value estimates from the target network
35 | target_vals = target_network.forward(next_obs)
36 | # Select actions through the policy network
37 | policy_actions = q_network(next_obs).argmax(dim=1)
38 | target_max = target_vals[range(len(target_vals)), policy_actions]
39 | # Calculate Q-target
40 | td_target = rewards.flatten() + gamma * target_max * (1 - dones.flatten())
41 |
42 | old_val = q_network(obs).gather(1, actions).squeeze()
43 | return F.mse_loss(td_target, old_val), old_val
44 |
45 |
46 | def main(cfg: Config) -> None:
47 | """Main training method for ReDO DQN."""
48 | run_name = f"{cfg.env_id}__{cfg.exp_name}__{cfg.seed}__{int(time.time())}"
49 |
50 | wandb.init(
51 | project=cfg.wandb_project_name,
52 | entity=cfg.wandb_entity,
53 | config=vars(cfg),
54 | name=run_name,
55 | monitor_gym=True,
56 | save_code=True,
57 | mode="online" if cfg.track else "disabled",
58 | )
59 |
60 | if cfg.save_model:
61 | evaluation_episode = 0
62 | wandb.define_metric("evaluation_episode")
63 | wandb.define_metric("eval/episodic_return", step_metric="evaluation_episode")
64 |
65 | # To get deterministic pytorch to work
66 | if cfg.torch_deterministic:
67 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
68 | torch.use_deterministic_algorithms(True)
69 |
70 | # TRY NOT TO MODIFY: seeding
71 | random.seed(cfg.seed)
72 | np.random.seed(cfg.seed)
73 | torch.manual_seed(cfg.seed)
74 | torch.set_float32_matmul_precision("high")
75 |
76 | device = set_cuda_configuration(cfg.gpu)
77 |
78 | # env setup
79 | envs = gym.vector.SyncVectorEnv(
80 | [make_env(cfg.env_id, cfg.seed + i, i, cfg.capture_video, run_name) for i in range(cfg.num_envs)]
81 | )
82 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
83 |
84 | q_network = QNetwork(envs).to(device)
85 | if cfg.use_lecun_init:
86 | # Use the same initialization scheme as jax/flax
87 | q_network.apply(lecun_normal_initializer)
88 | optimizer = optim.Adam(q_network.parameters(), lr=cfg.learning_rate, eps=cfg.adam_eps)
89 | target_network = QNetwork(envs).to(device)
90 | target_network.load_state_dict(q_network.state_dict())
91 |
92 | rb = ReplayBuffer(
93 | cfg.buffer_size,
94 | envs.single_observation_space,
95 | envs.single_action_space,
96 | device,
97 | optimize_memory_usage=True,
98 | handle_timeout_termination=False,
99 | )
100 | start_time = time.time()
101 |
102 | # TRY NOT TO MODIFY: start the game
103 | obs, _ = envs.reset(seed=cfg.seed)
104 | for global_step in range(cfg.total_timesteps):
105 | # ALGO LOGIC: put action logic here
106 | epsilon = linear_schedule(cfg.start_e, cfg.end_e, cfg.exploration_fraction * cfg.total_timesteps, global_step)
107 | if random.random() < epsilon:
108 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
109 | else:
110 | q_values = q_network(torch.Tensor(obs).to(device))
111 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
112 |
113 | # TRY NOT TO MODIFY: execute the game and log data.
114 | next_obs, rewards, terminated, truncated, infos = envs.step(actions)
115 |
116 | # TRY NOT TO MODIFY: record rewards for plotting purposes
117 | if "final_info" in infos:
118 | for info in infos["final_info"]:
119 | # Skip the envs that are not done
120 | if "episode" not in info:
121 | continue
122 | epi_return = info["episode"]["r"].item()
123 | print(f"global_step={global_step}, episodic_return={epi_return}")
124 | wandb.log(
125 | {
126 | "charts/episodic_return": epi_return,
127 | "charts/episodic_length": info["episode"]["l"].item(),
128 | "charts/epsilon": epsilon,
129 | },
130 | step=global_step,
131 | )
132 |
133 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
134 | real_next_obs = next_obs.copy()
135 | for idx, d in enumerate(truncated):
136 | if d:
137 | real_next_obs[idx] = infos["final_observation"][idx]
138 | rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
139 |
140 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
141 | obs = next_obs
142 |
143 | # ALGO LOGIC: training.
144 | if global_step > cfg.learning_starts:
145 | # Flag for logging
146 | done_update = False
147 | if done_update := global_step % cfg.train_frequency == 0:
148 | data = rb.sample(cfg.batch_size)
149 | loss, old_val = dqn_loss(
150 | q_network=q_network,
151 | target_network=target_network,
152 | obs=data.observations,
153 | next_obs=data.next_observations,
154 | actions=data.actions,
155 | rewards=data.rewards,
156 | dones=data.dones,
157 | gamma=cfg.gamma,
158 | )
159 | # optimize the model
160 | optimizer.zero_grad()
161 | loss.backward()
162 | optimizer.step()
163 |
164 | logs = {
165 | "losses/td_loss": loss,
166 | "losses/q_values": old_val.mean().item(),
167 | "charts/SPS": int(global_step / (time.time() - start_time)),
168 | }
169 |
170 | if global_step % cfg.redo_check_interval == 0:
171 | redo_samples = rb.sample(cfg.redo_bs)
172 | redo_out = run_redo(
173 | redo_samples.observations,
174 | model=q_network,
175 | optimizer=optimizer,
176 | tau=cfg.redo_tau,
177 | re_initialize=cfg.enable_redo,
178 | use_lecun_init=cfg.use_lecun_init,
179 | )
180 |
181 | q_network = redo_out["model"]
182 | optimizer = redo_out["optimizer"]
183 |
184 | logs |= {
185 | f"regularization/dormant_t={cfg.redo_tau}_fraction": redo_out["dormant_fraction"],
186 | f"regularization/dormant_t={cfg.redo_tau}_count": redo_out["dormant_count"],
187 | "regularization/dormant_t=0.0_fraction": redo_out["zero_fraction"],
188 | "regularization/dormant_t=0.0_count": redo_out["zero_count"],
189 | }
190 |
191 | if global_step % 100 == 0 and done_update:
192 | print("SPS:", int(global_step / (time.time() - start_time)))
193 | wandb.log(
194 | logs,
195 | step=global_step,
196 | )
197 |
198 | # update target network
199 | if global_step % cfg.target_network_frequency == 0:
200 | for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
201 | target_network_param.data.copy_(
202 | cfg.tau * q_network_param.data + (1.0 - cfg.tau) * target_network_param.data
203 | )
204 |
205 | if cfg.save_model:
206 | model_path = Path(f"runs/{run_name}/{cfg.exp_name}")
207 | model_path.mkdir(parents=True, exist_ok=True)
208 | torch.save(q_network.state_dict(), model_path / ".cleanrl_model")
209 | print(f"model saved to {model_path}")
210 | from src.evaluate import evaluate
211 |
212 | episodic_returns = evaluate(
213 | model_path=model_path,
214 | make_env=make_env,
215 | env_id=cfg.env_id,
216 | eval_episodes=10,
217 | run_name=f"{run_name}-eval",
218 | Model=QNetwork,
219 | device=device,
220 | epsilon=0.05,
221 | capture_video=False,
222 | )
223 | for episodic_return in episodic_returns:
224 | wandb.log({"evaluation_episode": evaluation_episode, "eval/episodic_return": episodic_return})
225 | evaluation_episode += 1
226 |
227 | envs.close()
228 | wandb.finish()
229 |
230 |
231 | if __name__ == "__main__":
232 | cfg = tyro.cli(Config)
233 | main(cfg)
234 |
--------------------------------------------------------------------------------
/src/agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class QNetwork(nn.Module):
7 | """Basic nature DQN agent."""
8 |
9 | def __init__(self, env):
10 | super().__init__()
11 | self.conv1 = nn.Conv2d(4, 32, 8, stride=4)
12 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
13 | self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
14 | self.fc1 = nn.Linear(3136, 512)
15 | self.q = nn.Linear(512, int(env.single_action_space.n))
16 |
17 | def forward(self, x):
18 | x = F.relu(self.conv1(x / 255.0))
19 | x = F.relu(self.conv2(x))
20 | x = F.relu(self.conv3(x))
21 | x = torch.flatten(x, start_dim=1)
22 | x = F.relu(self.fc1(x))
23 | x = self.q(x)
24 | return x
25 |
26 |
27 | def linear_schedule(start_e: float, end_e: float, duration: float, t: int):
28 | slope = (end_e - start_e) / duration
29 | return max(slope * t + start_e, end_e)
30 |
--------------------------------------------------------------------------------
/src/benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | Simplified version of cleanRL's benchmarking script.
3 | The original version can be found here: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/benchmark.py
4 | """
5 |
6 | import shlex
7 | import subprocess
8 | from dataclasses import dataclass
9 |
10 | import tyro
11 |
12 |
13 | @dataclass
14 | class BenchmarkConfig:
15 | env_ids: tuple[str] = ("CartPole-v1", "Acrobot-v1", "MountainCar-v0")
16 | command: str = "python redo_dqn.py"
17 | start_seed: int = 1
18 | num_seeds: int = 3
19 | workers: int = 3
20 |
21 |
22 | def run_experiment(command: str):
23 | command_list = shlex.split(command)
24 | print(f"running {command}")
25 | fd = subprocess.Popen(command_list)
26 | return_code = fd.wait()
27 | assert return_code == 0
28 |
29 |
30 | if __name__ == "__main__":
31 | args = tyro.cli(BenchmarkConfig)
32 |
33 | commands = []
34 | for seed in range(0, args.num_seeds):
35 | for env_id in args.env_ids:
36 | commands += [" ".join([args.command, "--env_id", env_id, "--seed", str(args.start_seed + seed)])]
37 |
38 | print("======= commands to run:")
39 | for command in commands:
40 | print(command)
41 |
42 | if args.workers > 0:
43 | from concurrent.futures import ThreadPoolExecutor
44 |
45 | executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-")
46 | for command in commands:
47 | executor.submit(run_experiment, command)
48 | executor.shutdown(wait=True)
49 | else:
50 | print("not running the experiments because --workers is set to 0; just printing the commands to run")
51 |
--------------------------------------------------------------------------------
/src/buffer.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a simplified version of the stable-baselines3 replay buffer taken from
3 | https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py
4 |
5 | I've removed unneeded functionality and put all dependencies into a single file.
6 | """
7 |
8 | import warnings
9 | from abc import ABC, abstractmethod
10 | from typing import Any, Dict, NamedTuple, Union
11 |
12 | import numpy as np
13 | import psutil
14 | import torch
15 | from gymnasium import spaces
16 |
17 |
18 | class ReplayBufferSamples(NamedTuple):
19 | """Container for replay buffer samples."""
20 |
21 | observations: torch.Tensor
22 | actions: torch.Tensor
23 | next_observations: torch.Tensor
24 | dones: torch.Tensor
25 | rewards: torch.Tensor
26 |
27 |
28 | def get_obs_shape(observation_space: spaces.Space) -> Union[tuple[int, ...], Dict[str, tuple[int, ...]]]:
29 | """
30 | Get the shape of the observation (useful for the buffers).
31 | :param observation_space:
32 | :return:
33 | """
34 | if isinstance(observation_space, spaces.Box):
35 | return observation_space.shape
36 | elif isinstance(observation_space, spaces.Discrete):
37 | # Observation is an int
38 | return (1,)
39 | elif isinstance(observation_space, spaces.MultiDiscrete):
40 | # Number of discrete features
41 | return (int(len(observation_space.nvec)),)
42 | elif isinstance(observation_space, spaces.MultiBinary):
43 | # Number of binary features
44 | if type(observation_space.n) in [tuple, list, np.ndarray]:
45 | return tuple(observation_space.n)
46 | else:
47 | return (int(observation_space.n),)
48 | elif isinstance(observation_space, spaces.Dict):
49 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]
50 | else:
51 | raise NotImplementedError(f"{observation_space} observation space is not supported")
52 |
53 |
54 | def get_action_dim(action_space: spaces.Space) -> int:
55 | """
56 | Get the dimension of the action space.
57 | :param action_space:
58 | :return:
59 | """
60 | if isinstance(action_space, spaces.Box):
61 | return int(np.prod(action_space.shape))
62 | elif isinstance(action_space, spaces.Discrete):
63 | # Action is an int
64 | return 1
65 | elif isinstance(action_space, spaces.MultiDiscrete):
66 | # Number of discrete actions
67 | return int(len(action_space.nvec))
68 | elif isinstance(action_space, spaces.MultiBinary):
69 | # Number of binary actions
70 | return int(action_space.n)
71 | else:
72 | raise NotImplementedError(f"{action_space} action space is not supported")
73 |
74 |
75 | def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
76 | """
77 | Retrieve PyTorch device.
78 | It checks that the requested device is available first.
79 | For now, it supports only cpu and cuda.
80 | By default, it tries to use the gpu.
81 | :param device: One for 'auto', 'cuda', 'cpu'
82 | :return: Supported Pytorch device
83 | """
84 | # Cuda by default
85 | if device == "auto":
86 | device = "cuda"
87 | # Force conversion to torch.device
88 | device = torch.device(device)
89 |
90 | # Cuda not available
91 | if device.type == torch.device("cuda").type and not torch.cuda.is_available():
92 | return torch.device("cpu")
93 |
94 | return device
95 |
96 |
97 | class BaseBuffer(ABC):
98 | """
99 | Base class that represent a buffer (rollout or replay)
100 | :param buffer_size: Max number of element in the buffer
101 | :param observation_space: Observation space
102 | :param action_space: Action space
103 | :param device: PyTorch device
104 | to which the values will be converted
105 | :param n_envs: Number of parallel environments
106 | """
107 |
108 | def __init__(
109 | self,
110 | buffer_size: int,
111 | observation_space: spaces.Space,
112 | action_space: spaces.Space,
113 | device: Union[torch.device, str] = "auto",
114 | n_envs: int = 1,
115 | ):
116 | super().__init__()
117 | self.buffer_size = buffer_size
118 | self.observation_space = observation_space
119 | self.action_space = action_space
120 | self.obs_shape = get_obs_shape(observation_space)
121 |
122 | self.action_dim = get_action_dim(action_space)
123 | self.pos = 0
124 | self.full = False
125 | self.device = get_device(device)
126 | self.n_envs = n_envs
127 |
128 | @staticmethod
129 | def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
130 | """
131 | Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
132 | to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
133 | to [n_steps * n_envs, ...] (which maintain the order)
134 | :param arr:
135 | :return:
136 | """
137 | shape = arr.shape
138 | if len(shape) < 3:
139 | shape = (*shape, 1)
140 | return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
141 |
142 | def size(self) -> int:
143 | """
144 | :return: The current size of the buffer
145 | """
146 | if self.full:
147 | return self.buffer_size
148 | return self.pos
149 |
150 | def add(self, *args, **kwargs) -> None:
151 | """
152 | Add elements to the buffer.
153 | """
154 | raise NotImplementedError()
155 |
156 | def extend(self, *args, **kwargs) -> None:
157 | """
158 | Add a new batch of transitions to the buffer
159 | """
160 | # Do a for loop along the batch axis
161 | for data in zip(*args):
162 | self.add(*data)
163 |
164 | def reset(self) -> None:
165 | """
166 | Reset the buffer.
167 | """
168 | self.pos = 0
169 | self.full = False
170 |
171 | def sample(self, batch_size: int):
172 | """
173 | :param batch_size: Number of element to sample
174 | :param env: associated gym VecEnv
175 | to normalize the observations/rewards when sampling
176 | :return:
177 | """
178 | upper_bound = self.buffer_size if self.full else self.pos
179 | batch_inds = np.random.randint(0, upper_bound, size=batch_size)
180 | return self._get_samples(batch_inds)
181 |
182 | @abstractmethod
183 | def _get_samples(self, batch_inds: np.ndarray) -> ReplayBufferSamples:
184 | """
185 | :param batch_inds:
186 | :param env:
187 | :return:
188 | """
189 | raise NotImplementedError()
190 |
191 | def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor:
192 | """
193 | Convert a numpy array to a PyTorch tensor.
194 | Note: it copies the data by default
195 | :param array:
196 | :param copy: Whether to copy or not the data (may be useful to avoid changing things
197 | by reference). This argument is inoperative if the device is not the CPU.
198 | :return:
199 | """
200 | if copy:
201 | return torch.tensor(array, device=self.device)
202 | return torch.as_tensor(array, device=self.device)
203 |
204 |
205 | class ReplayBuffer(BaseBuffer):
206 | """
207 | Replay buffer used in off-policy algorithms like SAC/TD3.
208 | :param buffer_size: Max number of element in the buffer
209 | :param observation_space: Observation space
210 | :param action_space: Action space
211 | :param device: PyTorch device
212 | :param n_envs: Number of parallel environments
213 | :param optimize_memory_usage: Enable a memory efficient variant
214 | of the replay buffer which reduces by almost a factor two the memory used,
215 | at a cost of more complexity.
216 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
217 | and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
218 | Cannot be used in combination with handle_timeout_termination.
219 | :param handle_timeout_termination: Handle timeout termination (due to timelimit)
220 | separately and treat the task as infinite horizon task.
221 | https://github.com/DLR-RM/stable-baselines3/issues/284
222 | """
223 |
224 | def __init__(
225 | self,
226 | buffer_size: int,
227 | observation_space: spaces.Space,
228 | action_space: spaces.Space,
229 | device: Union[torch.device, str] = "auto",
230 | n_envs: int = 1,
231 | optimize_memory_usage: bool = False,
232 | handle_timeout_termination: bool = True,
233 | ):
234 | super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
235 |
236 | # Adjust buffer size
237 | self.buffer_size = max(buffer_size // n_envs, 1)
238 |
239 | # Check that the replay buffer can fit into the memory
240 | if psutil is not None:
241 | mem_available = psutil.virtual_memory().available
242 |
243 | # there is a bug if both optimize_memory_usage and handle_timeout_termination are true
244 | # see https://github.com/DLR-RM/stable-baselines3/issues/934
245 | if optimize_memory_usage and handle_timeout_termination:
246 | raise ValueError(
247 | "ReplayBuffer does not support optimize_memory_usage = True "
248 | "and handle_timeout_termination = True simultaneously."
249 | )
250 | self.optimize_memory_usage = optimize_memory_usage
251 |
252 | self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
253 |
254 | if optimize_memory_usage:
255 | # `observations` contains also the next observation
256 | self.next_observations = None
257 | else:
258 | self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
259 |
260 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
261 |
262 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
263 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
264 | # Handle timeouts termination properly if needed
265 | # see https://github.com/DLR-RM/stable-baselines3/issues/284
266 | self.handle_timeout_termination = handle_timeout_termination
267 | self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
268 |
269 | if psutil is not None:
270 | total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
271 |
272 | if self.next_observations is not None:
273 | total_memory_usage += self.next_observations.nbytes
274 |
275 | if total_memory_usage > mem_available:
276 | # Convert to GB
277 | total_memory_usage /= 1e9
278 | mem_available /= 1e9
279 | warnings.warn(
280 | "This system does not have apparently enough memory to store the complete "
281 | f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
282 | )
283 |
284 | def add(
285 | self,
286 | obs: np.ndarray,
287 | next_obs: np.ndarray,
288 | action: np.ndarray,
289 | reward: np.ndarray,
290 | done: np.ndarray,
291 | infos: list[Dict[str, Any]],
292 | ) -> None:
293 | # Reshape needed when using multiple envs with discrete observations
294 | # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
295 | if isinstance(self.observation_space, spaces.Discrete):
296 | obs = obs.reshape((self.n_envs, *self.obs_shape))
297 | next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
298 |
299 | # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
300 | action = action.reshape((self.n_envs, self.action_dim))
301 |
302 | # Copy to avoid modification by reference
303 | self.observations[self.pos] = np.array(obs).copy()
304 |
305 | if self.optimize_memory_usage:
306 | self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
307 | else:
308 | self.next_observations[self.pos] = np.array(next_obs).copy()
309 |
310 | self.actions[self.pos] = np.array(action).copy()
311 | self.rewards[self.pos] = np.array(reward).copy()
312 | self.dones[self.pos] = np.array(done).copy()
313 |
314 | if self.handle_timeout_termination:
315 | self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
316 |
317 | self.pos += 1
318 | if self.pos == self.buffer_size:
319 | self.full = True
320 | self.pos = 0
321 |
322 | def sample(self, batch_size: int) -> ReplayBufferSamples:
323 | """
324 | Sample elements from the replay buffer.
325 | Custom sampling when using memory efficient variant,
326 | as we should not sample the element with index `self.pos`
327 | See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
328 | :param batch_size: Number of element to sample
329 | :return:
330 | """
331 | if not self.optimize_memory_usage:
332 | return super().sample(batch_size=batch_size)
333 | # Do not sample the element with index `self.pos` as the transitions is invalid
334 | # (we use only one array to store `obs` and `next_obs`)
335 | if self.full:
336 | batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
337 | else:
338 | batch_inds = np.random.randint(0, self.pos, size=batch_size)
339 | return self._get_samples(batch_inds)
340 |
341 | def _get_samples(self, batch_inds: np.ndarray) -> ReplayBufferSamples:
342 | # Sample randomly the env idx
343 | env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
344 |
345 | if self.optimize_memory_usage:
346 | next_obs = self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :]
347 | else:
348 | next_obs = self.next_observations[batch_inds, env_indices, :]
349 |
350 | data = (
351 | self.observations[batch_inds, env_indices, :],
352 | self.actions[batch_inds, env_indices, :],
353 | next_obs,
354 | # Only use dones that are not due to timeouts
355 | # deactivated by default (timeouts is initialized as an array of False)
356 | (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
357 | self.rewards[batch_inds, env_indices].reshape(-1, 1),
358 | )
359 | return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
360 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class Config:
6 | """Configuration for a ReDo DQN agent."""
7 |
8 | # Experiment settings
9 | exp_name: str = "ReDo DQN"
10 | tags: tuple[str, ...] | str | None = None
11 | seed: int = 0
12 | torch_deterministic: bool = True
13 | gpu: int | None = 0
14 | track: bool = False
15 | wandb_project_name: str = "ReDo"
16 | wandb_entity: str | None = None
17 | capture_video: bool = False
18 | save_model: bool = False
19 |
20 | # Environment settings
21 | env_id: str = "DemonAttackNoFrameskip-v4"
22 | total_timesteps: int = 10_000_000
23 | num_envs: int = 1
24 |
25 | # DQN settings
26 | buffer_size: int = 1_000_000
27 | batch_size: int = 32
28 | learning_rate: float = 6.25 * 1e-5 # cleanRL default: 1e-4, theirs: 6.25 * 1e-5
29 | adam_eps: float = 1.5 * 1e-4
30 | use_lecun_init: bool = False # ReDO uses lecun_normal initializer, cleanRL uses the pytorch default (kaiming_uniform)
31 | gamma: float = 0.99
32 | tau: float = 1.0
33 | target_network_frequency: int = 8000 # cleanRL default: 8000, 4 freq -> 8000, 1 -> 2000
34 | start_e: float = 1.0
35 | end_e: float = 0.01
36 | exploration_fraction: float = 0.10
37 | learning_starts: int = 80_000 # cleanRL default: 80000, theirs 20000
38 | train_frequency: int = 4 # cleanRL default: 4, theirs 1
39 |
40 | # ReDo settings
41 | enable_redo: bool = False
42 | redo_tau: float = 0.025 # 0.025 for default, else 0.1
43 | redo_check_interval: int = 1000
44 | redo_bs: int = 64
45 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from cleanRL: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/evals/dqn_eval.py.
3 | """
4 |
5 | import random
6 | from typing import Callable
7 |
8 | import gymnasium as gym
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def evaluate(
14 | model_path: str,
15 | make_env: Callable,
16 | env_id: str,
17 | eval_episodes: int,
18 | run_name: str,
19 | Model: torch.nn.Module,
20 | device: torch.device = torch.device("cpu"),
21 | epsilon: float = 0.05,
22 | capture_video: bool = False,
23 | ):
24 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
25 | model = Model(envs).to(device)
26 | model.load_state_dict(torch.load(model_path, map_location=device))
27 | model.eval()
28 |
29 | obs = envs.reset()
30 | episodic_returns = []
31 | while len(episodic_returns) < eval_episodes:
32 | if random.random() < epsilon:
33 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
34 | else:
35 | q_values = model(torch.Tensor(obs).to(device))
36 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
37 | next_obs, _, _, infos = envs.step(actions)
38 | for info in infos:
39 | if "episode" in info.keys():
40 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
41 | episodic_returns += [info["episode"]["r"]]
42 | obs = next_obs
43 |
44 | return episodic_returns
45 |
--------------------------------------------------------------------------------
/src/redo.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import optim
8 |
9 | from .agent import QNetwork
10 | from .buffer import ReplayBufferSamples
11 |
12 |
13 | @torch.no_grad()
14 | def _kaiming_uniform_reinit(layer: nn.Linear | nn.Conv2d, mask: torch.Tensor) -> None:
15 | """Partially re-initializes the bias of a layer according to the Kaiming uniform scheme."""
16 |
17 | # This is adapted from https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
18 | fan_in = nn.init._calculate_correct_fan(tensor=layer.weight, mode="fan_in")
19 | gain = nn.init.calculate_gain(nonlinearity="relu", param=math.sqrt(5))
20 | std = gain / math.sqrt(fan_in)
21 | # Calculate uniform bounds from standard deviation
22 | bound = math.sqrt(3.0) * std
23 | layer.weight.data[mask, ...] = torch.empty_like(layer.weight.data[mask, ...]).uniform_(-bound, bound)
24 |
25 | if layer.bias is not None:
26 | # The original code resets the bias to 0.0 because it uses a different initialization scheme
27 | # layer.bias.data[mask] = 0.0
28 | if isinstance(layer, nn.Conv2d):
29 | if fan_in != 0:
30 | bound = 1 / math.sqrt(fan_in)
31 | layer.bias.data[mask, ...] = torch.empty_like(layer.bias.data[mask, ...]).uniform_(-bound, bound)
32 | else:
33 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
34 | layer.bias.data[mask, ...] = torch.empty_like(layer.bias.data[mask, ...]).uniform_(-bound, bound)
35 |
36 |
37 | @torch.no_grad()
38 | def _lecun_normal_reinit(layer: nn.Linear | nn.Conv2d, mask: torch.Tensor) -> None:
39 | """Partially re-initializes the bias of a layer according to the Lecun normal scheme."""
40 |
41 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
42 |
43 | # This implementation follows the jax one
44 | # https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L260
45 | variance = 1.0 / fan_in
46 | stddev = math.sqrt(variance) / 0.87962566103423978
47 | layer.weight[mask] = nn.init._no_grad_trunc_normal_(layer.weight[mask], mean=0.0, std=1.0, a=-2.0, b=2.0)
48 | layer.weight[mask] *= stddev
49 | if layer.bias is not None:
50 | layer.bias.data[mask] = 0.0
51 |
52 |
53 | @torch.inference_mode()
54 | def _get_activation(name: str, activations: dict[str, torch.Tensor]):
55 | """Fetches and stores the activations of a network layer."""
56 |
57 | def hook(layer: nn.Linear | nn.Conv2d, input: tuple[torch.Tensor], output: torch.Tensor) -> None:
58 | """
59 | Get the activations of a layer with relu nonlinearity.
60 | ReLU has to be called explicitly here because the hook is attached to the conv/linear layer.
61 | """
62 | activations[name] = F.relu(output)
63 |
64 | return hook
65 |
66 |
67 | @torch.inference_mode()
68 | def _get_redo_masks(activations: dict[str, torch.Tensor], tau: float) -> torch.Tensor:
69 | """
70 | Computes the ReDo mask for a given set of activations.
71 | The returned mask has True where neurons are dormant and False where they are active.
72 | """
73 | masks = []
74 |
75 | # Last activation are the q-values, which are never reset
76 | for name, activation in list(activations.items())[:-1]:
77 | # Taking the mean here conforms to the expectation under D in the main paper's formula
78 | if activation.ndim == 4:
79 | # Conv layer
80 | score = activation.abs().mean(dim=(0, 2, 3))
81 | else:
82 | # Linear layer
83 | score = activation.abs().mean(dim=0)
84 |
85 | # Divide by activation mean to make the threshold independent of the layer size
86 | # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314
87 | # https://github.com/google/dopamine/issues/209
88 | normalized_score = score / (score.mean() + 1e-9)
89 |
90 | layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool)
91 | if tau > 0.0:
92 | layer_mask[normalized_score <= tau] = 1
93 | else:
94 | layer_mask[torch.isclose(normalized_score, torch.zeros_like(normalized_score))] = 1
95 | masks.append(layer_mask)
96 | return masks
97 |
98 |
99 | @torch.no_grad()
100 | def _reset_dormant_neurons(model: QNetwork, redo_masks: torch.Tensor, use_lecun_init: bool) -> QNetwork:
101 | """Re-initializes the dormant neurons of a model."""
102 |
103 | layers = [(name, layer) for name, layer in list(model.named_modules())[1:]]
104 | assert len(redo_masks) == len(layers) - 1, "Number of masks must match the number of layers"
105 |
106 | # Reset the ingoing weights
107 | # Here the mask size always matches the layer weight size
108 | for i in range(len(layers[:-1])):
109 | mask = redo_masks[i]
110 | layer = layers[i][1]
111 | next_layer = layers[i + 1][1]
112 | # Can be used to not reset outgoing weights in the Q-function
113 | next_layer_name = layers[i + 1][0]
114 |
115 | # Skip if there are no dead neurons
116 | if torch.all(~mask):
117 | # No dormant neurons in this layer
118 | continue
119 |
120 | # The initialization scheme is the same for conv2d and linear
121 | # 1. Reset the ingoing weights using the initialization distribution
122 | if use_lecun_init:
123 | _lecun_normal_reinit(layer, mask)
124 | else:
125 | _kaiming_uniform_reinit(layer, mask)
126 |
127 | # 2. Reset the outgoing weights to 0
128 | # NOTE: Don't reset the bias for the following layer or else you will create new dormant neurons
129 | # To not reset in the last layer: and not next_layer_name == 'q'
130 | if isinstance(layer, nn.Conv2d) and isinstance(next_layer, nn.Linear):
131 | # Special case: Transition from conv to linear layer
132 | # Reset the outgoing weights to 0 with a mask created from the conv filters
133 | num_repeatition = next_layer.weight.data.shape[1] // mask.shape[0]
134 | linear_mask = torch.repeat_interleave(mask, num_repeatition)
135 | next_layer.weight.data[:, linear_mask] = 0.0
136 | else:
137 | # Standard case: layer and next_layer are both conv or both linear
138 | # Reset the outgoing weights to 0
139 | next_layer.weight.data[:, mask, ...] = 0.0
140 |
141 | return model
142 |
143 |
144 | @torch.no_grad()
145 | def _reset_adam_moments(optimizer: optim.Adam, reset_masks: dict[str, torch.Tensor]) -> optim.Adam:
146 | """Resets the moments of the Adam optimizer for the dormant neurons."""
147 |
148 | assert isinstance(optimizer, optim.Adam), "Moment resetting currently only supported for Adam optimizer"
149 | for i, mask in enumerate(reset_masks):
150 | # Reset the moments for the weights
151 | optimizer.state_dict()["state"][i * 2]["exp_avg"][mask, ...] = 0.0
152 | optimizer.state_dict()["state"][i * 2]["exp_avg_sq"][mask, ...] = 0.0
153 | # NOTE: Step count resets are key to the algorithm's performance
154 | # It's possible to just reset the step for moment that's being reset
155 | optimizer.state_dict()["state"][i * 2]["step"].zero_()
156 |
157 | # Reset the moments for the bias
158 | optimizer.state_dict()["state"][i * 2 + 1]["exp_avg"][mask] = 0.0
159 | optimizer.state_dict()["state"][i * 2 + 1]["exp_avg_sq"][mask] = 0.0
160 | optimizer.state_dict()["state"][i * 2 + 1]["step"].zero_()
161 |
162 | # Reset the moments for the output weights
163 | if (
164 | len(optimizer.state_dict()["state"][i * 2]["exp_avg"].shape) == 4
165 | and len(optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape) == 2
166 | ):
167 | # Catch transition from conv to linear layer through moment shapes
168 | num_repeatition = optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape[1] // mask.shape[0]
169 | linear_mask = torch.repeat_interleave(mask, num_repeatition)
170 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"][:, linear_mask] = 0.0
171 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg_sq"][:, linear_mask] = 0.0
172 | optimizer.state_dict()["state"][i * 2 + 2]["step"].zero_()
173 | else:
174 | # Standard case: layer and next_layer are both conv or both linear
175 | # Reset the outgoing weights to 0
176 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"][:, mask, ...] = 0.0
177 | optimizer.state_dict()["state"][i * 2 + 2]["exp_avg_sq"][:, mask, ...] = 0.0
178 | optimizer.state_dict()["state"][i * 2 + 2]["step"].zero_()
179 |
180 | return optimizer
181 |
182 |
183 | @torch.inference_mode()
184 | def run_redo(
185 | obs: torch.Tensor,
186 | model: QNetwork,
187 | optimizer: optim.Adam,
188 | tau: float,
189 | re_initialize: bool,
190 | use_lecun_init: bool,
191 | ) -> tuple[nn.Module, optim.Adam, float, int]:
192 | """
193 | Checks the number of dormant neurons for a given model.
194 | If re_initialize is True, then the dormant neurons are re-initialized according to the scheme in
195 | https://arxiv.org/abs/2302.12902
196 |
197 | Returns the number of dormant neurons.
198 | """
199 |
200 | activations = {}
201 | activation_getter = partial(_get_activation, activations=activations)
202 |
203 | # Register hooks for all Conv2d and Linear layers to calculate activations
204 | handles = []
205 | for name, module in model.named_modules():
206 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
207 | handles.append(module.register_forward_hook(activation_getter(name)))
208 |
209 | # Calculate activations
210 | _ = model(obs)
211 |
212 | # Masks for tau=0 logging
213 | zero_masks = _get_redo_masks(activations, 0.0)
214 | total_neurons = sum([torch.numel(mask) for mask in zero_masks])
215 | zero_count = sum([torch.sum(mask) for mask in zero_masks])
216 | zero_fraction = (zero_count / total_neurons) * 100
217 |
218 | # Calculate the masks actually used for resetting
219 | masks = _get_redo_masks(activations, tau)
220 | dormant_count = sum([torch.sum(mask) for mask in masks])
221 | dormant_fraction = (dormant_count / sum([torch.numel(mask) for mask in masks])) * 100
222 |
223 | # Re-initialize the dormant neurons and reset the Adam moments
224 | if re_initialize:
225 | print("Re-initializing dormant neurons")
226 | print(f"Total neurons: {total_neurons} | Dormant neurons: {dormant_count} | Dormant fraction: {dormant_fraction:.2f}%")
227 | model = _reset_dormant_neurons(model, masks, use_lecun_init)
228 | optimizer = _reset_adam_moments(optimizer, masks)
229 |
230 | # Remove the hooks again
231 | for handle in handles:
232 | handle.remove()
233 |
234 | return {
235 | "model": model,
236 | "optimizer": optimizer,
237 | "zero_fraction": zero_fraction,
238 | "zero_count": zero_count,
239 | "dormant_fraction": dormant_fraction,
240 | "dormant_count": dormant_count,
241 | }
242 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import typing
3 |
4 | import gymnasium as gym
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .wrappers import (
9 | ClipRewardEnv,
10 | EpisodicLifeEnv,
11 | FireResetEnv,
12 | MaxAndSkipEnv,
13 | NoopResetEnv,
14 | )
15 |
16 |
17 | def make_env(env_id, seed, idx, capture_video, run_name):
18 | """Helper function to create an environment with some standard wrappers.
19 |
20 | Taken from cleanRL's DQN Atari implementation: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py.
21 | """
22 |
23 | def thunk():
24 | if capture_video and idx == 0:
25 | env = gym.make(env_id, render_mode="rgb_array")
26 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
27 | else:
28 | env = gym.make(env_id)
29 | env = gym.wrappers.RecordEpisodeStatistics(env)
30 | env = NoopResetEnv(env, noop_max=30)
31 | env = MaxAndSkipEnv(env, skip=4)
32 | env = EpisodicLifeEnv(env)
33 | if "FIRE" in env.unwrapped.get_action_meanings():
34 | env = FireResetEnv(env)
35 | env = ClipRewardEnv(env)
36 | env = gym.wrappers.ResizeObservation(env, (84, 84))
37 | env = gym.wrappers.GrayScaleObservation(env)
38 | env = gym.wrappers.FrameStack(env, 4)
39 | env.action_space.seed(seed)
40 |
41 | return env
42 |
43 | return thunk
44 |
45 |
46 | def set_cuda_configuration(gpu: typing.Any) -> torch.device:
47 | """Set up the device for the desired GPU or all GPUs."""
48 |
49 | if gpu is None or gpu == -1 or gpu is False:
50 | device = torch.device("cpu")
51 | elif isinstance(gpu, int):
52 | assert gpu <= torch.cuda.device_count(), "Invalid CUDA index specified."
53 | device = torch.device(f"cuda:{gpu}")
54 | else:
55 | device = torch.device("cuda")
56 |
57 | return device
58 |
59 |
60 | @torch.no_grad()
61 | def lecun_normal_initializer(layer: nn.Module) -> None:
62 | """
63 | Initialization according to LeCun et al. (1998).
64 | See here https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.initializers.lecun_normal.html
65 | and here https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L460 .
66 | Initializes bias terms to 0.
67 | """
68 |
69 | # Catch case where the whole network is passed
70 | if not isinstance(layer, nn.Linear | nn.Conv2d):
71 | return
72 |
73 | # For a conv layer, this is num_channels*kernel_height*kernel_width
74 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
75 |
76 | # This implementation follows the jax one
77 | # https://github.com/google/jax/blob/366a16f8ba59fe1ab59acede7efd160174134e01/jax/_src/nn/initializers.py#L260
78 | variance = 1.0 / fan_in
79 | stddev = math.sqrt(variance) / 0.87962566103423978
80 | torch.nn.init.trunc_normal_(layer.weight)
81 | layer.weight *= stddev
82 | if layer.bias is not None:
83 | torch.nn.init.zeros_(layer.bias)
84 |
--------------------------------------------------------------------------------
/src/wrappers.py:
--------------------------------------------------------------------------------
1 | """
2 | The wrappers are taken from stable-baselines3 with unnecessary ones removed.
3 | https://github.com/DLR-RM/stable-baselines3/blob/feat/gymnasium-support/stable_baselines3/common/type_aliases.py
4 | """
5 |
6 | from typing import Any, SupportsFloat, Union
7 |
8 | import gymnasium as gym
9 | import numpy as np
10 |
11 | GymObs = Union[tuple, dict[str, Any], np.ndarray, int]
12 | GymStepReturn = tuple[GymObs, float, bool, dict]
13 |
14 |
15 | AtariResetReturn = tuple[np.ndarray, dict[str, Any]]
16 | AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]
17 |
18 |
19 | class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
20 | """
21 | Sticky action.
22 | Paper: https://arxiv.org/abs/1709.06009
23 | Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
24 | :param env: Environment to wrap
25 | :param action_repeat_probability: Probability of repeating the last action
26 | """
27 |
28 | def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
29 | super().__init__(env)
30 | self.action_repeat_probability = action_repeat_probability
31 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
32 |
33 | def reset(self, **kwargs) -> AtariResetReturn:
34 | self._sticky_action = 0 # NOOP
35 | return self.env.reset(**kwargs)
36 |
37 | def step(self, action: int) -> AtariStepReturn:
38 | if self.np_random.random() >= self.action_repeat_probability:
39 | self._sticky_action = action
40 | return self.env.step(self._sticky_action)
41 |
42 |
43 | class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
44 | """
45 | Sample initial states by taking random number of no-ops on reset.
46 | No-op is assumed to be action 0.
47 | :param env: Environment to wrap
48 | :param noop_max: Maximum value of no-ops to run
49 | """
50 |
51 | def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
52 | super().__init__(env)
53 | self.noop_max = noop_max
54 | self.override_num_noops = None
55 | self.noop_action = 0
56 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
57 |
58 | def reset(self, **kwargs) -> AtariResetReturn:
59 | self.env.reset(**kwargs)
60 | if self.override_num_noops is not None:
61 | noops = self.override_num_noops
62 | else:
63 | noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
64 | assert noops > 0
65 | obs = np.zeros(0)
66 | info: dict = {}
67 | for _ in range(noops):
68 | obs, _, terminated, truncated, info = self.env.step(self.noop_action)
69 | if terminated or truncated:
70 | obs, info = self.env.reset(**kwargs)
71 | return obs, info
72 |
73 |
74 | class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
75 | """
76 | Take action on reset for environments that are fixed until firing.
77 | :param env: Environment to wrap
78 | """
79 |
80 | def __init__(self, env: gym.Env) -> None:
81 | super().__init__(env)
82 | assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
83 | assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
84 |
85 | def reset(self, **kwargs) -> AtariResetReturn:
86 | self.env.reset(**kwargs)
87 | obs, _, terminated, truncated, _ = self.env.step(1)
88 | if terminated or truncated:
89 | self.env.reset(**kwargs)
90 | obs, _, terminated, truncated, _ = self.env.step(2)
91 | if terminated or truncated:
92 | self.env.reset(**kwargs)
93 | return obs, {}
94 |
95 |
96 | class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
97 | """
98 | Make end-of-life == end-of-episode, but only reset on true game over.
99 | Done by DeepMind for the DQN and co. since it helps value estimation.
100 | :param env: Environment to wrap
101 | """
102 |
103 | def __init__(self, env: gym.Env) -> None:
104 | super().__init__(env)
105 | self.lives = 0
106 | self.was_real_done = True
107 |
108 | def step(self, action: int) -> AtariStepReturn:
109 | obs, reward, terminated, truncated, info = self.env.step(action)
110 | self.was_real_done = terminated or truncated
111 | # check current lives, make loss of life terminal,
112 | # then update lives to handle bonus lives
113 | lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
114 | if 0 < lives < self.lives:
115 | # for Qbert sometimes we stay in lives == 0 condition for a few frames
116 | # so its important to keep lives > 0, so that we only reset once
117 | # the environment advertises done.
118 | terminated = True
119 | self.lives = lives
120 | return obs, reward, terminated, truncated, info
121 |
122 | def reset(self, **kwargs) -> AtariResetReturn:
123 | """
124 | Calls the Gym environment reset, only when lives are exhausted.
125 | This way all states are still reachable even though lives are episodic,
126 | and the learner need not know about any of this behind-the-scenes.
127 | :param kwargs: Extra keywords passed to env.reset() call
128 | :return: the first observation of the environment
129 | """
130 | if self.was_real_done:
131 | obs, info = self.env.reset(**kwargs)
132 | else:
133 | # no-op step to advance from terminal/lost life state
134 | obs, _, terminated, truncated, info = self.env.step(0)
135 |
136 | # The no-op step can lead to a game over, so we need to check it again
137 | # to see if we should reset the environment and avoid the
138 | # monitor.py `RuntimeError: Tried to step environment that needs reset`
139 | if terminated or truncated:
140 | obs, info = self.env.reset(**kwargs)
141 | self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
142 | return obs, info
143 |
144 |
145 | class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
146 | """
147 | Return only every ``skip``-th frame (frameskipping)
148 | and return the max between the two last frames.
149 | :param env: Environment to wrap
150 | :param skip: Number of ``skip``-th frame
151 | The same action will be taken ``skip`` times.
152 | """
153 |
154 | def __init__(self, env: gym.Env, skip: int = 4) -> None:
155 | super().__init__(env)
156 | # most recent raw observations (for max pooling across time steps)
157 | assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
158 | assert env.observation_space.shape is not None, "No shape defined for the observation space"
159 | self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
160 | self._skip = skip
161 |
162 | def step(self, action: int) -> AtariStepReturn:
163 | """
164 | Step the environment with the given action
165 | Repeat action, sum reward, and max over last observations.
166 | :param action: the action
167 | :return: observation, reward, terminated, truncated, information
168 | """
169 | total_reward = 0.0
170 | terminated = truncated = False
171 | for i in range(self._skip):
172 | obs, reward, terminated, truncated, info = self.env.step(action)
173 | done = terminated or truncated
174 | if i == self._skip - 2:
175 | self._obs_buffer[0] = obs
176 | if i == self._skip - 1:
177 | self._obs_buffer[1] = obs
178 | total_reward += float(reward)
179 | if done:
180 | break
181 | # Note that the observation on the done=True frame
182 | # doesn't matter
183 | max_frame = self._obs_buffer.max(axis=0)
184 |
185 | return max_frame, total_reward, terminated, truncated, info
186 |
187 |
188 | class ClipRewardEnv(gym.RewardWrapper):
189 | """
190 | Clip the reward to {+1, 0, -1} by its sign.
191 | :param env: Environment to wrap
192 | """
193 |
194 | def __init__(self, env: gym.Env) -> None:
195 | super().__init__(env)
196 |
197 | def reward(self, reward: SupportsFloat) -> float:
198 | """
199 | Bin reward to {+1, 0, -1} by its sign.
200 | :param reward:
201 | :return:
202 | """
203 | return np.sign(float(reward))
204 |
--------------------------------------------------------------------------------