├── D_TRAJ.7z ├── TensorBoard.sh ├── agents.py ├── config_files └── STORM.yaml ├── env_wrapper.py ├── eval.py ├── eval.sh ├── readme.md ├── replay_buffer.py ├── requirements.txt ├── sub_models ├── attention_blocks.py ├── functions_losses.py ├── transformer_model.py └── world_models.py ├── train.py ├── train.sh └── utils.py /D_TRAJ.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weipu-zhang/STORM/e0b3fd44320d7e213ec905c673ad3f35b61b89f4/D_TRAJ.7z -------------------------------------------------------------------------------- /TensorBoard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | tensorboard --logdir runs --port 6006 --host 0.0.0.0 3 | -------------------------------------------------------------------------------- /agents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | import copy 8 | from torch.cuda.amp import autocast 9 | 10 | from sub_models.functions_losses import SymLogTwoHotLoss 11 | from utils import EMAScalar 12 | 13 | 14 | def percentile(x, percentage): 15 | flat_x = torch.flatten(x) 16 | kth = int(percentage*len(flat_x)) 17 | per = torch.kthvalue(flat_x, kth).values 18 | return per 19 | 20 | 21 | def calc_lambda_return(rewards, values, termination, gamma, lam, dtype=torch.float32): 22 | # Invert termination to have 0 if the episode ended and 1 otherwise 23 | inv_termination = (termination * -1) + 1 24 | 25 | batch_size, batch_length = rewards.shape[:2] 26 | # gae_step = torch.zeros((batch_size, ), dtype=dtype, device="cuda") 27 | gamma_return = torch.zeros((batch_size, batch_length+1), dtype=dtype, device="cuda") 28 | gamma_return[:, -1] = values[:, -1] 29 | for t in reversed(range(batch_length)): # with last bootstrap 30 | gamma_return[:, t] = \ 31 | rewards[:, t] + \ 32 | gamma * inv_termination[:, t] * (1-lam) * values[:, t] + \ 33 | gamma * inv_termination[:, t] * lam * gamma_return[:, t+1] 34 | return gamma_return[:, :-1] 35 | 36 | 37 | class ActorCriticAgent(nn.Module): 38 | def __init__(self, feat_dim, num_layers, hidden_dim, action_dim, gamma, lambd, entropy_coef) -> None: 39 | super().__init__() 40 | self.gamma = gamma 41 | self.lambd = lambd 42 | self.entropy_coef = entropy_coef 43 | self.use_amp = True 44 | self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32 45 | 46 | self.symlog_twohot_loss = SymLogTwoHotLoss(255, -20, 20) 47 | 48 | actor = [ 49 | nn.Linear(feat_dim, hidden_dim, bias=False), 50 | nn.LayerNorm(hidden_dim), 51 | nn.ReLU() 52 | ] 53 | for i in range(num_layers - 1): 54 | actor.extend([ 55 | nn.Linear(hidden_dim, hidden_dim, bias=False), 56 | nn.LayerNorm(hidden_dim), 57 | nn.ReLU() 58 | ]) 59 | self.actor = nn.Sequential( 60 | *actor, 61 | nn.Linear(hidden_dim, action_dim) 62 | ) 63 | 64 | critic = [ 65 | nn.Linear(feat_dim, hidden_dim, bias=False), 66 | nn.LayerNorm(hidden_dim), 67 | nn.ReLU() 68 | ] 69 | for i in range(num_layers - 1): 70 | critic.extend([ 71 | nn.Linear(hidden_dim, hidden_dim, bias=False), 72 | nn.LayerNorm(hidden_dim), 73 | nn.ReLU() 74 | ]) 75 | 76 | self.critic = nn.Sequential( 77 | *critic, 78 | nn.Linear(hidden_dim, 255) 79 | ) 80 | self.slow_critic = copy.deepcopy(self.critic) 81 | 82 | self.lowerbound_ema = EMAScalar(decay=0.99) 83 | self.upperbound_ema = EMAScalar(decay=0.99) 84 | 85 | self.optimizer = torch.optim.Adam(self.parameters(), lr=3e-5, eps=1e-5) 86 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) 87 | 88 | @torch.no_grad() 89 | def update_slow_critic(self, decay=0.98): 90 | for slow_param, param in zip(self.slow_critic.parameters(), self.critic.parameters()): 91 | slow_param.data.copy_(slow_param.data * decay + param.data * (1 - decay)) 92 | 93 | def policy(self, x): 94 | logits = self.actor(x) 95 | return logits 96 | 97 | def value(self, x): 98 | value = self.critic(x) 99 | value = self.symlog_twohot_loss.decode(value) 100 | return value 101 | 102 | @torch.no_grad() 103 | def slow_value(self, x): 104 | value = self.slow_critic(x) 105 | value = self.symlog_twohot_loss.decode(value) 106 | return value 107 | 108 | def get_logits_raw_value(self, x): 109 | logits = self.actor(x) 110 | raw_value = self.critic(x) 111 | return logits, raw_value 112 | 113 | @torch.no_grad() 114 | def sample(self, latent, greedy=False): 115 | self.eval() 116 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 117 | logits = self.policy(latent) 118 | dist = distributions.Categorical(logits=logits) 119 | if greedy: 120 | action = dist.probs.argmax(dim=-1) 121 | else: 122 | action = dist.sample() 123 | return action 124 | 125 | def sample_as_env_action(self, latent, greedy=False): 126 | action = self.sample(latent, greedy) 127 | return action.detach().cpu().squeeze(-1).numpy() 128 | 129 | def update(self, latent, action, old_logprob, old_value, reward, termination, logger=None): 130 | ''' 131 | Update policy and value model 132 | ''' 133 | self.train() 134 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 135 | logits, raw_value = self.get_logits_raw_value(latent) 136 | dist = distributions.Categorical(logits=logits[:, :-1]) 137 | log_prob = dist.log_prob(action) 138 | entropy = dist.entropy() 139 | 140 | # decode value, calc lambda return 141 | slow_value = self.slow_value(latent) 142 | slow_lambda_return = calc_lambda_return(reward, slow_value, termination, self.gamma, self.lambd) 143 | value = self.symlog_twohot_loss.decode(raw_value) 144 | lambda_return = calc_lambda_return(reward, value, termination, self.gamma, self.lambd) 145 | 146 | # update value function with slow critic regularization 147 | value_loss = self.symlog_twohot_loss(raw_value[:, :-1], lambda_return.detach()) 148 | slow_value_regularization_loss = self.symlog_twohot_loss(raw_value[:, :-1], slow_lambda_return.detach()) 149 | 150 | lower_bound = self.lowerbound_ema(percentile(lambda_return, 0.05)) 151 | upper_bound = self.upperbound_ema(percentile(lambda_return, 0.95)) 152 | S = upper_bound-lower_bound 153 | norm_ratio = torch.max(torch.ones(1).cuda(), S) # max(1, S) in the paper 154 | norm_advantage = (lambda_return-value[:, :-1]) / norm_ratio 155 | policy_loss = -(log_prob * norm_advantage.detach()).mean() 156 | 157 | entropy_loss = entropy.mean() 158 | 159 | loss = policy_loss + value_loss + slow_value_regularization_loss - self.entropy_coef * entropy_loss 160 | 161 | # gradient descent 162 | self.scaler.scale(loss).backward() 163 | self.scaler.unscale_(self.optimizer) # for clip grad 164 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=100.0) 165 | self.scaler.step(self.optimizer) 166 | self.scaler.update() 167 | self.optimizer.zero_grad(set_to_none=True) 168 | 169 | self.update_slow_critic() 170 | 171 | if logger is not None: 172 | logger.log('ActorCritic/policy_loss', policy_loss.item()) 173 | logger.log('ActorCritic/value_loss', value_loss.item()) 174 | logger.log('ActorCritic/entropy_loss', entropy_loss.item()) 175 | logger.log('ActorCritic/S', S.item()) 176 | logger.log('ActorCritic/norm_ratio', norm_ratio.item()) 177 | logger.log('ActorCritic/total_loss', loss.item()) 178 | -------------------------------------------------------------------------------- /config_files/STORM.yaml: -------------------------------------------------------------------------------- 1 | Task: "JointTrainAgent" 2 | 3 | BasicSettings: 4 | Seed: 0 5 | ImageSize: 64 6 | ReplayBufferOnGPU: True 7 | 8 | JointTrainAgent: 9 | SampleMaxSteps: 102000 10 | BufferMaxLength: 100000 11 | BufferWarmUp: 1024 12 | NumEnvs: 1 13 | BatchSize: 16 14 | DemonstrationBatchSize: 4 15 | BatchLength: 64 16 | ImagineBatchSize: 1024 17 | ImagineDemonstrationBatchSize: 256 18 | ImagineContextLength: 8 19 | ImagineBatchLength: 16 20 | TrainDynamicsEverySteps: 1 21 | TrainAgentEverySteps: 1 22 | UseDemonstration: False 23 | SaveEverySteps: 2500 24 | 25 | Models: 26 | WorldModel: 27 | InChannels: 3 28 | TransformerMaxLength: 64 29 | TransformerHiddenDim: 512 30 | TransformerNumLayers: 2 31 | TransformerNumHeads: 8 32 | 33 | Agent: 34 | NumLayers: 2 35 | HiddenDim: 512 36 | Gamma: 0.985 37 | Lambda: 0.95 38 | EntropyCoef: 3E-4 -------------------------------------------------------------------------------- /env_wrapper.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | from collections import deque 4 | import cv2 5 | from einops import rearrange 6 | import copy 7 | 8 | 9 | class LifeLossInfo(gymnasium.Wrapper): 10 | def __init__(self, env): 11 | super().__init__(env) 12 | self.lives_info = None 13 | 14 | def step(self, action): 15 | observation, reward, terminated, truncated, info = self.env.step(action) 16 | current_lives_info = info["lives"] 17 | if current_lives_info < self.lives_info: 18 | info["life_loss"] = True 19 | self.lives_info = info["lives"] 20 | else: 21 | info["life_loss"] = False 22 | 23 | return observation, reward, terminated, truncated, info 24 | 25 | def reset(self, **kwargs): 26 | observation, info = self.env.reset(**kwargs) 27 | self.lives_info = info["lives"] 28 | info["life_loss"] = False 29 | return observation, info 30 | 31 | 32 | class SeedEnvWrapper(gymnasium.Wrapper): 33 | def __init__(self, env, seed): 34 | super().__init__(env) 35 | self.seed = seed 36 | self.env.action_space.seed(seed) 37 | 38 | def reset(self, **kwargs): 39 | kwargs["seed"] = self.seed 40 | obs, _ = self.env.reset(**kwargs) 41 | return obs, _ 42 | 43 | def step(self, action): 44 | return self.env.step(action) 45 | 46 | 47 | class MaxLast2FrameSkipWrapper(gymnasium.Wrapper): 48 | def __init__(self, env, skip=4): 49 | super().__init__(env) 50 | self.skip = skip 51 | 52 | def reset(self, **kwargs): 53 | obs, _ = self.env.reset(**kwargs) 54 | return obs, _ 55 | 56 | def step(self, action): 57 | total_reward = 0 58 | self.obs_buffer = deque(maxlen=2) 59 | for _ in range(self.skip): 60 | obs, reward, done, truncated, info = self.env.step(action) 61 | self.obs_buffer.append(obs) 62 | total_reward += reward 63 | if done or truncated: 64 | break 65 | if len(self.obs_buffer) == 1: 66 | obs = self.obs_buffer[0] 67 | else: 68 | obs = np.max(np.stack(self.obs_buffer), axis=0) 69 | return obs, total_reward, done, truncated, info 70 | 71 | def build_single_env(env_name, image_size): 72 | env = gymnasium.make(env_name, full_action_space=True, frameskip=1) 73 | from gymnasium.wrappers import AtariPreprocessing 74 | env = AtariPreprocessing(env, screen_size=image_size, grayscale_obs=False) 75 | return env 76 | 77 | 78 | def build_vec_env(env_list, image_size, num_envs): 79 | # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7 80 | assert num_envs % len(env_list) == 0 81 | env_fns = [] 82 | vec_env_names = [] 83 | for env_name in env_list: 84 | def lambda_generator(env_name, image_size): 85 | return lambda: build_single_env(env_name, image_size) 86 | env_fns += [lambda_generator(env_name, image_size) for i in range(num_envs//len(env_list))] 87 | vec_env_names += [env_name for i in range(num_envs//len(env_list))] 88 | vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns) 89 | return vec_env, vec_env_names 90 | 91 | 92 | if __name__ == "__main__": 93 | vec_env, vec_env_names = build_vec_env(['ALE/Pong-v5', 'ALE/IceHockey-v5', 'ALE/Breakout-v5', 'ALE/Tennis-v5'], 64, num_envs=8) 94 | current_obs, _ = vec_env.reset() 95 | while True: 96 | action = vec_env.action_space.sample() 97 | obs, reward, done, truncated, info = vec_env.step(action) 98 | # done = done or truncated 99 | if done.any(): 100 | print("---------") 101 | print(reward) 102 | print(info["episode_frame_number"]) 103 | cv2.imshow("Pong", current_obs[0]) 104 | cv2.imshow("IceHockey", current_obs[2]) 105 | cv2.imshow("Breakout", current_obs[4]) 106 | cv2.imshow("Tennis", current_obs[6]) 107 | cv2.waitKey(40) 108 | current_obs = obs 109 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import argparse 3 | from tensorboardX import SummaryWriter 4 | import cv2 5 | import numpy as np 6 | from einops import rearrange 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from collections import deque 11 | from tqdm import tqdm 12 | import copy 13 | import colorama 14 | import random 15 | import json 16 | import shutil 17 | import pickle 18 | import os 19 | 20 | from utils import seed_np_torch, Logger, load_config 21 | from replay_buffer import ReplayBuffer 22 | import env_wrapper 23 | import agents 24 | from sub_models.functions_losses import symexp 25 | from sub_models.world_models import WorldModel, MSELoss 26 | 27 | 28 | def process_visualize(img): 29 | img = img.astype('uint8') 30 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 31 | img = cv2.resize(img, (640, 640)) 32 | return img 33 | 34 | 35 | def build_single_env(env_name, image_size): 36 | env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1) 37 | env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4) 38 | env = gymnasium.wrappers.ResizeObservation(env, shape=image_size) 39 | return env 40 | 41 | 42 | def build_vec_env(env_name, image_size, num_envs): 43 | # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7 44 | def lambda_generator(env_name, image_size): 45 | return lambda: build_single_env(env_name, image_size) 46 | env_fns = [] 47 | env_fns = [lambda_generator(env_name, image_size) for i in range(num_envs)] 48 | vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns) 49 | return vec_env 50 | 51 | 52 | def eval_episodes(num_episode, env_name, max_steps, num_envs, image_size, 53 | world_model: WorldModel, agent: agents.ActorCriticAgent): 54 | world_model.eval() 55 | agent.eval() 56 | vec_env = build_vec_env(env_name, image_size, num_envs=num_envs) 57 | print("Current env: " + colorama.Fore.YELLOW + f"{env_name}" + colorama.Style.RESET_ALL) 58 | sum_reward = np.zeros(num_envs) 59 | current_obs, current_info = vec_env.reset() 60 | context_obs = deque(maxlen=16) 61 | context_action = deque(maxlen=16) 62 | 63 | final_rewards = [] 64 | # for total_steps in tqdm(range(max_steps//num_envs)): 65 | while True: 66 | # sample part >>> 67 | with torch.no_grad(): 68 | if len(context_action) == 0: 69 | action = vec_env.action_space.sample() 70 | else: 71 | context_latent = world_model.encode_obs(torch.cat(list(context_obs), dim=1)) 72 | model_context_action = np.stack(list(context_action), axis=1) 73 | model_context_action = torch.Tensor(model_context_action).cuda() 74 | prior_flattened_sample, last_dist_feat = world_model.calc_last_dist_feat(context_latent, model_context_action) 75 | action = agent.sample_as_env_action( 76 | torch.cat([prior_flattened_sample, last_dist_feat], dim=-1), 77 | greedy=False 78 | ) 79 | 80 | context_obs.append(rearrange(torch.Tensor(current_obs).cuda(), "B H W C -> B 1 C H W")/255) 81 | context_action.append(action) 82 | 83 | obs, reward, done, truncated, info = vec_env.step(action) 84 | # cv2.imshow("current_obs", process_visualize(obs[0])) 85 | # cv2.waitKey(10) 86 | 87 | done_flag = np.logical_or(done, truncated) 88 | if done_flag.any(): 89 | for i in range(num_envs): 90 | if done_flag[i]: 91 | final_rewards.append(sum_reward[i]) 92 | sum_reward[i] = 0 93 | if len(final_rewards) == num_episode: 94 | print("Mean reward: " + colorama.Fore.YELLOW + f"{np.mean(final_rewards)}" + colorama.Style.RESET_ALL) 95 | return np.mean(final_rewards) 96 | 97 | # update current_obs, current_info and sum_reward 98 | sum_reward += reward 99 | current_obs = obs 100 | current_info = info 101 | # <<< sample part 102 | 103 | 104 | if __name__ == "__main__": 105 | # ignore warnings 106 | import warnings 107 | warnings.filterwarnings('ignore') 108 | torch.backends.cuda.matmul.allow_tf32 = True 109 | torch.backends.cudnn.allow_tf32 = True 110 | 111 | # parse arguments 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("-config_path", type=str, required=True) 114 | parser.add_argument("-env_name", type=str, required=True) 115 | parser.add_argument("-run_name", type=str, required=True) 116 | args = parser.parse_args() 117 | conf = load_config(args.config_path) 118 | print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL) 119 | # print(colorama.Fore.RED + str(conf) + colorama.Style.RESET_ALL) 120 | 121 | # set seed 122 | seed_np_torch(seed=conf.BasicSettings.Seed) 123 | 124 | # build and load model/agent 125 | import train 126 | dummy_env = build_single_env(args.env_name, conf.BasicSettings.ImageSize) 127 | action_dim = dummy_env.action_space.n 128 | world_model = train.build_world_model(conf, action_dim) 129 | agent = train.build_agent(conf, action_dim) 130 | root_path = f"ckpt/{args.run_name}" 131 | 132 | import glob 133 | pathes = glob.glob(f"{root_path}/world_model_*.pth") 134 | steps = [int(path.split("_")[-1].split(".")[0]) for path in pathes] 135 | steps.sort() 136 | steps = steps[-1:] 137 | print(steps) 138 | results = [] 139 | for step in tqdm(steps): 140 | world_model.load_state_dict(torch.load(f"{root_path}/world_model_{step}.pth")) 141 | agent.load_state_dict(torch.load(f"{root_path}/agent_{step}.pth")) 142 | # # eval 143 | episode_avg_return = eval_episodes( 144 | num_episode=20, 145 | env_name=args.env_name, 146 | num_envs=5, 147 | max_steps=conf.JointTrainAgent.SampleMaxSteps, 148 | image_size=conf.BasicSettings.ImageSize, 149 | world_model=world_model, 150 | agent=agent 151 | ) 152 | results.append([step, episode_avg_return]) 153 | with open(f"eval_result/{args.run_name}.csv", "w") as fout: 154 | fout.write("step, episode_avg_return\n") 155 | for step, episode_avg_return in results: 156 | fout.write(f"{step},{episode_avg_return}\n") 157 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | env_name=MsPacman 2 | python -u eval.py \ 3 | -env_name "ALE/${env_name}-v5" \ 4 | -run_name "${env_name}-life_done-wm_2L512D8H-100k-seed1"\ 5 | -config_path "config_files/STORM.yaml" 6 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Implementation of STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/storm-efficient-stochastic-transformer-based-1/atari-games-100k-on-atari-100k)](https://paperswithcode.com/sota/atari-games-100k-on-atari-100k?p=storm-efficient-stochastic-transformer-based-1) 4 | 5 | [Paper & OpenReview](https://openreview.net/forum?id=WxnrX42rnS), you may find some useful discussion there. 6 | 7 | This repo contains an implementation of STORM. 8 | 9 | Following the **Training and Evaluating Instructions** to reproduce the main results presented in our paper. One may also find **Additional Useful Information** useful when debugging and observing intermediate results. To reproduce the speed metrics mentioned in the paper, please see **Reproducing Speed Metrics**. 10 | 11 | ## Training and Evaluating Instructions 12 | 13 | 1. Install the necessary dependencies. Note that we conducted our experiments using `python 3.10`. 14 | ```shell 15 | pip install -r requirements.txt 16 | ``` 17 | Installing `AutoROM.accept-rom-license` may take several minutes. 18 | 19 | 2. Train the agent. 20 | ```shell 21 | chmod +x train.sh 22 | ./train.sh 23 | ``` 24 | 25 | The `train.sh` file controls the environment and the running name of a training process. 26 | ```shell 27 | env_name=MsPacman 28 | python -u train.py \ 29 | -n "${env_name}-life_done-wm_2L512D8H-100k-seed1" \ 30 | -seed 1 \ 31 | -config_path "config_files/STORM.yaml" \ 32 | -env_name "ALE/${env_name}-v5" \ 33 | -trajectory_path "trajectory/${env_name}.pkl" 34 | ``` 35 | 36 | - The `env_name` on the first line can be any Atari game, which can be found [here](https://gymnasium.farama.org/environments/atari/). 37 | 38 | - `-n` option is the name for the tensorboard logger and checkpoint folder. You can change it to your preference, but we recommend keeping the environment's name first. The tensorboard logging folder is `runs`, and the checkpoint folder is `ckpt`. 39 | 40 | - The `-seed` parameter controls the running seed during the training. We evaluated our method using 5 seeds and report the mean return in Table 1. 41 | 42 | - The `-config_path` points to a YAML file that controls the model's hyperparameters. The configuration in `config_files/STORM.yaml` is the same as in our paper. 43 | 44 | - The `-trajectory_path` is only useful when the option `UseDemonstration` in the YAML file is set to `True` (by default it's `False`). This corresponds to the ablation studies in Section 5.3. We provide the pre-collected trajectories in the `D_TRAJ.7z` file, and you need to decompress it for use. 45 | 46 | 47 | 3. Evaluate the agent. The evaluation results will be presented in a CSV file located in the `eval_result` folder. 48 | ```shell 49 | chmod +x eval.sh 50 | ./eval.sh 51 | ``` 52 | 53 | The `eval.sh` file controls the environment and the running name when testing an agent. 54 | 55 | ```shell 56 | env_name=MsPacman 57 | python -u eval.py \ 58 | -env_name "ALE/${env_name}-v5" \ 59 | -run_name "${env_name}-life_done-wm_2L512D8H-100k-seed1"\ 60 | -config_path "config_files/STORM.yaml" 61 | ``` 62 | 63 | The `-run_name` option is the same as the `-n` option in `train.sh`. It should be kept the same as in the training script. 64 | 65 | ## Additional Useful Information 66 | You can use Tensorboard to visualize the training curve and the imagination videos: 67 | ```shell 68 | chmod +x TensorBoard.sh 69 | ./TensorBoard.sh 70 | ``` 71 | 72 | 73 | ## Reproducing Speed Metrics 74 | To reproduce the speed metrics mentioned in the paper, please consider the following: 75 | - Hardware requirements: NVIDIA GeForce RTX 3090 with a high frequence CPU, we use `11th Gen Intel(R) Core(TM) i9-11900K` in our experiments. Low frequence CPUs may lead to a GPU idle and slow down the traning. To make full use of a powerful GPU, one can traing several agents at the same time on one device. 76 | - Software requiements: `PyTorch>=2.0.0` is required. 77 | 78 | ## Troubleshooting 79 | ### Mixed precision on other devices 80 | - Our experiments used bfloat16 to accelerate training. To train on devices that do not support bfloat16, such as the NVIDIA V100, you need to change `torch.bfloat16` to `torch.float16` in both `agents.py` and `sub_models/world_models.py`. Additionally, modify the line `attn = attn.masked_fill(mask == 0, -1e9)` to `attn = attn.masked_fill(mask == 0, -6e4)` to prevent overflow. 81 | - On devices like the NVIDIA A100, using bfloat16 may slow down the training. In this case, you can toggle the `self.use_amp = True` option in both `agents.py` and `sub_models/world_models.py`. 82 | 83 | ### Windows and WSL 84 | We've recently observed if one **clones the repo** from `Powershell` and then calls `train.sh` under `WSL shell`, then it may throw an error related to arg parse. This may be due to invisible newlines in the files somehow generated when cloning with git. The solution is to download the zip or clone directly inside `WSL`. 85 | 86 | ## Code references 87 | We've referenced several other projects during the development of this code: 88 | - [Attention is all you need pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch) For Transformer structure, attention operation, and other building blocks. 89 | - [Hugging Face Diffusers](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py) For trainable positional encoding. 90 | - [DreamerV3](https://github.com/danijar/dreamerv3) For Symlog loss, layer & kernel configuration in VAE. 91 | 92 | ## Bibtex 93 | 94 | ``` 95 | @inproceedings{ 96 | zhang2023storm, 97 | title={{STORM}: Efficient Stochastic Transformer based World Models for Reinforcement Learning}, 98 | author={Weipu Zhang and Gang Wang and Jian Sun and Yetian Yuan and Gao Huang}, 99 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 100 | year={2023}, 101 | url={https://openreview.net/forum?id=WxnrX42rnS} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import unittest 4 | import torch 5 | from einops import rearrange 6 | import copy 7 | import pickle 8 | 9 | 10 | class ReplayBuffer(): 11 | def __init__(self, obs_shape, num_envs, max_length=int(1E6), warmup_length=50000, store_on_gpu=False) -> None: 12 | self.store_on_gpu = store_on_gpu 13 | if store_on_gpu: 14 | self.obs_buffer = torch.empty((max_length//num_envs, num_envs, *obs_shape), dtype=torch.uint8, device="cuda", requires_grad=False) 15 | self.action_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.float32, device="cuda", requires_grad=False) 16 | self.reward_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.float32, device="cuda", requires_grad=False) 17 | self.termination_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.float32, device="cuda", requires_grad=False) 18 | else: 19 | self.obs_buffer = np.empty((max_length//num_envs, num_envs, *obs_shape), dtype=np.uint8) 20 | self.action_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32) 21 | self.reward_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32) 22 | self.termination_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32) 23 | 24 | self.length = 0 25 | self.num_envs = num_envs 26 | self.last_pointer = -1 27 | self.max_length = max_length 28 | self.warmup_length = warmup_length 29 | self.external_buffer_length = None 30 | 31 | def load_trajectory(self, path): 32 | buffer = pickle.load(open(path, "rb")) 33 | if self.store_on_gpu: 34 | self.external_buffer = {name: torch.from_numpy(buffer[name]).to("cuda") for name in buffer} 35 | else: 36 | self.external_buffer = buffer 37 | self.external_buffer_length = self.external_buffer["obs"].shape[0] 38 | 39 | def sample_external(self, batch_size, batch_length, to_device="cuda"): 40 | indexes = np.random.randint(0, self.external_buffer_length+1-batch_length, size=batch_size) 41 | if self.store_on_gpu: 42 | obs = torch.stack([self.external_buffer["obs"][idx:idx+batch_length] for idx in indexes]) 43 | action = torch.stack([self.external_buffer["action"][idx:idx+batch_length] for idx in indexes]) 44 | reward = torch.stack([self.external_buffer["reward"][idx:idx+batch_length] for idx in indexes]) 45 | termination = torch.stack([self.external_buffer["done"][idx:idx+batch_length] for idx in indexes]) 46 | else: 47 | obs = np.stack([self.external_buffer["obs"][idx:idx+batch_length] for idx in indexes]) 48 | action = np.stack([self.external_buffer["action"][idx:idx+batch_length] for idx in indexes]) 49 | reward = np.stack([self.external_buffer["reward"][idx:idx+batch_length] for idx in indexes]) 50 | termination = np.stack([self.external_buffer["done"][idx:idx+batch_length] for idx in indexes]) 51 | return obs, action, reward, termination 52 | 53 | def ready(self): 54 | return self.length * self.num_envs > self.warmup_length 55 | 56 | @torch.no_grad() 57 | def sample(self, batch_size, external_batch_size, batch_length, to_device="cuda"): 58 | if self.store_on_gpu: 59 | obs, action, reward, termination = [], [], [], [] 60 | if batch_size > 0: 61 | for i in range(self.num_envs): 62 | indexes = np.random.randint(0, self.length+1-batch_length, size=batch_size//self.num_envs) 63 | obs.append(torch.stack([self.obs_buffer[idx:idx+batch_length, i] for idx in indexes])) 64 | action.append(torch.stack([self.action_buffer[idx:idx+batch_length, i] for idx in indexes])) 65 | reward.append(torch.stack([self.reward_buffer[idx:idx+batch_length, i] for idx in indexes])) 66 | termination.append(torch.stack([self.termination_buffer[idx:idx+batch_length, i] for idx in indexes])) 67 | 68 | if self.external_buffer_length is not None and external_batch_size > 0: 69 | external_obs, external_action, external_reward, external_termination = self.sample_external( 70 | external_batch_size, batch_length, to_device) 71 | obs.append(external_obs) 72 | action.append(external_action) 73 | reward.append(external_reward) 74 | termination.append(external_termination) 75 | 76 | obs = torch.cat(obs, dim=0).float() / 255 77 | obs = rearrange(obs, "B T H W C -> B T C H W") 78 | action = torch.cat(action, dim=0) 79 | reward = torch.cat(reward, dim=0) 80 | termination = torch.cat(termination, dim=0) 81 | else: 82 | obs, action, reward, termination = [], [], [], [] 83 | if batch_size > 0: 84 | for i in range(self.num_envs): 85 | indexes = np.random.randint(0, self.length+1-batch_length, size=batch_size//self.num_envs) 86 | obs.append(np.stack([self.obs_buffer[idx:idx+batch_length, i] for idx in indexes])) 87 | action.append(np.stack([self.action_buffer[idx:idx+batch_length, i] for idx in indexes])) 88 | reward.append(np.stack([self.reward_buffer[idx:idx+batch_length, i] for idx in indexes])) 89 | termination.append(np.stack([self.termination_buffer[idx:idx+batch_length, i] for idx in indexes])) 90 | 91 | if self.external_buffer_length is not None and external_batch_size > 0: 92 | external_obs, external_action, external_reward, external_termination = self.sample_external( 93 | external_batch_size, batch_length, to_device) 94 | obs.append(external_obs) 95 | action.append(external_action) 96 | reward.append(external_reward) 97 | termination.append(external_termination) 98 | 99 | obs = torch.from_numpy(np.concatenate(obs, axis=0)).float().cuda() / 255 100 | obs = rearrange(obs, "B T H W C -> B T C H W") 101 | action = torch.from_numpy(np.concatenate(action, axis=0)).cuda() 102 | reward = torch.from_numpy(np.concatenate(reward, axis=0)).cuda() 103 | termination = torch.from_numpy(np.concatenate(termination, axis=0)).cuda() 104 | 105 | return obs, action, reward, termination 106 | 107 | def append(self, obs, action, reward, termination): 108 | # obs/nex_obs: torch Tensor 109 | # action/reward/termination: int or float or bool 110 | self.last_pointer = (self.last_pointer + 1) % (self.max_length//self.num_envs) 111 | if self.store_on_gpu: 112 | self.obs_buffer[self.last_pointer] = torch.from_numpy(obs) 113 | self.action_buffer[self.last_pointer] = torch.from_numpy(action) 114 | self.reward_buffer[self.last_pointer] = torch.from_numpy(reward) 115 | self.termination_buffer[self.last_pointer] = torch.from_numpy(termination) 116 | else: 117 | self.obs_buffer[self.last_pointer] = obs 118 | self.action_buffer[self.last_pointer] = action 119 | self.reward_buffer[self.last_pointer] = reward 120 | self.termination_buffer[self.last_pointer] = termination 121 | 122 | if len(self) < self.max_length: 123 | self.length += 1 124 | 125 | def __len__(self): 126 | return self.length * self.num_envs 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | torchvision 4 | yacs 5 | tensorboardX 6 | tensorboard 7 | moviepy 8 | colorama 9 | einops 10 | tqdm 11 | opencv-python 12 | gymnasium[atari,accept-rom-license] -------------------------------------------------------------------------------- /sub_models/attention_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | 7 | 8 | def get_subsequent_mask(seq): 9 | ''' For masking out the subsequent info. ''' 10 | batch_size, batch_length = seq.shape[:2] 11 | subsequent_mask = (1 - torch.triu( 12 | torch.ones((1, batch_length, batch_length), device=seq.device), diagonal=1)).bool() 13 | return subsequent_mask 14 | 15 | 16 | def get_subsequent_mask_with_batch_length(batch_length, device): 17 | ''' For masking out the subsequent info. ''' 18 | subsequent_mask = (1 - torch.triu(torch.ones((1, batch_length, batch_length), device=device), diagonal=1)).bool() 19 | return subsequent_mask 20 | 21 | 22 | def get_vector_mask(batch_length, device): 23 | mask = torch.ones((1, 1, batch_length), device=device).bool() 24 | # mask = torch.ones((1, batch_length, 1), device=device).bool() 25 | return mask 26 | 27 | 28 | class ScaledDotProductAttention(nn.Module): 29 | ''' Scaled Dot-Product Attention ''' 30 | 31 | def __init__(self, temperature, attn_dropout=0.1): 32 | super().__init__() 33 | self.temperature = temperature 34 | self.dropout = nn.Dropout(attn_dropout) 35 | 36 | def forward(self, q, k, v, mask=None): 37 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 38 | 39 | if mask is not None: 40 | attn = attn.masked_fill(mask == 0, -1e9) 41 | 42 | attn = self.dropout(F.softmax(attn, dim=-1)) 43 | output = torch.matmul(attn, v) 44 | 45 | return output, attn 46 | 47 | 48 | class MultiHeadAttention(nn.Module): 49 | ''' Multi-Head Attention module ''' 50 | 51 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 52 | super().__init__() 53 | 54 | self.n_head = n_head 55 | self.d_k = d_k 56 | self.d_v = d_v 57 | 58 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 59 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 60 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 61 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 62 | 63 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 67 | 68 | def forward(self, q, k, v, mask=None): 69 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 70 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 71 | 72 | residual = q 73 | 74 | # Pass through the pre-attention projection: b x lq x (n*dv) 75 | # Separate different heads: b x lq x n x dv 76 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 77 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 78 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 79 | 80 | # Transpose for attention dot product: b x n x lq x dv 81 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 82 | 83 | if mask is not None: 84 | mask = mask.unsqueeze(1) # For head axis broadcasting. 85 | 86 | q, attn = self.attention(q, k, v, mask=mask) 87 | 88 | # Transpose to move the head dimension back: b x lq x n x dv 89 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 90 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 91 | q = self.dropout(self.fc(q)) 92 | q += residual 93 | 94 | q = self.layer_norm(q) 95 | 96 | return q, attn 97 | 98 | 99 | class PositionwiseFeedForward(nn.Module): 100 | ''' A two-feed-forward-layer module ''' 101 | 102 | def __init__(self, d_in, d_hid, dropout=0.1): 103 | super().__init__() 104 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 105 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 106 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | def forward(self, x): 110 | 111 | residual = x 112 | 113 | x = self.w_2(F.relu(self.w_1(x))) 114 | x = self.dropout(x) 115 | x += residual 116 | 117 | x = self.layer_norm(x) 118 | 119 | return x 120 | 121 | 122 | class AttentionBlock(nn.Module): 123 | def __init__(self, feat_dim, hidden_dim, num_heads, dropout): 124 | super().__init__() 125 | self.slf_attn = MultiHeadAttention(num_heads, feat_dim, feat_dim//num_heads, feat_dim//num_heads, dropout=dropout) 126 | self.pos_ffn = PositionwiseFeedForward(feat_dim, hidden_dim, dropout=dropout) 127 | 128 | def forward(self, enc_input, slf_attn_mask=None): 129 | enc_output, enc_slf_attn = self.slf_attn( 130 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 131 | enc_output = self.pos_ffn(enc_output) 132 | return enc_output, enc_slf_attn 133 | 134 | 135 | class AttentionBlockKVCache(nn.Module): 136 | def __init__(self, feat_dim, hidden_dim, num_heads, dropout): 137 | super().__init__() 138 | self.slf_attn = MultiHeadAttention(num_heads, feat_dim, feat_dim//num_heads, feat_dim//num_heads, dropout=dropout) 139 | self.pos_ffn = PositionwiseFeedForward(feat_dim, hidden_dim, dropout=dropout) 140 | 141 | def forward(self, q, k, v, slf_attn_mask=None): 142 | output, attn = self.slf_attn(q, k, v, mask=slf_attn_mask) 143 | output = self.pos_ffn(output) 144 | return output, attn 145 | 146 | 147 | class PositionalEncoding1D(nn.Module): 148 | def __init__( 149 | self, 150 | max_length: int, 151 | embed_dim: int 152 | ): 153 | super().__init__() 154 | self.max_length = max_length 155 | self.embed_dim = embed_dim 156 | 157 | self.pos_emb = nn.Embedding(self.max_length, embed_dim) 158 | 159 | def forward(self, feat): 160 | pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device)) 161 | pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0]) 162 | 163 | feat = feat + pos_emb[:, :feat.shape[1], :] 164 | return feat 165 | 166 | def forward_with_position(self, feat, position): 167 | assert feat.shape[1] == 1 168 | pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device)) 169 | pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0]) 170 | 171 | feat = feat + pos_emb[:, position:position+1, :] 172 | return feat 173 | -------------------------------------------------------------------------------- /sub_models/functions_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.no_grad() 7 | def symlog(x): 8 | return torch.sign(x) * torch.log(1 + torch.abs(x)) 9 | 10 | 11 | @torch.no_grad() 12 | def symexp(x): 13 | return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) 14 | 15 | 16 | class SymLogLoss(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, output, target): 21 | target = symlog(target) 22 | return 0.5*F.mse_loss(output, target) 23 | 24 | 25 | class SymLogTwoHotLoss(nn.Module): 26 | def __init__(self, num_classes, lower_bound, upper_bound): 27 | super().__init__() 28 | self.num_classes = num_classes 29 | self.lower_bound = lower_bound 30 | self.upper_bound = upper_bound 31 | self.bin_length = (upper_bound - lower_bound) / (num_classes-1) 32 | 33 | # use register buffer so that bins move with .cuda() automatically 34 | self.bins: torch.Tensor 35 | self.register_buffer( 36 | 'bins', torch.linspace(-20, 20, num_classes), persistent=False) 37 | 38 | def forward(self, output, target): 39 | target = symlog(target) 40 | assert target.min() >= self.lower_bound and target.max() <= self.upper_bound 41 | 42 | index = torch.bucketize(target, self.bins) 43 | diff = target - self.bins[index-1] # -1 to get the lower bound 44 | weight = diff / self.bin_length 45 | weight = torch.clamp(weight, 0, 1) 46 | weight = weight.unsqueeze(-1) 47 | 48 | target_prob = (1-weight)*F.one_hot(index-1, self.num_classes) + weight*F.one_hot(index, self.num_classes) 49 | 50 | loss = -target_prob * F.log_softmax(output, dim=-1) 51 | loss = loss.sum(dim=-1) 52 | return loss.mean() 53 | 54 | def decode(self, output): 55 | return symexp(F.softmax(output, dim=-1) @ self.bins) 56 | 57 | 58 | if __name__ == "__main__": 59 | loss_func = SymLogTwoHotLoss(255, -20, 20) 60 | output = torch.randn(1, 1, 255).requires_grad_() 61 | target = torch.ones(1).reshape(1, 1).float() * 0.1 62 | print(target) 63 | loss = loss_func(output, target) 64 | print(loss) 65 | 66 | # prob = torch.ones(1, 1, 255)*0.5/255 67 | # prob[0, 0, 128] = 0.5 68 | # logits = torch.log(prob) 69 | # print(loss_func.decode(logits), loss_func.bins[128]) 70 | -------------------------------------------------------------------------------- /sub_models/transformer_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import repeat, rearrange 5 | 6 | from sub_models.attention_blocks import get_vector_mask 7 | from sub_models.attention_blocks import PositionalEncoding1D, AttentionBlock, AttentionBlockKVCache 8 | 9 | 10 | class StochasticTransformer(nn.Module): 11 | def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout): 12 | super().__init__() 13 | self.action_dim = action_dim 14 | 15 | # mix image_embedding and action 16 | self.stem = nn.Sequential( 17 | nn.Linear(stoch_dim+action_dim, feat_dim, bias=False), 18 | nn.LayerNorm(feat_dim), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(feat_dim, feat_dim, bias=False), 21 | nn.LayerNorm(feat_dim) 22 | ) 23 | self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim) 24 | self.layer_stack = nn.ModuleList([ 25 | AttentionBlock(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers) 26 | ]) 27 | self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6) # TODO: check if this is necessary 28 | 29 | self.head = nn.Linear(feat_dim, stoch_dim) 30 | 31 | def forward(self, samples, action, mask): 32 | action = F.one_hot(action.long(), self.action_dim).float() 33 | feats = self.stem(torch.cat([samples, action], dim=-1)) 34 | feats = self.position_encoding(feats) 35 | feats = self.layer_norm(feats) 36 | 37 | for enc_layer in self.layer_stack: 38 | feats, attn = enc_layer(feats, mask) 39 | 40 | feat = self.head(feats) 41 | return feat 42 | 43 | 44 | class StochasticTransformerKVCache(nn.Module): 45 | def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout): 46 | super().__init__() 47 | self.action_dim = action_dim 48 | self.feat_dim = feat_dim 49 | 50 | # mix image_embedding and action 51 | self.stem = nn.Sequential( 52 | nn.Linear(stoch_dim+action_dim, feat_dim, bias=False), 53 | nn.LayerNorm(feat_dim), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(feat_dim, feat_dim, bias=False), 56 | nn.LayerNorm(feat_dim) 57 | ) 58 | self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim) 59 | self.layer_stack = nn.ModuleList([ 60 | AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers) 61 | ]) 62 | self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6) # TODO: check if this is necessary 63 | 64 | def forward(self, samples, action, mask): 65 | ''' 66 | Normal forward pass 67 | ''' 68 | action = F.one_hot(action.long(), self.action_dim).float() 69 | feats = self.stem(torch.cat([samples, action], dim=-1)) 70 | feats = self.position_encoding(feats) 71 | feats = self.layer_norm(feats) 72 | 73 | for layer in self.layer_stack: 74 | feats, attn = layer(feats, feats, feats, mask) 75 | 76 | return feats 77 | 78 | def reset_kv_cache_list(self, batch_size, dtype): 79 | ''' 80 | Reset self.kv_cache_list 81 | ''' 82 | self.kv_cache_list = [] 83 | for layer in self.layer_stack: 84 | self.kv_cache_list.append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device="cuda")) 85 | 86 | def forward_with_kv_cache(self, samples, action): 87 | ''' 88 | Forward pass with kv_cache, cache stored in self.kv_cache_list 89 | ''' 90 | assert samples.shape[1] == 1 91 | mask = get_vector_mask(self.kv_cache_list[0].shape[1]+1, samples.device) 92 | 93 | action = F.one_hot(action.long(), self.action_dim).float() 94 | feats = self.stem(torch.cat([samples, action], dim=-1)) 95 | feats = self.position_encoding.forward_with_position(feats, position=self.kv_cache_list[0].shape[1]) 96 | feats = self.layer_norm(feats) 97 | 98 | for idx, layer in enumerate(self.layer_stack): 99 | self.kv_cache_list[idx] = torch.cat([self.kv_cache_list[idx], feats], dim=1) 100 | feats, attn = layer(feats, self.kv_cache_list[idx], self.kv_cache_list[idx], mask) 101 | 102 | return feats 103 | -------------------------------------------------------------------------------- /sub_models/world_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import OneHotCategorical, Normal 5 | from einops import rearrange, repeat, reduce 6 | from einops.layers.torch import Rearrange 7 | from torch.cuda.amp import autocast 8 | 9 | from sub_models.functions_losses import SymLogTwoHotLoss 10 | from sub_models.attention_blocks import get_subsequent_mask_with_batch_length, get_subsequent_mask 11 | from sub_models.transformer_model import StochasticTransformerKVCache 12 | import agents 13 | 14 | 15 | class EncoderBN(nn.Module): 16 | def __init__(self, in_channels, stem_channels, final_feature_width) -> None: 17 | super().__init__() 18 | 19 | backbone = [] 20 | # stem 21 | backbone.append( 22 | nn.Conv2d( 23 | in_channels=in_channels, 24 | out_channels=stem_channels, 25 | kernel_size=4, 26 | stride=2, 27 | padding=1, 28 | bias=False 29 | ) 30 | ) 31 | feature_width = 64//2 32 | channels = stem_channels 33 | backbone.append(nn.BatchNorm2d(stem_channels)) 34 | backbone.append(nn.ReLU(inplace=True)) 35 | 36 | # layers 37 | while True: 38 | backbone.append( 39 | nn.Conv2d( 40 | in_channels=channels, 41 | out_channels=channels*2, 42 | kernel_size=4, 43 | stride=2, 44 | padding=1, 45 | bias=False 46 | ) 47 | ) 48 | channels *= 2 49 | feature_width //= 2 50 | backbone.append(nn.BatchNorm2d(channels)) 51 | backbone.append(nn.ReLU(inplace=True)) 52 | 53 | if feature_width == final_feature_width: 54 | break 55 | 56 | self.backbone = nn.Sequential(*backbone) 57 | self.last_channels = channels 58 | 59 | def forward(self, x): 60 | batch_size = x.shape[0] 61 | x = rearrange(x, "B L C H W -> (B L) C H W") 62 | x = self.backbone(x) 63 | x = rearrange(x, "(B L) C H W -> B L (C H W)", B=batch_size) 64 | return x 65 | 66 | 67 | class DecoderBN(nn.Module): 68 | def __init__(self, stoch_dim, last_channels, original_in_channels, stem_channels, final_feature_width) -> None: 69 | super().__init__() 70 | 71 | backbone = [] 72 | # stem 73 | backbone.append(nn.Linear(stoch_dim, last_channels*final_feature_width*final_feature_width, bias=False)) 74 | backbone.append(Rearrange('B L (C H W) -> (B L) C H W', C=last_channels, H=final_feature_width)) 75 | backbone.append(nn.BatchNorm2d(last_channels)) 76 | backbone.append(nn.ReLU(inplace=True)) 77 | # residual_layer 78 | # backbone.append(ResidualStack(last_channels, 1, last_channels//4)) 79 | # layers 80 | channels = last_channels 81 | feat_width = final_feature_width 82 | while True: 83 | if channels == stem_channels: 84 | break 85 | backbone.append( 86 | nn.ConvTranspose2d( 87 | in_channels=channels, 88 | out_channels=channels//2, 89 | kernel_size=4, 90 | stride=2, 91 | padding=1, 92 | bias=False 93 | ) 94 | ) 95 | channels //= 2 96 | feat_width *= 2 97 | backbone.append(nn.BatchNorm2d(channels)) 98 | backbone.append(nn.ReLU(inplace=True)) 99 | 100 | backbone.append( 101 | nn.ConvTranspose2d( 102 | in_channels=channels, 103 | out_channels=original_in_channels, 104 | kernel_size=4, 105 | stride=2, 106 | padding=1 107 | ) 108 | ) 109 | self.backbone = nn.Sequential(*backbone) 110 | 111 | def forward(self, sample): 112 | batch_size = sample.shape[0] 113 | obs_hat = self.backbone(sample) 114 | obs_hat = rearrange(obs_hat, "(B L) C H W -> B L C H W", B=batch_size) 115 | return obs_hat 116 | 117 | 118 | class DistHead(nn.Module): 119 | ''' 120 | Dist: abbreviation of distribution 121 | ''' 122 | def __init__(self, image_feat_dim, transformer_hidden_dim, stoch_dim) -> None: 123 | super().__init__() 124 | self.stoch_dim = stoch_dim 125 | self.post_head = nn.Linear(image_feat_dim, stoch_dim*stoch_dim) 126 | self.prior_head = nn.Linear(transformer_hidden_dim, stoch_dim*stoch_dim) 127 | 128 | def unimix(self, logits, mixing_ratio=0.01): 129 | # uniform noise mixing 130 | probs = F.softmax(logits, dim=-1) 131 | mixed_probs = mixing_ratio * torch.ones_like(probs) / self.stoch_dim + (1-mixing_ratio) * probs 132 | logits = torch.log(mixed_probs) 133 | return logits 134 | 135 | def forward_post(self, x): 136 | logits = self.post_head(x) 137 | logits = rearrange(logits, "B L (K C) -> B L K C", K=self.stoch_dim) 138 | logits = self.unimix(logits) 139 | return logits 140 | 141 | def forward_prior(self, x): 142 | logits = self.prior_head(x) 143 | logits = rearrange(logits, "B L (K C) -> B L K C", K=self.stoch_dim) 144 | logits = self.unimix(logits) 145 | return logits 146 | 147 | 148 | class RewardDecoder(nn.Module): 149 | def __init__(self, num_classes, embedding_size, transformer_hidden_dim) -> None: 150 | super().__init__() 151 | self.backbone = nn.Sequential( 152 | nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False), 153 | nn.LayerNorm(transformer_hidden_dim), 154 | nn.ReLU(inplace=True), 155 | nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False), 156 | nn.LayerNorm(transformer_hidden_dim), 157 | nn.ReLU(inplace=True), 158 | ) 159 | self.head = nn.Linear(transformer_hidden_dim, num_classes) 160 | 161 | def forward(self, feat): 162 | feat = self.backbone(feat) 163 | reward = self.head(feat) 164 | return reward 165 | 166 | 167 | class TerminationDecoder(nn.Module): 168 | def __init__(self, embedding_size, transformer_hidden_dim) -> None: 169 | super().__init__() 170 | self.backbone = nn.Sequential( 171 | nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False), 172 | nn.LayerNorm(transformer_hidden_dim), 173 | nn.ReLU(inplace=True), 174 | nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False), 175 | nn.LayerNorm(transformer_hidden_dim), 176 | nn.ReLU(inplace=True), 177 | ) 178 | self.head = nn.Sequential( 179 | nn.Linear(transformer_hidden_dim, 1), 180 | # nn.Sigmoid() 181 | ) 182 | 183 | def forward(self, feat): 184 | feat = self.backbone(feat) 185 | termination = self.head(feat) 186 | termination = termination.squeeze(-1) # remove last 1 dim 187 | return termination 188 | 189 | 190 | class MSELoss(nn.Module): 191 | def __init__(self) -> None: 192 | super().__init__() 193 | 194 | def forward(self, obs_hat, obs): 195 | loss = (obs_hat - obs)**2 196 | loss = reduce(loss, "B L C H W -> B L", "sum") 197 | return loss.mean() 198 | 199 | 200 | class CategoricalKLDivLossWithFreeBits(nn.Module): 201 | def __init__(self, free_bits) -> None: 202 | super().__init__() 203 | self.free_bits = free_bits 204 | 205 | def forward(self, p_logits, q_logits): 206 | p_dist = OneHotCategorical(logits=p_logits) 207 | q_dist = OneHotCategorical(logits=q_logits) 208 | kl_div = torch.distributions.kl.kl_divergence(p_dist, q_dist) 209 | kl_div = reduce(kl_div, "B L D -> B L", "sum") 210 | kl_div = kl_div.mean() 211 | real_kl_div = kl_div 212 | kl_div = torch.max(torch.ones_like(kl_div)*self.free_bits, kl_div) 213 | return kl_div, real_kl_div 214 | 215 | 216 | class WorldModel(nn.Module): 217 | def __init__(self, in_channels, action_dim, 218 | transformer_max_length, transformer_hidden_dim, transformer_num_layers, transformer_num_heads): 219 | super().__init__() 220 | self.transformer_hidden_dim = transformer_hidden_dim 221 | self.final_feature_width = 4 222 | self.stoch_dim = 32 223 | self.stoch_flattened_dim = self.stoch_dim*self.stoch_dim 224 | self.use_amp = True 225 | self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32 226 | self.imagine_batch_size = -1 227 | self.imagine_batch_length = -1 228 | 229 | self.encoder = EncoderBN( 230 | in_channels=in_channels, 231 | stem_channels=32, 232 | final_feature_width=self.final_feature_width 233 | ) 234 | self.storm_transformer = StochasticTransformerKVCache( 235 | stoch_dim=self.stoch_flattened_dim, 236 | action_dim=action_dim, 237 | feat_dim=transformer_hidden_dim, 238 | num_layers=transformer_num_layers, 239 | num_heads=transformer_num_heads, 240 | max_length=transformer_max_length, 241 | dropout=0.1 242 | ) 243 | self.dist_head = DistHead( 244 | image_feat_dim=self.encoder.last_channels*self.final_feature_width*self.final_feature_width, 245 | transformer_hidden_dim=transformer_hidden_dim, 246 | stoch_dim=self.stoch_dim 247 | ) 248 | self.image_decoder = DecoderBN( 249 | stoch_dim=self.stoch_flattened_dim, 250 | last_channels=self.encoder.last_channels, 251 | original_in_channels=in_channels, 252 | stem_channels=32, 253 | final_feature_width=self.final_feature_width 254 | ) 255 | self.reward_decoder = RewardDecoder( 256 | num_classes=255, 257 | embedding_size=self.stoch_flattened_dim, 258 | transformer_hidden_dim=transformer_hidden_dim 259 | ) 260 | self.termination_decoder = TerminationDecoder( 261 | embedding_size=self.stoch_flattened_dim, 262 | transformer_hidden_dim=transformer_hidden_dim 263 | ) 264 | 265 | self.mse_loss_func = MSELoss() 266 | self.ce_loss = nn.CrossEntropyLoss() 267 | self.bce_with_logits_loss_func = nn.BCEWithLogitsLoss() 268 | self.symlog_twohot_loss_func = SymLogTwoHotLoss(num_classes=255, lower_bound=-20, upper_bound=20) 269 | self.categorical_kl_div_loss = CategoricalKLDivLossWithFreeBits(free_bits=1) 270 | self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4) 271 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) 272 | 273 | def encode_obs(self, obs): 274 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 275 | embedding = self.encoder(obs) 276 | post_logits = self.dist_head.forward_post(embedding) 277 | sample = self.stright_throught_gradient(post_logits, sample_mode="random_sample") 278 | flattened_sample = self.flatten_sample(sample) 279 | return flattened_sample 280 | 281 | def calc_last_dist_feat(self, latent, action): 282 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 283 | temporal_mask = get_subsequent_mask(latent) 284 | dist_feat = self.storm_transformer(latent, action, temporal_mask) 285 | last_dist_feat = dist_feat[:, -1:] 286 | prior_logits = self.dist_head.forward_prior(last_dist_feat) 287 | prior_sample = self.stright_throught_gradient(prior_logits, sample_mode="random_sample") 288 | prior_flattened_sample = self.flatten_sample(prior_sample) 289 | return prior_flattened_sample, last_dist_feat 290 | 291 | def predict_next(self, last_flattened_sample, action, log_video=True): 292 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 293 | dist_feat = self.storm_transformer.forward_with_kv_cache(last_flattened_sample, action) 294 | prior_logits = self.dist_head.forward_prior(dist_feat) 295 | 296 | # decoding 297 | prior_sample = self.stright_throught_gradient(prior_logits, sample_mode="random_sample") 298 | prior_flattened_sample = self.flatten_sample(prior_sample) 299 | if log_video: 300 | obs_hat = self.image_decoder(prior_flattened_sample) 301 | else: 302 | obs_hat = None 303 | reward_hat = self.reward_decoder(dist_feat) 304 | reward_hat = self.symlog_twohot_loss_func.decode(reward_hat) 305 | termination_hat = self.termination_decoder(dist_feat) 306 | termination_hat = termination_hat > 0 307 | 308 | return obs_hat, reward_hat, termination_hat, prior_flattened_sample, dist_feat 309 | 310 | def stright_throught_gradient(self, logits, sample_mode="random_sample"): 311 | dist = OneHotCategorical(logits=logits) 312 | if sample_mode == "random_sample": 313 | sample = dist.sample() + dist.probs - dist.probs.detach() 314 | elif sample_mode == "mode": 315 | sample = dist.mode 316 | elif sample_mode == "probs": 317 | sample = dist.probs 318 | return sample 319 | 320 | def flatten_sample(self, sample): 321 | return rearrange(sample, "B L K C -> B L (K C)") 322 | 323 | def init_imagine_buffer(self, imagine_batch_size, imagine_batch_length, dtype): 324 | ''' 325 | This can slightly improve the efficiency of imagine_data 326 | But may vary across different machines 327 | ''' 328 | if self.imagine_batch_size != imagine_batch_size or self.imagine_batch_length != imagine_batch_length: 329 | print(f"init_imagine_buffer: {imagine_batch_size}x{imagine_batch_length}@{dtype}") 330 | self.imagine_batch_size = imagine_batch_size 331 | self.imagine_batch_length = imagine_batch_length 332 | latent_size = (imagine_batch_size, imagine_batch_length+1, self.stoch_flattened_dim) 333 | hidden_size = (imagine_batch_size, imagine_batch_length+1, self.transformer_hidden_dim) 334 | scalar_size = (imagine_batch_size, imagine_batch_length) 335 | self.latent_buffer = torch.zeros(latent_size, dtype=dtype, device="cuda") 336 | self.hidden_buffer = torch.zeros(hidden_size, dtype=dtype, device="cuda") 337 | self.action_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda") 338 | self.reward_hat_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda") 339 | self.termination_hat_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda") 340 | 341 | def imagine_data(self, agent: agents.ActorCriticAgent, sample_obs, sample_action, 342 | imagine_batch_size, imagine_batch_length, log_video, logger): 343 | self.init_imagine_buffer(imagine_batch_size, imagine_batch_length, dtype=self.tensor_dtype) 344 | obs_hat_list = [] 345 | 346 | self.storm_transformer.reset_kv_cache_list(imagine_batch_size, dtype=self.tensor_dtype) 347 | # context 348 | context_latent = self.encode_obs(sample_obs) 349 | for i in range(sample_obs.shape[1]): # context_length is sample_obs.shape[1] 350 | last_obs_hat, last_reward_hat, last_termination_hat, last_latent, last_dist_feat = self.predict_next( 351 | context_latent[:, i:i+1], 352 | sample_action[:, i:i+1], 353 | log_video=log_video 354 | ) 355 | self.latent_buffer[:, 0:1] = last_latent 356 | self.hidden_buffer[:, 0:1] = last_dist_feat 357 | 358 | # imagine 359 | for i in range(imagine_batch_length): 360 | action = agent.sample(torch.cat([self.latent_buffer[:, i:i+1], self.hidden_buffer[:, i:i+1]], dim=-1)) 361 | self.action_buffer[:, i:i+1] = action 362 | 363 | last_obs_hat, last_reward_hat, last_termination_hat, last_latent, last_dist_feat = self.predict_next( 364 | self.latent_buffer[:, i:i+1], self.action_buffer[:, i:i+1], log_video=log_video) 365 | 366 | self.latent_buffer[:, i+1:i+2] = last_latent 367 | self.hidden_buffer[:, i+1:i+2] = last_dist_feat 368 | self.reward_hat_buffer[:, i:i+1] = last_reward_hat 369 | self.termination_hat_buffer[:, i:i+1] = last_termination_hat 370 | if log_video: 371 | obs_hat_list.append(last_obs_hat[::imagine_batch_size//16]) # uniform sample vec_env 372 | 373 | if log_video: 374 | logger.log("Imagine/predict_video", torch.clamp(torch.cat(obs_hat_list, dim=1), 0, 1).cpu().float().detach().numpy()) 375 | 376 | return torch.cat([self.latent_buffer, self.hidden_buffer], dim=-1), self.action_buffer, self.reward_hat_buffer, self.termination_hat_buffer 377 | 378 | def update(self, obs, action, reward, termination, logger=None): 379 | self.train() 380 | batch_size, batch_length = obs.shape[:2] 381 | 382 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 383 | # encoding 384 | embedding = self.encoder(obs) 385 | post_logits = self.dist_head.forward_post(embedding) 386 | sample = self.stright_throught_gradient(post_logits, sample_mode="random_sample") 387 | flattened_sample = self.flatten_sample(sample) 388 | 389 | # decoding image 390 | obs_hat = self.image_decoder(flattened_sample) 391 | 392 | # transformer 393 | temporal_mask = get_subsequent_mask_with_batch_length(batch_length, flattened_sample.device) 394 | dist_feat = self.storm_transformer(flattened_sample, action, temporal_mask) 395 | prior_logits = self.dist_head.forward_prior(dist_feat) 396 | # decoding reward and termination with dist_feat 397 | reward_hat = self.reward_decoder(dist_feat) 398 | termination_hat = self.termination_decoder(dist_feat) 399 | 400 | # env loss 401 | reconstruction_loss = self.mse_loss_func(obs_hat, obs) 402 | reward_loss = self.symlog_twohot_loss_func(reward_hat, reward) 403 | termination_loss = self.bce_with_logits_loss_func(termination_hat, termination) 404 | # dyn-rep loss 405 | dynamics_loss, dynamics_real_kl_div = self.categorical_kl_div_loss(post_logits[:, 1:].detach(), prior_logits[:, :-1]) 406 | representation_loss, representation_real_kl_div = self.categorical_kl_div_loss(post_logits[:, 1:], prior_logits[:, :-1].detach()) 407 | total_loss = reconstruction_loss + reward_loss + termination_loss + 0.5*dynamics_loss + 0.1*representation_loss 408 | 409 | # gradient descent 410 | self.scaler.scale(total_loss).backward() 411 | self.scaler.unscale_(self.optimizer) # for clip grad 412 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1000.0) 413 | self.scaler.step(self.optimizer) 414 | self.scaler.update() 415 | self.optimizer.zero_grad(set_to_none=True) 416 | 417 | if logger is not None: 418 | logger.log("WorldModel/reconstruction_loss", reconstruction_loss.item()) 419 | logger.log("WorldModel/reward_loss", reward_loss.item()) 420 | logger.log("WorldModel/termination_loss", termination_loss.item()) 421 | logger.log("WorldModel/dynamics_loss", dynamics_loss.item()) 422 | logger.log("WorldModel/dynamics_real_kl_div", dynamics_real_kl_div.item()) 423 | logger.log("WorldModel/representation_loss", representation_loss.item()) 424 | logger.log("WorldModel/representation_real_kl_div", representation_real_kl_div.item()) 425 | logger.log("WorldModel/total_loss", total_loss.item()) 426 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import argparse 3 | from tensorboardX import SummaryWriter 4 | import cv2 5 | import numpy as np 6 | from einops import rearrange 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from collections import deque 11 | from tqdm import tqdm 12 | import copy 13 | import colorama 14 | import random 15 | import json 16 | import shutil 17 | import pickle 18 | import os 19 | 20 | from utils import seed_np_torch, Logger, load_config 21 | from replay_buffer import ReplayBuffer 22 | import env_wrapper 23 | import agents 24 | from sub_models.functions_losses import symexp 25 | from sub_models.world_models import WorldModel, MSELoss 26 | 27 | 28 | def build_single_env(env_name, image_size, seed): 29 | env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1) 30 | env = env_wrapper.SeedEnvWrapper(env, seed=seed) 31 | env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4) 32 | env = gymnasium.wrappers.ResizeObservation(env, shape=image_size) 33 | env = env_wrapper.LifeLossInfo(env) 34 | return env 35 | 36 | 37 | def build_vec_env(env_name, image_size, num_envs, seed): 38 | # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7 39 | def lambda_generator(env_name, image_size): 40 | return lambda: build_single_env(env_name, image_size, seed) 41 | env_fns = [] 42 | env_fns = [lambda_generator(env_name, image_size) for i in range(num_envs)] 43 | vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns) 44 | return vec_env 45 | 46 | 47 | def train_world_model_step(replay_buffer: ReplayBuffer, world_model: WorldModel, batch_size, demonstration_batch_size, batch_length, logger): 48 | obs, action, reward, termination = replay_buffer.sample(batch_size, demonstration_batch_size, batch_length) 49 | world_model.update(obs, action, reward, termination, logger=logger) 50 | 51 | 52 | @torch.no_grad() 53 | def world_model_imagine_data(replay_buffer: ReplayBuffer, 54 | world_model: WorldModel, agent: agents.ActorCriticAgent, 55 | imagine_batch_size, imagine_demonstration_batch_size, 56 | imagine_context_length, imagine_batch_length, 57 | log_video, logger): 58 | ''' 59 | Sample context from replay buffer, then imagine data with world model and agent 60 | ''' 61 | world_model.eval() 62 | agent.eval() 63 | 64 | sample_obs, sample_action, sample_reward, sample_termination = replay_buffer.sample( 65 | imagine_batch_size, imagine_demonstration_batch_size, imagine_context_length) 66 | latent, action, reward_hat, termination_hat = world_model.imagine_data( 67 | agent, sample_obs, sample_action, 68 | imagine_batch_size=imagine_batch_size+imagine_demonstration_batch_size, 69 | imagine_batch_length=imagine_batch_length, 70 | log_video=log_video, 71 | logger=logger 72 | ) 73 | return latent, action, None, None, reward_hat, termination_hat 74 | 75 | 76 | def joint_train_world_model_agent(env_name, max_steps, num_envs, image_size, 77 | replay_buffer: ReplayBuffer, 78 | world_model: WorldModel, agent: agents.ActorCriticAgent, 79 | train_dynamics_every_steps, train_agent_every_steps, 80 | batch_size, demonstration_batch_size, batch_length, 81 | imagine_batch_size, imagine_demonstration_batch_size, 82 | imagine_context_length, imagine_batch_length, 83 | save_every_steps, seed, logger): 84 | # create ckpt dir 85 | os.makedirs(f"ckpt/{args.n}", exist_ok=True) 86 | 87 | # build vec env, not useful in the Atari100k setting 88 | # but when the max_steps is large, you can use parallel envs to speed up 89 | vec_env = build_vec_env(env_name, image_size, num_envs=num_envs, seed=seed) 90 | print("Current env: " + colorama.Fore.YELLOW + f"{env_name}" + colorama.Style.RESET_ALL) 91 | 92 | # reset envs and variables 93 | sum_reward = np.zeros(num_envs) 94 | current_obs, current_info = vec_env.reset() 95 | context_obs = deque(maxlen=16) 96 | context_action = deque(maxlen=16) 97 | 98 | # sample and train 99 | for total_steps in tqdm(range(max_steps//num_envs)): 100 | # sample part >>> 101 | if replay_buffer.ready(): 102 | world_model.eval() 103 | agent.eval() 104 | with torch.no_grad(): 105 | if len(context_action) == 0: 106 | action = vec_env.action_space.sample() 107 | else: 108 | context_latent = world_model.encode_obs(torch.cat(list(context_obs), dim=1)) 109 | model_context_action = np.stack(list(context_action), axis=1) 110 | model_context_action = torch.Tensor(model_context_action).cuda() 111 | prior_flattened_sample, last_dist_feat = world_model.calc_last_dist_feat(context_latent, model_context_action) 112 | action = agent.sample_as_env_action( 113 | torch.cat([prior_flattened_sample, last_dist_feat], dim=-1), 114 | greedy=False 115 | ) 116 | 117 | context_obs.append(rearrange(torch.Tensor(current_obs).cuda(), "B H W C -> B 1 C H W")/255) 118 | context_action.append(action) 119 | else: 120 | action = vec_env.action_space.sample() 121 | 122 | obs, reward, done, truncated, info = vec_env.step(action) 123 | replay_buffer.append(current_obs, action, reward, np.logical_or(done, info["life_loss"])) 124 | 125 | done_flag = np.logical_or(done, truncated) 126 | if done_flag.any(): 127 | for i in range(num_envs): 128 | if done_flag[i]: 129 | logger.log(f"sample/{env_name}_reward", sum_reward[i]) 130 | logger.log(f"sample/{env_name}_episode_steps", current_info["episode_frame_number"][i]//4) # framskip=4 131 | logger.log("replay_buffer/length", len(replay_buffer)) 132 | sum_reward[i] = 0 133 | 134 | # update current_obs, current_info and sum_reward 135 | sum_reward += reward 136 | current_obs = obs 137 | current_info = info 138 | # <<< sample part 139 | 140 | # train world model part >>> 141 | if replay_buffer.ready() and total_steps % (train_dynamics_every_steps//num_envs) == 0: 142 | train_world_model_step( 143 | replay_buffer=replay_buffer, 144 | world_model=world_model, 145 | batch_size=batch_size, 146 | demonstration_batch_size=demonstration_batch_size, 147 | batch_length=batch_length, 148 | logger=logger 149 | ) 150 | # <<< train world model part 151 | 152 | # train agent part >>> 153 | if replay_buffer.ready() and total_steps % (train_agent_every_steps//num_envs) == 0 and total_steps*num_envs >= 0: 154 | if total_steps % (save_every_steps//num_envs) == 0: 155 | log_video = True 156 | else: 157 | log_video = False 158 | 159 | imagine_latent, agent_action, agent_logprob, agent_value, imagine_reward, imagine_termination = world_model_imagine_data( 160 | replay_buffer=replay_buffer, 161 | world_model=world_model, 162 | agent=agent, 163 | imagine_batch_size=imagine_batch_size, 164 | imagine_demonstration_batch_size=imagine_demonstration_batch_size, 165 | imagine_context_length=imagine_context_length, 166 | imagine_batch_length=imagine_batch_length, 167 | log_video=log_video, 168 | logger=logger 169 | ) 170 | 171 | agent.update( 172 | latent=imagine_latent, 173 | action=agent_action, 174 | old_logprob=agent_logprob, 175 | old_value=agent_value, 176 | reward=imagine_reward, 177 | termination=imagine_termination, 178 | logger=logger 179 | ) 180 | # <<< train agent part 181 | 182 | # save model per episode 183 | if total_steps % (save_every_steps//num_envs) == 0: 184 | print(colorama.Fore.GREEN + f"Saving model at total steps {total_steps}" + colorama.Style.RESET_ALL) 185 | torch.save(world_model.state_dict(), f"ckpt/{args.n}/world_model_{total_steps}.pth") 186 | torch.save(agent.state_dict(), f"ckpt/{args.n}/agent_{total_steps}.pth") 187 | 188 | 189 | def build_world_model(conf, action_dim): 190 | return WorldModel( 191 | in_channels=conf.Models.WorldModel.InChannels, 192 | action_dim=action_dim, 193 | transformer_max_length=conf.Models.WorldModel.TransformerMaxLength, 194 | transformer_hidden_dim=conf.Models.WorldModel.TransformerHiddenDim, 195 | transformer_num_layers=conf.Models.WorldModel.TransformerNumLayers, 196 | transformer_num_heads=conf.Models.WorldModel.TransformerNumHeads 197 | ).cuda() 198 | 199 | 200 | def build_agent(conf, action_dim): 201 | return agents.ActorCriticAgent( 202 | feat_dim=32*32+conf.Models.WorldModel.TransformerHiddenDim, 203 | num_layers=conf.Models.Agent.NumLayers, 204 | hidden_dim=conf.Models.Agent.HiddenDim, 205 | action_dim=action_dim, 206 | gamma=conf.Models.Agent.Gamma, 207 | lambd=conf.Models.Agent.Lambda, 208 | entropy_coef=conf.Models.Agent.EntropyCoef, 209 | ).cuda() 210 | 211 | 212 | if __name__ == "__main__": 213 | # ignore warnings 214 | import warnings 215 | warnings.filterwarnings('ignore') 216 | torch.backends.cuda.matmul.allow_tf32 = True 217 | torch.backends.cudnn.allow_tf32 = True 218 | 219 | # parse arguments 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument("-n", type=str, required=True) 222 | parser.add_argument("-seed", type=int, required=True) 223 | parser.add_argument("-config_path", type=str, required=True) 224 | parser.add_argument("-env_name", type=str, required=True) 225 | parser.add_argument("-trajectory_path", type=str, required=True) 226 | args = parser.parse_args() 227 | conf = load_config(args.config_path) 228 | print(colorama.Fore.RED + str(args) + colorama.Style.RESET_ALL) 229 | 230 | # set seed 231 | seed_np_torch(seed=args.seed) 232 | # tensorboard writer 233 | logger = Logger(path=f"runs/{args.n}") 234 | # copy config file 235 | shutil.copy(args.config_path, f"runs/{args.n}/config.yaml") 236 | 237 | # distinguish between tasks, other debugging options are removed for simplicity 238 | if conf.Task == "JointTrainAgent": 239 | # getting action_dim with dummy env 240 | dummy_env = build_single_env(args.env_name, conf.BasicSettings.ImageSize, seed=0) 241 | action_dim = dummy_env.action_space.n 242 | 243 | # build world model and agent 244 | world_model = build_world_model(conf, action_dim) 245 | agent = build_agent(conf, action_dim) 246 | 247 | # build replay buffer 248 | replay_buffer = ReplayBuffer( 249 | obs_shape=(conf.BasicSettings.ImageSize, conf.BasicSettings.ImageSize, 3), 250 | num_envs=conf.JointTrainAgent.NumEnvs, 251 | max_length=conf.JointTrainAgent.BufferMaxLength, 252 | warmup_length=conf.JointTrainAgent.BufferWarmUp, 253 | store_on_gpu=conf.BasicSettings.ReplayBufferOnGPU 254 | ) 255 | 256 | # judge whether to load demonstration trajectory 257 | if conf.JointTrainAgent.UseDemonstration: 258 | print(colorama.Fore.MAGENTA + f"loading demonstration trajectory from {args.trajectory_path}" + colorama.Style.RESET_ALL) 259 | replay_buffer.load_trajectory(path=args.trajectory_path) 260 | 261 | # train 262 | joint_train_world_model_agent( 263 | env_name=args.env_name, 264 | num_envs=conf.JointTrainAgent.NumEnvs, 265 | max_steps=conf.JointTrainAgent.SampleMaxSteps, 266 | image_size=conf.BasicSettings.ImageSize, 267 | replay_buffer=replay_buffer, 268 | world_model=world_model, 269 | agent=agent, 270 | train_dynamics_every_steps=conf.JointTrainAgent.TrainDynamicsEverySteps, 271 | train_agent_every_steps=conf.JointTrainAgent.TrainAgentEverySteps, 272 | batch_size=conf.JointTrainAgent.BatchSize, 273 | demonstration_batch_size=conf.JointTrainAgent.DemonstrationBatchSize if conf.JointTrainAgent.UseDemonstration else 0, 274 | batch_length=conf.JointTrainAgent.BatchLength, 275 | imagine_batch_size=conf.JointTrainAgent.ImagineBatchSize, 276 | imagine_demonstration_batch_size=conf.JointTrainAgent.ImagineDemonstrationBatchSize if conf.JointTrainAgent.UseDemonstration else 0, 277 | imagine_context_length=conf.JointTrainAgent.ImagineContextLength, 278 | imagine_batch_length=conf.JointTrainAgent.ImagineBatchLength, 279 | save_every_steps=conf.JointTrainAgent.SaveEverySteps, 280 | seed=args.seed, 281 | logger=logger 282 | ) 283 | else: 284 | raise NotImplementedError(f"Task {conf.Task} not implemented") 285 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | env_name=MsPacman 2 | python -u train.py \ 3 | -n "${env_name}-life_done-wm_2L512D8H-100k-seed1" \ 4 | -seed 1 \ 5 | -config_path "config_files/STORM.yaml" \ 6 | -env_name "ALE/${env_name}-v5" \ 7 | -trajectory_path "D_TRAJ/${env_name}.pkl" -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import random 5 | from tensorboardX import SummaryWriter 6 | from einops import repeat 7 | from contextlib import contextmanager 8 | import time 9 | import yacs 10 | from yacs.config import CfgNode as CN 11 | 12 | 13 | def seed_np_torch(seed=20010105): 14 | random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | # some cudnn methods can be random even after fixing the seed unless you tell it to be deterministic 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | class Logger(): 26 | def __init__(self, path) -> None: 27 | self.writer = SummaryWriter(logdir=path, flush_secs=1) 28 | self.tag_step = {} 29 | 30 | def log(self, tag, value): 31 | if tag not in self.tag_step: 32 | self.tag_step[tag] = 0 33 | else: 34 | self.tag_step[tag] += 1 35 | if "video" in tag: 36 | self.writer.add_video(tag, value, self.tag_step[tag], fps=15) 37 | elif "images" in tag: 38 | self.writer.add_images(tag, value, self.tag_step[tag]) 39 | elif "hist" in tag: 40 | self.writer.add_histogram(tag, value, self.tag_step[tag]) 41 | else: 42 | self.writer.add_scalar(tag, value, self.tag_step[tag]) 43 | 44 | 45 | class EMAScalar(): 46 | def __init__(self, decay) -> None: 47 | self.scalar = 0.0 48 | self.decay = decay 49 | 50 | def __call__(self, value): 51 | self.update(value) 52 | return self.get() 53 | 54 | def update(self, value): 55 | self.scalar = self.scalar * self.decay + value * (1 - self.decay) 56 | 57 | def get(self): 58 | return self.scalar 59 | 60 | 61 | def load_config(config_path): 62 | conf = CN() 63 | # Task need to be RandomSample/TrainVQVAE/TrainWorldModel 64 | conf.Task = "" 65 | 66 | conf.BasicSettings = CN() 67 | conf.BasicSettings.Seed = 0 68 | conf.BasicSettings.ImageSize = 0 69 | conf.BasicSettings.ReplayBufferOnGPU = False 70 | 71 | # Under this setting, input 128*128 -> latent 16*16*64 72 | conf.Models = CN() 73 | 74 | conf.Models.WorldModel = CN() 75 | conf.Models.WorldModel.InChannels = 0 76 | conf.Models.WorldModel.TransformerMaxLength = 0 77 | conf.Models.WorldModel.TransformerHiddenDim = 0 78 | conf.Models.WorldModel.TransformerNumLayers = 0 79 | conf.Models.WorldModel.TransformerNumHeads = 0 80 | 81 | conf.Models.Agent = CN() 82 | conf.Models.Agent.NumLayers = 0 83 | conf.Models.Agent.HiddenDim = 256 84 | conf.Models.Agent.Gamma = 1.0 85 | conf.Models.Agent.Lambda = 0.0 86 | conf.Models.Agent.EntropyCoef = 0.0 87 | 88 | conf.JointTrainAgent = CN() 89 | conf.JointTrainAgent.SampleMaxSteps = 0 90 | conf.JointTrainAgent.BufferMaxLength = 0 91 | conf.JointTrainAgent.BufferWarmUp = 0 92 | conf.JointTrainAgent.NumEnvs = 0 93 | conf.JointTrainAgent.BatchSize = 0 94 | conf.JointTrainAgent.DemonstrationBatchSize = 0 95 | conf.JointTrainAgent.BatchLength = 0 96 | conf.JointTrainAgent.ImagineBatchSize = 0 97 | conf.JointTrainAgent.ImagineDemonstrationBatchSize = 0 98 | conf.JointTrainAgent.ImagineContextLength = 0 99 | conf.JointTrainAgent.ImagineBatchLength = 0 100 | conf.JointTrainAgent.TrainDynamicsEverySteps = 0 101 | conf.JointTrainAgent.TrainAgentEverySteps = 0 102 | conf.JointTrainAgent.SaveEverySteps = 0 103 | conf.JointTrainAgent.UseDemonstration = False 104 | 105 | conf.defrost() 106 | conf.merge_from_file(config_path) 107 | conf.freeze() 108 | 109 | return conf 110 | --------------------------------------------------------------------------------