├── .gitignore ├── LICENSE ├── README.md ├── agent.py ├── model.py ├── requirements.txt ├── test.py ├── train.py ├── utils.py ├── video_prediction.py ├── viewer.py └── wrappers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaito Suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dreamer_PyTorch 2 | Unofficial re-implementation of "Dream to Control: Learning Behaviors by Latent Imagination" (https://arxiv.org/abs/1912.01603 ). Work in progress 3 | 4 | ## Instructions 5 | For training, install the requirements (see below) and run 6 | ```python 7 | python3 train.py 8 | ``` 9 | 10 | ## Requirements 11 | * Python3 12 | * Mujoco (for DeepMind Control Suite) 13 | 14 | and see requirements.txt for required python library 15 | 16 | ## References 17 | * [Dream to Control: Learning Behaviors by Latent Imagination](https://arxiv.org/abs/1912.01603) 18 | * [Official Implementation](https://github.com/google-research/dreamer) 19 | * [Official Implementation2](https://github.com/danijar/dreamer) 20 | * [My Implementation of PlaNet](https://github.com/cross32768/PlaNet_PyTorch) 21 | 22 | 23 | ## TODO 24 | * Speed up training 25 | * Add comments and Improve readability 26 | * Add results of experiments and video prediction 27 | * Generalize code for other environments 28 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import preprocess_obs 3 | 4 | 5 | class Agent: 6 | """ 7 | Agent class to get action with action model 8 | and maintain rnn_hidden for input of action model 9 | """ 10 | def __init__(self, encoder, rssm, action_model): 11 | self.encoder = encoder 12 | self.rssm = rssm 13 | self.action_model = action_model 14 | 15 | self.device = next(self.action_model.parameters()).device 16 | self.rnn_hidden = torch.zeros(1, rssm.rnn_hidden_dim, device=self.device) 17 | 18 | def __call__(self, obs, training=True): 19 | """ 20 | if training=False, returned action is mean 21 | instead of sample from action_model's distribution 22 | """ 23 | # preprocess observation and transpose for torch style (channel-first) 24 | obs = preprocess_obs(obs) 25 | obs = torch.as_tensor(obs, device=self.device) 26 | obs = obs.transpose(1, 2).transpose(0, 1).unsqueeze(0) 27 | 28 | with torch.no_grad(): 29 | # embed observation, compute state posterior, sample from state posterior 30 | # and get action using sampled state and rnn_hidden as input 31 | embedded_obs = self.encoder(obs) 32 | state_posterior = self.rssm.posterior(self.rnn_hidden, embedded_obs) 33 | state = state_posterior.sample() 34 | action = self.action_model(state, self.rnn_hidden, training=training) 35 | 36 | # update rnn_hidden for next step 37 | _, self.rnn_hidden = self.rssm.prior(state, action, self.rnn_hidden) 38 | 39 | return action.squeeze().cpu().numpy() 40 | 41 | def reset(self): 42 | self.rnn_hidden = torch.zeros(1, self.rssm.rnn_hidden_dim, device=self.device) 43 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.distributions import Normal 6 | 7 | 8 | class Encoder(nn.Module): 9 | """ 10 | Encoder to embed image observation (3, 64, 64) to vector (1024,) 11 | """ 12 | def __init__(self): 13 | super(Encoder, self).__init__() 14 | self.cv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2) 15 | self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 16 | self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2) 17 | self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2) 18 | 19 | def forward(self, obs): 20 | hidden = F.relu(self.cv1(obs)) 21 | hidden = F.relu(self.cv2(hidden)) 22 | hidden = F.relu(self.cv3(hidden)) 23 | embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1) 24 | return embedded_obs 25 | 26 | 27 | class RecurrentStateSpaceModel(nn.Module): 28 | """ 29 | This class includes multiple components 30 | Deterministic state model: h_t+1 = f(h_t, s_t, a_t) 31 | Stochastic state model (prior): p(s_t+1 | h_t+1) 32 | State posterior: q(s_t | h_t, o_t) 33 | NOTE: actually, this class takes embedded observation by Encoder class 34 | min_stddev is added to stddev same as original implementation 35 | Activation function for this class is F.relu same as original implementation 36 | """ 37 | def __init__(self, state_dim, action_dim, rnn_hidden_dim, 38 | hidden_dim=200, min_stddev=0.1, act=F.elu): 39 | super(RecurrentStateSpaceModel, self).__init__() 40 | self.state_dim = state_dim 41 | self.action_dim = action_dim 42 | self.rnn_hidden_dim = rnn_hidden_dim 43 | self.fc_state_action = nn.Linear(state_dim + action_dim, hidden_dim) 44 | self.fc_rnn_hidden = nn.Linear(rnn_hidden_dim, hidden_dim) 45 | self.fc_state_mean_prior = nn.Linear(hidden_dim, state_dim) 46 | self.fc_state_stddev_prior = nn.Linear(hidden_dim, state_dim) 47 | self.fc_rnn_hidden_embedded_obs = nn.Linear(rnn_hidden_dim + 1024, hidden_dim) 48 | self.fc_state_mean_posterior = nn.Linear(hidden_dim, state_dim) 49 | self.fc_state_stddev_posterior = nn.Linear(hidden_dim, state_dim) 50 | self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim) 51 | self._min_stddev = min_stddev 52 | self.act = act 53 | 54 | def forward(self, state, action, rnn_hidden, embedded_next_obs): 55 | """ 56 | h_t+1 = f(h_t, s_t, a_t) 57 | Return prior p(s_t+1 | h_t+1) and posterior p(s_t+1 | h_t+1, o_t+1) 58 | for model training 59 | """ 60 | next_state_prior, rnn_hidden = self.prior(state, action, rnn_hidden) 61 | next_state_posterior = self.posterior(rnn_hidden, embedded_next_obs) 62 | return next_state_prior, next_state_posterior, rnn_hidden 63 | 64 | def prior(self, state, action, rnn_hidden): 65 | """ 66 | h_t+1 = f(h_t, s_t, a_t) 67 | Compute prior p(s_t+1 | h_t+1) 68 | """ 69 | hidden = self.act(self.fc_state_action(torch.cat([state, action], dim=1))) 70 | rnn_hidden = self.rnn(hidden, rnn_hidden) 71 | hidden = self.act(self.fc_rnn_hidden(rnn_hidden)) 72 | 73 | mean = self.fc_state_mean_prior(hidden) 74 | stddev = F.softplus(self.fc_state_stddev_prior(hidden)) + self._min_stddev 75 | return Normal(mean, stddev), rnn_hidden 76 | 77 | def posterior(self, rnn_hidden, embedded_obs): 78 | """ 79 | Compute posterior q(s_t | h_t, o_t) 80 | """ 81 | hidden = self.act(self.fc_rnn_hidden_embedded_obs( 82 | torch.cat([rnn_hidden, embedded_obs], dim=1))) 83 | mean = self.fc_state_mean_posterior(hidden) 84 | stddev = F.softplus(self.fc_state_stddev_posterior(hidden)) + self._min_stddev 85 | return Normal(mean, stddev) 86 | 87 | 88 | class ObservationModel(nn.Module): 89 | """ 90 | p(o_t | s_t, h_t) 91 | Observation model to reconstruct image observation (3, 64, 64) 92 | from state and rnn hidden state 93 | """ 94 | def __init__(self, state_dim, rnn_hidden_dim): 95 | super(ObservationModel, self).__init__() 96 | self.fc = nn.Linear(state_dim + rnn_hidden_dim, 1024) 97 | self.dc1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2) 98 | self.dc2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2) 99 | self.dc3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2) 100 | self.dc4 = nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2) 101 | 102 | def forward(self, state, rnn_hidden): 103 | hidden = self.fc(torch.cat([state, rnn_hidden], dim=1)) 104 | hidden = hidden.view(hidden.size(0), 1024, 1, 1) 105 | hidden = F.relu(self.dc1(hidden)) 106 | hidden = F.relu(self.dc2(hidden)) 107 | hidden = F.relu(self.dc3(hidden)) 108 | obs = self.dc4(hidden) 109 | return obs 110 | 111 | 112 | class RewardModel(nn.Module): 113 | """ 114 | p(r_t | s_t, h_t) 115 | Reward model to predict reward from state and rnn hidden state 116 | """ 117 | def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu): 118 | super(RewardModel, self).__init__() 119 | self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim) 120 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 121 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 122 | self.fc4 = nn.Linear(hidden_dim, 1) 123 | self.act = act 124 | 125 | def forward(self, state, rnn_hidden): 126 | hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1))) 127 | hidden = self.act(self.fc2(hidden)) 128 | hidden = self.act(self.fc3(hidden)) 129 | reward = self.fc4(hidden) 130 | return reward 131 | 132 | 133 | class ValueModel(nn.Module): 134 | """ 135 | Value model to predict state-value of current policy (action_model) 136 | from state and rnn_hidden 137 | """ 138 | def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu): 139 | super(ValueModel, self).__init__() 140 | self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim) 141 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 142 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 143 | self.fc4 = nn.Linear(hidden_dim, 1) 144 | self.act = act 145 | 146 | def forward(self, state, rnn_hidden): 147 | hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1))) 148 | hidden = self.act(self.fc2(hidden)) 149 | hidden = self.act(self.fc3(hidden)) 150 | state_value = self.fc4(hidden) 151 | return state_value 152 | 153 | 154 | class ActionModel(nn.Module): 155 | """ 156 | Action model to compute action from state and rnn_hidden 157 | """ 158 | def __init__(self, state_dim, rnn_hidden_dim, action_dim, 159 | hidden_dim=400, act=F.elu, min_stddev=1e-4, init_stddev=5.0): 160 | super(ActionModel, self).__init__() 161 | self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim) 162 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 163 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 164 | self.fc4 = nn.Linear(hidden_dim, hidden_dim) 165 | self.fc_mean = nn.Linear(hidden_dim, action_dim) 166 | self.fc_stddev = nn.Linear(hidden_dim, action_dim) 167 | self.act = act 168 | self.min_stddev = min_stddev 169 | self.init_stddev = np.log(np.exp(init_stddev) - 1) 170 | 171 | def forward(self, state, rnn_hidden, training=True): 172 | """ 173 | if training=True, returned action is reparametrized sample 174 | if training=False, returned action is mean of action distribution 175 | """ 176 | hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1))) 177 | hidden = self.act(self.fc2(hidden)) 178 | hidden = self.act(self.fc3(hidden)) 179 | hidden = self.act(self.fc4(hidden)) 180 | 181 | # action-mean is divided by 5.0 and applied tanh 182 | # and multiplied by 5.0 same as Dreamer's paper 183 | mean = self.fc_mean(hidden) 184 | mean = 5.0 * torch.tanh(mean / 5.0) 185 | 186 | # stddev is computed with some hyperparameter 187 | # (init_stddev, min_stddev) same as original implementation 188 | stddev = self.fc_stddev(hidden) 189 | stddev = F.softplus(stddev + self.init_stddev) + self.min_stddev 190 | 191 | if training: 192 | action = torch.tanh(Normal(mean, stddev).rsample()) 193 | else: 194 | action = torch.tanh(mean) 195 | return action 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm_control 2 | gym 3 | opencv-python 4 | matplotlib 5 | numpy 6 | tensorboard 7 | torch -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | from dm_control import suite 6 | from dm_control.suite.wrappers import pixels 7 | from agent import Agent 8 | from model import Encoder, RecurrentStateSpaceModel, ActionModel 9 | from wrappers import GymWrapper, RepeatAction 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Test learned model') 14 | parser.add_argument('dir', type=str, help='log directory to load learned model') 15 | parser.add_argument('--render', action='store_true') 16 | parser.add_argument('--domain-name', type=str, default='cheetah') 17 | parser.add_argument('--task-name', type=str, default='run') 18 | parser.add_argument('-R', '--action-repeat', type=int, default=2) 19 | parser.add_argument('--episodes', type=int, default=1) 20 | args = parser.parse_args() 21 | 22 | # define environment and apply wrapper 23 | env = suite.load(args.domain_name, args.task_name) 24 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 25 | 'width': 64, 26 | 'camera_id': 0}) 27 | env = GymWrapper(env) 28 | env = RepeatAction(env, skip=args.action_repeat) 29 | 30 | # define models 31 | with open(os.path.join(args.dir, 'args.json'), 'r') as f: 32 | train_args = json.load(f) 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | encoder = Encoder().to(device) 36 | rssm = RecurrentStateSpaceModel(train_args['state_dim'], 37 | env.action_space.shape[0], 38 | train_args['rnn_hidden_dim']).to(device) 39 | action_model = ActionModel(train_args['state_dim'], train_args['rnn_hidden_dim'], 40 | env.action_space.shape[0]).to(device) 41 | 42 | # load learned parameters 43 | encoder.load_state_dict(torch.load(os.path.join(args.dir, 'encoder.pth'))) 44 | rssm.load_state_dict(torch.load(os.path.join(args.dir, 'rssm.pth'))) 45 | action_model.load_state_dict(torch.load(os.path.join(args.dir, 'action_model.pth'))) 46 | 47 | # define agent 48 | policy = Agent(encoder, rssm, action_model) 49 | 50 | # test learnged model in the environment 51 | for episode in range(args.episodes): 52 | policy.reset() 53 | obs = env.reset() 54 | done = False 55 | total_reward = 0 56 | while not done: 57 | action = policy(obs) 58 | obs, reward, done, _ = env.step(action) 59 | total_reward += reward 60 | if args.render: 61 | env.render(height=256, width=256, camera_id=0) 62 | 63 | print('Total test reward at episode [%4d/%4d] is %f' % 64 | (episode+1, args.episodes, total_reward)) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import json 4 | import os 5 | from pprint import pprint 6 | import time 7 | import numpy as np 8 | import torch 9 | from torch.distributions.kl import kl_divergence 10 | from torch.nn.functional import mse_loss 11 | from torch.nn.utils import clip_grad_norm_ 12 | from torch.optim import Adam 13 | from torch.utils.tensorboard import SummaryWriter 14 | from dm_control import suite 15 | from dm_control.suite.wrappers import pixels 16 | from agent import Agent 17 | from model import (Encoder, RecurrentStateSpaceModel, ObservationModel, RewardModel, 18 | ValueModel, ActionModel) 19 | from utils import ReplayBuffer, preprocess_obs, lambda_target 20 | from wrappers import GymWrapper, RepeatAction 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description='Dreamer for DM control') 25 | parser.add_argument('--seed', type=int, default=0) 26 | parser.add_argument('--log-dir', type=str, default='log') 27 | parser.add_argument('--test-interval', type=int, default=10) 28 | parser.add_argument('--domain-name', type=str, default='cheetah') 29 | parser.add_argument('--task-name', type=str, default='run') 30 | parser.add_argument('-R', '--action-repeat', type=int, default=2) 31 | parser.add_argument('--state-dim', type=int, default=30) 32 | parser.add_argument('--rnn-hidden-dim', type=int, default=200) 33 | parser.add_argument('--buffer-capacity', type=int, default=1000000) 34 | parser.add_argument('--all-episodes', type=int, default=1000) 35 | parser.add_argument('-S', '--seed-episodes', type=int, default=5) 36 | parser.add_argument('-C', '--collect-interval', type=int, default=100) 37 | parser.add_argument('-B', '--batch-size', type=int, default=50) 38 | parser.add_argument('-L', '--chunk-length', type=int, default=50) 39 | parser.add_argument('-H', '--imagination-horizon', type=int, default=15) 40 | parser.add_argument('--gamma', type=float, default=0.99) 41 | parser.add_argument('--lambda_', type=float, default=0.95) 42 | parser.add_argument('--model_lr', type=float, default=6e-4) 43 | parser.add_argument('--value_lr', type=float, default=8e-5) 44 | parser.add_argument('--action_lr', type=float, default=8e-5) 45 | parser.add_argument('--eps', type=float, default=1e-4) 46 | parser.add_argument('--clip-grad-norm', type=int, default=100) 47 | parser.add_argument('--free-nats', type=int, default=3) 48 | parser.add_argument('--action-noise-var', type=float, default=0.3) 49 | args = parser.parse_args() 50 | 51 | # Prepare logging 52 | log_dir = os.path.join(args.log_dir, args.domain_name + '_' + args.task_name) 53 | log_dir = os.path.join(log_dir, datetime.now().strftime('%Y%m%d_%H%M')) 54 | os.makedirs(log_dir) 55 | with open(os.path.join(log_dir, 'args.json'), 'w') as f: 56 | json.dump(vars(args), f) 57 | pprint(vars(args)) 58 | writer = SummaryWriter(log_dir=log_dir) 59 | 60 | # set seed (NOTE: some randomness is still remaining (e.g. cuDNN's randomness)) 61 | np.random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | if torch.cuda.is_available(): 64 | torch.cuda.manual_seed(args.seed) 65 | 66 | # define env and apply wrappers 67 | env = suite.load(args.domain_name, args.task_name, task_kwargs={'random': args.seed}) 68 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 69 | 'width': 64, 70 | 'camera_id': 0}) 71 | env = GymWrapper(env) 72 | env = RepeatAction(env, skip=args.action_repeat) 73 | 74 | # define replay buffer 75 | replay_buffer = ReplayBuffer(capacity=args.buffer_capacity, 76 | observation_shape=env.observation_space.shape, 77 | action_dim=env.action_space.shape[0]) 78 | 79 | # define models and optimizer 80 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 81 | encoder = Encoder().to(device) 82 | rssm = RecurrentStateSpaceModel(args.state_dim, 83 | env.action_space.shape[0], 84 | args.rnn_hidden_dim).to(device) 85 | obs_model = ObservationModel(args.state_dim, args.rnn_hidden_dim).to(device) 86 | reward_model = RewardModel(args.state_dim, args.rnn_hidden_dim).to(device) 87 | model_params = (list(encoder.parameters()) + 88 | list(rssm.parameters()) + 89 | list(obs_model.parameters()) + 90 | list(reward_model.parameters())) 91 | model_optimizer = Adam(model_params, lr=args.model_lr, eps=args.eps) 92 | 93 | # define value model and action model and optimizer 94 | value_model = ValueModel(args.state_dim, args.rnn_hidden_dim).to(device) 95 | action_model = ActionModel(args.state_dim, args.rnn_hidden_dim, 96 | env.action_space.shape[0]).to(device) 97 | value_optimizer = Adam(value_model.parameters(), lr=args.value_lr, eps=args.eps) 98 | action_optimizer = Adam(action_model.parameters(), lr=args.action_lr, eps=args.eps) 99 | 100 | # collect seed episodes with random action 101 | for episode in range(args.seed_episodes): 102 | obs = env.reset() 103 | done = False 104 | while not done: 105 | action = env.action_space.sample() 106 | next_obs, reward, done, _ = env.step(action) 107 | replay_buffer.push(obs, action, reward, done) 108 | obs = next_obs 109 | 110 | # main training loop 111 | for episode in range(args.seed_episodes, args.all_episodes): 112 | # ----------------------------- 113 | # collect experiences 114 | # ----------------------------- 115 | start = time.time() 116 | policy = Agent(encoder, rssm, action_model) 117 | 118 | obs = env.reset() 119 | done = False 120 | total_reward = 0 121 | while not done: 122 | action = policy(obs) 123 | action += np.random.normal(0, np.sqrt(args.action_noise_var), 124 | env.action_space.shape[0]) 125 | next_obs, reward, done, _ = env.step(action) 126 | replay_buffer.push(obs, action, reward, done) 127 | obs = next_obs 128 | total_reward += reward 129 | 130 | writer.add_scalar('total reward at train', total_reward, episode) 131 | print('episode [%4d/%4d] is collected. Total reward is %f' % 132 | (episode+1, args.all_episodes, total_reward)) 133 | print('elasped time for interaction: %.2fs' % (time.time() - start)) 134 | 135 | # update parameters of model, value model, action model 136 | start = time.time() 137 | for update_step in range(args.collect_interval): 138 | # --------------------------------------------------------------- 139 | # update model (encoder, rssm, obs_model, reward_model) 140 | # --------------------------------------------------------------- 141 | observations, actions, rewards, _ = \ 142 | replay_buffer.sample(args.batch_size, args.chunk_length) 143 | 144 | # preprocess observations and transpose tensor for RNN training 145 | observations = preprocess_obs(observations) 146 | observations = torch.as_tensor(observations, device=device) 147 | observations = observations.transpose(3, 4).transpose(2, 3) 148 | observations = observations.transpose(0, 1) 149 | actions = torch.as_tensor(actions, device=device).transpose(0, 1) 150 | rewards = torch.as_tensor(rewards, device=device).transpose(0, 1) 151 | 152 | # embed observations with CNN 153 | embedded_observations = encoder( 154 | observations.reshape(-1, 3, 64, 64)).view(args.chunk_length, args.batch_size, -1) 155 | 156 | # prepare Tensor to maintain states sequence and rnn hidden states sequence 157 | states = torch.zeros( 158 | args.chunk_length, args.batch_size, args.state_dim, device=device) 159 | rnn_hiddens = torch.zeros( 160 | args.chunk_length, args.batch_size, args.rnn_hidden_dim, device=device) 161 | 162 | # initialize state and rnn hidden state with 0 vector 163 | state = torch.zeros(args.batch_size, args.state_dim, device=device) 164 | rnn_hidden = torch.zeros(args.batch_size, args.rnn_hidden_dim, device=device) 165 | 166 | # compute state and rnn hidden sequences and kl loss 167 | kl_loss = 0 168 | for l in range(args.chunk_length-1): 169 | next_state_prior, next_state_posterior, rnn_hidden = \ 170 | rssm(state, actions[l], rnn_hidden, embedded_observations[l+1]) 171 | state = next_state_posterior.rsample() 172 | states[l+1] = state 173 | rnn_hiddens[l+1] = rnn_hidden 174 | kl = kl_divergence(next_state_prior, next_state_posterior).sum(dim=1) 175 | kl_loss += kl.clamp(min=args.free_nats).mean() 176 | kl_loss /= (args.chunk_length - 1) 177 | 178 | # states[0] and rnn_hiddens[0] are always 0 and have no information 179 | states = states[1:] 180 | rnn_hiddens = rnn_hiddens[1:] 181 | 182 | # compute reconstructed observations and predicted rewards 183 | flatten_states = states.view(-1, args.state_dim) 184 | flatten_rnn_hiddens = rnn_hiddens.view(-1, args.rnn_hidden_dim) 185 | recon_observations = obs_model(flatten_states, flatten_rnn_hiddens).view( 186 | args.chunk_length-1, args.batch_size, 3, 64, 64) 187 | predicted_rewards = reward_model(flatten_states, flatten_rnn_hiddens).view( 188 | args.chunk_length-1, args.batch_size, 1) 189 | 190 | # compute loss for observation and reward 191 | obs_loss = 0.5 * mse_loss( 192 | recon_observations, observations[1:], reduction='none').mean([0, 1]).sum() 193 | reward_loss = 0.5 * mse_loss(predicted_rewards, rewards[:-1]) 194 | 195 | # add all losses and update model parameters with gradient descent 196 | model_loss = kl_loss + obs_loss + reward_loss 197 | model_optimizer.zero_grad() 198 | model_loss.backward() 199 | clip_grad_norm_(model_params, args.clip_grad_norm) 200 | model_optimizer.step() 201 | 202 | # ---------------------------------------------- 203 | # update value_model and action_model 204 | # ---------------------------------------------- 205 | # detach gradient because Dreamer doesn't update model with actor-critic loss 206 | flatten_states = flatten_states.detach() 207 | flatten_rnn_hiddens = flatten_rnn_hiddens.detach() 208 | 209 | # prepare tensor to maintain imaginated trajectory's states and rnn_hiddens 210 | imaginated_states = torch.zeros(args.imagination_horizon + 1, 211 | *flatten_states.shape, 212 | device=flatten_states.device) 213 | imaginated_rnn_hiddens = torch.zeros(args.imagination_horizon + 1, 214 | *flatten_rnn_hiddens.shape, 215 | device=flatten_rnn_hiddens.device) 216 | imaginated_states[0] = flatten_states 217 | imaginated_rnn_hiddens[0] = flatten_rnn_hiddens 218 | 219 | # compute imaginated trajectory using action from action_model 220 | for h in range(1, args.imagination_horizon + 1): 221 | actions = action_model(flatten_states, flatten_rnn_hiddens) 222 | flatten_states_prior, flatten_rnn_hiddens = rssm.prior(flatten_states, 223 | actions, 224 | flatten_rnn_hiddens) 225 | flatten_states = flatten_states_prior.rsample() 226 | imaginated_states[h] = flatten_states 227 | imaginated_rnn_hiddens[h] = flatten_rnn_hiddens 228 | 229 | # compute rewards and values corresponding to imaginated states and rnn_hiddens 230 | flatten_imaginated_states = imaginated_states.view(-1, args.state_dim) 231 | flatten_imaginated_rnn_hiddens = imaginated_rnn_hiddens.view(-1, args.rnn_hidden_dim) 232 | imaginated_rewards = \ 233 | reward_model(flatten_imaginated_states, 234 | flatten_imaginated_rnn_hiddens).view(args.imagination_horizon + 1, -1) 235 | imaginated_values = \ 236 | value_model(flatten_imaginated_states, 237 | flatten_imaginated_rnn_hiddens).view(args.imagination_horizon + 1, -1) 238 | # compute lambda target 239 | lambda_target_values = lambda_target(imaginated_rewards, imaginated_values, 240 | args.gamma, args.lambda_) 241 | 242 | # update_value model 243 | value_loss = 0.5 * mse_loss(imaginated_values, lambda_target_values.detach()) 244 | value_optimizer.zero_grad() 245 | value_loss.backward(retain_graph=True) 246 | clip_grad_norm_(value_model.parameters(), args.clip_grad_norm) 247 | value_optimizer.step() 248 | 249 | # update action model (multiply -1 for gradient ascent) 250 | action_loss = -1 * (lambda_target_values.mean()) 251 | action_optimizer.zero_grad() 252 | action_loss.backward() 253 | clip_grad_norm_(action_model.parameters(), args.clip_grad_norm) 254 | action_optimizer.step() 255 | 256 | # print losses and add to tensorboard 257 | print('update_step: %3d model loss: %.5f, kl_loss: %.5f, ' 258 | 'obs_loss: %.5f, reward_loss: %.5f, ' 259 | 'value_loss: %.5f action_loss: %.5f' 260 | % (update_step + 1, model_loss.item(), kl_loss.item(), 261 | obs_loss.item(), reward_loss.item(), 262 | value_loss.item(), action_loss.item())) 263 | total_update_step = episode * args.collect_interval + update_step 264 | writer.add_scalar('model loss', model_loss.item(), total_update_step) 265 | writer.add_scalar('kl loss', kl_loss.item(), total_update_step) 266 | writer.add_scalar('obs loss', obs_loss.item(), total_update_step) 267 | writer.add_scalar('reward loss', reward_loss.item(), total_update_step) 268 | writer.add_scalar('value loss', value_loss.item(), total_update_step) 269 | writer.add_scalar('action loss', action_loss.item(), total_update_step) 270 | 271 | print('elasped time for update: %.2fs' % (time.time() - start)) 272 | 273 | # ---------------------------------------------- 274 | # evaluation without exploration noise 275 | # ---------------------------------------------- 276 | if (episode + 1) % args.test_interval == 0: 277 | policy = Agent(encoder, rssm, action_model) 278 | start = time.time() 279 | obs = env.reset() 280 | done = False 281 | total_reward = 0 282 | while not done: 283 | action = policy(obs, training=False) 284 | obs, reward, done, _ = env.step(action) 285 | total_reward += reward 286 | 287 | writer.add_scalar('total reward at test', total_reward, episode) 288 | print('Total test reward at episode [%4d/%4d] is %f' % 289 | (episode+1, args.all_episodes, total_reward)) 290 | print('elasped time for test: %.2fs' % (time.time() - start)) 291 | 292 | # save learned model parameters 293 | torch.save(encoder.state_dict(), os.path.join(log_dir, 'encoder.pth')) 294 | torch.save(rssm.state_dict(), os.path.join(log_dir, 'rssm.pth')) 295 | torch.save(obs_model.state_dict(), os.path.join(log_dir, 'obs_model.pth')) 296 | torch.save(reward_model.state_dict(), os.path.join(log_dir, 'reward_model.pth')) 297 | torch.save(value_model.state_dict(), os.path.join(log_dir, 'value_model.pth')) 298 | torch.save(action_model.state_dict(), os.path.join(log_dir, 'action_model.pth')) 299 | writer.close() 300 | 301 | if __name__ == '__main__': 302 | main() 303 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | """ 7 | Replay buffer for training with RNN 8 | """ 9 | def __init__(self, capacity, observation_shape, action_dim): 10 | self.capacity = capacity 11 | 12 | self.observations = np.zeros((capacity, *observation_shape), dtype=np.uint8) 13 | self.actions = np.zeros((capacity, action_dim), dtype=np.float32) 14 | self.rewards = np.zeros((capacity, 1), dtype=np.float32) 15 | self.done = np.zeros((capacity, 1), dtype=np.bool) 16 | 17 | self.index = 0 18 | self.is_filled = False 19 | 20 | def push(self, observation, action, reward, done): 21 | """ 22 | Add experience to replay buffer 23 | NOTE: observation should be transformed to np.uint8 before push 24 | """ 25 | self.observations[self.index] = observation 26 | self.actions[self.index] = action 27 | self.rewards[self.index] = reward 28 | self.done[self.index] = done 29 | 30 | if self.index == self.capacity - 1: 31 | self.is_filled = True 32 | self.index = (self.index + 1) % self.capacity 33 | 34 | def sample(self, batch_size, chunk_length): 35 | """ 36 | Sample experiences from replay buffer (almost) uniformly 37 | The resulting array will be of the form (batch_size, chunk_length) 38 | and each batch is consecutive sequence 39 | NOTE: too large chunk_length for the length of episode will cause problems 40 | """ 41 | episode_borders = np.where(self.done)[0] 42 | sampled_indexes = [] 43 | for _ in range(batch_size): 44 | cross_border = True 45 | while cross_border: 46 | initial_index = np.random.randint(len(self) - chunk_length + 1) 47 | final_index = initial_index + chunk_length - 1 48 | cross_border = np.logical_and(initial_index <= episode_borders, 49 | episode_borders < final_index).any() 50 | sampled_indexes += list(range(initial_index, final_index + 1)) 51 | 52 | sampled_observations = self.observations[sampled_indexes].reshape( 53 | batch_size, chunk_length, *self.observations.shape[1:]) 54 | sampled_actions = self.actions[sampled_indexes].reshape( 55 | batch_size, chunk_length, self.actions.shape[1]) 56 | sampled_rewards = self.rewards[sampled_indexes].reshape( 57 | batch_size, chunk_length, 1) 58 | sampled_done = self.done[sampled_indexes].reshape( 59 | batch_size, chunk_length, 1) 60 | return sampled_observations, sampled_actions, sampled_rewards, sampled_done 61 | 62 | def __len__(self): 63 | return self.capacity if self.is_filled else self.index 64 | 65 | 66 | def preprocess_obs(obs): 67 | """ 68 | conbert image from [0, 255] to [-0.5, 0.5] 69 | """ 70 | obs = obs.astype(np.float32) 71 | normalized_obs = obs / 255.0 - 0.5 72 | return normalized_obs 73 | 74 | 75 | def lambda_target(rewards, values, gamma, lambda_): 76 | """ 77 | Compute lambda target of value function 78 | rewards and values should be 2D-tensor and same size, 79 | and first-dimension means time step 80 | gamma is discount factor and lambda_ is weight to compute lambda target 81 | """ 82 | V_lambda = torch.zeros_like(rewards, device=rewards.device) 83 | 84 | H = rewards.shape[0] - 1 85 | V_n = torch.zeros_like(rewards, device=rewards.device) 86 | V_n[H] = values[H] 87 | for n in range(1, H+1): 88 | # compute n-step target 89 | # NOTE: If it hits the end, compromise with the largest possible n-step return 90 | V_n[:-n] = (gamma ** n) * values[n:] 91 | for k in range(1, n+1): 92 | if k == n: 93 | V_n[:-n] += (gamma ** (n-1)) * rewards[k:] 94 | else: 95 | V_n[:-n] += (gamma ** (k-1)) * rewards[k:-n+k] 96 | 97 | # add lambda_ weighted n-step target to compute lambda target 98 | if n == H: 99 | V_lambda += (lambda_ ** (H-1)) * V_n 100 | else: 101 | V_lambda += (1 - lambda_) * (lambda_ ** (n-1)) * V_n 102 | 103 | return V_lambda 104 | -------------------------------------------------------------------------------- /video_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | from matplotlib import animation 6 | import numpy as np 7 | import torch 8 | from dm_control import suite 9 | from dm_control.suite.wrappers import pixels 10 | from agent import Agent 11 | from model import Encoder, RecurrentStateSpaceModel, ObservationModel, ActionModel 12 | from utils import preprocess_obs 13 | from wrappers import GymWrapper, RepeatAction 14 | 15 | 16 | def save_video_as_gif(frames): 17 | """ 18 | make video with given frames and save as "video_prediction.gif" 19 | """ 20 | plt.figure() 21 | patch = plt.imshow(frames[0]) 22 | plt.axis('off') 23 | 24 | def animate(i): 25 | patch.set_data(frames[i]) 26 | plt.title('Left: GT frame' + ' '*20 + 'Right: predicted frame \n Step %d' % (i)) 27 | 28 | anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=150) 29 | anim.save('video_prediction.gif', writer='imagemagick') 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser(description='Open-loop video prediction with learned model') 34 | parser.add_argument('dir', type=str, help='log directory to load learned model') 35 | parser.add_argument('--length', type=int, default=50, 36 | help='the length of video prediction') 37 | parser.add_argument('--domain-name', type=str, default='cheetah') 38 | parser.add_argument('--task-name', type=str, default='run') 39 | parser.add_argument('-R', '--action-repeat', type=int, default=2) 40 | args = parser.parse_args() 41 | 42 | # define environment and apply wrapper 43 | env = suite.load(args.domain_name, args.task_name) 44 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 45 | 'width': 64, 46 | 'camera_id': 0}) 47 | env = GymWrapper(env) 48 | env = RepeatAction(env, skip=args.action_repeat) 49 | 50 | # define models 51 | with open(os.path.join(args.dir, 'args.json'), 'r') as f: 52 | train_args = json.load(f) 53 | 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | encoder = Encoder().to(device) 56 | rssm = RecurrentStateSpaceModel(train_args['state_dim'], 57 | env.action_space.shape[0], 58 | train_args['rnn_hidden_dim']).to(device) 59 | obs_model = ObservationModel(train_args['state_dim'], 60 | train_args['rnn_hidden_dim']).to(device) 61 | action_model = ActionModel(train_args['state_dim'], train_args['rnn_hidden_dim'], 62 | env.action_space.shape[0]).to(device) 63 | 64 | # load learned parameters 65 | encoder.load_state_dict(torch.load(os.path.join(args.dir, 'encoder.pth'))) 66 | rssm.load_state_dict(torch.load(os.path.join(args.dir, 'rssm.pth'))) 67 | obs_model.load_state_dict(torch.load(os.path.join(args.dir, 'obs_model.pth'))) 68 | action_model.load_state_dict(torch.load(os.path.join(args.dir, 'action_model.pth'))) 69 | 70 | # define agent 71 | policy = Agent(encoder, rssm, action_model) 72 | 73 | # open-loop video prediction 74 | # select starting point of open-loop prediction randomly 75 | starting_point = torch.randint(1000 // args.action_repeat - args.length, (1,)).item() 76 | # interact in environment until starting point and charge context in policy.rnn_hidden 77 | obs = env.reset() 78 | for _ in range(starting_point): 79 | action = policy(obs) 80 | obs, _, _, _ = env.step(action) 81 | 82 | # preprocess observatin and embed by encoder 83 | preprocessed_obs = preprocess_obs(obs) 84 | preprocessed_obs = torch.as_tensor(preprocessed_obs, device=device) 85 | preprocessed_obs = preprocessed_obs.transpose(1, 2).transpose(0, 1).unsqueeze(0) 86 | with torch.no_grad(): 87 | embedded_obs = encoder(preprocessed_obs) 88 | 89 | # compute state using embedded observation 90 | # NOTE: after this, state is updated only using prior, 91 | # it means model doesn't see observation 92 | rnn_hidden = policy.rnn_hidden 93 | state = rssm.posterior(rnn_hidden, embedded_obs).sample() 94 | frame = np.zeros((64, 128, 3)) 95 | frames = [] 96 | for _ in range(args.length): 97 | # action is selected same as training time (closed-loop) 98 | action = policy(obs) 99 | obs, _, _, _ = env.step(action) 100 | 101 | # update state and reconstruct observation with same action 102 | action = torch.as_tensor(action, device=device).unsqueeze(0) 103 | with torch.no_grad(): 104 | state_prior, rnn_hidden = rssm.prior(state, action, rnn_hidden) 105 | state = state_prior.sample() 106 | predicted_obs = obs_model(state, rnn_hidden) 107 | 108 | # arrange GT frame and predicted frame in parallel 109 | frame[:, :64, :] = preprocess_obs(obs) 110 | frame[:, 64:, :] = predicted_obs.squeeze().transpose(0, 1).transpose(1, 2).cpu().numpy() 111 | frames.append((frame + 0.5).clip(0.0, 1.0)) 112 | 113 | save_video_as_gif(frames) 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple OpenCV based viewer for rendering 3 | from https://github.com/zuoxingdong/dm2gym 4 | """ 5 | import uuid 6 | import cv2 7 | 8 | 9 | class OpenCVImageViewer: 10 | """ 11 | A simple OpenCV highgui based dm_control image viewer 12 | This class is meant to be a drop-in replacement for 13 | `gym.envs.classic_control.rendering.SimpleImageViewer` 14 | """ 15 | def __init__(self, *, escape_to_exit=False): 16 | """ 17 | Construct the viewing window 18 | """ 19 | self._escape_to_exit = escape_to_exit 20 | self._window_name = str(uuid.uuid4()) 21 | cv2.namedWindow(self._window_name, cv2.WINDOW_AUTOSIZE) 22 | self._isopen = True 23 | 24 | def __del__(self): 25 | """ 26 | Close the window 27 | """ 28 | cv2.destroyWindow(self._window_name) 29 | self._isopen = False 30 | 31 | def imshow(self, img): 32 | """ 33 | Show an image 34 | """ 35 | cv2.imshow(self._window_name, img[:, :, [2, 1, 0]]) 36 | if cv2.waitKey(1) in [27] and self._escape_to_exit: 37 | exit() 38 | 39 | @property 40 | def isopen(self): 41 | """ 42 | Is the window open? 43 | """ 44 | return self._isopen 45 | 46 | def close(self): 47 | pass 48 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from viewer import OpenCVImageViewer 4 | 5 | 6 | class GymWrapper(object): 7 | """ 8 | Gym interface wrapper for dm_control env wrapped by pixels.Wrapper 9 | """ 10 | metadata = {'render.modes': ['human', 'rgb_array']} 11 | reward_range = (-np.inf, np.inf) 12 | 13 | def __init__(self, env): 14 | self._env = env 15 | self._viewer = None 16 | 17 | def __getattr(self, name): 18 | return getattr(self._env, name) 19 | 20 | @property 21 | def observation_space(self): 22 | obs_spec = self._env.observation_spec() 23 | return gym.spaces.Box(0, 255, obs_spec['pixels'].shape, dtype=np.uint8) 24 | 25 | @property 26 | def action_space(self): 27 | action_spec = self._env.action_spec() 28 | return gym.spaces.Box(action_spec.minimum, action_spec.maximum, dtype=np.float32) 29 | 30 | def step(self, action): 31 | time_step = self._env.step(action) 32 | obs = time_step.observation['pixels'] 33 | reward = time_step.reward or 0 34 | done = time_step.last() 35 | info = {'discount': time_step.discount} 36 | return obs, reward, done, info 37 | 38 | def reset(self): 39 | time_step = self._env.reset() 40 | obs = time_step.observation['pixels'] 41 | return obs 42 | 43 | def render(self, mode='human', **kwargs): 44 | if not kwargs: 45 | kwargs = self._env._render_kwargs 46 | 47 | img = self._env.physics.render(**kwargs) 48 | if mode == 'rgb_array': 49 | return img 50 | elif mode == 'human': 51 | if self._viewer is None: 52 | self._viewer = OpenCVImageViewer() 53 | self._viewer.imshow(img) 54 | return self._viewer.isopen 55 | else: 56 | raise NotImplementedError 57 | 58 | 59 | class RepeatAction(gym.Wrapper): 60 | """ 61 | Action repeat wrapper to act same action repeatedly 62 | """ 63 | def __init__(self, env, skip=4): 64 | gym.Wrapper.__init__(self, env) 65 | self._skip = skip 66 | 67 | def reset(self): 68 | return self.env.reset() 69 | 70 | def step(self, action): 71 | total_reward = 0.0 72 | for _ in range(self._skip): 73 | obs, reward, done, info = self.env.step(action) 74 | total_reward += reward 75 | if done: 76 | break 77 | return obs, total_reward, done, info 78 | --------------------------------------------------------------------------------