├── .gitignore
├── LICENSE
├── README.md
├── assets
├── demo.gif
├── img.png
└── img_2.png
├── config
├── agent
│ └── racing.yaml
├── env
│ └── racing.yaml
├── trainer.yaml
└── world_model_env
│ └── fast.yaml
├── game
└── spawn
│ ├── 1
│ ├── act.npy
│ ├── full_res.npy
│ ├── low_res.npy
│ └── next_act.npy
│ ├── 2
│ ├── act.npy
│ ├── full_res.npy
│ ├── low_res.npy
│ └── next_act.npy
│ └── 3
│ ├── act.npy
│ ├── full_res.npy
│ ├── low_res.npy
│ └── next_act.npy
├── requirements.txt
├── scripts
├── import_run.py
└── resume.sh
└── src
├── __init__.py
├── agent.py
├── coroutines
├── __init__.py
├── collector.py
└── env_loop.py
├── data
├── __init__.py
├── batch.py
├── batch_sampler.py
├── dataset.py
├── episode.py
├── segment.py
└── utils.py
├── envs
├── __init__.py
├── atari_preprocessing.py
├── env.py
└── world_model_env.py
├── game
├── __init__.py
├── dataset_env.py
├── game.py
└── play_env.py
├── main.py
├── models
├── __init__.py
├── actor_critic.py
├── blocks.py
├── diffusion
│ ├── __init__.py
│ ├── denoiser.py
│ ├── diffusion_sampler.py
│ └── inner_model.py
└── rew_end_model.py
├── play.py
├── player
├── __init__.py
├── action_processing.py
└── keymap.py
├── process_denoiser_files.py
├── process_upsampler_files.py
├── spawn.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.zip
3 | checkpoints
4 | **/__pycache__/
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Eloi Alonso
4 | Copyright (c) 2025 Enigma Labs AI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multiverse: The First AI Multiplayer World Model
2 |
3 | 🌐 [Enigma-AI website](https://enigma-labs.io/) - 📚 [Technical Blog](https://enigma-labs.io/) - [🤗 Model on Huggingface](https://huggingface.co/Enigma-AI/multiverse) - [🤗 Datasets on Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res) - 𝕏 [Multiverse Tweet](https://x.com/j0nathanj/status/1920516649511244258)
4 |
5 |
6 |
Two human players driving cars in Multiverse
7 |
8 |

9 |
10 |
11 | ---
12 |
13 | ## Installation
14 | ```bash
15 | git clone https://github.com/EnigmaLabsAI/multiverse
16 | cd multiverse
17 | pip install -r requirements.txt
18 | ```
19 |
20 | ### Running the model
21 |
22 | ```bash
23 | python src/play.py --compile
24 | ```
25 |
26 | > Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py
27 |
28 | When running this command, you will be prompted with the controls. Press `enter` to start:
29 | 
30 |
31 | Then the game will be start:
32 | * To control the silver car at the top screen use the arrow keys.
33 | * To control the blue car at the bottom use the WASD keys.
34 |
35 | 
36 |
37 | ---
38 |
39 |
40 | ## Training
41 |
42 | Multiverse comprised two models:
43 | * Denoiser - a world model that simulates a game
44 | * Upsampler - a model which takes the frames from the denoiser and increases their resolution
45 |
46 | ### Denoiser training
47 |
48 | #### 1. Download the dataset
49 | Download the Denoiser's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
50 |
51 | #### 2. Process data for training
52 | Run the command:
53 | ```bash
54 | python src/process_denoiser_files.py
55 | ```
56 |
57 | #### 3. Edit training configuration
58 |
59 | Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
60 | - `path_data_low_res` to `/low_res`
61 | - `path_data_full_res` to `/full_res`
62 |
63 | Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
64 | ```yaml
65 | train_model: denoiser
66 | ```
67 |
68 | #### 4. Launch training run
69 |
70 | You can then launch a training run with `python src/main.py`.
71 |
72 |
73 | ### Upsampler training
74 |
75 | #### 1. Download the dataset
76 | Download the Upsampler's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
77 |
78 | #### 2. Process data for training
79 | Run the command:
80 | ```bash
81 | python src/process_upsampler_files.py
82 | ```
83 |
84 | #### 3. Edit training configuration
85 |
86 | Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
87 | - `path_data_low_res` to `/low_res`
88 | - `path_data_full_res` to `/full_res`
89 |
90 | Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
91 | ```yaml
92 | train_model: upsampler
93 | ```
94 |
95 | #### 4. Launch training run
96 |
97 | You can then launch a training run with `python src/main.py`.
98 |
99 |
100 | ---
101 |
102 | ## Datasets
103 |
104 | 1. We've collected over 4 hours of multiplayer (1v1) footage from Gran Turismo 4 at a resolution of 48x64 (per players): [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
105 |
106 | 2. A sparse sampling of full resolution, cropped frames, are availabe in order to train the upsampler at a resolution of 350x530: [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
107 |
108 | The datasets contain a variety of situations: acceleration, braking, overtakes, crashes, and expert driving for both players.
109 | You can read about the data collection mechanism [here](https://enigma-labs.io/blog)
110 |
111 | Note: The full resolution dataset is only for upsampler training and is not fit for world model training.
112 |
113 | ---
114 |
115 | ## Outside resources
116 |
117 | - DIAMOND - https://github.com/eloialonso/diamond
118 | - AI-MarioKart64 - https://github.com/Dere-Wah/AI-MarioKart64
119 |
120 |
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/demo.gif
--------------------------------------------------------------------------------
/assets/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/img.png
--------------------------------------------------------------------------------
/assets/img_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/assets/img_2.png
--------------------------------------------------------------------------------
/config/agent/racing.yaml:
--------------------------------------------------------------------------------
1 | _target_: agent.AgentConfig
2 |
3 | denoiser:
4 | _target_: models.diffusion.DenoiserConfig
5 | sigma_data: 0.5
6 | sigma_offset_noise: 0.1
7 | noise_previous_obs: true
8 | upsampling_factor: null
9 | frame_sampling:
10 | - count: 4
11 | stride: 1
12 | - count: 4
13 | stride: 4
14 | inner_model:
15 | _target_: models.diffusion.InnerModelConfig
16 | img_channels: 6
17 | num_steps_conditioning: 8
18 | cond_channels: 2048
19 | depths:
20 | - 2
21 | - 2
22 | - 2
23 | - 2
24 | channels:
25 | - 128
26 | - 256
27 | - 512
28 | - 1024
29 | attn_depths:
30 | - 0
31 | - 0
32 | - 1
33 | - 1
34 |
35 | upsampler:
36 | _target_: models.diffusion.DenoiserConfig
37 | sigma_data: 0.5
38 | sigma_offset_noise: 0.1
39 | noise_previous_obs: false
40 | upsampling_factor: 10
41 | upsampling_frame_height: 350
42 | upsampling_frame_width: 530
43 | inner_model:
44 | _target_: models.diffusion.InnerModelConfig
45 | img_channels: 6
46 | num_steps_conditioning: 0
47 | cond_channels: 2048
48 | depths:
49 | - 2
50 | - 2
51 | - 2
52 | - 2
53 | channels:
54 | - 64
55 | - 64
56 | - 128
57 | - 256
58 | attn_depths:
59 | - 0
60 | - 0
61 | - 0
62 | - 0
63 |
64 | rew_end_model: null
65 |
66 | actor_critic: null
67 |
--------------------------------------------------------------------------------
/config/env/racing.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | id: racing
3 | size: [700, 530]
4 | num_actions: 66
5 | path_data_low_res: null
6 | path_data_full_res: null
7 | keymap: racing
8 |
--------------------------------------------------------------------------------
/config/trainer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - env: racing
4 | - agent: racing
5 | - world_model_env: fast
6 |
7 | hydra:
8 | job:
9 | chdir: True
10 |
11 | wandb:
12 | mode: offline
13 | project: null
14 | entity: null
15 | name: null
16 | group: null
17 | tags: null
18 |
19 | initialization:
20 | path_to_ckpt: null
21 | load_denoiser: True
22 | load_rew_end_model: True
23 | load_actor_critic: True
24 |
25 | common:
26 | devices: all # int, list of int, cpu, or all
27 | seed: null
28 | resume: False # do not modify, set by scripts/resume.sh only.
29 |
30 | checkpointing:
31 | save_agent_every: 5
32 | num_to_keep: 11 # number of checkpoints to keep, use null to disable
33 |
34 | collection:
35 | train:
36 | num_envs: 1
37 | epsilon: 0.01
38 | num_steps_total: 100000
39 | first_epoch:
40 | min: 5000
41 | max: 10000 # null: no maximum
42 | threshold_rew: 10
43 | steps_per_epoch: 100
44 | test:
45 | num_envs: 1
46 | num_episodes: 4
47 | epsilon: 0.0
48 | num_final_episodes: 100
49 |
50 | static_dataset:
51 | path: ${env.path_data_low_res}
52 | ignore_sample_weights: True
53 |
54 | training:
55 | should: True
56 | num_final_epochs: 600
57 | cache_in_ram: False
58 | num_workers_data_loaders: 1
59 | model_free: False # if True, turn off world_model training and RL in imagination
60 | compile_wm: False
61 |
62 | evaluation:
63 | should: True
64 | every: 20
65 |
66 | train_model: denoiser
67 |
68 | denoiser:
69 | training:
70 | num_autoregressive_steps: 8
71 | initial_num_consecutive_page_count: 1
72 | num_consecutive_pages:
73 | - epoch: 400
74 | count: 10
75 | - epoch: 500
76 | count: 50
77 | start_after_epochs: 0
78 | steps_first_epoch: 10
79 | steps_per_epoch: 20
80 | sample_weights: null
81 | batch_size: 30
82 | grad_acc_steps: 2
83 | lr_warmup_steps: 100
84 | max_grad_norm: 10.0
85 |
86 | optimizer:
87 | lr: 1e-4
88 | weight_decay: 1e-2
89 | eps: 1e-8
90 |
91 | sigma_distribution: # log normal distribution for sigma during training
92 | _target_: models.diffusion.SigmaDistributionConfig
93 | loc: -1.2
94 | scale: 1.2
95 | sigma_min: 2e-3
96 | sigma_max: 20
97 |
98 | upsampler:
99 | training:
100 | num_autoregressive_steps: 1
101 | initial_num_consecutive_page_count: 1
102 | start_after_epochs: 0
103 | steps_first_epoch: 20
104 | steps_per_epoch: 20
105 | sample_weights: null
106 | batch_size: 4
107 | grad_acc_steps: 2
108 | lr_warmup_steps: 100
109 | max_grad_norm: 10.0
110 |
111 | optimizer: ${denoiser.optimizer}
112 | sigma_distribution: ${denoiser.sigma_distribution}
113 |
114 |
--------------------------------------------------------------------------------
/config/world_model_env/fast.yaml:
--------------------------------------------------------------------------------
1 | _target_: envs.WorldModelEnvConfig
2 | horizon: 1000
3 | num_batches_to_preload: 256
4 | diffusion_sampler_next_obs:
5 | _target_: models.diffusion.DiffusionSamplerConfig
6 | num_steps_denoising: 1
7 | sigma_min: 2e-3
8 | sigma_max: 5.0
9 | rho: 7
10 | order: 1 # 1: Euler, 2: Heun
11 | s_churn: 0.0 # Amount of stochasticity
12 | s_tmin: 0.0
13 | s_tmax: ${eval:'float("inf")'}
14 | s_noise: 1.0
15 | s_cond: 0.005
16 | diffusion_sampler_upsampling:
17 | _target_: models.diffusion.DiffusionSamplerConfig
18 | num_steps_denoising: 1
19 | sigma_min: 1
20 | sigma_max: 5.0
21 | rho: 7
22 | order: 2 # 1: Euler, 2: Heun
23 | s_churn: 10.0 # Amount of stochasticity
24 | s_tmin: 1
25 | s_tmax: 5
26 | s_noise: 0.9
27 | s_cond: 0
--------------------------------------------------------------------------------
/game/spawn/1/act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/act.npy
--------------------------------------------------------------------------------
/game/spawn/1/full_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/full_res.npy
--------------------------------------------------------------------------------
/game/spawn/1/low_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/low_res.npy
--------------------------------------------------------------------------------
/game/spawn/1/next_act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/1/next_act.npy
--------------------------------------------------------------------------------
/game/spawn/2/act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/act.npy
--------------------------------------------------------------------------------
/game/spawn/2/full_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/full_res.npy
--------------------------------------------------------------------------------
/game/spawn/2/low_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/low_res.npy
--------------------------------------------------------------------------------
/game/spawn/2/next_act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/2/next_act.npy
--------------------------------------------------------------------------------
/game/spawn/3/act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/act.npy
--------------------------------------------------------------------------------
/game/spawn/3/full_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/full_res.npy
--------------------------------------------------------------------------------
/game/spawn/3/low_res.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/low_res.npy
--------------------------------------------------------------------------------
/game/spawn/3/next_act.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/game/spawn/3/next_act.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gymnasium==0.29.1
2 | ale-py==0.9.0
3 | h5py==3.11.0
4 | huggingface-hub==0.17.2
5 | hydra-core==1.3
6 | numpy==1.26.0
7 | opencv-python==4.10.0.84
8 | pillow==10.3.0
9 | pygame==2.5.2
10 | torch==2.1.0
11 | torchvision==0.16.0
12 | torcheval==0.0.7
13 | tqdm==4.66.4
14 | wandb==0.17.0
15 |
--------------------------------------------------------------------------------
/scripts/import_run.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import argparse
4 | from functools import partial
5 | import json
6 | from pathlib import Path
7 | import subprocess
8 | from typing import Optional
9 |
10 |
11 | def main() -> None:
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("host", type=str)
14 | parser.add_argument("-v", "--verbose", action="store_true")
15 | parser.add_argument("--user", type=Optional[str])
16 | parser.add_argument("--rootdir", type=Optional[str])
17 | args = parser.parse_args()
18 |
19 | run = partial(subprocess.run, shell=True, check=True, text=True)
20 | host = args.host if args.user is None else f"{args.user}@{args.host}"
21 |
22 | def run_remote_cmd(cmd):
23 | return subprocess.check_output(f"ssh {host} {cmd}", shell=True, text=True)
24 |
25 | def ls(p):
26 | out = run_remote_cmd(f"ls {p}")
27 | return out.strip().split("\n")[::-1]
28 |
29 | def ask(l, info=None):
30 | print(
31 | "\n".join(
32 | [
33 | f"{i:{len(str(len(l)))}d}: {d}"
34 | + (f" ({info[d]})" if info is not None else "")
35 | for i, d in enumerate(l, 1)
36 | ]
37 | )
38 | )
39 | while True:
40 | i = input("\nEnter a number: ")
41 | if i.isdigit() and 1 <= int(i) <= len(l):
42 | break
43 | print("\n/!\\ Invalid choice\n")
44 | return l[int(i) - 1]
45 |
46 | def ask_if_verbose(question, default):
47 | if not args.verbose:
48 | return default
49 | suffix = "[Y|n]" if default else "[y|N]"
50 | answer = input(f"{question} {suffix} ").lower()
51 |
52 | return (answer != "n") if default else (answer == "y")
53 |
54 | def get_info(rundir):
55 | return json.loads(
56 | run_remote_cmd(f"cat {rundir}/checkpoints/info_for_import_script.json")
57 | )
58 |
59 | if args.rootdir is None:
60 | for p in Path(__file__).resolve().parents:
61 | if (p / ".git").is_dir():
62 | break
63 | else:
64 | raise RuntimeError("This file is not in a git repository")
65 | out = run_remote_cmd(f"find -type d -name {p.name}").strip().split("\n")
66 | assert len(out) == 1
67 | rootdir = out[0]
68 | else:
69 | rootdir = f'{args.rootdir.strip().strip("/")}'
70 |
71 | dates = ls(f"{rootdir}/outputs")
72 | date = ask(dates)
73 | times = ls(f"{rootdir}/outputs/{date}")
74 |
75 | infos = {
76 | time: get_info(rundir=f"{rootdir}/outputs/{date}/{time}") for time in times
77 | }
78 | time = ask(times, infos)
79 |
80 | src = f"{rootdir}/outputs/{date}/{time}"
81 |
82 | dst = Path(args.host) / date
83 | dst.mkdir(exist_ok=True, parents=True)
84 |
85 | exclude = [
86 | "*.log",
87 | "checkpoints/*",
88 | "checkpoints_tmp",
89 | ".hydra",
90 | "media",
91 | "__pycache__",
92 | "wandb",
93 | ]
94 |
95 | include = ["checkpoints/agent_versions"]
96 |
97 | if ask_if_verbose("Download only last checkpoint?", default=True):
98 | last_ckpt = ls(f"{src}/checkpoints/agent_versions")[0]
99 | exclude.append("checkpoints/agent_versions/*")
100 | include.append(f"checkpoints/agent_versions/{last_ckpt}")
101 |
102 | if not ask_if_verbose("Download train dataset?", default=False):
103 | exclude.append("dataset/train")
104 |
105 | if not ask_if_verbose("Download test dataset?", default=False):
106 | exclude.append("dataset/test")
107 |
108 | cmd = "rsync -av"
109 | for i in include:
110 | cmd += f' --include="{i}"'
111 | for e in exclude:
112 | cmd += f' --exclude="{e}"'
113 |
114 | cmd += f" {host}:{src} {str(dst)}"
115 | run(cmd)
116 |
117 | path = (dst / time).absolute()
118 | print(f"\n--> Run imported in:\n{path}")
119 | run(f"echo {path} | xclip")
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/scripts/resume.sh:
--------------------------------------------------------------------------------
1 | python src/main.py common.resume=True hydra.output_subdir=null hydra.run.dir=.
2 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/__init__.py
--------------------------------------------------------------------------------
/src/agent.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import Optional, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from envs import TorchEnv, WorldModelEnv
9 | from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
10 | from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
11 | from models.rew_end_model import RewEndModel, RewEndModelConfig
12 | from utils import extract_state_dict
13 |
14 |
15 | @dataclass
16 | class AgentConfig:
17 | denoiser: DenoiserConfig
18 | upsampler: Optional[DenoiserConfig]
19 | rew_end_model: Optional[RewEndModelConfig]
20 | actor_critic: Optional[ActorCriticConfig]
21 | num_actions: int
22 |
23 | def __post_init__(self) -> None:
24 | self.denoiser.inner_model.num_actions = self.num_actions
25 | if self.upsampler is not None:
26 | self.upsampler.inner_model.num_actions = self.num_actions
27 | if self.rew_end_model is not None:
28 | self.rew_end_model.num_actions = self.num_actions
29 | if self.actor_critic is not None:
30 | self.actor_critic.num_actions = self.num_actions
31 |
32 |
33 | class Agent(nn.Module):
34 | def __init__(self, cfg: AgentConfig) -> None:
35 | super().__init__()
36 | self.denoiser = Denoiser(cfg.denoiser)
37 | self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
38 | self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
39 | self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None
40 |
41 | @property
42 | def device(self):
43 | return self.denoiser.device
44 |
45 | def setup_training(
46 | self,
47 | sigma_distribution_cfg: SigmaDistributionConfig,
48 | sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
49 | actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
50 | rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
51 | ) -> None:
52 | self.denoiser.setup_training(sigma_distribution_cfg)
53 | if self.upsampler is not None:
54 | self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
55 | if self.actor_critic is not None:
56 | self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
57 |
58 | def load(
59 | self,
60 | path_to_ckpt: Path,
61 | load_denoiser: bool = True,
62 | load_upsampler: bool = True,
63 | load_rew_end_model: bool = True,
64 | load_actor_critic: bool = True,
65 | ) -> None:
66 | sd = torch.load(Path(path_to_ckpt), map_location=self.device)
67 | if load_denoiser:
68 | self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser"))
69 | if load_upsampler:
70 | self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler"))
71 | if load_rew_end_model and self.rew_end_model is not None:
72 | self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model"))
73 | if load_actor_critic and self.actor_critic is not None:
74 | self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic"))
75 |
--------------------------------------------------------------------------------
/src/coroutines/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 |
3 |
4 | def coroutine(func):
5 | @wraps(func)
6 | def primer(*args, **kwargs):
7 | gen = func(*args, **kwargs)
8 | next(gen)
9 | return gen
10 |
11 | return primer
12 |
--------------------------------------------------------------------------------
/src/coroutines/collector.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from dataclasses import dataclass
3 | from typing import Generator, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 |
9 | from . import coroutine
10 | from data import Episode, Dataset
11 | from envs import TorchEnv
12 | from .env_loop import make_env_loop
13 | from utils import Logs
14 |
15 |
16 | @coroutine
17 | def make_collector(
18 | env: TorchEnv,
19 | model: nn.Module,
20 | dataset: Dataset,
21 | epsilon: float = 0.0,
22 | reset_every_collect: bool = False,
23 | verbose: bool = True,
24 | ) -> Generator[Logs, int, None]:
25 | num_envs = env.num_envs
26 |
27 | env_loop, buffer, episode_ids, dead = (None,) * 4
28 | num_steps, num_episodes, to_log, pbar = (None,) * 4
29 |
30 | def setup_new_collect():
31 | nonlocal num_steps, num_episodes, buffer, to_log, pbar
32 | num_steps = 0
33 | num_episodes = 0
34 | buffer = defaultdict(list)
35 | to_log = []
36 | pbar = tqdm(
37 | total=num_to_collect.total,
38 | unit=num_to_collect.unit,
39 | desc=f"Collect {dataset.name}",
40 | disable=not verbose,
41 | )
42 |
43 | def reset():
44 | nonlocal env_loop, episode_ids, dead
45 | env_loop = make_env_loop(env, model, epsilon)
46 | episode_ids = defaultdict(lambda: None)
47 | dead = [None] * num_envs
48 |
49 | num_to_collect = yield
50 | setup_new_collect()
51 | reset()
52 |
53 | while True:
54 | with torch.no_grad():
55 | all_obs, act, rew, end, trunc, *_, [infos] = env_loop.send(1)
56 |
57 | num_steps += num_envs
58 | pbar.update(num_envs if num_to_collect.steps is not None else 0)
59 |
60 | for i, (o, a, r, e, t) in enumerate(zip(all_obs, act, rew, end, trunc)):
61 | buffer[i].append((o, a, r, e, t))
62 | dead[i] = (e + t).clip(max=1).item()
63 |
64 | num_episodes += sum(dead)
65 |
66 | can_stop = num_to_collect.can_stop(num_steps, num_episodes)
67 |
68 | count_dead = 0
69 | for i in range(num_envs):
70 | # Store incomplete episodes only when reset_every_collect is set to False (train)
71 | add_to_dataset = dead[i] or (can_stop and not reset_every_collect)
72 | if add_to_dataset:
73 | info = {"final_observation": infos["final_observation"][count_dead]} if dead[i] else {}
74 | ep = Episode(*(torch.cat(x, dim=0) for x in zip(*buffer[i])), info).to("cpu")
75 | if episode_ids[i] is not None:
76 | ep = dataset.load_episode(episode_ids[i]) + ep
77 | episode_ids[i] = dataset.add_episode(ep, episode_id=episode_ids[i])
78 |
79 | if dead[i]:
80 | to_log.append(
81 | {
82 | f"{dataset.name}/episode_id": episode_ids[i],
83 | **ep.compute_metrics(),
84 | }
85 | )
86 | buffer[i] = []
87 | episode_ids[i] = None
88 | pbar.update(1 if num_to_collect.episodes is not None else 0)
89 |
90 | count_dead += dead[i]
91 |
92 | if can_stop:
93 | pbar.close()
94 | metrics = {
95 | "num_steps": dataset.num_steps,
96 | "counts/rew_-1": dataset.counts_rew[0],
97 | "counts/rew__0": dataset.counts_rew[1],
98 | "counts/rew_+1": dataset.counts_rew[2],
99 | "counts/end_0": dataset.counts_end[0],
100 | "counts/end_1": dataset.counts_end[1],
101 | }
102 | to_log.append({f"{dataset.name}/{k}": v for k, v in metrics.items()})
103 | num_to_collect = yield to_log
104 | setup_new_collect()
105 | if reset_every_collect:
106 | reset()
107 |
108 |
109 | @dataclass
110 | class NumToCollect:
111 | steps: Optional[int] = None
112 | episodes: Optional[int] = None
113 |
114 | def __post_init__(self) -> None:
115 | assert (self.steps is None) != (self.episodes is None)
116 |
117 | def can_stop(self, num_steps: int, num_episodes: int) -> bool:
118 | return num_steps >= self.steps if self.steps is not None else num_episodes >= self.episodes
119 |
120 | @property
121 | def unit(self) -> str:
122 | return "steps" if self.steps is not None else "eps"
123 |
124 | @property
125 | def total(self) -> int:
126 | return self.steps if self.steps is not None else self.episodes
127 |
--------------------------------------------------------------------------------
/src/coroutines/env_loop.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Generator, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.distributions.categorical import Categorical
7 |
8 | from . import coroutine
9 | from envs import TorchEnv, WorldModelEnv
10 |
11 |
12 | @coroutine
13 | def make_env_loop(
14 | env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0
15 | ) -> Generator[Tuple[torch.Tensor, ...], int, None]:
16 | num_steps = yield
17 |
18 | hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
19 | cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
20 |
21 | seed = random.randint(0, 2**31 - 1)
22 | obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)])
23 |
24 | while True:
25 | hx, cx = hx.detach(), cx.detach()
26 | all_ = []
27 | infos = []
28 | n = 0
29 |
30 | while n < num_steps:
31 | logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx))
32 | act = Categorical(logits=logits_act).sample()
33 |
34 | if random.random() < epsilon:
35 | act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)
36 |
37 | next_obs, rew, end, trunc, info = env.step(act)
38 |
39 | if n > 0:
40 | val_bootstrap = val.detach().clone()
41 | if dead.any():
42 | val_bootstrap[dead] = val_final_obs
43 | all_[-1][-1] = val_bootstrap
44 |
45 | dead = torch.logical_or(end, trunc)
46 |
47 | if dead.any():
48 | with torch.no_grad():
49 | _, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead]))
50 | reset_gate = 1 - dead.float().unsqueeze(1)
51 | hx = hx * reset_gate
52 | cx = cx * reset_gate
53 | if "burnin_obs" in info:
54 | burnin_obs = info["burnin_obs"]
55 | for i in range(burnin_obs.size(1)):
56 | _, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead]))
57 |
58 | all_.append([obs, act, rew, end, trunc, logits_act, val, None])
59 | infos.append(info)
60 |
61 | obs = next_obs
62 | n += 1
63 |
64 | with torch.no_grad():
65 | _, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx
66 |
67 | if dead.any():
68 | val_bootstrap[dead] = val_final_obs
69 |
70 | all_[-1][-1] = val_bootstrap
71 |
72 | all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_))
73 |
74 | num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos
75 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch import Batch
2 | from .batch_sampler import BatchSampler
3 | from .dataset import Dataset, GameHdf5Dataset
4 | from .episode import Episode
5 | from .segment import Segment, SegmentId
6 | from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
7 |
--------------------------------------------------------------------------------
/src/data/batch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, List
4 |
5 | import torch
6 |
7 | from .segment import SegmentId
8 |
9 |
10 | @dataclass
11 | class Batch:
12 | obs: torch.ByteTensor
13 | act: torch.LongTensor
14 | rew: torch.FloatTensor
15 | end: torch.LongTensor
16 | trunc: torch.LongTensor
17 | mask_padding: torch.BoolTensor
18 | info: List[Dict[str, Any]]
19 | segment_ids: List[SegmentId]
20 |
21 | def pin_memory(self) -> Batch:
22 | return Batch(**{k: v if k in ("segment_ids", "info") else v.pin_memory() for k, v in self.__dict__.items()})
23 |
24 | def to(self, device: torch.device) -> Batch:
25 | return Batch(**{k: v if k in ("segment_ids", "info") else v.to(device) for k, v in self.__dict__.items()})
26 |
--------------------------------------------------------------------------------
/src/data/batch_sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, List, Optional
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .dataset import GameHdf5Dataset, Dataset
7 | from .segment import SegmentId
8 |
9 |
10 | class BatchSampler(torch.utils.data.Sampler):
11 | def __init__(
12 | self,
13 | dataset: Dataset,
14 | rank: int,
15 | world_size: int,
16 | batch_size: int,
17 | seq_length: int,
18 | sample_weights: Optional[List[float]] = None,
19 | can_sample_beyond_end: bool = False,
20 | autoregressive_obs: int = None,
21 | initial_num_consecutive_page_count: int = 1
22 | ) -> None:
23 | super().__init__(dataset)
24 | assert isinstance(dataset, (Dataset, GameHdf5Dataset))
25 | self.dataset = dataset
26 | self.rank = rank
27 | self.world_size = world_size
28 | self.sample_weights = sample_weights
29 | self.batch_size = batch_size
30 | self.seq_length = seq_length
31 | self.can_sample_beyond_end = can_sample_beyond_end
32 | self.autoregressive_obs = autoregressive_obs
33 | self.num_consecutive_batches = initial_num_consecutive_page_count
34 |
35 | def __len__(self):
36 | raise NotImplementedError
37 |
38 | def __iter__(self) -> Generator[List[SegmentId], None, None]:
39 | segments = None
40 | current_iter = 0
41 |
42 | while True:
43 | if current_iter == 0:
44 | segments = self.sample()
45 | else:
46 | segments = self.next(segments)
47 |
48 | current_iter = (current_iter + 1) % self.num_consecutive_batches
49 | yield segments
50 |
51 | def next(self, segments: List[SegmentId]):
52 | return [
53 | SegmentId(segment.episode_id, segment.stop, segment.stop + self.autoregressive_obs, False)
54 | for segment in segments
55 | ]
56 |
57 | def sample(self) -> List[SegmentId]:
58 | total_length = self.seq_length + (self.num_consecutive_batches - 1) * self.autoregressive_obs
59 |
60 | num_episodes = self.dataset.num_episodes
61 |
62 | if (self.sample_weights is None) or num_episodes < len(self.sample_weights):
63 | weights = self.dataset.lengths / self.dataset.num_steps
64 | else:
65 | weights = self.sample_weights
66 | num_weights = len(self.sample_weights)
67 | assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
68 | sizes = [
69 | num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1)
70 | for i in range(num_weights)
71 | ]
72 | weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]
73 |
74 | episodes_partition = np.arange(self.rank, num_episodes, self.world_size)
75 | episode_lengths = self.dataset.lengths[episodes_partition]
76 | valid_mask = episode_lengths > total_length # valid episodes must be long enough for autoregressvie generation
77 | episodes_partition = episodes_partition[valid_mask]
78 |
79 | weights = np.array(weights[self.rank::self.world_size])
80 | weights = weights[valid_mask]
81 |
82 | max_eps = self.batch_size
83 | episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum())
84 | episode_ids = episode_ids.repeat(self.batch_size // max_eps)
85 |
86 | # choose a random timestamp at the dataset
87 | timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
88 | # compute total context size + autoregressive generation length
89 |
90 | # the stops of the first page can be at most the length of the training example minus the autoregressive generation frames in the next pages
91 | stops = np.minimum(
92 | self.dataset.lengths[episode_ids] - (self.num_consecutive_batches - 1) * self.seq_length,
93 | timesteps + 1 + np.random.randint(0, total_length, len(timesteps))
94 | )
95 | # stops must be longer than the initial context + first page prediction size
96 | stops = np.maximum(stops, self.seq_length)
97 | # starts is stops minus the initial context and the first page prediction size
98 | starts = stops - self.seq_length
99 |
100 | return [SegmentId(*x, True) for x in zip(episode_ids, starts, stops)]
--------------------------------------------------------------------------------
/src/data/dataset.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import multiprocessing as mp
3 | from pathlib import Path
4 | import shutil
5 | from typing import Any, Dict, List, Optional
6 |
7 | import h5py
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import Dataset as TorchDataset
12 |
13 | from .episode import Episode
14 | from .segment import Segment, SegmentId
15 | from .utils import make_segment
16 | from utils import StateDictMixin
17 |
18 |
19 | class Dataset(StateDictMixin, TorchDataset):
20 | def __init__(
21 | self,
22 | directory: Path,
23 | dataset_full_res: Optional[TorchDataset],
24 | name: Optional[str] = None,
25 | cache_in_ram: bool = False,
26 | use_manager: bool = False,
27 | save_on_disk: bool = True,
28 | ) -> None:
29 | super().__init__()
30 |
31 | # State
32 | self.is_static = False
33 | self.num_episodes = None
34 | self.num_steps = None
35 | self.start_idx = None
36 | self.lengths = None
37 | self.counter_rew = None
38 | self.counter_end = None
39 |
40 | self._directory = Path(directory).expanduser()
41 | self._name = name if name is not None else self._directory.stem
42 | self._cache_in_ram = cache_in_ram
43 | self._save_on_disk = save_on_disk
44 | self._default_path = self._directory / "info.pt"
45 | self._cache = mp.Manager().dict() if use_manager else {}
46 | self._reset()
47 |
48 | self._dataset_full_res = dataset_full_res
49 |
50 | def __len__(self) -> int:
51 | return self.num_steps
52 |
53 | def __getitem__(self, segment_id: SegmentId) -> Segment:
54 | episode = self.load_episode(segment_id.episode_id)
55 | segment = make_segment(episode, segment_id, should_pad=True)
56 | if self._dataset_full_res is not None:
57 | segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop, segment_id.is_first_batch)
58 | segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs
59 | elif "full_res" in segment.info:
60 | segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop]
61 | return segment
62 |
63 | def __str__(self) -> str:
64 | return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps."
65 |
66 | @property
67 | def name(self) -> str:
68 | return self._name
69 |
70 | @property
71 | def counts_rew(self) -> List[int]:
72 | return [self.counter_rew[r] for r in [-1, 0, 1]]
73 |
74 | @property
75 | def counts_end(self) -> List[int]:
76 | return [self.counter_end[e] for e in [0, 1]]
77 |
78 | def _reset(self) -> None:
79 | self.num_episodes = 0
80 | self.num_steps = 0
81 | self.start_idx = np.array([], dtype=np.int64)
82 | self.lengths = np.array([], dtype=np.int64)
83 | self.counter_rew = Counter()
84 | self.counter_end = Counter()
85 | self._cache.clear()
86 |
87 | def clear(self) -> None:
88 | self.assert_not_static()
89 | if self._directory.is_dir():
90 | shutil.rmtree(self._directory)
91 | self._reset()
92 |
93 | def load_episode(self, episode_id: int) -> Episode:
94 | if self._cache_in_ram and episode_id in self._cache:
95 | episode = self._cache[episode_id]
96 | else:
97 | episode = Episode.load(self._get_episode_path(episode_id))
98 | if self._cache_in_ram:
99 | self._cache[episode_id] = episode
100 | return episode
101 |
102 | def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
103 | self.assert_not_static()
104 | episode = episode.to("cpu")
105 |
106 | if episode_id is None:
107 | episode_id = self.num_episodes
108 | self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
109 | self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
110 | self.num_steps += len(episode)
111 | self.num_episodes += 1
112 |
113 | else:
114 | assert episode_id < self.num_episodes
115 | old_episode = self.load_episode(episode_id)
116 | incr_num_steps = len(episode) - len(old_episode)
117 | self.lengths[episode_id] = len(episode)
118 | self.start_idx[episode_id + 1 :] += incr_num_steps
119 | self.num_steps += incr_num_steps
120 | self.counter_rew.subtract(old_episode.rew.sign().tolist())
121 | self.counter_end.subtract(old_episode.end.tolist())
122 |
123 | self.counter_rew.update(episode.rew.sign().tolist())
124 | self.counter_end.update(episode.end.tolist())
125 |
126 | if self._save_on_disk:
127 | episode.save(self._get_episode_path(episode_id))
128 |
129 | if self._cache_in_ram:
130 | self._cache[episode_id] = episode
131 |
132 | return episode_id
133 |
134 | def _get_episode_path(self, episode_id: int) -> Path:
135 | n = 3 # number of hierarchies
136 | powers = np.arange(n)
137 | subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers
138 | subfolders = [int(x) for x in subfolders[::-1]]
139 | subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)])
140 | return self._directory / subfolders / f"{episode_id}.pt"
141 |
142 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
143 | super().load_state_dict(state_dict)
144 | self._cache.clear()
145 |
146 | def assert_not_static(self) -> None:
147 | assert not self.is_static, "Trying to modify a static dataset."
148 |
149 | def save_to_default_path(self) -> None:
150 | self._default_path.parent.mkdir(exist_ok=True, parents=True)
151 | torch.save(self.state_dict(), self._default_path)
152 |
153 | def load_from_default_path(self) -> None:
154 | print(self._default_path)
155 | if self._default_path.is_file():
156 | self.load_state_dict(torch.load(self._default_path, weights_only=False))
157 |
158 |
159 | class GameHdf5Dataset(StateDictMixin, TorchDataset):
160 | def __init__(self, directory: Path) -> None:
161 | super().__init__()
162 | filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1]))
163 | self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames}
164 |
165 | self._length_one_episode = self._episode_lengths(self._filenames)
166 |
167 | self.num_episodes = len(self._filenames)
168 |
169 | self.num_steps = sum(list(self._length_one_episode.values()))
170 | self.lengths = np.array(list(self._length_one_episode.values()), dtype=np.int64)
171 |
172 | def _episode_lengths(self, filenames):
173 | length_one_episode = {}
174 |
175 | for filename in filenames:
176 | with h5py.File(filenames[filename], "r") as f:
177 | keys = f.keys()
178 | max_frame_index = max(int(key[len('frame_'):-len('_x')]) for key in keys if key.endswith('_x') and key.startswith('frame_'))
179 | length_one_episode[filename] = max_frame_index + 1
180 |
181 | return length_one_episode
182 |
183 | def __len__(self) -> int:
184 | return self.num_steps
185 |
186 | def save_to_default_path(self) -> None:
187 | pass
188 |
189 | def __getitem__(self, segment_id: SegmentId) -> Segment:
190 | episode_length = self._length_one_episode[segment_id.episode_id]
191 | assert segment_id.start < episode_length and segment_id.stop > 0 and segment_id.start < segment_id.stop
192 |
193 | pad_len_right = max(0, segment_id.stop - episode_length)
194 | pad_len_left = max(0, -segment_id.start)
195 |
196 | start = max(0, segment_id.start)
197 | stop = min(episode_length, segment_id.stop)
198 | mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
199 |
200 | #print(self._filenames[segment_id.episode_id])
201 | with h5py.File(self._filenames[segment_id.episode_id], "r") as f:
202 | obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)])
203 | act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)]))
204 |
205 | def pad(x):
206 | right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
207 | return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
208 |
209 | obs = pad(obs)
210 | act = pad(act)
211 | rew = torch.zeros(obs.size(0))
212 | end = torch.zeros(obs.size(0), dtype=torch.uint8)
213 | trunc = torch.zeros(obs.size(0), dtype=torch.uint8)
214 | return Segment(obs, act, rew, end, trunc, mask_padding, info={}, id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch))
215 |
216 | def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser
217 | episode_length = self._length_one_episode[episode_id]
218 | s = self[SegmentId(episode_id, 0, episode_length, None)]
219 | return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info)
--------------------------------------------------------------------------------
/src/data/episode.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Any, Dict, Optional
5 |
6 | import torch
7 |
8 |
9 | @dataclass
10 | class Episode:
11 | obs: torch.FloatTensor
12 | act: torch.LongTensor
13 | rew: torch.FloatTensor
14 | end: torch.ByteTensor
15 | trunc: torch.ByteTensor
16 | info: Dict[str, Any]
17 |
18 | def __len__(self) -> int:
19 | return self.obs.size(0)
20 |
21 | def __add__(self, other: Episode) -> Episode:
22 | assert self.dead.sum() == 0
23 | d = {k: torch.cat((v, other.__dict__[k]), dim=0) for k, v in self.__dict__.items() if k != "info"}
24 | return Episode(**d, info=merge_info(self.info, other.info))
25 |
26 | def to(self, device) -> Episode:
27 | return Episode(**{k: v.to(device) if k != "info" else v for k, v in self.__dict__.items()})
28 |
29 | @property
30 | def dead(self) -> torch.ByteTensor:
31 | return (self.end + self.trunc).clip(max=1)
32 |
33 | def compute_metrics(self) -> Dict[str, Any]:
34 | return {"length": len(self), "return": self.rew.sum().item()}
35 |
36 | @classmethod
37 | def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode:
38 | return cls(
39 | **{
40 | k: v.div(255).mul(2).sub(1) if k == "obs" else v
41 | for k, v in torch.load(Path(path), map_location=map_location).items()
42 | }
43 | )
44 |
45 | def save(self, path: Path) -> None:
46 | path = Path(path)
47 | path.parent.mkdir(parents=True, exist_ok=True)
48 | d = {k: v.add(1).div(2).mul(255).byte() if k == "obs" else v for k, v in self.__dict__.items()}
49 | torch.save(d, path.with_suffix(".tmp"))
50 | path.with_suffix(".tmp").rename(path)
51 |
52 |
53 | def merge_info(info_a, info_b):
54 | keys_a = set(info_a)
55 | keys_b = set(info_b)
56 | intersection = keys_a & keys_b
57 | info = {
58 | **{k: info_a[k] for k in keys_a if k not in intersection},
59 | **{k: info_b[k] for k in keys_b if k not in intersection},
60 | **{k: torch.cat((info_a[k], info_b[k]), dim=0) for k in intersection},
61 | }
62 | return info
63 |
--------------------------------------------------------------------------------
/src/data/segment.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Union
4 |
5 | import torch
6 |
7 |
8 | @dataclass
9 | class SegmentId:
10 | episode_id: Union[int, str]
11 | start: int
12 | stop: int
13 | is_first_batch: bool
14 |
15 |
16 | @dataclass
17 | class Segment:
18 | obs: torch.FloatTensor
19 | act: torch.LongTensor
20 | rew: torch.FloatTensor
21 | end: torch.ByteTensor
22 | trunc: torch.ByteTensor
23 | mask_padding: torch.BoolTensor
24 | info: Dict[str, Any]
25 | id: SegmentId
26 |
27 | @property
28 | def effective_size(self):
29 | return self.mask_padding.sum().item()
30 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Generator, List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from .batch import Batch
8 | from .episode import Episode
9 | from .segment import Segment, SegmentId
10 |
11 |
12 | def collate_segments_to_batch(segments: List[Segment]) -> Batch:
13 | attrs = ("obs", "act", "rew", "end", "trunc", "mask_padding")
14 | stack = (torch.stack([getattr(s, x) for s in segments]) for x in attrs)
15 | return Batch(*stack, [s.info for s in segments], [s.id for s in segments])
16 |
17 |
18 | def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
19 | if not (segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop):
20 | print(f'Failed assertion because: start={segment_id.start}, stop={segment_id.stop}, len(episode)={len(episode)}')
21 |
22 | assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
23 | pad_len_right = max(0, segment_id.stop - len(episode))
24 | pad_len_left = max(0, -segment_id.start)
25 | assert pad_len_right == pad_len_left == 0 or should_pad
26 |
27 | def pad(x):
28 | right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
29 | return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
30 |
31 | start = max(0, segment_id.start)
32 | stop = min(len(episode), segment_id.stop)
33 | mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
34 |
35 | return Segment(
36 | pad(episode.obs[start:stop]),
37 | pad(episode.act[start:stop]),
38 | pad(episode.rew[start:stop]),
39 | pad(episode.end[start:stop]),
40 | pad(episode.trunc[start:stop]),
41 | mask_padding,
42 | info=episode.info,
43 | id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch),
44 | )
45 |
46 |
47 | class DatasetTraverser:
48 | def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
49 | self.dataset = dataset
50 | self.batch_num_samples = batch_num_samples
51 | self.chunk_size = chunk_size
52 |
53 | def __len__(self):
54 | return math.ceil(
55 | sum(
56 | [
57 | math.ceil(self.dataset.lengths[episode_id] / self.chunk_size)
58 | - int(self.dataset.lengths[episode_id] % self.chunk_size == 1)
59 | for episode_id in range(self.dataset.num_episodes)
60 | ]
61 | )
62 | / self.batch_num_samples
63 | )
64 |
65 | def __iter__(self) -> Generator[Batch, None, None]:
66 | chunks = []
67 | for episode_id in range(self.dataset.num_episodes):
68 | episode = self.dataset.load_episode(episode_id)
69 | segments = []
70 | for i in range(math.ceil(len(episode) / self.chunk_size)):
71 | start = i * self.chunk_size
72 | stop = (i + 1) * self.chunk_size
73 | segment = make_segment(
74 | episode,
75 | SegmentId(episode_id, start, stop, None),
76 | should_pad=True,
77 | )
78 | segment_id_full_res = SegmentId(episode.info["original_file_id"], start, stop)
79 | segment.info["full_res"] = self.dataset._dataset_full_res[segment_id_full_res].obs
80 | chunks.append(segment)
81 | if chunks[-1].effective_size < 2:
82 | chunks.pop()
83 |
84 | while len(chunks) >= self.batch_num_samples:
85 | yield collate_segments_to_batch(chunks[: self.batch_num_samples])
86 | chunks = chunks[self.batch_num_samples :]
87 |
88 | if len(chunks) > 0:
89 | yield collate_segments_to_batch(chunks)
90 |
91 |
--------------------------------------------------------------------------------
/src/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .env import make_atari_env, TorchEnv
2 | from .world_model_env import WorldModelEnv, WorldModelEnvConfig
3 |
--------------------------------------------------------------------------------
/src/envs/atari_preprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | Derived from https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py
3 | Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from typing import Any, SupportsFloat
9 |
10 | import cv2
11 | import numpy as np
12 |
13 | import gymnasium as gym
14 | from gymnasium.core import WrapperActType, WrapperObsType
15 | from gymnasium.spaces import Box
16 |
17 |
18 | class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
19 | def __init__(
20 | self,
21 | env: gym.Env,
22 | noop_max: int,
23 | frame_skip: int,
24 | screen_size: int,
25 | ):
26 | gym.utils.RecordConstructorArgs.__init__(
27 | self,
28 | noop_max=noop_max,
29 | frame_skip=frame_skip,
30 | screen_size=screen_size,
31 | )
32 | gym.Wrapper.__init__(self, env)
33 |
34 | assert frame_skip > 0
35 | assert screen_size > 0
36 | assert noop_max >= 0
37 | if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
38 | raise ValueError(
39 | "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
40 | )
41 | self.noop_max = noop_max
42 | assert env.unwrapped.get_action_meanings()[0] == "NOOP"
43 |
44 | self.frame_skip = frame_skip
45 | self.screen_size = screen_size
46 |
47 | # buffer of most recent two observations for max pooling
48 | assert isinstance(env.observation_space, Box)
49 | self.obs_buffer = [
50 | np.empty(env.observation_space.shape, dtype=np.uint8),
51 | np.empty(env.observation_space.shape, dtype=np.uint8),
52 | ]
53 |
54 | self.lives = 0
55 | self.game_over = False
56 |
57 | _low, _high, _obs_dtype = (0, 255, np.uint8)
58 | _shape = (screen_size, screen_size, 3)
59 | self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
60 |
61 | @property
62 | def ale(self):
63 | """Make ale as a class property to avoid serialization error."""
64 | return self.env.unwrapped.ale
65 |
66 | def step(self, action: WrapperActType) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
67 | total_reward, terminated, truncated, info = 0.0, False, False, {}
68 |
69 | life_loss = False
70 |
71 | for t in range(self.frame_skip):
72 | _, reward, terminated, truncated, info = self.env.step(action)
73 | total_reward += reward
74 | self.game_over = terminated
75 |
76 | if self.ale.lives() < self.lives:
77 | life_loss = True
78 | self.lives = self.ale.lives()
79 |
80 | if terminated or truncated:
81 | break
82 |
83 | if t == self.frame_skip - 2:
84 | self.ale.getScreenRGB(self.obs_buffer[1])
85 | elif t == self.frame_skip - 1:
86 | self.ale.getScreenRGB(self.obs_buffer[0])
87 |
88 | info["life_loss"] = life_loss
89 |
90 | obs, original_obs = self._get_obs()
91 | info["original_obs"] = original_obs
92 |
93 | return obs, total_reward, terminated, truncated, info
94 |
95 | def reset(
96 | self, *, seed: int | None = None, options: dict[str, Any] | None = None
97 | ) -> tuple[WrapperObsType, dict[str, Any]]:
98 | """Resets the environment using preprocessing."""
99 | # NoopReset
100 | _, reset_info = self.env.reset(seed=seed, options=options)
101 |
102 | reset_info["life_loss"] = False
103 |
104 | noops = self.env.unwrapped.np_random.integers(1, self.noop_max + 1) if self.noop_max > 0 else 0
105 | for _ in range(noops):
106 | _, _, terminated, truncated, step_info = self.env.step(0)
107 | reset_info.update(step_info)
108 | if terminated or truncated:
109 | _, reset_info = self.env.reset(seed=seed, options=options)
110 |
111 | self.lives = self.ale.lives()
112 | self.ale.getScreenRGB(self.obs_buffer[0])
113 | self.obs_buffer[1].fill(0)
114 |
115 | obs, original_obs = self._get_obs()
116 | reset_info["original_obs"] = original_obs
117 |
118 | return obs, reset_info
119 |
120 | def _get_obs(self):
121 | if self.frame_skip > 1: # more efficient in-place pooling
122 | np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
123 |
124 | original_obs = self.obs_buffer[0]
125 | obs = cv2.resize(
126 | original_obs,
127 | (self.screen_size, self.screen_size),
128 | interpolation=cv2.INTER_AREA,
129 | )
130 |
131 | obs = np.asarray(obs, dtype=np.uint8)
132 |
133 | return obs, original_obs
134 |
--------------------------------------------------------------------------------
/src/envs/env.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Any, Dict, Optional, Tuple
3 |
4 | import ale_py
5 | import gymnasium
6 | from gymnasium.vector import AsyncVectorEnv
7 | import numpy as np
8 | import torch
9 | from torch import Tensor
10 |
11 | from .atari_preprocessing import AtariPreprocessing
12 |
13 |
14 | def make_atari_env(
15 | id: str,
16 | num_envs: int,
17 | device: torch.device,
18 | done_on_life_loss: bool,
19 | size: int,
20 | max_episode_steps: Optional[int],
21 | ) -> TorchEnv:
22 | def env_fn():
23 | env = gymnasium.make(
24 | id,
25 | full_action_space=False,
26 | frameskip=1,
27 | render_mode="rgb_array",
28 | max_episode_steps=max_episode_steps,
29 | )
30 | env = AtariPreprocessing(
31 | env=env,
32 | noop_max=30,
33 | frame_skip=4,
34 | screen_size=size,
35 | )
36 | return env
37 |
38 | env = AsyncVectorEnv([env_fn for _ in range(num_envs)])
39 |
40 | # The AsyncVectorEnv resets the env on termination, which means that it will
41 | # reset the environment if we use the default AtariPreprocessing of gymnasium with
42 | # terminate_on_life_loss=True (which means that we will only see the first life).
43 | # Hence a separate wrapper for life_loss, coming after the AsyncVectorEnv.
44 |
45 | if done_on_life_loss:
46 | env = DoneOnLifeLoss(env)
47 |
48 | env = TorchEnv(env, device)
49 |
50 | return env
51 |
52 |
53 | class DoneOnLifeLoss(gymnasium.Wrapper):
54 | def __init__(self, env: AsyncVectorEnv) -> None:
55 | super().__init__(env)
56 |
57 | def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
58 | obs, rew, end, trunc, info = self.env.step(actions)
59 | life_loss = info["life_loss"]
60 | if life_loss.any():
61 | end[life_loss] = True
62 | info["final_observation"] = obs
63 | return obs, rew, end, trunc, info
64 |
65 |
66 | class TorchEnv(gymnasium.Wrapper):
67 | def __init__(self, env: gymnasium.Env, device: torch.device) -> None:
68 | super().__init__(env)
69 | self.device = device
70 | self.num_envs = env.observation_space.shape[0]
71 | self.num_actions = env.unwrapped.single_action_space.n
72 | b, h, w, c = env.observation_space.shape
73 | self.observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(b, c, h, w))
74 |
75 | def reset(self, *args, **kwargs) -> Tuple[Tensor, Dict[str, Any]]:
76 | obs, info = self.env.reset(*args, **kwargs)
77 | return self._to_tensor(obs), info
78 |
79 | def step(self, actions: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
80 | obs, rew, end, trunc, info = self.env.step(actions.cpu().numpy())
81 | dead = np.logical_or(end, trunc)
82 | if dead.any():
83 | info["final_observation"] = self._to_tensor(np.stack(info["final_observation"][dead]))
84 | obs, rew, end, trunc = (self._to_tensor(x) for x in (obs, rew, end, trunc))
85 | return obs, rew, end, trunc, info
86 |
87 | def _to_tensor(self, x: Tensor) -> Tensor:
88 | if x.ndim == 4:
89 | return torch.tensor(x, device=self.device).div(255).mul(2).sub(1).permute(0, 3, 1, 2).contiguous()
90 | elif x.dtype is np.dtype("bool"):
91 | return torch.tensor(x, dtype=torch.uint8, device=self.device)
92 | else:
93 | return torch.tensor(x, dtype=torch.float32, device=self.device)
94 |
--------------------------------------------------------------------------------
/src/envs/world_model_env.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from itertools import cycle
3 | from pathlib import Path
4 | from typing import Any, Dict, Generator, List, Optional, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | from torch import Tensor
9 | from torch.distributions.categorical import Categorical
10 | import torch.nn.functional as F
11 |
12 | from coroutines import coroutine
13 | from models.diffusion import Denoiser, DiffusionSampler, DiffusionSamplerConfig
14 | from models.rew_end_model import RewEndModel
15 |
16 | from utils import get_frame_indices
17 |
18 | ResetOutput = Tuple[torch.FloatTensor, Dict[str, Any]]
19 | StepOutput = Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]
20 | InitialCondition = Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]
21 |
22 |
23 | crop_frame = {
24 | 'left_top': (0.04, 0.18),
25 | 'right_bottom': (0.92, 0.95)
26 | }
27 |
28 | def extract_roi(image, rect):
29 | _, _, img_width, img_height, = image.shape
30 | min_x = int(rect['left_top'][1] * img_width)
31 | max_x = int(rect['right_bottom'][1] * img_width)
32 | min_y = int(rect['left_top'][0] * img_height)
33 | max_y = int(rect['right_bottom'][0] * img_height)
34 | roi = image[:,:,min_x:max_x, min_y:max_y]
35 | return roi
36 |
37 | @dataclass
38 | class WorldModelEnvConfig:
39 | horizon: int
40 | num_batches_to_preload: int
41 | diffusion_sampler_next_obs: DiffusionSamplerConfig
42 | diffusion_sampler_upsampling: Optional[DiffusionSamplerConfig] = None
43 |
44 |
45 | class WorldModelEnv:
46 | def __init__(
47 | self,
48 | denoiser: Denoiser,
49 | upsampler: Optional[Denoiser],
50 | rew_end_model: Optional[RewEndModel],
51 | spawn_dir: Path,
52 | num_envs: int,
53 | seq_length: int,
54 | cfg: WorldModelEnvConfig,
55 | return_denoising_trajectory: bool = False,
56 | ) -> None:
57 | assert num_envs == 1 # for human play only
58 | self.sampler_next_obs = DiffusionSampler(denoiser, cfg.diffusion_sampler_next_obs)
59 | self.sampler_upsampling = None if upsampler is None else DiffusionSampler(upsampler, cfg.diffusion_sampler_upsampling)
60 | self.rew_end_model = rew_end_model
61 | self.horizon = cfg.horizon
62 | self.return_denoising_trajectory = return_denoising_trajectory
63 | self.num_envs = num_envs
64 | self.generator_init = self.make_generator_init(spawn_dir, cfg.num_batches_to_preload)
65 |
66 | self.n_skip_next_obs = seq_length - self.sampler_next_obs.denoiser.cfg.inner_model.num_steps_conditioning
67 | self.n_skip_upsampling = None if upsampler is None else seq_length - self.sampler_upsampling.denoiser.cfg.inner_model.num_steps_conditioning
68 |
69 | self.context_indicies = get_frame_indices(denoiser.cfg.frame_sampling)
70 |
71 | @property
72 | def device(self) -> torch.device:
73 | return self.sampler_next_obs.denoiser.device
74 |
75 | @torch.no_grad()
76 | def reset(self, **kwargs) -> ResetOutput:
77 | obs, obs_full_res, act, next_act, (hx, cx) = self.generator_init.send(self.num_envs)
78 | self.obs_buffer = obs
79 | self.act_buffer = act
80 | self.next_act = next_act[0]
81 | self.obs_full_res_buffer = obs_full_res
82 | self.ep_len = torch.zeros(self.num_envs, dtype=torch.long, device=obs.device)
83 | self.hx_rew_end = hx
84 | self.cx_rew_end = cx
85 | obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1]
86 | return obs_to_return, {}
87 |
88 | @torch.no_grad()
89 | def step(self, act: torch.LongTensor) -> StepOutput:
90 | self.act_buffer[:, -1] = act
91 |
92 | next_obs, denoising_trajectory = self.predict_next_obs()
93 |
94 | if self.sampler_upsampling is not None:
95 | next_obs_full, denoising_trajectory_upsampling = self.upsample_next_obs(next_obs)
96 |
97 | if self.rew_end_model is not None:
98 | rew, end = self.predict_rew_end(next_obs.unsqueeze(1))
99 | else:
100 | rew = torch.zeros(next_obs.size(0), dtype=torch.float32, device=self.device)
101 | end = torch.zeros(next_obs.size(0), dtype=torch.int64, device=self.device)
102 |
103 | self.ep_len += 1
104 | trunc = (self.ep_len >= self.horizon).long()
105 |
106 | self.obs_buffer = self.obs_buffer.roll(-1, dims=1)
107 | self.act_buffer = self.act_buffer.roll(-1, dims=1)
108 | self.obs_buffer[:, -1] = next_obs
109 |
110 | if self.sampler_upsampling is not None:
111 | self.obs_full_res_buffer = self.obs_full_res_buffer.roll(-1, dims=1)
112 | self.obs_full_res_buffer[:, -1] = next_obs_full
113 |
114 | info = {}
115 | if self.return_denoising_trajectory:
116 | info["denoising_trajectory"] = torch.stack(denoising_trajectory, dim=1)
117 |
118 | if self.sampler_upsampling is not None:
119 | info["obs_low_res"] = next_obs
120 | if self.return_denoising_trajectory:
121 | info["denoising_trajectory_upsampling"] = torch.stack(denoising_trajectory_upsampling, dim=1)
122 |
123 | obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1]
124 | return obs_to_return, rew, end, trunc, info
125 |
126 | @torch.no_grad()
127 | def predict_next_obs(self) -> Tuple[Tensor, List[Tensor]]:
128 | obs = self.obs_buffer[:, self.n_skip_next_obs:][:,self.context_indicies]
129 | act = self.act_buffer[:, self.n_skip_next_obs:][:,self.context_indicies]
130 |
131 | return self.sampler_next_obs.sample(obs, act)
132 |
133 | @torch.no_grad()
134 | def upsample_next_obs(self, next_obs: Tensor) -> Tuple[Tensor, List[Tensor]]:
135 | # Upsampling the low resolution frame from the world model
136 | low_res = F.interpolate(next_obs, scale_factor=self.sampler_upsampling.denoiser.cfg.upsampling_factor, mode="bicubic")
137 |
138 | # Cropping the frame to remove the sides of the screen
139 | low_res = extract_roi(low_res, crop_frame)
140 |
141 | # Reshape the frame to the upsampler's expected size
142 | size = (self.sampler_upsampling.denoiser.cfg.upsampling_frame_height, self.sampler_upsampling.denoiser.cfg.upsampling_frame_width)
143 | low_res = F.interpolate(low_res, size=size, mode='bicubic')
144 |
145 | return self.sampler_upsampling.sample(low_res.unsqueeze(1), None)
146 |
147 | @torch.no_grad()
148 | def predict_rew_end(self, next_obs: Tensor) -> Tuple[Tensor, Tensor]:
149 | logits_rew, logits_end, (self.hx_rew_end, self.cx_rew_end) = self.rew_end_model.predict_rew_end(
150 | self.obs_buffer[:, -1:],
151 | self.act_buffer[:, -1:],
152 | next_obs,
153 | (self.hx_rew_end, self.cx_rew_end),
154 | )
155 | rew = Categorical(logits=logits_rew).sample().squeeze(1) - 1.0 # in {-1, 0, 1}
156 | end = Categorical(logits=logits_end).sample().squeeze(1)
157 | return rew, end
158 |
159 | @coroutine
160 | def make_generator_init(
161 | self,
162 | spawn_dir: Path,
163 | num_batches_to_preload: int,
164 | ) -> Generator[InitialCondition, None, None]:
165 | num_dead = yield
166 |
167 | spawn_dirs = cycle(sorted(list(spawn_dir.iterdir())))
168 |
169 | while True:
170 | # Preload on device and burnin rew/end model
171 | obs_, obs_full_res_, act_, next_act_, hx_, cx_ = [], [], [], [], [], []
172 | for _ in range(num_batches_to_preload):
173 | d = next(spawn_dirs)
174 | obs = torch.tensor(np.load(d / "low_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0)
175 | obs_full_res = torch.tensor(np.load(d / "full_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0)
176 | act = torch.tensor(np.load(d / "act.npy"), dtype=torch.long, device=self.device).unsqueeze(0)
177 | next_act = torch.tensor(np.load(d / "next_act.npy"), dtype=torch.long, device=self.device).unsqueeze(0)
178 |
179 | obs_.extend(list(obs))
180 | obs_full_res_.extend(list(obs_full_res))
181 | act_.extend(list(act))
182 | next_act_.extend(list(next_act))
183 |
184 | if self.rew_end_model is not None:
185 | with torch.no_grad():
186 | *_, (hx, cx) = self.rew_end_model.predict_rew_end(obs_[:, :-1], act[:, :-1], obs[:, 1:]) # Burn-in of rew/end model
187 | assert hx.size(0) == cx.size(0) == 1
188 | hx_.extend(list(hx[0]))
189 | cx_.extend(list(cx[0]))
190 |
191 | # Yield new initial conditions for dead envs
192 | c = 0
193 | while c + num_dead <= len(obs_):
194 | obs = torch.stack(obs_[c : c + num_dead])
195 | act = torch.stack(act_[c : c + num_dead])
196 | next_act = next_act_[c : c + num_dead]
197 | obs_full_res = torch.stack(obs_full_res_[c : c + num_dead]) if self.sampler_upsampling is not None else None
198 | hx = torch.stack(hx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None
199 | cx = torch.stack(cx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None
200 | c += num_dead
201 | num_dead = yield obs, obs_full_res, act, next_act, (hx, cx)
202 |
--------------------------------------------------------------------------------
/src/game/__init__.py:
--------------------------------------------------------------------------------
1 | from .game import Game
2 | from .play_env import PlayEnv
3 |
--------------------------------------------------------------------------------
/src/game/dataset_env.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from data import Dataset
7 |
8 |
9 | class DatasetEnv:
10 | def __init__(self, datasets: List[Dataset], action_names: List[str]) -> None:
11 | self.datasets = [d for d in datasets if len(d) > 0]
12 | assert len(self.datasets) > 0
13 | self.action_names = action_names
14 | self.dataset_id = 0
15 | self.dataset = self.datasets[0]
16 | self.episode_id = None
17 | self.episode = None
18 | self.t = None
19 | self.ep_return = None
20 | self.ep_length = None
21 | self.pos_return = None
22 | self.neg_return = None
23 | self.load_episode(0)
24 |
25 | def print_controls(self) -> None:
26 | print("\nControls (dataset mode):\n")
27 | print(f"m : datasets ({'/'.join([d.name for d in self.datasets])})")
28 | print("↑ : next episode")
29 | print("↓ : prev episode")
30 | print("→ : next timestep")
31 | print("← : prev timestep")
32 |
33 | def next_mode(self) -> bool:
34 | self.switch_dataset()
35 | return True
36 |
37 | def next_axis_1(self) -> bool:
38 | self.load_episode(self.episode_id + 1)
39 | return True
40 |
41 | def prev_axis_1(self) -> bool:
42 | self.load_episode(self.episode_id - 1)
43 | return True
44 |
45 | def next_axis_2(self) -> bool:
46 | return False
47 |
48 | def prev_axis_2(self) -> bool:
49 | return False
50 |
51 | def load_episode(self, episode_id: int) -> None:
52 | self.episode_id = episode_id % self.dataset.num_episodes
53 | self.episode = self.dataset.load_episode(self.episode_id)
54 | self.set_timestep(0)
55 | metrics = self.episode.compute_metrics()
56 | self.ep_return = metrics["return"]
57 | self.ep_length = metrics["length"]
58 | self.pos_return = self.episode.rew[self.episode.rew > 0].sum().item()
59 | self.neg_return = self.episode.rew[self.episode.rew < 0].sum().abs().item()
60 |
61 | def set_timestep(self, timestep: int) -> None:
62 | self.t = timestep % len(self.episode)
63 | self.obs = self.episode.obs[self.t].unsqueeze(0)
64 | self.act = self.episode.act[self.t]
65 | self.rew = self.episode.rew[self.t]
66 | self.end = self.episode.end[self.t]
67 | self.trunc = self.episode.trunc[self.t]
68 |
69 | def switch_dataset(self) -> None:
70 | self.dataset_id = (self.dataset_id + 1) % len(self.datasets)
71 | self.dataset = self.datasets[self.dataset_id]
72 | self.load_episode(0)
73 |
74 | def reset(self) -> None:
75 | self.set_timestep(0)
76 | return self.obs, None
77 |
78 | @torch.no_grad()
79 | def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
80 | match act:
81 | case 1:
82 | self.set_timestep(self.t - 1)
83 | case 2:
84 | self.set_timestep(self.t + 1)
85 | case 3:
86 | self.set_timestep(self.t - 10)
87 | case 4:
88 | self.set_timestep(self.t + 10)
89 |
90 | n_digits = len(str(self.ep_length))
91 |
92 | header = [
93 | [
94 | f"Dataset: {self.dataset.name}",
95 | f"Episode: {self.episode_id}",
96 | "--------",
97 | f"Return (+): +{self.pos_return:4.1f}",
98 | f"Return (-): -{self.neg_return:4.1f}",
99 | f"Total : {self.ep_return:4.1f}",
100 | ],
101 | [
102 | f"Action: {self.action_names[self.act]}",
103 | f"Trunc : {bool(self.trunc)}",
104 | f"Done : {bool(self.end)}",
105 | f"Reward: {self.rew.item():.2f}",
106 | "-------",
107 | f"To here: {self.episode.rew[:self.t + 1].sum().item():.2f}",
108 | f"To go : {self.episode.rew[self.t + 1:].sum().item():.2f}",
109 | ],
110 | [
111 | f"Timestep: {self.t:{n_digits}d}",
112 | f"Length : {self.ep_length}",
113 | ],
114 | ]
115 | info = {"header": header}
116 | return self.obs, torch.tensor(0), False, False, info
117 |
--------------------------------------------------------------------------------
/src/game/game.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import numpy as np
4 | import pygame
5 | from PIL import Image
6 |
7 | from player.action_processing import GameAction
8 | from .dataset_env import DatasetEnv
9 | from .play_env import PlayEnv
10 |
11 | class Game:
12 | def __init__(
13 | self,
14 | play_env: Union[PlayEnv, DatasetEnv],
15 | size: Tuple[int, int],
16 | fps: int,
17 | verbose: bool,
18 | ) -> None:
19 | self.env = play_env
20 | self.height, self.width = size
21 | self.fps = fps
22 | self.verbose = verbose
23 | self.env.print_controls()
24 | print("\nControls:\n")
25 | print(" m : switch control (human/replay)") # Not for main as Game can use either PlayEnv or DatasetEnv
26 | print(" . : pause/unpause")
27 | print(" e : step-by-step (when paused)")
28 | print(" ⏎ : reset env")
29 | print("Esc : quit")
30 | print("\n")
31 | input("Press enter to start")
32 |
33 | def run(self) -> None:
34 | pygame.init()
35 |
36 | header_height = 150 if self.verbose else 0
37 | header_width = 540
38 | font_size = 16
39 | screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
40 | pygame.mouse.set_visible(False)
41 | pygame.event.set_grab(True)
42 | clock = pygame.time.Clock()
43 | font = pygame.font.SysFont("mono", font_size)
44 | x_center, y_center = screen.get_rect().center
45 | x_header = x_center - header_width // 2
46 | y_header = y_center - self.height // 2 - header_height - 10
47 | header_rect = pygame.Rect(x_header, y_header, header_width, header_height)
48 |
49 | def clear_header():
50 | pygame.draw.rect(screen, pygame.Color("black"), header_rect)
51 | pygame.draw.rect(screen, pygame.Color("white"), header_rect, 1)
52 |
53 | def draw_text(text, idx_line, idx_column, num_cols):
54 | x_pos = 5 + idx_column * int(header_width // num_cols)
55 | y_pos = 5 + idx_line * font_size
56 | assert (0 <= x_pos <= header_width) and (0 <= y_pos <= header_height)
57 | screen.blit(font.render(text, True, pygame.Color("white")), (x_header + x_pos, y_header + y_pos))
58 |
59 | def draw_obs(obs, obs_low_res=None):
60 | assert obs.ndim == 4 and obs.size(0) == 1
61 | two_players = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
62 | player_1_frame, player_2_frame = two_players[:,:,:3], two_players[:, :, 3:]
63 |
64 | # Stack images vertically
65 | stacked_arr = np.vstack((player_1_frame, player_2_frame))
66 |
67 | # Convert back to image
68 | img = Image.fromarray(stacked_arr)
69 |
70 | # resize the images, and prepare it for display
71 | pygame_image = np.array(img.resize((self.width, self.height), resample=Image.BOX)).transpose((1, 0, 2))
72 |
73 | surface = pygame.surfarray.make_surface(pygame_image)
74 | screen.blit(surface, (x_center - self.width // 2, y_center - self.height // 2))
75 |
76 | if obs_low_res is not None:
77 | assert obs_low_res.ndim == 4 and obs_low_res.size(0) == 1
78 | img = Image.fromarray(obs_low_res[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
79 | h = self.height * obs_low_res.size(2) // obs.size(2)
80 | w = self.width * obs_low_res.size(3) // obs.size(3)
81 | pygame_image = np.array(img)
82 | surface = pygame.surfarray.make_surface(pygame_image)
83 | screen.blit(surface, (x_center - w // 2, y_center + self.height // 2))
84 |
85 | def reset():
86 | nonlocal obs, info, do_reset, ep_return, ep_length, keys_pressed
87 | obs, info = self.env.reset()
88 | pygame.event.clear()
89 | do_reset = False
90 | ep_return = 0
91 | ep_length = 0
92 | keys_pressed = []
93 |
94 | obs, info, do_reset, ep_return, ep_length, keys_pressed= (None,) * 6
95 |
96 | reset()
97 | do_wait = False
98 | should_stop = False
99 |
100 | while not should_stop:
101 | do_one_step = False
102 | pygame.event.pump()
103 |
104 | for event in pygame.event.get():
105 | if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
106 | should_stop = True
107 |
108 | if event.type == pygame.KEYDOWN:
109 | keys_pressed.append(event.key)
110 |
111 | elif event.type == pygame.KEYUP and event.key in keys_pressed:
112 | keys_pressed.remove(event.key)
113 |
114 | if event.type != pygame.KEYDOWN:
115 | continue
116 |
117 | if event.key == pygame.K_RETURN:
118 | do_reset = True
119 |
120 | if event.key == pygame.K_PERIOD:
121 | do_wait = not do_wait
122 | print("Game paused." if do_wait else "Game resumed.")
123 |
124 | if event.key == pygame.K_e:
125 | do_one_step = True
126 |
127 | if event.key == pygame.K_m:
128 | do_reset = self.env.next_mode()
129 |
130 | if event.key == pygame.K_UP:
131 | do_reset = self.env.next_axis_1()
132 |
133 | if event.key == pygame.K_DOWN:
134 | do_reset = self.env.prev_axis_1()
135 |
136 | if event.key == pygame.K_RIGHT:
137 | do_reset = self.env.next_axis_2()
138 |
139 | if event.key == pygame.K_LEFT:
140 | do_reset = self.env.prev_axis_2()
141 |
142 | if do_reset:
143 | reset()
144 |
145 | if do_wait and not do_one_step:
146 | continue
147 |
148 | game_action = GameAction(keys_pressed)
149 | next_obs, rew, end, trunc, info = self.env.step(game_action)
150 |
151 | ep_return += rew.item()
152 | ep_length += 1
153 |
154 | if self.verbose and info is not None:
155 | clear_header()
156 | assert isinstance(info, dict) and "header" in info
157 | header = info["header"]
158 | num_cols = len(header)
159 | for j, col in enumerate(header):
160 | for i, row in enumerate(col):
161 | draw_text(row, idx_line=i, idx_column=j, num_cols=num_cols)
162 |
163 | draw_low_res = self.verbose and "obs_low_res" in info and self.width == 280
164 | if draw_low_res:
165 | draw_obs(obs, info["obs_low_res"])
166 | draw_text(" Pre-upsampling:", 0, 2, 3)
167 | else:
168 | draw_obs(obs, None)
169 |
170 | pygame.display.flip() # update screen
171 | clock.tick(self.fps) # ensures game maintains the given frame rate
172 |
173 | if end or trunc:
174 | reset()
175 |
176 | else:
177 | obs = next_obs
178 |
179 | pygame.quit()
180 |
--------------------------------------------------------------------------------
/src/game/play_env.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 | import math
3 | from pathlib import Path
4 | from typing import Any, Dict, List, Tuple
5 |
6 | import pygame
7 | import torch
8 | from torch import Tensor
9 |
10 | from agent import Agent
11 | from player.action_processing import GameAction, decode_game_action, encode_game_action, print_game_action
12 | from player.keymap import GAME_KEYMAP
13 | from data import Dataset, Episode
14 | from envs import WorldModelEnv
15 |
16 |
17 | NamedEnv = namedtuple("NamedEnv", "name env")
18 | OneStepData = namedtuple("OneStepData", "obs act rew end trunc")
19 |
20 |
21 | class PlayEnv:
22 | def __init__(
23 | self,
24 | agent: Agent,
25 | wm_env: WorldModelEnv,
26 | recording_mode: bool,
27 | store_denoising_trajectory: bool,
28 | store_original_obs: bool,
29 | ) -> None:
30 | self.agent = agent
31 | self.keymap = GAME_KEYMAP
32 | self.recording_mode = recording_mode
33 | self.store_denoising_trajectory = store_denoising_trajectory
34 | self.store_original_obs = store_original_obs
35 | self.is_human_player = True
36 | self.env_id = 0
37 | self.env_name = "world model"
38 | self.env = wm_env
39 | self.obs, self.t, self.buffer, self.rec_dataset = (None,) * 4
40 |
41 | def print_controls(self) -> None:
42 | print("\nEnvironment actions:\n")
43 | for key, action_name in self.keymap.items():
44 | if key is not None:
45 | key_name = pygame.key.name(key)
46 | key_name = "⎵" if key_name == "space" else key_name
47 | print(f"{key_name} : {action_name}")
48 |
49 | def next_mode(self) -> bool:
50 | self.switch_controller()
51 | return True
52 |
53 | def next_axis_1(self) -> bool:
54 | return False
55 |
56 | def prev_axis_1(self) -> bool:
57 | return False
58 |
59 | def next_axis_2(self) -> bool:
60 | return False
61 |
62 | def prev_axis_2(self) -> bool:
63 | return False
64 |
65 | def print_env(self) -> None:
66 | print(f"> Environment: {self.env_name}")
67 |
68 | def str_control(self) -> str:
69 | return "human" if self.is_human_player else "replay actions (test dataset)"
70 |
71 | def print_control(self) -> None:
72 | print(f"> Control: {self.str_control()}")
73 |
74 | def switch_controller(self) -> None:
75 | self.is_human_player = not self.is_human_player
76 | self.print_control()
77 |
78 | def update_wm_horizon(self, incr: int) -> None:
79 | self.env.horizon = max(1, self.env.horizon + incr)
80 |
81 | def reset_recording(self) -> None:
82 | self.buffer = defaultdict(list)
83 | self.buffer["info"] = defaultdict(list)
84 | dir = Path("dataset") / f"rec_{self.env_name}_{'H' if self.is_human_player else 'R'}"
85 | self.rec_dataset = Dataset(dir, None)
86 | self.rec_dataset.load_from_default_path()
87 |
88 | def reset(self) -> Tuple[Tensor, None]:
89 | self.obs, _ = self.env.reset()
90 | self.t = 0
91 | if self.recording_mode:
92 | self.reset_recording()
93 | return self.obs, None
94 |
95 | @torch.no_grad()
96 | def step(self, game_action: GameAction) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
97 | if self.is_human_player:
98 | action = encode_game_action(game_action, device=self.agent.device)
99 | else:
100 | action = self.env.next_act[self.t - 1] if self.t > 0 else self.env.act_buffer[0, -1].clone()
101 | game_action = decode_game_action(action.cpu())
102 |
103 | next_obs, rew, end, trunc, env_info = self.env.step(action)
104 |
105 | if not self.is_human_player and self.t == self.env.next_act.size(0):
106 | trunc[0] = 1
107 |
108 | data = OneStepData(self.obs, action, rew, end, trunc)
109 | keys = print_game_action(game_action)
110 | horizon = self.env.horizon if self.is_human_player else min(self.env.horizon, self.env.next_act.size(0))
111 | header = [
112 | [
113 | f"Env : {self.env_name}",
114 | f"Control : {self.str_control()}",
115 | f"Timestep: {self.t + 1}",
116 | f"Horizon : {horizon}",
117 | f"Keys : {keys}",
118 | ],
119 | ]
120 | info = {"header": header}
121 | if "obs_low_res" in env_info:
122 | info["obs_low_res"] = env_info["obs_low_res"]
123 |
124 | if self.recording_mode:
125 | for k, v in data._asdict().items():
126 | self.buffer[k].append(v)
127 | if "obs_low_res" in env_info:
128 | self.buffer["info"]["obs_low_res"].append(env_info["obs_low_res"])
129 | if self.store_denoising_trajectory and "denoising_trajectory" in env_info:
130 | self.buffer["info"]["denoising_trajectory"].append(env_info["denoising_trajectory"])
131 | if self.store_original_obs and "original_obs" in env_info:
132 | original_obs = (torch.tensor(env_info["original_obs"][0]).permute(2, 0, 1).unsqueeze(0).contiguous())
133 | self.buffer["info"]["original_obs"].append(original_obs)
134 | if end or trunc:
135 | ep_dict = {k: torch.cat(v, dim=0) for k, v in self.buffer.items() if k != "info"}
136 | ep_info = {k: torch.cat(v, dim=0) for k, v in self.buffer["info"].items()}
137 | ep = Episode(**ep_dict, info=ep_info).to("cpu")
138 | self.rec_dataset.add_episode(ep)
139 | self.rec_dataset.save_to_default_path()
140 |
141 | self.obs = next_obs
142 | self.t += 1
143 |
144 | return next_obs, rew, end, trunc, info
145 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import List, Union
4 |
5 | import hydra
6 | from omegaconf import DictConfig, OmegaConf
7 | import torch
8 | from torch.distributed import init_process_group, destroy_process_group
9 | import torch.multiprocessing as mp
10 |
11 | from trainer import Trainer
12 | from utils import skip_if_run_is_over
13 |
14 |
15 | OmegaConf.register_new_resolver("eval", eval)
16 |
17 |
18 | @hydra.main(config_path="../config", config_name="trainer", version_base="1.3")
19 | def main(cfg: DictConfig) -> None:
20 | setup_visible_cuda_devices(cfg.common.devices)
21 | world_size = torch.cuda.device_count()
22 | root_dir = Path(hydra.utils.get_original_cwd())
23 | if world_size < 2:
24 | run(cfg, root_dir)
25 | else:
26 | mp.spawn(main_ddp, args=(world_size, cfg, root_dir), nprocs=world_size)
27 |
28 |
29 | def main_ddp(rank: int, world_size: int, cfg: DictConfig, root_dir: Path) -> None:
30 | setup_ddp(rank, world_size)
31 | run(cfg, root_dir)
32 | destroy_process_group()
33 |
34 |
35 | @skip_if_run_is_over
36 | def run(cfg: DictConfig, root_dir: Path) -> None:
37 | trainer = Trainer(cfg, root_dir)
38 | trainer.run()
39 |
40 |
41 | def setup_ddp(rank: int, world_size: int) -> None:
42 | os.environ["MASTER_ADDR"] = "localhost"
43 | os.environ["MASTER_PORT"] = "6006"
44 | init_process_group(backend="nccl", rank=rank, world_size=world_size)
45 |
46 |
47 | def setup_visible_cuda_devices(devices: Union[str, int, List[int]]) -> None:
48 | if isinstance(devices, str):
49 | if devices == "cpu":
50 | devices = []
51 | else:
52 | assert devices == "all"
53 | return
54 | elif isinstance(devices, int):
55 | devices = [devices]
56 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices))
57 |
58 |
59 | if __name__ == "__main__":
60 | main()
61 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/actor_critic.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass
3 | import math
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import torch
7 | from torch import Tensor
8 | import torch.nn as nn
9 | from torch.distributions.categorical import Categorical
10 | import torch.nn.functional as F
11 |
12 | from .blocks import Conv3x3, SmallResBlock
13 | from coroutines.env_loop import make_env_loop
14 | from envs import TorchEnv, WorldModelEnv
15 | from utils import init_lstm, LossAndLogs
16 |
17 |
18 | ActorCriticOutput = namedtuple("ActorCriticOutput", "logits_act val hx_cx")
19 |
20 |
21 | @dataclass
22 | class ActorCriticLossConfig:
23 | backup_every: int
24 | gamma: float
25 | lambda_: float
26 | weight_value_loss: float
27 | weight_entropy_loss: float
28 |
29 |
30 | @dataclass
31 | class ActorCriticConfig:
32 | lstm_dim: int
33 | img_channels: int
34 | img_size: int
35 | channels: List[int]
36 | down: List[int]
37 | num_actions: Optional[int] = None
38 |
39 |
40 | class ActorCritic(nn.Module):
41 | def __init__(self, cfg: ActorCriticConfig) -> None:
42 | super().__init__()
43 | self.encoder = ActorCriticEncoder(cfg)
44 | self.lstm_dim = cfg.lstm_dim
45 | input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (sum(cfg.down))) ** 2
46 | self.lstm = nn.LSTMCell(input_dim_lstm, cfg.lstm_dim)
47 | self.critic_linear = nn.Linear(cfg.lstm_dim, 1)
48 | self.actor_linear = nn.Linear(cfg.lstm_dim, cfg.num_actions)
49 |
50 | self.actor_linear.weight.data.fill_(0)
51 | self.actor_linear.bias.data.fill_(0)
52 | self.critic_linear.weight.data.fill_(0)
53 | self.critic_linear.bias.data.fill_(0)
54 | init_lstm(self.lstm)
55 |
56 | self.env_loop = None
57 | self.loss_cfg = None
58 |
59 | @property
60 | def device(self) -> torch.device:
61 | return self.lstm.weight_hh.device
62 |
63 | def setup_training(self, rl_env: Union[TorchEnv, WorldModelEnv], loss_cfg: ActorCriticLossConfig) -> None:
64 | assert self.env_loop is None and self.loss_cfg is None
65 | self.env_loop = make_env_loop(rl_env, self)
66 | self.loss_cfg = loss_cfg
67 |
68 | def predict_act_value(self, obs: Tensor, hx_cx: Tuple[Tensor, Tensor]) -> ActorCriticOutput:
69 | assert obs.ndim == 4
70 | x = self.encoder(obs)
71 | x = x.flatten(start_dim=1)
72 | hx, cx = self.lstm(x, hx_cx)
73 | return ActorCriticOutput(self.actor_linear(hx), self.critic_linear(hx).squeeze(dim=1), (hx, cx))
74 |
75 | def forward(self) -> LossAndLogs:
76 | c = self.loss_cfg
77 | _, act, rew, end, trunc, logits_act, val, val_bootstrap, _ = self.env_loop.send(c.backup_every)
78 |
79 | d = Categorical(logits=logits_act)
80 | entropy = d.entropy().mean()
81 |
82 | lambda_returns = compute_lambda_returns(rew, end, trunc, val_bootstrap, c.gamma, c.lambda_)
83 |
84 | loss_actions = (-d.log_prob(act) * (lambda_returns - val).detach()).mean()
85 | loss_values = c.weight_value_loss * F.mse_loss(val, lambda_returns)
86 | loss_entropy = -c.weight_entropy_loss * entropy
87 |
88 | loss = loss_actions + loss_entropy + loss_values
89 |
90 | metrics = {
91 | "policy_entropy": entropy.detach() / math.log(2),
92 | "loss_actions": loss_actions.detach(),
93 | "loss_entropy": loss_entropy.detach(),
94 | "loss_values": loss_values.detach(),
95 | "loss_total": loss.detach(),
96 | }
97 |
98 | return loss, metrics
99 |
100 |
101 | class ActorCriticEncoder(nn.Module):
102 | def __init__(self, cfg: ActorCriticConfig) -> None:
103 | super().__init__()
104 | assert len(cfg.channels) == len(cfg.down)
105 | encoder_layers = [Conv3x3(cfg.img_channels, cfg.channels[0])]
106 | for i in range(len(cfg.channels)):
107 | encoder_layers.append(SmallResBlock(cfg.channels[max(0, i - 1)], cfg.channels[i]))
108 | if cfg.down[i]:
109 | encoder_layers.append(nn.MaxPool2d(2))
110 | self.encoder = nn.Sequential(*encoder_layers)
111 |
112 | def forward(self, x: Tensor) -> Tensor:
113 | return self.encoder(x)
114 |
115 |
116 | @torch.no_grad()
117 | def compute_lambda_returns(
118 | rew: Tensor,
119 | end: Tensor,
120 | trunc: Tensor,
121 | val_bootstrap: Tensor,
122 | gamma: float,
123 | lambda_: float,
124 | ) -> Tensor:
125 | assert rew.ndim == 2 and rew.size() == end.size() == trunc.size() == val_bootstrap.size()
126 |
127 | rew = rew.sign() # clip reward
128 |
129 | end_or_trunc = (end + trunc).clip(max=1)
130 | not_end = 1 - end
131 | not_trunc = 1 - trunc
132 |
133 | lambda_returns = rew + not_end * gamma * (not_trunc * (1 - lambda_) + trunc) * val_bootstrap
134 |
135 | if lambda_ == 0:
136 | return lambda_returns
137 |
138 | last = val_bootstrap[:, -1]
139 | for t in reversed(range(rew.size(1))):
140 | lambda_returns[:, t] += end_or_trunc[:, t].logical_not() * gamma * lambda_ * last
141 | last = lambda_returns[:, t]
142 |
143 | return lambda_returns
144 |
--------------------------------------------------------------------------------
/src/models/blocks.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import math
3 | from typing import List, Optional
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 | # Settings for GroupNorm and Attention
11 |
12 | GN_GROUP_SIZE = 32
13 | GN_EPS = 1e-5
14 | ATTN_HEAD_DIM = 8
15 |
16 | # Convs
17 |
18 | Conv1x1 = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
19 | Conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1)
20 |
21 | # GroupNorm and conditional GroupNorm
22 |
23 |
24 | class GroupNorm(nn.Module):
25 | def __init__(self, in_channels: int) -> None:
26 | super().__init__()
27 | num_groups = max(1, in_channels // GN_GROUP_SIZE)
28 | self.norm = nn.GroupNorm(num_groups, in_channels, eps=GN_EPS)
29 |
30 | def forward(self, x: Tensor) -> Tensor:
31 | return self.norm(x)
32 |
33 |
34 | class AdaGroupNorm(nn.Module):
35 | def __init__(self, in_channels: int, cond_channels: int) -> None:
36 | super().__init__()
37 | self.in_channels = in_channels
38 | self.num_groups = max(1, in_channels // GN_GROUP_SIZE)
39 | self.linear = nn.Linear(cond_channels, in_channels * 2)
40 |
41 | def forward(self, x: Tensor, cond: Tensor) -> Tensor:
42 | assert x.size(1) == self.in_channels
43 | x = F.group_norm(x, self.num_groups, eps=GN_EPS)
44 | scale, shift = self.linear(cond)[:, :, None, None].chunk(2, dim=1)
45 | return x * (1 + scale) + shift
46 |
47 |
48 | # Self Attention
49 |
50 |
51 | class SelfAttention2d(nn.Module):
52 | def __init__(self, in_channels: int, head_dim: int = ATTN_HEAD_DIM) -> None:
53 | super().__init__()
54 | self.n_head = max(1, in_channels // head_dim)
55 | assert in_channels % self.n_head == 0
56 | self.norm = GroupNorm(in_channels)
57 | self.qkv_proj = Conv1x1(in_channels, in_channels * 3)
58 | self.out_proj = Conv1x1(in_channels, in_channels)
59 | nn.init.zeros_(self.out_proj.weight)
60 | nn.init.zeros_(self.out_proj.bias)
61 |
62 | def forward(self, x: Tensor) -> Tensor:
63 | n, c, h, w = x.shape
64 | x = self.norm(x)
65 | qkv = self.qkv_proj(x)
66 | qkv = qkv.view(n, self.n_head * 3, c // self.n_head, h * w).transpose(2, 3).contiguous()
67 | q, k, v = [x for x in qkv.chunk(3, dim=1)]
68 | att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
69 | att = F.softmax(att, dim=-1)
70 | y = att @ v
71 | y = y.transpose(2, 3).reshape(n, c, h, w)
72 | return x + self.out_proj(y)
73 |
74 |
75 | # Embedding of the noise level
76 |
77 |
78 | class FourierFeatures(nn.Module):
79 | def __init__(self, cond_channels: int) -> None:
80 | super().__init__()
81 | assert cond_channels % 2 == 0
82 | self.register_buffer("weight", torch.randn(1, cond_channels // 2))
83 |
84 | def forward(self, input: Tensor) -> Tensor:
85 | assert input.ndim == 1
86 | f = 2 * math.pi * input.unsqueeze(1) @ self.weight
87 | return torch.cat([f.cos(), f.sin()], dim=-1)
88 |
89 |
90 | # [Down|Up]sampling
91 |
92 |
93 | class Downsample(nn.Module):
94 | def __init__(self, in_channels: int) -> None:
95 | super().__init__()
96 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)
97 | nn.init.orthogonal_(self.conv.weight)
98 |
99 | def forward(self, x: Tensor) -> Tensor:
100 | return self.conv(x)
101 |
102 |
103 | class Upsample(nn.Module):
104 | def __init__(self, in_channels: int) -> None:
105 | super().__init__()
106 | self.conv = Conv3x3(in_channels, in_channels)
107 |
108 | def forward(self, x: Tensor) -> Tensor:
109 | x = F.interpolate(x, scale_factor=2.0, mode="nearest")
110 | return self.conv(x)
111 |
112 |
113 | # Small Residual block
114 |
115 |
116 | class SmallResBlock(nn.Module):
117 | def __init__(self, in_channels: int, out_channels: int) -> None:
118 | super().__init__()
119 | self.f = nn.Sequential(GroupNorm(in_channels), nn.SiLU(inplace=True), Conv3x3(in_channels, out_channels))
120 | self.skip_projection = nn.Identity() if in_channels == out_channels else Conv1x1(in_channels, out_channels)
121 |
122 | def forward(self, x: Tensor) -> Tensor:
123 | return self.skip_projection(x) + self.f(x)
124 |
125 |
126 | # Residual block (conditioning with AdaGroupNorm, no [down|up]sampling, optional self-attention)
127 |
128 |
129 | class ResBlock(nn.Module):
130 | def __init__(self, in_channels: int, out_channels: int, cond_channels: int, attn: bool) -> None:
131 | super().__init__()
132 | should_proj = in_channels != out_channels
133 | self.proj = Conv1x1(in_channels, out_channels) if should_proj else nn.Identity()
134 | self.norm1 = AdaGroupNorm(in_channels, cond_channels)
135 | self.conv1 = Conv3x3(in_channels, out_channels)
136 | self.norm2 = AdaGroupNorm(out_channels, cond_channels)
137 | self.conv2 = Conv3x3(out_channels, out_channels)
138 | self.attn = SelfAttention2d(out_channels) if attn else nn.Identity()
139 | nn.init.zeros_(self.conv2.weight)
140 |
141 | def forward(self, x: Tensor, cond: Tensor) -> Tensor:
142 | r = self.proj(x)
143 | x = self.conv1(F.silu(self.norm1(x, cond)))
144 | x = self.conv2(F.silu(self.norm2(x, cond)))
145 | x = x + r
146 | x = self.attn(x)
147 | return x
148 |
149 |
150 | # Sequence of residual blocks (in_channels -> mid_channels -> ... -> mid_channels -> out_channels)
151 |
152 |
153 | class ResBlocks(nn.Module):
154 | def __init__(
155 | self,
156 | list_in_channels: List[int],
157 | list_out_channels: List[int],
158 | cond_channels: int,
159 | attn: bool,
160 | ) -> None:
161 | super().__init__()
162 | assert len(list_in_channels) == len(list_out_channels)
163 | self.in_channels = list_in_channels[0]
164 | self.resblocks = nn.ModuleList(
165 | [
166 | ResBlock(in_ch, out_ch, cond_channels, attn)
167 | for (in_ch, out_ch) in zip(list_in_channels, list_out_channels)
168 | ]
169 | )
170 |
171 | def forward(self, x: Tensor, cond: Tensor, to_cat: Optional[List[Tensor]] = None) -> Tensor:
172 | outputs = []
173 | for i, resblock in enumerate(self.resblocks):
174 | x = x if to_cat is None else torch.cat((x, to_cat[i]), dim=1)
175 | x = resblock(x, cond)
176 | outputs.append(x)
177 | return x, outputs
178 |
179 |
180 | # UNet
181 |
182 |
183 | class UNet(nn.Module):
184 | def __init__(self, cond_channels: int, depths: List[int], channels: List[int], attn_depths: List[int]) -> None:
185 | super().__init__()
186 | assert len(depths) == len(channels) == len(attn_depths)
187 | self._num_down = len(channels) - 1
188 |
189 | d_blocks, u_blocks = [], []
190 | for i, n in enumerate(depths):
191 | c1 = channels[max(0, i - 1)]
192 | c2 = channels[i]
193 | d_blocks.append(
194 | ResBlocks(
195 | list_in_channels=[c1] + [c2] * (n - 1),
196 | list_out_channels=[c2] * n,
197 | cond_channels=cond_channels,
198 | attn=attn_depths[i],
199 | )
200 | )
201 | u_blocks.append(
202 | ResBlocks(
203 | list_in_channels=[2 * c2] * n + [c1 + c2],
204 | list_out_channels=[c2] * n + [c1],
205 | cond_channels=cond_channels,
206 | attn=attn_depths[i],
207 | )
208 | )
209 | self.d_blocks = nn.ModuleList(d_blocks)
210 | self.u_blocks = nn.ModuleList(reversed(u_blocks))
211 |
212 | self.mid_blocks = ResBlocks(
213 | list_in_channels=[channels[-1]] * 2,
214 | list_out_channels=[channels[-1]] * 2,
215 | cond_channels=cond_channels,
216 | attn=True,
217 | )
218 |
219 | downsamples = [nn.Identity()] + [Downsample(c) for c in channels[:-1]]
220 | upsamples = [nn.Identity()] + [Upsample(c) for c in reversed(channels[:-1])]
221 | self.downsamples = nn.ModuleList(downsamples)
222 | self.upsamples = nn.ModuleList(upsamples)
223 |
224 | def forward(self, x: Tensor, cond: Tensor) -> Tensor:
225 | *_, h, w = x.size()
226 | n = self._num_down
227 | padding_h = math.ceil(h / 2 ** n) * 2 ** n - h
228 | padding_w = math.ceil(w / 2 ** n) * 2 ** n - w
229 | x = F.pad(x, (0, padding_w, 0, padding_h))
230 |
231 | d_outputs = []
232 | for block, down in zip(self.d_blocks, self.downsamples):
233 | x_down = down(x)
234 | x, block_outputs = block(x_down, cond)
235 | d_outputs.append((x_down, *block_outputs))
236 |
237 | x, _ = self.mid_blocks(x, cond)
238 |
239 | u_outputs = []
240 | for block, up, skip in zip(self.u_blocks, self.upsamples, reversed(d_outputs)):
241 | x_up = up(x)
242 | x, block_outputs = block(x_up, cond, skip[::-1])
243 | u_outputs.append((x_up, *block_outputs))
244 |
245 | x = x[..., :h, :w]
246 | return x, d_outputs, u_outputs
247 |
--------------------------------------------------------------------------------
/src/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .denoiser import Denoiser, DenoiserConfig, SigmaDistributionConfig
2 | from .inner_model import InnerModelConfig
3 | from .diffusion_sampler import DiffusionSampler, DiffusionSamplerConfig
4 |
--------------------------------------------------------------------------------
/src/models/diffusion/denoiser.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, List
3 |
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from data import Batch
10 | from .inner_model import InnerModel, InnerModelConfig
11 | from utils import LossAndLogs, get_frame_indices, LossLogsData
12 |
13 |
14 | def add_dims(input: Tensor, n: int) -> Tensor:
15 | return input.reshape(input.shape + (1,) * (n - input.ndim))
16 |
17 |
18 | @dataclass
19 | class Conditioners:
20 | c_in: Tensor
21 | c_out: Tensor
22 | c_skip: Tensor
23 | c_noise: Tensor
24 | c_noise_cond: Tensor
25 |
26 |
27 | @dataclass
28 | class SigmaDistributionConfig:
29 | loc: float
30 | scale: float
31 | sigma_min: float
32 | sigma_max: float
33 |
34 | @dataclass
35 | class FrameStrides:
36 | count: int
37 | stride: int
38 |
39 |
40 | @dataclass
41 | class DenoiserConfig:
42 | inner_model: InnerModelConfig
43 | sigma_data: float
44 | sigma_offset_noise: float
45 | noise_previous_obs: bool
46 | upsampling_factor: Optional[int] = None
47 | upsampling_frame_height: Optional[int] = None
48 | upsampling_frame_width: Optional[int] = None
49 | frame_sampling: Optional[List[FrameStrides]] = None
50 |
51 |
52 | class Denoiser(nn.Module):
53 | def __init__(self, cfg: DenoiserConfig) -> None:
54 | super().__init__()
55 | self.cfg = cfg
56 | self.is_upsampler = cfg.upsampling_factor is not None
57 | cfg.inner_model.is_upsampler = self.is_upsampler
58 | self.inner_model = InnerModel(cfg.inner_model)
59 | self.sample_sigma_training = None
60 | self.context_indicies = None if self.is_upsampler else get_frame_indices(cfg.frame_sampling)
61 |
62 | @property
63 | def device(self) -> torch.device:
64 | return self.inner_model.noise_emb.weight.device
65 |
66 | def setup_training(self, cfg: SigmaDistributionConfig) -> None:
67 | assert self.sample_sigma_training is None
68 |
69 | def sample_sigma(n: int, device: torch.device):
70 | s = torch.randn(n, device=device) * cfg.scale + cfg.loc
71 | return s.exp().clip(cfg.sigma_min, cfg.sigma_max)
72 |
73 | self.sample_sigma_training = sample_sigma
74 |
75 | def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
76 | b, c, _, _ = x.shape
77 | offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)
78 | return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)
79 |
80 | def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners:
81 | sigma = (sigma ** 2 + self.cfg.sigma_offset_noise ** 2).sqrt()
82 | c_in = 1 / (sigma ** 2 + self.cfg.sigma_data ** 2).sqrt()
83 | c_skip = self.cfg.sigma_data ** 2 / (sigma ** 2 + self.cfg.sigma_data ** 2)
84 | c_out = sigma * c_skip.sqrt()
85 | c_noise = sigma.log() / 4
86 | c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise)
87 | return Conditioners(
88 | *(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1))))
89 |
90 | def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor],
91 | cs: Conditioners) -> Tensor:
92 | rescaled_obs = obs / self.cfg.sigma_data
93 | rescaled_noise = noisy_next_obs * cs.c_in
94 | return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act)
95 |
96 | @torch.no_grad()
97 | def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
98 | d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
99 | # Quantize to {0, ..., 255}, then back to [-1, 1]
100 | d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
101 | return d
102 |
103 | @torch.no_grad()
104 | def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor,
105 | act: Optional[Tensor]) -> Tensor:
106 | cs = self.compute_conditioners(sigma, sigma_cond)
107 | model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
108 | denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
109 | return denoised
110 |
111 | def get_prev_obs(self, all_obs, act, i, b, n, c, H, W):
112 | prev_obs = all_obs[:, i:i + n].reshape(b, n * c, H, W) if self.is_upsampler else all_obs[:, i + self.context_indicies].reshape(b, n * c, H, W)
113 | prev_act = None if self.is_upsampler else act[:,i+self.context_indicies]
114 | return prev_obs, prev_act
115 |
116 | def forward(self, batch: Batch) -> LossLogsData:
117 | b, t, c, h, w = batch.obs.size()
118 | H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w)
119 | n = 0 if self.is_upsampler else self.context_indicies[-1] + 1
120 | seq_length = t - n # t = n + 1 + num_autoregressive_steps
121 |
122 | if self.is_upsampler:
123 | all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device)
124 | low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor,
125 | mode="bicubic").reshape(b, t, c, H, W)
126 | all_acts = None
127 | assert all_obs.shape == low_res.shape
128 | else:
129 | all_obs = batch.obs.clone()
130 | all_acts = batch.act.clone()
131 |
132 | loss = 0
133 | for i in range(seq_length):
134 | prev_obs, prev_act = self.get_prev_obs(all_obs, all_acts, i, b, self.cfg.inner_model.num_steps_conditioning, c, H, W)
135 | obs = all_obs[:, n + i]
136 | mask = batch.mask_padding[:, n + i]
137 |
138 | if self.cfg.noise_previous_obs:
139 | sigma_cond = self.sample_sigma_training(b, self.device)
140 | prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise)
141 | else:
142 | sigma_cond = None
143 |
144 | if self.is_upsampler:
145 | prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1)
146 |
147 | sigma = self.sample_sigma_training(b, self.device)
148 | noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise)
149 |
150 | cs = self.compute_conditioners(sigma, sigma_cond)
151 | model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs)
152 |
153 | target = (obs - cs.c_skip * noisy_obs) / cs.c_out
154 | loss += F.mse_loss(model_output[mask], target[mask])
155 |
156 | denoised = self.wrap_model_output(noisy_obs, model_output, cs)
157 | all_obs[:, n + i] = denoised
158 |
159 | metrics = {"loss_denoising": loss.item()/seq_length}
160 | batch_data = {"obs": all_obs[:, -seq_length:], 'act': batch.act[:, -seq_length:], 'mask_padding': batch.mask_padding[:, -seq_length:]}
161 | return loss, metrics, batch_data
162 |
--------------------------------------------------------------------------------
/src/models/diffusion/diffusion_sampler.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from .denoiser import Denoiser
8 |
9 |
10 | @dataclass
11 | class DiffusionSamplerConfig:
12 | num_steps_denoising: int
13 | sigma_min: float = 2e-3
14 | sigma_max: float = 5
15 | rho: int = 7
16 | order: int = 1
17 | s_churn: float = 0
18 | s_tmin: float = 0
19 | s_tmax: float = float("inf")
20 | s_noise: float = 1
21 | s_cond: float = 0
22 |
23 |
24 | class DiffusionSampler:
25 | def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
26 | self.denoiser = denoiser
27 | self.cfg = cfg
28 | self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)
29 |
30 | @torch.no_grad()
31 | def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]:
32 | device = prev_obs.device
33 |
34 | b, t, c, h, w = prev_obs.size()
35 |
36 | prev_obs = prev_obs.reshape(b, t * c, h, w)
37 | s_in = torch.ones(b, device=device)
38 | gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
39 | x = torch.randn(b, c, h, w, device=device)
40 | trajectory = [x]
41 | for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
42 | gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
43 | sigma_hat = sigma * (gamma + 1)
44 | if gamma > 0:
45 | eps = torch.randn_like(x) * self.cfg.s_noise
46 | x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
47 | if self.cfg.s_cond > 0:
48 | sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device)
49 | prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0)
50 | else:
51 | sigma_cond = None
52 | denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act)
53 | d = (x - denoised) / sigma_hat
54 | dt = next_sigma - sigma_hat
55 | if self.cfg.order == 1 or next_sigma == 0:
56 | # Euler method
57 | x = x + d * dt
58 | else:
59 | # Heun's method
60 | x_2 = x + d * dt
61 | denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act)
62 | d_2 = (x_2 - denoised_2) / next_sigma
63 | d_prime = (d + d_2) / 2
64 | x = x + d_prime * dt
65 | trajectory.append(x)
66 | return x, trajectory
67 |
68 |
69 | def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
70 | min_inv_rho = sigma_min ** (1 / rho)
71 | max_inv_rho = sigma_max ** (1 / rho)
72 | l = torch.linspace(0, 1, num_steps, device=device)
73 | sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
74 | return torch.cat((sigmas, sigmas.new_zeros(1)))
75 |
--------------------------------------------------------------------------------
/src/models/diffusion/inner_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 |
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet
10 |
11 |
12 | @dataclass
13 | class InnerModelConfig:
14 | img_channels: int
15 | num_steps_conditioning: int
16 | cond_channels: int
17 | depths: List[int]
18 | channels: List[int]
19 | attn_depths: List[bool]
20 | num_actions: Optional[int] = None # set by trainer after env creation
21 | is_upsampler: Optional[bool] = None # set by Denoiser
22 |
23 |
24 | class InnerModel(nn.Module):
25 | def __init__(self, cfg: InnerModelConfig) -> None:
26 | super().__init__()
27 | self.noise_emb = FourierFeatures(cfg.cond_channels)
28 | self.noise_cond_emb = FourierFeatures(cfg.cond_channels)
29 | self.act_emb = None if cfg.is_upsampler else nn.Sequential(
30 | nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning),
31 | nn.Flatten(), # b t e -> b (t e)
32 | )
33 | self.cond_proj = nn.Sequential(
34 | nn.Linear(cfg.cond_channels, cfg.cond_channels),
35 | nn.SiLU(),
36 | nn.Linear(cfg.cond_channels, cfg.cond_channels),
37 | )
38 | self.conv_in = Conv3x3((cfg.num_steps_conditioning + int(cfg.is_upsampler) + 1) * cfg.img_channels, cfg.channels[0])
39 |
40 | self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
41 |
42 | self.norm_out = GroupNorm(cfg.channels[0])
43 | self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels)
44 | nn.init.zeros_(self.conv_out.weight)
45 |
46 | def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, c_noise_cond: Tensor, obs: Tensor, act: Optional[Tensor]) -> Tensor:
47 | if self.act_emb is not None:
48 | assert act.ndim == 2 or (act.ndim == 3 and act.size(2) == self.act_emb[0].num_embeddings and set(act.unique().tolist()).issubset(set([0, 1])))
49 | act_emb = self.act_emb(act) if act.ndim == 2 else self.act_emb[1]((act.float() @ self.act_emb[0].weight))
50 | else:
51 | assert act is None
52 | act_emb = 0
53 |
54 | cond = self.cond_proj(self.noise_emb(c_noise) + self.noise_cond_emb(c_noise_cond) + act_emb)
55 | x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1))
56 | x, _, _ = self.unet(x, cond)
57 | x = self.conv_out(F.silu(self.norm_out(x)))
58 | return x
59 |
--------------------------------------------------------------------------------
/src/models/rew_end_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torcheval.metrics.functional import multiclass_confusion_matrix
9 |
10 | from .blocks import Conv3x3, Downsample, ResBlocks
11 | from data import Batch
12 | from utils import init_lstm, LossAndLogs
13 |
14 |
15 | @dataclass
16 | class RewEndModelConfig:
17 | lstm_dim: int
18 | img_channels: int
19 | img_size: int
20 | cond_channels: int
21 | depths: List[int]
22 | channels: List[int]
23 | attn_depths: List[int]
24 | num_actions: Optional[int] = None
25 |
26 |
27 | class RewEndModel(nn.Module):
28 | def __init__(self, cfg: RewEndModelConfig) -> None:
29 | super().__init__()
30 | self.cfg = cfg
31 | self.encoder = RewEndEncoder(2 * cfg.img_channels, cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
32 | self.act_emb = nn.Embedding(cfg.num_actions, cfg.cond_channels)
33 | input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (len(cfg.depths) - 1)) ** 2
34 | self.lstm = nn.LSTM(input_dim_lstm, cfg.lstm_dim, batch_first=True)
35 | self.head = nn.Sequential(
36 | nn.Linear(cfg.lstm_dim, cfg.lstm_dim),
37 | nn.SiLU(),
38 | nn.Linear(cfg.lstm_dim, 3 + 2, bias=False),
39 | )
40 | init_lstm(self.lstm)
41 |
42 | def predict_rew_end(
43 | self,
44 | obs: Tensor,
45 | act: Tensor,
46 | next_obs: Tensor,
47 | hx_cx: Optional[Tuple[Tensor, Tensor]] = None,
48 | ) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]:
49 | b, t, c, h, w = obs.shape
50 | obs, act, next_obs = obs.reshape(b * t, c, h, w), act.reshape(b * t), next_obs.reshape(b * t, c, h, w)
51 | x = self.encoder(torch.cat((obs, next_obs), dim=1), self.act_emb(act))
52 | x = x.reshape(b, t, -1) # (b t) e h w -> b t (e h w)
53 | x, hx_cx = self.lstm(x, hx_cx)
54 | logits = self.head(x)
55 | return logits[:, :, :-2], logits[:, :, -2:], hx_cx
56 |
57 | def forward(self, batch: Batch) -> LossAndLogs:
58 | obs = batch.obs[:, :-1]
59 | act = batch.act[:, :-1]
60 | next_obs = batch.obs[:, 1:]
61 | rew = batch.rew[:, :-1]
62 | end = batch.end[:, :-1]
63 | mask = batch.mask_padding[:, :-1]
64 |
65 | # When dead, replace frame (gray padding) by true final obs
66 | dead = end.bool().any(dim=1)
67 | if dead.any():
68 | final_obs = torch.stack([i["final_observation"] for i, d in zip(batch.info, dead) if d]).to(obs.device)
69 | next_obs[dead, end[dead].argmax(dim=1)] = final_obs
70 |
71 | logits_rew, logits_end, _ = self.predict_rew_end(obs, act, next_obs)
72 | logits_rew = logits_rew[mask]
73 | logits_end = logits_end[mask]
74 | target_rew = rew[mask].sign().long().add(1) # clipped to {-1, 0, 1}
75 | target_end = end[mask]
76 |
77 | loss_rew = F.cross_entropy(logits_rew, target_rew)
78 | loss_end = F.cross_entropy(logits_end, target_end)
79 | loss = loss_rew + loss_end
80 |
81 | metrics = {
82 | "loss_rew": loss_rew.detach(),
83 | "loss_end": loss_end.detach(),
84 | "loss_total": loss.detach(),
85 | "confusion_matrix": {
86 | "rew": multiclass_confusion_matrix(logits_rew, target_rew, num_classes=3),
87 | "end": multiclass_confusion_matrix(logits_end, target_end, num_classes=2),
88 | },
89 | }
90 | return loss, metrics
91 |
92 |
93 | class RewEndEncoder(nn.Module):
94 | def __init__(
95 | self,
96 | in_channels: int,
97 | cond_channels: int,
98 | depths: List[int],
99 | channels: List[int],
100 | attn_depths: List[int],
101 | ) -> None:
102 | super().__init__()
103 | assert len(depths) == len(channels) == len(attn_depths)
104 | self.conv_in = Conv3x3(in_channels, channels[0])
105 | blocks = []
106 | for i, n in enumerate(depths):
107 | c1 = channels[max(0, i - 1)]
108 | c2 = channels[i]
109 | blocks.append(
110 | ResBlocks(
111 | list_in_channels=[c1] + [c2] * (n - 1),
112 | list_out_channels=[c2] * n,
113 | cond_channels=cond_channels,
114 | attn=attn_depths[i],
115 | )
116 | )
117 | blocks.append(
118 | ResBlocks(
119 | list_in_channels=[channels[-1]] * 2,
120 | list_out_channels=[channels[-1]] * 2,
121 | cond_channels=cond_channels,
122 | attn=True,
123 | )
124 | )
125 | self.blocks = nn.ModuleList(blocks)
126 | self.downsamples = nn.ModuleList([nn.Identity()] + [Downsample(c) for c in channels[:-1]] + [nn.Identity()])
127 |
128 | def forward(self, x: Tensor, cond: Tensor) -> Tensor:
129 | x = self.conv_in(x)
130 | for block, down in zip(self.blocks, self.downsamples):
131 | x = down(x)
132 | x, _ = block(x, cond)
133 | return x
134 |
--------------------------------------------------------------------------------
/src/play.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from huggingface_hub import snapshot_download
5 | from hydra import compose, initialize
6 | from hydra.utils import instantiate
7 | from omegaconf import DictConfig, OmegaConf
8 | import torch
9 |
10 | from agent import Agent
11 | from envs import WorldModelEnv
12 | from game import Game, PlayEnv
13 |
14 |
15 | OmegaConf.register_new_resolver("eval", eval)
16 |
17 |
18 | def parse_args() -> argparse.Namespace:
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.")
21 | parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.")
22 | parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.")
23 | parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.")
24 | parser.add_argument("--compile", action="store_true", help="Turn on model compilation.")
25 | parser.add_argument("--fps", type=int, default=30, help="Frame rate.")
26 | parser.add_argument("--no-header", action="store_true")
27 | return parser.parse_args()
28 |
29 |
30 | def check_args(args: argparse.Namespace) -> None:
31 | if not args.record and (args.store_denoising_trajectory or args.store_original_obs):
32 | print("Warning: not in recording mode, ignoring --store* options")
33 | return True
34 |
35 |
36 | def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv:
37 | path_hf = Path(snapshot_download(repo_id="Enigma-AI/multiverse"))
38 |
39 | path_ckpt = path_hf / 'agent.pt'
40 | spawn_dir = Path('.') / 'game/spawn'
41 | # Override config
42 | cfg.agent = OmegaConf.load("config/agent/racing.yaml")
43 | cfg.env = OmegaConf.load("config/env/racing.yaml")
44 |
45 | if torch.cuda.is_available():
46 | device = torch.device("cuda:0")
47 | elif torch.backends.mps.is_available():
48 | device = torch.device("mps")
49 | else:
50 | device = torch.device("cpu")
51 |
52 | print("----------------------------------------------------------------------")
53 | print(f"Using {device} for rendering.")
54 | if not torch.cuda.is_available() and not torch.backends.mps.is_available(): # warn in case CUDA isn't being used (not on MPS devices)
55 | print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.")
56 | print("----------------------------------------------------------------------")
57 |
58 | assert cfg.env.train.id == "racing"
59 | num_actions = cfg.env.num_actions
60 |
61 | # Models
62 | agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
63 | agent.load(path_ckpt)
64 |
65 | # World model environment
66 | sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
67 | if agent.upsampler is not None:
68 | sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
69 | wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
70 | wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
71 |
72 | if device.type == "cuda" and args.compile:
73 | print("Compiling models...")
74 | wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
75 | wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
76 |
77 | play_env = PlayEnv(
78 | agent,
79 | wm_env,
80 | args.record,
81 | args.store_denoising_trajectory,
82 | args.store_original_obs,
83 | )
84 |
85 | return play_env
86 |
87 |
88 | @torch.no_grad()
89 | def main():
90 | args = parse_args()
91 | ok = check_args(args)
92 | if not ok:
93 | return
94 |
95 | with initialize(version_base="1.3", config_path="../config"):
96 | cfg = compose(config_name="trainer")
97 |
98 | # window size
99 | h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size
100 | size_h, size_w = h, w
101 | env = prepare_play_mode(cfg, args)
102 | game = Game(env, (size_h, size_w), fps=args.fps, verbose=not args.no_header)
103 | game.run()
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/src/player/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnigmaLabsAI/multiverse/39d9a01079eae196a8e132f2fcfd9c549f151e58/src/player/__init__.py
--------------------------------------------------------------------------------
/src/player/action_processing.py:
--------------------------------------------------------------------------------
1 | """
2 | Credits: some parts are taken and modified from the file `config.py` from https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/
3 | """
4 |
5 | from dataclasses import dataclass
6 | from typing import Dict, List, Set
7 |
8 | import numpy as np
9 | import pygame
10 | import torch
11 |
12 | from .keymap import GAME_FORBIDDEN_COMBINATIONS, GAME_KEYMAP
13 |
14 |
15 | @dataclass
16 | class GameAction:
17 | keys: List[int]
18 |
19 | def __post_init__(self) -> None:
20 | self.keys = filter_keys_pressed_forbidden(self.keys)
21 | # self.process_mouse()
22 |
23 | @property
24 | def key_names(self) -> List[str]:
25 | return [pygame.key.name(key) for key in self.keys]
26 |
27 |
28 | def print_game_action(action: GameAction) -> str:
29 | action_names = [GAME_KEYMAP[k] for k in action.keys] if len(action.keys) > 0 else []
30 | action_names = [x for x in action_names if not x.startswith("camera_")]
31 | keys = " + ".join(action_names)
32 | return keys
33 |
34 |
35 | N_PLAYERS = 2
36 | N_GAS_KAVIM = 11
37 | N_BREX_KAVIM = 11
38 | N_STEERING = 11
39 | N_KEYS = 8 # number of keyboard outputs, w,s,a,d,up,down,left,right
40 |
41 |
42 | def encode_game_action(game_action: GameAction, device: torch.device) -> torch.Tensor:
43 | p1_gas = torch.zeros(N_GAS_KAVIM)
44 | p2_gas = torch.zeros(N_GAS_KAVIM)
45 |
46 | p1_brex = torch.zeros(N_BREX_KAVIM)
47 | p2_brex = torch.zeros(N_BREX_KAVIM)
48 |
49 | p1_steer = torch.zeros(N_STEERING)
50 | p2_steer = torch.zeros(N_STEERING)
51 |
52 | p1_is_steer = False
53 | p2_is_steer = False
54 |
55 | for key in game_action.key_names:
56 | if key == "w":
57 | p1_gas[N_GAS_KAVIM - 1] = 1
58 | if key == "a":
59 | p1_steer[3] = 1
60 | p1_is_steer = True
61 | if key == "s":
62 | p1_brex[N_BREX_KAVIM - 1] = 1
63 | if key == "d":
64 | p1_steer[N_STEERING - 3] = 1
65 | p1_is_steer = True
66 |
67 | if key == "up":
68 | p2_gas[N_GAS_KAVIM - 1] = 1
69 | if key == "left":
70 | p2_steer[3] = 1
71 | p2_is_steer = True
72 | if key == "down":
73 | p2_brex[N_BREX_KAVIM - 1] = 1
74 | if key == "right":
75 | p2_steer[N_STEERING - 3] = 1
76 | p2_is_steer = True
77 |
78 | if not p1_is_steer:
79 | p1_steer[len(p1_steer) // 2] = 1
80 |
81 | if not any(p1_gas):
82 | p1_gas[0] = 1
83 |
84 | if not any(p1_brex):
85 | p1_brex[0] = 1
86 |
87 | if not p2_is_steer:
88 | p2_steer[len(p2_steer) // 2] = 1
89 |
90 | if not any(p2_gas):
91 | p2_gas[0] = 1
92 |
93 | if not any(p2_brex):
94 | p2_brex[0] = 1
95 |
96 | return torch.cat([p1_gas, p1_brex, p1_steer, p2_gas, p2_brex, p2_steer]).float().to(device)
97 |
98 |
99 | def decode_game_action(y_preds: torch.Tensor) -> GameAction:
100 | y_preds = y_preds.squeeze()
101 | keys_pred = y_preds[0:N_KEYS]
102 |
103 | keys_pressed = []
104 | keys_pressed_onehot = np.round(keys_pred)
105 | if keys_pressed_onehot[0] == 1:
106 | keys_pressed.append("w")
107 | if keys_pressed_onehot[1] == 1:
108 | keys_pressed.append("a")
109 | if keys_pressed_onehot[2] == 1:
110 | keys_pressed.append("s")
111 | if keys_pressed_onehot[3] == 1:
112 | keys_pressed.append("d")
113 | if keys_pressed_onehot[4] == 1:
114 | keys_pressed.append("up")
115 | if keys_pressed_onehot[5] == 1:
116 | keys_pressed.append("left")
117 | if keys_pressed_onehot[6] == 1:
118 | keys_pressed.append("down")
119 | if keys_pressed_onehot[7] == 1:
120 | keys_pressed.append("right")
121 |
122 | keys_pressed = [pygame.key.key_code(x) for x in keys_pressed]
123 |
124 | return GameAction(keys_pressed)
125 |
126 |
127 | def filter_keys_pressed_forbidden(keys_pressed: List[int], keymap: Dict[int, str] = GAME_KEYMAP,
128 | forbidden_combinations: List[Set[str]] = GAME_FORBIDDEN_COMBINATIONS) -> List[int]:
129 | keys = set()
130 | names = set()
131 | for key in keys_pressed:
132 | if key not in keymap:
133 | continue
134 | name = keymap[key]
135 | keys.add(key)
136 | names.add(name)
137 | for forbidden in forbidden_combinations:
138 | if forbidden.issubset(names):
139 | keys.remove(key)
140 | names.remove(name)
141 | break
142 | return list(filter(lambda key: key in keys, keys_pressed))
143 |
--------------------------------------------------------------------------------
/src/player/keymap.py:
--------------------------------------------------------------------------------
1 | import pygame
2 |
3 |
4 | GAME_KEYMAP = {
5 | pygame.K_w: "p1_up",
6 | pygame.K_d: "p1_right",
7 | pygame.K_a: "p1_left",
8 | pygame.K_s: "p1_down",
9 |
10 | pygame.K_UP: "p2_up",
11 | pygame.K_RIGHT: "p2_right",
12 | pygame.K_LEFT: "p2_left",
13 | pygame.K_DOWN: "p2_down",
14 | }
15 |
16 | GAME_FORBIDDEN_COMBINATIONS = {}
--------------------------------------------------------------------------------
/src/process_denoiser_files.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 | from pathlib import Path
4 | from multiprocessing import Pool
5 | import shutil
6 |
7 | import torchvision.transforms.functional as T
8 | from tqdm import tqdm
9 |
10 | from data.dataset import Dataset, GameHdf5Dataset
11 | from data.episode import Episode
12 | from data.segment import SegmentId
13 |
14 | import os
15 |
16 | PREFIX = ""
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "tar_dir",
23 | type=Path,
24 | help="folder containing the .tar files from `dataset_tars` folder from Huggingface",
25 | )
26 | parser.add_argument(
27 | "out_dir",
28 | type=Path,
29 | help="a new directory (should not exist already), the script will untar and process data there",
30 | )
31 | return parser.parse_args()
32 |
33 |
34 | def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None:
35 | d = path_tar.stem
36 | assert path_tar.stem.startswith(PREFIX)
37 | d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_"))
38 | d.mkdir(exist_ok=False, parents=True)
39 | shutil.copy(path_tar, d / os.path.basename(path_tar))
40 | new_path_tar = d / path_tar.name
41 | if remove_tar:
42 | new_path_tar.unlink()
43 | else:
44 | shutil.copy(new_path_tar, path_tar.parent)
45 |
46 |
47 | def main():
48 | args = parse_args()
49 |
50 | tar_dir = args.tar_dir.absolute()
51 | out_dir = args.out_dir.absolute()
52 |
53 | if not tar_dir.exists():
54 | print(
55 | "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)"
56 | )
57 | return
58 |
59 | if out_dir.exists():
60 | print(f"Wrong usage: the output directory should not exist ({args.out_dir})")
61 | return
62 |
63 | with Path("test_split.txt").open("r") as f:
64 | test_files = f.read().split("\n")
65 |
66 | full_res_dir = out_dir / "full_res"
67 | low_res_dir = out_dir / "low_res"
68 |
69 | hdf5_files = [
70 | x for x in tar_dir.iterdir() if x.suffix == ".hdf5" and x.stem.startswith(PREFIX)
71 | ]
72 | n = len(hdf5_files)
73 |
74 | str_files = "\n".join(map(str, hdf5_files))
75 | print(f"Ready to untar {n} tar files:\n{str_files}")
76 |
77 | remove_tar = False
78 |
79 | # Untar game files
80 | f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar)
81 | with Pool(n) as p:
82 | p.map(f, hdf5_files)
83 |
84 | print(f"{n} .tar files unpacked in {full_res_dir}")
85 |
86 | #
87 | # Create low-res data
88 | #
89 |
90 | game_dataset = GameHdf5Dataset(full_res_dir)
91 |
92 | train_dataset = Dataset(low_res_dir / "train", None)
93 | test_dataset = Dataset(low_res_dir / "test", None)
94 |
95 | for i in tqdm(game_dataset._filenames, desc="Creating low_res"):
96 | episode_length = game_dataset._length_one_episode[i]
97 | episode = Episode(
98 | **{
99 | k: v
100 | for k, v in game_dataset[SegmentId(i, 0, episode_length, None)].__dict__.items()
101 | if k not in ("mask_padding", "id")
102 | }
103 | )
104 | episode.obs = T.resize(
105 | episode.obs, (48, 64), interpolation=T.InterpolationMode.BICUBIC
106 | )
107 | filename = game_dataset._filenames[i]
108 | file_id = f"{filename.parent.stem}/{filename.name}"
109 | episode.info = {"original_file_id": file_id}
110 | dataset = test_dataset if filename.name in test_files else train_dataset
111 | dataset.add_episode(episode)
112 |
113 | train_dataset.save_to_default_path()
114 | test_dataset.save_to_default_path()
115 |
116 | print(
117 | f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n"
118 | )
119 |
120 | print("You can now edit `config/env/racing.yaml` and set:")
121 | print(f"path_data_low_res: {low_res_dir}")
122 | print(f"path_data_full_res: {full_res_dir}")
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/src/process_upsampler_files.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 | from pathlib import Path
4 | from multiprocessing import Pool
5 | import shutil
6 |
7 | import torchvision.transforms.functional as T
8 | from tqdm import tqdm
9 |
10 | from data.dataset import Dataset, GameHdf5Dataset
11 | from data.episode import Episode
12 | from data.segment import SegmentId
13 |
14 | import os
15 |
16 | PREFIX = ""
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "tar_dir",
23 | type=Path,
24 | help="folder containing the .tar files from `dataset_tars` folder from Huggingface",
25 | )
26 | parser.add_argument(
27 | "out_dir",
28 | type=Path,
29 | help="a new directory (should not exist already), the script will untar and process data there",
30 | )
31 | return parser.parse_args()
32 |
33 |
34 | def process_tar(path_tar: Path, out_dir: Path, remove_tar: bool) -> None:
35 | d = path_tar.stem
36 | assert path_tar.stem.startswith(PREFIX)
37 | d = out_dir / "-".join(path_tar.stem[len(PREFIX) :].split("_to_"))
38 | d.mkdir(exist_ok=False, parents=True)
39 | shutil.copy(path_tar, d / os.path.basename(path_tar))
40 | new_path_tar = d / path_tar.name
41 | if remove_tar:
42 | new_path_tar.unlink()
43 | else:
44 | shutil.copy(new_path_tar, path_tar.parent)
45 |
46 |
47 | def main():
48 | args = parse_args()
49 |
50 | tar_dir = args.tar_dir.absolute()
51 | out_dir = args.out_dir.absolute()
52 |
53 | if not tar_dir.exists():
54 | print(
55 | "Wrong usage: the tar directory should exist (and contain the downloaded .tar files)"
56 | )
57 | return
58 |
59 | if out_dir.exists():
60 | print(f"Wrong usage: the output directory should not exist ({args.out_dir})")
61 | return
62 |
63 | with Path("test_split.txt").open("r") as f:
64 | test_files = f.read().split("\n")
65 |
66 | full_res_dir = out_dir / "full_res"
67 | low_res_dir = out_dir / "low_res"
68 |
69 | hdf5_files = [
70 | x for x in tar_dir.iterdir() if x.suffix == ".hdf5" and x.stem.startswith(PREFIX)
71 | ]
72 | n = len(hdf5_files)
73 |
74 | str_files = "\n".join(map(str, hdf5_files))
75 | print(f"Ready to untar {n} tar files:\n{str_files}")
76 |
77 | remove_tar = False
78 |
79 | # Untar game files
80 | f = partial(process_tar, out_dir=full_res_dir, remove_tar=remove_tar)
81 | with Pool(n) as p:
82 | p.map(f, hdf5_files)
83 |
84 | print(f"{n} .tar files unpacked in {full_res_dir}")
85 |
86 | #
87 | # Create low-res data
88 | #
89 |
90 | game_dataset = GameHdf5Dataset(full_res_dir)
91 |
92 | train_dataset = Dataset(low_res_dir / "train", None)
93 | test_dataset = Dataset(low_res_dir / "test", None)
94 |
95 | for i in tqdm(game_dataset._filenames, desc="Creating low_res"):
96 | episode_length = game_dataset._length_one_episode[i]
97 | episode = Episode(
98 | **{
99 | k: v
100 | for k, v in game_dataset[SegmentId(i, 0, episode_length, None)].__dict__.items()
101 | if k not in ("mask_padding", "id")
102 | }
103 | )
104 | episode.obs = T.resize(
105 | episode.obs, (35, 53), interpolation=T.InterpolationMode.BICUBIC
106 | )
107 | filename = game_dataset._filenames[i]
108 | file_id = f"{filename.parent.stem}/{filename.name}"
109 | episode.info = {"original_file_id": file_id}
110 | dataset = test_dataset if filename.name in test_files else train_dataset
111 | dataset.add_episode(episode)
112 |
113 | train_dataset.save_to_default_path()
114 | test_dataset.save_to_default_path()
115 |
116 | print(
117 | f"Split train/test data ({train_dataset.num_episodes}/{test_dataset.num_episodes} episodes)\n"
118 | )
119 |
120 | print("You can now edit `config/env/racing.yaml` and set:")
121 | print(f"path_data_low_res: {low_res_dir}")
122 | print(f"path_data_full_res: {full_res_dir}")
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/src/spawn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import os
4 | import random
5 | import h5py
6 | import numpy as np
7 | import cv2
8 |
9 | low_res_w = 64
10 | low_res_h = 48
11 |
12 | crop_frame = {
13 | 'left_top': (0.04, 0.18),
14 | 'right_bottom': (0.92, 0.95)
15 | }
16 |
17 | def extract_roi(image, rect):
18 | img_width, img_height, _ = image.shape
19 | min_x = int(rect['left_top'][1] * img_width)
20 | max_x = int(rect['right_bottom'][1] * img_width)
21 | min_y = int(rect['left_top'][0] * img_height)
22 | max_y = int(rect['right_bottom'][0] * img_height)
23 | roi = image[min_x:max_x, min_y:max_y]
24 | return roi
25 |
26 | def rescale_image(image, scale_factor):
27 | # Get new dimensions
28 | new_width = int(image.shape[1] * scale_factor)
29 | new_height = int(image.shape[0] * scale_factor)
30 |
31 | # Resize the image
32 | return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
33 |
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument(
38 | "full_res_directory",
39 | type=Path,
40 | help="Specify your full_res directory.",
41 | )
42 |
43 | parser.add_argument(
44 | "model_directory",
45 | type=Path,
46 | help="Specify ",
47 | )
48 | return parser.parse_args()
49 |
50 |
51 | def main():
52 | args = parse_args()
53 |
54 | full_res_directory = args.full_res_directory.absolute()
55 | model_directory = args.model_directory.absolute()
56 |
57 | spawn_dir = model_directory / "game/spawn"
58 | existing_spawns = len(os.listdir(spawn_dir))
59 | os.makedirs(spawn_dir / str(existing_spawns), exist_ok=True)
60 | full_res = os.path.join(full_res_directory, random.choice(os.listdir(full_res_directory)))
61 | h5file = h5py.File(full_res, 'r')
62 | i = random.randint(0, 1000)
63 |
64 | data_x_frames = []
65 | data_y_frames = []
66 | next_act_frames = []
67 | low_res_frames = []
68 |
69 | for j in range(20):
70 | frame_x = f'frame_{i}_x'
71 | frame_y = f'frame_{i}_y'
72 |
73 | # Check if the datasets exist in the file
74 | if frame_x in h5file and frame_y in h5file:
75 | # Append each frame to the lists
76 | data_x = h5file[frame_x][:]
77 | data_y = h5file[frame_y][:]
78 |
79 | img1 = cv2.cvtColor(data_x[:,:,3:], cv2.COLOR_BGR2RGB)
80 | img2 = cv2.cvtColor(data_x[:,:,:3], cv2.COLOR_BGR2RGB)
81 |
82 | img1_cropped = extract_roi(img1, crop_frame)
83 | img2_cropped = extract_roi(img2, crop_frame)
84 | img1_cropped = cv2.resize(img1_cropped, (530, 350), interpolation=cv2.INTER_AREA)
85 | img2_cropped = cv2.resize(img2_cropped, (530, 350), interpolation=cv2.INTER_AREA)
86 |
87 | data_x_frames.append(np.concatenate([img1_cropped, img2_cropped], axis=2))
88 | data_y_frames.append(data_y)
89 |
90 | img1 = cv2.resize(img1, (low_res_w, low_res_h), interpolation=cv2.INTER_AREA)
91 | img2 = cv2.resize(img2, (low_res_w, low_res_h), interpolation=cv2.INTER_AREA)
92 |
93 | low_res_frames.append(np.concatenate([img1, img2], axis=2).astype(np.uint8))
94 | else:
95 | print(f"One or both of {frame_x} or {frame_y} do not exist in the file.")
96 | i += 1
97 | for _ in range(200):
98 | next_act = f'frame_{i}_y'
99 | if next_act in h5file:
100 | next_act_data = h5file[next_act][:]
101 | next_act_frames.append(next_act_data)
102 |
103 | data_x_stacked = np.stack(data_x_frames)
104 | data_y_stacked = np.stack(data_y_frames)
105 | next_act_stacked = np.stack(next_act_frames)
106 | low_res_stacked = np.stack(low_res_frames)
107 |
108 | low_res_stacked = np.transpose(low_res_stacked, (0, 3, 1, 2))
109 | data_x_stacked = np.transpose(data_x_stacked, (0, 3, 1, 2))
110 |
111 | print(f"Saving act.npy of size {data_y_stacked.shape}")
112 | np.save(spawn_dir / f"{existing_spawns}/act.npy", data_y_stacked)
113 | print(f"Saving full_res.npy of size {data_x_stacked.shape}")
114 | np.save(spawn_dir / f"{existing_spawns}/full_res.npy", data_x_stacked)
115 | print(f"Saving next_act.npy of size {next_act_stacked.shape}")
116 | np.save(spawn_dir / f"{existing_spawns}/next_act.npy", next_act_stacked)
117 | print(f"Saving low_res.npy of size {low_res_stacked.shape}")
118 | np.save(spawn_dir / f"{existing_spawns}/low_res.npy", low_res_stacked)
119 |
120 | h5file.close()
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from pathlib import Path
3 | import shutil
4 | import time
5 | from typing import List, Optional, Tuple
6 |
7 | from hydra.utils import instantiate
8 | import numpy as np
9 | from omegaconf import DictConfig, OmegaConf
10 | import torch
11 | import torch.distributed as dist
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm, trange
14 | import wandb
15 |
16 | from agent import Agent
17 | from coroutines.collector import make_collector, NumToCollect
18 | from data import BatchSampler, collate_segments_to_batch, Dataset, DatasetTraverser, GameHdf5Dataset
19 | from envs import make_atari_env, WorldModelEnv
20 | from utils import (
21 | broadcast_if_needed,
22 | build_ddp_wrapper,
23 | CommonTools,
24 | configure_opt,
25 | count_parameters,
26 | get_lr_sched,
27 | keep_agent_copies_every,
28 | Logs,
29 | move_opt_to,
30 | process_confusion_matrices_if_any_and_compute_classification_metrics,
31 | save_info_for_import_script,
32 | save_with_backup,
33 | set_seed,
34 | StateDictMixin,
35 | try_until_no_except,
36 | wandb_log,
37 | get_frame_indices,
38 | build_pages_per_epoch,
39 | find_maximum_key_below_threshold
40 | )
41 |
42 |
43 | class Trainer(StateDictMixin):
44 | def __init__(self, cfg: DictConfig, root_dir: Path) -> None:
45 | torch.backends.cuda.matmul.allow_tf32 = True
46 | OmegaConf.resolve(cfg)
47 | self._cfg = cfg
48 | self._rank = dist.get_rank() if dist.is_initialized() else 0
49 | self._world_size = dist.get_world_size() if dist.is_initialized() else 1
50 |
51 | # Pick a random seed
52 | set_seed(torch.seed() % 10 ** 9)
53 |
54 | # Device
55 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu", self._rank)
56 | print(f"Starting on {self._device}")
57 | self._use_cuda = self._device.type == "cuda"
58 | if self._use_cuda:
59 | torch.cuda.set_device(self._rank) # fix compilation error on multi-gpu nodes
60 |
61 | # Init wandb
62 | if self._rank == 0:
63 | try_until_no_except(
64 | partial(wandb.init, config=OmegaConf.to_container(cfg, resolve=True), reinit=True, resume=True,
65 | **cfg.wandb)
66 | )
67 |
68 | # Flags
69 | self._is_static_dataset = cfg.static_dataset.path is not None
70 | self._is_model_free = cfg.training.model_free
71 |
72 | # Checkpointing
73 | self._path_ckpt_dir = Path("checkpoints")
74 | self._path_state_ckpt = self._path_ckpt_dir / "state.pt"
75 | self._keep_agent_copies = partial(
76 | keep_agent_copies_every,
77 | every=cfg.checkpointing.save_agent_every,
78 | path_ckpt_dir=self._path_ckpt_dir,
79 | num_to_keep=cfg.checkpointing.num_to_keep,
80 | )
81 | self._save_info_for_import_script = partial(
82 | save_info_for_import_script, run_name=cfg.wandb.name, path_ckpt_dir=self._path_ckpt_dir
83 | )
84 |
85 | # First time, init files hierarchy
86 | if not cfg.common.resume and self._rank == 0:
87 | self._path_ckpt_dir.mkdir(exist_ok=False, parents=False)
88 | path_config = Path("config") / "trainer.yaml"
89 | path_config.parent.mkdir(exist_ok=False, parents=False)
90 | shutil.move(".hydra/config.yaml", path_config)
91 | wandb.save(str(path_config))
92 | shutil.copytree(src=root_dir / "src", dst="./src")
93 | shutil.copytree(src=root_dir / "scripts", dst="./scripts")
94 |
95 | if cfg.env.train.id == "racing":
96 | assert cfg.env.path_data_low_res is not None and cfg.env.path_data_full_res is not None, "Make sure to download GT4 data and set the relevant paths in cfg.env"
97 | assert self._is_static_dataset
98 | num_actions = cfg.env.num_actions
99 | dataset_full_res = GameHdf5Dataset(Path(cfg.env.path_data_full_res))
100 |
101 | # Envs (atari only)
102 | else:
103 | if self._rank == 0:
104 | train_env = make_atari_env(num_envs=cfg.collection.train.num_envs, device=self._device, **cfg.env.train)
105 | test_env = make_atari_env(num_envs=cfg.collection.test.num_envs, device=self._device, **cfg.env.test)
106 | num_actions = int(test_env.num_actions)
107 | else:
108 | num_actions = None
109 | num_actions, = broadcast_if_needed(num_actions)
110 | dataset_full_res = None
111 |
112 | num_workers = cfg.training.num_workers_data_loaders
113 | use_manager = cfg.training.cache_in_ram and (num_workers > 0)
114 | p = Path(cfg.static_dataset.path) if self._is_static_dataset else Path("dataset")
115 | self.train_dataset = Dataset(p / "train", dataset_full_res, "train_dataset", cfg.training.cache_in_ram,
116 | use_manager)
117 | self.test_dataset = Dataset(p / "test", dataset_full_res, "test_dataset", cache_in_ram=True)
118 | self.train_dataset.load_from_default_path()
119 | self.test_dataset.load_from_default_path()
120 |
121 | # Create models
122 | self.agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(self._device)
123 | self._agent = build_ddp_wrapper(**self.agent._modules) if dist.is_initialized() else self.agent
124 |
125 | if cfg.initialization.path_to_ckpt is not None:
126 | self.agent.load(**cfg.initialization)
127 |
128 | # Collectors
129 | if not self._is_static_dataset and self._rank == 0:
130 | self._train_collector = make_collector(
131 | train_env, self.agent.actor_critic, self.train_dataset, cfg.collection.train.epsilon
132 | )
133 | self._test_collector = make_collector(
134 | test_env, self.agent.actor_critic, self.test_dataset, cfg.collection.test.epsilon,
135 | reset_every_collect=True
136 | )
137 |
138 | ######################################################
139 |
140 | # Optimizers and LR schedulers
141 |
142 | def build_opt(name: str) -> torch.optim.AdamW:
143 | return configure_opt(getattr(self.agent, name), **getattr(cfg, name).optimizer)
144 |
145 | def build_lr_sched(name: str) -> torch.optim.lr_scheduler.LambdaLR:
146 | return get_lr_sched(self.opt.get(name), getattr(cfg, name).training.lr_warmup_steps)
147 |
148 | model_names = [self._cfg.train_model]
149 | self._model_names = ["actor_critic"] if self._is_model_free else [name for name in model_names if
150 | getattr(self.agent, name) is not None]
151 |
152 | self.opt = CommonTools(**{name: build_opt(name) for name in self._model_names})
153 | self.lr_sched = CommonTools(**{name: build_lr_sched(name) for name in self._model_names})
154 |
155 | # Data loaders
156 |
157 | make_data_loader = partial(
158 | DataLoader,
159 | dataset=self.train_dataset,
160 | collate_fn=collate_segments_to_batch,
161 | num_workers=num_workers,
162 | persistent_workers=(num_workers > 0),
163 | pin_memory=self._use_cuda,
164 | pin_memory_device=str(self._device) if self._use_cuda else "",
165 | )
166 |
167 | make_batch_sampler = partial(BatchSampler, self.train_dataset, self._rank, self._world_size)
168 |
169 | def get_sample_weights(sample_weights: List[float]) -> Optional[List[float]]:
170 | return None if (self._is_static_dataset and cfg.static_dataset.ignore_sample_weights) else sample_weights
171 |
172 | c = cfg.denoiser.training
173 | # in case of distributed training, the cfg needs to be taken from the module attribute
174 | agent_cfg = self._agent.denoiser.cfg if not hasattr(self._agent.denoiser, "module") else self._agent.denoiser.module.cfg
175 | effective_context_length = int(get_frame_indices(agent_cfg.frame_sampling)[-1]) + 1
176 | seq_length = effective_context_length + 1 + c.num_autoregressive_steps
177 | bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights), False, c.num_autoregressive_steps, c.initial_num_consecutive_page_count)
178 | dl_denoiser_train = make_data_loader(batch_sampler=bs)
179 | dl_denoiser_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length)
180 |
181 | self.pages_per_epoch = build_pages_per_epoch(c.num_consecutive_pages)
182 |
183 | if self.agent.upsampler is not None:
184 | c = cfg.upsampler.training
185 | seq_length = cfg.agent.upsampler.inner_model.num_steps_conditioning + 1 + c.num_autoregressive_steps
186 | bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights), False, c.num_autoregressive_steps, c.initial_num_consecutive_page_count)
187 | dl_upsampler_train = make_data_loader(batch_sampler=bs)
188 | dl_upsampler_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length)
189 | else:
190 | dl_upsampler_train = dl_upsampler_test = None
191 |
192 | if self.agent.rew_end_model is not None:
193 | c = cfg.rew_end_model.training
194 | bs = make_batch_sampler(c.batch_size, c.seq_length, get_sample_weights(c.sample_weights),
195 | can_sample_beyond_end=True)
196 | dl_rew_end_model_train = make_data_loader(batch_sampler=bs)
197 | dl_rew_end_model_test = DatasetTraverser(self.test_dataset, c.batch_size, c.seq_length)
198 | else:
199 | dl_rew_end_model_train = dl_rew_end_model_test = None
200 |
201 | self._data_loader_train = CommonTools(dl_denoiser_train, dl_upsampler_train, dl_rew_end_model_train, None)
202 | self._data_loader_test = CommonTools(dl_denoiser_test, dl_upsampler_test, dl_rew_end_model_test, None)
203 |
204 | # RL env
205 |
206 | if self.agent.actor_critic is not None:
207 | actor_critic_loss_cfg = instantiate(cfg.actor_critic.actor_critic_loss)
208 |
209 | if self._is_model_free:
210 | assert self.agent.actor_critic is not None
211 | rl_env = make_atari_env(num_envs=cfg.actor_critic.training.batch_size, device=self._device,
212 | **cfg.env.train)
213 |
214 | else:
215 | c = cfg.actor_critic.training
216 | sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
217 | if self.agent.upsampler is not None:
218 | sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
219 | bs = make_batch_sampler(c.batch_size, sl, get_sample_weights(c.sample_weights))
220 | dl_actor_critic = make_data_loader(batch_sampler=bs)
221 | wm_env_cfg = instantiate(cfg.world_model_env)
222 | rl_env = WorldModelEnv(self.agent.denoiser, self.agent.upsampler, self.agent.rew_end_model,
223 | dl_actor_critic, wm_env_cfg)
224 |
225 | if cfg.training.compile_wm:
226 | rl_env.predict_next_obs = torch.compile(rl_env.predict_next_obs, mode="reduce-overhead")
227 | rl_env.predict_rew_end = torch.compile(rl_env.predict_rew_end, mode="reduce-overhead")
228 | else:
229 | actor_critic_loss_cfg = None
230 | rl_env = None
231 |
232 | # Setup training
233 | sigma_distribution_cfg = instantiate(cfg.denoiser.sigma_distribution)
234 | sigma_distribution_cfg_upsampler = instantiate(
235 | cfg.upsampler.sigma_distribution) if self.agent.upsampler is not None else None
236 | self.agent.setup_training(sigma_distribution_cfg, sigma_distribution_cfg_upsampler, actor_critic_loss_cfg,
237 | rl_env)
238 |
239 | # Training state (things to be saved/restored)
240 | self.epoch = 0
241 | self.num_epochs_collect = None
242 | self.num_episodes_test = 0
243 | self.num_batch_train = CommonTools(0, 0, 0)
244 | self.num_batch_test = CommonTools(0, 0, 0)
245 |
246 | if cfg.common.resume:
247 | # self.load_state_checkpoint()
248 | self.load_agent_state_checkpoint()
249 | else:
250 | self.save_checkpoint()
251 |
252 | if self._rank == 0:
253 | for name in self._model_names:
254 | print(f"{count_parameters(getattr(self.agent, name))} parameters in {name}")
255 | print(self.train_dataset)
256 | print(self.test_dataset)
257 |
258 | def run(self) -> None:
259 | to_log = []
260 |
261 | if self.epoch == 0:
262 | if self._is_model_free or self._is_static_dataset:
263 | self.num_epochs_collect = 0
264 | else:
265 | if self._rank == 0:
266 | self.num_epochs_collect, to_log_ = self.collect_initial_dataset()
267 | to_log += to_log_
268 | self.num_epochs_collect, sd_train_dataset = broadcast_if_needed(self.num_epochs_collect,
269 | self.train_dataset.state_dict())
270 | self.train_dataset.load_state_dict(sd_train_dataset)
271 |
272 | num_epochs = self.num_epochs_collect + self._cfg.training.num_final_epochs
273 |
274 | while self.epoch < num_epochs:
275 | self.epoch += 1
276 | start_time = time.time()
277 |
278 | if self._rank == 0:
279 | print(f"\nEpoch {self.epoch} / {num_epochs}\n")
280 |
281 | # Training
282 | should_collect_train = (
283 | self._rank == 0 and not self._is_model_free and not self._is_static_dataset and self.epoch <= self.num_epochs_collect)
284 |
285 | if should_collect_train:
286 | c = self._cfg.collection.train
287 | to_log += self._train_collector.send(NumToCollect(steps=c.steps_per_epoch))
288 | sd_train_dataset, = broadcast_if_needed(self.train_dataset.state_dict()) # update dataset for ranks > 0
289 | self.train_dataset.load_state_dict(sd_train_dataset)
290 |
291 | if self._cfg.training.should:
292 | to_log += self.train_agent()
293 |
294 | # Evaluation
295 | should_test = self._rank == 0 and self._cfg.evaluation.should and (
296 | self.epoch % self._cfg.evaluation.every == 0)
297 | should_collect_test = should_test and not self._is_static_dataset
298 |
299 | if should_collect_test:
300 | to_log += self.collect_test()
301 |
302 | if should_test and not self._is_model_free:
303 | to_log += self.test_agent()
304 |
305 | # Logging
306 | to_log.append({"duration": (time.time() - start_time) / 3600})
307 | if self._rank == 0:
308 | wandb_log(to_log, self.epoch)
309 | to_log = []
310 |
311 | # Checkpointing
312 | self.save_checkpoint()
313 |
314 | if dist.is_initialized():
315 | dist.barrier()
316 |
317 | smallest_page_epoch = find_maximum_key_below_threshold(self.pages_per_epoch, self.epoch)
318 | if smallest_page_epoch is not None:
319 | self._data_loader_train.denoiser.batch_sampler.num_consecutive_batches = self.pages_per_epoch[smallest_page_epoch]
320 |
321 | # Last collect
322 | if self._rank == 0 and not self._is_static_dataset:
323 | wandb_log(self.collect_test(final=True), self.epoch)
324 |
325 | def collect_initial_dataset(self) -> Tuple[int, Logs]:
326 | print("\nInitial collect\n")
327 | to_log = []
328 | c = self._cfg.collection.train
329 | min_steps = c.first_epoch.min
330 | steps_per_epoch = c.steps_per_epoch
331 | max_steps = c.first_epoch.max
332 | threshold_rew = c.first_epoch.threshold_rew
333 | assert min_steps % steps_per_epoch == 0
334 |
335 | steps = min_steps
336 | while True:
337 | to_log += self._train_collector.send(NumToCollect(steps=steps))
338 | num_steps = self.train_dataset.num_steps
339 | total_minority_rew = sum(sorted(self.train_dataset.counts_rew)[:-1])
340 | if total_minority_rew >= threshold_rew:
341 | break
342 | if (max_steps is not None) and num_steps >= max_steps:
343 | print("Reached the specified maximum for initial collect")
344 | break
345 | print(f"Minority reward: {total_minority_rew}/{threshold_rew} -> Keep collecting\n")
346 | steps = steps_per_epoch
347 |
348 | print("\nSummary of initial collect:")
349 | print(f"Num steps: {num_steps} / {c.num_steps_total}")
350 | print(f"Reward counts: {dict(self.train_dataset.counter_rew)}")
351 |
352 | remaining_steps = c.num_steps_total - num_steps
353 | assert remaining_steps % c.steps_per_epoch == 0
354 | num_epochs_collect = remaining_steps // c.steps_per_epoch
355 |
356 | return num_epochs_collect, to_log
357 |
358 | def collect_test(self, final: bool = False) -> Logs:
359 | c = self._cfg.collection.test
360 | episodes = c.num_final_episodes if final else c.num_episodes
361 | td = self.test_dataset
362 | td.clear()
363 | to_log = self._test_collector.send(NumToCollect(episodes=episodes))
364 | key_ep_id = f"{td.name}/episode_id"
365 | to_log = [{k: v + self.num_episodes_test if k == key_ep_id else v for k, v in x.items()} for x in to_log]
366 |
367 | print(f"\nSummary of {'final' if final else 'test'} collect: {td.num_episodes} episodes ({td.num_steps} steps)")
368 | keys = [key_ep_id, "return", "length"]
369 | to_log_episodes = [x for x in to_log if set(x.keys()) == set(keys)]
370 | episode_ids, returns, lengths = [[d[k] for d in to_log_episodes] for k in keys]
371 | for i, (ep_id, ret, length) in enumerate(zip(episode_ids, returns, lengths)):
372 | print(f" Episode {ep_id}: return = {ret} length = {length}\n", end="\n" if i == episodes - 1 else "")
373 |
374 | self.num_episodes_test += episodes
375 |
376 | if final:
377 | to_log.append({"final_return_mean": np.mean(returns), "final_return_std": np.std(returns)})
378 | print(to_log[-1])
379 |
380 | return to_log
381 |
382 | def train_agent(self) -> Logs:
383 | self.agent.train()
384 | self.agent.zero_grad()
385 | to_log = []
386 | for name in self._model_names:
387 | cfg = getattr(self._cfg, name).training
388 | if self.epoch > cfg.start_after_epochs:
389 | steps = cfg.steps_first_epoch if self.epoch == 1 else cfg.steps_per_epoch
390 | to_log += self.train_component(name, steps)
391 | return to_log
392 |
393 | @torch.no_grad()
394 | def test_agent(self) -> Logs:
395 | self.agent.eval()
396 | to_log = []
397 | for name in self._model_names:
398 | if name == "actor_critic":
399 | continue
400 | cfg = getattr(self._cfg, name).training
401 | if self.epoch > cfg.start_after_epochs:
402 | to_log += self.test_component(name)
403 | return to_log
404 |
405 | def train_component(self, name: str, steps: int) -> Logs:
406 | cfg = getattr(self._cfg, name).training
407 | model = getattr(self._agent, name)
408 | opt = self.opt.get(name)
409 | lr_sched = self.lr_sched.get(name)
410 | data_loader = self._data_loader_train.get(name)
411 |
412 | torch.cuda.empty_cache()
413 | model.to(self._device)
414 | move_opt_to(opt, self._device)
415 |
416 | model.train()
417 | opt.zero_grad()
418 | data_iterator = iter(data_loader) if data_loader is not None else None
419 | to_log = []
420 |
421 | num_steps = cfg.grad_acc_steps * steps
422 |
423 | # in case of distributed training, the cfg needs to be taken from the module attribute
424 | agent_cfg = self._agent.denoiser.cfg if not hasattr(self._agent.denoiser, "module") else self._agent.denoiser.module.cfg
425 | effective_context_length = 0 if name == 'upsampler' else get_frame_indices(agent_cfg.frame_sampling)[-1] + 1
426 | context_obs = None
427 | context_act = None
428 | context_mask_padding = None
429 |
430 | for i in trange(num_steps, desc=f"Training {name}", disable=self._rank > 0):
431 | curr_iter = 0
432 | total_length = 0
433 | while curr_iter < data_loader.batch_sampler.num_consecutive_batches:
434 | batch = next(data_iterator).to(self._device) if data_iterator is not None else None
435 | curr_iter += 1
436 |
437 | # initialize variables in the first segment
438 | if batch.segment_ids[0].is_first_batch:
439 | context_obs = batch.obs
440 | context_act = batch.act
441 | context_mask_padding = batch.mask_padding
442 |
443 | # build the observations until there's enough context
444 | while context_obs.shape[1] < effective_context_length + 1:
445 | batch = next(data_iterator).to(self._device) if data_iterator is not None else None
446 | curr_iter += 1
447 |
448 | context_obs = torch.concat([context_obs, batch.obs], dim=1)
449 | context_act = torch.concat([context_act, batch.act], dim=1)
450 | context_mask_padding = torch.concat([context_mask_padding, batch.mask_padding], dim=1)
451 |
452 | # split the context into the context obs and the obs to predict
453 |
454 | # future obs (after the effective context length)
455 | predict_obs = context_obs[:, effective_context_length:]
456 | predict_act = context_act[:, effective_context_length:]
457 | predict_mask_padding = context_mask_padding[:, effective_context_length:]
458 |
459 | # previous obs (before the effective context length)
460 | context_obs = context_obs[:, :effective_context_length]
461 | context_act = context_act[:, :effective_context_length]
462 | context_mask_padding = context_mask_padding[:, :effective_context_length]
463 |
464 | total_length = (data_loader.batch_sampler.num_consecutive_batches-1)*data_loader.batch_sampler.autoregressive_obs + data_loader.batch_sampler.seq_length - effective_context_length
465 | else: # set batch to be the frames to be predicted next
466 | predict_obs = batch.obs
467 | predict_act = batch.act
468 | predict_mask_padding = batch.mask_padding
469 |
470 | # build batch for prediction
471 | batch.obs = torch.cat([context_obs[:, -effective_context_length:], predict_obs], dim=1)
472 | batch.act = torch.cat([context_act[:, -effective_context_length:], predict_act], dim=1)
473 | batch.mask_padding = torch.cat([context_mask_padding[:, -effective_context_length:], predict_mask_padding], dim=1)
474 |
475 | # train on the batch
476 | loss, metrics, batch_data = model(batch) if batch is not None else model()
477 | loss /= total_length
478 | loss.backward()
479 |
480 | # collect the predicted frames and prepare the next context
481 | if 'obs' in batch_data and 'act' in batch_data and 'mask_padding' in batch_data and name != 'upsampler':
482 | obs = batch_data['obs']
483 | act = batch_data['act']
484 | mask_padding = batch_data['mask_padding']
485 |
486 | # the next context begins at the end of the current predictions
487 | context_obs = torch.concat([context_obs, obs], dim=1)
488 | context_act = torch.concat([context_act, act], dim=1)
489 | context_mask_padding = torch.cat([context_mask_padding, mask_padding], dim=1)
490 |
491 | # saving up to effective context length frames back
492 | context_obs = context_obs[:, -effective_context_length:]
493 | context_act = context_act[:, -effective_context_length:]
494 | context_mask_padding = context_mask_padding[:, -effective_context_length:]
495 |
496 | num_batch = self.num_batch_train.get(name)
497 | metrics[f"num_batch_train_{name}"] = num_batch
498 | self.num_batch_train.set(name, num_batch + 1)
499 | to_log.append(metrics)
500 |
501 | if (i + 1) % cfg.grad_acc_steps == 0:
502 | if cfg.max_grad_norm is not None:
503 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm).item()
504 | metrics["grad_norm_before_clip"] = grad_norm
505 | opt.step()
506 | opt.zero_grad()
507 | if lr_sched is not None:
508 | metrics["lr"] = lr_sched.get_last_lr()[0]
509 | lr_sched.step()
510 |
511 | process_confusion_matrices_if_any_and_compute_classification_metrics(to_log)
512 | to_log = [{f"{name}/train/{k}": v for k, v in d.items()} for d in to_log]
513 |
514 | model.to("cpu")
515 | move_opt_to(opt, "cpu")
516 |
517 | return to_log
518 |
519 | @torch.no_grad()
520 | def test_component(self, name: str) -> Logs:
521 | model = getattr(self.agent, name)
522 | data_loader = self._data_loader_test.get(name)
523 | model.eval()
524 | model.to(self._device)
525 | to_log = []
526 | for batch in tqdm(data_loader, desc=f"Evaluating {name}"):
527 | batch = batch.to(self._device)
528 | _, metrics, _ = model(batch)
529 | num_batch = self.num_batch_test.get(name)
530 | metrics[f"num_batch_test_{name}"] = num_batch
531 | self.num_batch_test.set(name, num_batch + 1)
532 | to_log.append(metrics)
533 |
534 | process_confusion_matrices_if_any_and_compute_classification_metrics(to_log)
535 | to_log = [{f"{name}/test/{k}": v for k, v in d.items()} for d in to_log]
536 | model.to("cpu")
537 | return to_log
538 |
539 | def load_state_checkpoint(self) -> None:
540 | self.load_state_dict(torch.load(self._path_state_ckpt, map_location=self._device))
541 |
542 | def load_agent_state_checkpoint(self) -> None:
543 | agent_state_dict = torch.load(self._path_state_ckpt, map_location=self._device)
544 | self.agent.load_state_dict(agent_state_dict)
545 |
546 | def save_checkpoint(self) -> None:
547 | if self._rank == 0:
548 | save_with_backup(self.state_dict(), self._path_state_ckpt)
549 | self.train_dataset.save_to_default_path()
550 | self.test_dataset.save_to_default_path()
551 | self._keep_agent_copies(self.agent.state_dict(), self.epoch)
552 | self._save_info_for_import_script(self.epoch)
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 |
2 | from argparse import Namespace
3 | from collections import OrderedDict
4 | from dataclasses import dataclass
5 | from functools import partial
6 | import json
7 | from pathlib import Path
8 | import random
9 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10 |
11 | from omegaconf import OmegaConf
12 | import numpy as np
13 | import torch
14 | import torch.distributed as dist
15 | from torch import Tensor
16 | from torch.optim.lr_scheduler import LambdaLR
17 | import torch.nn as nn
18 | from torch.nn.parallel import DistributedDataParallel as DDP
19 | from torch.optim import AdamW
20 | import wandb
21 |
22 |
23 | ATARI_100K_GAMES = [
24 | "Alien",
25 | "Amidar",
26 | "Assault",
27 | "Asterix",
28 | "BankHeist",
29 | "BattleZone",
30 | "Boxing",
31 | "Breakout",
32 | "ChopperCommand",
33 | "CrazyClimber",
34 | "DemonAttack",
35 | "Freeway",
36 | "Frostbite",
37 | "Gopher",
38 | "Hero",
39 | "Jamesbond",
40 | "Kangaroo",
41 | "Krull",
42 | "KungFuMaster",
43 | "MsPacman",
44 | "Pong",
45 | "PrivateEye",
46 | "Qbert",
47 | "RoadRunner",
48 | "Seaquest",
49 | "UpNDown",
50 | ]
51 |
52 |
53 | Logs = List[Dict[str, float]]
54 | LossAndLogs = Tuple[Tensor, Dict[str, Any]]
55 | LossLogsData = Tuple[Tensor, Dict[str, Any], Dict[str, Any]]
56 |
57 |
58 | class StateDictMixin:
59 | def _init_fields(self) -> None:
60 | def has_sd(x: str) -> bool:
61 | return callable(getattr(x, "state_dict", None)) and callable(getattr(x, "load_state_dict", None))
62 |
63 | self._all_fields = {k for k in vars(self) if not k.startswith("_")}
64 | self._fields_sd = {k for k in self._all_fields if has_sd(getattr(self, k))}
65 |
66 | def _get_field(self, k: str) -> Any:
67 | return getattr(self, k).state_dict() if k in self._fields_sd else getattr(self, k)
68 |
69 | def _set_field(self, k: str, v: Any) -> None:
70 | getattr(self, k).load_state_dict(v) if k in self._fields_sd else setattr(self, k, v)
71 |
72 | def state_dict(self) -> Dict[str, Any]:
73 | if not hasattr(self, "_all_fields"):
74 | self._init_fields()
75 | return {k: self._get_field(k) for k in self._all_fields}
76 |
77 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
78 | if not hasattr(self, "_all_fields"):
79 | self._init_fields()
80 | assert set(list(state_dict.keys())) == self._all_fields
81 | for k, v in state_dict.items():
82 | self._set_field(k, v)
83 |
84 |
85 | @dataclass
86 | class CommonTools(StateDictMixin):
87 | denoiser: Optional[Any] = None
88 | upsampler: Optional[Any] = None
89 | rew_end_model: Optional[Any] = None
90 | actor_critic: Optional[Any] = None
91 |
92 | def get(self, name: str) -> Any:
93 | return getattr(self, name)
94 |
95 | def set(self, name: str, value: Any):
96 | return setattr(self, name, value)
97 |
98 |
99 | def broadcast_if_needed(*args):
100 | objects = list(args)
101 | if dist.is_initialized():
102 | dist.broadcast_object_list(objects, src=0)
103 | # the list `objects` now contains the version of rank 0
104 | return objects
105 |
106 |
107 | def build_ddp_wrapper(**modules_dict: Dict[str, nn.Module]) -> Namespace:
108 | return Namespace(**{name: DDP(module) for name, module in modules_dict.items()})
109 |
110 |
111 | def compute_classification_metrics(confusion_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
112 | num_classes = confusion_matrix.size(0)
113 | precision = torch.zeros(num_classes)
114 | recall = torch.zeros(num_classes)
115 | f1_score = torch.zeros(num_classes)
116 |
117 | for i in range(num_classes):
118 | true_positive = confusion_matrix[i, i].item()
119 | false_positive = confusion_matrix[:, i].sum().item() - true_positive
120 | false_negative = confusion_matrix[i, :].sum().item() - true_positive
121 |
122 | precision[i] = true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else 0
123 | recall[i] = true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else 0
124 | f1_score[i] = (
125 | 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) != 0 else 0
126 | )
127 |
128 | return precision, recall, f1_score
129 |
130 |
131 | def configure_opt(model: nn.Module, lr: float, weight_decay: float, eps: float, *blacklist_module_names: str) -> AdamW:
132 | """Credits to https://github.com/karpathy/minGPT"""
133 | # separate out all parameters to those that will and won't experience regularizing weight decay
134 | decay = set()
135 | no_decay = set()
136 | whitelist_weight_modules = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.LSTMCell, nn.LSTM)
137 | blacklist_weight_modules = (nn.LayerNorm, nn.Embedding, nn.GroupNorm)
138 | for mn, m in model.named_modules():
139 | for pn, p in m.named_parameters():
140 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
141 | if any([fpn.startswith(module_name) for module_name in blacklist_module_names]):
142 | no_decay.add(fpn)
143 | elif "bias" in pn:
144 | # all biases will not be decayed
145 | no_decay.add(fpn)
146 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, whitelist_weight_modules):
147 | # weights of whitelist modules will be weight decayed
148 | decay.add(fpn)
149 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, blacklist_weight_modules):
150 | # weights of blacklist modules will NOT be weight decayed
151 | no_decay.add(fpn)
152 |
153 | # validate that we considered every parameter
154 | param_dict = {pn: p for pn, p in model.named_parameters()}
155 | inter_params = decay & no_decay
156 | union_params = decay | no_decay
157 | assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
158 | assert (
159 | len(param_dict.keys() - union_params) == 0
160 | ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
161 |
162 | # create the pytorch optimizer object
163 | optim_groups = [
164 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
165 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
166 | ]
167 | optimizer = AdamW(optim_groups, lr=lr, eps=eps)
168 | return optimizer
169 |
170 |
171 | def count_parameters(model: nn.Module) -> int:
172 | return sum(p.numel() for p in model.parameters())
173 |
174 |
175 | def extract_state_dict(state_dict: OrderedDict, module_name: str) -> OrderedDict:
176 | return OrderedDict({k.split(".", 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)})
177 |
178 |
179 | def get_lr_sched(opt: torch.optim.Optimizer, num_warmup_steps: int) -> LambdaLR:
180 | def lr_lambda(current_step: int):
181 | return 1 if current_step >= num_warmup_steps else current_step / max(1, num_warmup_steps)
182 |
183 | return LambdaLR(opt, lr_lambda, last_epoch=-1)
184 |
185 |
186 | def init_lstm(model: nn.Module) -> None:
187 | for name, p in model.named_parameters():
188 | if "weight_ih" in name:
189 | nn.init.xavier_uniform_(p.data)
190 | elif "weight_hh" in name:
191 | nn.init.orthogonal_(p.data)
192 | elif "bias_ih" in name:
193 | p.data.fill_(0)
194 | # Set forget-gate bias to 1
195 | n = p.size(0)
196 | p.data[(n // 4) : (n // 2)].fill_(1)
197 | elif "bias_hh" in name:
198 | p.data.fill_(0)
199 |
200 |
201 | def get_path_agent_ckpt(path_ckpt_dir: Union[str, Path], epoch: int, num_zeros: int = 5) -> Path:
202 | d = Path(path_ckpt_dir) / "agent_versions"
203 | if epoch >= 0:
204 | return d / f"agent_epoch_{epoch:0{num_zeros}d}.pt"
205 | else:
206 | all_ = sorted(list(d.iterdir()))
207 | assert len(all_) >= -epoch
208 | return all_[epoch]
209 |
210 |
211 | def keep_agent_copies_every(
212 | agent_sd: Dict[str, Any],
213 | epoch: int,
214 | path_ckpt_dir: Path,
215 | every: int,
216 | num_to_keep: Optional[int],
217 | ) -> None:
218 | assert every > 0
219 | assert num_to_keep is None or num_to_keep > 0
220 | get_path = partial(get_path_agent_ckpt, path_ckpt_dir)
221 | get_path(0).parent.mkdir(parents=False, exist_ok=True)
222 |
223 | # Save agent
224 | save_with_backup(agent_sd, get_path(epoch))
225 |
226 | # Clean oldest
227 | if (num_to_keep is not None) and (epoch % every == 0):
228 | get_path(max(0, epoch - num_to_keep * every)).unlink(missing_ok=True)
229 |
230 | # Clean previous
231 | if (epoch - 1) % every != 0:
232 | get_path(max(0, epoch - 1)).unlink(missing_ok=True)
233 |
234 |
235 | def move_opt_to(opt: AdamW, device: torch.device):
236 | for optimizer_metrics in opt.state.values():
237 | for metric_name, metric in optimizer_metrics.items():
238 | if torch.is_tensor(metric) and metric_name != "step":
239 | optimizer_metrics[metric_name] = metric.to(device)
240 |
241 |
242 | def process_confusion_matrices_if_any_and_compute_classification_metrics(logs: Logs) -> None:
243 | cm = [x.pop("confusion_matrix") for x in logs if "confusion_matrix" in x]
244 | if len(cm) > 0:
245 | confusion_matrices = {k: sum([d[k] for d in cm]) for k in cm[0]} # accumulate confusion matrices
246 | metrics = {}
247 | for key, confusion_matrix in confusion_matrices.items():
248 | precision, recall, f1_score = compute_classification_metrics(confusion_matrix)
249 | metrics.update(
250 | {
251 | **{f"classification_metrics/{key}_precision_class_{i}": v for i, v in enumerate(precision)},
252 | **{f"classification_metrics/{key}_recall_class_{i}": v for i, v in enumerate(recall)},
253 | **{f"classification_metrics/{key}_f1_score_class_{i}": v for i, v in enumerate(f1_score)},
254 | }
255 | )
256 |
257 | logs.append(metrics) # Append the obtained metrics to logs (in place)
258 |
259 |
260 | def prompt_atari_game():
261 | for i, game in enumerate(ATARI_100K_GAMES):
262 | print(f"{i:2d}: {game}")
263 | while True:
264 | x = input("\nEnter a number: ")
265 | if not x.isdigit():
266 | print("Invalid.")
267 | continue
268 | x = int(x)
269 | if x < 0 or x > 25:
270 | print("Invalid.")
271 | continue
272 | break
273 | game = ATARI_100K_GAMES[x]
274 | return game
275 |
276 |
277 | def prompt_run_name(game):
278 | cfg_file = Path("config/trainer.yaml")
279 | cfg_name = OmegaConf.load(cfg_file).wandb.name
280 | suffix = f"-{cfg_name}" if cfg_name is not None else ""
281 | name = game + suffix
282 | name_ = input(f"Confirm run name by pressing Enter (or enter a new name): {name}\n")
283 | if name_ != "":
284 | name = name_
285 | return name
286 |
287 |
288 | def save_info_for_import_script(epoch: int, run_name: str, path_ckpt_dir: Path) -> None:
289 | with (path_ckpt_dir / "info_for_import_script.json").open("w") as f:
290 | json.dump({"epoch": epoch, "name": run_name}, f)
291 |
292 |
293 | def save_with_backup(obj: Any, path: Path):
294 | bk = path.with_suffix(".bk")
295 | if path.is_file():
296 | path.rename(bk)
297 | torch.save(obj, path)
298 | bk.unlink(missing_ok=True)
299 |
300 |
301 | def set_seed(seed: int) -> None:
302 | np.random.seed(seed)
303 | torch.manual_seed(seed)
304 | torch.cuda.manual_seed(seed)
305 | random.seed(seed)
306 |
307 |
308 | def skip_if_run_is_over(func: Callable) -> Callable:
309 | def inner(*args, **kwargs):
310 | path_run_is_over = Path(".run_is_over")
311 | if not path_run_is_over.is_file():
312 | func(*args, **kwargs)
313 | path_run_is_over.touch()
314 | else:
315 | print(f"Run is marked as finished. To unmark, remove '{str(path_run_is_over)}'.")
316 |
317 | return inner
318 |
319 |
320 | def try_until_no_except(func: Callable) -> None:
321 | while True:
322 | try:
323 | func()
324 | except KeyboardInterrupt:
325 | break
326 | except Exception:
327 | continue
328 | else:
329 | break
330 |
331 |
332 | def wandb_log(logs: Logs, epoch: int):
333 | for d in logs:
334 | wandb.log({"epoch": epoch, **d})
335 |
336 |
337 | def get_frame_indices(frame_sampling):
338 | indexes = []
339 | current_index = 0
340 | for group in frame_sampling[::-1]:
341 | for _ in range(group['count']):
342 | indexes.append(current_index)
343 | current_index += group['stride']
344 |
345 | return torch.tensor(indexes)
346 |
347 | def build_pages_per_epoch(pages_per_epoch):
348 | mapping = {}
349 | for group in pages_per_epoch[::-1]:
350 | mapping[group['epoch']] = group['count']
351 |
352 | return mapping
353 |
354 | def find_maximum_key_below_threshold(d, threshold):
355 | if d is None:
356 | return None
357 |
358 | eligible_keys = [k for k in d.keys() if k <= threshold]
359 | if not eligible_keys:
360 | return None
361 | return max(eligible_keys)
--------------------------------------------------------------------------------