├── .gitignore ├── LICENSE ├── README.md ├── assets ├── demo.gif ├── img.png └── img_2.png ├── config ├── agent │ └── racing.yaml ├── env │ └── racing.yaml ├── trainer.yaml └── world_model_env │ └── fast.yaml ├── game └── spawn │ ├── 1 │ ├── act.npy │ ├── full_res.npy │ ├── low_res.npy │ └── next_act.npy │ ├── 2 │ ├── act.npy │ ├── full_res.npy │ ├── low_res.npy │ └── next_act.npy │ └── 3 │ ├── act.npy │ ├── full_res.npy │ ├── low_res.npy │ └── next_act.npy ├── requirements.txt ├── scripts ├── import_run.py └── resume.sh └── src ├── __init__.py ├── agent.py ├── coroutines ├── __init__.py ├── collector.py └── env_loop.py ├── data ├── __init__.py ├── batch.py ├── batch_sampler.py ├── dataset.py ├── episode.py ├── segment.py └── utils.py ├── envs ├── __init__.py ├── atari_preprocessing.py ├── env.py └── world_model_env.py ├── game ├── __init__.py ├── dataset_env.py ├── game.py └── play_env.py ├── main.py ├── models ├── __init__.py ├── actor_critic.py ├── blocks.py ├── diffusion │ ├── __init__.py │ ├── denoiser.py │ ├── diffusion_sampler.py │ └── inner_model.py └── rew_end_model.py ├── play.py ├── player ├── __init__.py ├── action_processing.py └── keymap.py ├── process_denoiser_files.py ├── process_upsampler_files.py ├── spawn.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.zip 3 | checkpoints 4 | **/__pycache__/ 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Eloi Alonso 4 | Copyright (c) 2025 Enigma Labs AI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiverse: The First AI Multiplayer World Model 2 | 3 | 🌐 [Enigma-AI website](https://enigma-labs.io/) - 📚 [Technical Blog](https://enigma-labs.io/) - [🤗 Model on Huggingface](https://huggingface.co/Enigma-AI/multiverse) - [🤗 Datasets on Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res) - 𝕏 [Multiverse Tweet](https://x.com/j0nathanj/status/1920516649511244258) 4 | 5 |
6 | Two human players driving cars in Multiverse 7 |
8 | Cars in Multiverse 9 |
10 | 11 | --- 12 | 13 | ## Installation 14 | ```bash 15 | git clone https://github.com/EnigmaLabsAI/multiverse 16 | cd multiverse 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ### Running the model 21 | 22 | ```bash 23 | python src/play.py --compile 24 | ``` 25 | 26 | > Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py 27 | 28 | When running this command, you will be prompted with the controls. Press `enter` to start: 29 | ![img.png](assets/img.png) 30 | 31 | Then the game will be start: 32 | * To control the silver car at the top screen use the arrow keys. 33 | * To control the blue car at the bottom use the WASD keys. 34 | 35 | ![img_2.png](assets/img_2.png) 36 | 37 | --- 38 | 39 | 40 | ## Training 41 | 42 | Multiverse comprised two models: 43 | * Denoiser - a world model that simulates a game 44 | * Upsampler - a model which takes the frames from the denoiser and increases their resolution 45 | 46 | ### Denoiser training 47 | 48 | #### 1. Download the dataset 49 | Download the Denoiser's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res). 50 | 51 | #### 2. Process data for training 52 | Run the command: 53 | ```bash 54 | python src/process_denoiser_files.py 55 | ``` 56 | 57 | #### 3. Edit training configuration 58 | 59 | Edit [config/env/racing.yaml](config/env/racing.yaml) and set: 60 | - `path_data_low_res` to `/low_res` 61 | - `path_data_full_res` to `/full_res` 62 | 63 | Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`: 64 | ```yaml 65 | train_model: denoiser 66 | ``` 67 | 68 | #### 4. Launch training run 69 | 70 | You can then launch a training run with `python src/main.py`. 71 | 72 | 73 | ### Upsampler training 74 | 75 | #### 1. Download the dataset 76 | Download the Upsampler's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res). 77 | 78 | #### 2. Process data for training 79 | Run the command: 80 | ```bash 81 | python src/process_upsampler_files.py 82 | ``` 83 | 84 | #### 3. Edit training configuration 85 | 86 | Edit [config/env/racing.yaml](config/env/racing.yaml) and set: 87 | - `path_data_low_res` to `/low_res` 88 | - `path_data_full_res` to `/full_res` 89 | 90 | Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`: 91 | ```yaml 92 | train_model: upsampler 93 | ``` 94 | 95 | #### 4. Launch training run 96 | 97 | You can then launch a training run with `python src/main.py`. 98 | 99 | 100 | --- 101 | 102 | ## Datasets 103 | 104 | 1. We've collected over 4 hours of multiplayer (1v1) footage from Gran Turismo 4 at a resolution of 48x64 (per players): [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res). 105 | 106 | 2. A sparse sampling of full resolution, cropped frames, are availabe in order to train the upsampler at a resolution of 350x530: [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res). 107 | 108 | The datasets contain a variety of situations: acceleration, braking, overtakes, crashes, and expert driving for both players. 109 | You can read about the data collection mechanism [here](https://enigma-labs.io/blog) 110 | 111 | Note: The full resolution dataset is only for upsampler training and is not fit for world model training. 112 | 113 | --- 114 | 115 | ## Outside resources 116 | 117 | - DIAMOND - https://github.com/eloialonso/diamond 118 | - AI-MarioKart64 - https://github.com/Dere-Wah/AI-MarioKart64 119 | 120 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/demo.gif -------------------------------------------------------------------------------- /assets/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/img.png -------------------------------------------------------------------------------- /assets/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/img_2.png -------------------------------------------------------------------------------- /config/agent/racing.yaml: -------------------------------------------------------------------------------- 1 | _target_: agent.AgentConfig 2 | 3 | denoiser: 4 | _target_: models.diffusion.DenoiserConfig 5 | sigma_data: 0.5 6 | sigma_offset_noise: 0.1 7 | noise_previous_obs: true 8 | upsampling_factor: null 9 | frame_sampling: 10 | - count: 4 11 | stride: 1 12 | - count: 4 13 | stride: 4 14 | inner_model: 15 | _target_: models.diffusion.InnerModelConfig 16 | img_channels: 6 17 | num_steps_conditioning: 8 18 | cond_channels: 2048 19 | depths: 20 | - 2 21 | - 2 22 | - 2 23 | - 2 24 | channels: 25 | - 128 26 | - 256 27 | - 512 28 | - 1024 29 | attn_depths: 30 | - 0 31 | - 0 32 | - 1 33 | - 1 34 | 35 | upsampler: 36 | _target_: models.diffusion.DenoiserConfig 37 | sigma_data: 0.5 38 | sigma_offset_noise: 0.1 39 | noise_previous_obs: false 40 | upsampling_factor: 10 41 | upsampling_frame_height: 350 42 | upsampling_frame_width: 530 43 | inner_model: 44 | _target_: models.diffusion.InnerModelConfig 45 | img_channels: 6 46 | num_steps_conditioning: 0 47 | cond_channels: 2048 48 | depths: 49 | - 2 50 | - 2 51 | - 2 52 | - 2 53 | channels: 54 | - 64 55 | - 64 56 | - 128 57 | - 256 58 | attn_depths: 59 | - 0 60 | - 0 61 | - 0 62 | - 0 63 | 64 | rew_end_model: null 65 | 66 | actor_critic: null 67 | -------------------------------------------------------------------------------- /config/env/racing.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | id: racing 3 | size: [700, 530] 4 | num_actions: 66 5 | path_data_low_res: null 6 | path_data_full_res: null 7 | keymap: racing 8 | -------------------------------------------------------------------------------- /config/trainer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - env: racing 4 | - agent: racing 5 | - world_model_env: fast 6 | 7 | hydra: 8 | job: 9 | chdir: True 10 | 11 | wandb: 12 | mode: offline 13 | project: null 14 | entity: null 15 | name: null 16 | group: null 17 | tags: null 18 | 19 | initialization: 20 | path_to_ckpt: null 21 | load_denoiser: True 22 | load_rew_end_model: True 23 | load_actor_critic: True 24 | 25 | common: 26 | devices: all # int, list of int, cpu, or all 27 | seed: null 28 | resume: False # do not modify, set by scripts/resume.sh only. 29 | 30 | checkpointing: 31 | save_agent_every: 5 32 | num_to_keep: 11 # number of checkpoints to keep, use null to disable 33 | 34 | collection: 35 | train: 36 | num_envs: 1 37 | epsilon: 0.01 38 | num_steps_total: 100000 39 | first_epoch: 40 | min: 5000 41 | max: 10000 # null: no maximum 42 | threshold_rew: 10 43 | steps_per_epoch: 100 44 | test: 45 | num_envs: 1 46 | num_episodes: 4 47 | epsilon: 0.0 48 | num_final_episodes: 100 49 | 50 | static_dataset: 51 | path: ${env.path_data_low_res} 52 | ignore_sample_weights: True 53 | 54 | training: 55 | should: True 56 | num_final_epochs: 600 57 | cache_in_ram: False 58 | num_workers_data_loaders: 1 59 | model_free: False # if True, turn off world_model training and RL in imagination 60 | compile_wm: False 61 | 62 | evaluation: 63 | should: True 64 | every: 20 65 | 66 | train_model: denoiser 67 | 68 | denoiser: 69 | training: 70 | num_autoregressive_steps: 8 71 | initial_num_consecutive_page_count: 1 72 | num_consecutive_pages: 73 | - epoch: 400 74 | count: 10 75 | - epoch: 500 76 | count: 50 77 | start_after_epochs: 0 78 | steps_first_epoch: 10 79 | steps_per_epoch: 20 80 | sample_weights: null 81 | batch_size: 30 82 | grad_acc_steps: 2 83 | lr_warmup_steps: 100 84 | max_grad_norm: 10.0 85 | 86 | optimizer: 87 | lr: 1e-4 88 | weight_decay: 1e-2 89 | eps: 1e-8 90 | 91 | sigma_distribution: # log normal distribution for sigma during training 92 | _target_: models.diffusion.SigmaDistributionConfig 93 | loc: -1.2 94 | scale: 1.2 95 | sigma_min: 2e-3 96 | sigma_max: 20 97 | 98 | upsampler: 99 | training: 100 | num_autoregressive_steps: 1 101 | initial_num_consecutive_page_count: 1 102 | start_after_epochs: 0 103 | steps_first_epoch: 20 104 | steps_per_epoch: 20 105 | sample_weights: null 106 | batch_size: 4 107 | grad_acc_steps: 2 108 | lr_warmup_steps: 100 109 | max_grad_norm: 10.0 110 | 111 | optimizer: ${denoiser.optimizer} 112 | sigma_distribution: ${denoiser.sigma_distribution} 113 | 114 | -------------------------------------------------------------------------------- /config/world_model_env/fast.yaml: -------------------------------------------------------------------------------- 1 | _target_: envs.WorldModelEnvConfig 2 | horizon: 1000 3 | num_batches_to_preload: 256 4 | diffusion_sampler_next_obs: 5 | _target_: models.diffusion.DiffusionSamplerConfig 6 | num_steps_denoising: 1 7 | sigma_min: 2e-3 8 | sigma_max: 5.0 9 | rho: 7 10 | order: 1 # 1: Euler, 2: Heun 11 | s_churn: 0.0 # Amount of stochasticity 12 | s_tmin: 0.0 13 | s_tmax: ${eval:'float("inf")'} 14 | s_noise: 1.0 15 | s_cond: 0.005 16 | diffusion_sampler_upsampling: 17 | _target_: models.diffusion.DiffusionSamplerConfig 18 | num_steps_denoising: 1 19 | sigma_min: 1 20 | sigma_max: 5.0 21 | rho: 7 22 | order: 2 # 1: Euler, 2: Heun 23 | s_churn: 10.0 # Amount of stochasticity 24 | s_tmin: 1 25 | s_tmax: 5 26 | s_noise: 0.9 27 | s_cond: 0 -------------------------------------------------------------------------------- /game/spawn/1/act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/act.npy -------------------------------------------------------------------------------- /game/spawn/1/full_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/full_res.npy -------------------------------------------------------------------------------- /game/spawn/1/low_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/low_res.npy -------------------------------------------------------------------------------- /game/spawn/1/next_act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/next_act.npy -------------------------------------------------------------------------------- /game/spawn/2/act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/act.npy -------------------------------------------------------------------------------- /game/spawn/2/full_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/full_res.npy -------------------------------------------------------------------------------- /game/spawn/2/low_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/low_res.npy -------------------------------------------------------------------------------- /game/spawn/2/next_act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/next_act.npy -------------------------------------------------------------------------------- /game/spawn/3/act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/act.npy -------------------------------------------------------------------------------- /game/spawn/3/full_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/full_res.npy -------------------------------------------------------------------------------- /game/spawn/3/low_res.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/low_res.npy -------------------------------------------------------------------------------- /game/spawn/3/next_act.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/next_act.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium==0.29.1 2 | ale-py==0.9.0 3 | h5py==3.11.0 4 | huggingface-hub==0.17.2 5 | hydra-core==1.3 6 | numpy==1.26.0 7 | opencv-python==4.10.0.84 8 | pillow==10.3.0 9 | pygame==2.5.2 10 | torch==2.1.0 11 | torchvision==0.16.0 12 | torcheval==0.0.7 13 | tqdm==4.66.4 14 | wandb==0.17.0 15 | -------------------------------------------------------------------------------- /scripts/import_run.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import argparse 4 | from functools import partial 5 | import json 6 | from pathlib import Path 7 | import subprocess 8 | from typing import Optional 9 | 10 | 11 | def main() -> None: 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("host", type=str) 14 | parser.add_argument("-v", "--verbose", action="store_true") 15 | parser.add_argument("--user", type=Optional[str]) 16 | parser.add_argument("--rootdir", type=Optional[str]) 17 | args = parser.parse_args() 18 | 19 | run = partial(subprocess.run, shell=True, check=True, text=True) 20 | host = args.host if args.user is None else f"{args.user}@{args.host}" 21 | 22 | def run_remote_cmd(cmd): 23 | return subprocess.check_output(f"ssh {host} {cmd}", shell=True, text=True) 24 | 25 | def ls(p): 26 | out = run_remote_cmd(f"ls {p}") 27 | return out.strip().split("\n")[::-1] 28 | 29 | def ask(l, info=None): 30 | print( 31 | "\n".join( 32 | [ 33 | f"{i:{len(str(len(l)))}d}: {d}" 34 | + (f" ({info[d]})" if info is not None else "") 35 | for i, d in enumerate(l, 1) 36 | ] 37 | ) 38 | ) 39 | while True: 40 | i = input("\nEnter a number: ") 41 | if i.isdigit() and 1 <= int(i) <= len(l): 42 | break 43 | print("\n/!\\ Invalid choice\n") 44 | return l[int(i) - 1] 45 | 46 | def ask_if_verbose(question, default): 47 | if not args.verbose: 48 | return default 49 | suffix = "[Y|n]" if default else "[y|N]" 50 | answer = input(f"{question} {suffix} ").lower() 51 | 52 | return (answer != "n") if default else (answer == "y") 53 | 54 | def get_info(rundir): 55 | return json.loads( 56 | run_remote_cmd(f"cat {rundir}/checkpoints/info_for_import_script.json") 57 | ) 58 | 59 | if args.rootdir is None: 60 | for p in Path(__file__).resolve().parents: 61 | if (p / ".git").is_dir(): 62 | break 63 | else: 64 | raise RuntimeError("This file is not in a git repository") 65 | out = run_remote_cmd(f"find -type d -name {p.name}").strip().split("\n") 66 | assert len(out) == 1 67 | rootdir = out[0] 68 | else: 69 | rootdir = f'{args.rootdir.strip().strip("/")}' 70 | 71 | dates = ls(f"{rootdir}/outputs") 72 | date = ask(dates) 73 | times = ls(f"{rootdir}/outputs/{date}") 74 | 75 | infos = { 76 | time: get_info(rundir=f"{rootdir}/outputs/{date}/{time}") for time in times 77 | } 78 | time = ask(times, infos) 79 | 80 | src = f"{rootdir}/outputs/{date}/{time}" 81 | 82 | dst = Path(args.host) / date 83 | dst.mkdir(exist_ok=True, parents=True) 84 | 85 | exclude = [ 86 | "*.log", 87 | "checkpoints/*", 88 | "checkpoints_tmp", 89 | ".hydra", 90 | "media", 91 | "__pycache__", 92 | "wandb", 93 | ] 94 | 95 | include = ["checkpoints/agent_versions"] 96 | 97 | if ask_if_verbose("Download only last checkpoint?", default=True): 98 | last_ckpt = ls(f"{src}/checkpoints/agent_versions")[0] 99 | exclude.append("checkpoints/agent_versions/*") 100 | include.append(f"checkpoints/agent_versions/{last_ckpt}") 101 | 102 | if not ask_if_verbose("Download train dataset?", default=False): 103 | exclude.append("dataset/train") 104 | 105 | if not ask_if_verbose("Download test dataset?", default=False): 106 | exclude.append("dataset/test") 107 | 108 | cmd = "rsync -av" 109 | for i in include: 110 | cmd += f' --include="{i}"' 111 | for e in exclude: 112 | cmd += f' --exclude="{e}"' 113 | 114 | cmd += f" {host}:{src} {str(dst)}" 115 | run(cmd) 116 | 117 | path = (dst / time).absolute() 118 | print(f"\n--> Run imported in:\n{path}") 119 | run(f"echo {path} | xclip") 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /scripts/resume.sh: -------------------------------------------------------------------------------- 1 | python src/main.py common.resume=True hydra.output_subdir=null hydra.run.dir=. 2 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/__init__.py -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from envs import TorchEnv, WorldModelEnv 9 | from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig 10 | from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig 11 | from models.rew_end_model import RewEndModel, RewEndModelConfig 12 | from utils import extract_state_dict 13 | 14 | 15 | @dataclass 16 | class AgentConfig: 17 | denoiser: DenoiserConfig 18 | upsampler: Optional[DenoiserConfig] 19 | rew_end_model: Optional[RewEndModelConfig] 20 | actor_critic: Optional[ActorCriticConfig] 21 | num_actions: int 22 | 23 | def __post_init__(self) -> None: 24 | self.denoiser.inner_model.num_actions = self.num_actions 25 | if self.upsampler is not None: 26 | self.upsampler.inner_model.num_actions = self.num_actions 27 | if self.rew_end_model is not None: 28 | self.rew_end_model.num_actions = self.num_actions 29 | if self.actor_critic is not None: 30 | self.actor_critic.num_actions = self.num_actions 31 | 32 | 33 | class Agent(nn.Module): 34 | def __init__(self, cfg: AgentConfig) -> None: 35 | super().__init__() 36 | self.denoiser = Denoiser(cfg.denoiser) 37 | self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None 38 | self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None 39 | self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None 40 | 41 | @property 42 | def device(self): 43 | return self.denoiser.device 44 | 45 | def setup_training( 46 | self, 47 | sigma_distribution_cfg: SigmaDistributionConfig, 48 | sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig], 49 | actor_critic_loss_cfg: Optional[ActorCriticLossConfig], 50 | rl_env: Optional[Union[TorchEnv, WorldModelEnv]], 51 | ) -> None: 52 | self.denoiser.setup_training(sigma_distribution_cfg) 53 | if self.upsampler is not None: 54 | self.upsampler.setup_training(sigma_distribution_cfg_upsampler) 55 | if self.actor_critic is not None: 56 | self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg) 57 | 58 | def load( 59 | self, 60 | path_to_ckpt: Path, 61 | load_denoiser: bool = True, 62 | load_upsampler: bool = True, 63 | load_rew_end_model: bool = True, 64 | load_actor_critic: bool = True, 65 | ) -> None: 66 | sd = torch.load(Path(path_to_ckpt), map_location=self.device) 67 | if load_denoiser: 68 | self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser")) 69 | if load_upsampler: 70 | self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler")) 71 | if load_rew_end_model and self.rew_end_model is not None: 72 | self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model")) 73 | if load_actor_critic and self.actor_critic is not None: 74 | self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic")) 75 | -------------------------------------------------------------------------------- /src/coroutines/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | def coroutine(func): 5 | @wraps(func) 6 | def primer(*args, **kwargs): 7 | gen = func(*args, **kwargs) 8 | next(gen) 9 | return gen 10 | 11 | return primer 12 | -------------------------------------------------------------------------------- /src/coroutines/collector.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Generator, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | from . import coroutine 10 | from data import Episode, Dataset 11 | from envs import TorchEnv 12 | from .env_loop import make_env_loop 13 | from utils import Logs 14 | 15 | 16 | @coroutine 17 | def make_collector( 18 | env: TorchEnv, 19 | model: nn.Module, 20 | dataset: Dataset, 21 | epsilon: float = 0.0, 22 | reset_every_collect: bool = False, 23 | verbose: bool = True, 24 | ) -> Generator[Logs, int, None]: 25 | num_envs = env.num_envs 26 | 27 | env_loop, buffer, episode_ids, dead = (None,) * 4 28 | num_steps, num_episodes, to_log, pbar = (None,) * 4 29 | 30 | def setup_new_collect(): 31 | nonlocal num_steps, num_episodes, buffer, to_log, pbar 32 | num_steps = 0 33 | num_episodes = 0 34 | buffer = defaultdict(list) 35 | to_log = [] 36 | pbar = tqdm( 37 | total=num_to_collect.total, 38 | unit=num_to_collect.unit, 39 | desc=f"Collect {dataset.name}", 40 | disable=not verbose, 41 | ) 42 | 43 | def reset(): 44 | nonlocal env_loop, episode_ids, dead 45 | env_loop = make_env_loop(env, model, epsilon) 46 | episode_ids = defaultdict(lambda: None) 47 | dead = [None] * num_envs 48 | 49 | num_to_collect = yield 50 | setup_new_collect() 51 | reset() 52 | 53 | while True: 54 | with torch.no_grad(): 55 | all_obs, act, rew, end, trunc, *_, [infos] = env_loop.send(1) 56 | 57 | num_steps += num_envs 58 | pbar.update(num_envs if num_to_collect.steps is not None else 0) 59 | 60 | for i, (o, a, r, e, t) in enumerate(zip(all_obs, act, rew, end, trunc)): 61 | buffer[i].append((o, a, r, e, t)) 62 | dead[i] = (e + t).clip(max=1).item() 63 | 64 | num_episodes += sum(dead) 65 | 66 | can_stop = num_to_collect.can_stop(num_steps, num_episodes) 67 | 68 | count_dead = 0 69 | for i in range(num_envs): 70 | # Store incomplete episodes only when reset_every_collect is set to False (train) 71 | add_to_dataset = dead[i] or (can_stop and not reset_every_collect) 72 | if add_to_dataset: 73 | info = {"final_observation": infos["final_observation"][count_dead]} if dead[i] else {} 74 | ep = Episode(*(torch.cat(x, dim=0) for x in zip(*buffer[i])), info).to("cpu") 75 | if episode_ids[i] is not None: 76 | ep = dataset.load_episode(episode_ids[i]) + ep 77 | episode_ids[i] = dataset.add_episode(ep, episode_id=episode_ids[i]) 78 | 79 | if dead[i]: 80 | to_log.append( 81 | { 82 | f"{dataset.name}/episode_id": episode_ids[i], 83 | **ep.compute_metrics(), 84 | } 85 | ) 86 | buffer[i] = [] 87 | episode_ids[i] = None 88 | pbar.update(1 if num_to_collect.episodes is not None else 0) 89 | 90 | count_dead += dead[i] 91 | 92 | if can_stop: 93 | pbar.close() 94 | metrics = { 95 | "num_steps": dataset.num_steps, 96 | "counts/rew_-1": dataset.counts_rew[0], 97 | "counts/rew__0": dataset.counts_rew[1], 98 | "counts/rew_+1": dataset.counts_rew[2], 99 | "counts/end_0": dataset.counts_end[0], 100 | "counts/end_1": dataset.counts_end[1], 101 | } 102 | to_log.append({f"{dataset.name}/{k}": v for k, v in metrics.items()}) 103 | num_to_collect = yield to_log 104 | setup_new_collect() 105 | if reset_every_collect: 106 | reset() 107 | 108 | 109 | @dataclass 110 | class NumToCollect: 111 | steps: Optional[int] = None 112 | episodes: Optional[int] = None 113 | 114 | def __post_init__(self) -> None: 115 | assert (self.steps is None) != (self.episodes is None) 116 | 117 | def can_stop(self, num_steps: int, num_episodes: int) -> bool: 118 | return num_steps >= self.steps if self.steps is not None else num_episodes >= self.episodes 119 | 120 | @property 121 | def unit(self) -> str: 122 | return "steps" if self.steps is not None else "eps" 123 | 124 | @property 125 | def total(self) -> int: 126 | return self.steps if self.steps is not None else self.episodes 127 | -------------------------------------------------------------------------------- /src/coroutines/env_loop.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Generator, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.distributions.categorical import Categorical 7 | 8 | from . import coroutine 9 | from envs import TorchEnv, WorldModelEnv 10 | 11 | 12 | @coroutine 13 | def make_env_loop( 14 | env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0 15 | ) -> Generator[Tuple[torch.Tensor, ...], int, None]: 16 | num_steps = yield 17 | 18 | hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device) 19 | cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device) 20 | 21 | seed = random.randint(0, 2**31 - 1) 22 | obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)]) 23 | 24 | while True: 25 | hx, cx = hx.detach(), cx.detach() 26 | all_ = [] 27 | infos = [] 28 | n = 0 29 | 30 | while n < num_steps: 31 | logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx)) 32 | act = Categorical(logits=logits_act).sample() 33 | 34 | if random.random() < epsilon: 35 | act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device) 36 | 37 | next_obs, rew, end, trunc, info = env.step(act) 38 | 39 | if n > 0: 40 | val_bootstrap = val.detach().clone() 41 | if dead.any(): 42 | val_bootstrap[dead] = val_final_obs 43 | all_[-1][-1] = val_bootstrap 44 | 45 | dead = torch.logical_or(end, trunc) 46 | 47 | if dead.any(): 48 | with torch.no_grad(): 49 | _, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead])) 50 | reset_gate = 1 - dead.float().unsqueeze(1) 51 | hx = hx * reset_gate 52 | cx = cx * reset_gate 53 | if "burnin_obs" in info: 54 | burnin_obs = info["burnin_obs"] 55 | for i in range(burnin_obs.size(1)): 56 | _, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead])) 57 | 58 | all_.append([obs, act, rew, end, trunc, logits_act, val, None]) 59 | infos.append(info) 60 | 61 | obs = next_obs 62 | n += 1 63 | 64 | with torch.no_grad(): 65 | _, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx 66 | 67 | if dead.any(): 68 | val_bootstrap[dead] = val_final_obs 69 | 70 | all_[-1][-1] = val_bootstrap 71 | 72 | all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_)) 73 | 74 | num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos 75 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import Batch 2 | from .batch_sampler import BatchSampler 3 | from .dataset import Dataset, GameHdf5Dataset 4 | from .episode import Episode 5 | from .segment import Segment, SegmentId 6 | from .utils import collate_segments_to_batch, DatasetTraverser, make_segment 7 | -------------------------------------------------------------------------------- /src/data/batch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | 7 | from .segment import SegmentId 8 | 9 | 10 | @dataclass 11 | class Batch: 12 | obs: torch.ByteTensor 13 | act: torch.LongTensor 14 | rew: torch.FloatTensor 15 | end: torch.LongTensor 16 | trunc: torch.LongTensor 17 | mask_padding: torch.BoolTensor 18 | info: List[Dict[str, Any]] 19 | segment_ids: List[SegmentId] 20 | 21 | def pin_memory(self) -> Batch: 22 | return Batch(**{k: v if k in ("segment_ids", "info") else v.pin_memory() for k, v in self.__dict__.items()}) 23 | 24 | def to(self, device: torch.device) -> Batch: 25 | return Batch(**{k: v if k in ("segment_ids", "info") else v.to(device) for k, v in self.__dict__.items()}) 26 | -------------------------------------------------------------------------------- /src/data/batch_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .dataset import GameHdf5Dataset, Dataset 7 | from .segment import SegmentId 8 | 9 | 10 | class BatchSampler(torch.utils.data.Sampler): 11 | def __init__( 12 | self, 13 | dataset: Dataset, 14 | rank: int, 15 | world_size: int, 16 | batch_size: int, 17 | seq_length: int, 18 | sample_weights: Optional[List[float]] = None, 19 | can_sample_beyond_end: bool = False, 20 | autoregressive_obs: int = None, 21 | initial_num_consecutive_page_count: int = 1 22 | ) -> None: 23 | super().__init__(dataset) 24 | assert isinstance(dataset, (Dataset, GameHdf5Dataset)) 25 | self.dataset = dataset 26 | self.rank = rank 27 | self.world_size = world_size 28 | self.sample_weights = sample_weights 29 | self.batch_size = batch_size 30 | self.seq_length = seq_length 31 | self.can_sample_beyond_end = can_sample_beyond_end 32 | self.autoregressive_obs = autoregressive_obs 33 | self.num_consecutive_batches = initial_num_consecutive_page_count 34 | 35 | def __len__(self): 36 | raise NotImplementedError 37 | 38 | def __iter__(self) -> Generator[List[SegmentId], None, None]: 39 | segments = None 40 | current_iter = 0 41 | 42 | while True: 43 | if current_iter == 0: 44 | segments = self.sample() 45 | else: 46 | segments = self.next(segments) 47 | 48 | current_iter = (current_iter + 1) % self.num_consecutive_batches 49 | yield segments 50 | 51 | def next(self, segments: List[SegmentId]): 52 | return [ 53 | SegmentId(segment.episode_id, segment.stop, segment.stop + self.autoregressive_obs, False) 54 | for segment in segments 55 | ] 56 | 57 | def sample(self) -> List[SegmentId]: 58 | total_length = self.seq_length + (self.num_consecutive_batches - 1) * self.autoregressive_obs 59 | 60 | num_episodes = self.dataset.num_episodes 61 | 62 | if (self.sample_weights is None) or num_episodes < len(self.sample_weights): 63 | weights = self.dataset.lengths / self.dataset.num_steps 64 | else: 65 | weights = self.sample_weights 66 | num_weights = len(self.sample_weights) 67 | assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1 68 | sizes = [ 69 | num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1) 70 | for i in range(num_weights) 71 | ] 72 | weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)] 73 | 74 | episodes_partition = np.arange(self.rank, num_episodes, self.world_size) 75 | episode_lengths = self.dataset.lengths[episodes_partition] 76 | valid_mask = episode_lengths > total_length # valid episodes must be long enough for autoregressvie generation 77 | episodes_partition = episodes_partition[valid_mask] 78 | 79 | weights = np.array(weights[self.rank::self.world_size]) 80 | weights = weights[valid_mask] 81 | 82 | max_eps = self.batch_size 83 | episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum()) 84 | episode_ids = episode_ids.repeat(self.batch_size // max_eps) 85 | 86 | # choose a random timestamp at the dataset 87 | timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids]) 88 | # compute total context size + autoregressive generation length 89 | 90 | # the stops of the first page can be at most the length of the training example minus the autoregressive generation frames in the next pages 91 | stops = np.minimum( 92 | self.dataset.lengths[episode_ids] - (self.num_consecutive_batches - 1) * self.seq_length, 93 | timesteps + 1 + np.random.randint(0, total_length, len(timesteps)) 94 | ) 95 | # stops must be longer than the initial context + first page prediction size 96 | stops = np.maximum(stops, self.seq_length) 97 | # starts is stops minus the initial context and the first page prediction size 98 | starts = stops - self.seq_length 99 | 100 | return [SegmentId(*x, True) for x in zip(episode_ids, starts, stops)] -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import multiprocessing as mp 3 | from pathlib import Path 4 | import shutil 5 | from typing import Any, Dict, List, Optional 6 | 7 | import h5py 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset as TorchDataset 12 | 13 | from .episode import Episode 14 | from .segment import Segment, SegmentId 15 | from .utils import make_segment 16 | from utils import StateDictMixin 17 | 18 | 19 | class Dataset(StateDictMixin, TorchDataset): 20 | def __init__( 21 | self, 22 | directory: Path, 23 | dataset_full_res: Optional[TorchDataset], 24 | name: Optional[str] = None, 25 | cache_in_ram: bool = False, 26 | use_manager: bool = False, 27 | save_on_disk: bool = True, 28 | ) -> None: 29 | super().__init__() 30 | 31 | # State 32 | self.is_static = False 33 | self.num_episodes = None 34 | self.num_steps = None 35 | self.start_idx = None 36 | self.lengths = None 37 | self.counter_rew = None 38 | self.counter_end = None 39 | 40 | self._directory = Path(directory).expanduser() 41 | self._name = name if name is not None else self._directory.stem 42 | self._cache_in_ram = cache_in_ram 43 | self._save_on_disk = save_on_disk 44 | self._default_path = self._directory / "info.pt" 45 | self._cache = mp.Manager().dict() if use_manager else {} 46 | self._reset() 47 | 48 | self._dataset_full_res = dataset_full_res 49 | 50 | def __len__(self) -> int: 51 | return self.num_steps 52 | 53 | def __getitem__(self, segment_id: SegmentId) -> Segment: 54 | episode = self.load_episode(segment_id.episode_id) 55 | segment = make_segment(episode, segment_id, should_pad=True) 56 | if self._dataset_full_res is not None: 57 | segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop, segment_id.is_first_batch) 58 | segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs 59 | elif "full_res" in segment.info: 60 | segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop] 61 | return segment 62 | 63 | def __str__(self) -> str: 64 | return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps." 65 | 66 | @property 67 | def name(self) -> str: 68 | return self._name 69 | 70 | @property 71 | def counts_rew(self) -> List[int]: 72 | return [self.counter_rew[r] for r in [-1, 0, 1]] 73 | 74 | @property 75 | def counts_end(self) -> List[int]: 76 | return [self.counter_end[e] for e in [0, 1]] 77 | 78 | def _reset(self) -> None: 79 | self.num_episodes = 0 80 | self.num_steps = 0 81 | self.start_idx = np.array([], dtype=np.int64) 82 | self.lengths = np.array([], dtype=np.int64) 83 | self.counter_rew = Counter() 84 | self.counter_end = Counter() 85 | self._cache.clear() 86 | 87 | def clear(self) -> None: 88 | self.assert_not_static() 89 | if self._directory.is_dir(): 90 | shutil.rmtree(self._directory) 91 | self._reset() 92 | 93 | def load_episode(self, episode_id: int) -> Episode: 94 | if self._cache_in_ram and episode_id in self._cache: 95 | episode = self._cache[episode_id] 96 | else: 97 | episode = Episode.load(self._get_episode_path(episode_id)) 98 | if self._cache_in_ram: 99 | self._cache[episode_id] = episode 100 | return episode 101 | 102 | def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int: 103 | self.assert_not_static() 104 | episode = episode.to("cpu") 105 | 106 | if episode_id is None: 107 | episode_id = self.num_episodes 108 | self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps]))) 109 | self.lengths = np.concatenate((self.lengths, np.array([len(episode)]))) 110 | self.num_steps += len(episode) 111 | self.num_episodes += 1 112 | 113 | else: 114 | assert episode_id < self.num_episodes 115 | old_episode = self.load_episode(episode_id) 116 | incr_num_steps = len(episode) - len(old_episode) 117 | self.lengths[episode_id] = len(episode) 118 | self.start_idx[episode_id + 1 :] += incr_num_steps 119 | self.num_steps += incr_num_steps 120 | self.counter_rew.subtract(old_episode.rew.sign().tolist()) 121 | self.counter_end.subtract(old_episode.end.tolist()) 122 | 123 | self.counter_rew.update(episode.rew.sign().tolist()) 124 | self.counter_end.update(episode.end.tolist()) 125 | 126 | if self._save_on_disk: 127 | episode.save(self._get_episode_path(episode_id)) 128 | 129 | if self._cache_in_ram: 130 | self._cache[episode_id] = episode 131 | 132 | return episode_id 133 | 134 | def _get_episode_path(self, episode_id: int) -> Path: 135 | n = 3 # number of hierarchies 136 | powers = np.arange(n) 137 | subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers 138 | subfolders = [int(x) for x in subfolders[::-1]] 139 | subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)]) 140 | return self._directory / subfolders / f"{episode_id}.pt" 141 | 142 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 143 | super().load_state_dict(state_dict) 144 | self._cache.clear() 145 | 146 | def assert_not_static(self) -> None: 147 | assert not self.is_static, "Trying to modify a static dataset." 148 | 149 | def save_to_default_path(self) -> None: 150 | self._default_path.parent.mkdir(exist_ok=True, parents=True) 151 | torch.save(self.state_dict(), self._default_path) 152 | 153 | def load_from_default_path(self) -> None: 154 | print(self._default_path) 155 | if self._default_path.is_file(): 156 | self.load_state_dict(torch.load(self._default_path, weights_only=False)) 157 | 158 | 159 | class GameHdf5Dataset(StateDictMixin, TorchDataset): 160 | def __init__(self, directory: Path) -> None: 161 | super().__init__() 162 | filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1])) 163 | self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames} 164 | 165 | self._length_one_episode = self._episode_lengths(self._filenames) 166 | 167 | self.num_episodes = len(self._filenames) 168 | 169 | self.num_steps = sum(list(self._length_one_episode.values())) 170 | self.lengths = np.array(list(self._length_one_episode.values()), dtype=np.int64) 171 | 172 | def _episode_lengths(self, filenames): 173 | length_one_episode = {} 174 | 175 | for filename in filenames: 176 | with h5py.File(filenames[filename], "r") as f: 177 | keys = f.keys() 178 | max_frame_index = max(int(key[len('frame_'):-len('_x')]) for key in keys if key.endswith('_x') and key.startswith('frame_')) 179 | length_one_episode[filename] = max_frame_index + 1 180 | 181 | return length_one_episode 182 | 183 | def __len__(self) -> int: 184 | return self.num_steps 185 | 186 | def save_to_default_path(self) -> None: 187 | pass 188 | 189 | def __getitem__(self, segment_id: SegmentId) -> Segment: 190 | episode_length = self._length_one_episode[segment_id.episode_id] 191 | assert segment_id.start < episode_length and segment_id.stop > 0 and segment_id.start < segment_id.stop 192 | 193 | pad_len_right = max(0, segment_id.stop - episode_length) 194 | pad_len_left = max(0, -segment_id.start) 195 | 196 | start = max(0, segment_id.start) 197 | stop = min(episode_length, segment_id.stop) 198 | mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool() 199 | 200 | #print(self._filenames[segment_id.episode_id]) 201 | with h5py.File(self._filenames[segment_id.episode_id], "r") as f: 202 | obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)]) 203 | act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)])) 204 | 205 | def pad(x): 206 | right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x 207 | return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right 208 | 209 | obs = pad(obs) 210 | act = pad(act) 211 | rew = torch.zeros(obs.size(0)) 212 | end = torch.zeros(obs.size(0), dtype=torch.uint8) 213 | trunc = torch.zeros(obs.size(0), dtype=torch.uint8) 214 | return Segment(obs, act, rew, end, trunc, mask_padding, info={}, id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch)) 215 | 216 | def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser 217 | episode_length = self._length_one_episode[episode_id] 218 | s = self[SegmentId(episode_id, 0, episode_length, None)] 219 | return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info) -------------------------------------------------------------------------------- /src/data/episode.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Any, Dict, Optional 5 | 6 | import torch 7 | 8 | 9 | @dataclass 10 | class Episode: 11 | obs: torch.FloatTensor 12 | act: torch.LongTensor 13 | rew: torch.FloatTensor 14 | end: torch.ByteTensor 15 | trunc: torch.ByteTensor 16 | info: Dict[str, Any] 17 | 18 | def __len__(self) -> int: 19 | return self.obs.size(0) 20 | 21 | def __add__(self, other: Episode) -> Episode: 22 | assert self.dead.sum() == 0 23 | d = {k: torch.cat((v, other.__dict__[k]), dim=0) for k, v in self.__dict__.items() if k != "info"} 24 | return Episode(**d, info=merge_info(self.info, other.info)) 25 | 26 | def to(self, device) -> Episode: 27 | return Episode(**{k: v.to(device) if k != "info" else v for k, v in self.__dict__.items()}) 28 | 29 | @property 30 | def dead(self) -> torch.ByteTensor: 31 | return (self.end + self.trunc).clip(max=1) 32 | 33 | def compute_metrics(self) -> Dict[str, Any]: 34 | return {"length": len(self), "return": self.rew.sum().item()} 35 | 36 | @classmethod 37 | def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode: 38 | return cls( 39 | **{ 40 | k: v.div(255).mul(2).sub(1) if k == "obs" else v 41 | for k, v in torch.load(Path(path), map_location=map_location).items() 42 | } 43 | ) 44 | 45 | def save(self, path: Path) -> None: 46 | path = Path(path) 47 | path.parent.mkdir(parents=True, exist_ok=True) 48 | d = {k: v.add(1).div(2).mul(255).byte() if k == "obs" else v for k, v in self.__dict__.items()} 49 | torch.save(d, path.with_suffix(".tmp")) 50 | path.with_suffix(".tmp").rename(path) 51 | 52 | 53 | def merge_info(info_a, info_b): 54 | keys_a = set(info_a) 55 | keys_b = set(info_b) 56 | intersection = keys_a & keys_b 57 | info = { 58 | **{k: info_a[k] for k in keys_a if k not in intersection}, 59 | **{k: info_b[k] for k in keys_b if k not in intersection}, 60 | **{k: torch.cat((info_a[k], info_b[k]), dim=0) for k in intersection}, 61 | } 62 | return info 63 | -------------------------------------------------------------------------------- /src/data/segment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Union 4 | 5 | import torch 6 | 7 | 8 | @dataclass 9 | class SegmentId: 10 | episode_id: Union[int, str] 11 | start: int 12 | stop: int 13 | is_first_batch: bool 14 | 15 | 16 | @dataclass 17 | class Segment: 18 | obs: torch.FloatTensor 19 | act: torch.LongTensor 20 | rew: torch.FloatTensor 21 | end: torch.ByteTensor 22 | trunc: torch.ByteTensor 23 | mask_padding: torch.BoolTensor 24 | info: Dict[str, Any] 25 | id: SegmentId 26 | 27 | @property 28 | def effective_size(self): 29 | return self.mask_padding.sum().item() 30 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Generator, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from .batch import Batch 8 | from .episode import Episode 9 | from .segment import Segment, SegmentId 10 | 11 | 12 | def collate_segments_to_batch(segments: List[Segment]) -> Batch: 13 | attrs = ("obs", "act", "rew", "end", "trunc", "mask_padding") 14 | stack = (torch.stack([getattr(s, x) for s in segments]) for x in attrs) 15 | return Batch(*stack, [s.info for s in segments], [s.id for s in segments]) 16 | 17 | 18 | def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment: 19 | if not (segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop): 20 | print(f'Failed assertion because: start={segment_id.start}, stop={segment_id.stop}, len(episode)={len(episode)}') 21 | 22 | assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop 23 | pad_len_right = max(0, segment_id.stop - len(episode)) 24 | pad_len_left = max(0, -segment_id.start) 25 | assert pad_len_right == pad_len_left == 0 or should_pad 26 | 27 | def pad(x): 28 | right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x 29 | return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right 30 | 31 | start = max(0, segment_id.start) 32 | stop = min(len(episode), segment_id.stop) 33 | mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool() 34 | 35 | return Segment( 36 | pad(episode.obs[start:stop]), 37 | pad(episode.act[start:stop]), 38 | pad(episode.rew[start:stop]), 39 | pad(episode.end[start:stop]), 40 | pad(episode.trunc[start:stop]), 41 | mask_padding, 42 | info=episode.info, 43 | id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch), 44 | ) 45 | 46 | 47 | class DatasetTraverser: 48 | def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None: 49 | self.dataset = dataset 50 | self.batch_num_samples = batch_num_samples 51 | self.chunk_size = chunk_size 52 | 53 | def __len__(self): 54 | return math.ceil( 55 | sum( 56 | [ 57 | math.ceil(self.dataset.lengths[episode_id] / self.chunk_size) 58 | - int(self.dataset.lengths[episode_id] % self.chunk_size == 1) 59 | for episode_id in range(self.dataset.num_episodes) 60 | ] 61 | ) 62 | / self.batch_num_samples 63 | ) 64 | 65 | def __iter__(self) -> Generator[Batch, None, None]: 66 | chunks = [] 67 | for episode_id in range(self.dataset.num_episodes): 68 | episode = self.dataset.load_episode(episode_id) 69 | segments = [] 70 | for i in range(math.ceil(len(episode) / self.chunk_size)): 71 | start = i * self.chunk_size 72 | stop = (i + 1) * self.chunk_size 73 | segment = make_segment( 74 | episode, 75 | SegmentId(episode_id, start, stop, None), 76 | should_pad=True, 77 | ) 78 | segment_id_full_res = SegmentId(episode.info["original_file_id"], start, stop) 79 | segment.info["full_res"] = self.dataset._dataset_full_res[segment_id_full_res].obs 80 | chunks.append(segment) 81 | if chunks[-1].effective_size < 2: 82 | chunks.pop() 83 | 84 | while len(chunks) >= self.batch_num_samples: 85 | yield collate_segments_to_batch(chunks[: self.batch_num_samples]) 86 | chunks = chunks[self.batch_num_samples :] 87 | 88 | if len(chunks) > 0: 89 | yield collate_segments_to_batch(chunks) 90 | 91 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import make_atari_env, TorchEnv 2 | from .world_model_env import WorldModelEnv, WorldModelEnvConfig 3 | -------------------------------------------------------------------------------- /src/envs/atari_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Derived from https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py 3 | Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from typing import Any, SupportsFloat 9 | 10 | import cv2 11 | import numpy as np 12 | 13 | import gymnasium as gym 14 | from gymnasium.core import WrapperActType, WrapperObsType 15 | from gymnasium.spaces import Box 16 | 17 | 18 | class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): 19 | def __init__( 20 | self, 21 | env: gym.Env, 22 | noop_max: int, 23 | frame_skip: int, 24 | screen_size: int, 25 | ): 26 | gym.utils.RecordConstructorArgs.__init__( 27 | self, 28 | noop_max=noop_max, 29 | frame_skip=frame_skip, 30 | screen_size=screen_size, 31 | ) 32 | gym.Wrapper.__init__(self, env) 33 | 34 | assert frame_skip > 0 35 | assert screen_size > 0 36 | assert noop_max >= 0 37 | if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1: 38 | raise ValueError( 39 | "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper" 40 | ) 41 | self.noop_max = noop_max 42 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" 43 | 44 | self.frame_skip = frame_skip 45 | self.screen_size = screen_size 46 | 47 | # buffer of most recent two observations for max pooling 48 | assert isinstance(env.observation_space, Box) 49 | self.obs_buffer = [ 50 | np.empty(env.observation_space.shape, dtype=np.uint8), 51 | np.empty(env.observation_space.shape, dtype=np.uint8), 52 | ] 53 | 54 | self.lives = 0 55 | self.game_over = False 56 | 57 | _low, _high, _obs_dtype = (0, 255, np.uint8) 58 | _shape = (screen_size, screen_size, 3) 59 | self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype) 60 | 61 | @property 62 | def ale(self): 63 | """Make ale as a class property to avoid serialization error.""" 64 | return self.env.unwrapped.ale 65 | 66 | def step(self, action: WrapperActType) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: 67 | total_reward, terminated, truncated, info = 0.0, False, False, {} 68 | 69 | life_loss = False 70 | 71 | for t in range(self.frame_skip): 72 | _, reward, terminated, truncated, info = self.env.step(action) 73 | total_reward += reward 74 | self.game_over = terminated 75 | 76 | if self.ale.lives() < self.lives: 77 | life_loss = True 78 | self.lives = self.ale.lives() 79 | 80 | if terminated or truncated: 81 | break 82 | 83 | if t == self.frame_skip - 2: 84 | self.ale.getScreenRGB(self.obs_buffer[1]) 85 | elif t == self.frame_skip - 1: 86 | self.ale.getScreenRGB(self.obs_buffer[0]) 87 | 88 | info["life_loss"] = life_loss 89 | 90 | obs, original_obs = self._get_obs() 91 | info["original_obs"] = original_obs 92 | 93 | return obs, total_reward, terminated, truncated, info 94 | 95 | def reset( 96 | self, *, seed: int | None = None, options: dict[str, Any] | None = None 97 | ) -> tuple[WrapperObsType, dict[str, Any]]: 98 | """Resets the environment using preprocessing.""" 99 | # NoopReset 100 | _, reset_info = self.env.reset(seed=seed, options=options) 101 | 102 | reset_info["life_loss"] = False 103 | 104 | noops = self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0 105 | for _ in range(noops): 106 | _, _, terminated, truncated, step_info = self.env.step(0) 107 | reset_info.update(step_info) 108 | if terminated or truncated: 109 | _, reset_info = self.env.reset(seed=seed, options=options) 110 | 111 | self.lives = self.ale.lives() 112 | self.ale.getScreenRGB(self.obs_buffer[0]) 113 | self.obs_buffer[1].fill(0) 114 | 115 | obs, original_obs = self._get_obs() 116 | reset_info["original_obs"] = original_obs 117 | 118 | return obs, reset_info 119 | 120 | def _get_obs(self): 121 | if self.frame_skip > 1: # more efficient in-place pooling 122 | np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) 123 | 124 | original_obs = self.obs_buffer[0] 125 | obs = cv2.resize( 126 | original_obs, 127 | (self.screen_size, self.screen_size), 128 | interpolation=cv2.INTER_AREA, 129 | ) 130 | 131 | obs = np.asarray(obs, dtype=np.uint8) 132 | 133 | return obs, original_obs 134 | -------------------------------------------------------------------------------- /src/envs/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Dict, Optional, Tuple 3 | 4 | import ale_py 5 | import gymnasium 6 | from gymnasium.vector import AsyncVectorEnv 7 | import numpy as np 8 | import torch 9 | from torch import Tensor 10 | 11 | from .atari_preprocessing import AtariPreprocessing 12 | 13 | 14 | def make_atari_env( 15 | id: str, 16 | num_envs: int, 17 | device: torch.device, 18 | done_on_life_loss: bool, 19 | size: int, 20 | max_episode_steps: Optional[int], 21 | ) -> TorchEnv: 22 | def env_fn(): 23 | env = gymnasium.make( 24 | id, 25 | full_action_space=False, 26 | frameskip=1, 27 | render_mode="rgb_array", 28 | max_episode_steps=max_episode_steps, 29 | ) 30 | env = AtariPreprocessing( 31 | env=env, 32 | noop_max=30, 33 | frame_skip=4, 34 | screen_size=size, 35 | ) 36 | return env 37 | 38 | env = AsyncVectorEnv([env_fn for _ in range(num_envs)]) 39 | 40 | # The AsyncVectorEnv resets the env on termination, which means that it will 41 | # reset the environment if we use the default AtariPreprocessing of gymnasium with 42 | # terminate_on_life_loss=True (which means that we will only see the first life). 43 | # Hence a separate wrapper for life_loss, coming after the AsyncVectorEnv. 44 | 45 | if done_on_life_loss: 46 | env = DoneOnLifeLoss(env) 47 | 48 | env = TorchEnv(env, device) 49 | 50 | return env 51 | 52 | 53 | class DoneOnLifeLoss(gymnasium.Wrapper): 54 | def __init__(self, env: AsyncVectorEnv) -> None: 55 | super().__init__(env) 56 | 57 | def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]: 58 | obs, rew, end, trunc, info = self.env.step(actions) 59 | life_loss = info["life_loss"] 60 | if life_loss.any(): 61 | end[life_loss] = True 62 | info["final_observation"] = obs 63 | return obs, rew, end, trunc, info 64 | 65 | 66 | class TorchEnv(gymnasium.Wrapper): 67 | def __init__(self, env: gymnasium.Env, device: torch.device) -> None: 68 | super().__init__(env) 69 | self.device = device 70 | self.num_envs = env.observation_space.shape[0] 71 | self.num_actions = env.unwrapped.single_action_space.n 72 | b, h, w, c = env.observation_space.shape 73 | self.observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(b, c, h, w)) 74 | 75 | def reset(self, *args, **kwargs) -> Tuple[Tensor, Dict[str, Any]]: 76 | obs, info = self.env.reset(*args, **kwargs) 77 | return self._to_tensor(obs), info 78 | 79 | def step(self, actions: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]: 80 | obs, rew, end, trunc, info = self.env.step(actions.cpu().numpy()) 81 | dead = np.logical_or(end, trunc) 82 | if dead.any(): 83 | info["final_observation"] = self._to_tensor(np.stack(info["final_observation"][dead])) 84 | obs, rew, end, trunc = (self._to_tensor(x) for x in (obs, rew, end, trunc)) 85 | return obs, rew, end, trunc, info 86 | 87 | def _to_tensor(self, x: Tensor) -> Tensor: 88 | if x.ndim == 4: 89 | return torch.tensor(x, device=self.device).div(255).mul(2).sub(1).permute(0, 3, 1, 2).contiguous() 90 | elif x.dtype is np.dtype("bool"): 91 | return torch.tensor(x, dtype=torch.uint8, device=self.device) 92 | else: 93 | return torch.tensor(x, dtype=torch.float32, device=self.device) 94 | -------------------------------------------------------------------------------- /src/envs/world_model_env.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from itertools import cycle 3 | from pathlib import Path 4 | from typing import Any, Dict, Generator, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from torch.distributions.categorical import Categorical 10 | import torch.nn.functional as F 11 | 12 | from coroutines import coroutine 13 | from models.diffusion import Denoiser, DiffusionSampler, DiffusionSamplerConfig 14 | from models.rew_end_model import RewEndModel 15 | 16 | from utils import get_frame_indices 17 | 18 | ResetOutput = Tuple[torch.FloatTensor, Dict[str, Any]] 19 | StepOutput = Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]] 20 | InitialCondition = Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]] 21 | 22 | 23 | crop_frame = { 24 | 'left_top': (0.04, 0.18), 25 | 'right_bottom': (0.92, 0.95) 26 | } 27 | 28 | def extract_roi(image, rect): 29 | _, _, img_width, img_height, = image.shape 30 | min_x = int(rect['left_top'][1] * img_width) 31 | max_x = int(rect['right_bottom'][1] * img_width) 32 | min_y = int(rect['left_top'][0] * img_height) 33 | max_y = int(rect['right_bottom'][0] * img_height) 34 | roi = image[:,:,min_x:max_x, min_y:max_y] 35 | return roi 36 | 37 | @dataclass 38 | class WorldModelEnvConfig: 39 | horizon: int 40 | num_batches_to_preload: int 41 | diffusion_sampler_next_obs: DiffusionSamplerConfig 42 | diffusion_sampler_upsampling: Optional[DiffusionSamplerConfig] = None 43 | 44 | 45 | class WorldModelEnv: 46 | def __init__( 47 | self, 48 | denoiser: Denoiser, 49 | upsampler: Optional[Denoiser], 50 | rew_end_model: Optional[RewEndModel], 51 | spawn_dir: Path, 52 | num_envs: int, 53 | seq_length: int, 54 | cfg: WorldModelEnvConfig, 55 | return_denoising_trajectory: bool = False, 56 | ) -> None: 57 | assert num_envs == 1 # for human play only 58 | self.sampler_next_obs = DiffusionSampler(denoiser, cfg.diffusion_sampler_next_obs) 59 | self.sampler_upsampling = None if upsampler is None else DiffusionSampler(upsampler, cfg.diffusion_sampler_upsampling) 60 | self.rew_end_model = rew_end_model 61 | self.horizon = cfg.horizon 62 | self.return_denoising_trajectory = return_denoising_trajectory 63 | self.num_envs = num_envs 64 | self.generator_init = self.make_generator_init(spawn_dir, cfg.num_batches_to_preload) 65 | 66 | self.n_skip_next_obs = seq_length - self.sampler_next_obs.denoiser.cfg.inner_model.num_steps_conditioning 67 | self.n_skip_upsampling = None if upsampler is None else seq_length - self.sampler_upsampling.denoiser.cfg.inner_model.num_steps_conditioning 68 | 69 | self.context_indicies = get_frame_indices(denoiser.cfg.frame_sampling) 70 | 71 | @property 72 | def device(self) -> torch.device: 73 | return self.sampler_next_obs.denoiser.device 74 | 75 | @torch.no_grad() 76 | def reset(self, **kwargs) -> ResetOutput: 77 | obs, obs_full_res, act, next_act, (hx, cx) = self.generator_init.send(self.num_envs) 78 | self.obs_buffer = obs 79 | self.act_buffer = act 80 | self.next_act = next_act[0] 81 | self.obs_full_res_buffer = obs_full_res 82 | self.ep_len = torch.zeros(self.num_envs, dtype=torch.long, device=obs.device) 83 | self.hx_rew_end = hx 84 | self.cx_rew_end = cx 85 | obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1] 86 | return obs_to_return, {} 87 | 88 | @torch.no_grad() 89 | def step(self, act: torch.LongTensor) -> StepOutput: 90 | self.act_buffer[:, -1] = act 91 | 92 | next_obs, denoising_trajectory = self.predict_next_obs() 93 | 94 | if self.sampler_upsampling is not None: 95 | next_obs_full, denoising_trajectory_upsampling = self.upsample_next_obs(next_obs) 96 | 97 | if self.rew_end_model is not None: 98 | rew, end = self.predict_rew_end(next_obs.unsqueeze(1)) 99 | else: 100 | rew = torch.zeros(next_obs.size(0), dtype=torch.float32, device=self.device) 101 | end = torch.zeros(next_obs.size(0), dtype=torch.int64, device=self.device) 102 | 103 | self.ep_len += 1 104 | trunc = (self.ep_len >= self.horizon).long() 105 | 106 | self.obs_buffer = self.obs_buffer.roll(-1, dims=1) 107 | self.act_buffer = self.act_buffer.roll(-1, dims=1) 108 | self.obs_buffer[:, -1] = next_obs 109 | 110 | if self.sampler_upsampling is not None: 111 | self.obs_full_res_buffer = self.obs_full_res_buffer.roll(-1, dims=1) 112 | self.obs_full_res_buffer[:, -1] = next_obs_full 113 | 114 | info = {} 115 | if self.return_denoising_trajectory: 116 | info["denoising_trajectory"] = torch.stack(denoising_trajectory, dim=1) 117 | 118 | if self.sampler_upsampling is not None: 119 | info["obs_low_res"] = next_obs 120 | if self.return_denoising_trajectory: 121 | info["denoising_trajectory_upsampling"] = torch.stack(denoising_trajectory_upsampling, dim=1) 122 | 123 | obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1] 124 | return obs_to_return, rew, end, trunc, info 125 | 126 | @torch.no_grad() 127 | def predict_next_obs(self) -> Tuple[Tensor, List[Tensor]]: 128 | obs = self.obs_buffer[:, self.n_skip_next_obs:][:,self.context_indicies] 129 | act = self.act_buffer[:, self.n_skip_next_obs:][:,self.context_indicies] 130 | 131 | return self.sampler_next_obs.sample(obs, act) 132 | 133 | @torch.no_grad() 134 | def upsample_next_obs(self, next_obs: Tensor) -> Tuple[Tensor, List[Tensor]]: 135 | # Upsampling the low resolution frame from the world model 136 | low_res = F.interpolate(next_obs, scale_factor=self.sampler_upsampling.denoiser.cfg.upsampling_factor, mode="bicubic") 137 | 138 | # Cropping the frame to remove the sides of the screen 139 | low_res = extract_roi(low_res, crop_frame) 140 | 141 | # Reshape the frame to the upsampler's expected size 142 | size = (self.sampler_upsampling.denoiser.cfg.upsampling_frame_height, self.sampler_upsampling.denoiser.cfg.upsampling_frame_width) 143 | low_res = F.interpolate(low_res, size=size, mode='bicubic') 144 | 145 | return self.sampler_upsampling.sample(low_res.unsqueeze(1), None) 146 | 147 | @torch.no_grad() 148 | def predict_rew_end(self, next_obs: Tensor) -> Tuple[Tensor, Tensor]: 149 | logits_rew, logits_end, (self.hx_rew_end, self.cx_rew_end) = self.rew_end_model.predict_rew_end( 150 | self.obs_buffer[:, -1:], 151 | self.act_buffer[:, -1:], 152 | next_obs, 153 | (self.hx_rew_end, self.cx_rew_end), 154 | ) 155 | rew = Categorical(logits=logits_rew).sample().squeeze(1) - 1.0 # in {-1, 0, 1} 156 | end = Categorical(logits=logits_end).sample().squeeze(1) 157 | return rew, end 158 | 159 | @coroutine 160 | def make_generator_init( 161 | self, 162 | spawn_dir: Path, 163 | num_batches_to_preload: int, 164 | ) -> Generator[InitialCondition, None, None]: 165 | num_dead = yield 166 | 167 | spawn_dirs = cycle(sorted(list(spawn_dir.iterdir()))) 168 | 169 | while True: 170 | # Preload on device and burnin rew/end model 171 | obs_, obs_full_res_, act_, next_act_, hx_, cx_ = [], [], [], [], [], [] 172 | for _ in range(num_batches_to_preload): 173 | d = next(spawn_dirs) 174 | obs = torch.tensor(np.load(d / "low_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0) 175 | obs_full_res = torch.tensor(np.load(d / "full_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0) 176 | act = torch.tensor(np.load(d / "act.npy"), dtype=torch.long, device=self.device).unsqueeze(0) 177 | next_act = torch.tensor(np.load(d / "next_act.npy"), dtype=torch.long, device=self.device).unsqueeze(0) 178 | 179 | obs_.extend(list(obs)) 180 | obs_full_res_.extend(list(obs_full_res)) 181 | act_.extend(list(act)) 182 | next_act_.extend(list(next_act)) 183 | 184 | if self.rew_end_model is not None: 185 | with torch.no_grad(): 186 | *_, (hx, cx) = self.rew_end_model.predict_rew_end(obs_[:, :-1], act[:, :-1], obs[:, 1:]) # Burn-in of rew/end model 187 | assert hx.size(0) == cx.size(0) == 1 188 | hx_.extend(list(hx[0])) 189 | cx_.extend(list(cx[0])) 190 | 191 | # Yield new initial conditions for dead envs 192 | c = 0 193 | while c + num_dead <= len(obs_): 194 | obs = torch.stack(obs_[c : c + num_dead]) 195 | act = torch.stack(act_[c : c + num_dead]) 196 | next_act = next_act_[c : c + num_dead] 197 | obs_full_res = torch.stack(obs_full_res_[c : c + num_dead]) if self.sampler_upsampling is not None else None 198 | hx = torch.stack(hx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None 199 | cx = torch.stack(cx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None 200 | c += num_dead 201 | num_dead = yield obs, obs_full_res, act, next_act, (hx, cx) 202 | -------------------------------------------------------------------------------- /src/game/__init__.py: -------------------------------------------------------------------------------- 1 | from .game import Game 2 | from .play_env import PlayEnv 3 | -------------------------------------------------------------------------------- /src/game/dataset_env.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from data import Dataset 7 | 8 | 9 | class DatasetEnv: 10 | def __init__(self, datasets: List[Dataset], action_names: List[str]) -> None: 11 | self.datasets = [d for d in datasets if len(d) > 0] 12 | assert len(self.datasets) > 0 13 | self.action_names = action_names 14 | self.dataset_id = 0 15 | self.dataset = self.datasets[0] 16 | self.episode_id = None 17 | self.episode = None 18 | self.t = None 19 | self.ep_return = None 20 | self.ep_length = None 21 | self.pos_return = None 22 | self.neg_return = None 23 | self.load_episode(0) 24 | 25 | def print_controls(self) -> None: 26 | print("\nControls (dataset mode):\n") 27 | print(f"m : datasets ({'/'.join([d.name for d in self.datasets])})") 28 | print("↑ : next episode") 29 | print("↓ : prev episode") 30 | print("→ : next timestep") 31 | print("← : prev timestep") 32 | 33 | def next_mode(self) -> bool: 34 | self.switch_dataset() 35 | return True 36 | 37 | def next_axis_1(self) -> bool: 38 | self.load_episode(self.episode_id + 1) 39 | return True 40 | 41 | def prev_axis_1(self) -> bool: 42 | self.load_episode(self.episode_id - 1) 43 | return True 44 | 45 | def next_axis_2(self) -> bool: 46 | return False 47 | 48 | def prev_axis_2(self) -> bool: 49 | return False 50 | 51 | def load_episode(self, episode_id: int) -> None: 52 | self.episode_id = episode_id % self.dataset.num_episodes 53 | self.episode = self.dataset.load_episode(self.episode_id) 54 | self.set_timestep(0) 55 | metrics = self.episode.compute_metrics() 56 | self.ep_return = metrics["return"] 57 | self.ep_length = metrics["length"] 58 | self.pos_return = self.episode.rew[self.episode.rew > 0].sum().item() 59 | self.neg_return = self.episode.rew[self.episode.rew < 0].sum().abs().item() 60 | 61 | def set_timestep(self, timestep: int) -> None: 62 | self.t = timestep % len(self.episode) 63 | self.obs = self.episode.obs[self.t].unsqueeze(0) 64 | self.act = self.episode.act[self.t] 65 | self.rew = self.episode.rew[self.t] 66 | self.end = self.episode.end[self.t] 67 | self.trunc = self.episode.trunc[self.t] 68 | 69 | def switch_dataset(self) -> None: 70 | self.dataset_id = (self.dataset_id + 1) % len(self.datasets) 71 | self.dataset = self.datasets[self.dataset_id] 72 | self.load_episode(0) 73 | 74 | def reset(self) -> None: 75 | self.set_timestep(0) 76 | return self.obs, None 77 | 78 | @torch.no_grad() 79 | def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]: 80 | match act: 81 | case 1: 82 | self.set_timestep(self.t - 1) 83 | case 2: 84 | self.set_timestep(self.t + 1) 85 | case 3: 86 | self.set_timestep(self.t - 10) 87 | case 4: 88 | self.set_timestep(self.t + 10) 89 | 90 | n_digits = len(str(self.ep_length)) 91 | 92 | header = [ 93 | [ 94 | f"Dataset: {self.dataset.name}", 95 | f"Episode: {self.episode_id}", 96 | "--------", 97 | f"Return (+): +{self.pos_return:4.1f}", 98 | f"Return (-): -{self.neg_return:4.1f}", 99 | f"Total : {self.ep_return:4.1f}", 100 | ], 101 | [ 102 | f"Action: {self.action_names[self.act]}", 103 | f"Trunc : {bool(self.trunc)}", 104 | f"Done : {bool(self.end)}", 105 | f"Reward: {self.rew.item():.2f}", 106 | "-------", 107 | f"To here: {self.episode.rew[:self.t + 1].sum().item():.2f}", 108 | f"To go : {self.episode.rew[self.t + 1:].sum().item():.2f}", 109 | ], 110 | [ 111 | f"Timestep: {self.t:{n_digits}d}", 112 | f"Length : {self.ep_length}", 113 | ], 114 | ] 115 | info = {"header": header} 116 | return self.obs, torch.tensor(0), False, False, info 117 | -------------------------------------------------------------------------------- /src/game/game.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | import pygame 5 | from PIL import Image 6 | 7 | from player.action_processing import GameAction 8 | from .dataset_env import DatasetEnv 9 | from .play_env import PlayEnv 10 | 11 | class Game: 12 | def __init__( 13 | self, 14 | play_env: Union[PlayEnv, DatasetEnv], 15 | size: Tuple[int, int], 16 | fps: int, 17 | verbose: bool, 18 | ) -> None: 19 | self.env = play_env 20 | self.height, self.width = size 21 | self.fps = fps 22 | self.verbose = verbose 23 | self.env.print_controls() 24 | print("\nControls:\n") 25 | print(" m : switch control (human/replay)") # Not for main as Game can use either PlayEnv or DatasetEnv 26 | print(" . : pause/unpause") 27 | print(" e : step-by-step (when paused)") 28 | print(" ⏎ : reset env") 29 | print("Esc : quit") 30 | print("\n") 31 | input("Press enter to start") 32 | 33 | def run(self) -> None: 34 | pygame.init() 35 | 36 | header_height = 150 if self.verbose else 0 37 | header_width = 540 38 | font_size = 16 39 | screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN) 40 | pygame.mouse.set_visible(False) 41 | pygame.event.set_grab(True) 42 | clock = pygame.time.Clock() 43 | font = pygame.font.SysFont("mono", font_size) 44 | x_center, y_center = screen.get_rect().center 45 | x_header = x_center - header_width // 2 46 | y_header = y_center - self.height // 2 - header_height - 10 47 | header_rect = pygame.Rect(x_header, y_header, header_width, header_height) 48 | 49 | def clear_header(): 50 | pygame.draw.rect(screen, pygame.Color("black"), header_rect) 51 | pygame.draw.rect(screen, pygame.Color("white"), header_rect, 1) 52 | 53 | def draw_text(text, idx_line, idx_column, num_cols): 54 | x_pos = 5 + idx_column * int(header_width // num_cols) 55 | y_pos = 5 + idx_line * font_size 56 | assert (0 <= x_pos <= header_width) and (0 <= y_pos <= header_height) 57 | screen.blit(font.render(text, True, pygame.Color("white")), (x_header + x_pos, y_header + y_pos)) 58 | 59 | def draw_obs(obs, obs_low_res=None): 60 | assert obs.ndim == 4 and obs.size(0) == 1 61 | two_players = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy() 62 | player_1_frame, player_2_frame = two_players[:,:,:3], two_players[:, :, 3:] 63 | 64 | # Stack images vertically 65 | stacked_arr = np.vstack((player_1_frame, player_2_frame)) 66 | 67 | # Convert back to image 68 | img = Image.fromarray(stacked_arr) 69 | 70 | # resize the images, and prepare it for display 71 | pygame_image = np.array(img.resize((self.width, self.height), resample=Image.BOX)).transpose((1, 0, 2)) 72 | 73 | surface = pygame.surfarray.make_surface(pygame_image) 74 | screen.blit(surface, (x_center - self.width // 2, y_center - self.height // 2)) 75 | 76 | if obs_low_res is not None: 77 | assert obs_low_res.ndim == 4 and obs_low_res.size(0) == 1 78 | img = Image.fromarray(obs_low_res[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()) 79 | h = self.height * obs_low_res.size(2) // obs.size(2) 80 | w = self.width * obs_low_res.size(3) // obs.size(3) 81 | pygame_image = np.array(img) 82 | surface = pygame.surfarray.make_surface(pygame_image) 83 | screen.blit(surface, (x_center - w // 2, y_center + self.height // 2)) 84 | 85 | def reset(): 86 | nonlocal obs, info, do_reset, ep_return, ep_length, keys_pressed 87 | obs, info = self.env.reset() 88 | pygame.event.clear() 89 | do_reset = False 90 | ep_return = 0 91 | ep_length = 0 92 | keys_pressed = [] 93 | 94 | obs, info, do_reset, ep_return, ep_length, keys_pressed= (None,) * 6 95 | 96 | reset() 97 | do_wait = False 98 | should_stop = False 99 | 100 | while not should_stop: 101 | do_one_step = False 102 | pygame.event.pump() 103 | 104 | for event in pygame.event.get(): 105 | if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE): 106 | should_stop = True 107 | 108 | if event.type == pygame.KEYDOWN: 109 | keys_pressed.append(event.key) 110 | 111 | elif event.type == pygame.KEYUP and event.key in keys_pressed: 112 | keys_pressed.remove(event.key) 113 | 114 | if event.type != pygame.KEYDOWN: 115 | continue 116 | 117 | if event.key == pygame.K_RETURN: 118 | do_reset = True 119 | 120 | if event.key == pygame.K_PERIOD: 121 | do_wait = not do_wait 122 | print("Game paused." if do_wait else "Game resumed.") 123 | 124 | if event.key == pygame.K_e: 125 | do_one_step = True 126 | 127 | if event.key == pygame.K_m: 128 | do_reset = self.env.next_mode() 129 | 130 | if event.key == pygame.K_UP: 131 | do_reset = self.env.next_axis_1() 132 | 133 | if event.key == pygame.K_DOWN: 134 | do_reset = self.env.prev_axis_1() 135 | 136 | if event.key == pygame.K_RIGHT: 137 | do_reset = self.env.next_axis_2() 138 | 139 | if event.key == pygame.K_LEFT: 140 | do_reset = self.env.prev_axis_2() 141 | 142 | if do_reset: 143 | reset() 144 | 145 | if do_wait and not do_one_step: 146 | continue 147 | 148 | game_action = GameAction(keys_pressed) 149 | next_obs, rew, end, trunc, info = self.env.step(game_action) 150 | 151 | ep_return += rew.item() 152 | ep_length += 1 153 | 154 | if self.verbose and info is not None: 155 | clear_header() 156 | assert isinstance(info, dict) and "header" in info 157 | header = info["header"] 158 | num_cols = len(header) 159 | for j, col in enumerate(header): 160 | for i, row in enumerate(col): 161 | draw_text(row, idx_line=i, idx_column=j, num_cols=num_cols) 162 | 163 | draw_low_res = self.verbose and "obs_low_res" in info and self.width == 280 164 | if draw_low_res: 165 | draw_obs(obs, info["obs_low_res"]) 166 | draw_text(" Pre-upsampling:", 0, 2, 3) 167 | else: 168 | draw_obs(obs, None) 169 | 170 | pygame.display.flip() # update screen 171 | clock.tick(self.fps) # ensures game maintains the given frame rate 172 | 173 | if end or trunc: 174 | reset() 175 | 176 | else: 177 | obs = next_obs 178 | 179 | pygame.quit() 180 | -------------------------------------------------------------------------------- /src/game/play_env.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | import math 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Tuple 5 | 6 | import pygame 7 | import torch 8 | from torch import Tensor 9 | 10 | from agent import Agent 11 | from player.action_processing import GameAction, decode_game_action, encode_game_action, print_game_action 12 | from player.keymap import GAME_KEYMAP 13 | from data import Dataset, Episode 14 | from envs import WorldModelEnv 15 | 16 | 17 | NamedEnv = namedtuple("NamedEnv", "name env") 18 | OneStepData = namedtuple("OneStepData", "obs act rew end trunc") 19 | 20 | 21 | class PlayEnv: 22 | def __init__( 23 | self, 24 | agent: Agent, 25 | wm_env: WorldModelEnv, 26 | recording_mode: bool, 27 | store_denoising_trajectory: bool, 28 | store_original_obs: bool, 29 | ) -> None: 30 | self.agent = agent 31 | self.keymap = GAME_KEYMAP 32 | self.recording_mode = recording_mode 33 | self.store_denoising_trajectory = store_denoising_trajectory 34 | self.store_original_obs = store_original_obs 35 | self.is_human_player = True 36 | self.env_id = 0 37 | self.env_name = "world model" 38 | self.env = wm_env 39 | self.obs, self.t, self.buffer, self.rec_dataset = (None,) * 4 40 | 41 | def print_controls(self) -> None: 42 | print("\nEnvironment actions:\n") 43 | for key, action_name in self.keymap.items(): 44 | if key is not None: 45 | key_name = pygame.key.name(key) 46 | key_name = "⎵" if key_name == "space" else key_name 47 | print(f"{key_name} : {action_name}") 48 | 49 | def next_mode(self) -> bool: 50 | self.switch_controller() 51 | return True 52 | 53 | def next_axis_1(self) -> bool: 54 | return False 55 | 56 | def prev_axis_1(self) -> bool: 57 | return False 58 | 59 | def next_axis_2(self) -> bool: 60 | return False 61 | 62 | def prev_axis_2(self) -> bool: 63 | return False 64 | 65 | def print_env(self) -> None: 66 | print(f"> Environment: {self.env_name}") 67 | 68 | def str_control(self) -> str: 69 | return "human" if self.is_human_player else "replay actions (test dataset)" 70 | 71 | def print_control(self) -> None: 72 | print(f"> Control: {self.str_control()}") 73 | 74 | def switch_controller(self) -> None: 75 | self.is_human_player = not self.is_human_player 76 | self.print_control() 77 | 78 | def update_wm_horizon(self, incr: int) -> None: 79 | self.env.horizon = max(1, self.env.horizon + incr) 80 | 81 | def reset_recording(self) -> None: 82 | self.buffer = defaultdict(list) 83 | self.buffer["info"] = defaultdict(list) 84 | dir = Path("dataset") / f"rec_{self.env_name}_{'H' if self.is_human_player else 'R'}" 85 | self.rec_dataset = Dataset(dir, None) 86 | self.rec_dataset.load_from_default_path() 87 | 88 | def reset(self) -> Tuple[Tensor, None]: 89 | self.obs, _ = self.env.reset() 90 | self.t = 0 91 | if self.recording_mode: 92 | self.reset_recording() 93 | return self.obs, None 94 | 95 | @torch.no_grad() 96 | def step(self, game_action: GameAction) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]: 97 | if self.is_human_player: 98 | action = encode_game_action(game_action, device=self.agent.device) 99 | else: 100 | action = self.env.next_act[self.t - 1] if self.t > 0 else self.env.act_buffer[0, -1].clone() 101 | game_action = decode_game_action(action.cpu()) 102 | 103 | next_obs, rew, end, trunc, env_info = self.env.step(action) 104 | 105 | if not self.is_human_player and self.t == self.env.next_act.size(0): 106 | trunc[0] = 1 107 | 108 | data = OneStepData(self.obs, action, rew, end, trunc) 109 | keys = print_game_action(game_action) 110 | horizon = self.env.horizon if self.is_human_player else min(self.env.horizon, self.env.next_act.size(0)) 111 | header = [ 112 | [ 113 | f"Env : {self.env_name}", 114 | f"Control : {self.str_control()}", 115 | f"Timestep: {self.t + 1}", 116 | f"Horizon : {horizon}", 117 | f"Keys : {keys}", 118 | ], 119 | ] 120 | info = {"header": header} 121 | if "obs_low_res" in env_info: 122 | info["obs_low_res"] = env_info["obs_low_res"] 123 | 124 | if self.recording_mode: 125 | for k, v in data._asdict().items(): 126 | self.buffer[k].append(v) 127 | if "obs_low_res" in env_info: 128 | self.buffer["info"]["obs_low_res"].append(env_info["obs_low_res"]) 129 | if self.store_denoising_trajectory and "denoising_trajectory" in env_info: 130 | self.buffer["info"]["denoising_trajectory"].append(env_info["denoising_trajectory"]) 131 | if self.store_original_obs and "original_obs" in env_info: 132 | original_obs = (torch.tensor(env_info["original_obs"][0]).permute(2, 0, 1).unsqueeze(0).contiguous()) 133 | self.buffer["info"]["original_obs"].append(original_obs) 134 | if end or trunc: 135 | ep_dict = {k: torch.cat(v, dim=0) for k, v in self.buffer.items() if k != "info"} 136 | ep_info = {k: torch.cat(v, dim=0) for k, v in self.buffer["info"].items()} 137 | ep = Episode(**ep_dict, info=ep_info).to("cpu") 138 | self.rec_dataset.add_episode(ep) 139 | self.rec_dataset.save_to_default_path() 140 | 141 | self.obs = next_obs 142 | self.t += 1 143 | 144 | return next_obs, rew, end, trunc, info 145 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | import torch 8 | from torch.distributed import init_process_group, destroy_process_group 9 | import torch.multiprocessing as mp 10 | 11 | from trainer import Trainer 12 | from utils import skip_if_run_is_over 13 | 14 | 15 | OmegaConf.register_new_resolver("eval", eval) 16 | 17 | 18 | @hydra.main(config_path="../config", config_name="trainer", version_base="1.3") 19 | def main(cfg: DictConfig) -> None: 20 | setup_visible_cuda_devices(cfg.common.devices) 21 | world_size = torch.cuda.device_count() 22 | root_dir = Path(hydra.utils.get_original_cwd()) 23 | if world_size < 2: 24 | run(cfg, root_dir) 25 | else: 26 | mp.spawn(main_ddp, args=(world_size, cfg, root_dir), nprocs=world_size) 27 | 28 | 29 | def main_ddp(rank: int, world_size: int, cfg: DictConfig, root_dir: Path) -> None: 30 | setup_ddp(rank, world_size) 31 | run(cfg, root_dir) 32 | destroy_process_group() 33 | 34 | 35 | @skip_if_run_is_over 36 | def run(cfg: DictConfig, root_dir: Path) -> None: 37 | trainer = Trainer(cfg, root_dir) 38 | trainer.run() 39 | 40 | 41 | def setup_ddp(rank: int, world_size: int) -> None: 42 | os.environ["MASTER_ADDR"] = "localhost" 43 | os.environ["MASTER_PORT"] = "6006" 44 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 45 | 46 | 47 | def setup_visible_cuda_devices(devices: Union[str, int, List[int]]) -> None: 48 | if isinstance(devices, str): 49 | if devices == "cpu": 50 | devices = [] 51 | else: 52 | assert devices == "all" 53 | return 54 | elif isinstance(devices, int): 55 | devices = [devices] 56 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices)) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/actor_critic.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass 3 | import math 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torch import Tensor 8 | import torch.nn as nn 9 | from torch.distributions.categorical import Categorical 10 | import torch.nn.functional as F 11 | 12 | from .blocks import Conv3x3, SmallResBlock 13 | from coroutines.env_loop import make_env_loop 14 | from envs import TorchEnv, WorldModelEnv 15 | from utils import init_lstm, LossAndLogs 16 | 17 | 18 | ActorCriticOutput = namedtuple("ActorCriticOutput", "logits_act val hx_cx") 19 | 20 | 21 | @dataclass 22 | class ActorCriticLossConfig: 23 | backup_every: int 24 | gamma: float 25 | lambda_: float 26 | weight_value_loss: float 27 | weight_entropy_loss: float 28 | 29 | 30 | @dataclass 31 | class ActorCriticConfig: 32 | lstm_dim: int 33 | img_channels: int 34 | img_size: int 35 | channels: List[int] 36 | down: List[int] 37 | num_actions: Optional[int] = None 38 | 39 | 40 | class ActorCritic(nn.Module): 41 | def __init__(self, cfg: ActorCriticConfig) -> None: 42 | super().__init__() 43 | self.encoder = ActorCriticEncoder(cfg) 44 | self.lstm_dim = cfg.lstm_dim 45 | input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (sum(cfg.down))) ** 2 46 | self.lstm = nn.LSTMCell(input_dim_lstm, cfg.lstm_dim) 47 | self.critic_linear = nn.Linear(cfg.lstm_dim, 1) 48 | self.actor_linear = nn.Linear(cfg.lstm_dim, cfg.num_actions) 49 | 50 | self.actor_linear.weight.data.fill_(0) 51 | self.actor_linear.bias.data.fill_(0) 52 | self.critic_linear.weight.data.fill_(0) 53 | self.critic_linear.bias.data.fill_(0) 54 | init_lstm(self.lstm) 55 | 56 | self.env_loop = None 57 | self.loss_cfg = None 58 | 59 | @property 60 | def device(self) -> torch.device: 61 | return self.lstm.weight_hh.device 62 | 63 | def setup_training(self, rl_env: Union[TorchEnv, WorldModelEnv], loss_cfg: ActorCriticLossConfig) -> None: 64 | assert self.env_loop is None and self.loss_cfg is None 65 | self.env_loop = make_env_loop(rl_env, self) 66 | self.loss_cfg = loss_cfg 67 | 68 | def predict_act_value(self, obs: Tensor, hx_cx: Tuple[Tensor, Tensor]) -> ActorCriticOutput: 69 | assert obs.ndim == 4 70 | x = self.encoder(obs) 71 | x = x.flatten(start_dim=1) 72 | hx, cx = self.lstm(x, hx_cx) 73 | return ActorCriticOutput(self.actor_linear(hx), self.critic_linear(hx).squeeze(dim=1), (hx, cx)) 74 | 75 | def forward(self) -> LossAndLogs: 76 | c = self.loss_cfg 77 | _, act, rew, end, trunc, logits_act, val, val_bootstrap, _ = self.env_loop.send(c.backup_every) 78 | 79 | d = Categorical(logits=logits_act) 80 | entropy = d.entropy().mean() 81 | 82 | lambda_returns = compute_lambda_returns(rew, end, trunc, val_bootstrap, c.gamma, c.lambda_) 83 | 84 | loss_actions = (-d.log_prob(act) * (lambda_returns - val).detach()).mean() 85 | loss_values = c.weight_value_loss * F.mse_loss(val, lambda_returns) 86 | loss_entropy = -c.weight_entropy_loss * entropy 87 | 88 | loss = loss_actions + loss_entropy + loss_values 89 | 90 | metrics = { 91 | "policy_entropy": entropy.detach() / math.log(2), 92 | "loss_actions": loss_actions.detach(), 93 | "loss_entropy": loss_entropy.detach(), 94 | "loss_values": loss_values.detach(), 95 | "loss_total": loss.detach(), 96 | } 97 | 98 | return loss, metrics 99 | 100 | 101 | class ActorCriticEncoder(nn.Module): 102 | def __init__(self, cfg: ActorCriticConfig) -> None: 103 | super().__init__() 104 | assert len(cfg.channels) == len(cfg.down) 105 | encoder_layers = [Conv3x3(cfg.img_channels, cfg.channels[0])] 106 | for i in range(len(cfg.channels)): 107 | encoder_layers.append(SmallResBlock(cfg.channels[max(0, i - 1)], cfg.channels[i])) 108 | if cfg.down[i]: 109 | encoder_layers.append(nn.MaxPool2d(2)) 110 | self.encoder = nn.Sequential(*encoder_layers) 111 | 112 | def forward(self, x: Tensor) -> Tensor: 113 | return self.encoder(x) 114 | 115 | 116 | @torch.no_grad() 117 | def compute_lambda_returns( 118 | rew: Tensor, 119 | end: Tensor, 120 | trunc: Tensor, 121 | val_bootstrap: Tensor, 122 | gamma: float, 123 | lambda_: float, 124 | ) -> Tensor: 125 | assert rew.ndim == 2 and rew.size() == end.size() == trunc.size() == val_bootstrap.size() 126 | 127 | rew = rew.sign() # clip reward 128 | 129 | end_or_trunc = (end + trunc).clip(max=1) 130 | not_end = 1 - end 131 | not_trunc = 1 - trunc 132 | 133 | lambda_returns = rew + not_end * gamma * (not_trunc * (1 - lambda_) + trunc) * val_bootstrap 134 | 135 | if lambda_ == 0: 136 | return lambda_returns 137 | 138 | last = val_bootstrap[:, -1] 139 | for t in reversed(range(rew.size(1))): 140 | lambda_returns[:, t] += end_or_trunc[:, t].logical_not() * gamma * lambda_ * last 141 | last = lambda_returns[:, t] 142 | 143 | return lambda_returns 144 | -------------------------------------------------------------------------------- /src/models/blocks.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | from typing import List, Optional 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | # Settings for GroupNorm and Attention 11 | 12 | GN_GROUP_SIZE = 32 13 | GN_EPS = 1e-5 14 | ATTN_HEAD_DIM = 8 15 | 16 | # Convs 17 | 18 | Conv1x1 = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) 19 | Conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1) 20 | 21 | # GroupNorm and conditional GroupNorm 22 | 23 | 24 | class GroupNorm(nn.Module): 25 | def __init__(self, in_channels: int) -> None: 26 | super().__init__() 27 | num_groups = max(1, in_channels // GN_GROUP_SIZE) 28 | self.norm = nn.GroupNorm(num_groups, in_channels, eps=GN_EPS) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | return self.norm(x) 32 | 33 | 34 | class AdaGroupNorm(nn.Module): 35 | def __init__(self, in_channels: int, cond_channels: int) -> None: 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.num_groups = max(1, in_channels // GN_GROUP_SIZE) 39 | self.linear = nn.Linear(cond_channels, in_channels * 2) 40 | 41 | def forward(self, x: Tensor, cond: Tensor) -> Tensor: 42 | assert x.size(1) == self.in_channels 43 | x = F.group_norm(x, self.num_groups, eps=GN_EPS) 44 | scale, shift = self.linear(cond)[:, :, None, None].chunk(2, dim=1) 45 | return x * (1 + scale) + shift 46 | 47 | 48 | # Self Attention 49 | 50 | 51 | class SelfAttention2d(nn.Module): 52 | def __init__(self, in_channels: int, head_dim: int = ATTN_HEAD_DIM) -> None: 53 | super().__init__() 54 | self.n_head = max(1, in_channels // head_dim) 55 | assert in_channels % self.n_head == 0 56 | self.norm = GroupNorm(in_channels) 57 | self.qkv_proj = Conv1x1(in_channels, in_channels * 3) 58 | self.out_proj = Conv1x1(in_channels, in_channels) 59 | nn.init.zeros_(self.out_proj.weight) 60 | nn.init.zeros_(self.out_proj.bias) 61 | 62 | def forward(self, x: Tensor) -> Tensor: 63 | n, c, h, w = x.shape 64 | x = self.norm(x) 65 | qkv = self.qkv_proj(x) 66 | qkv = qkv.view(n, self.n_head * 3, c // self.n_head, h * w).transpose(2, 3).contiguous() 67 | q, k, v = [x for x in qkv.chunk(3, dim=1)] 68 | att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1)) 69 | att = F.softmax(att, dim=-1) 70 | y = att @ v 71 | y = y.transpose(2, 3).reshape(n, c, h, w) 72 | return x + self.out_proj(y) 73 | 74 | 75 | # Embedding of the noise level 76 | 77 | 78 | class FourierFeatures(nn.Module): 79 | def __init__(self, cond_channels: int) -> None: 80 | super().__init__() 81 | assert cond_channels % 2 == 0 82 | self.register_buffer("weight", torch.randn(1, cond_channels // 2)) 83 | 84 | def forward(self, input: Tensor) -> Tensor: 85 | assert input.ndim == 1 86 | f = 2 * math.pi * input.unsqueeze(1) @ self.weight 87 | return torch.cat([f.cos(), f.sin()], dim=-1) 88 | 89 | 90 | # [Down|Up]sampling 91 | 92 | 93 | class Downsample(nn.Module): 94 | def __init__(self, in_channels: int) -> None: 95 | super().__init__() 96 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1) 97 | nn.init.orthogonal_(self.conv.weight) 98 | 99 | def forward(self, x: Tensor) -> Tensor: 100 | return self.conv(x) 101 | 102 | 103 | class Upsample(nn.Module): 104 | def __init__(self, in_channels: int) -> None: 105 | super().__init__() 106 | self.conv = Conv3x3(in_channels, in_channels) 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 110 | return self.conv(x) 111 | 112 | 113 | # Small Residual block 114 | 115 | 116 | class SmallResBlock(nn.Module): 117 | def __init__(self, in_channels: int, out_channels: int) -> None: 118 | super().__init__() 119 | self.f = nn.Sequential(GroupNorm(in_channels), nn.SiLU(inplace=True), Conv3x3(in_channels, out_channels)) 120 | self.skip_projection = nn.Identity() if in_channels == out_channels else Conv1x1(in_channels, out_channels) 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | return self.skip_projection(x) + self.f(x) 124 | 125 | 126 | # Residual block (conditioning with AdaGroupNorm, no [down|up]sampling, optional self-attention) 127 | 128 | 129 | class ResBlock(nn.Module): 130 | def __init__(self, in_channels: int, out_channels: int, cond_channels: int, attn: bool) -> None: 131 | super().__init__() 132 | should_proj = in_channels != out_channels 133 | self.proj = Conv1x1(in_channels, out_channels) if should_proj else nn.Identity() 134 | self.norm1 = AdaGroupNorm(in_channels, cond_channels) 135 | self.conv1 = Conv3x3(in_channels, out_channels) 136 | self.norm2 = AdaGroupNorm(out_channels, cond_channels) 137 | self.conv2 = Conv3x3(out_channels, out_channels) 138 | self.attn = SelfAttention2d(out_channels) if attn else nn.Identity() 139 | nn.init.zeros_(self.conv2.weight) 140 | 141 | def forward(self, x: Tensor, cond: Tensor) -> Tensor: 142 | r = self.proj(x) 143 | x = self.conv1(F.silu(self.norm1(x, cond))) 144 | x = self.conv2(F.silu(self.norm2(x, cond))) 145 | x = x + r 146 | x = self.attn(x) 147 | return x 148 | 149 | 150 | # Sequence of residual blocks (in_channels -> mid_channels -> ... -> mid_channels -> out_channels) 151 | 152 | 153 | class ResBlocks(nn.Module): 154 | def __init__( 155 | self, 156 | list_in_channels: List[int], 157 | list_out_channels: List[int], 158 | cond_channels: int, 159 | attn: bool, 160 | ) -> None: 161 | super().__init__() 162 | assert len(list_in_channels) == len(list_out_channels) 163 | self.in_channels = list_in_channels[0] 164 | self.resblocks = nn.ModuleList( 165 | [ 166 | ResBlock(in_ch, out_ch, cond_channels, attn) 167 | for (in_ch, out_ch) in zip(list_in_channels, list_out_channels) 168 | ] 169 | ) 170 | 171 | def forward(self, x: Tensor, cond: Tensor, to_cat: Optional[List[Tensor]] = None) -> Tensor: 172 | outputs = [] 173 | for i, resblock in enumerate(self.resblocks): 174 | x = x if to_cat is None else torch.cat((x, to_cat[i]), dim=1) 175 | x = resblock(x, cond) 176 | outputs.append(x) 177 | return x, outputs 178 | 179 | 180 | # UNet 181 | 182 | 183 | class UNet(nn.Module): 184 | def __init__(self, cond_channels: int, depths: List[int], channels: List[int], attn_depths: List[int]) -> None: 185 | super().__init__() 186 | assert len(depths) == len(channels) == len(attn_depths) 187 | self._num_down = len(channels) - 1 188 | 189 | d_blocks, u_blocks = [], [] 190 | for i, n in enumerate(depths): 191 | c1 = channels[max(0, i - 1)] 192 | c2 = channels[i] 193 | d_blocks.append( 194 | ResBlocks( 195 | list_in_channels=[c1] + [c2] * (n - 1), 196 | list_out_channels=[c2] * n, 197 | cond_channels=cond_channels, 198 | attn=attn_depths[i], 199 | ) 200 | ) 201 | u_blocks.append( 202 | ResBlocks( 203 | list_in_channels=[2 * c2] * n + [c1 + c2], 204 | list_out_channels=[c2] * n + [c1], 205 | cond_channels=cond_channels, 206 | attn=attn_depths[i], 207 | ) 208 | ) 209 | self.d_blocks = nn.ModuleList(d_blocks) 210 | self.u_blocks = nn.ModuleList(reversed(u_blocks)) 211 | 212 | self.mid_blocks = ResBlocks( 213 | list_in_channels=[channels[-1]] * 2, 214 | list_out_channels=[channels[-1]] * 2, 215 | cond_channels=cond_channels, 216 | attn=True, 217 | ) 218 | 219 | downsamples = [nn.Identity()] + [Downsample(c) for c in channels[:-1]] 220 | upsamples = [nn.Identity()] + [Upsample(c) for c in reversed(channels[:-1])] 221 | self.downsamples = nn.ModuleList(downsamples) 222 | self.upsamples = nn.ModuleList(upsamples) 223 | 224 | def forward(self, x: Tensor, cond: Tensor) -> Tensor: 225 | *_, h, w = x.size() 226 | n = self._num_down 227 | padding_h = math.ceil(h / 2 ** n) * 2 ** n - h 228 | padding_w = math.ceil(w / 2 ** n) * 2 ** n - w 229 | x = F.pad(x, (0, padding_w, 0, padding_h)) 230 | 231 | d_outputs = [] 232 | for block, down in zip(self.d_blocks, self.downsamples): 233 | x_down = down(x) 234 | x, block_outputs = block(x_down, cond) 235 | d_outputs.append((x_down, *block_outputs)) 236 | 237 | x, _ = self.mid_blocks(x, cond) 238 | 239 | u_outputs = [] 240 | for block, up, skip in zip(self.u_blocks, self.upsamples, reversed(d_outputs)): 241 | x_up = up(x) 242 | x, block_outputs = block(x_up, cond, skip[::-1]) 243 | u_outputs.append((x_up, *block_outputs)) 244 | 245 | x = x[..., :h, :w] 246 | return x, d_outputs, u_outputs 247 | -------------------------------------------------------------------------------- /src/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser, DenoiserConfig, SigmaDistributionConfig 2 | from .inner_model import InnerModelConfig 3 | from .diffusion_sampler import DiffusionSampler, DiffusionSamplerConfig 4 | -------------------------------------------------------------------------------- /src/models/diffusion/denoiser.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, List 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from data import Batch 10 | from .inner_model import InnerModel, InnerModelConfig 11 | from utils import LossAndLogs, get_frame_indices, LossLogsData 12 | 13 | 14 | def add_dims(input: Tensor, n: int) -> Tensor: 15 | return input.reshape(input.shape + (1,) * (n - input.ndim)) 16 | 17 | 18 | @dataclass 19 | class Conditioners: 20 | c_in: Tensor 21 | c_out: Tensor 22 | c_skip: Tensor 23 | c_noise: Tensor 24 | c_noise_cond: Tensor 25 | 26 | 27 | @dataclass 28 | class SigmaDistributionConfig: 29 | loc: float 30 | scale: float 31 | sigma_min: float 32 | sigma_max: float 33 | 34 | @dataclass 35 | class FrameStrides: 36 | count: int 37 | stride: int 38 | 39 | 40 | @dataclass 41 | class DenoiserConfig: 42 | inner_model: InnerModelConfig 43 | sigma_data: float 44 | sigma_offset_noise: float 45 | noise_previous_obs: bool 46 | upsampling_factor: Optional[int] = None 47 | upsampling_frame_height: Optional[int] = None 48 | upsampling_frame_width: Optional[int] = None 49 | frame_sampling: Optional[List[FrameStrides]] = None 50 | 51 | 52 | class Denoiser(nn.Module): 53 | def __init__(self, cfg: DenoiserConfig) -> None: 54 | super().__init__() 55 | self.cfg = cfg 56 | self.is_upsampler = cfg.upsampling_factor is not None 57 | cfg.inner_model.is_upsampler = self.is_upsampler 58 | self.inner_model = InnerModel(cfg.inner_model) 59 | self.sample_sigma_training = None 60 | self.context_indicies = None if self.is_upsampler else get_frame_indices(cfg.frame_sampling) 61 | 62 | @property 63 | def device(self) -> torch.device: 64 | return self.inner_model.noise_emb.weight.device 65 | 66 | def setup_training(self, cfg: SigmaDistributionConfig) -> None: 67 | assert self.sample_sigma_training is None 68 | 69 | def sample_sigma(n: int, device: torch.device): 70 | s = torch.randn(n, device=device) * cfg.scale + cfg.loc 71 | return s.exp().clip(cfg.sigma_min, cfg.sigma_max) 72 | 73 | self.sample_sigma_training = sample_sigma 74 | 75 | def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor: 76 | b, c, _, _ = x.shape 77 | offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device) 78 | return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim) 79 | 80 | def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners: 81 | sigma = (sigma ** 2 + self.cfg.sigma_offset_noise ** 2).sqrt() 82 | c_in = 1 / (sigma ** 2 + self.cfg.sigma_data ** 2).sqrt() 83 | c_skip = self.cfg.sigma_data ** 2 / (sigma ** 2 + self.cfg.sigma_data ** 2) 84 | c_out = sigma * c_skip.sqrt() 85 | c_noise = sigma.log() / 4 86 | c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise) 87 | return Conditioners( 88 | *(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1)))) 89 | 90 | def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], 91 | cs: Conditioners) -> Tensor: 92 | rescaled_obs = obs / self.cfg.sigma_data 93 | rescaled_noise = noisy_next_obs * cs.c_in 94 | return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act) 95 | 96 | @torch.no_grad() 97 | def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor: 98 | d = cs.c_skip * noisy_next_obs + cs.c_out * model_output 99 | # Quantize to {0, ..., 255}, then back to [-1, 1] 100 | d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1) 101 | return d 102 | 103 | @torch.no_grad() 104 | def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, 105 | act: Optional[Tensor]) -> Tensor: 106 | cs = self.compute_conditioners(sigma, sigma_cond) 107 | model_output = self.compute_model_output(noisy_next_obs, obs, act, cs) 108 | denoised = self.wrap_model_output(noisy_next_obs, model_output, cs) 109 | return denoised 110 | 111 | def get_prev_obs(self, all_obs, act, i, b, n, c, H, W): 112 | prev_obs = all_obs[:, i:i + n].reshape(b, n * c, H, W) if self.is_upsampler else all_obs[:, i + self.context_indicies].reshape(b, n * c, H, W) 113 | prev_act = None if self.is_upsampler else act[:,i+self.context_indicies] 114 | return prev_obs, prev_act 115 | 116 | def forward(self, batch: Batch) -> LossLogsData: 117 | b, t, c, h, w = batch.obs.size() 118 | H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w) 119 | n = 0 if self.is_upsampler else self.context_indicies[-1] + 1 120 | seq_length = t - n # t = n + 1 + num_autoregressive_steps 121 | 122 | if self.is_upsampler: 123 | all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device) 124 | low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, 125 | mode="bicubic").reshape(b, t, c, H, W) 126 | all_acts = None 127 | assert all_obs.shape == low_res.shape 128 | else: 129 | all_obs = batch.obs.clone() 130 | all_acts = batch.act.clone() 131 | 132 | loss = 0 133 | for i in range(seq_length): 134 | prev_obs, prev_act = self.get_prev_obs(all_obs, all_acts, i, b, self.cfg.inner_model.num_steps_conditioning, c, H, W) 135 | obs = all_obs[:, n + i] 136 | mask = batch.mask_padding[:, n + i] 137 | 138 | if self.cfg.noise_previous_obs: 139 | sigma_cond = self.sample_sigma_training(b, self.device) 140 | prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise) 141 | else: 142 | sigma_cond = None 143 | 144 | if self.is_upsampler: 145 | prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1) 146 | 147 | sigma = self.sample_sigma_training(b, self.device) 148 | noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise) 149 | 150 | cs = self.compute_conditioners(sigma, sigma_cond) 151 | model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs) 152 | 153 | target = (obs - cs.c_skip * noisy_obs) / cs.c_out 154 | loss += F.mse_loss(model_output[mask], target[mask]) 155 | 156 | denoised = self.wrap_model_output(noisy_obs, model_output, cs) 157 | all_obs[:, n + i] = denoised 158 | 159 | metrics = {"loss_denoising": loss.item()/seq_length} 160 | batch_data = {"obs": all_obs[:, -seq_length:], 'act': batch.act[:, -seq_length:], 'mask_padding': batch.mask_padding[:, -seq_length:]} 161 | return loss, metrics, batch_data 162 | -------------------------------------------------------------------------------- /src/models/diffusion/diffusion_sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .denoiser import Denoiser 8 | 9 | 10 | @dataclass 11 | class DiffusionSamplerConfig: 12 | num_steps_denoising: int 13 | sigma_min: float = 2e-3 14 | sigma_max: float = 5 15 | rho: int = 7 16 | order: int = 1 17 | s_churn: float = 0 18 | s_tmin: float = 0 19 | s_tmax: float = float("inf") 20 | s_noise: float = 1 21 | s_cond: float = 0 22 | 23 | 24 | class DiffusionSampler: 25 | def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None: 26 | self.denoiser = denoiser 27 | self.cfg = cfg 28 | self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device) 29 | 30 | @torch.no_grad() 31 | def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]: 32 | device = prev_obs.device 33 | 34 | b, t, c, h, w = prev_obs.size() 35 | 36 | prev_obs = prev_obs.reshape(b, t * c, h, w) 37 | s_in = torch.ones(b, device=device) 38 | gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 39 | x = torch.randn(b, c, h, w, device=device) 40 | trajectory = [x] 41 | for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]): 42 | gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0 43 | sigma_hat = sigma * (gamma + 1) 44 | if gamma > 0: 45 | eps = torch.randn_like(x) * self.cfg.s_noise 46 | x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 47 | if self.cfg.s_cond > 0: 48 | sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device) 49 | prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0) 50 | else: 51 | sigma_cond = None 52 | denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act) 53 | d = (x - denoised) / sigma_hat 54 | dt = next_sigma - sigma_hat 55 | if self.cfg.order == 1 or next_sigma == 0: 56 | # Euler method 57 | x = x + d * dt 58 | else: 59 | # Heun's method 60 | x_2 = x + d * dt 61 | denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act) 62 | d_2 = (x_2 - denoised_2) / next_sigma 63 | d_prime = (d + d_2) / 2 64 | x = x + d_prime * dt 65 | trajectory.append(x) 66 | return x, trajectory 67 | 68 | 69 | def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor: 70 | min_inv_rho = sigma_min ** (1 / rho) 71 | max_inv_rho = sigma_max ** (1 / rho) 72 | l = torch.linspace(0, 1, num_steps, device=device) 73 | sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho 74 | return torch.cat((sigmas, sigmas.new_zeros(1))) 75 | -------------------------------------------------------------------------------- /src/models/diffusion/inner_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet 10 | 11 | 12 | @dataclass 13 | class InnerModelConfig: 14 | img_channels: int 15 | num_steps_conditioning: int 16 | cond_channels: int 17 | depths: List[int] 18 | channels: List[int] 19 | attn_depths: List[bool] 20 | num_actions: Optional[int] = None # set by trainer after env creation 21 | is_upsampler: Optional[bool] = None # set by Denoiser 22 | 23 | 24 | class InnerModel(nn.Module): 25 | def __init__(self, cfg: InnerModelConfig) -> None: 26 | super().__init__() 27 | self.noise_emb = FourierFeatures(cfg.cond_channels) 28 | self.noise_cond_emb = FourierFeatures(cfg.cond_channels) 29 | self.act_emb = None if cfg.is_upsampler else nn.Sequential( 30 | nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning), 31 | nn.Flatten(), # b t e -> b (t e) 32 | ) 33 | self.cond_proj = nn.Sequential( 34 | nn.Linear(cfg.cond_channels, cfg.cond_channels), 35 | nn.SiLU(), 36 | nn.Linear(cfg.cond_channels, cfg.cond_channels), 37 | ) 38 | self.conv_in = Conv3x3((cfg.num_steps_conditioning + int(cfg.is_upsampler) + 1) * cfg.img_channels, cfg.channels[0]) 39 | 40 | self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths) 41 | 42 | self.norm_out = GroupNorm(cfg.channels[0]) 43 | self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels) 44 | nn.init.zeros_(self.conv_out.weight) 45 | 46 | def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, c_noise_cond: Tensor, obs: Tensor, act: Optional[Tensor]) -> Tensor: 47 | if self.act_emb is not None: 48 | assert act.ndim == 2 or (act.ndim == 3 and act.size(2) == self.act_emb[0].num_embeddings and set(act.unique().tolist()).issubset(set([0, 1]))) 49 | act_emb = self.act_emb(act) if act.ndim == 2 else self.act_emb[1]((act.float() @ self.act_emb[0].weight)) 50 | else: 51 | assert act is None 52 | act_emb = 0 53 | 54 | cond = self.cond_proj(self.noise_emb(c_noise) + self.noise_cond_emb(c_noise_cond) + act_emb) 55 | x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1)) 56 | x, _, _ = self.unet(x, cond) 57 | x = self.conv_out(F.silu(self.norm_out(x))) 58 | return x 59 | -------------------------------------------------------------------------------- /src/models/rew_end_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torcheval.metrics.functional import multiclass_confusion_matrix 9 | 10 | from .blocks import Conv3x3, Downsample, ResBlocks 11 | from data import Batch 12 | from utils import init_lstm, LossAndLogs 13 | 14 | 15 | @dataclass 16 | class RewEndModelConfig: 17 | lstm_dim: int 18 | img_channels: int 19 | img_size: int 20 | cond_channels: int 21 | depths: List[int] 22 | channels: List[int] 23 | attn_depths: List[int] 24 | num_actions: Optional[int] = None 25 | 26 | 27 | class RewEndModel(nn.Module): 28 | def __init__(self, cfg: RewEndModelConfig) -> None: 29 | super().__init__() 30 | self.cfg = cfg 31 | self.encoder = RewEndEncoder(2 * cfg.img_channels, cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths) 32 | self.act_emb = nn.Embedding(cfg.num_actions, cfg.cond_channels) 33 | input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (len(cfg.depths) - 1)) ** 2 34 | self.lstm = nn.LSTM(input_dim_lstm, cfg.lstm_dim, batch_first=True) 35 | self.head = nn.Sequential( 36 | nn.Linear(cfg.lstm_dim, cfg.lstm_dim), 37 | nn.SiLU(), 38 | nn.Linear(cfg.lstm_dim, 3 + 2, bias=False), 39 | ) 40 | init_lstm(self.lstm) 41 | 42 | def predict_rew_end( 43 | self, 44 | obs: Tensor, 45 | act: Tensor, 46 | next_obs: Tensor, 47 | hx_cx: Optional[Tuple[Tensor, Tensor]] = None, 48 | ) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]: 49 | b, t, c, h, w = obs.shape 50 | obs, act, next_obs = obs.reshape(b * t, c, h, w), act.reshape(b * t), next_obs.reshape(b * t, c, h, w) 51 | x = self.encoder(torch.cat((obs, next_obs), dim=1), self.act_emb(act)) 52 | x = x.reshape(b, t, -1) # (b t) e h w -> b t (e h w) 53 | x, hx_cx = self.lstm(x, hx_cx) 54 | logits = self.head(x) 55 | return logits[:, :, :-2], logits[:, :, -2:], hx_cx 56 | 57 | def forward(self, batch: Batch) -> LossAndLogs: 58 | obs = batch.obs[:, :-1] 59 | act = batch.act[:, :-1] 60 | next_obs = batch.obs[:, 1:] 61 | rew = batch.rew[:, :-1] 62 | end = batch.end[:, :-1] 63 | mask = batch.mask_padding[:, :-1] 64 | 65 | # When dead, replace frame (gray padding) by true final obs 66 | dead = end.bool().any(dim=1) 67 | if dead.any(): 68 | final_obs = torch.stack([i["final_observation"] for i, d in zip(batch.info, dead) if d]).to(obs.device) 69 | next_obs[dead, end[dead].argmax(dim=1)] = final_obs 70 | 71 | logits_rew, logits_end, _ = self.predict_rew_end(obs, act, next_obs) 72 | logits_rew = logits_rew[mask] 73 | logits_end = logits_end[mask] 74 | target_rew = rew[mask].sign().long().add(1) # clipped to {-1, 0, 1} 75 | target_end = end[mask] 76 | 77 | loss_rew = F.cross_entropy(logits_rew, target_rew) 78 | loss_end = F.cross_entropy(logits_end, target_end) 79 | loss = loss_rew + loss_end 80 | 81 | metrics = { 82 | "loss_rew": loss_rew.detach(), 83 | "loss_end": loss_end.detach(), 84 | "loss_total": loss.detach(), 85 | "confusion_matrix": { 86 | "rew": multiclass_confusion_matrix(logits_rew, target_rew, num_classes=3), 87 | "end": multiclass_confusion_matrix(logits_end, target_end, num_classes=2), 88 | }, 89 | } 90 | return loss, metrics 91 | 92 | 93 | class RewEndEncoder(nn.Module): 94 | def __init__( 95 | self, 96 | in_channels: int, 97 | cond_channels: int, 98 | depths: List[int], 99 | channels: List[int], 100 | attn_depths: List[int], 101 | ) -> None: 102 | super().__init__() 103 | assert len(depths) == len(channels) == len(attn_depths) 104 | self.conv_in = Conv3x3(in_channels, channels[0]) 105 | blocks = [] 106 | for i, n in enumerate(depths): 107 | c1 = channels[max(0, i - 1)] 108 | c2 = channels[i] 109 | blocks.append( 110 | ResBlocks( 111 | list_in_channels=[c1] + [c2] * (n - 1), 112 | list_out_channels=[c2] * n, 113 | cond_channels=cond_channels, 114 | attn=attn_depths[i], 115 | ) 116 | ) 117 | blocks.append( 118 | ResBlocks( 119 | list_in_channels=[channels[-1]] * 2, 120 | list_out_channels=[channels[-1]] * 2, 121 | cond_channels=cond_channels, 122 | attn=True, 123 | ) 124 | ) 125 | self.blocks = nn.ModuleList(blocks) 126 | self.downsamples = nn.ModuleList([nn.Identity()] + [Downsample(c) for c in channels[:-1]] + [nn.Identity()]) 127 | 128 | def forward(self, x: Tensor, cond: Tensor) -> Tensor: 129 | x = self.conv_in(x) 130 | for block, down in zip(self.blocks, self.downsamples): 131 | x = down(x) 132 | x, _ = block(x, cond) 133 | return x 134 | -------------------------------------------------------------------------------- /src/play.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from huggingface_hub import snapshot_download 5 | from hydra import compose, initialize 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig, OmegaConf 8 | import torch 9 | 10 | from agent import Agent 11 | from envs import WorldModelEnv 12 | from game import Game, PlayEnv 13 | 14 | 15 | OmegaConf.register_new_resolver("eval", eval) 16 | 17 | 18 | def parse_args() -> argparse.Namespace: 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.") 21 | parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.") 22 | parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.") 23 | parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.") 24 | parser.add_argument("--compile", action="store_true", help="Turn on model compilation.") 25 | parser.add_argument("--fps", type=int, default=30, help="Frame rate.") 26 | parser.add_argument("--no-header", action="store_true") 27 | return parser.parse_args() 28 | 29 | 30 | def check_args(args: argparse.Namespace) -> None: 31 | if not args.record and (args.store_denoising_trajectory or args.store_original_obs): 32 | print("Warning: not in recording mode, ignoring --store* options") 33 | return True 34 | 35 | 36 | def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv: 37 | path_hf = Path(snapshot_download(repo_id="Enigma-AI/multiverse")) 38 | 39 | path_ckpt = path_hf / 'agent.pt' 40 | spawn_dir = Path('.') / 'game/spawn' 41 | # Override config 42 | cfg.agent = OmegaConf.load("config/agent/racing.yaml") 43 | cfg.env = OmegaConf.load("config/env/racing.yaml") 44 | 45 | if torch.cuda.is_available(): 46 | device = torch.device("cuda:0") 47 | elif torch.backends.mps.is_available(): 48 | device = torch.device("mps") 49 | else: 50 | device = torch.device("cpu") 51 | 52 | print("----------------------------------------------------------------------") 53 | print(f"Using {device} for rendering.") 54 | if not torch.cuda.is_available() and not torch.backends.mps.is_available(): # warn in case CUDA isn't being used (not on MPS devices) 55 | print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.") 56 | print("----------------------------------------------------------------------") 57 | 58 | assert cfg.env.train.id == "racing" 59 | num_actions = cfg.env.num_actions 60 | 61 | # Models 62 | agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval() 63 | agent.load(path_ckpt) 64 | 65 | # World model environment 66 | sl = cfg.agent.denoiser.inner_model.num_steps_conditioning 67 | if agent.upsampler is not None: 68 | sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) 69 | wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1) 70 | wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True) 71 | 72 | if device.type == "cuda" and args.compile: 73 | print("Compiling models...") 74 | wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead") 75 | wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead") 76 | 77 | play_env = PlayEnv( 78 | agent, 79 | wm_env, 80 | args.record, 81 | args.store_denoising_trajectory, 82 | args.store_original_obs, 83 | ) 84 | 85 | return play_env 86 | 87 | 88 | @torch.no_grad() 89 | def main(): 90 | args = parse_args() 91 | ok = check_args(args) 92 | if not ok: 93 | return 94 | 95 | with initialize(version_base="1.3", config_path="../config"): 96 | cfg = compose(config_name="trainer") 97 | 98 | # window size 99 | h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size 100 | size_h, size_w = h, w 101 | env = prepare_play_mode(cfg, args) 102 | game = Game(env, (size_h, size_w), fps=args.fps, verbose=not args.no_header) 103 | game.run() 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /src/player/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/player/__init__.py -------------------------------------------------------------------------------- /src/player/action_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: some parts are taken and modified from the file `config.py` from https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/ 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Dict, List, Set 7 | 8 | import numpy as np 9 | import pygame 10 | import torch 11 | 12 | from .keymap import GAME_FORBIDDEN_COMBINATIONS, GAME_KEYMAP 13 | 14 | 15 | @dataclass 16 | class GameAction: 17 | keys: List[int] 18 | 19 | def __post_init__(self) -> None: 20 | self.keys = filter_keys_pressed_forbidden(self.keys) 21 | # self.process_mouse() 22 | 23 | @property 24 | def key_names(self) -> List[str]: 25 | return [pygame.key.name(key) for key in self.keys] 26 | 27 | 28 | def print_game_action(action: GameAction) -> str: 29 | action_names = [GAME_KEYMAP[k] for k in action.keys] if len(action.keys) > 0 else [] 30 | action_names = [x for x in action_names if not x.startswith("camera_")] 31 | keys = " + ".join(action_names) 32 | return keys 33 | 34 | 35 | N_PLAYERS = 2 36 | N_GAS_KAVIM = 11 37 | N_BREX_KAVIM = 11 38 | N_STEERING = 11 39 | N_KEYS = 8 # number of keyboard outputs, w,s,a,d,up,down,left,right 40 | 41 | 42 | def encode_game_action(game_action: GameAction, device: torch.device) -> torch.Tensor: 43 | p1_gas = torch.zeros(N_GAS_KAVIM) 44 | p2_gas = torch.zeros(N_GAS_KAVIM) 45 | 46 | p1_brex = torch.zeros(N_BREX_KAVIM) 47 | p2_brex = torch.zeros(N_BREX_KAVIM) 48 | 49 | p1_steer = torch.zeros(N_STEERING) 50 | p2_steer = torch.zeros(N_STEERING) 51 | 52 | p1_is_steer = False 53 | p2_is_steer = False 54 | 55 | for key in game_action.key_names: 56 | if key == "w": 57 | p1_gas[N_GAS_KAVIM - 1] = 1 58 | if key == "a": 59 | p1_steer[3] = 1 60 | p1_is_steer = True 61 | if key == "s": 62 | p1_brex[N_BREX_KAVIM - 1] = 1 63 | if key == "d": 64 | p1_steer[N_STEERING - 3] = 1 65 | p1_is_steer = True 66 | 67 | if key == "up": 68 | p2_gas[N_GAS_KAVIM - 1] = 1 69 | if key == "left": 70 | p2_steer[3] = 1 71 | p2_is_steer = True 72 | if key == "down": 73 | p2_brex[N_BREX_KAVIM - 1] = 1 74 | if key == "right": 75 | p2_steer[N_STEERING - 3] = 1 76 | p2_is_steer = True 77 | 78 | if not p1_is_steer: 79 | p1_steer[len(p1_steer) // 2] = 1 80 | 81 | if not any(p1_gas): 82 | p1_gas[0] = 1 83 | 84 | if not any(p1_brex): 85 | p1_brex[0] = 1 86 | 87 | if not p2_is_steer: 88 | p2_steer[len(p2_steer) // 2] = 1 89 | 90 | if not any(p2_gas): 91 | p2_gas[0] = 1 92 | 93 | if not any(p2_brex): 94 | p2_brex[0] = 1 95 | 96 | return torch.cat([p1_gas, p1_brex, p1_steer, p2_gas, p2_brex, p2_steer]).float().to(device) 97 | 98 | 99 | def decode_game_action(y_preds: torch.Tensor) -> GameAction: 100 | y_preds = y_preds.squeeze() 101 | keys_pred = y_preds[0:N_KEYS] 102 | 103 | keys_pressed = [] 104 | keys_pressed_onehot = np.round(keys_pred) 105 | if keys_pressed_onehot[0] == 1: 106 | keys_pressed.append("w") 107 | if keys_pressed_onehot[1] == 1: 108 | keys_pressed.append("a") 109 | if keys_pressed_onehot[2] == 1: 110 | keys_pressed.append("s") 111 | if keys_pressed_onehot[3] == 1: 112 | keys_pressed.append("d") 113 | if keys_pressed_onehot[4] == 1: 114 | keys_pressed.append("up") 115 | if keys_pressed_onehot[5] == 1: 116 | keys_pressed.append("left") 117 | if keys_pressed_onehot[6] == 1: 118 | keys_pressed.append("down") 119 | if keys_pressed_onehot[7] == 1: 120 | keys_pressed.append("right") 121 | 122 | keys_pressed = [pygame.key.key_code(x) for x in keys_pressed] 123 | 124 | return GameAction(keys_pressed) 125 | 126 | 127 | def filter_keys_pressed_forbidden(keys_pressed: List[int], keymap: Dict[int, str] = GAME_KEYMAP, 128 | forbidden_combinations: List[Set[str]] = GAME_FORBIDDEN_COMBINATIONS) -> List[int]: 129 | keys = set() 130 | names = set() 131 | for key in keys_pressed: 132 | if key not in keymap: 133 | continue 134 | name = keymap[key] 135 | keys.add(key) 136 | names.add(name) 137 | for forbidden in forbidden_combinations: 138 | if forbidden.issubset(names): 139 | keys.remove(key) 140 | names.remove(name) 141 | break 142 | return list(filter(lambda key: key in keys, keys_pressed)) 143 | -------------------------------------------------------------------------------- /src/player/keymap.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | 3 | 4 | GAME_KEYMAP = { 5 | pygame.K_w: "p1_up", 6 | pygame.K_d: "p1_right", 7 | pygame.K_a: "p1_left", 8 | pygame.K_s: "p1_down", 9 | 10 | pygame.K_UP: "p2_up", 11 | pygame.K_RIGHT: "p2_right", 12 | pygame.K_LEFT: "p2_left", 13 | pygame.K_DOWN: "p2_down", 14 | } 15 | 16 | GAME_FORBIDDEN_COMBINATIONS = {} -------------------------------------------------------------------------------- /src/process_denoiser_files.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | from pathlib import Path 4 | from multiprocessing import Pool 5 | import shutil 6 | 7 | import torchvision.transforms.functional as T 8 | from tqdm import tqdm 9 | 10 | from data.dataset import Dataset, GameHdf5Dataset 11 | from data.episode import Episode 12 | from data.segment import SegmentId 13 | 14 | import os 15 | 16 | PREFIX = "" 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "tar_dir", 23 | type=Path, 24 | help="folder containing the .tar files from `dataset_tars` folder from Huggingface", 25 | ) 26 | parser.add_argument( 27 | "out_dir", 28 | type=Path, 29 | help="a new directory (should not exist already), the script will untar and process data there", 30 | ) 31 | return parser.parse_args() 32 | 33 | 34 | def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None: 35 | d = path_tar.stem 36 | assert path_tar.stem.startswith(PREFIX) 37 | d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_")) 38 | d.mkdir(exist_ok=False, parents=True) 39 | shutil.copy(path_tar, d / os.path.basename(path_tar)) 40 | new_path_tar = d / path_tar.name 41 | if remove_tar: 42 | new_path_tar.unlink() 43 | else: 44 | shutil.copy(new_path_tar, path_tar.parent) 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | tar_dir = args.tar_dir.absolute() 51 | out_dir = args.out_dir.absolute() 52 | 53 | if not tar_dir.exists(): 54 | print( 55 | "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)" 56 | ) 57 | return 58 | 59 | if out_dir.exists(): 60 | print(f"Wrong usage: the output directory should not exist ({args.out_dir})") 61 | return 62 | 63 | with Path("test_split.txt").open("r") as f: 64 | test_files = f.read().split("\n") 65 | 66 | full_res_dir = out_dir / "full_res" 67 | low_res_dir = out_dir / "low_res" 68 | 69 | hdf5_files = [ 70 | x for x in tar_dir.iterdir() if x.suffix == ".hdf5" and x.stem.startswith(PREFIX) 71 | ] 72 | n = len(hdf5_files) 73 | 74 | str_files = "\n".join(map(str, hdf5_files)) 75 | print(f"Ready to untar {n} tar files:\n{str_files}") 76 | 77 | remove_tar = False 78 | 79 | # Untar game files 80 | f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar) 81 | with Pool(n) as p: 82 | p.map(f, hdf5_files) 83 | 84 | print(f"{n} .tar files unpacked in {full_res_dir}") 85 | 86 | # 87 | # Create low-res data 88 | # 89 | 90 | game_dataset = GameHdf5Dataset(full_res_dir) 91 | 92 | train_dataset = Dataset(low_res_dir / "train", None) 93 | test_dataset = Dataset(low_res_dir / "test", None) 94 | 95 | for i in tqdm(game_dataset._filenames, desc="Creating low_res"): 96 | episode_length = game_dataset._length_one_episode[i] 97 | episode = Episode( 98 | **{ 99 | k: v 100 | for k, v in game_dataset[SegmentId(i, 0, episode_length, None)].__dict__.items() 101 | if k not in ("mask_padding", "id") 102 | } 103 | ) 104 | episode.obs = T.resize( 105 | episode.obs, (48, 64), interpolation=T.InterpolationMode.BICUBIC 106 | ) 107 | filename = game_dataset._filenames[i] 108 | file_id = f"{filename.parent.stem}/{filename.name}" 109 | episode.info = {"original_file_id": file_id} 110 | dataset = test_dataset if filename.name in test_files else train_dataset 111 | dataset.add_episode(episode) 112 | 113 | train_dataset.save_to_default_path() 114 | test_dataset.save_to_default_path() 115 | 116 | print( 117 | f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n" 118 | ) 119 | 120 | print("You can now edit `config/env/racing.yaml` and set:") 121 | print(f"path_data_low_res: {low_res_dir}") 122 | print(f"path_data_full_res: {full_res_dir}") 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /src/process_upsampler_files.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | from pathlib import Path 4 | from multiprocessing import Pool 5 | import shutil 6 | 7 | import torchvision.transforms.functional as T 8 | from tqdm import tqdm 9 | 10 | from data.dataset import Dataset, GameHdf5Dataset 11 | from data.episode import Episode 12 | from data.segment import SegmentId 13 | 14 | import os 15 | 16 | PREFIX = "" 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "tar_dir", 23 | type=Path, 24 | help="folder containing the .tar files from `dataset_tars` folder from Huggingface", 25 | ) 26 | parser.add_argument( 27 | "out_dir", 28 | type=Path, 29 | help="a new directory (should not exist already), the script will untar and process data there", 30 | ) 31 | return parser.parse_args() 32 | 33 | 34 | def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None: 35 | d = path_tar.stem 36 | assert path_tar.stem.startswith(PREFIX) 37 | d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_")) 38 | d.mkdir(exist_ok=False, parents=True) 39 | shutil.copy(path_tar, d / os.path.basename(path_tar)) 40 | new_path_tar = d / path_tar.name 41 | if remove_tar: 42 | new_path_tar.unlink() 43 | else: 44 | shutil.copy(new_path_tar, path_tar.parent) 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | tar_dir = args.tar_dir.absolute() 51 | out_dir = args.out_dir.absolute() 52 | 53 | if not tar_dir.exists(): 54 | print( 55 | "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)" 56 | ) 57 | return 58 | 59 | if out_dir.exists(): 60 | print(f"Wrong usage: the output directory should not exist ({args.out_dir})") 61 | return 62 | 63 | with Path("test_split.txt").open("r") as f: 64 | test_files = f.read().split("\n") 65 | 66 | full_res_dir = out_dir / "full_res" 67 | low_res_dir = out_dir / "low_res" 68 | 69 | hdf5_files = [ 70 | x for x in tar_dir.iterdir() if x.suffix == ".hdf5" and x.stem.startswith(PREFIX) 71 | ] 72 | n = len(hdf5_files) 73 | 74 | str_files = "\n".join(map(str, hdf5_files)) 75 | print(f"Ready to untar {n} tar files:\n{str_files}") 76 | 77 | remove_tar = False 78 | 79 | # Untar game files 80 | f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar) 81 | with Pool(n) as p: 82 | p.map(f, hdf5_files) 83 | 84 | print(f"{n} .tar files unpacked in {full_res_dir}") 85 | 86 | # 87 | # Create low-res data 88 | # 89 | 90 | game_dataset = GameHdf5Dataset(full_res_dir) 91 | 92 | train_dataset = Dataset(low_res_dir / "train", None) 93 | test_dataset = Dataset(low_res_dir / "test", None) 94 | 95 | for i in tqdm(game_dataset._filenames, desc="Creating low_res"): 96 | episode_length = game_dataset._length_one_episode[i] 97 | episode = Episode( 98 | **{ 99 | k: v 100 | for k, v in game_dataset[SegmentId(i, 0, episode_length, None)].__dict__.items() 101 | if k not in ("mask_padding", "id") 102 | } 103 | ) 104 | episode.obs = T.resize( 105 | episode.obs, (35, 53), interpolation=T.InterpolationMode.BICUBIC 106 | ) 107 | filename = game_dataset._filenames[i] 108 | file_id = f"{filename.parent.stem}/{filename.name}" 109 | episode.info = {"original_file_id": file_id} 110 | dataset = test_dataset if filename.name in test_files else train_dataset 111 | dataset.add_episode(episode) 112 | 113 | train_dataset.save_to_default_path() 114 | test_dataset.save_to_default_path() 115 | 116 | print( 117 | f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n" 118 | ) 119 | 120 | print("You can now edit `config/env/racing.yaml` and set:") 121 | print(f"path_data_low_res: {low_res_dir}") 122 | print(f"path_data_full_res: {full_res_dir}") 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /src/spawn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import os 4 | import random 5 | import h5py 6 | import numpy as np 7 | import cv2 8 | 9 | low_res_w = 64 10 | low_res_h = 48 11 | 12 | crop_frame = { 13 | 'left_top': (0.04, 0.18), 14 | 'right_bottom': (0.92, 0.95) 15 | } 16 | 17 | def extract_roi(image, rect): 18 | img_width, img_height, _ = image.shape 19 | min_x = int(rect['left_top'][1] * img_width) 20 | max_x = int(rect['right_bottom'][1] * img_width) 21 | min_y = int(rect['left_top'][0] * img_height) 22 | max_y = int(rect['right_bottom'][0] * img_height) 23 | roi = image[min_x:max_x, min_y:max_y] 24 | return roi 25 | 26 | def rescale_image(image, scale_factor): 27 | # Get new dimensions 28 | new_width = int(image.shape[1] * scale_factor) 29 | new_height = int(image.shape[0] * scale_factor) 30 | 31 | # Resize the image 32 | return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument( 38 | "full_res_directory", 39 | type=Path, 40 | help="Specify your full_res directory.", 41 | ) 42 | 43 | parser.add_argument( 44 | "model_directory", 45 | type=Path, 46 | help="Specify ", 47 | ) 48 | return parser.parse_args() 49 | 50 | 51 | def main(): 52 | args = parse_args() 53 | 54 | full_res_directory = args.full_res_directory.absolute() 55 | model_directory = args.model_directory.absolute() 56 | 57 | spawn_dir = model_directory / "game/spawn" 58 | existing_spawns = len(os.listdir(spawn_dir)) 59 | os.makedirs(spawn_dir / str(existing_spawns), exist_ok=True) 60 | full_res = os.path.join(full_res_directory, random.choice(os.listdir(full_res_directory))) 61 | h5file = h5py.File(full_res, 'r') 62 | i = random.randint(0, 1000) 63 | 64 | data_x_frames = [] 65 | data_y_frames = [] 66 | next_act_frames = [] 67 | low_res_frames = [] 68 | 69 | for j in range(20): 70 | frame_x = f'frame_{i}_x' 71 | frame_y = f'frame_{i}_y' 72 | 73 | # Check if the datasets exist in the file 74 | if frame_x in h5file and frame_y in h5file: 75 | # Append each frame to the lists 76 | data_x = h5file[frame_x][:] 77 | data_y = h5file[frame_y][:] 78 | 79 | img1 = cv2.cvtColor(data_x[:,:,3:], cv2.COLOR_BGR2RGB) 80 | img2 = cv2.cvtColor(data_x[:,:,:3], cv2.COLOR_BGR2RGB) 81 | 82 | img1_cropped = extract_roi(img1, crop_frame) 83 | img2_cropped = extract_roi(img2, crop_frame) 84 | img1_cropped = cv2.resize(img1_cropped, (530, 350), interpolation=cv2.INTER_AREA) 85 | img2_cropped = cv2.resize(img2_cropped, (530, 350), interpolation=cv2.INTER_AREA) 86 | 87 | data_x_frames.append(np.concatenate([img1_cropped, img2_cropped], axis=2)) 88 | data_y_frames.append(data_y) 89 | 90 | img1 = cv2.resize(img1, (low_res_w, low_res_h), interpolation=cv2.INTER_AREA) 91 | img2 = cv2.resize(img2, (low_res_w, low_res_h), interpolation=cv2.INTER_AREA) 92 | 93 | low_res_frames.append(np.concatenate([img1, img2], axis=2).astype(np.uint8)) 94 | else: 95 | print(f"One or both of {frame_x} or {frame_y} do not exist in the file.") 96 | i += 1 97 | for _ in range(200): 98 | next_act = f'frame_{i}_y' 99 | if next_act in h5file: 100 | next_act_data = h5file[next_act][:] 101 | next_act_frames.append(next_act_data) 102 | 103 | data_x_stacked = np.stack(data_x_frames) 104 | data_y_stacked = np.stack(data_y_frames) 105 | next_act_stacked = np.stack(next_act_frames) 106 | low_res_stacked = np.stack(low_res_frames) 107 | 108 | low_res_stacked = np.transpose(low_res_stacked, (0, 3, 1, 2)) 109 | data_x_stacked = np.transpose(data_x_stacked, (0, 3, 1, 2)) 110 | 111 | print(f"Saving act.npy of size {data_y_stacked.shape}") 112 | np.save(spawn_dir / f"{existing_spawns}/act.npy", data_y_stacked) 113 | print(f"Saving full_res.npy of size {data_x_stacked.shape}") 114 | np.save(spawn_dir / f"{existing_spawns}/full_res.npy", data_x_stacked) 115 | print(f"Saving next_act.npy of size {next_act_stacked.shape}") 116 | np.save(spawn_dir / f"{existing_spawns}/next_act.npy", next_act_stacked) 117 | print(f"Saving low_res.npy of size {low_res_stacked.shape}") 118 | np.save(spawn_dir / f"{existing_spawns}/low_res.npy", low_res_stacked) 119 | 120 | h5file.close() 121 | 122 | 123 | if __name__ == "__main__": 124 | main() -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | import shutil 4 | import time 5 | from typing import List, Optional, Tuple 6 | 7 | from hydra.utils import instantiate 8 | import numpy as np 9 | from omegaconf import DictConfig, OmegaConf 10 | import torch 11 | import torch.distributed as dist 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm, trange 14 | import wandb 15 | 16 | from agent import Agent 17 | from coroutines.collector import make_collector, NumToCollect 18 | from data import BatchSampler, collate_segments_to_batch, Dataset, DatasetTraverser, GameHdf5Dataset 19 | from envs import make_atari_env, WorldModelEnv 20 | from utils import ( 21 | broadcast_if_needed, 22 | build_ddp_wrapper, 23 | CommonTools, 24 | configure_opt, 25 | count_parameters, 26 | get_lr_sched, 27 | keep_agent_copies_every, 28 | Logs, 29 | move_opt_to, 30 | process_confusion_matrices_if_any_and_compute_classification_metrics, 31 | save_info_for_import_script, 32 | save_with_backup, 33 | set_seed, 34 | StateDictMixin, 35 | try_until_no_except, 36 | wandb_log, 37 | get_frame_indices, 38 | build_pages_per_epoch, 39 | find_maximum_key_below_threshold 40 | ) 41 | 42 | 43 | class Trainer(StateDictMixin): 44 | def __init__(self, cfg: DictConfig, root_dir: Path) -> None: 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | OmegaConf.resolve(cfg) 47 | self._cfg = cfg 48 | self._rank = dist.get_rank() if dist.is_initialized() else 0 49 | self._world_size = dist.get_world_size() if dist.is_initialized() else 1 50 | 51 | # Pick a random seed 52 | set_seed(torch.seed() % 10 ** 9) 53 | 54 | # Device 55 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu", self._rank) 56 | print(f"Starting on {self._device}") 57 | self._use_cuda = self._device.type == "cuda" 58 | if self._use_cuda: 59 | torch.cuda.set_device(self._rank) # fix compilation error on multi-gpu nodes 60 | 61 | # Init wandb 62 | if self._rank == 0: 63 | try_until_no_except( 64 | partial(wandb.init, config=OmegaConf.to_container(cfg, resolve=True), reinit=True, resume=True, 65 | **cfg.wandb) 66 | ) 67 | 68 | # Flags 69 | self._is_static_dataset = cfg.static_dataset.path is not None 70 | self._is_model_free = cfg.training.model_free 71 | 72 | # Checkpointing 73 | self._path_ckpt_dir = Path("checkpoints") 74 | self._path_state_ckpt = self._path_ckpt_dir / "state.pt" 75 | self._keep_agent_copies = partial( 76 | keep_agent_copies_every, 77 | every=cfg.checkpointing.save_agent_every, 78 | path_ckpt_dir=self._path_ckpt_dir, 79 | num_to_keep=cfg.checkpointing.num_to_keep, 80 | ) 81 | self._save_info_for_import_script = partial( 82 | save_info_for_import_script, run_name=cfg.wandb.name, path_ckpt_dir=self._path_ckpt_dir 83 | ) 84 | 85 | # First time, init files hierarchy 86 | if not cfg.common.resume and self._rank == 0: 87 | self._path_ckpt_dir.mkdir(exist_ok=False, parents=False) 88 | path_config = Path("config") / "trainer.yaml" 89 | path_config.parent.mkdir(exist_ok=False, parents=False) 90 | shutil.move(".hydra/config.yaml", path_config) 91 | wandb.save(str(path_config)) 92 | shutil.copytree(src=root_dir / "src", dst="./src") 93 | shutil.copytree(src=root_dir / "scripts", dst="./scripts") 94 | 95 | if cfg.env.train.id == "racing": 96 | assert cfg.env.path_data_low_res is not None and cfg.env.path_data_full_res is not None, "Make sure to download GT4 data and set the relevant paths in cfg.env" 97 | assert self._is_static_dataset 98 | num_actions = cfg.env.num_actions 99 | dataset_full_res = GameHdf5Dataset(Path(cfg.env.path_data_full_res)) 100 | 101 | # Envs (atari only) 102 | else: 103 | if self._rank == 0: 104 | train_env = make_atari_env(num_envs=cfg.collection.train.num_envs, device=self._device, **cfg.env.train) 105 | test_env = make_atari_env(num_envs=cfg.collection.test.num_envs, device=self._device, **cfg.env.test) 106 | num_actions = int(test_env.num_actions) 107 | else: 108 | num_actions = None 109 | num_actions, = broadcast_if_needed(num_actions) 110 | dataset_full_res = None 111 | 112 | num_workers = cfg.training.num_workers_data_loaders 113 | use_manager = cfg.training.cache_in_ram and (num_workers > 0) 114 | p = Path(cfg.static_dataset.path) if self._is_static_dataset else Path("dataset") 115 | self.train_dataset = Dataset(p / "train", dataset_full_res, "train_dataset", cfg.training.cache_in_ram, 116 | use_manager) 117 | self.test_dataset = Dataset(p / "test", dataset_full_res, "test_dataset", cache_in_ram=True) 118 | self.train_dataset.load_from_default_path() 119 | self.test_dataset.load_from_default_path() 120 | 121 | # Create models 122 | self.agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(self._device) 123 | self._agent = build_ddp_wrapper(**self.agent._modules) if dist.is_initialized() else self.agent 124 | 125 | if cfg.initialization.path_to_ckpt is not None: 126 | self.agent.load(**cfg.initialization) 127 | 128 | # Collectors 129 | if not self._is_static_dataset and self._rank == 0: 130 | self._train_collector = make_collector( 131 | train_env, self.agent.actor_critic, self.train_dataset, cfg.collection.train.epsilon 132 | ) 133 | self._test_collector = make_collector( 134 | test_env, self.agent.actor_critic, self.test_dataset, cfg.collection.test.epsilon, 135 | reset_every_collect=True 136 | ) 137 | 138 | ###################################################### 139 | 140 | # Optimizers and LR schedulers 141 | 142 | def build_opt(name: str) -> torch.optim.AdamW: 143 | return configure_opt(getattr(self.agent, name), **getattr(cfg, name).optimizer) 144 | 145 | def build_lr_sched(name: str) -> torch.optim.lr_scheduler.LambdaLR: 146 | return get_lr_sched(self.opt.get(name), getattr(cfg, name).training.lr_warmup_steps) 147 | 148 | model_names = [self._cfg.train_model] 149 | self._model_names = ["actor_critic"] if self._is_model_free else [name for name in model_names if 150 | getattr(self.agent, name) is not None] 151 | 152 | self.opt = CommonTools(**{name: build_opt(name) for name in self._model_names}) 153 | self.lr_sched = CommonTools(**{name: build_lr_sched(name) for name in self._model_names}) 154 | 155 | # Data loaders 156 | 157 | make_data_loader = partial( 158 | DataLoader, 159 | dataset=self.train_dataset, 160 | collate_fn=collate_segments_to_batch, 161 | num_workers=num_workers, 162 | persistent_workers=(num_workers > 0), 163 | pin_memory=self._use_cuda, 164 | pin_memory_device=str(self._device) if self._use_cuda else "", 165 | ) 166 | 167 | make_batch_sampler = partial(BatchSampler, self.train_dataset, self._rank, self._world_size) 168 | 169 | def get_sample_weights(sample_weights: List[float]) -> Optional[List[float]]: 170 | return None if (self._is_static_dataset and cfg.static_dataset.ignore_sample_weights) else sample_weights 171 | 172 | c = cfg.denoiser.training 173 | # in case of distributed training, the cfg needs to be taken from the module attribute 174 | agent_cfg = self._agent.denoiser.cfg if not hasattr(self._agent.denoiser, "module") else self._agent.denoiser.module.cfg 175 | effective_context_length = int(get_frame_indices(agent_cfg.frame_sampling)[-1]) + 1 176 | seq_length = effective_context_length + 1 + c.num_autoregressive_steps 177 | bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights), False, c.num_autoregressive_steps, c.initial_num_consecutive_page_count) 178 | dl_denoiser_train = make_data_loader(batch_sampler=bs) 179 | dl_denoiser_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length) 180 | 181 | self.pages_per_epoch = build_pages_per_epoch(c.num_consecutive_pages) 182 | 183 | if self.agent.upsampler is not None: 184 | c = cfg.upsampler.training 185 | seq_length = cfg.agent.upsampler.inner_model.num_steps_conditioning + 1 + c.num_autoregressive_steps 186 | bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights), False, c.num_autoregressive_steps, c.initial_num_consecutive_page_count) 187 | dl_upsampler_train = make_data_loader(batch_sampler=bs) 188 | dl_upsampler_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length) 189 | else: 190 | dl_upsampler_train = dl_upsampler_test = None 191 | 192 | if self.agent.rew_end_model is not None: 193 | c = cfg.rew_end_model.training 194 | bs = make_batch_sampler(c.batch_size, c.seq_length, get_sample_weights(c.sample_weights), 195 | can_sample_beyond_end=True) 196 | dl_rew_end_model_train = make_data_loader(batch_sampler=bs) 197 | dl_rew_end_model_test = DatasetTraverser(self.test_dataset, c.batch_size, c.seq_length) 198 | else: 199 | dl_rew_end_model_train = dl_rew_end_model_test = None 200 | 201 | self._data_loader_train = CommonTools(dl_denoiser_train, dl_upsampler_train, dl_rew_end_model_train, None) 202 | self._data_loader_test = CommonTools(dl_denoiser_test, dl_upsampler_test, dl_rew_end_model_test, None) 203 | 204 | # RL env 205 | 206 | if self.agent.actor_critic is not None: 207 | actor_critic_loss_cfg = instantiate(cfg.actor_critic.actor_critic_loss) 208 | 209 | if self._is_model_free: 210 | assert self.agent.actor_critic is not None 211 | rl_env = make_atari_env(num_envs=cfg.actor_critic.training.batch_size, device=self._device, 212 | **cfg.env.train) 213 | 214 | else: 215 | c = cfg.actor_critic.training 216 | sl = cfg.agent.denoiser.inner_model.num_steps_conditioning 217 | if self.agent.upsampler is not None: 218 | sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) 219 | bs = make_batch_sampler(c.batch_size, sl, get_sample_weights(c.sample_weights)) 220 | dl_actor_critic = make_data_loader(batch_sampler=bs) 221 | wm_env_cfg = instantiate(cfg.world_model_env) 222 | rl_env = WorldModelEnv(self.agent.denoiser, self.agent.upsampler, self.agent.rew_end_model, 223 | dl_actor_critic, wm_env_cfg) 224 | 225 | if cfg.training.compile_wm: 226 | rl_env.predict_next_obs = torch.compile(rl_env.predict_next_obs, mode="reduce-overhead") 227 | rl_env.predict_rew_end = torch.compile(rl_env.predict_rew_end, mode="reduce-overhead") 228 | else: 229 | actor_critic_loss_cfg = None 230 | rl_env = None 231 | 232 | # Setup training 233 | sigma_distribution_cfg = instantiate(cfg.denoiser.sigma_distribution) 234 | sigma_distribution_cfg_upsampler = instantiate( 235 | cfg.upsampler.sigma_distribution) if self.agent.upsampler is not None else None 236 | self.agent.setup_training(sigma_distribution_cfg, sigma_distribution_cfg_upsampler, actor_critic_loss_cfg, 237 | rl_env) 238 | 239 | # Training state (things to be saved/restored) 240 | self.epoch = 0 241 | self.num_epochs_collect = None 242 | self.num_episodes_test = 0 243 | self.num_batch_train = CommonTools(0, 0, 0) 244 | self.num_batch_test = CommonTools(0, 0, 0) 245 | 246 | if cfg.common.resume: 247 | # self.load_state_checkpoint() 248 | self.load_agent_state_checkpoint() 249 | else: 250 | self.save_checkpoint() 251 | 252 | if self._rank == 0: 253 | for name in self._model_names: 254 | print(f"{count_parameters(getattr(self.agent, name))} parameters in {name}") 255 | print(self.train_dataset) 256 | print(self.test_dataset) 257 | 258 | def run(self) -> None: 259 | to_log = [] 260 | 261 | if self.epoch == 0: 262 | if self._is_model_free or self._is_static_dataset: 263 | self.num_epochs_collect = 0 264 | else: 265 | if self._rank == 0: 266 | self.num_epochs_collect, to_log_ = self.collect_initial_dataset() 267 | to_log += to_log_ 268 | self.num_epochs_collect, sd_train_dataset = broadcast_if_needed(self.num_epochs_collect, 269 | self.train_dataset.state_dict()) 270 | self.train_dataset.load_state_dict(sd_train_dataset) 271 | 272 | num_epochs = self.num_epochs_collect + self._cfg.training.num_final_epochs 273 | 274 | while self.epoch < num_epochs: 275 | self.epoch += 1 276 | start_time = time.time() 277 | 278 | if self._rank == 0: 279 | print(f"\nEpoch {self.epoch} / {num_epochs}\n") 280 | 281 | # Training 282 | should_collect_train = ( 283 | self._rank == 0 and not self._is_model_free and not self._is_static_dataset and self.epoch <= self.num_epochs_collect) 284 | 285 | if should_collect_train: 286 | c = self._cfg.collection.train 287 | to_log += self._train_collector.send(NumToCollect(steps=c.steps_per_epoch)) 288 | sd_train_dataset, = broadcast_if_needed(self.train_dataset.state_dict()) # update dataset for ranks > 0 289 | self.train_dataset.load_state_dict(sd_train_dataset) 290 | 291 | if self._cfg.training.should: 292 | to_log += self.train_agent() 293 | 294 | # Evaluation 295 | should_test = self._rank == 0 and self._cfg.evaluation.should and ( 296 | self.epoch % self._cfg.evaluation.every == 0) 297 | should_collect_test = should_test and not self._is_static_dataset 298 | 299 | if should_collect_test: 300 | to_log += self.collect_test() 301 | 302 | if should_test and not self._is_model_free: 303 | to_log += self.test_agent() 304 | 305 | # Logging 306 | to_log.append({"duration": (time.time() - start_time) / 3600}) 307 | if self._rank == 0: 308 | wandb_log(to_log, self.epoch) 309 | to_log = [] 310 | 311 | # Checkpointing 312 | self.save_checkpoint() 313 | 314 | if dist.is_initialized(): 315 | dist.barrier() 316 | 317 | smallest_page_epoch = find_maximum_key_below_threshold(self.pages_per_epoch, self.epoch) 318 | if smallest_page_epoch is not None: 319 | self._data_loader_train.denoiser.batch_sampler.num_consecutive_batches = self.pages_per_epoch[smallest_page_epoch] 320 | 321 | # Last collect 322 | if self._rank == 0 and not self._is_static_dataset: 323 | wandb_log(self.collect_test(final=True), self.epoch) 324 | 325 | def collect_initial_dataset(self) -> Tuple[int, Logs]: 326 | print("\nInitial collect\n") 327 | to_log = [] 328 | c = self._cfg.collection.train 329 | min_steps = c.first_epoch.min 330 | steps_per_epoch = c.steps_per_epoch 331 | max_steps = c.first_epoch.max 332 | threshold_rew = c.first_epoch.threshold_rew 333 | assert min_steps % steps_per_epoch == 0 334 | 335 | steps = min_steps 336 | while True: 337 | to_log += self._train_collector.send(NumToCollect(steps=steps)) 338 | num_steps = self.train_dataset.num_steps 339 | total_minority_rew = sum(sorted(self.train_dataset.counts_rew)[:-1]) 340 | if total_minority_rew >= threshold_rew: 341 | break 342 | if (max_steps is not None) and num_steps >= max_steps: 343 | print("Reached the specified maximum for initial collect") 344 | break 345 | print(f"Minority reward: {total_minority_rew}/{threshold_rew} -> Keep collecting\n") 346 | steps = steps_per_epoch 347 | 348 | print("\nSummary of initial collect:") 349 | print(f"Num steps: {num_steps} / {c.num_steps_total}") 350 | print(f"Reward counts: {dict(self.train_dataset.counter_rew)}") 351 | 352 | remaining_steps = c.num_steps_total - num_steps 353 | assert remaining_steps % c.steps_per_epoch == 0 354 | num_epochs_collect = remaining_steps // c.steps_per_epoch 355 | 356 | return num_epochs_collect, to_log 357 | 358 | def collect_test(self, final: bool = False) -> Logs: 359 | c = self._cfg.collection.test 360 | episodes = c.num_final_episodes if final else c.num_episodes 361 | td = self.test_dataset 362 | td.clear() 363 | to_log = self._test_collector.send(NumToCollect(episodes=episodes)) 364 | key_ep_id = f"{td.name}/episode_id" 365 | to_log = [{k: v + self.num_episodes_test if k == key_ep_id else v for k, v in x.items()} for x in to_log] 366 | 367 | print(f"\nSummary of {'final' if final else 'test'} collect: {td.num_episodes} episodes ({td.num_steps} steps)") 368 | keys = [key_ep_id, "return", "length"] 369 | to_log_episodes = [x for x in to_log if set(x.keys()) == set(keys)] 370 | episode_ids, returns, lengths = [[d[k] for d in to_log_episodes] for k in keys] 371 | for i, (ep_id, ret, length) in enumerate(zip(episode_ids, returns, lengths)): 372 | print(f" Episode {ep_id}: return = {ret} length = {length}\n", end="\n" if i == episodes - 1 else "") 373 | 374 | self.num_episodes_test += episodes 375 | 376 | if final: 377 | to_log.append({"final_return_mean": np.mean(returns), "final_return_std": np.std(returns)}) 378 | print(to_log[-1]) 379 | 380 | return to_log 381 | 382 | def train_agent(self) -> Logs: 383 | self.agent.train() 384 | self.agent.zero_grad() 385 | to_log = [] 386 | for name in self._model_names: 387 | cfg = getattr(self._cfg, name).training 388 | if self.epoch > cfg.start_after_epochs: 389 | steps = cfg.steps_first_epoch if self.epoch == 1 else cfg.steps_per_epoch 390 | to_log += self.train_component(name, steps) 391 | return to_log 392 | 393 | @torch.no_grad() 394 | def test_agent(self) -> Logs: 395 | self.agent.eval() 396 | to_log = [] 397 | for name in self._model_names: 398 | if name == "actor_critic": 399 | continue 400 | cfg = getattr(self._cfg, name).training 401 | if self.epoch > cfg.start_after_epochs: 402 | to_log += self.test_component(name) 403 | return to_log 404 | 405 | def train_component(self, name: str, steps: int) -> Logs: 406 | cfg = getattr(self._cfg, name).training 407 | model = getattr(self._agent, name) 408 | opt = self.opt.get(name) 409 | lr_sched = self.lr_sched.get(name) 410 | data_loader = self._data_loader_train.get(name) 411 | 412 | torch.cuda.empty_cache() 413 | model.to(self._device) 414 | move_opt_to(opt, self._device) 415 | 416 | model.train() 417 | opt.zero_grad() 418 | data_iterator = iter(data_loader) if data_loader is not None else None 419 | to_log = [] 420 | 421 | num_steps = cfg.grad_acc_steps * steps 422 | 423 | # in case of distributed training, the cfg needs to be taken from the module attribute 424 | agent_cfg = self._agent.denoiser.cfg if not hasattr(self._agent.denoiser, "module") else self._agent.denoiser.module.cfg 425 | effective_context_length = 0 if name == 'upsampler' else get_frame_indices(agent_cfg.frame_sampling)[-1] + 1 426 | context_obs = None 427 | context_act = None 428 | context_mask_padding = None 429 | 430 | for i in trange(num_steps, desc=f"Training {name}", disable=self._rank > 0): 431 | curr_iter = 0 432 | total_length = 0 433 | while curr_iter < data_loader.batch_sampler.num_consecutive_batches: 434 | batch = next(data_iterator).to(self._device) if data_iterator is not None else None 435 | curr_iter += 1 436 | 437 | # initialize variables in the first segment 438 | if batch.segment_ids[0].is_first_batch: 439 | context_obs = batch.obs 440 | context_act = batch.act 441 | context_mask_padding = batch.mask_padding 442 | 443 | # build the observations until there's enough context 444 | while context_obs.shape[1] < effective_context_length + 1: 445 | batch = next(data_iterator).to(self._device) if data_iterator is not None else None 446 | curr_iter += 1 447 | 448 | context_obs = torch.concat([context_obs, batch.obs], dim=1) 449 | context_act = torch.concat([context_act, batch.act], dim=1) 450 | context_mask_padding = torch.concat([context_mask_padding, batch.mask_padding], dim=1) 451 | 452 | # split the context into the context obs and the obs to predict 453 | 454 | # future obs (after the effective context length) 455 | predict_obs = context_obs[:, effective_context_length:] 456 | predict_act = context_act[:, effective_context_length:] 457 | predict_mask_padding = context_mask_padding[:, effective_context_length:] 458 | 459 | # previous obs (before the effective context length) 460 | context_obs = context_obs[:, :effective_context_length] 461 | context_act = context_act[:, :effective_context_length] 462 | context_mask_padding = context_mask_padding[:, :effective_context_length] 463 | 464 | total_length = (data_loader.batch_sampler.num_consecutive_batches-1)*data_loader.batch_sampler.autoregressive_obs + data_loader.batch_sampler.seq_length - effective_context_length 465 | else: # set batch to be the frames to be predicted next 466 | predict_obs = batch.obs 467 | predict_act = batch.act 468 | predict_mask_padding = batch.mask_padding 469 | 470 | # build batch for prediction 471 | batch.obs = torch.cat([context_obs[:, -effective_context_length:], predict_obs], dim=1) 472 | batch.act = torch.cat([context_act[:, -effective_context_length:], predict_act], dim=1) 473 | batch.mask_padding = torch.cat([context_mask_padding[:, -effective_context_length:], predict_mask_padding], dim=1) 474 | 475 | # train on the batch 476 | loss, metrics, batch_data = model(batch) if batch is not None else model() 477 | loss /= total_length 478 | loss.backward() 479 | 480 | # collect the predicted frames and prepare the next context 481 | if 'obs' in batch_data and 'act' in batch_data and 'mask_padding' in batch_data and name != 'upsampler': 482 | obs = batch_data['obs'] 483 | act = batch_data['act'] 484 | mask_padding = batch_data['mask_padding'] 485 | 486 | # the next context begins at the end of the current predictions 487 | context_obs = torch.concat([context_obs, obs], dim=1) 488 | context_act = torch.concat([context_act, act], dim=1) 489 | context_mask_padding = torch.cat([context_mask_padding, mask_padding], dim=1) 490 | 491 | # saving up to effective context length frames back 492 | context_obs = context_obs[:, -effective_context_length:] 493 | context_act = context_act[:, -effective_context_length:] 494 | context_mask_padding = context_mask_padding[:, -effective_context_length:] 495 | 496 | num_batch = self.num_batch_train.get(name) 497 | metrics[f"num_batch_train_{name}"] = num_batch 498 | self.num_batch_train.set(name, num_batch + 1) 499 | to_log.append(metrics) 500 | 501 | if (i + 1) % cfg.grad_acc_steps == 0: 502 | if cfg.max_grad_norm is not None: 503 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm).item() 504 | metrics["grad_norm_before_clip"] = grad_norm 505 | opt.step() 506 | opt.zero_grad() 507 | if lr_sched is not None: 508 | metrics["lr"] = lr_sched.get_last_lr()[0] 509 | lr_sched.step() 510 | 511 | process_confusion_matrices_if_any_and_compute_classification_metrics(to_log) 512 | to_log = [{f"{name}/train/{k}": v for k, v in d.items()} for d in to_log] 513 | 514 | model.to("cpu") 515 | move_opt_to(opt, "cpu") 516 | 517 | return to_log 518 | 519 | @torch.no_grad() 520 | def test_component(self, name: str) -> Logs: 521 | model = getattr(self.agent, name) 522 | data_loader = self._data_loader_test.get(name) 523 | model.eval() 524 | model.to(self._device) 525 | to_log = [] 526 | for batch in tqdm(data_loader, desc=f"Evaluating {name}"): 527 | batch = batch.to(self._device) 528 | _, metrics, _ = model(batch) 529 | num_batch = self.num_batch_test.get(name) 530 | metrics[f"num_batch_test_{name}"] = num_batch 531 | self.num_batch_test.set(name, num_batch + 1) 532 | to_log.append(metrics) 533 | 534 | process_confusion_matrices_if_any_and_compute_classification_metrics(to_log) 535 | to_log = [{f"{name}/test/{k}": v for k, v in d.items()} for d in to_log] 536 | model.to("cpu") 537 | return to_log 538 | 539 | def load_state_checkpoint(self) -> None: 540 | self.load_state_dict(torch.load(self._path_state_ckpt, map_location=self._device)) 541 | 542 | def load_agent_state_checkpoint(self) -> None: 543 | agent_state_dict = torch.load(self._path_state_ckpt, map_location=self._device) 544 | self.agent.load_state_dict(agent_state_dict) 545 | 546 | def save_checkpoint(self) -> None: 547 | if self._rank == 0: 548 | save_with_backup(self.state_dict(), self._path_state_ckpt) 549 | self.train_dataset.save_to_default_path() 550 | self.test_dataset.save_to_default_path() 551 | self._keep_agent_copies(self.agent.state_dict(), self.epoch) 552 | self._save_info_for_import_script(self.epoch) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from argparse import Namespace 3 | from collections import OrderedDict 4 | from dataclasses import dataclass 5 | from functools import partial 6 | import json 7 | from pathlib import Path 8 | import random 9 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 10 | 11 | from omegaconf import OmegaConf 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | from torch import Tensor 16 | from torch.optim.lr_scheduler import LambdaLR 17 | import torch.nn as nn 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torch.optim import AdamW 20 | import wandb 21 | 22 | 23 | ATARI_100K_GAMES = [ 24 | "Alien", 25 | "Amidar", 26 | "Assault", 27 | "Asterix", 28 | "BankHeist", 29 | "BattleZone", 30 | "Boxing", 31 | "Breakout", 32 | "ChopperCommand", 33 | "CrazyClimber", 34 | "DemonAttack", 35 | "Freeway", 36 | "Frostbite", 37 | "Gopher", 38 | "Hero", 39 | "Jamesbond", 40 | "Kangaroo", 41 | "Krull", 42 | "KungFuMaster", 43 | "MsPacman", 44 | "Pong", 45 | "PrivateEye", 46 | "Qbert", 47 | "RoadRunner", 48 | "Seaquest", 49 | "UpNDown", 50 | ] 51 | 52 | 53 | Logs = List[Dict[str, float]] 54 | LossAndLogs = Tuple[Tensor, Dict[str, Any]] 55 | LossLogsData = Tuple[Tensor, Dict[str, Any], Dict[str, Any]] 56 | 57 | 58 | class StateDictMixin: 59 | def _init_fields(self) -> None: 60 | def has_sd(x: str) -> bool: 61 | return callable(getattr(x, "state_dict", None)) and callable(getattr(x, "load_state_dict", None)) 62 | 63 | self._all_fields = {k for k in vars(self) if not k.startswith("_")} 64 | self._fields_sd = {k for k in self._all_fields if has_sd(getattr(self, k))} 65 | 66 | def _get_field(self, k: str) -> Any: 67 | return getattr(self, k).state_dict() if k in self._fields_sd else getattr(self, k) 68 | 69 | def _set_field(self, k: str, v: Any) -> None: 70 | getattr(self, k).load_state_dict(v) if k in self._fields_sd else setattr(self, k, v) 71 | 72 | def state_dict(self) -> Dict[str, Any]: 73 | if not hasattr(self, "_all_fields"): 74 | self._init_fields() 75 | return {k: self._get_field(k) for k in self._all_fields} 76 | 77 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 78 | if not hasattr(self, "_all_fields"): 79 | self._init_fields() 80 | assert set(list(state_dict.keys())) == self._all_fields 81 | for k, v in state_dict.items(): 82 | self._set_field(k, v) 83 | 84 | 85 | @dataclass 86 | class CommonTools(StateDictMixin): 87 | denoiser: Optional[Any] = None 88 | upsampler: Optional[Any] = None 89 | rew_end_model: Optional[Any] = None 90 | actor_critic: Optional[Any] = None 91 | 92 | def get(self, name: str) -> Any: 93 | return getattr(self, name) 94 | 95 | def set(self, name: str, value: Any): 96 | return setattr(self, name, value) 97 | 98 | 99 | def broadcast_if_needed(*args): 100 | objects = list(args) 101 | if dist.is_initialized(): 102 | dist.broadcast_object_list(objects, src=0) 103 | # the list `objects` now contains the version of rank 0 104 | return objects 105 | 106 | 107 | def build_ddp_wrapper(**modules_dict: Dict[str, nn.Module]) -> Namespace: 108 | return Namespace(**{name: DDP(module) for name, module in modules_dict.items()}) 109 | 110 | 111 | def compute_classification_metrics(confusion_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 112 | num_classes = confusion_matrix.size(0) 113 | precision = torch.zeros(num_classes) 114 | recall = torch.zeros(num_classes) 115 | f1_score = torch.zeros(num_classes) 116 | 117 | for i in range(num_classes): 118 | true_positive = confusion_matrix[i, i].item() 119 | false_positive = confusion_matrix[:, i].sum().item() - true_positive 120 | false_negative = confusion_matrix[i, :].sum().item() - true_positive 121 | 122 | precision[i] = true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else 0 123 | recall[i] = true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else 0 124 | f1_score[i] = ( 125 | 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) != 0 else 0 126 | ) 127 | 128 | return precision, recall, f1_score 129 | 130 | 131 | def configure_opt(model: nn.Module, lr: float, weight_decay: float, eps: float, *blacklist_module_names: str) -> AdamW: 132 | """Credits to https://github.com/karpathy/minGPT""" 133 | # separate out all parameters to those that will and won't experience regularizing weight decay 134 | decay = set() 135 | no_decay = set() 136 | whitelist_weight_modules = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.LSTMCell, nn.LSTM) 137 | blacklist_weight_modules = (nn.LayerNorm, nn.Embedding, nn.GroupNorm) 138 | for mn, m in model.named_modules(): 139 | for pn, p in m.named_parameters(): 140 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 141 | if any([fpn.startswith(module_name) for module_name in blacklist_module_names]): 142 | no_decay.add(fpn) 143 | elif "bias" in pn: 144 | # all biases will not be decayed 145 | no_decay.add(fpn) 146 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, whitelist_weight_modules): 147 | # weights of whitelist modules will be weight decayed 148 | decay.add(fpn) 149 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, blacklist_weight_modules): 150 | # weights of blacklist modules will NOT be weight decayed 151 | no_decay.add(fpn) 152 | 153 | # validate that we considered every parameter 154 | param_dict = {pn: p for pn, p in model.named_parameters()} 155 | inter_params = decay & no_decay 156 | union_params = decay | no_decay 157 | assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" 158 | assert ( 159 | len(param_dict.keys() - union_params) == 0 160 | ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" 161 | 162 | # create the pytorch optimizer object 163 | optim_groups = [ 164 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 165 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 166 | ] 167 | optimizer = AdamW(optim_groups, lr=lr, eps=eps) 168 | return optimizer 169 | 170 | 171 | def count_parameters(model: nn.Module) -> int: 172 | return sum(p.numel() for p in model.parameters()) 173 | 174 | 175 | def extract_state_dict(state_dict: OrderedDict, module_name: str) -> OrderedDict: 176 | return OrderedDict({k.split(".", 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)}) 177 | 178 | 179 | def get_lr_sched(opt: torch.optim.Optimizer, num_warmup_steps: int) -> LambdaLR: 180 | def lr_lambda(current_step: int): 181 | return 1 if current_step >= num_warmup_steps else current_step / max(1, num_warmup_steps) 182 | 183 | return LambdaLR(opt, lr_lambda, last_epoch=-1) 184 | 185 | 186 | def init_lstm(model: nn.Module) -> None: 187 | for name, p in model.named_parameters(): 188 | if "weight_ih" in name: 189 | nn.init.xavier_uniform_(p.data) 190 | elif "weight_hh" in name: 191 | nn.init.orthogonal_(p.data) 192 | elif "bias_ih" in name: 193 | p.data.fill_(0) 194 | # Set forget-gate bias to 1 195 | n = p.size(0) 196 | p.data[(n // 4) : (n // 2)].fill_(1) 197 | elif "bias_hh" in name: 198 | p.data.fill_(0) 199 | 200 | 201 | def get_path_agent_ckpt(path_ckpt_dir: Union[str, Path], epoch: int, num_zeros: int = 5) -> Path: 202 | d = Path(path_ckpt_dir) / "agent_versions" 203 | if epoch >= 0: 204 | return d / f"agent_epoch_{epoch:0{num_zeros}d}.pt" 205 | else: 206 | all_ = sorted(list(d.iterdir())) 207 | assert len(all_) >= -epoch 208 | return all_[epoch] 209 | 210 | 211 | def keep_agent_copies_every( 212 | agent_sd: Dict[str, Any], 213 | epoch: int, 214 | path_ckpt_dir: Path, 215 | every: int, 216 | num_to_keep: Optional[int], 217 | ) -> None: 218 | assert every > 0 219 | assert num_to_keep is None or num_to_keep > 0 220 | get_path = partial(get_path_agent_ckpt, path_ckpt_dir) 221 | get_path(0).parent.mkdir(parents=False, exist_ok=True) 222 | 223 | # Save agent 224 | save_with_backup(agent_sd, get_path(epoch)) 225 | 226 | # Clean oldest 227 | if (num_to_keep is not None) and (epoch % every == 0): 228 | get_path(max(0, epoch - num_to_keep * every)).unlink(missing_ok=True) 229 | 230 | # Clean previous 231 | if (epoch - 1) % every != 0: 232 | get_path(max(0, epoch - 1)).unlink(missing_ok=True) 233 | 234 | 235 | def move_opt_to(opt: AdamW, device: torch.device): 236 | for optimizer_metrics in opt.state.values(): 237 | for metric_name, metric in optimizer_metrics.items(): 238 | if torch.is_tensor(metric) and metric_name != "step": 239 | optimizer_metrics[metric_name] = metric.to(device) 240 | 241 | 242 | def process_confusion_matrices_if_any_and_compute_classification_metrics(logs: Logs) -> None: 243 | cm = [x.pop("confusion_matrix") for x in logs if "confusion_matrix" in x] 244 | if len(cm) > 0: 245 | confusion_matrices = {k: sum([d[k] for d in cm]) for k in cm[0]} # accumulate confusion matrices 246 | metrics = {} 247 | for key, confusion_matrix in confusion_matrices.items(): 248 | precision, recall, f1_score = compute_classification_metrics(confusion_matrix) 249 | metrics.update( 250 | { 251 | **{f"classification_metrics/{key}_precision_class_{i}": v for i, v in enumerate(precision)}, 252 | **{f"classification_metrics/{key}_recall_class_{i}": v for i, v in enumerate(recall)}, 253 | **{f"classification_metrics/{key}_f1_score_class_{i}": v for i, v in enumerate(f1_score)}, 254 | } 255 | ) 256 | 257 | logs.append(metrics) # Append the obtained metrics to logs (in place) 258 | 259 | 260 | def prompt_atari_game(): 261 | for i, game in enumerate(ATARI_100K_GAMES): 262 | print(f"{i:2d}: {game}") 263 | while True: 264 | x = input("\nEnter a number: ") 265 | if not x.isdigit(): 266 | print("Invalid.") 267 | continue 268 | x = int(x) 269 | if x < 0 or x > 25: 270 | print("Invalid.") 271 | continue 272 | break 273 | game = ATARI_100K_GAMES[x] 274 | return game 275 | 276 | 277 | def prompt_run_name(game): 278 | cfg_file = Path("config/trainer.yaml") 279 | cfg_name = OmegaConf.load(cfg_file).wandb.name 280 | suffix = f"-{cfg_name}" if cfg_name is not None else "" 281 | name = game + suffix 282 | name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n") 283 | if name_ != "": 284 | name = name_ 285 | return name 286 | 287 | 288 | def save_info_for_import_script(epoch: int, run_name: str, path_ckpt_dir: Path) -> None: 289 | with (path_ckpt_dir / "info_for_import_script.json").open("w") as f: 290 | json.dump({"epoch": epoch, "name": run_name}, f) 291 | 292 | 293 | def save_with_backup(obj: Any, path: Path): 294 | bk = path.with_suffix(".bk") 295 | if path.is_file(): 296 | path.rename(bk) 297 | torch.save(obj, path) 298 | bk.unlink(missing_ok=True) 299 | 300 | 301 | def set_seed(seed: int) -> None: 302 | np.random.seed(seed) 303 | torch.manual_seed(seed) 304 | torch.cuda.manual_seed(seed) 305 | random.seed(seed) 306 | 307 | 308 | def skip_if_run_is_over(func: Callable) -> Callable: 309 | def inner(*args, **kwargs): 310 | path_run_is_over = Path(".run_is_over") 311 | if not path_run_is_over.is_file(): 312 | func(*args, **kwargs) 313 | path_run_is_over.touch() 314 | else: 315 | print(f"Run is marked as finished. To unmark, remove '{str(path_run_is_over)}'.") 316 | 317 | return inner 318 | 319 | 320 | def try_until_no_except(func: Callable) -> None: 321 | while True: 322 | try: 323 | func() 324 | except KeyboardInterrupt: 325 | break 326 | except Exception: 327 | continue 328 | else: 329 | break 330 | 331 | 332 | def wandb_log(logs: Logs, epoch: int): 333 | for d in logs: 334 | wandb.log({"epoch": epoch, **d}) 335 | 336 | 337 | def get_frame_indices(frame_sampling): 338 | indexes = [] 339 | current_index = 0 340 | for group in frame_sampling[::-1]: 341 | for _ in range(group['count']): 342 | indexes.append(current_index) 343 | current_index += group['stride'] 344 | 345 | return torch.tensor(indexes) 346 | 347 | def build_pages_per_epoch(pages_per_epoch): 348 | mapping = {} 349 | for group in pages_per_epoch[::-1]: 350 | mapping[group['epoch']] = group['count'] 351 | 352 | return mapping 353 | 354 | def find_maximum_key_below_threshold(d, threshold): 355 | if d is None: 356 | return None 357 | 358 | eligible_keys = [k for k in d.keys() if k <= threshold] 359 | if not eligible_keys: 360 | return None 361 | return max(eligible_keys) --------------------------------------------------------------------------------