├── requirements.txt ├── README.md ├── .gitignore ├── environment.py ├── utils.py └── dreamerv3.py /requirements.txt: -------------------------------------------------------------------------------- 1 | swig 2 | box2d-py 3 | torch 4 | matplotlib 5 | numpy 6 | gymnasium[all] 7 | gymnasium[accept-rom-license] 8 | imageio 9 | termcolor 10 | opencv-python 11 | tensorboard 12 | wandb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamerV3 2 | 3 | 🚫🚧👷‍♀️ Warning: Under Construction 👷‍♂️🚧🚫 4 | 5 | This repository contains a PyTorch implementation of the DreamerV3 algorithm, aimed at providing a more readable and accessible version compared to the official implementations. 6 | 7 | ## Overview 8 | 9 | DreamerV3 is a model-based reinforcement learning algorithm that learns a world model of the environment dynamics, and uses it to train an actor-critic policy from imagined trajectories. The algorithm consists of several key components: 10 | 11 | 1. A world model that encodes sensory inputs into discrete latent representations and predicts future states and rewards. 12 | 2. An actor network that learns to take actions in the imagined environment. 13 | 3. A critic network that estimates the value of states and actions. 14 | 4. An imagination process that generates trajectories using the learned world model. 15 | 16 | This implementation aims to break down these components into clear, modular parts, making it easier to understand and modify. 17 | 18 | Note: This code is written to handle the Atari environments where observations are images, you will need to modify the networks for environments where the observations are vectors. 19 | 20 | ## Setup and Installation 21 | 22 | To set up the environment and install the required dependencies, follow these steps: 23 | 24 | 1. Create a new conda environment: 25 | 26 | ```bash 27 | conda create -n dreamerv3 python=3.11 28 | conda activate dreamerv3 29 | ``` 30 | 31 | 2. Install the required packages: 32 | 33 | ```bash 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Running the Code 38 | 39 | ```bash 40 | python dreamerv3.py 41 | ``` 42 | 43 | ### Command-line Arguments 44 | 45 | - `--env`: Specify the environment (optional) 46 | - `--wandb_key`: WandB API key for logging (required) 47 | 48 | Example: 49 | 50 | ```bash 51 | python dreamerv3.py --env ALE/MsPacman-v5 --wandb_key "../wandb.txt" 52 | ``` 53 | 54 | If no env is specified the code will loop through all Atari environments (see full list in `environment.py`) 55 | 56 | For a full list of available options, run: 57 | 58 | ```bash 59 | python dreamerv3.py --help 60 | ``` 61 | 62 | ## Acknowledgements 63 | 64 | This implementation draws inspiration from the following repositories: 65 | 66 | - [Official DreamerV3 JAX Implementation](https://github.com/danijar/dreamerv3) 67 | - [DreamerV3 PyTorch Implementation by NM512](https://github.com/NM512/dreamerv3-torch) 68 | 69 | These resources have been invaluable in understanding the DreamerV3 algorithm and creating this more accessible implementation. 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pt 3 | weights/ 4 | tovi/ 5 | environments/ 6 | metrics/ 7 | results/ 8 | csv/ 9 | runs/ 10 | wandb/ 11 | videos/ 12 | *.csv 13 | 14 | old.py 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | cover/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | .pybuilder/ 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 125 | .pdm.toml 126 | .pdm-python 127 | .pdm-build/ 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # Cython debug symbols 170 | cython_debug/ 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import gymnasium as gym 4 | from collections import deque 5 | from utils import preprocess 6 | 7 | 8 | class AtariEnv: 9 | def __init__( 10 | self, 11 | env_id, 12 | shape=(64, 64), 13 | repeat=4, 14 | clip_rewards=False, 15 | no_ops=0, 16 | fire_first=False, 17 | ): 18 | base_env = gym.make(env_id, render_mode="rgb_array") 19 | env = RepeatActionAndMaxFrame( 20 | base_env, repeat, clip_rewards, no_ops, fire_first 21 | ) 22 | env = PreprocessFrame(env, shape) 23 | env = StackFrames(env, repeat) 24 | self.env = env 25 | 26 | def make(self): 27 | return self.env 28 | 29 | 30 | class RepeatActionAndMaxFrame(gym.Wrapper): 31 | def __init__(self, env, repeat=4, clip_reward=True, no_ops=0, fire_first=False): 32 | super().__init__(env) 33 | self.repeat = repeat 34 | self.clip_reward = clip_reward 35 | self.no_ops = no_ops 36 | self.fire_first = fire_first 37 | self.frame_buffer = np.zeros( 38 | (2, *env.observation_space.shape), dtype=np.float32 39 | ) 40 | 41 | def step(self, action): 42 | total_reward = 0 43 | term, trunc = False, False 44 | for i in range(self.repeat): 45 | state, reward, term, trunc, info = self.env.step(action) 46 | if self.clip_reward: 47 | reward = np.clip(reward, -1, 1) 48 | total_reward += reward 49 | self.frame_buffer[i % 2] = state 50 | if term or trunc: 51 | break 52 | max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1]) 53 | return max_frame, total_reward, term, trunc, info 54 | 55 | def reset(self, seed=None, options=None): 56 | state, info = self.env.reset(seed=seed, options=options) 57 | no_ops = np.random.randint(self.no_ops) + 1 if self.no_ops > 0 else 0 58 | for _ in range(no_ops): 59 | _, _, term, trunc, info = self.env.step(0) 60 | if term or trunc: 61 | state, info = self.env.reset() 62 | if self.fire_first: 63 | assert self.env.unwrapped.get_action_meanings()[1] == "FIRE" 64 | state, _, term, trunc, info = self.env.step(1) 65 | self.frame_buffer = np.zeros( 66 | (2, *self.env.observation_space.shape), dtype=np.float32 67 | ) 68 | self.frame_buffer[0] = state 69 | return state, info 70 | 71 | 72 | class PreprocessFrame(gym.ObservationWrapper): 73 | def __init__(self, env, shape=(64, 64)): 74 | super().__init__(env) 75 | self.shape = shape 76 | self.observation_space = gym.spaces.Box(0.0, 1.0, self.shape, dtype=np.float32) 77 | 78 | def observation(self, state): 79 | state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY) 80 | state = cv2.resize(state, self.shape, interpolation=cv2.INTER_AREA) 81 | return preprocess(state) 82 | 83 | 84 | class StackFrames(gym.ObservationWrapper): 85 | def __init__(self, env, size=4): 86 | super().__init__(env) 87 | self.size = int(size) 88 | self.stack = deque([], maxlen=self.size) 89 | shape = self.env.observation_space.shape 90 | self.observation_space = gym.spaces.Box( 91 | 0.0, 1.0, (self.size, *shape), dtype=np.float32 92 | ) 93 | 94 | def reset(self, seed=None, options=None): 95 | state, info = self.env.reset(seed=seed, options=options) 96 | self.stack = deque([state] * self.size, maxlen=self.size) 97 | return np.array(self.stack), info 98 | 99 | def observation(self, state): 100 | self.stack.append(state) 101 | return np.array(self.stack) 102 | 103 | 104 | ENV_LIST = [ 105 | "ALE/Adventure-v5", 106 | "ALE/AirRaid-v5", 107 | "ALE/Alien-v5", 108 | "ALE/Amidar-v5", 109 | "ALE/Assault-v5", 110 | "ALE/Asterix-v5", 111 | "ALE/Asteroids-v5", 112 | "ALE/Atlantis-v5", 113 | "ALE/BankHeist-v5", 114 | "ALE/BattleZone-v5", 115 | "ALE/BeamRider-v5", 116 | "ALE/Berzerk-v5", 117 | "ALE/Bowling-v5", 118 | "ALE/Boxing-v5", 119 | "ALE/Breakout-v5", 120 | "ALE/Carnival-v5", 121 | "ALE/Centipede-v5", 122 | "ALE/ChopperCommand-v5", 123 | "ALE/CrazyClimber-v5", 124 | "ALE/Defender-v5", 125 | "ALE/DemonAttack-v5", 126 | "ALE/DoubleDunk-v5", 127 | "ALE/ElevatorAction-v5", 128 | "ALE/Enduro-v5", 129 | "ALE/FishingDerby-v5", 130 | "ALE/Freeway-v5", 131 | "ALE/Frostbite-v5", 132 | "ALE/Gopher-v5", 133 | "ALE/Gravitar-v5", 134 | "ALE/Hero-v5", 135 | "ALE/IceHockey-v5", 136 | "ALE/Jamesbond-v5", 137 | "ALE/JourneyEscape-v5", 138 | "ALE/Kangaroo-v5", 139 | "ALE/Krull-v5", 140 | "ALE/KungFuMaster-v5", 141 | "ALE/MontezumaRevenge-v5", 142 | "ALE/MsPacman-v5", 143 | "ALE/NameThisGame-v5", 144 | "ALE/Phoenix-v5", 145 | "ALE/Pitfall-v5", 146 | "ALE/Pong-v5", 147 | "ALE/Pooyan-v5", 148 | "ALE/PrivateEye-v5", 149 | "ALE/Qbert-v5", 150 | "ALE/Riverraid-v5", 151 | "ALE/RoadRunner-v5", 152 | "ALE/Robotank-v5", 153 | "ALE/Seaquest-v5", 154 | "ALE/Skiing-v5", 155 | "ALE/Solaris-v5", 156 | "ALE/SpaceInvaders-v5", 157 | "ALE/StarGunner-v5", 158 | "ALE/Tennis-v5", 159 | "ALE/TimePilot-v5", 160 | "ALE/Tutankham-v5", 161 | "ALE/UpNDown-v5", 162 | "ALE/Venture-v5", 163 | "ALE/VideoPinball-v5", 164 | "ALE/WizardOfWor-v5", 165 | "ALE/YarsRevenge-v5", 166 | "ALE/Zaxxon-v5", 167 | ] 168 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import imageio 5 | import wandb 6 | import os 7 | import gymnasium as gym 8 | import time 9 | 10 | 11 | def make_env( 12 | env_name, record_video=False, video_folder="videos", video_interval=100, test=False 13 | ): 14 | env = gym.make(env_name, render_mode="rgb_array" if record_video else None) 15 | 16 | env = gym.wrappers.AtariPreprocessing( 17 | env, 18 | frame_skip=1, 19 | screen_size=64, 20 | grayscale_obs=False, 21 | scale_obs=True, 22 | noop_max=0 if test else 30, 23 | ) 24 | env = gym.wrappers.TransformObservation( 25 | env, lambda obs: np.transpose(obs, (2, 0, 1)), None 26 | ) 27 | env.observation_space = gym.spaces.Box( 28 | low=0, high=1, shape=(3, 64, 64), dtype=np.float32 29 | ) 30 | 31 | if record_video: 32 | env = gym.wrappers.RecordVideo( 33 | env, 34 | video_folder=video_folder, 35 | episode_trigger=lambda x: x % video_interval == 0, 36 | name_prefix=env_name.split("/")[-1], 37 | ) 38 | 39 | return env 40 | 41 | 42 | def make_vec_env(env_name, num_envs=16, video_folder="videos"): 43 | os.makedirs(video_folder, exist_ok=True) 44 | 45 | env_fns = [ 46 | lambda i=i: make_env( 47 | env_name, 48 | record_video=(i == 0), 49 | video_folder=video_folder, 50 | video_interval=1000, 51 | ) 52 | for i in range(num_envs) 53 | ] 54 | 55 | vec_env = gym.vector.AsyncVectorEnv(env_fns) 56 | return vec_env 57 | 58 | 59 | class VideoLoggerWrapper(gym.vector.VectorWrapper): 60 | def __init__(self, env, video_folder, get_step_callback): 61 | super().__init__(env) 62 | self.video_folder = video_folder 63 | self.last_logged = 0 64 | self.get_step = get_step_callback 65 | 66 | def step(self, action): 67 | obs, rewards, terminated, truncated, infos = super().step(action) 68 | 69 | current_step = self.get_step() 70 | 71 | new_videos = [ 72 | f 73 | for f in os.listdir(self.video_folder) 74 | if f.endswith(".mp4") 75 | and os.path.getmtime(os.path.join(self.video_folder, f)) > self.last_logged 76 | ] 77 | 78 | for video_file in sorted( 79 | new_videos, 80 | key=lambda x: os.path.getctime(os.path.join(self.video_folder, x)), 81 | ): 82 | video_path = os.path.join(self.video_folder, video_file) 83 | wandb.log({"video": wandb.Video(video_path)}, step=current_step) 84 | os.remove(video_path) 85 | self.last_logged = time.time() 86 | 87 | return obs, rewards, terminated, truncated, infos 88 | 89 | 90 | def preprocess(image): 91 | return image.astype(np.float32) / 255.0 92 | 93 | 94 | def quantize(image): 95 | return (image * 255).clip(0, 255).astype(np.uint8) 96 | 97 | 98 | def symlog(x): 99 | return torch.sign(x) * torch.log(torch.abs(x) + 1e-5) 100 | 101 | 102 | def symexp(x): 103 | return torch.sign(x) * (torch.exp(torch.clamp(torch.abs(x), max=20.0)) - 1) 104 | 105 | 106 | def init_weights(m): 107 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 108 | nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) 109 | if m.bias is not None: 110 | nn.init.zeros_(m.bias) 111 | 112 | 113 | def adaptive_gradient_clip(model, clip_factor=0.3, eps=1e-3): 114 | for param in model.parameters(): 115 | if param.grad is not None: 116 | weight_norm = torch.norm(param.detach(), p=2) # L2 norm of weights 117 | grad_norm = torch.norm(param.grad.detach(), p=2) # L2 norm of gradients 118 | max_norm = clip_factor * weight_norm + eps 119 | if grad_norm > max_norm: 120 | scale = max_norm / (grad_norm + 1e-8) # Avoid division by zero 121 | param.grad.mul_(scale) # Scale gradients in-place 122 | 123 | 124 | def save_animation(frames, filename): 125 | with imageio.get_writer(filename, mode="I", loop=0) as writer: 126 | for frame in frames: 127 | writer.append_data(frame) 128 | 129 | 130 | def create_animation(env, agent, save_prefix, seeds=10): 131 | for mod in ["best", "best_avg", "final"]: 132 | save_path = f"environments/{save_prefix}_{mod}.gif" 133 | agent.load_checkpoint(save_prefix, mod) 134 | best_total_reward, best_frames = float("-inf"), None 135 | 136 | for _ in range(seeds): 137 | state, _ = env.reset() 138 | frames, total_reward = [], 0 139 | term, trunc = False, False 140 | 141 | while not (term or trunc): 142 | frames.append(env.render()) 143 | action = agent.act(state) 144 | next_state, reward, term, trunc, _ = env.step(action) 145 | total_reward += reward 146 | state = next_state 147 | 148 | if total_reward > best_total_reward: 149 | best_total_reward = total_reward 150 | best_frames = frames 151 | 152 | save_animation(best_frames, save_path) 153 | wandb.log({f"Animation/{mod}": wandb.Video(save_path, format="gif")}) 154 | 155 | 156 | def log_hparams(config, run_name): 157 | with open(config.wandb_key, "r", encoding="utf-8") as f: 158 | os.environ["WANDB_API_KEY"] = f.read().strip() 159 | 160 | wandb.init( 161 | project="dreamerv3-atari-v2", 162 | name=run_name, 163 | config=wandb.helper.parse_config(config, exclude=("wandb_key",)), 164 | save_code=True, 165 | ) 166 | 167 | 168 | def log_losses(ep: int, losses: dict): 169 | wandb.log( 170 | { 171 | "Loss/World": losses["world_loss"], 172 | "Loss/Recon": losses["recon_loss"], 173 | "Loss/Reward": losses["reward_loss"], 174 | "Loss/Continue": losses["continue_loss"], 175 | "Loss/KL": losses["kl_loss"], 176 | "Loss/Actor": losses["actor_loss"], 177 | "Loss/Critic": losses["critic_loss"], 178 | "Entropy/Actor": losses["actor_entropy"], 179 | "Entropy/Prior": losses["prior_entropy"], 180 | "Entropy/Posterior": losses["posterior_entropy"], 181 | }, 182 | step=ep, 183 | ) 184 | 185 | 186 | def log_rewards( 187 | step: int, 188 | avg_score: float, 189 | best_score: float, 190 | mem_size: int, 191 | episode: int, 192 | total_episodes: int, 193 | ): 194 | wandb.log( 195 | { 196 | "Reward/Average": avg_score, 197 | "Reward/Best": best_score, 198 | "Memory/Size": mem_size, 199 | }, 200 | step=step, 201 | ) 202 | 203 | e_str = f"[Ep {episode:05d}/{total_episodes}]" 204 | a_str = f"Avg.Score = {avg_score:8.2f}" 205 | b_str = f"Best.Score = {best_score:8.2f}" 206 | s_str = f"Step = {step:8d}" 207 | m_str = f"Mem.Size = {mem_size:7d}" 208 | print(f"{e_str} {a_str} {b_str} {m_str} {s_str}", end="\r") 209 | -------------------------------------------------------------------------------- /dreamerv3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ale_py 3 | import gymnasium as gym 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from datetime import datetime 9 | from torch.distributions import Categorical, Independent 10 | from environment import ENV_LIST 11 | import utils 12 | import warnings 13 | 14 | warnings.simplefilter("ignore") 15 | gym.register_envs(ale_py) 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | class Config: 20 | def __init__(self, args): 21 | self.capacity = 2_000_000 22 | self.batch_size = 16 23 | self.sequence_length = 64 24 | self.embed_dim = 1024 25 | self.latent_dim = 32 26 | self.num_classes = 32 27 | self.deter_dim = 4096 28 | self.lr = 4e-5 29 | self.eps = 1e-20 30 | self.actor_lr = 4e-5 31 | self.critic_lr = 4e-5 32 | self.discount = 0.997 33 | self.gae_lambda = 0.95 34 | self.rep_loss_scale = 0.1 35 | self.imagination_horizon = 15 36 | self.min_buffer_size = 500 37 | self.episodes = 100_000 38 | self.device = torch.device("cuda") 39 | self.free_bits = 1.0 40 | self.entropy_coef = 3e-4 41 | self.retnorm_scale = 1.0 42 | self.retnorm_limit = 1.0 43 | self.retnorm_decay = 0.99 44 | self.critic_ema_decay = 0.98 45 | self.update_interval = 2 46 | self.updates_per_step = 1 47 | self.mixed_precision = True 48 | self.wandb_key = args.wandb_key 49 | 50 | 51 | class ReplayBuffer: 52 | def __init__(self, config, device, obs_shape): 53 | self.num_envs = 16 54 | self.capacity = config.capacity // self.num_envs 55 | self.batch_size = config.batch_size 56 | self.sequence_length = config.sequence_length 57 | self.device = device 58 | self.obs_shape = obs_shape 59 | 60 | self.obs_buf = np.zeros( 61 | (self.num_envs, self.capacity, *obs_shape), dtype=np.uint8 62 | ) 63 | self.act_buf = np.zeros((self.num_envs, self.capacity), dtype=np.uint8) 64 | self.rew_buf = np.zeros((self.num_envs, self.capacity), dtype=np.float16) 65 | self.done_buf = np.zeros((self.num_envs, self.capacity), dtype=np.bool_) 66 | self.stoch_buf = np.zeros( 67 | (self.num_envs, self.capacity, config.latent_dim, config.num_classes), 68 | dtype=np.float16, 69 | ) 70 | self.deter_buf = np.zeros( 71 | (self.num_envs, self.capacity, config.deter_dim), dtype=np.float16 72 | ) 73 | self.positions = np.zeros(self.num_envs, dtype=np.int64) 74 | self.full = [False] * self.num_envs 75 | 76 | def store(self, obs, act, rew, done, stoch, deter): 77 | for env_idx in range(self.num_envs): 78 | pos = self.positions[env_idx] 79 | idx = pos % self.capacity 80 | 81 | self.obs_buf[env_idx, idx] = obs[env_idx] 82 | self.act_buf[env_idx, idx] = act[env_idx] 83 | self.rew_buf[env_idx, idx] = rew[env_idx].astype(np.float16) 84 | self.done_buf[env_idx, idx] = done[env_idx] 85 | self.stoch_buf[env_idx, idx] = ( 86 | stoch[env_idx].cpu().numpy().astype(np.float16) 87 | ) 88 | self.deter_buf[env_idx, idx] = ( 89 | deter[env_idx].cpu().numpy().astype(np.float16) 90 | ) 91 | 92 | self.positions[env_idx] += 1 93 | if self.positions[env_idx] >= self.capacity: 94 | self.full[env_idx] = True 95 | self.positions[env_idx] = 0 96 | 97 | def sample(self): 98 | # Sample one sequence from each environment's buffer 99 | indices = [] 100 | start_indices = [] 101 | for env_idx in range(self.num_envs): 102 | current_size = ( 103 | self.capacity if self.full[env_idx] else self.positions[env_idx] 104 | ) 105 | valid_end = current_size - self.sequence_length 106 | 107 | if valid_end <= 0: 108 | start = 0 109 | else: 110 | start = np.random.randint(0, valid_end) 111 | 112 | env_indices = (start + np.arange(self.sequence_length)) % self.capacity 113 | indices.append(env_indices) 114 | start_indices.append(start) 115 | 116 | # Stack indices across environments 117 | indices = np.stack(indices) 118 | 119 | return { 120 | "initial_stoch": torch.as_tensor( 121 | self.stoch_buf[np.arange(self.num_envs), start_indices], 122 | device=self.device, 123 | dtype=torch.float32, 124 | ), 125 | "initial_deter": torch.as_tensor( 126 | self.deter_buf[np.arange(self.num_envs), start_indices], 127 | device=self.device, 128 | dtype=torch.float32, 129 | ), 130 | "observation": torch.as_tensor( 131 | self.obs_buf[np.arange(self.num_envs)[:, None], indices], 132 | dtype=torch.float32, 133 | device=self.device, 134 | ) 135 | .div_(255.0) 136 | .permute(1, 0, 2, 3, 4), 137 | "action": torch.as_tensor( 138 | self.act_buf[np.arange(self.num_envs)[:, None], indices], 139 | dtype=torch.long, 140 | device=self.device, 141 | ).permute(1, 0), 142 | "reward": torch.as_tensor( 143 | self.rew_buf[np.arange(self.num_envs)[:, None], indices], 144 | dtype=torch.float32, 145 | device=self.device, 146 | ).permute(1, 0), 147 | "done": torch.as_tensor( 148 | self.done_buf[np.arange(self.num_envs)[:, None], indices], 149 | dtype=torch.float32, 150 | device=self.device, 151 | ).permute(1, 0), 152 | } 153 | 154 | def __len__(self): 155 | # Return minimum available length across all environments 156 | return min( 157 | pos if not full else self.capacity 158 | for pos, full in zip(self.positions, self.full) 159 | ) 160 | 161 | def size(self): 162 | # Return total size of the buffer 163 | return sum( 164 | pos if not full else self.capacity 165 | for pos, full in zip(self.positions, self.full) 166 | ) 167 | 168 | 169 | class LAProp(torch.optim.Optimizer): 170 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-20): 171 | defaults = dict(lr=lr, betas=betas, eps=eps) 172 | super().__init__(params, defaults) 173 | self.state["step"] = 0 174 | 175 | @torch.no_grad() 176 | def step(self, closure=None): 177 | loss = None 178 | if closure is not None: 179 | loss = closure() 180 | 181 | for group in self.param_groups: 182 | beta1, beta2 = group["betas"] 183 | eps = group["eps"] 184 | 185 | for p in group["params"]: 186 | if p.grad is None: 187 | continue 188 | 189 | grad = p.grad 190 | state = self.state[p] 191 | 192 | if len(state) == 0: 193 | state["exp_avg_sq"] = torch.zeros_like(p) 194 | state["momentum_buffer"] = torch.zeros_like(p) 195 | 196 | exp_avg_sq = state["exp_avg_sq"] 197 | momentum_buffer = state["momentum_buffer"] 198 | 199 | # RMSProp update 200 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 201 | denom = exp_avg_sq.sqrt().add_(eps) 202 | normalized_grad = grad / denom 203 | 204 | # Momentum update 205 | momentum_buffer.mul_(beta1).add_(normalized_grad, alpha=1 - beta1) 206 | 207 | # Parameter update 208 | p.add_(momentum_buffer, alpha=-group["lr"]) 209 | 210 | return loss 211 | 212 | 213 | class OneHotCategoricalStraightThrough(Categorical): 214 | def sample(self, sample_shape=torch.Size()): 215 | samples = super().sample(sample_shape) 216 | return self.probs + (samples - self.probs).detach() 217 | 218 | 219 | class ObservationEncoder(nn.Module): 220 | def __init__(self, in_channels=3, embed_dim=1024): 221 | super().__init__() 222 | self.conv = nn.Sequential( 223 | nn.Conv2d(in_channels, 32, 4, 2), 224 | nn.ReLU(), 225 | nn.Conv2d(32, 64, 4, 2), 226 | nn.ReLU(), 227 | nn.Conv2d(64, 128, 4, 2), 228 | nn.ReLU(), 229 | nn.Conv2d(128, 256, 4, 2), 230 | nn.ReLU(), 231 | nn.Flatten(), 232 | nn.Linear(256 * 2 * 2, embed_dim), 233 | nn.LayerNorm(embed_dim), 234 | ) 235 | self.apply(utils.init_weights) 236 | 237 | def forward(self, x): 238 | return torch.utils.checkpoint.checkpoint(self.conv, x) 239 | 240 | 241 | class ObservationDecoder(nn.Module): 242 | def __init__(self, feature_dim, out_channels=3, output_size=(64, 64)): 243 | super().__init__() 244 | self.out_channels = out_channels 245 | self.output_size = output_size 246 | 247 | self.net = nn.Sequential( 248 | nn.Linear(feature_dim, 256 * 8 * 8), 249 | nn.LayerNorm(256 * 8 * 8), 250 | nn.SiLU(), 251 | nn.Unflatten(1, (256, 8, 8)), 252 | nn.ConvTranspose2d(256, 128, 4, 2, 1), 253 | nn.SiLU(), 254 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 255 | nn.SiLU(), 256 | nn.ConvTranspose2d(64, 32, 4, 2, 1), 257 | nn.SiLU(), 258 | nn.Conv2d(32, out_channels, 3, padding=1), 259 | nn.Sigmoid(), 260 | ) 261 | self.apply(utils.init_weights) 262 | 263 | def forward(self, x): 264 | return torch.utils.checkpoint.checkpoint(self.net, x) 265 | 266 | 267 | class TwoHotCategoricalStraightThrough(torch.distributions.Distribution): 268 | def __init__(self, logits, bins=255, low=-20.0, high=20.0): 269 | super().__init__(validate_args=False) 270 | self.logits = logits 271 | self.bin_centers = torch.linspace(low, high, bins, device=logits.device) 272 | 273 | def log_prob(self, value): 274 | value = utils.symlog(value).clamp(self.bin_centers[0], self.bin_centers[-1]) 275 | indices = ( 276 | (value - self.bin_centers[0]) / (self.bin_centers[1] - self.bin_centers[0]) 277 | ).clamp(0, len(self.bin_centers) - 1) 278 | 279 | lower = indices.floor().long().unsqueeze(-1) 280 | upper = indices.ceil().long().unsqueeze(-1) 281 | alpha = (indices - lower.squeeze(-1)).unsqueeze(-1) 282 | 283 | probs = F.softmax(self.logits, dim=-1) 284 | return torch.log( 285 | (1 - alpha) * probs.gather(-1, lower) + alpha * probs.gather(-1, upper) 286 | ).squeeze(-1) 287 | 288 | @property 289 | def mean(self): 290 | return utils.symexp( 291 | (F.softmax(self.logits, dim=-1) * self.bin_centers).sum(-1, keepdim=True) 292 | ) 293 | 294 | 295 | class RSSM(nn.Module): 296 | def __init__(self, action_dim, latent_dim, num_classes, deter_dim, embed_dim): 297 | super().__init__() 298 | self.action_dim = action_dim 299 | self.latent_dim = latent_dim 300 | self.num_classes = num_classes 301 | self.deter_dim = deter_dim 302 | self.embed_dim = embed_dim 303 | 304 | self.prior_net = nn.Sequential( 305 | nn.Linear(deter_dim, 512), 306 | nn.LayerNorm(512), 307 | nn.SiLU(), 308 | nn.Linear(512, latent_dim * num_classes), 309 | ) 310 | 311 | self.post_net = nn.Sequential( 312 | nn.Linear(deter_dim + embed_dim, 512), 313 | nn.LayerNorm(512), 314 | nn.SiLU(), 315 | nn.Linear(512, latent_dim * num_classes), 316 | ) 317 | 318 | self.gru = nn.GRUCell(latent_dim * num_classes + action_dim, deter_dim) 319 | self.deter_init = nn.Parameter(torch.zeros(1, deter_dim)) 320 | self.stoch_init = nn.Parameter(torch.zeros(1, latent_dim * num_classes)) 321 | self.apply(utils.init_weights) 322 | 323 | def init_state(self, batch_size, device): 324 | stoch = ( 325 | F.one_hot( 326 | torch.zeros(batch_size, self.latent_dim, dtype=torch.long), 327 | self.num_classes, 328 | ) 329 | .float() 330 | .to(device) 331 | ) 332 | deter = self.deter_init.repeat(batch_size, 1) 333 | return (stoch, deter) 334 | 335 | def imagine_step(self, stoch, deter, action): 336 | # Use prior network 337 | stoch = utils.symlog(stoch) 338 | action_oh = F.one_hot(action, self.action_dim).float() 339 | gru_input = torch.cat([stoch.flatten(1), action_oh], dim=1) 340 | deter = self.gru(gru_input, deter) 341 | if torch.isnan(deter).any(): 342 | print("NaN detected in GRU input") 343 | deter = torch.nan_to_num(deter, nan=0.0) 344 | prior_logits = self.prior_net(deter).view(-1, self.latent_dim, self.num_classes) 345 | prior_logits = prior_logits - torch.logsumexp(prior_logits, -1, keepdim=True) 346 | prior_logits = torch.log( 347 | 0.99 * torch.softmax(prior_logits, -1) + 0.01 / self.num_classes 348 | ) 349 | stoch = F.gumbel_softmax(prior_logits, tau=1.0, hard=True) 350 | return utils.symexp(stoch), deter 351 | 352 | def observe_step(self, deter, embed): 353 | # Use posterior network 354 | post_logits = self.post_net(torch.cat([deter, embed], dim=1)) 355 | post_logits = post_logits.view(-1, self.latent_dim, self.num_classes) 356 | post_logits = post_logits - torch.logsumexp(post_logits, -1, keepdim=True) 357 | post_logits = torch.log( 358 | 0.99 * torch.softmax(post_logits, -1) + 0.01 / self.num_classes 359 | ) 360 | return post_logits 361 | 362 | def imagine(self, init_state, actor, horizon): 363 | stoch, deter = init_state 364 | features, actions = [], [] 365 | 366 | for _ in range(horizon): 367 | feature = torch.cat([deter, stoch.flatten(1)], dim=1) 368 | with torch.no_grad(): 369 | action = actor(feature).sample() 370 | 371 | stoch, deter = self.imagine_step(stoch, deter, action) 372 | features.append(feature) 373 | actions.append(action) 374 | 375 | return torch.stack(features), torch.stack(actions) 376 | 377 | def observe(self, embed_seq, action_seq, init_state): 378 | priors, posteriors = [], [] 379 | features = [] 380 | stoch, deter = init_state 381 | 382 | for t in range(action_seq.size(0)): 383 | gru_input = torch.cat([stoch.flatten(1), action_seq[t]], dim=1) 384 | deter = self.gru(gru_input, deter) 385 | 386 | prior_logits, post_logits = self.observe_step(deter, embed_seq[t]) 387 | 388 | prior_dist = Independent( 389 | OneHotCategoricalStraightThrough(logits=prior_logits), 1 390 | ) 391 | post_dist = Independent( 392 | OneHotCategoricalStraightThrough(logits=post_logits), 1 393 | ) 394 | 395 | stoch = F.gumbel_softmax(post_logits, tau=1.0, hard=True) 396 | features.append(torch.cat([deter, stoch.flatten(1)], dim=1)) 397 | 398 | priors.append(prior_dist) 399 | posteriors.append(post_dist) 400 | 401 | return (priors, posteriors), torch.stack(features) 402 | 403 | 404 | class WorldModel(nn.Module): 405 | def __init__( 406 | self, 407 | in_channels, 408 | action_dim, 409 | embed_dim, 410 | latent_dim, 411 | num_classes, 412 | deter_dim, 413 | obs_size, 414 | ): 415 | super().__init__() 416 | self.encoder = ObservationEncoder(in_channels, embed_dim) 417 | self.rssm = RSSM(action_dim, latent_dim, num_classes, deter_dim, embed_dim) 418 | self.decoder = ObservationDecoder( 419 | deter_dim + latent_dim * num_classes, in_channels, obs_size[1:] 420 | ) 421 | self.reward_decoder = nn.Sequential( 422 | nn.Linear(deter_dim + latent_dim * num_classes, 255) 423 | ) 424 | self.continue_decoder = nn.Sequential( 425 | nn.Linear(deter_dim + latent_dim * num_classes, 1) 426 | ) 427 | 428 | self.apply(utils.init_weights) 429 | self.reward_decoder[-1].weight.data.zero_() 430 | self.reward_decoder[-1].bias.data.zero_() 431 | 432 | def observe(self, observations, actions, stoch, deter): 433 | embed = self.encoder(observations.flatten(0, 1)).view( 434 | actions.size(0), actions.size(1), -1 435 | ) 436 | actions_onehot = F.one_hot(actions, self.rssm.action_dim).float() 437 | 438 | priors, posteriors = [], [] 439 | features = [] 440 | 441 | for t in range(actions.size(0)): 442 | deter = self.rssm.gru( 443 | torch.cat([stoch.flatten(1), actions_onehot[t]], dim=1), deter 444 | ) 445 | 446 | prior_logits = self.rssm.prior_net(deter).view( 447 | deter.size(0), self.rssm.latent_dim, self.rssm.num_classes 448 | ) 449 | prior_dist = Independent( 450 | OneHotCategoricalStraightThrough(logits=prior_logits), 1 451 | ) 452 | 453 | post_logits = self.rssm.post_net(torch.cat([deter, embed[t]], dim=1)) 454 | post_logits = post_logits.view( 455 | deter.size(0), self.rssm.latent_dim, self.rssm.num_classes 456 | ) 457 | post_dist = Independent( 458 | OneHotCategoricalStraightThrough(logits=post_logits), 1 459 | ) 460 | 461 | stoch = F.gumbel_softmax(post_logits, tau=1.0, hard=True) 462 | features.append(torch.cat([deter, stoch.flatten(1)], dim=1)) 463 | priors.append(prior_dist) 464 | posteriors.append(post_dist) 465 | 466 | features = torch.stack(features) 467 | recon_dist = self.decoder(features.flatten(0, 1)) 468 | reward_dist = TwoHotCategoricalStraightThrough( 469 | self.reward_decoder(features.flatten(0, 1)) 470 | ) 471 | continue_pred = self.continue_decoder(features.flatten(0, 1)) 472 | 473 | return (priors, posteriors), features, recon_dist, reward_dist, continue_pred 474 | 475 | 476 | class Actor(nn.Module): 477 | def __init__(self, feature_dim, action_dim): 478 | super().__init__() 479 | self.net = nn.Sequential( 480 | nn.Linear(feature_dim, 512), 481 | nn.LayerNorm(512), 482 | nn.SiLU(), 483 | nn.Linear(512, 512), 484 | nn.LayerNorm(512), 485 | nn.SiLU(), 486 | nn.Linear(512, 512), 487 | nn.LayerNorm(512), 488 | nn.SiLU(), 489 | nn.Linear(512, action_dim), 490 | ) 491 | self.apply(utils.init_weights) 492 | 493 | def forward(self, x): 494 | return Categorical(logits=self.net(x)) 495 | 496 | 497 | class Critic(nn.Module): 498 | def __init__(self, feature_dim): 499 | super().__init__() 500 | self.net = nn.Sequential( 501 | nn.Linear(feature_dim, 512), 502 | nn.LayerNorm(512), 503 | nn.SiLU(), 504 | nn.Linear(512, 512), 505 | nn.LayerNorm(512), 506 | nn.SiLU(), 507 | nn.Linear(512, 512), 508 | nn.LayerNorm(512), 509 | nn.SiLU(), 510 | nn.Linear(512, 255), 511 | ) 512 | self.apply(utils.init_weights) 513 | # Initialize last layer to zeros as per the paper 514 | self.net[-1].weight.data.zero_() 515 | self.net[-1].bias.data.zero_() 516 | 517 | def forward(self, x): 518 | return self.net(x) 519 | 520 | 521 | class DreamerV3: 522 | def __init__(self, obs_shape, action_dim, config): 523 | self.obs_shape = obs_shape 524 | self.action_dim = action_dim 525 | self.config = config 526 | self.replay_buffer = ReplayBuffer(config, config.device, obs_shape) 527 | self.device = config.device 528 | self.num_envs = 16 529 | 530 | self.world_model = WorldModel( 531 | obs_shape[0], 532 | action_dim, 533 | config.embed_dim, 534 | config.latent_dim, 535 | config.num_classes, 536 | config.deter_dim, 537 | obs_shape, 538 | ).to(self.device) 539 | 540 | feature_dim = config.deter_dim + config.latent_dim * config.num_classes 541 | self.actor = Actor(feature_dim, action_dim).to(self.device) 542 | self.critic = Critic(feature_dim).to(self.device) 543 | self.target_critic = Critic(feature_dim).to(self.device) 544 | self.target_critic.load_state_dict(self.critic.state_dict()) 545 | for param in self.target_critic.parameters(): 546 | param.requires_grad = False 547 | 548 | self.bin_centers = torch.linspace(-20.0, 20.0, 255, device=self.device) 549 | 550 | self.optimizers = { 551 | "world": LAProp( 552 | self.world_model.parameters(), 553 | lr=config.lr, 554 | betas=(0.9, 0.99), 555 | eps=config.eps, 556 | ), 557 | "actor": LAProp( 558 | self.actor.parameters(), 559 | lr=config.actor_lr, 560 | betas=(0.9, 0.99), 561 | eps=config.eps, 562 | ), 563 | "critic": LAProp( 564 | self.critic.parameters(), 565 | lr=config.critic_lr, 566 | betas=(0.9, 0.99), 567 | eps=config.eps, 568 | ), 569 | } 570 | self.scalers = { 571 | "world": torch.amp.GradScaler("cuda"), 572 | "actor": torch.amp.GradScaler("cuda"), 573 | "critic": torch.amp.GradScaler("cuda"), 574 | } 575 | 576 | self.init_hidden_state() 577 | self._reset_stoch, self._reset_deter = self.world_model.rssm.init_state( 578 | self.num_envs, self.device 579 | ) 580 | self.step = 0 581 | 582 | def init_hidden_state(self): 583 | self.hidden_state = self.world_model.rssm.init_state(self.num_envs, self.device) 584 | 585 | def reset_hidden_states(self, done_indices): 586 | """Reset hidden states for specified environment indices""" 587 | if not done_indices.any(): 588 | return 589 | 590 | stoch, deter = self.hidden_state 591 | stoch[done_indices] = self._reset_stoch[done_indices] 592 | deter[done_indices] = self._reset_deter[done_indices] 593 | 594 | def act(self, observations): 595 | obs = torch.tensor(observations, dtype=torch.float32, device=self.device) 596 | with torch.no_grad(): 597 | stoch, deter = self.hidden_state 598 | embed = self.world_model.encoder(obs) 599 | 600 | # Get posteriors 601 | post_logits = self.world_model.rssm.observe_step(deter, embed) 602 | post_logits = post_logits.view( 603 | self.num_envs, self.config.latent_dim, self.config.num_classes 604 | ) 605 | stoch = F.gumbel_softmax(post_logits, tau=1.0, hard=True) 606 | 607 | # Get actions 608 | feature = torch.cat([deter, stoch.flatten(1)], dim=1) 609 | action_dist = self.actor(feature) 610 | actions = action_dist.sample() 611 | 612 | # Update hidden states 613 | _, deter = self.world_model.rssm.imagine_step(stoch, deter, actions) 614 | self.hidden_state = (stoch, deter) 615 | 616 | return actions.cpu().numpy() 617 | 618 | def store_transition(self, obs, actions, rewards, dones): 619 | stoch, deter = self.hidden_state 620 | with torch.no_grad(): 621 | obs_tensor = torch.as_tensor(obs, device=self.device) 622 | quantized_obs = (obs_tensor * 255).clamp(0, 255).byte().cpu().numpy() 623 | self.replay_buffer.store( 624 | quantized_obs, actions, rewards, dones, stoch.detach(), deter.detach() 625 | ) 626 | 627 | def update_world_model(self, batch): 628 | self.optimizers["world"].zero_grad() 629 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 630 | init_stoch = batch["initial_stoch"] 631 | init_deter = batch["initial_deter"] 632 | obs, actions = batch["observation"], batch["action"] 633 | 634 | (priors, posteriors), features, recon_dist, reward_dist, continue_pred = ( 635 | self.world_model.observe(obs, actions, init_stoch, init_deter) 636 | ) 637 | 638 | prior_entropy = torch.stack([p.entropy() for p in priors]).mean() 639 | post_entropy = torch.stack([q.entropy() for q in posteriors]).mean() 640 | 641 | flat_feat = features.permute(1, 0, 2).flatten(0, 1) 642 | obs = batch["observation"] 643 | recon_target = obs.permute(1, 0, *range(2, obs.ndim)).flatten(0, 1) 644 | recon_pred = self.world_model.decoder(flat_feat) 645 | if obs.dtype == torch.uint8: 646 | recon_target = recon_target.float() / 255.0 647 | recon_loss = F.mse_loss(recon_pred, recon_target, reduction="mean") 648 | 649 | reward_loss = -reward_dist.log_prob(batch["reward"].flatten(0, 1)).mean() 650 | continue_loss = F.binary_cross_entropy_with_logits( 651 | continue_pred.flatten(0, 1), (1 - batch["done"].flatten(0, 1)) 652 | ) 653 | 654 | dyn_loss = torch.stack( 655 | [ 656 | torch.maximum( 657 | torch.tensor(self.config.free_bits, device=self.device), 658 | torch.distributions.kl_divergence( 659 | Independent( # Detach posterior logits 660 | OneHotCategoricalStraightThrough( 661 | logits=posterior.base_dist.logits.detach() 662 | ), 663 | 1, 664 | ), 665 | prior, 666 | ).sum(dim=-1), 667 | ) 668 | for prior, posterior in zip(priors, posteriors) 669 | ] 670 | ).mean() 671 | 672 | rep_loss = torch.stack( 673 | [ 674 | torch.maximum( 675 | torch.tensor(self.config.free_bits, device=self.device), 676 | torch.distributions.kl_divergence( 677 | posterior, 678 | Independent( # Detach prior logits 679 | OneHotCategoricalStraightThrough( 680 | logits=prior.base_dist.logits.detach() 681 | ), 682 | 1, 683 | ), 684 | ).sum(dim=-1), 685 | ) 686 | for prior, posterior in zip(priors, posteriors) 687 | ] 688 | ).mean() 689 | 690 | kl_loss = dyn_loss + rep_loss * self.config.rep_loss_scale 691 | total_loss = recon_loss + reward_loss + continue_loss + kl_loss 692 | 693 | self.scalers["world"].scale(total_loss).backward() 694 | self.scalers["world"].unscale_(self.optimizers["world"]) 695 | utils.adaptive_gradient_clip(self.world_model, clip_factor=0.3, eps=1e-3) 696 | self.scalers["world"].step(self.optimizers["world"]) 697 | self.scalers["world"].update() 698 | 699 | return { 700 | "world_loss": total_loss.item(), 701 | "recon_loss": recon_loss.item(), 702 | "reward_loss": reward_loss.item(), 703 | "continue_loss": continue_loss.item(), 704 | "kl_loss": kl_loss.item(), 705 | "prior_entropy": prior_entropy.item(), 706 | "posterior_entropy": post_entropy.item(), 707 | } 708 | 709 | def update_actor_and_critic(self, replay_batch): 710 | B = self.config.batch_size 711 | init_state = (replay_batch["initial_stoch"], replay_batch["initial_deter"]) 712 | 713 | with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.float16): 714 | # Imagination rollout 715 | features, actions = self.world_model.rssm.imagine( 716 | init_state, self.actor, self.config.imagination_horizon 717 | ) 718 | 719 | # Predict rewards and continues 720 | flat_features = features.flatten(0, 1) 721 | reward_logits = self.world_model.reward_decoder(flat_features) 722 | reward_dist = TwoHotCategoricalStraightThrough(reward_logits) 723 | rewards = utils.symlog(reward_dist.mean.view(features.shape[0], B)) 724 | 725 | continue_pred = self.world_model.continue_decoder(flat_features) 726 | continues = continue_pred.view(features.shape[0], B) 727 | discounts = self.config.discount * continues 728 | 729 | # Compute values from target critic 730 | T, B, _ = features.shape 731 | critic_logits = self.target_critic(features.flatten(0, 1)) 732 | probs = F.softmax(critic_logits, dim=-1) 733 | values = (probs * self.bin_centers).sum(-1) 734 | values = values.view(T, B) 735 | 736 | # Compute lambda returns 737 | lambda_returns = torch.zeros_like(values) 738 | lambda_returns[-1] = values[-1] 739 | for t in reversed(range(T - 1)): 740 | blended = (1 - self.config.gae_lambda) * values[ 741 | t 742 | ] + self.config.gae_lambda * lambda_returns[t + 1] 743 | lambda_returns[t] = rewards[t] + discounts[t] * blended 744 | 745 | # Return normalization 746 | returns_flat = lambda_returns.flatten() 747 | current_scale = torch.quantile(returns_flat, 0.95) - torch.quantile( 748 | returns_flat, 0.05 749 | ) 750 | current_scale = current_scale.clamp(min=self.config.retnorm_limit) 751 | self.config.retnorm_scale = ( 752 | self.config.retnorm_decay * self.config.retnorm_scale 753 | + (1 - self.config.retnorm_decay) * current_scale.item() 754 | ) 755 | lambda_returns = lambda_returns / max(1.0, self.config.retnorm_scale) 756 | 757 | # Process replay buffer samples 758 | _, replay_features, _, _, _ = self.world_model.observe( 759 | replay_batch["observation"], 760 | replay_batch["action"], 761 | replay_batch["initial_stoch"], 762 | replay_batch["initial_deter"], 763 | ) 764 | replay_rewards = replay_batch["reward"] 765 | replay_dones = replay_batch["done"] 766 | replay_continues = (1 - replay_dones.float()) * self.config.discount 767 | 768 | # Compute replay returns 769 | replay_values = ( 770 | F.softmax(self.target_critic(replay_features.flatten(0, 1)), -1) 771 | * self.bin_centers 772 | ).sum(-1) 773 | replay_values = replay_values.view(replay_features.shape[0], B) 774 | 775 | replay_lambda_returns = torch.zeros_like(replay_values) 776 | replay_lambda_returns[-1] = replay_values[-1] 777 | for t in reversed(range(replay_features.shape[0] - 1)): 778 | blended = (1 - self.config.gae_lambda) * replay_values[ 779 | t 780 | ] + self.config.gae_lambda * replay_lambda_returns[t + 1] 781 | replay_lambda_returns[t] = ( 782 | replay_rewards[t] + replay_continues[t] * blended 783 | ) 784 | 785 | # Normalize replay returns using the same scale 786 | replay_lambda_returns = replay_lambda_returns / max( 787 | 1.0, self.config.retnorm_scale 788 | ) 789 | 790 | # Critic update 791 | self.optimizers["critic"].zero_grad() 792 | 793 | # Imagination loss 794 | critic_logits = self.critic(features.flatten(0, 1)) 795 | critic_dist = TwoHotCategoricalStraightThrough(critic_logits) 796 | imagination_loss = -critic_dist.log_prob(lambda_returns.flatten(0, 1)).mean() 797 | 798 | # Replay loss 799 | replay_critic_logits = self.critic(replay_features.flatten(0, 1)) 800 | replay_critic_dist = TwoHotCategoricalStraightThrough(replay_critic_logits) 801 | replay_loss = -replay_critic_dist.log_prob( 802 | replay_lambda_returns.flatten(0, 1) 803 | ).mean() 804 | 805 | total_critic_loss = imagination_loss + 0.3 * replay_loss 806 | 807 | self.scalers["critic"].scale(total_critic_loss).backward() 808 | self.scalers["critic"].unscale_(self.optimizers["critic"]) 809 | utils.adaptive_gradient_clip(self.critic, clip_factor=0.3, eps=1e-3) 810 | self.scalers["critic"].step(self.optimizers["critic"]) 811 | self.scalers["critic"].update() 812 | 813 | # Update target critic 814 | with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.float16): 815 | for online_param, target_param in zip( 816 | self.critic.parameters(), self.target_critic.parameters() 817 | ): 818 | target_param.data.mul_(self.config.critic_ema_decay).add_( 819 | online_param.data, alpha=1 - self.config.critic_ema_decay 820 | ) 821 | 822 | # Actor update 823 | self.optimizers["actor"].zero_grad() 824 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 825 | advantages = (lambda_returns - values.detach()).flatten(0, 1) 826 | 827 | action_dist = self.actor(features.flatten(0, 1)) 828 | log_probs = action_dist.log_prob(actions.flatten(0, 1)) 829 | entropy = action_dist.entropy().mean() 830 | 831 | actor_loss = ( 832 | -(log_probs * advantages.detach()).mean() 833 | - self.config.entropy_coef * entropy 834 | ) 835 | 836 | self.scalers["actor"].scale(actor_loss).backward() 837 | self.scalers["actor"].unscale_(self.optimizers["actor"]) 838 | utils.adaptive_gradient_clip(self.actor, clip_factor=0.3, eps=1e-3) 839 | self.scalers["actor"].step(self.optimizers["actor"]) 840 | self.scalers["actor"].update() 841 | 842 | return { 843 | "actor_loss": actor_loss.item(), 844 | "critic_loss": total_critic_loss.item(), 845 | "actor_entropy": entropy.item(), 846 | } 847 | 848 | def train(self): 849 | if len(self.replay_buffer) < self.config.min_buffer_size: 850 | return None 851 | 852 | losses = { 853 | "world_loss": 0, 854 | "recon_loss": 0, 855 | "reward_loss": 0, 856 | "continue_loss": 0, 857 | "kl_loss": 0, 858 | "actor_loss": 0, 859 | "critic_loss": 0, 860 | "actor_entropy": 0, 861 | "prior_entropy": 0, 862 | "posterior_entropy": 0, 863 | } 864 | 865 | for _ in range(self.config.updates_per_step): 866 | batch = self.replay_buffer.sample() 867 | 868 | # World model update 869 | wm_losses = self.update_world_model(batch) 870 | for k, v in wm_losses.items(): 871 | losses[k] += v / self.config.updates_per_step 872 | 873 | # Actor-critic update 874 | ac_losses = self.update_actor_and_critic(batch) 875 | for k, v in ac_losses.items(): 876 | losses[k] += v / self.config.updates_per_step 877 | 878 | self.step += 1 879 | return losses 880 | 881 | def save_checkpoint(self, env_name): 882 | os.makedirs("weights", exist_ok=True) 883 | torch.save( 884 | { 885 | "world_model": self.world_model.state_dict(), 886 | "actor": self.actor.state_dict(), 887 | "critic": self.critic.state_dict(), 888 | }, 889 | f"weights/{env_name}_dreamerv3.pt", 890 | ) 891 | 892 | def load_checkpoint(self, env_name, mod="best"): 893 | checkpoint = torch.load(f"weights/{env_name}_{mod}_dreamerv3.pt") 894 | self.world_model.load_state_dict(checkpoint["world_model"]) 895 | self.actor.load_state_dict(checkpoint["actor"]) 896 | self.critic.load_state_dict(checkpoint["critic"]) 897 | 898 | 899 | def train_dreamer(args): 900 | config = Config(args) 901 | 902 | step_counter = 0 903 | env = utils.make_vec_env(args.env, num_envs=16) 904 | env = utils.VideoLoggerWrapper(env, "videos", lambda: step_counter) 905 | 906 | obs_shape = env.single_observation_space.shape 907 | act_dim = env.single_action_space.n 908 | save_prefix = args.env.split("/")[-1] 909 | print(f"Env: {save_prefix}, Obs: {obs_shape}, Act: {act_dim}") 910 | 911 | agent = DreamerV3(obs_shape, act_dim, config) 912 | 913 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 914 | run_name = f"{save_prefix}_{timestamp}" 915 | utils.log_hparams(config, run_name) 916 | 917 | episode_history = [] 918 | avg_reward_window = 100 919 | best_score, best_avg = float("-inf"), float("-inf") 920 | episode_scores = np.zeros(16) 921 | 922 | states, _ = env.reset() 923 | agent.init_hidden_state() 924 | 925 | while len(episode_history) < config.episodes: 926 | actions = agent.act(states) 927 | next_states, rewards, terms, truncs, _ = env.step(actions) 928 | dones = np.logical_or(terms, truncs) 929 | agent.store_transition(states, actions, rewards, dones) 930 | episode_scores += rewards 931 | 932 | reset_indices = np.where(dones)[0] 933 | if len(reset_indices) > 0: 934 | agent.reset_hidden_states(reset_indices) 935 | for idx in reset_indices: 936 | episode_history.append(episode_scores[idx]) 937 | episode_scores[idx] = 0 938 | 939 | step_counter += 1 940 | states = next_states 941 | 942 | if len(agent.replay_buffer) >= config.min_buffer_size: 943 | if step_counter % config.update_interval == 0: 944 | losses = agent.train() 945 | utils.log_losses(step_counter, losses) 946 | 947 | avg_score = np.mean(episode_history[-avg_reward_window:]) 948 | mem_size = agent.replay_buffer.size() 949 | utils.log_rewards( 950 | step_counter, 951 | avg_score, 952 | best_score, 953 | mem_size, 954 | len(episode_history), 955 | config.episodes, 956 | ) 957 | 958 | if max(episode_history, default=float("-inf")) > best_score: 959 | best_score = max(episode_history) 960 | # agent.save_checkpoint(save_prefix + "_best") 961 | 962 | if avg_score > best_avg: 963 | best_avg = avg_score 964 | # agent.save_checkpoint(save_prefix + "_best_avg") 965 | 966 | print(f"\nFinished training. Best Avg.Score = {best_avg:.2f}") 967 | agent.save_checkpoint(save_prefix + "_final") 968 | env.close() 969 | 970 | 971 | if __name__ == "__main__": 972 | import argparse 973 | 974 | parser = argparse.ArgumentParser() 975 | parser.add_argument("--env", type=str, default=None) 976 | parser.add_argument("--wandb_key", type=str, default="../wandb.txt") 977 | args = parser.parse_args() 978 | for folder in ["videos", "weights"]: 979 | os.makedirs(folder, exist_ok=True) 980 | if args.env: 981 | train_dreamer(args) 982 | else: 983 | rand_order = np.random.permutation(ENV_LIST) 984 | for env in rand_order: 985 | args.env = env 986 | train_dreamer(args) 987 | --------------------------------------------------------------------------------