├── wmlib ├── envs │ ├── __init__.py │ ├── utils.py │ ├── metaworld.py │ └── robodesk.py ├── __init__.py ├── nets │ ├── va_net │ │ ├── __init__.py │ │ ├── action_encoder.py │ │ └── va_net.py │ ├── dynamics │ │ ├── __init__.py │ │ ├── rssm.py │ │ └── base.py │ ├── decoder │ │ ├── __init__.py │ │ ├── base.py │ │ ├── resnet.py │ │ ├── plaincnn.py │ │ └── deco_resnet.py │ ├── __init__.py │ └── encoder │ │ ├── __init__.py │ │ ├── plaincnn.py │ │ ├── resnet.py │ │ ├── va_resnet.py │ │ ├── base.py │ │ ├── ctx_resnet.py │ │ └── deco_resnet.py ├── datasets │ ├── __init__.py │ ├── video │ │ ├── __init__.py │ │ ├── somethingv2.py │ │ ├── utils.py │ │ └── somethingv2_flow.py │ └── utils.py ├── train │ ├── __init__.py │ ├── finetuner.py │ └── pretrainer.py ├── core │ ├── __init__.py │ └── driver.py ├── utils │ ├── __init__.py │ ├── other.py │ ├── seed.py │ ├── counter.py │ ├── when.py │ ├── timer.py │ ├── flags.py │ └── logger.py └── agents │ ├── __init__.py │ ├── random_agent.py │ ├── expl.py │ ├── actor_critic.py │ └── base.py ├── assets └── framework.png ├── LICENSE ├── data └── somethingv2 │ ├── extract_frames.py │ └── process_somethingv2.py ├── examples ├── train_apv_pretraining.py ├── train_prelar_wo_cl_pretraining.py ├── train_prelar_pretraining.py ├── train_dreamerv2.py ├── train_apv_finetuning.py ├── train_naive_finetune.py └── train_prelar_finetuning.py ├── .gitignore ├── configs ├── apv_pretraining.yaml ├── prelar_pretraining.yaml ├── prelar_wo_al_pretraining.yaml ├── dreamerv2.yaml ├── naive_finetuning.yaml ├── apv_finetuning.yaml └── prelar_finetuning.yaml └── environment.yaml /wmlib/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import make_env, make_async_env -------------------------------------------------------------------------------- /wmlib/__init__.py: -------------------------------------------------------------------------------- 1 | DEBUG_METRICS = True 2 | 3 | from .core import * 4 | -------------------------------------------------------------------------------- /wmlib/nets/va_net/__init__.py: -------------------------------------------------------------------------------- 1 | from .va_net import * 2 | from .action_encoder import * -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VIPL-EPP/PreLAR/HEAD/assets/framework.png -------------------------------------------------------------------------------- /wmlib/nets/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from .rssm import * 2 | from .va_rssm import VAEnsembleRSSM -------------------------------------------------------------------------------- /wmlib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .video import * 2 | from .utils import make_action_free_dataset 3 | -------------------------------------------------------------------------------- /wmlib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .pretrainer import Pretrainer 3 | from .finetuner import Finetuner -------------------------------------------------------------------------------- /wmlib/nets/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .plaincnn import * 3 | from .resnet import * 4 | from .deco_resnet import * -------------------------------------------------------------------------------- /wmlib/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .dists import * 2 | from .driver import * 3 | from .other import * 4 | from .replay import * 5 | from .torch_utils import * -------------------------------------------------------------------------------- /wmlib/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import * 2 | from .encoder import * 3 | from .dynamics import * 4 | from .modules import * 5 | from .va_net import * -------------------------------------------------------------------------------- /wmlib/datasets/video/__init__.py: -------------------------------------------------------------------------------- 1 | from .somethingv2 import SomethingV2 2 | from .utils import Mixture, DummyReplay 3 | from .somethingv2_flow import SomethingV2Flow 4 | -------------------------------------------------------------------------------- /wmlib/nets/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .plaincnn import * 3 | from .resnet import * 4 | from .ctx_resnet import * 5 | from .deco_resnet import * 6 | from .va_resnet import * -------------------------------------------------------------------------------- /wmlib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .seed import set_seed 2 | from .flags import Flags 3 | from .counter import Counter 4 | from .timer import Timer 5 | from .when import Every, Once, Until 6 | from .config import Config 7 | from .logger import Logger, TerminalOutput,JSONLOutput, WandbOutput 8 | 9 | 10 | -------------------------------------------------------------------------------- /wmlib/utils/other.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_dir(dir_path): 5 | try: 6 | os.mkdir(dir_path) 7 | except OSError: 8 | pass 9 | return dir_path 10 | 11 | 12 | def snapshot_src(src, target, exclude_from): 13 | make_dir(target) 14 | os.system(f"rsync -rv --exclude-from={exclude_from} {src} {target}") -------------------------------------------------------------------------------- /wmlib/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .random_agent import RandomAgent 2 | from .base import BaseAgent 3 | from .dreamerv2 import DreamerV2 4 | from .apv_pretrain import APV_Pretrain 5 | from .apv_finetune import APV_Finetune 6 | from .naive_finetune import Naive_Finetune 7 | from .prelar_pretrain import PreLARPretrain 8 | from .prelar_finetune import PreLARFinetune 9 | from .prelar_wo_cl_pretrain import PreLARwoCLPretrain -------------------------------------------------------------------------------- /wmlib/utils/seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | def set_seed(seed): 6 | """Set the seed for all random number generators.""" 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed_all(seed) 9 | torch.cuda.manual_seed(seed) 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | torch.backends.cudnn.deterministic = True # no apparent impact on speed 13 | torch.backends.cudnn.benchmark = True # faster, increases memory though.., no impact on seed 14 | return seed -------------------------------------------------------------------------------- /wmlib/utils/counter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.total_ordering 5 | class Counter: 6 | 7 | def __init__(self, initial=0): 8 | self.value = initial 9 | 10 | def __int__(self): 11 | return int(self.value) 12 | 13 | def __str__(self): 14 | return str(int(self)) 15 | 16 | def __eq__(self, other): 17 | return int(self) == other 18 | 19 | def __ne__(self, other): 20 | return int(self) != other 21 | 22 | def __lt__(self, other): 23 | return int(self) < other 24 | 25 | def __add__(self, other): 26 | return int(self) + other 27 | 28 | def increment(self, amount=1): 29 | self.value += amount 30 | -------------------------------------------------------------------------------- /wmlib/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as tdist 3 | from ..core import dists 4 | 5 | 6 | class RandomAgent: 7 | 8 | def __init__(self, act_space, logprob=False): 9 | self.act_space = act_space['action'] 10 | self.logprob = logprob 11 | if hasattr(self.act_space, 'n'): 12 | self._dist = dists.OneHotDist(torch.zeros(self.act_space.n)) 13 | else: 14 | dist = tdist.Uniform(torch.tensor(self.act_space.low), torch.tensor(self.act_space.high)) 15 | self._dist = tdist.Independent(dist, 1) 16 | 17 | def __call__(self, obs, state=None, mode=None): 18 | action = self._dist.sample((len(obs['is_first']),)) 19 | output = {'action': action} 20 | if self.logprob: 21 | output['logprob'] = self._dist.log_prob(action) 22 | return output, None 23 | -------------------------------------------------------------------------------- /wmlib/utils/when.py: -------------------------------------------------------------------------------- 1 | class Every: 2 | 3 | def __init__(self, every): 4 | self._every = every 5 | self._last = None 6 | 7 | def __call__(self, step): 8 | step = int(step) 9 | if not self._every: 10 | return False 11 | if self._last is None: 12 | self._last = step 13 | return True 14 | if step >= self._last + self._every: 15 | self._last += self._every 16 | return True 17 | return False 18 | 19 | 20 | class Once: 21 | 22 | def __init__(self): 23 | self._once = True 24 | 25 | def __call__(self): 26 | if self._once: 27 | self._once = False 28 | return True 29 | return False 30 | 31 | 32 | class Until: 33 | 34 | def __init__(self, until): 35 | self._until = until 36 | 37 | def __call__(self, step): 38 | step = int(step) 39 | if not self._until: 40 | return True 41 | return step < self._until 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 VIPL @ Institute of Computing Technology, Chinese Academy of Sciences 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /wmlib/envs/utils.py: -------------------------------------------------------------------------------- 1 | from .metaworld import MetaWorld 2 | from .robodesk import RoboDesk 3 | from .wrappers import NormalizeAction, TimeLimit, Async 4 | import functools 5 | import os 6 | 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | 10 | def make_env(config:dict, mode): 11 | suite, task = config.task.split("_", 1) 12 | 13 | if suite == "metaworld": 14 | task = "-".join(task.split("_")) 15 | env = MetaWorld( 16 | task, 17 | config.seed, 18 | config.action_repeat, 19 | config.render_size, 20 | config.camera, 21 | ) 22 | env = NormalizeAction(env) 23 | elif suite == "robodesk": 24 | env = RoboDesk(task, config.seed, config.action_repeat, config.render_size, 25 | evaluate=mode=='eval') 26 | env = NormalizeAction(env) 27 | else: 28 | raise NotImplementedError(suite) 29 | env = TimeLimit(env, config.time_limit) 30 | return env 31 | 32 | make_async_env = lambda config, mode: Async(functools.partial(make_env, config, mode), config.envs_parallel) -------------------------------------------------------------------------------- /data/somethingv2/extract_frames.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits to https://github.com/zhoubolei/TRN-pytorch/blob/master/extract_frames.py 3 | """ 4 | 5 | import os 6 | import threading 7 | 8 | NUM_THREADS = 100 9 | VIDEO_ROOT = 'dataset/Something-Something/20bn-something-something-v2' # Downloaded webm videos 10 | IMAGE_SIZE = 240 #64 11 | FRAME_ROOT = f'dataset/Something-Something/20bn-something-something-v2-frames-{IMAGE_SIZE}' # Directory for extracted frames 12 | 13 | 14 | def split(l, n): 15 | """Yield successive n-sized chunks from l.""" 16 | for i in range(0, len(l), n): 17 | yield l[i:i + n] 18 | 19 | 20 | def extract(video, tmpl='%06d.jpg'): 21 | os.system(f'ffmpeg -i {VIDEO_ROOT}/{video} -vf scale={IMAGE_SIZE}:{IMAGE_SIZE} ' 22 | f'{FRAME_ROOT}/{video[:-5]}/{tmpl}') 23 | 24 | 25 | def target(video_list): 26 | for video in video_list: 27 | targt_path = os.path.join(FRAME_ROOT, video[:-5]) 28 | if not os.path.exists(targt_path): 29 | os.makedirs(targt_path) 30 | extract(video) 31 | 32 | 33 | if not os.path.exists(VIDEO_ROOT): 34 | raise ValueError('Please download videos and set VIDEO_ROOT variable.') 35 | if not os.path.exists(FRAME_ROOT): 36 | os.makedirs(FRAME_ROOT) 37 | 38 | video_list = os.listdir(VIDEO_ROOT) 39 | splits = list(split(video_list, NUM_THREADS)) 40 | 41 | threads = [] 42 | for i, split in enumerate(splits): 43 | thread = threading.Thread(target=target, args=(split,)) 44 | thread.start() 45 | threads.append(thread) 46 | 47 | for thread in threads: 48 | thread.join() 49 | -------------------------------------------------------------------------------- /wmlib/utils/timer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | import time 4 | import numpy as np 5 | 6 | 7 | class Timer: 8 | 9 | def __init__(self): 10 | self._indurs = collections.defaultdict(list) 11 | self._outdurs = collections.defaultdict(list) 12 | self._start_times = {} 13 | self._end_times = {} 14 | 15 | @contextlib.contextmanager 16 | def section(self, name): 17 | self.start(name) 18 | yield 19 | self.end(name) 20 | 21 | def wrap(self, function, name): 22 | def wrapped(*args, **kwargs): 23 | with self.section(name): 24 | return function(*args, **kwargs) 25 | return wrapped 26 | 27 | def start(self, name): 28 | now = time.time() 29 | self._start_times[name] = now 30 | if name in self._end_times: 31 | last = self._end_times[name] 32 | self._outdurs[name].append(now - last) 33 | 34 | def end(self, name): 35 | now = time.time() 36 | self._end_times[name] = now 37 | self._indurs[name].append(now - self._start_times[name]) 38 | 39 | def result(self): 40 | metrics = {} 41 | for key in self._indurs: 42 | indurs = self._indurs[key] 43 | outdurs = self._outdurs[key] 44 | metrics[f'timer_count_{key}'] = len(indurs) 45 | metrics[f'timer_inside_{key}'] = np.sum(indurs) 46 | metrics[f'timer_outside_{key}'] = np.sum(outdurs) 47 | indurs.clear() 48 | outdurs.clear() 49 | return metrics 50 | -------------------------------------------------------------------------------- /wmlib/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from .video import SomethingV2, DummyReplay, SomethingV2Flow 3 | from ..core import ReplayWithoutAction 4 | from ..core import Replay 5 | 6 | def make_action_free_dataset(dataset_type,root_paths:dict=dict(),video_index_files:dict=dict(),segment_len=50,manual_labels=False,**kwargs): 7 | video_index_files = {dataset_type:video_index_files} if isinstance(video_index_files,str) else video_index_files 8 | root_paths = {dataset_type:root_paths} if isinstance(root_paths,str) else root_paths 9 | if dataset_type == 'replay': 10 | train_replay = ReplayWithoutAction(root_paths['replay'],**kwargs) # rquire load_directory, seed 11 | elif dataset_type == 'something': 12 | somethingv2_dataset = SomethingV2( 13 | root_path=root_paths['something'], 14 | list_file=f'data/somethingv2/{video_index_files["something"]}.txt', 15 | segment_len=segment_len, 16 | manual_labels=manual_labels, 17 | ) 18 | train_replay = DummyReplay(somethingv2_dataset) 19 | elif dataset_type == 'something_flow': 20 | somethingv2_dataset = SomethingV2Flow( 21 | root_path=root_paths['something'], 22 | list_file=f'data/somethingv2/{video_index_files["something"]}.txt', 23 | segment_len=segment_len, 24 | manual_labels=manual_labels, 25 | ) 26 | train_replay = DummyReplay(somethingv2_dataset) 27 | elif dataset_type == 'rlbench': 28 | train_replay = Replay(Path(root_paths['rlbench']), **kwargs) 29 | else: 30 | raise NotImplementedError 31 | return train_replay -------------------------------------------------------------------------------- /wmlib/nets/encoder/plaincnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | from .base import BaseEncoder 6 | from ..modules import get_act_module, NormLayer 7 | 8 | class PlainCNNEncoder(BaseEncoder): 9 | 10 | def __init__( 11 | self, 12 | shapes, 13 | cnn_keys=r".*", 14 | mlp_keys=r".*", 15 | act="elu", 16 | norm="none", 17 | cnn_depth=48, 18 | cnn_kernels=(4, 4, 4, 4), 19 | mlp_layers=[400, 400, 400, 400], 20 | **dummy_kwargs, 21 | ): 22 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 23 | 24 | self._act_module = get_act_module(act) 25 | # self._act = get_act(act) 26 | self._norm = norm 27 | self._cnn_depth = cnn_depth 28 | self._cnn_kernels = cnn_kernels 29 | 30 | h, w, c = self.shapes[self.cnn_keys[0]] # raw image shape 31 | self._cnn_nn = nn.Sequential() 32 | for i, kernel in enumerate(self._cnn_kernels): 33 | depth = 2 ** i * self._cnn_depth 34 | input_channels = depth // 2 if i else c 35 | h, w = (h - kernel ) // 2 + 1, (w - kernel) // 2 + 1 # (h - k + 2p) // s + 1 36 | self._cnn_nn.add_module(f"conv{i}", nn.Conv2d(input_channels, depth, kernel, 2)) 37 | self._cnn_nn.add_module(f"convnorm{i}", NormLayer(self._norm, (depth, h, w))) 38 | self._cnn_nn.add_module(f"act{i}", self._act_module()) 39 | 40 | self._cnn_nn.add_module('flatten', Rearrange('b c h w -> b (c h w)')) 41 | 42 | def _cnn(self, data): 43 | x = torch.cat(list(data.values()), -1) 44 | x = x.to(memory_format=torch.channels_last) 45 | return self._cnn_nn(x) 46 | -------------------------------------------------------------------------------- /wmlib/nets/encoder/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | from .base import BaseEncoder 6 | from ..modules import get_act_module, ResidualStack 7 | 8 | class ResNetEncoder(BaseEncoder): 9 | 10 | def __init__( 11 | self, 12 | shapes, 13 | cnn_keys=r".*", 14 | mlp_keys=r".*", 15 | act="elu", 16 | cnn_depth=48, 17 | mlp_layers=[400, 400, 400, 400], 18 | res_layers=2, 19 | res_depth=3, 20 | res_norm='none', 21 | **dummy_kwargs, 22 | ): 23 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 24 | self._act_module = get_act_module(act) 25 | self._cnn_depth = cnn_depth 26 | 27 | self._res_layers = res_layers 28 | self._res_depth = res_depth 29 | self._res_norm = res_norm 30 | 31 | h, w, c = self.shapes[self.cnn_keys[0]] # raw image shape 32 | self._cnn_net = nn.Sequential() 33 | self._cnn_net.add_module('convin', nn.Conv2d(c, self._cnn_depth, 3, 2, 1)) 34 | self._cnn_net.add_module('act', self._act_module()) 35 | for i in range(self._res_depth): 36 | depth = 2 ** i * self._cnn_depth 37 | input_channels = depth // 2 if i else self._cnn_depth 38 | self._cnn_net.add_module(f"res{i}", ResidualStack(input_channels, depth, 39 | self._res_layers, 40 | norm=self._res_norm)) 41 | self._cnn_net.add_module(f"pool{i}", nn.AvgPool2d(2, 2)) 42 | 43 | self._cnn_net.add_module('flatten', Rearrange('b c h w -> b (c h w)')) 44 | 45 | 46 | 47 | def _cnn(self, data): 48 | x = torch.cat(list(data.values()), -1) 49 | x = x.to(memory_format=torch.channels_last) 50 | return self._cnn_net(x) 51 | -------------------------------------------------------------------------------- /wmlib/train/finetuner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | 5 | from wmlib.core.replay import Replay 6 | from .trainer import Trainer 7 | from wmlib.utils import Counter, Logger 8 | from wmlib.agents import BaseAgent 9 | 10 | class Finetuner(Trainer): 11 | def __init__(self, config, agent: BaseAgent, train_replay: Replay, eval_replay: Replay, train_envs, eval_envs, step: Counter, logger: Logger) -> None: 12 | super().__init__(config, agent, train_replay, eval_replay, train_envs, eval_envs, step, logger) 13 | 14 | def load_agent(self,path:Path,iteration:int=0): 15 | if path.exists(): 16 | print(f'Load agent from checkpoint {path}.') 17 | self.agent.load_state_dict(torch.load(path)) 18 | else: 19 | load_logdir = Path(self.config.load_logdir).expanduser() 20 | if load_logdir != 'none': 21 | if 'af_rssm' in self.config.load_modules: 22 | print(self.agent.wm.af_rssm.load_state_dict(torch.load(load_logdir / 'rssm_variables.pt'), strict=self.config.load_strict)) 23 | print(f'Load af_rssm from checkpoint {load_logdir}/rssm_variables.pt.') 24 | if 'encoder' in self.config.load_modules: 25 | print(self.agent.wm.encoder.load_state_dict(torch.load( 26 | load_logdir / 'encoder_variables.pt'), strict=self.config.load_strict)) 27 | print(f'Load encoder from checkpoint {load_logdir}/encoder_variables.pt.') 28 | if 'decoder' in self.config.load_modules: 29 | print(self.agent.wm.heads['decoder'].load_state_dict(torch.load( 30 | load_logdir / 'decoder_variables.pt'), strict=self.config.load_strict)) 31 | print(f'Load decoder from checkpoint {load_logdir}/decoder_variables.pt.') 32 | print(f'Pretrain agent from scratch {iteration} iterations.') 33 | for _ in tqdm(range(iteration)): 34 | self.train_agent(self.next_batch(self.train_dataset)) -------------------------------------------------------------------------------- /wmlib/nets/encoder/va_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | from .base import BaseEncoder 6 | from ..modules import get_act_module, ResidualStack 7 | 8 | 9 | class VAResNetEncoder(BaseEncoder): 10 | 11 | def __init__( 12 | self, 13 | shapes, 14 | cnn_keys=r".*", 15 | mlp_keys=r".*", 16 | act="elu", 17 | cnn_depth=48, 18 | mlp_layers=[400, 400, 400, 400], 19 | res_layers=2, 20 | res_depth=3, 21 | res_norm='none', 22 | va_method='concate', 23 | **dummy_kwargs, 24 | ): 25 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 26 | self._act_module = get_act_module(act) 27 | self._cnn_depth = cnn_depth 28 | 29 | self._res_layers = res_layers 30 | self._res_depth = res_depth 31 | self._res_norm = res_norm 32 | 33 | h, w, c = self.shapes[self.cnn_keys[0]] # raw image shape 34 | if va_method == 'concate': 35 | input_channels = c * 2 36 | elif va_method in ['diff', 'flow']: 37 | input_channels = c 38 | elif va_method == 'catdiff': 39 | input_channels = c * 3 40 | else: 41 | raise ValueError(f"Unknown va_method: {va_method}") 42 | self._cnn_net = nn.Sequential() 43 | self._cnn_net.add_module('convin', nn.Conv2d(input_channels, self._cnn_depth, 3, 2, 1)) 44 | self._cnn_net.add_module('act', self._act_module()) 45 | for i in range(self._res_depth): 46 | depth = 2 ** i * self._cnn_depth 47 | input_channels = depth // 2 if i else self._cnn_depth 48 | self._cnn_net.add_module(f"res{i}", ResidualStack(input_channels, depth, 49 | self._res_layers, 50 | norm=self._res_norm)) 51 | self._cnn_net.add_module(f"pool{i}", nn.AvgPool2d(2, 2)) 52 | 53 | self._cnn_net.add_module('flatten', Rearrange('b c h w -> b (c h w)')) 54 | 55 | def _cnn(self, data): 56 | x = torch.cat(list(data.values()), -1) 57 | x = x.to(memory_format=torch.channels_last) 58 | return self._cnn_net(x) 59 | -------------------------------------------------------------------------------- /wmlib/nets/encoder/base.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from abc import ABC, abstractmethod 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from ..modules import NormLayer 9 | 10 | 11 | class BaseEncoder(nn.Module, ABC): 12 | 13 | def __init__( 14 | self, 15 | shapes, 16 | cnn_keys=r".*", 17 | mlp_keys=r".*", 18 | mlp_layers=[400, 400, 400, 400], 19 | mlp_input_dim = None, 20 | ): 21 | super().__init__() 22 | self.shapes = shapes 23 | self.cnn_keys = [ 24 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3 25 | ] 26 | self.mlp_keys = [ 27 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1 28 | ] 29 | print("Encoder CNN inputs:", list(self.cnn_keys)) 30 | print("Encoder MLP inputs:", list(self.mlp_keys)) 31 | self._mlp_layers = mlp_layers 32 | 33 | if self.mlp_keys: 34 | assert mlp_input_dim is not None 35 | inpup_dim = mlp_input_dim 36 | self._mlp_nn = nn.Sequential() 37 | for i, width in enumerate(self._mlp_layers): 38 | self._mlp_nn.add_module(f"dense{i}", nn.Linear(inpup_dim, width)) 39 | self._mlp_nn.add_module(f"densenorm{i}", NormLayer(self._norm, width)) 40 | self._mlp_nn.add_module(f"act{i}", self._act_module()) 41 | inpup_dim = width 42 | 43 | def forward(self, data): 44 | key, shape = list(self.shapes.items())[0] 45 | batch_dims = data[key].shape[:-len(shape)] 46 | data = { 47 | k: torch.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):]) 48 | for k, v in data.items() 49 | } 50 | outputs = [] 51 | if self.cnn_keys: 52 | outputs.append(self._cnn({k: data[k] for k in self.cnn_keys})) 53 | if self.mlp_keys: 54 | outputs.append(self._mlp({k: data[k] for k in self.mlp_keys})) 55 | output = torch.cat(outputs, -1) 56 | return output.reshape(batch_dims + output.shape[1:]) 57 | 58 | @abstractmethod 59 | def _cnn(self, data): 60 | pass 61 | 62 | def _mlp(self, data): 63 | x = torch.cat(list(data.values()), -1) 64 | x = self._mlp_nn(x) 65 | return x 66 | -------------------------------------------------------------------------------- /data/somethingv2/process_somethingv2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | from tqdm import trange 5 | from pathlib import Path 6 | 7 | dataset_path = 'data/somethingv2/20bn-something-something-v2-frames-64' 8 | lebels_base_ori = 'dataset/Something-Something/labels' 9 | with open(Path(lebels_base_ori) / Path('labels.json'), 'r') as f: 10 | label = json.load(f) 11 | with open(Path(lebels_base_ori) / Path('train.json'), 'r') as f: 12 | train = json.load(f) 13 | with open(Path(lebels_base_ori) / Path('validation.json'), 'r') as f: 14 | val = json.load(f) 15 | with open(Path(lebels_base_ori) / Path('test.json'), 'r') as f: 16 | test = json.load(f) 17 | 18 | cnt = 0 19 | 20 | category = [k for k in label.keys()] 21 | 22 | train_txt = [] 23 | for i in trange(len(train), desc='train'): 24 | cur_index = train[i]['id'] 25 | cur_label = train[i]['template'] 26 | cur_label = cur_label.replace(']', '').replace('[', '') 27 | cur_id = label[cur_label] 28 | num_frames = len(os.listdir(os.path.join(dataset_path, cur_index))) 29 | if num_frames == 0: 30 | cnt += 1 31 | train_txt.append('%s %d %s' % (cur_index, num_frames, cur_id)) 32 | 33 | val_txt = [] 34 | for i in trange(len(val), desc='val'): 35 | cur_index = val[i]['id'] 36 | cur_label = val[i]['template'] 37 | cur_label = cur_label.replace(']', '').replace('[', '') 38 | cur_id = label[cur_label] 39 | num_frames = len(os.listdir(os.path.join(dataset_path, cur_index))) 40 | if num_frames == 0: 41 | cnt += 1 42 | val_txt.append('%s %d %s' % (cur_index, num_frames, cur_id)) 43 | 44 | # test_txt = [] 45 | # for i in trange(len(test), desc='test'): 46 | # cur_index = test[i]['id'] 47 | # cur_label = test[i]['template'] 48 | # cur_label = cur_label.replace(']', '').replace('[', '') 49 | # cur_id = label[cur_label] 50 | # num_frames = len(os.listdir(os.path.join(dataset_path, cur_index))) 51 | # if num_frames == 0: 52 | # cnt += 1 53 | # test_txt.append('%s %d %s' % (cur_index, num_frames, cur_id)) 54 | 55 | with open('train_video_folder.txt', 'w') as f: 56 | f.write('\n'.join(train_txt)) 57 | with open('val_video_folder.txt', 'w') as f: 58 | f.write('\n'.join(val_txt)) 59 | with open('train_val_video_folder.txt', 'w') as f: 60 | f.write('\n'.join(train_txt + val_txt)) 61 | # with open('test_folder.txt', 'w') as f: 62 | # f.write('\n'.join(test_txt)) 63 | 64 | with open('category.txt', 'w') as f: 65 | f.write('\n'.join(category)) 66 | 67 | print(cnt, 'empty') 68 | -------------------------------------------------------------------------------- /wmlib/nets/decoder/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | from abc import ABC, abstractmethod 3 | 4 | import torch.nn as nn 5 | 6 | from ..modules import NormLayer, DistLayer 7 | from ... import core 8 | 9 | 10 | class BaseDecoder(nn.Module, ABC): 11 | 12 | def __init__( 13 | self, 14 | shapes, 15 | cnn_keys=r".*", 16 | mlp_keys=r".*", 17 | mlp_layers=[400, 400, 400, 400], 18 | mlp_input_dim = None, 19 | ): 20 | super().__init__() 21 | self._shapes = shapes 22 | self.cnn_keys = [ 23 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3 24 | ] 25 | self.mlp_keys = [ 26 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1 27 | ] 28 | print("Decoder CNN outputs:", list(self.cnn_keys)) 29 | print("Decoder MLP outputs:", list(self.mlp_keys)) 30 | 31 | self._mlp_layers = mlp_layers 32 | if self.mlp_keys: 33 | assert mlp_input_dim is not None 34 | self.init_mlp(mlp_input_dim) 35 | 36 | def forward(self, features): 37 | outputs = {} 38 | if self.cnn_keys: 39 | outputs.update(self._cnn(features)) 40 | if self.mlp_keys: 41 | outputs.update(self._mlp(features)) 42 | return outputs 43 | 44 | @abstractmethod 45 | def _cnn(self, features): 46 | pass 47 | 48 | def _mlp(self, features): 49 | shapes = {k: self._shapes[k] for k in self.mlp_keys} 50 | # x = features 51 | x = self._mlp_nn(features) 52 | dist = {} 53 | for key, shape in shapes.items(): 54 | dist[key] = self.heads[f'dense_{key}'](x) 55 | return dist 56 | # for i, width in enumerate(self._mlp_layers): 57 | # x = self.get(f"dense{i}", nn.Linear, x.shape[-1], width)(x) 58 | # x = self.get(f"densenorm{i}", NormLayer, self._norm, x.shape[-1:])(x) 59 | # x = self._act(x) 60 | # dists = {} 61 | # for key, shape in shapes.items(): 62 | # dists[key] = self.get(f"dense_{key}", DistLayer, shape)(x) 63 | # return dists 64 | 65 | def init_mlp(self,input_dim): 66 | self._mlp_nn = nn.Sequential() 67 | for i, width in enumerate(self._mlp_layers): 68 | self._mlp_nn.add_module(f"dense{i}", nn.Linear(input_dim, width)) 69 | self._mlp_nn.add_module(f"densenorm{i}", NormLayer(self._norm, width)) 70 | self._mlp_nn.add_module(f"act{i}", self._act_module()) 71 | input_dim = width 72 | self.heads = nn.ModuleList() 73 | shapes = {k: self._shapes[k] for k in self.mlp_keys} 74 | for key, shape in shapes.items(): 75 | self.heads.add_module(f"dense_{key}", DistLayer(shape)) 76 | # dists[key] = self.get(f"dense_{key}", DistLayer, shape)(x) 77 | 78 | -------------------------------------------------------------------------------- /wmlib/datasets/video/somethingv2.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | from .utils import VideoRecord 9 | from pathlib import Path 10 | 11 | # NOTE: you can manually select videos with specific labels for training here, by default we use all videos 12 | maunally_selected_labels = { 13 | "93": "Pushing something from left to right", 14 | "94": "Pushing something from right to left", 15 | } 16 | 17 | 18 | class SomethingV2(Dataset): 19 | def __init__(self, root_path, list_file, segment_len=50, image_tmpl='{:06d}.jpg', manual_labels=False): 20 | self.root_path = Path(root_path).expanduser() 21 | self.list_file = list_file 22 | self.segment_len = segment_len 23 | self.image_tmpl = image_tmpl 24 | 25 | self._parse_list(self.segment_len, maunally_selected_labels if manual_labels else None) 26 | 27 | def _parse_list(self, minlen, selected_labels=None): 28 | # check the frame number is large >segment_len: 29 | # usually it is [video_id, num_frames, class_idx] 30 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 31 | tmp = [item for item in tmp if int(item[1]) >= minlen and ( 32 | (selected_labels is None) or (item[2] in selected_labels.keys()))] 33 | self.video_list = [VideoRecord(item) for item in tmp] 34 | print('video number:%d' % (len(self.video_list))) 35 | 36 | @property 37 | def total_steps(self): 38 | return sum([record.num_frames for record in self.video_list]) 39 | 40 | def _load_image(self, directory, idx): 41 | # TODO: cache 42 | # image = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB') 43 | image = Image.open(self.root_path / directory / self.image_tmpl.format(idx)).convert('RGB') 44 | return np.array(image) 45 | 46 | def _sample_index(self, record): 47 | return np.random.randint(0, record.num_frames - self.segment_len + 1) +1 # start from 1 48 | 49 | def get(self, record, ind): 50 | images = [] 51 | p = ind 52 | for i in range(self.segment_len): 53 | seg_imgs = self._load_image(record.path, p) 54 | images.append(seg_imgs) 55 | if p < record.num_frames: 56 | p += 1 57 | # images = self.transform(images) 58 | return np.array(images) 59 | 60 | def __getitem__(self, index): 61 | record = self.video_list[index] 62 | # check this is a legit video folder 63 | while not (self.root_path / record.path / self.image_tmpl.format(1)).exists(): 64 | print(self.root_path / record.path / self.image_tmpl.format(1)) 65 | # while not os.path.exists(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))): 66 | # print(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))) 67 | index = np.random.randint(len(self.video_list)) 68 | record = self.video_list[index] 69 | 70 | segment_index = self._sample_index(record) 71 | segment = self.get(record, segment_index) 72 | return segment 73 | 74 | def __len__(self): 75 | return len(self.video_list) 76 | -------------------------------------------------------------------------------- /wmlib/datasets/video/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader, IterableDataset 2 | 3 | import numpy as np 4 | from ...core import seed_worker 5 | 6 | 7 | class VideoRecord(object): 8 | def __init__(self, row): 9 | self._data = row 10 | 11 | @property 12 | def path(self): 13 | return self._data[0] 14 | 15 | @property 16 | def num_frames(self): 17 | return int(self._data[1]) 18 | 19 | @property 20 | def label(self): 21 | return int(self._data[2]) 22 | 23 | def __str__(self): 24 | return str(self._data) 25 | 26 | 27 | # a mixture of datasets 28 | class Mixture(Dataset): 29 | 30 | def __init__(self, datasets, weights=None): 31 | self.datasets = datasets 32 | self.weights = weights 33 | self.lengths = [len(d) for d in datasets] 34 | self.total_length = sum(self.lengths) 35 | 36 | def __getitem__(self, index): 37 | if self.weights is None: 38 | dataset_index = np.random.randint(len(self.datasets)) 39 | else: 40 | dataset_index = np.random.choice(len(self.datasets), p=self.weights) 41 | return self.datasets[dataset_index][index % self.lengths[dataset_index]] 42 | 43 | def __len__(self): 44 | return self.total_length 45 | 46 | @property 47 | def total_steps(self): 48 | return sum([d.total_steps for d in self.datasets]) 49 | 50 | 51 | class DummyReplay: # A wrapper make datasets behave like replay buffers 52 | 53 | def __init__(self, video_dataset) -> None: 54 | self.video_dataset = video_dataset 55 | 56 | def _generate_chunks(self, length): 57 | while True: 58 | ind = np.random.randint(len(self.video_dataset)) 59 | image = self.video_dataset[ind] 60 | action = np.zeros((image.shape[0], 1), dtype=np.float32) 61 | is_first = np.zeros((image.shape[0]), dtype=bool) 62 | is_first[0] = True 63 | chunk = { 64 | 'image': image, 65 | 'action': action, 66 | 'is_first': is_first, 67 | } 68 | # T,H,W,C -> T,C,H,W 69 | if len(chunk['image'].shape) == 4: 70 | chunk['image'] = chunk['image'].transpose(0, 3, 1, 2) 71 | yield chunk 72 | 73 | def dataset(self, batch, length, pin_memory=True, num_workers=8, **kwargs): 74 | generator = lambda: self._generate_chunks(length) 75 | 76 | class ReplayDataset(IterableDataset): 77 | def __iter__(self): 78 | return generator() 79 | 80 | dataset = ReplayDataset() 81 | dataset = DataLoader( 82 | dataset, 83 | batch, 84 | pin_memory=pin_memory, 85 | drop_last=True, 86 | worker_init_fn=seed_worker, 87 | num_workers=num_workers, 88 | **kwargs 89 | ) 90 | return dataset 91 | 92 | @property 93 | def stats(self): 94 | return { 95 | "total_steps": 0, 96 | "total_episodes": 0, 97 | "loaded_steps": self.video_dataset.total_steps, 98 | "loaded_episodes": len(self.video_dataset), 99 | } 100 | -------------------------------------------------------------------------------- /wmlib/envs/metaworld.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class MetaWorld: 8 | 9 | def __init__(self, name, seed=None, action_repeat=1, size=(64, 64), camera=None, use_gripper=False): 10 | import metaworld 11 | from metaworld.envs import ( 12 | ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE, 13 | ALL_V2_ENVIRONMENTS_GOAL_HIDDEN, 14 | ) 15 | 16 | os.environ["MUJOCO_GL"] = "egl" 17 | 18 | task = f"{name}-v2-goal-observable" 19 | env_cls = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task] 20 | self._env = env_cls(seed=seed) 21 | self._env._freeze_rand_vec = False 22 | self._size = size 23 | self._action_repeat = action_repeat 24 | self._use_gripper = use_gripper 25 | 26 | self._camera = camera 27 | 28 | @property 29 | def obs_space(self): 30 | spaces = { 31 | "image": gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), 32 | "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 33 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool), 34 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool), 35 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), 36 | "state": self._env.observation_space, 37 | "success": gym.spaces.Box(0, 1, (), dtype=bool), 38 | } 39 | if self._use_gripper: 40 | spaces["gripper_image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) 41 | return spaces 42 | 43 | @property 44 | def act_space(self): 45 | action = self._env.action_space 46 | return {"action": action} 47 | 48 | def step(self, action): 49 | assert np.isfinite(action["action"]).all(), action["action"] 50 | reward = 0.0 51 | for _ in range(self._action_repeat): 52 | state, rew, done, info = self._env.step(action["action"]) 53 | success = float(info["success"]) 54 | reward += rew or 0.0 55 | if done or success == 1.0: 56 | break 57 | assert success in [0.0, 1.0] 58 | obs = { 59 | "reward": reward, 60 | "is_first": False, 61 | "is_last": False, # will be handled by timelimit wrapper 62 | "is_terminal": False, # will be handled by per_episode function 63 | "image": self._env.sim.render( 64 | *self._size, mode="offscreen", camera_name=self._camera 65 | ), 66 | "state": state, 67 | "success": success, 68 | } 69 | if self._use_gripper: 70 | obs["gripper_image"] = self._env.sim.render( 71 | *self._size, mode="offscreen", camera_name="behindGripper" 72 | ) 73 | return obs 74 | 75 | def reset(self): 76 | if self._camera == "corner2": 77 | self._env.model.cam_pos[2][:] = [0.75, 0.075, 0.7] 78 | state = self._env.reset() 79 | obs = { 80 | "reward": 0.0, 81 | "is_first": True, 82 | "is_last": False, 83 | "is_terminal": False, 84 | "image": self._env.sim.render( 85 | *self._size, mode="offscreen", camera_name=self._camera 86 | ), 87 | "state": state, 88 | "success": False, 89 | } 90 | if self._use_gripper: 91 | obs["gripper_image"] = self._env.sim.render( 92 | *self._size, mode="offscreen", camera_name="behindGripper" 93 | ) 94 | return obs 95 | 96 | def close(self): 97 | ... -------------------------------------------------------------------------------- /examples/train_apv_pretraining.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import sys 8 | import warnings 9 | from pathlib import Path 10 | 11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 13 | logging.getLogger().setLevel("ERROR") 14 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 15 | 16 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 17 | os.environ['MUJOCO_GL'] = 'egl' 18 | 19 | sys.path.append(str(pathlib.Path(__file__).parent)) 20 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 21 | 22 | import numpy as np 23 | import ruamel.yaml as yaml_package 24 | yaml = yaml_package.YAML(typ='safe', pure=True) 25 | import torch 26 | import random 27 | 28 | import wmlib 29 | import wmlib.envs as envs 30 | import wmlib.agents as agents 31 | import wmlib.utils as utils 32 | import wmlib.datasets as datasets 33 | import wmlib.train as train 34 | 35 | 36 | def main(): 37 | 38 | configs = yaml.load( 39 | (pathlib.Path(sys.argv[0]).parent.parent / "configs" / "apv_pretraining.yaml").read_text() 40 | ) 41 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 42 | config = utils.Config(configs["defaults"]) 43 | for name in parsed.configs: 44 | config = config.update(configs[name]) 45 | config = utils.Flags(config).parse(remaining) 46 | 47 | logdir = pathlib.Path(config.logdir).expanduser() # expand the user's home directory, e.g. ~/logs to /home/user/logs 48 | load_logdir = pathlib.Path(config.load_logdir).expanduser() 49 | load_model_dir = pathlib.Path(config.load_model_dir).expanduser() 50 | logdir.mkdir(parents=True, exist_ok=True) 51 | config.save(logdir / "config.yaml") 52 | print(config, "\n") 53 | print("Logdir", logdir) 54 | print("Loading Logdir", load_logdir) 55 | 56 | assert torch.cuda.is_available(), 'No GPU found.' 57 | assert config.precision in (16, 32), config.precision 58 | if config.precision == 16: 59 | print("setting fp16") 60 | 61 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 62 | 63 | if device != "cpu": 64 | torch.set_num_threads(1) 65 | 66 | # reproducibility 67 | utils.set_seed(config.seed) 68 | train_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['video_lists'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 69 | eval_replay = None 70 | if config.eval_video_list != 'none': 71 | eval_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['eval_video_list'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 72 | 73 | step = utils.Counter(train_replay.stats["total_steps"]) 74 | wandb_config = dict(config.wandb) 75 | wandb_config['name']= f'{wandb_config["name"]}-{config["dataset_type"]}-seed{config.seed}' 76 | step = utils.Counter(train_replay.stats["total_steps"]) 77 | outputs = [ 78 | utils.TerminalOutput(), 79 | utils.JSONLOutput(logdir), 80 | utils.WandbOutput(**wandb_config,config=dict(config)) 81 | ] 82 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 83 | 84 | print("Create envs.") 85 | env = envs.make_env(config, 'train') 86 | act_space, obs_space = env.act_space, env.obs_space 87 | 88 | agent = agents.APV_Pretrain(config, obs_space, act_space, step) 89 | pretrainer = train.Pretrainer(config,agent,train_replay,eval_replay,step,logger) 90 | pretrainer.run(config.steps) 91 | pretrainer.save_agent(logdir) 92 | env.close() 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /wmlib/nets/decoder/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | from einops import rearrange, unpack 5 | 6 | from .base import BaseDecoder 7 | from ..modules import ResidualStack 8 | from ... import core 9 | 10 | 11 | class ResNetDecoder(BaseDecoder): 12 | 13 | def __init__( 14 | self, 15 | shapes, 16 | cnn_keys=r".*", 17 | mlp_keys=r".*", 18 | cnn_depth=48, 19 | cnn_input_dim = 2048, 20 | mlp_layers=[400, 400, 400, 400], 21 | res_layers=2, 22 | res_depth=3, 23 | res_norm='none', 24 | **dummy_kwargs, 25 | ): 26 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 27 | 28 | self._cnn_depth = cnn_depth 29 | self._res_layers = res_layers 30 | self._res_depth = res_depth 31 | self._res_norm = res_norm 32 | self._cnn_input_dim = cnn_input_dim 33 | 34 | # L = self._res_depth 35 | hw = 64 // 2**(self._res_depth + 1) 36 | cnn_out_channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 37 | self.convin = nn.Sequential(nn.Linear(self._cnn_input_dim, hw * hw * (2**(self._res_depth - 1)) * self._cnn_depth),Rearrange('b t (c h w)-> (b t) c h w',h=hw,w=hw)) 38 | self._cnn_nn = nn.Sequential() 39 | for i in range(self._res_depth): 40 | depth = depth // 2 if i else int((2**(self._res_depth - 1)) * self._cnn_depth) 41 | self._cnn_nn.add_module(f"unpool{i}", nn.UpsamplingNearest2d(scale_factor=2)) 42 | self._cnn_nn.add_module(f"res{i}", ResidualStack(depth, depth//2, 43 | self._res_layers, 44 | norm=self._res_norm, dec=True)) 45 | self.convout = nn.ConvTranspose2d(depth//2, sum(cnn_out_channels.values()), 3, 2, 1, output_padding=1) 46 | self._cnn_out_ps = [[out_channel] for out_channel in cnn_out_channels.values()] 47 | 48 | 49 | def _cnn(self, features): 50 | x = self.convin(features).to(memory_format=torch.channels_last) 51 | x = self._cnn_nn(x) 52 | x = self.convout(x) 53 | x = rearrange(x,'(b t) c h w -> b t c h w',b=features.shape[0]) 54 | # means = torch.split(x, list(self._cnn_out_channels.values()), 2) 55 | means = unpack(x, self._cnn_out_ps, 'b t * h w ') 56 | dists = { 57 | key: core.dists.Independent(core.dists.MSE(mean), 3) 58 | for key, mean in zip(self.cnn_keys, means) 59 | } 60 | return dists 61 | 62 | 63 | # channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 64 | 65 | # L = self._res_depth 66 | # hw = 64 // 2**(self._res_depth + 1) 67 | # x = self.get("convin", nn.Linear, features.shape[-1], hw * hw * (2**(L - 1)) * self._cnn_depth)(features) 68 | # x = torch.reshape(x, [-1, (2**(L - 1)) * self._cnn_depth, hw, hw]).to(memory_format=torch.channels_last) 69 | # for i in range(L): 70 | # x = self.get(f"unpool{i}", nn.UpsamplingNearest2d, scale_factor=2)(x) 71 | # depth = x.shape[1] 72 | # x = self.get(f"res{i}", ResidualStack, depth, depth // 2, 73 | # self._res_layers, norm=self._res_norm, dec=True)(x) 74 | 75 | # depth = sum(channels.values()) 76 | # x = self.get(f"convout", nn.ConvTranspose2d, x.shape[1], depth, 3, 2, 1, output_padding=1)(x) 77 | 78 | # x = x.reshape(features.shape[:-1] + x.shape[1:]) 79 | # means = torch.split(x, list(channels.values()), 2) 80 | # dists = { 81 | # key: core.dists.Independent(core.dists.MSE(mean), 3) 82 | # for (key, shape), mean in zip(channels.items(), means) 83 | # } 84 | # return dists 85 | -------------------------------------------------------------------------------- /wmlib/core/driver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Driver: 6 | 7 | def __init__(self, envs, device, precision = 32, **kwargs): 8 | self._envs = envs 9 | self._device = device 10 | self._precision = precision 11 | self._kwargs = kwargs 12 | self._on_steps = [] 13 | self._on_resets = [] 14 | self._on_episodes = [] 15 | self._act_spaces = [env.act_space for env in envs] 16 | self._dtype = torch.float16 if precision == 16 else torch.float32 17 | self.reset() 18 | 19 | def on_step(self, callback): 20 | self._on_steps.append(callback) 21 | 22 | def on_reset(self, callback): 23 | self._on_resets.append(callback) 24 | 25 | def on_episode(self, callback): 26 | self._on_episodes.append(callback) 27 | 28 | def reset(self): 29 | self._obs = [None] * len(self._envs) 30 | self._eps = [None] * len(self._envs) 31 | self._state = None 32 | 33 | def __call__(self, policy, steps=0, episodes=0): 34 | step, episode = 0, 0 35 | while step < steps or episode < episodes: 36 | # 1. reset check 37 | obs = { 38 | i: self._envs[i].reset() 39 | for i, ob in enumerate(self._obs) if ob is None or ob['is_last']} 40 | for i, ob in obs.items(): 41 | assert not callable(ob) 42 | # self._obs[i] = ob() if callable(ob) else ob 43 | self._obs[i] = ob 44 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()} 45 | trans = {k: self._convert(v) for k, v in {**self._obs[i], **act}.items()} 46 | [fn(trans, worker=i, **self._kwargs) for fn in self._on_resets] 47 | self._eps[i] = [trans] 48 | 49 | # 2. observe 50 | obs = {k: torch.from_numpy(np.stack([o[k] for o in self._obs])).float() for k in self._obs[0]} # convert before sending 51 | if len(obs['image'].shape) == 4: 52 | obs['image'] = obs['image'].permute(0, 3, 1, 2) 53 | 54 | # this is a hack to make it work with the current policy 55 | obs = {k: v.to(device=self._device, dtype=self._dtype) for k, v in obs.items()} 56 | 57 | # 3. policy 58 | actions, self._state = policy(obs, self._state, **self._kwargs) 59 | 60 | # 4. step 61 | actions = [{k: np.array(actions[k][i]) for k in actions} for i in range(len(self._envs))] 62 | assert len(actions) == len(self._envs) 63 | obs = [e.step(a) for e, a in zip(self._envs, actions)] 64 | # obs = [ob() if callable(ob) else ob for ob in obs] 65 | for i, (act, ob) in enumerate(zip(actions, obs)): 66 | # ob = _ob() if callable(_ob) else _ob 67 | assert not callable(ob) 68 | self._obs[i] = ob 69 | trans = {k: self._convert(v) for k, v in {**ob, **act}.items()} 70 | [fn(trans, worker=i, **self._kwargs) for fn in self._on_steps] 71 | self._eps[i].append(trans) 72 | step += 1 73 | if ob['is_last']: 74 | ep = self._eps[i] 75 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]} 76 | [fn(ep, **self._kwargs) for fn in self._on_episodes] 77 | episode += 1 78 | # self._obs = obs 79 | 80 | def _convert(self, value): 81 | value = np.array(value) 82 | if np.issubdtype(value.dtype, np.floating): 83 | return value.astype(np.float32) 84 | elif np.issubdtype(value.dtype, np.signedinteger): 85 | return value.astype(np.int32) 86 | elif np.issubdtype(value.dtype, np.uint8): 87 | return value.astype(np.uint8) 88 | return value 89 | -------------------------------------------------------------------------------- /examples/train_prelar_wo_cl_pretraining.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import sys 8 | import warnings 9 | from pathlib import Path 10 | 11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 13 | logging.getLogger().setLevel("ERROR") 14 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 15 | 16 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 17 | os.environ['MUJOCO_GL'] = 'egl' 18 | 19 | sys.path.append(str(pathlib.Path(__file__).parent)) 20 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 21 | 22 | import numpy as np 23 | import ruamel.yaml as yaml_package 24 | yaml = yaml_package.YAML(typ='safe', pure=True) 25 | import torch 26 | import random 27 | 28 | import wmlib 29 | import wmlib.envs as envs 30 | import wmlib.agents as agents 31 | import wmlib.utils as utils 32 | import wmlib.datasets as datasets 33 | import wmlib.train as train 34 | 35 | 36 | def main(): 37 | 38 | configs = yaml.load( 39 | (pathlib.Path(sys.argv[0]).parent.parent / "configs" / "prelar_wo_al_pretraining.yaml").read_text() 40 | ) 41 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 42 | config = utils.Config(configs["defaults"]) 43 | for name in parsed.configs: 44 | config = config.update(configs[name]) 45 | config = utils.Flags(config).parse(remaining) 46 | 47 | logdir = pathlib.Path(config.logdir).expanduser() # expand the user's home directory, e.g. ~/logs to /home/user/logs 48 | load_logdir = pathlib.Path(config.load_logdir).expanduser() 49 | load_model_dir = pathlib.Path(config.load_model_dir).expanduser() 50 | logdir.mkdir(parents=True, exist_ok=True) 51 | config.save(logdir / "config.yaml") 52 | print(config, "\n") 53 | print("Logdir", logdir) 54 | print("Loading Logdir", load_logdir) 55 | 56 | 57 | assert torch.cuda.is_available(), 'No GPU found.' 58 | assert config.precision in (16, 32), config.precision 59 | if config.precision == 16: 60 | print("setting fp16") 61 | 62 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 63 | 64 | if device != "cpu": 65 | torch.set_num_threads(1) 66 | 67 | # reproducibility 68 | utils.set_seed(config.seed) 69 | train_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['video_lists'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 70 | eval_replay = None 71 | if config.eval_video_list != 'none': 72 | eval_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['eval_video_list'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 73 | 74 | step = utils.Counter(train_replay.stats["total_steps"]) 75 | wandb_config = dict(config.wandb) 76 | wandb_name = f"{wandb_config['name']}(d{config.vanet.stoch}x{config.vanet.discrete})" if config.vanet.discrete else wandb_config['name'] 77 | wandb_config['name']= f'{wandb_name}-{config["dataset_type"]}-seed{config.seed}' 78 | step = utils.Counter(train_replay.stats["total_steps"]) 79 | outputs = [ 80 | utils.TerminalOutput(), 81 | utils.JSONLOutput(logdir), 82 | utils.WandbOutput(**wandb_config,config=dict(config)) 83 | ] 84 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 85 | 86 | print("Create envs.") 87 | env = envs.make_env(config, 'train') 88 | act_space, obs_space = env.act_space, env.obs_space 89 | 90 | agent = agents.PreLARwoCLPretrain(config, obs_space, act_space, step) 91 | pretrainer = train.Pretrainer(config,agent,train_replay,eval_replay,step,logger) 92 | pretrainer.run(config.steps) 93 | pretrainer.save_agent(logdir) 94 | env.close() 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /wmlib/nets/decoder/plaincnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | from einops import rearrange, unpack 5 | 6 | from .base import BaseDecoder 7 | from ..modules import get_act_module, NormLayer 8 | from ... import core 9 | 10 | 11 | class PlainCNNDecoder(BaseDecoder): 12 | 13 | def __init__( 14 | self, 15 | shapes, 16 | cnn_keys=r".*", 17 | mlp_keys=r".*", 18 | act="elu", 19 | norm="none", 20 | cnn_depth=48, 21 | cnn_input_dim = 2048, 22 | cnn_kernels=(4, 4, 4, 4), 23 | mlp_layers=[400, 400, 400, 400], 24 | **dummy_kwargs, 25 | ): 26 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 27 | 28 | # self._act = get_act(act) 29 | self._act_module = get_act_module(act) 30 | self._norm = norm 31 | self._cnn_depth = cnn_depth 32 | self._cnn_kernels = cnn_kernels 33 | self._cnn_input_dim = cnn_input_dim 34 | cnn_out_channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 35 | self.convin = nn.Sequential(nn.Linear(self._cnn_input_dim, 32 * self._cnn_depth),Rearrange('b t c -> (b t) c 1 1')) 36 | self._cnn_nn = nn.Sequential() 37 | h, w = 1, 1 38 | for i, kernel in enumerate(self._cnn_kernels): 39 | depth = int(2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth) 40 | h, w = (h - 1) * 2 + kernel, (2 - 1) * 2 + kernel # (h−1)∗s+k−2∗p 41 | input_channel = depth * 2 if i else 32 * self._cnn_depth 42 | act_module, norm = self._act_module, self._norm 43 | if i == len(self._cnn_kernels) - 1: 44 | depth, act_module, norm = sum(cnn_out_channels.values()), get_act_module('none'), 'none' 45 | self._cnn_nn.add_module(f"conv{i}", nn.ConvTranspose2d(input_channel, depth, kernel, 2)) 46 | self._cnn_nn.add_module(f"convnorm{i}", NormLayer(norm, (depth, h, w))) 47 | self._cnn_nn.add_module(f"act{i}", act_module()) 48 | self._cnn_out_ps = [[out_channel] for out_channel in cnn_out_channels.values()] 49 | 50 | 51 | 52 | 53 | def _cnn(self, features): 54 | x = self.convin(features).to(memory_format=torch.channels_last) 55 | x = self._cnn_nn(x) 56 | x = rearrange(x,'(b t) c h w -> b t c h w',b=features.shape[0]) 57 | # means = torch.split(x, list(self._cnn_out_channels.values()), 2) 58 | means = unpack(x, self._cnn_out_ps, 'b t * h w ') 59 | 60 | dists = { 61 | key: core.dists.Independent(core.dists.MSE(mean), 3) 62 | for key, mean in zip(self.cnn_keys, means) 63 | } 64 | return dists 65 | 66 | 67 | 68 | 69 | # channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 70 | # ConvT = nn.ConvTranspose2d 71 | # x = self.get("convin", nn.Linear, features.shape[-1], 32 * self._cnn_depth)(features) 72 | # x = torch.reshape(x, [-1, 32 * self._cnn_depth, 1, 1]).to(memory_format=torch.channels_last) 73 | 74 | # for i, kernel in enumerate(self._cnn_kernels): 75 | # depth = 2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth 76 | # act, norm = self._act, self._norm 77 | # if i == len(self._cnn_kernels) - 1: 78 | # depth, act, norm = sum(channels.values()), get_act("none"), "none" 79 | # x = self.get(f"conv{i}", ConvT, x.shape[1], depth, kernel, 2)(x) 80 | # x = self.get(f"convnorm{i}", NormLayer, norm, x.shape[-3:])(x) 81 | # x = act(x) 82 | 83 | # x = x.reshape(features.shape[:-1] + x.shape[1:]) 84 | # means = torch.split(x, list(channels.values()), 2) 85 | # dists = { 86 | # key: core.dists.Independent(core.dists.MSE(mean), 3) 87 | # for (key, shape), mean in zip(channels.items(), means) 88 | # } 89 | # return dists 90 | -------------------------------------------------------------------------------- /wmlib/datasets/video/somethingv2_flow.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | from .utils import VideoRecord 9 | from pathlib import Path 10 | 11 | # NOTE: you can manually select videos with specific labels for training here, by default we use all videos 12 | maunally_selected_labels = { 13 | "93": "Pushing something from left to right", 14 | "94": "Pushing something from right to left", 15 | } 16 | 17 | 18 | class SomethingV2Flow(Dataset): 19 | def __init__(self, root_path, list_file, segment_len=50, image_tmpl='{:06d}.jpg', manual_labels=False,flow_root=None,flow_tmpl='{:06d}.png'): 20 | self.root_path = Path(root_path).expanduser() 21 | self.list_file = list_file 22 | self.segment_len = segment_len 23 | self.image_tmpl = image_tmpl 24 | self.flow_tmpl = flow_tmpl 25 | if flow_root is None: 26 | self.flow_root = self.root_path / '../20bn-something-something-v2-frames-64-flow-rgb' 27 | else: 28 | self.flow_root = Path(flow_root).expanduser() 29 | self._parse_list(self.segment_len, maunally_selected_labels if manual_labels else None) 30 | 31 | def _parse_list(self, minlen, selected_labels=None): 32 | # check the frame number is large >segment_len: 33 | # usually it is [video_id, num_frames, class_idx] 34 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 35 | tmp = [item for item in tmp if int(item[1]) >= minlen and ( 36 | (selected_labels is None) or (item[2] in selected_labels.keys()))] 37 | self.video_list = [VideoRecord(item) for item in tmp] 38 | print('video number:%d' % (len(self.video_list))) 39 | 40 | @property 41 | def total_steps(self): 42 | return sum([record.num_frames for record in self.video_list]) 43 | 44 | def _load_image(self, directory, idx): 45 | # TODO: cache 46 | # image = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB') 47 | image = Image.open(self.root_path / directory / self.image_tmpl.format(idx)).convert('RGB') 48 | return np.array(image) 49 | 50 | def _load_flow(self, directory, idx): 51 | idx = idx + 1 if idx == 1 else idx # no 1st frame flow 52 | image = Image.open(self.flow_root / directory / self.flow_tmpl.format(idx)).convert('RGB') 53 | return np.array(image) 54 | 55 | def _sample_index(self, record): 56 | return np.random.randint(0, record.num_frames - self.segment_len + 1) +1 # start from 1 57 | 58 | def get(self, record, ind): 59 | images = [] 60 | p = ind 61 | for i in range(self.segment_len): 62 | seg_imgs = self._load_image(record.path, p) 63 | seg_flows = self._load_flow(record.path, p) 64 | images.append(np.concatenate([seg_imgs,seg_flows],axis=-1)) 65 | if p < record.num_frames: 66 | p += 1 67 | # images = self.transform(images) 68 | return np.array(images) 69 | 70 | def __getitem__(self, index): 71 | record = self.video_list[index] 72 | # check this is a legit video folder 73 | while not (self.root_path / record.path / self.image_tmpl.format(1)).exists(): 74 | print(self.root_path / record.path / self.image_tmpl.format(1)) 75 | # while not os.path.exists(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))): 76 | # print(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))) 77 | index = np.random.randint(len(self.video_list)) 78 | record = self.video_list[index] 79 | 80 | segment_index = self._sample_index(record) 81 | segment = self.get(record, segment_index) 82 | return segment 83 | 84 | def __len__(self): 85 | return len(self.video_list) 86 | -------------------------------------------------------------------------------- /wmlib/utils/flags.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | 5 | class Flags: 6 | 7 | def __init__(self, *args, **kwargs): 8 | from .config import Config 9 | 10 | self._config = Config(*args, **kwargs) 11 | 12 | def parse(self, argv=None, known_only=False, help_exists=None): 13 | if help_exists is None: 14 | help_exists = not known_only 15 | if argv is None: 16 | argv = sys.argv[1:] 17 | if "--help" in argv: 18 | print("\nHelp:") 19 | lines = str(self._config).split("\n")[2:] 20 | print("\n".join("--" + re.sub(r"[:,\[\]]", "", x) for x in lines)) 21 | help_exists and sys.exit() 22 | parsed = {} 23 | remaining = [] 24 | key = None 25 | vals = None 26 | for arg in argv: 27 | if arg.startswith("--"): 28 | if key: 29 | self._submit_entry(key, vals, parsed, remaining) 30 | if "=" in arg: 31 | key, val = arg.split("=", 1) 32 | vals = [val] 33 | else: 34 | key, vals = arg, [] 35 | else: 36 | if key: 37 | vals.append(arg) 38 | else: 39 | remaining.append(arg) 40 | self._submit_entry(key, vals, parsed, remaining) 41 | parsed = self._config.update(parsed) 42 | if known_only: 43 | return parsed, remaining 44 | else: 45 | for flag in remaining: 46 | if flag.startswith("--"): 47 | raise ValueError(f"Flag '{flag}' did not match any config keys.") 48 | assert not remaining, remaining 49 | return parsed 50 | 51 | def _submit_entry(self, key, vals, parsed, remaining): 52 | if not key and not vals: 53 | return 54 | if not key: 55 | vals = ", ".join(f"'{x}'" for x in vals) 56 | raise ValueError(f"Values {vals} were not preceeded by any flag.") 57 | name = key[len("--") :] 58 | if "=" in name: 59 | remaining.extend([key] + vals) 60 | return 61 | if self._config.IS_PATTERN.match(name): 62 | pattern = re.compile(name) 63 | keys = {k for k in self._config.flat if pattern.match(k)} 64 | elif name in self._config: 65 | keys = [name] 66 | else: 67 | keys = [] 68 | if not keys: 69 | remaining.extend([key] + vals) 70 | return 71 | if not vals: 72 | raise ValueError(f"Flag '{key}' was not followed by any values.") 73 | for key in keys: 74 | parsed[key] = self._parse_flag_value(self._config[key], vals, key) 75 | 76 | def _parse_flag_value(self, default, value, key): 77 | value = value if isinstance(value, (tuple, list)) else (value,) 78 | if isinstance(default, (tuple, list)): 79 | if len(value) == 1 and "," in value[0]: 80 | value = value[0].split(",") 81 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value) 82 | assert len(value) == 1, value 83 | value = str(value[0]) 84 | if default is None: 85 | return value 86 | if isinstance(default, bool): 87 | try: 88 | return bool(["False", "True"].index(value)) 89 | except ValueError: 90 | message = f"Expected bool but got '{value}' for key '{key}'." 91 | raise TypeError(message) 92 | if isinstance(default, int): 93 | value = float(value) # Allow scientific notation for integers. 94 | if float(int(value)) != value: 95 | message = f"Expected int but got float '{value}' for key '{key}'." 96 | raise TypeError(message) 97 | return int(value) 98 | return type(default)(value) 99 | -------------------------------------------------------------------------------- /wmlib/envs/robodesk.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class RoboDesk: 8 | 9 | def __init__(self, name, seed=None, action_repeat=1, size=(64, 64), camera=None, use_gripper=False, evaluate=False): 10 | import robodesk 11 | 12 | 13 | os.environ["MUJOCO_GL"] = "egl" 14 | 15 | task = f"{name}" 16 | reward_type = 'success' if evaluate else 'dense' 17 | self._evaluate = evaluate 18 | self._env = robodesk.RoboDesk(task=task, reward=reward_type, action_repeat=action_repeat, episode_length=500, image_size=size[0]) 19 | self._env._freeze_rand_vec = False 20 | self._size = size 21 | self._action_repeat = action_repeat 22 | self._use_gripper = use_gripper 23 | 24 | self._camera = camera 25 | 26 | @property 27 | def obs_space(self): 28 | spaces = { 29 | "image": gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), 30 | "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 31 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool), 32 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool), 33 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), 34 | "state": self._env.observation_space["qpos_robot"], 35 | "success": gym.spaces.Box(0, 1, (), dtype=bool), 36 | } 37 | # if self._use_gripper: 38 | # spaces["gripper_image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) 39 | return spaces 40 | 41 | @property 42 | def act_space(self): 43 | action = self._env.action_space 44 | return {"action": action} 45 | 46 | def step(self, action): 47 | assert np.isfinite(action["action"]).all(), action["action"] 48 | reward = 0.0 49 | for _ in range(self._action_repeat): 50 | state, rew, done, info = self._env.step(action["action"]) 51 | if self._evaluate: 52 | success = rew #float(info["success"]) 53 | else: 54 | success = 0 55 | reward += rew or 0.0 56 | if done or success == 1.0: 57 | break 58 | assert success in [0.0, 1.0] 59 | obs = { 60 | "reward": reward, 61 | "is_first": False, 62 | "is_last": False, # will be handled by timelimit wrapper 63 | "is_terminal": False, # will be handled by per_episode function 64 | "image": state["image"], 65 | "state": state["qpos_robot"], 66 | "success": success, 67 | } 68 | # if self._use_gripper: 69 | # obs["gripper_image"] = self._env.sim.render( 70 | # *self._size, mode="offscreen", camera_name="behindGripper" 71 | # ) 72 | return obs 73 | 74 | def reset(self): 75 | # if self._camera == "corner2": 76 | # self._env.model.cam_pos[2][:] = [0.75, 0.075, 0.7] 77 | state = self._env.reset() 78 | obs = { 79 | "reward": 0.0, 80 | "is_first": True, 81 | "is_last": False, 82 | "is_terminal": False, 83 | "image": state["image"], 84 | "state": state["qpos_robot"], 85 | "success": False, 86 | } 87 | # if self._use_gripper: 88 | # obs["gripper_image"] = self._env.sim.render( 89 | # *self._size, mode="offscreen", camera_name="behindGripper" 90 | # ) 91 | return obs 92 | 93 | def close(self): 94 | ... 95 | 96 | if __name__ == '__main__': 97 | robo_desk_env = RoboDesk('open_slide',0,1,(64,64),evaluate=False) 98 | obs = robo_desk_env.reset() 99 | done = False 100 | while not done: 101 | action = robo_desk_env.act_space["action"].sample() 102 | action = {"action":action} 103 | obs = robo_desk_env.step(action) -------------------------------------------------------------------------------- /examples/train_prelar_pretraining.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import sys 8 | import warnings 9 | from pathlib import Path 10 | 11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 13 | logging.getLogger().setLevel("ERROR") 14 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 15 | 16 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 17 | os.environ['MUJOCO_GL'] = 'egl' 18 | 19 | sys.path.append(str(pathlib.Path(__file__).parent)) 20 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 21 | 22 | import numpy as np 23 | import ruamel.yaml as yaml_package 24 | yaml = yaml_package.YAML(typ='safe', pure=True) 25 | import torch 26 | import random 27 | 28 | import wmlib 29 | import wmlib.envs as envs 30 | import wmlib.agents as agents 31 | import wmlib.utils as utils 32 | import wmlib.datasets as datasets 33 | import wmlib.train as train 34 | 35 | 36 | def main(): 37 | 38 | configs = yaml.load( 39 | (pathlib.Path(sys.argv[0]).parent.parent / "configs" / "prelar_pretraining.yaml").read_text() 40 | ) 41 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 42 | config = utils.Config(configs["defaults"]) 43 | for name in parsed.configs: 44 | config = config.update(configs[name]) 45 | config = utils.Flags(config).parse(remaining) 46 | 47 | logdir = pathlib.Path(config.logdir).expanduser() # expand the user's home directory, e.g. ~/logs to /home/user/logs 48 | load_logdir = pathlib.Path(config.load_logdir).expanduser() 49 | load_model_dir = pathlib.Path(config.load_model_dir).expanduser() 50 | logdir.mkdir(parents=True, exist_ok=True) 51 | config.save(logdir / "config.yaml") 52 | print(config, "\n") 53 | print("Logdir", logdir) 54 | print("Loading Logdir", load_logdir) 55 | 56 | # utils.snapshot_src(".", logdir / "src", ".gitignore") 57 | 58 | assert torch.cuda.is_available(), 'No GPU found.' 59 | assert config.precision in (16, 32), config.precision 60 | if config.precision == 16: 61 | print("setting fp16") 62 | 63 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 64 | 65 | if device != "cpu": 66 | torch.set_num_threads(1) 67 | 68 | # reproducibility 69 | utils.set_seed(config.seed) 70 | train_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['video_lists'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 71 | eval_replay = None 72 | if config.eval_video_list != 'none': 73 | eval_replay = datasets.make_action_free_dataset(config['dataset_type'],config['video_dirs'],config['eval_video_list'],config['replay']['minlen'],config['manual_labels'],seed=config.seed,**config.replay) 74 | 75 | step = utils.Counter(train_replay.stats["total_steps"]) 76 | wandb_config = dict(config.wandb) 77 | #wandb_config['name'] + '-seed' + str(config.seed) 78 | wandb_name = f"{wandb_config['name']}(d{config.vanet.stoch}x{config.vanet.discrete})" if config.vanet.discrete else wandb_config['name'] 79 | wandb_config['name']= f'{wandb_name}-{config["dataset_type"]}-seed{config.seed}' 80 | step = utils.Counter(train_replay.stats["total_steps"]) 81 | outputs = [ 82 | utils.TerminalOutput(), 83 | utils.JSONLOutput(logdir), 84 | utils.WandbOutput(**wandb_config,config=dict(config)) 85 | ] 86 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 87 | 88 | print("Create envs.") 89 | env = envs.make_env(config, 'train') 90 | act_space, obs_space = env.act_space, env.obs_space 91 | 92 | agent = agents.PreLARPretrain(config, obs_space, act_space, step) 93 | pretrainer = train.Pretrainer(config,agent,train_replay,eval_replay,step,logger) 94 | pretrainer.run(config.steps) 95 | pretrainer.save_agent(logdir) 96 | env.close() 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # compressed 165 | *.zip 166 | *.rar 167 | *.gz 168 | *.7z 169 | *.tar 170 | *.tgz 171 | *.bzip 172 | *.bzip2 173 | *.bz2 174 | *.xz 175 | 176 | # development temporary files 177 | .vscode/ 178 | dev/ 179 | log/ 180 | logs/ 181 | 182 | # larger files 183 | pre-train_models/ 184 | *.pt 185 | *.pth 186 | *.h5 187 | *.npz 188 | -------------------------------------------------------------------------------- /examples/train_dreamerv2.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import warnings 8 | from pathlib import Path 9 | 10 | logging.getLogger().setLevel("ERROR") 11 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 12 | 13 | sys.path.append(str(Path(__file__).parent)) 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | import numpy as np 17 | import ruamel.yaml as yaml_package 18 | yaml = yaml_package.YAML(typ='safe', pure=True) 19 | 20 | import torch 21 | import random 22 | 23 | import wmlib 24 | import wmlib.envs as envs 25 | import wmlib.agents as agents 26 | import wmlib.utils as utils 27 | import wmlib.train as train 28 | 29 | 30 | 31 | 32 | def main(): 33 | 34 | configs = yaml.load( 35 | (Path(sys.argv[0]).parent.parent / "configs" / "dreamerv2.yaml").read_text() 36 | ) 37 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 38 | config = utils.Config(configs["defaults"]) 39 | for name in parsed.configs: 40 | config = config.update(configs[name]) 41 | config = utils.Flags(config).parse(remaining) 42 | 43 | logdir = Path(config.logdir).expanduser() 44 | logdir.mkdir(parents=True, exist_ok=True) 45 | config.save(logdir / "config.yaml") 46 | print(config, "\n") 47 | print("Set log directory:", logdir) 48 | 49 | assert torch.cuda.is_available(), 'No GPU found.' 50 | assert config.precision in (16, 32), config.precision 51 | 52 | if config.precision == 16: 53 | print("setting fp16") 54 | 55 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 56 | 57 | if device != "cpu": 58 | torch.set_num_threads(1) 59 | 60 | # reproducibility 61 | utils.set_seed(config.seed) 62 | 63 | train_replay = wmlib.Replay(logdir / "train_episodes", seed=config.seed, **config.replay) 64 | eval_replay = wmlib.Replay(logdir / "eval_episodes", seed=config.seed, **dict( 65 | capacity=config.replay.capacity // 10, 66 | minlen=config.dataset.length, 67 | maxlen=config.dataset.length)) 68 | 69 | step = utils.Counter(train_replay.stats["total_steps"]) 70 | wandb_config = dict(config.wandb) 71 | task_name = '-'.join(config.task.lower().split("_", 1)) 72 | 73 | wandb_config['name']= f'{wandb_config["name"]}-{task_name}-seed{config.seed}' 74 | outputs = [ 75 | utils.TerminalOutput(), 76 | utils.JSONLOutput(logdir), 77 | utils.WandbOutput(**wandb_config,config=dict(config)) 78 | ] 79 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 80 | 81 | 82 | # save experiment used config 83 | with open(logdir / "used_config.yaml", "w") as f: 84 | f.write("## command line input:\n## " + " ".join(sys.argv) + "\n##########\n\n") 85 | yaml.dump(dict(config), f) 86 | 87 | print("Create envs.") 88 | is_carla = config.task.split("_", 1)[0] == 'carla' 89 | num_eval_envs = min(config.envs, config.eval_eps) 90 | 91 | # only one env for carla 92 | if is_carla: 93 | assert config.envs == 1 and num_eval_envs == 1 94 | if config.envs_parallel == "none": 95 | train_envs = [envs.make_env(config, "train") for _ in range(config.envs)] 96 | eval_envs = [envs.make_env(config,"eval") for _ in range(num_eval_envs)] 97 | else: 98 | train_envs = [envs.make_async_env(config, "train") for _ in range(config.envs)] 99 | eval_envs = [envs.make_async_env(config, "eval") for _ in range(num_eval_envs)] 100 | 101 | agent = agents.DreamerV2(config, train_envs[0].obs_space, train_envs[0].act_space, step) 102 | trainer = train.Trainer(config, agent, train_replay, eval_replay, train_envs, eval_envs, step, logger) 103 | try: 104 | trainer.run(config.steps) 105 | except KeyboardInterrupt: 106 | print("Keyboard Interrupt - saving agent") 107 | trainer.save_agent(logdir / "variables.pt") 108 | except Exception as e: 109 | print("Training Error:", e) 110 | trainer.save_agent(logdir / "variables_error.pt") 111 | raise e 112 | finally: 113 | trainer.save_agent(logdir / "variables.pt") 114 | for env in train_envs + eval_envs: 115 | env.close() 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /wmlib/agents/expl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as tdist 4 | 5 | from .. import core 6 | from ..core import dists 7 | 8 | 9 | class Random(nn.Module): 10 | 11 | def __init__(self, action_space): 12 | super(Random, self).__init__() 13 | self._action_space = action_space 14 | 15 | def actor(self, feat): 16 | shape = feat.shape[:-1] + [self._action_space.shape[-1]] 17 | if self.config.actor.dist == 'onehot': 18 | return dists.OneHotDist(torch.zeros(shape)) 19 | else: 20 | dist = tdist.Uniform(-torch.ones(shape), torch.ones(shape)) 21 | return tdist.Independent(dist, 1) 22 | 23 | def train(self, start, context, data): 24 | return None, {} 25 | 26 | 27 | # TODO support plan2explore & model loss 28 | 29 | 30 | class Plan2Explore(nn.Module): 31 | pass 32 | 33 | 34 | class ModelLoss(nn.Module): 35 | pass 36 | 37 | 38 | class VideoIntrBonus(nn.Module): 39 | def __init__( 40 | self, 41 | beta, 42 | k, 43 | intr_seq_length, 44 | feat_dim, 45 | queue_dim, 46 | queue_size, 47 | reward_norm, 48 | beta_type='abs', 49 | ) -> None: 50 | super().__init__() 51 | 52 | self.beta = beta 53 | self.k = k 54 | self.intr_seq_length = intr_seq_length 55 | self.tf_queue_step = 0 56 | self.tf_queue_size = queue_size 57 | shape = (feat_dim, queue_dim) 58 | self.random_projection_matrix = torch.nn.Parameter( 59 | torch.normal(mean=torch.zeros(shape), std=torch.ones(shape) / queue_dim), 60 | requires_grad=False, 61 | ) 62 | self.register_buffer('queue', torch.zeros(queue_size, queue_dim)) 63 | self.intr_rewnorm = core.StreamNorm(**reward_norm) 64 | 65 | self.beta_type = beta_type 66 | if self.beta_type == 'rel': 67 | self.plain_rewnorm = core.StreamNorm() 68 | 69 | def construct_queue(self, seq_feat): 70 | with torch.no_grad(): 71 | seq_size = seq_feat.shape[0] 72 | self.queue.data[seq_size:] = self.queue.data[:-seq_size].clone() 73 | self.queue.data[:seq_size] = seq_feat.data 74 | 75 | self.tf_queue_step = self.tf_queue_step + seq_size 76 | self.tf_queue_step = min(self.tf_queue_step, self.tf_queue_size) 77 | return self.queue[: self.tf_queue_step] 78 | 79 | def compute_bonus(self, data, feat): 80 | with torch.no_grad(): 81 | seq_feat = feat 82 | # NOTE: seq_feat [B, T, D], after unfold [B, T-S+1, D, S] 83 | seq_feat = seq_feat.unfold(dimension=1, size=self.intr_seq_length, step=1).mean(dim=-1) 84 | seq_feat = torch.matmul(seq_feat, self.random_projection_matrix) 85 | b, t, d = (seq_feat.shape[0], seq_feat.shape[1], seq_feat.shape[2]) 86 | seq_feat = torch.reshape(seq_feat, (b * t, d)) 87 | queue = self.construct_queue(seq_feat) 88 | dist = torch.norm(seq_feat[:, None, :] - queue[None, :, :], dim=-1) 89 | int_rew = -1.0 * torch.topk( 90 | -dist, k=min(self.k, queue.shape[0]) 91 | ).values.mean(1) 92 | int_rew = int_rew.detach() 93 | int_rew, int_rew_mets = self.intr_rewnorm(int_rew) 94 | int_rew_mets = {f"intr_{k}": v for k, v in int_rew_mets.items()} 95 | int_rew = torch.reshape(int_rew, (b, t)) 96 | 97 | plain_reward = data["reward"] 98 | 99 | if self.beta_type == 'abs': 100 | data["reward"] = data["reward"][:, :t] + self.beta * int_rew.detach() 101 | elif self.beta_type == 'rel': 102 | self.plain_rewnorm.update(data["reward"]) 103 | beta = self.beta * self.plain_rewnorm.mag.item() 104 | data["reward"] = data["reward"][:, :t] + beta * int_rew.detach() 105 | int_rew_mets["abs_beta"] = beta 106 | int_rew_mets["plain_reward_mean"] = self.plain_rewnorm.mag.item() 107 | else: 108 | raise NotImplementedError 109 | 110 | if int_rew_mets['intr_mean'] < 1e-5: 111 | print("intr_rew too small:", int_rew_mets['intr_mean']) 112 | 113 | int_rew_mets["plain_reward_mean"] = plain_reward.mean().item() 114 | int_rew_mets["intr_mag"] = self.intr_rewnorm.mag.item() 115 | 116 | return data, t, int_rew_mets 117 | -------------------------------------------------------------------------------- /configs/apv_pretraining.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: ./dev/log 5 | load_logdir: ./dev/null 6 | load_model_dir: ./dev/null 7 | video_dir: ./dev/null 8 | video_dirs: {something: ./dev/null, rlbench: ./dev/null} 9 | seed: 0 10 | device: cuda 11 | wandb: {project: 'world-model',name: 'apv_pretraining',mode: 'online'} 12 | task: metaworld_drawer_open 13 | render_size: [64, 64] 14 | dmc_camera: -1 15 | camera: none 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | video_every: 2000 22 | eval_every: 5000 23 | pretrain: 1 24 | train_every: 5 25 | train_steps: 1 26 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 27 | dataset: {batch: 16, length: 50} 28 | log_keys_video: ['image'] 29 | log_keys_sum: '^$' 30 | log_keys_mean: '^$' 31 | log_keys_max: '^$' 32 | precision: 16 33 | jit: True 34 | 35 | eval_video_list: none 36 | save_all_models: False 37 | 38 | # Agent 39 | clip_rewards: tanh 40 | 41 | # World Model 42 | grad_heads: [decoder] 43 | rssm: {action_free: True, fill_action: 50, ensemble: 1,embed_dim: 3072, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 44 | encoder_type: resnet # ['plaincnn', 'resnet', 'ctx_resnet'] 45 | encoder: { 46 | mlp_keys: '.*', 47 | cnn_keys: '.*', 48 | act: elu, 49 | norm: none, 50 | cnn_depth: 48, 51 | cnn_kernels: [4, 4, 4, 4], 52 | mlp_layers: [400, 400, 400, 400], 53 | res_norm: 'batch', 54 | res_depth: 3, 55 | res_layers: 2, 56 | } 57 | decoder_type: resnet # ['plaincnn', 'resnet', 'ctx_resnet'] 58 | decoder: { 59 | mlp_keys: '.*', 60 | cnn_keys: '.*', 61 | act: elu, 62 | norm: none, 63 | cnn_depth: 48, 64 | cnn_kernels: [5, 5, 6, 6], 65 | mlp_layers: [400, 400, 400, 400], 66 | res_norm: 'batch', 67 | res_depth: 3, 68 | res_layers: 2, 69 | } 70 | loss_scales: { 71 | kl: 1.0, 72 | image: 1.0 73 | } 74 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 75 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 76 | 77 | dataset_type: replay 78 | video_list: none 79 | video_lists: {something: none, rlbench: none} 80 | manual_labels: False 81 | # num_workers: 8 82 | 83 | # Contextualized World Model (subset of Decoupled World Model) 84 | # Decoupled World Model 85 | encoder_deco: { 86 | deco_res_layers: 2, 87 | deco_cnn_depth: 48, 88 | deco_cond_choice: trand, 89 | ctx_aug: none, 90 | } 91 | decoder_deco: { 92 | deco_attmask: 0.75, 93 | ctx_attmaskwarmup: -1, 94 | } 95 | 96 | 97 | something_pretrain: 98 | 99 | task: metaworld_drawer_open 100 | video_dirs: {something: dataset/Something-Something/20bn-something-something-v2-frames-64} 101 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 102 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 103 | replay: {minlen: 25, maxlen: 25} 104 | dataset: {batch: 16, length: 25} 105 | action_repeat: 1 106 | steps: 5e7 107 | log_every: 100 108 | train_every: 1 109 | rssm: {hidden: 1024, deter: 1024} 110 | grad_heads: [decoder] 111 | model_opt.lr: 3e-4 112 | # loss_scales.kl: 0.1 113 | 114 | dataset_type: something 115 | video_lists: {something: train_video_folder} 116 | manual_labels: False 117 | 118 | 119 | 120 | debug: 121 | jit: False 122 | time_limit: 100 123 | eval_every: 300 124 | log_every: 300 125 | pretrain: 1 126 | train_steps: 1 127 | replay: {minlen: 10, maxlen: 30} 128 | dataset: {batch: 10, length: 10} 129 | 130 | 131 | small: 132 | rssm: {hidden: 200, deter: 200} 133 | 134 | 135 | plainresnet: 136 | encoder_type: resnet 137 | decoder_type: resnet 138 | 139 | 140 | contextualized: 141 | encoder_type: deco_resnet 142 | decoder_type: deco_resnet 143 | 144 | 145 | rlbench_pretrain: 146 | task: metaworld_drawer_open 147 | video_dirs: {rlbench: dataset/rlbench/train_episodes} 148 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 149 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 150 | replay: {minlen: 25, maxlen: 25} 151 | dataset: {batch: 16, length: 25} 152 | action_repeat: 1 153 | steps: 5e7 154 | log_every: 100 155 | train_every: 1 156 | rssm: {hidden: 1024, deter: 1024} 157 | grad_heads: [decoder] 158 | model_opt.lr: 3e-4 159 | 160 | dataset_type: rlbench 161 | video_lists: {rlbench: null} 162 | manual_labels: False 163 | -------------------------------------------------------------------------------- /examples/train_apv_finetuning.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import warnings 8 | from pathlib import Path 9 | 10 | logging.getLogger().setLevel("ERROR") 11 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 12 | 13 | sys.path.append(str(Path(__file__).parent)) 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | import numpy as np 17 | import ruamel.yaml as yaml_package 18 | yaml = yaml_package.YAML(typ='safe', pure=True) 19 | import torch 20 | import random 21 | 22 | import wmlib 23 | import wmlib.envs as envs 24 | import wmlib.agents as agents 25 | import wmlib.utils as utils 26 | import wmlib.train as train 27 | 28 | 29 | def main(): 30 | 31 | configs = yaml.load( 32 | (Path(sys.argv[0]).parent.parent / "configs" / "apv_finetuning.yaml").read_text() 33 | ) 34 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 35 | config = utils.Config(configs["defaults"]) 36 | for name in parsed.configs: 37 | config = config.update(configs[name]) 38 | config = utils.Flags(config).parse(remaining) 39 | 40 | logdir = Path(config.logdir).expanduser() 41 | logdir.mkdir(parents=True, exist_ok=True) 42 | config.save(logdir / "config.yaml") 43 | print(config, "\n") 44 | print("Logdir", logdir) 45 | if config.load_logdir != "none": 46 | load_logdir = Path(config.load_logdir).expanduser() 47 | print("Loading Logdir", load_logdir) 48 | prtrain_dataset_prefix = str(load_logdir).split('-',1)[1] 49 | 50 | assert torch.cuda.is_available(), 'No GPU found.' 51 | assert config.precision in (16, 32), config.precision 52 | if config.precision == 16: 53 | print("setting fp16") 54 | 55 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 56 | 57 | if device != "cpu": 58 | torch.set_num_threads(1) 59 | 60 | # reproducibility 61 | utils.set_seed(config.seed) 62 | 63 | train_replay = wmlib.Replay(logdir / "train_episodes", seed=config.seed, **config.replay) 64 | eval_replay = wmlib.Replay(logdir / "eval_episodes", seed=config.seed, **dict( 65 | capacity=config.replay.capacity // 10, 66 | minlen=config.dataset.length, 67 | maxlen=config.dataset.length)) 68 | step = utils.Counter(train_replay.stats["total_steps"]) 69 | wandb_config = dict(config.wandb) 70 | task_name = '-'.join(config.task.lower().split("_", 1)) 71 | 72 | wandb_config['name']= f'{wandb_config["name"]}-{prtrain_dataset_prefix}-{task_name}-seed{config.seed}' 73 | outputs = [ 74 | utils.TerminalOutput(), 75 | utils.JSONLOutput(logdir), 76 | utils.WandbOutput(**wandb_config,config=dict(config)) 77 | ] 78 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 79 | 80 | # save experiment used config 81 | with open(logdir / "used_config.yaml", "w") as f: 82 | f.write("## command line input:\n## " + " ".join(sys.argv) + "\n##########\n\n") 83 | yaml.dump(dict(config), f) 84 | 85 | is_carla = config.task.split("_", 1)[0] == 'carla' 86 | num_eval_envs = min(config.envs, config.eval_eps) 87 | # only one env for carla 88 | if is_carla: 89 | assert config.envs == 1 and num_eval_envs == 1 90 | if config.envs_parallel == "none": 91 | train_envs = [envs.make_env(config, "train") for _ in range(config.envs)] 92 | eval_envs = [envs.make_env(config,"eval") for _ in range(num_eval_envs)] 93 | else: 94 | train_envs = [envs.make_async_env(config, "train") for _ in range(config.envs)] 95 | eval_envs = [envs.make_async_env(config, "eval") for _ in range(num_eval_envs)] 96 | 97 | # the agent needs 1. init modules 2. go to device 3. set optimizer 98 | agent = agents.APV_Finetune(config, train_envs[0].obs_space, train_envs[0].act_space, step) 99 | finetuner = train.Finetuner(config, agent, train_replay, eval_replay, train_envs, eval_envs, step, logger) 100 | 101 | try: 102 | finetuner.run(config.steps) 103 | except KeyboardInterrupt: 104 | print("Keyboard Interrupt - saving agent") 105 | finetuner.save_agent(logdir / "variables.pt") 106 | except Exception as e: 107 | print("Training Error:", e) 108 | finetuner.save_agent(logdir / "variables_error.pt") 109 | raise e 110 | finally: 111 | finetuner.save_agent(logdir / "variables.pt") 112 | for env in train_envs + eval_envs: 113 | env.close() 114 | 115 | 116 | if __name__ == "__main__": 117 | __spec__ = "ModuleSpec(name='builtins', loader=)" 118 | main() 119 | -------------------------------------------------------------------------------- /examples/train_naive_finetune.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import warnings 8 | from pathlib import Path 9 | 10 | logging.getLogger().setLevel("ERROR") 11 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 12 | 13 | sys.path.append(str(Path(__file__).parent)) 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | import numpy as np 17 | import ruamel.yaml as yaml_package 18 | yaml = yaml_package.YAML(typ='safe', pure=True) 19 | import torch 20 | import random 21 | 22 | import wmlib 23 | import wmlib.envs as envs 24 | import wmlib.agents as agents 25 | import wmlib.utils as utils 26 | import wmlib.train as train 27 | 28 | 29 | def main(): 30 | 31 | configs = yaml.load( 32 | (Path(sys.argv[0]).parent.parent / "configs" / "naive_finetuning.yaml").read_text() 33 | ) 34 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 35 | config = utils.Config(configs["defaults"]) 36 | for name in parsed.configs: 37 | config = config.update(configs[name]) 38 | config = utils.Flags(config).parse(remaining) 39 | 40 | logdir = Path(config.logdir).expanduser() 41 | logdir.mkdir(parents=True, exist_ok=True) 42 | config.save(logdir / "config.yaml") 43 | print(config, "\n") 44 | print("Logdir", logdir) 45 | if config.load_logdir != "none": 46 | load_logdir = Path(config.load_logdir).expanduser() 47 | print("Loading Logdir", load_logdir) 48 | prtrain_dataset_prefix = str(load_logdir).split('-',1)[1] 49 | 50 | # utils.snapshot_src(".", logdir / "src", ".gitignore") 51 | 52 | assert torch.cuda.is_available(), 'No GPU found.' 53 | assert config.precision in (16, 32), config.precision 54 | if config.precision == 16: 55 | print("setting fp16") 56 | 57 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 58 | 59 | if device != "cpu": 60 | torch.set_num_threads(1) 61 | 62 | # reproducibility 63 | utils.set_seed(config.seed) 64 | 65 | train_replay = wmlib.Replay(logdir / "train_episodes", seed=config.seed, **config.replay) 66 | eval_replay = wmlib.Replay(logdir / "eval_episodes", seed=config.seed, **dict( 67 | capacity=config.replay.capacity // 10, 68 | minlen=config.dataset.length, 69 | maxlen=config.dataset.length)) 70 | step = utils.Counter(train_replay.stats["total_steps"]) 71 | wandb_config = dict(config.wandb) 72 | # wandb_config['name']= wandb_config['name'] + '-seed' + str(config.seed) 73 | task_name = '-'.join(config.task.lower().split("_", 1)) 74 | 75 | wandb_config['name']= f'{wandb_config["name"]}-{prtrain_dataset_prefix}-{task_name}-seed{config.seed}' 76 | outputs = [ 77 | utils.TerminalOutput(), 78 | utils.JSONLOutput(logdir), 79 | utils.WandbOutput(**wandb_config,config=dict(config)) 80 | ] 81 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 82 | 83 | # save experiment used config 84 | with open(logdir / "used_config.yaml", "w") as f: 85 | f.write("## command line input:\n## " + " ".join(sys.argv) + "\n##########\n\n") 86 | yaml.dump(dict(config), f) 87 | 88 | is_carla = config.task.split("_", 1)[0] == 'carla' 89 | num_eval_envs = min(config.envs, config.eval_eps) 90 | # only one env for carla 91 | if is_carla: 92 | assert config.envs == 1 and num_eval_envs == 1 93 | if config.envs_parallel == "none": 94 | train_envs = [envs.make_env(config, "train") for _ in range(config.envs)] 95 | eval_envs = [envs.make_env(config,"eval") for _ in range(num_eval_envs)] 96 | else: 97 | train_envs = [envs.make_async_env(config, "train") for _ in range(config.envs)] 98 | eval_envs = [envs.make_async_env(config, "eval") for _ in range(num_eval_envs)] 99 | 100 | # the agent needs 1. init modules 2. go to device 3. set optimizer 101 | agent = agents.Naive_Finetune(config, train_envs[0].obs_space, train_envs[0].act_space, step) 102 | finetuner = train.Finetuner(config, agent, train_replay, eval_replay, train_envs, eval_envs, step, logger) 103 | 104 | try: 105 | finetuner.run(config.steps) 106 | except KeyboardInterrupt: 107 | print("Keyboard Interrupt - saving agent") 108 | finetuner.save_agent(logdir / "variables.pt") 109 | except Exception as e: 110 | print("Training Error:", e) 111 | finetuner.save_agent(logdir / "variables_error.pt") 112 | raise e 113 | finally: 114 | finetuner.save_agent(logdir / "variables.pt") 115 | for env in train_envs + eval_envs: 116 | env.close() 117 | 118 | 119 | if __name__ == "__main__": 120 | __spec__ = "ModuleSpec(name='builtins', loader=)" 121 | main() 122 | -------------------------------------------------------------------------------- /configs/prelar_pretraining.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: ./dev/log 5 | load_logdir: ./dev/null 6 | load_model_dir: ./dev/null 7 | video_dir: ./dev/null 8 | video_dirs: {something: ./dev/null, rlbench: ./dev/null} 9 | seed: 0 10 | device: cuda 11 | wandb: {project: 'world-model',name: 'prelar_pretraining',mode: 'online'} 12 | task: metaworld_drawer_open 13 | render_size: [64, 64] 14 | dmc_camera: -1 15 | camera: none 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | video_every: 2000 22 | eval_every: 5000 23 | pretrain: 1 24 | train_every: 5 25 | train_steps: 1 26 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 27 | dataset: {batch: 16, length: 50} 28 | log_keys_video: ['image'] 29 | log_keys_sum: '^$' 30 | log_keys_mean: '^$' 31 | log_keys_max: '^$' 32 | precision: 16 33 | jit: True 34 | 35 | eval_video_list: none 36 | save_all_models: False 37 | 38 | # Agent 39 | clip_rewards: tanh 40 | 41 | # World Model 42 | grad_heads: [decoder] 43 | rssm: {action_free: False, ensemble: 1,embed_dim: 3072, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 44 | encoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 45 | encoder: { 46 | mlp_keys: '.*', 47 | cnn_keys: '.*', 48 | act: elu, 49 | norm: none, 50 | cnn_depth: 48, 51 | cnn_kernels: [4, 4, 4, 4], 52 | mlp_layers: [400, 400, 400, 400], 53 | res_norm: 'batch', 54 | res_depth: 3, 55 | res_layers: 2, 56 | } 57 | decoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 58 | decoder: { 59 | mlp_keys: '.*', 60 | cnn_keys: '.*', 61 | act: elu, 62 | norm: none, 63 | cnn_depth: 48, 64 | cnn_kernels: [5, 5, 6, 6], 65 | mlp_layers: [400, 400, 400, 400], 66 | res_norm: 'batch', 67 | res_depth: 3, 68 | res_layers: 2, 69 | } 70 | loss_scales: { 71 | kl: 1.0, 72 | image: 1.0 73 | } 74 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 75 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 76 | 77 | dataset_type: replay 78 | video_list: none 79 | video_lists: {something: none, rlbench: none} 80 | manual_labels: False 81 | # num_workers: 8 82 | 83 | # Contextualized World Model (subset of Decoupled World Model) 84 | # Decoupled World Model 85 | encoder_deco: { 86 | deco_res_layers: 2, 87 | deco_cnn_depth: 48, 88 | deco_cond_choice: trand, 89 | ctx_aug: none, 90 | } 91 | decoder_deco: { 92 | deco_attmask: 0.75, 93 | ctx_attmaskwarmup: -1, 94 | } 95 | 96 | vanet: { 97 | mlp_keys: '$^', 98 | cnn_keys: 'image', 99 | act: elu, 100 | norm: none, 101 | cnn_depth: 48, 102 | cnn_kernels: [4, 4, 4, 4], 103 | mlp_layers: [400, 400, 400, 400], 104 | res_norm: 'batch', 105 | res_depth: 3, 106 | res_layers: 2, 107 | hidden_dim: 1024, 108 | deter: 1024, 109 | stoch: 32, 110 | discrete: 32, 111 | std_act: sigmoid2, 112 | va_method: concate, 113 | type_: stoch, 114 | } 115 | rank_loss: sort_rank # sign 116 | 117 | 118 | something_pretrain: 119 | 120 | task: metaworld_drawer_open 121 | video_dirs: {something: dataset/Something-Something/20bn-something-something-v2-frames-64} 122 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 123 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 124 | replay: {minlen: 25, maxlen: 25} 125 | dataset: {batch: 16, length: 25} 126 | action_repeat: 1 127 | steps: 5e7 128 | log_every: 100 129 | train_every: 1 130 | rssm: {hidden: 1024, deter: 1024} 131 | grad_heads: [decoder] 132 | model_opt.lr: 3e-4 133 | 134 | dataset_type: something 135 | video_lists: {something: train_video_folder} 136 | manual_labels: False 137 | 138 | 139 | debug: 140 | jit: False 141 | time_limit: 100 142 | eval_every: 300 143 | log_every: 300 144 | pretrain: 1 145 | train_steps: 1 146 | replay: {minlen: 10, maxlen: 30} 147 | dataset: {batch: 10, length: 10} 148 | 149 | 150 | small: 151 | rssm: {hidden: 200, deter: 200} 152 | 153 | 154 | plainresnet: 155 | encoder_type: resnet 156 | decoder_type: resnet 157 | 158 | 159 | contextualized: 160 | encoder_type: deco_resnet 161 | decoder_type: deco_resnet 162 | 163 | 164 | rlbench_pretrain: 165 | task: metaworld_drawer_open 166 | video_dirs: {rlbench: dataset/rlbench/train_episodes} 167 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 168 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 169 | replay: {minlen: 25, maxlen: 25} 170 | dataset: {batch: 16, length: 25} 171 | action_repeat: 1 172 | steps: 5e7 173 | log_every: 100 174 | train_every: 1 175 | rssm: {hidden: 1024, deter: 1024} 176 | grad_heads: [decoder] 177 | model_opt.lr: 3e-4 178 | 179 | dataset_type: rlbench 180 | video_lists: {rlbench: null} 181 | manual_labels: False -------------------------------------------------------------------------------- /configs/prelar_wo_al_pretraining.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: ./dev/log 5 | load_logdir: ./dev/null 6 | load_model_dir: ./dev/null 7 | video_dir: ./dev/null 8 | video_dirs: {something: ./dev/null, rlbench: ./dev/null} 9 | seed: 0 10 | device: cuda 11 | wandb: {project: 'world-model',name: 'prelar_wo_al_pretraining',mode: 'online'} 12 | task: metaworld_drawer_open 13 | render_size: [64, 64] 14 | dmc_camera: -1 15 | camera: none 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | video_every: 2000 22 | eval_every: 5000 23 | pretrain: 1 24 | train_every: 5 25 | train_steps: 1 26 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 27 | dataset: {batch: 16, length: 50} 28 | log_keys_video: ['image'] 29 | log_keys_sum: '^$' 30 | log_keys_mean: '^$' 31 | log_keys_max: '^$' 32 | precision: 16 33 | jit: True 34 | 35 | eval_video_list: none 36 | save_all_models: False 37 | 38 | # Agent 39 | clip_rewards: tanh 40 | 41 | # World Model 42 | grad_heads: [decoder] 43 | rssm: {action_free: False, ensemble: 1,embed_dim: 3072, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 44 | encoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 45 | encoder: { 46 | mlp_keys: '.*', 47 | cnn_keys: '.*', 48 | act: elu, 49 | norm: none, 50 | cnn_depth: 48, 51 | cnn_kernels: [4, 4, 4, 4], 52 | mlp_layers: [400, 400, 400, 400], 53 | res_norm: 'batch', 54 | res_depth: 3, 55 | res_layers: 2, 56 | } 57 | decoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 58 | decoder: { 59 | mlp_keys: '.*', 60 | cnn_keys: '.*', 61 | act: elu, 62 | norm: none, 63 | cnn_depth: 48, 64 | cnn_kernels: [5, 5, 6, 6], 65 | mlp_layers: [400, 400, 400, 400], 66 | res_norm: 'batch', 67 | res_depth: 3, 68 | res_layers: 2, 69 | } 70 | loss_scales: { 71 | kl: 1.0, 72 | image: 1.0 73 | } 74 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 75 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 76 | 77 | dataset_type: replay 78 | video_list: none 79 | video_lists: {something: none, human: none, ytb: none, rlbench: none, rlbenchrandom: none, triped_walk: none} 80 | manual_labels: False 81 | 82 | # Contextualized World Model (subset of Decoupled World Model) 83 | # Decoupled World Model 84 | encoder_deco: { 85 | deco_res_layers: 2, 86 | deco_cnn_depth: 48, 87 | deco_cond_choice: trand, 88 | ctx_aug: none, 89 | } 90 | decoder_deco: { 91 | deco_attmask: 0.75, 92 | ctx_attmaskwarmup: -1, 93 | } 94 | 95 | vanet: { 96 | mlp_keys: '$^', 97 | cnn_keys: 'image', 98 | act: elu, 99 | norm: none, 100 | cnn_depth: 48, 101 | cnn_kernels: [4, 4, 4, 4], 102 | mlp_layers: [400, 400, 400, 400], 103 | res_norm: 'batch', 104 | res_depth: 3, 105 | res_layers: 2, 106 | hidden_dim: 1024, 107 | deter: 1024, 108 | stoch: 32, 109 | discrete: 32, 110 | std_act: sigmoid2, 111 | va_method: concate, 112 | type_: stoch, 113 | } 114 | 115 | 116 | something_pretrain: 117 | 118 | task: metaworld_drawer_open 119 | video_dirs: {something: dataset/Something-Something/20bn-something-something-v2-frames-64} 120 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 121 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 122 | replay: {minlen: 25, maxlen: 25} 123 | dataset: {batch: 16, length: 25} 124 | action_repeat: 1 125 | steps: 5e7 126 | log_every: 100 127 | train_every: 1 128 | rssm: {hidden: 1024, deter: 1024} 129 | grad_heads: [decoder] 130 | model_opt.lr: 3e-4 131 | 132 | dataset_type: something 133 | video_lists: {something: train_video_folder} 134 | manual_labels: False 135 | 136 | 137 | debug: 138 | jit: False 139 | time_limit: 100 140 | eval_every: 300 141 | log_every: 300 142 | pretrain: 1 143 | train_steps: 1 144 | replay: {minlen: 10, maxlen: 30} 145 | dataset: {batch: 10, length: 10} 146 | 147 | 148 | small: 149 | rssm: {hidden: 200, deter: 200} 150 | 151 | 152 | plainresnet: 153 | encoder_type: resnet 154 | decoder_type: resnet 155 | 156 | 157 | contextualized: 158 | encoder_type: deco_resnet 159 | decoder_type: deco_resnet 160 | 161 | rlbench_pretrain: 162 | task: metaworld_drawer_open 163 | video_dirs: {rlbench: dataset/rlbench/train_episodes} 164 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 165 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 166 | replay: {minlen: 25, maxlen: 25} 167 | dataset: {batch: 16, length: 25} 168 | action_repeat: 1 169 | steps: 5e7 170 | log_every: 100 171 | train_every: 1 172 | rssm: {hidden: 1024, deter: 1024} 173 | grad_heads: [decoder] 174 | model_opt.lr: 3e-4 175 | 176 | dataset_type: rlbench 177 | video_lists: {rlbench: null} 178 | manual_labels: False -------------------------------------------------------------------------------- /examples/train_prelar_finetuning.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import warnings 8 | from pathlib import Path 9 | 10 | logging.getLogger().setLevel("ERROR") 11 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 12 | 13 | sys.path.append(str(Path(__file__).parent)) 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | 16 | import numpy as np 17 | import ruamel.yaml as yaml_package 18 | yaml = yaml_package.YAML(typ='safe', pure=True) 19 | import torch 20 | import random 21 | 22 | import wmlib 23 | import wmlib.envs as envs 24 | import wmlib.agents as agents 25 | import wmlib.utils as utils 26 | import wmlib.train as train 27 | 28 | 29 | def main(): 30 | 31 | configs = yaml.load( 32 | (Path(sys.argv[0]).parent.parent / "configs" / "prelar_finetuning.yaml").read_text() 33 | ) 34 | parsed, remaining = utils.Flags(configs=["defaults"]).parse(known_only=True) 35 | config = utils.Config(configs["defaults"]) 36 | for name in parsed.configs: 37 | config = config.update(configs[name]) 38 | config = utils.Flags(config).parse(remaining) 39 | 40 | logdir = Path(config.logdir).expanduser() 41 | print("--------Logdir", config.logdir, logdir) 42 | logdir.mkdir(parents=True, exist_ok=True) 43 | config.save(logdir / "config.yaml") 44 | print(config, "\n") 45 | print("Logdir", logdir) 46 | if config.load_logdir != "none": 47 | load_logdir = Path(config.load_logdir).expanduser() 48 | print("Loading Logdir", load_logdir) 49 | prtrain_dataset_prefix = str(load_logdir).split('-',1)[1] 50 | 51 | assert torch.cuda.is_available(), 'No GPU found.' 52 | assert config.precision in (16, 32), config.precision 53 | if config.precision == 16: 54 | print("setting fp16") 55 | 56 | device = torch.device(config.device if torch.cuda.is_available() else "cpu") 57 | 58 | if device != "cpu": 59 | torch.set_num_threads(1) 60 | # reproducibility 61 | utils.set_seed(config.seed) 62 | 63 | train_replay = wmlib.Replay(logdir / "train_episodes", seed=config.seed, **config.replay) 64 | eval_replay = wmlib.Replay(logdir / "eval_episodes", seed=config.seed, **dict( 65 | capacity=config.replay.capacity // 10, 66 | minlen=config.dataset.length, 67 | maxlen=config.dataset.length)) 68 | step = utils.Counter(train_replay.stats["total_steps"]) 69 | wandb_config = dict(config.wandb) 70 | task_name = '-'.join(config.task.lower().split("_", 1)) 71 | if config.enc_lr_type == 'no_pretrain': 72 | wandb_name = wandb_config['name'] if config.finetune_rssm else wandb_config['name'] + '(nrssm)' 73 | else: 74 | wandb_name = wandb_config['name'] + '(full)' 75 | wandb_config['name']= f'{wandb_name}-{prtrain_dataset_prefix}-{task_name}-seed{config.seed}' 76 | outputs = [ 77 | utils.TerminalOutput(), 78 | utils.JSONLOutput(logdir), 79 | utils.WandbOutput(**wandb_config,config=dict(config)) 80 | ] 81 | logger = utils.Logger(step, outputs, multiplier=config.action_repeat) 82 | 83 | # save experiment used config 84 | with open(logdir / "used_config.yaml", "w") as f: 85 | f.write("## command line input:\n## " + " ".join(sys.argv) + "\n##########\n\n") 86 | yaml.dump(dict(config), f) 87 | 88 | is_carla = config.task.split("_", 1)[0] == 'carla' 89 | num_eval_envs = min(config.envs, config.eval_eps) 90 | # only one env for carla 91 | if is_carla: 92 | assert config.envs == 1 and num_eval_envs == 1 93 | if config.envs_parallel == "none": 94 | train_envs = [envs.make_env(config, "train") for _ in range(config.envs)] 95 | eval_envs = [envs.make_env(config,"eval") for _ in range(num_eval_envs)] 96 | else: 97 | train_envs = [envs.make_async_env(config, "train") for _ in range(config.envs)] 98 | eval_envs = [envs.make_async_env(config, "eval") for _ in range(num_eval_envs)] 99 | 100 | # the agent needs 1. init modules 2. go to device 3. set optimizer 101 | agent = agents.PreLARFinetune(config, train_envs[0].obs_space, train_envs[0].act_space, step) 102 | finetuner = train.Finetuner(config, agent, train_replay, eval_replay, train_envs, eval_envs, step, logger) 103 | 104 | try: 105 | finetuner.run(config.steps) 106 | except KeyboardInterrupt: 107 | print("Keyboard Interrupt - saving agent") 108 | finetuner.save_agent(logdir / "variables.pt") 109 | except Exception as e: 110 | print("Training Error:", e) 111 | finetuner.save_agent(logdir / "variables_error.pt") 112 | raise e 113 | finally: 114 | finetuner.save_agent(logdir / "variables.pt") 115 | for env in train_envs + eval_envs: 116 | env.close() 117 | 118 | 119 | if __name__ == "__main__": 120 | __spec__ = "ModuleSpec(name='builtins', loader=)" 121 | main() 122 | -------------------------------------------------------------------------------- /wmlib/nets/va_net/action_encoder.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # import torch.distributions as dists 6 | from ..modules import get_act_module 7 | from einops.layers.torch import Rearrange 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | from ...core import dists 11 | 12 | class ActionEncoder(nn.Module): 13 | def __init__(self,action_dim,hidden_dim,act,deter=64,discrete=False,stoch=64,type_='deter', std_act='sigmoid2',twl=1,**dummy_kwargs): 14 | super().__init__() 15 | self._action_dim = action_dim 16 | self._hidden_dim = hidden_dim 17 | self.twl = int(twl) 18 | # self.twl = int(2) 19 | 20 | # self._embed_dim = embed_dim 21 | self._act_module = get_act_module(act) 22 | # self._dist_embed = dist_embed 23 | self._deter = deter 24 | self._discrete = discrete 25 | self._stoch = stoch 26 | self._std_act = std_act 27 | self.type = type_ # deter, stoch, mix 28 | if self.type == 'deter': 29 | self.embed_dim = self._deter 30 | elif self.type == 'stoch': 31 | self.embed_dim = self._stoch * self._discrete if self._discrete else self._stoch * 2 32 | elif self.type == 'mix': 33 | self.embed_dim = self._deter + (self._stoch * self._discrete if self._discrete else self._stoch * 2) 34 | else: 35 | raise NotImplementedError 36 | # if self._dist_embed: 37 | # self._embed_dim = self._stoch * self._discrete if self._discrete else self._stoch * 2 38 | # input_action_dim = self._action_dim if self.twl<=1 else self._action_dim * self.twl 39 | if self.twl <= 1: 40 | self._encoder = nn.Sequential( 41 | nn.Linear(self._action_dim,self._hidden_dim), 42 | self._act_module(), 43 | nn.Linear(self._hidden_dim,self.embed_dim) 44 | ) 45 | else: 46 | self._encoder = nn.Sequential( 47 | Rearrange('... t d -> ... d t'), 48 | nn.Conv1d(self._action_dim,self._hidden_dim,self.twl,padding=self.twl-1), 49 | Rearrange('... d t -> ... t d'), 50 | self._act_module(), 51 | nn.Linear(self._hidden_dim,self.embed_dim) 52 | ) 53 | self.action_buffer = ActionBuffer(self.twl) 54 | # if self._discrete: 55 | # self._encoder.add_module('encoder_rerange',Rearrange('... (s d) -> ... s d', s=self._stoch, d=self._discrete)) 56 | 57 | self._std_fn = { 58 | "softplus": lambda std: F.softplus(std), 59 | "sigmoid": lambda std: torch.sigmoid(std), 60 | "sigmoid2": lambda std: 2 * torch.sigmoid(std / 2), 61 | }[self._std_act] 62 | 63 | def forward(self,action,sample=True): 64 | x = self._encoder(action) 65 | if self.twl > 1: 66 | x = x[:,:-self.twl+1] 67 | # x = x 68 | action_code = {} 69 | if self.type == 'deter': 70 | return {'deter':x} 71 | elif self.type == 'stoch': 72 | x_stoch = x 73 | elif self.type == 'mix': 74 | x_deter = x[..., :self._deter] 75 | x_stoch = x[..., self._deter:] 76 | action_code.update({'deter':x_deter}) 77 | if self._discrete: 78 | x_stoch = rearrange(x_stoch,'... (s d) -> ... s d', s=self._stoch, d=self._discrete) 79 | dist = self.get_dist({'logit':x_stoch}) 80 | stoch = dist.sample() if sample else dist.mode 81 | stoch = rearrange(stoch,'... s d -> ... (s d)') 82 | action_code.update({'stoch':stoch,'logit':x_stoch}) 83 | else: 84 | mean, std = torch.chunk(x_stoch, 2, dim=-1) 85 | dist = self.get_dist({'mean':mean,'std':std}) 86 | stoch = dist.sample() if sample else dist.mode 87 | action_code.update({'stoch':stoch,'mean':mean,'std':std}) 88 | return action_code 89 | 90 | def get_dist(self, state): 91 | """ 92 | gets the stochastic state distribution 93 | """ 94 | if self._discrete: 95 | logit = state["logit"] 96 | logit = logit.float() 97 | dist = dists.Independent(dists.OneHotDist(logit), 1) 98 | else: 99 | mean, std = state['mean'], state['std'] 100 | # mean = mean.float() 101 | # std = std.float() 102 | std = self._std_fn(std) 103 | dist = dists.Independent(dists.Normal(mean, std), 1) 104 | return dist 105 | 106 | class ActionBuffer: 107 | def __init__(self,twl) -> None: 108 | self.actions = deque(maxlen=twl) 109 | 110 | def append(self,action): 111 | self.actions.append(action) 112 | 113 | def get_actions(self): 114 | return torch.stack(list(self.actions),dim=1) 115 | 116 | def clear(self): 117 | self.actions.clear() -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: PreLAR 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - asttokens=2.4.1=pyhd8ed1ab_0 10 | - backcall=0.2.0=pyh9f0ad1d_0 11 | - backports=1.0=pyhd8ed1ab_3 12 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 13 | - ca-certificates=2023.08.22=h06a4308_0 14 | - comm=0.1.4=pyhd8ed1ab_0 15 | - debugpy=1.6.7=py38h6a678d5_0 16 | - entrypoints=0.4=pyhd8ed1ab_0 17 | - executing=2.0.1=pyhd8ed1ab_0 18 | - ipykernel=6.26.0=pyhf8b6a83_0 19 | - ipython=8.12.0=pyh41d4057_0 20 | - jedi=0.19.1=pyhd8ed1ab_0 21 | - jupyter_client=7.3.4=pyhd8ed1ab_0 22 | - jupyter_core=5.5.0=py38h578d9bd_0 23 | - ld_impl_linux-64=2.38=h1181459_1 24 | - libffi=3.4.4=h6a678d5_0 25 | - libgcc-ng=11.2.0=h1234567_1 26 | - libgomp=11.2.0=h1234567_1 27 | - libsodium=1.0.18=h36c2ea0_1 28 | - libstdcxx-ng=11.2.0=h1234567_1 29 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 30 | - ncurses=6.4=h6a678d5_0 31 | - nest-asyncio=1.5.8=pyhd8ed1ab_0 32 | - openssl=3.0.12=h7f8727e_0 33 | - packaging=23.2=pyhd8ed1ab_0 34 | - parso=0.8.3=pyhd8ed1ab_0 35 | - pexpect=4.8.0=pyh1a96a4e_2 36 | - pickleshare=0.7.5=py_1003 37 | - platformdirs=4.0.0=pyhd8ed1ab_0 38 | - prompt-toolkit=3.0.40=pyha770c72_0 39 | - prompt_toolkit=3.0.40=hd8ed1ab_0 40 | - ptyprocess=0.7.0=pyhd3deb0d_0 41 | - pure_eval=0.2.2=pyhd8ed1ab_0 42 | - pygments=2.16.1=pyhd8ed1ab_0 43 | - python=3.8.18=h955ad1f_0 44 | - python-dateutil=2.8.2=pyhd8ed1ab_0 45 | - python_abi=3.8=2_cp38 46 | - pyzmq=25.1.0=py38h6a678d5_0 47 | - readline=8.2=h5eee18b_0 48 | - setuptools=68.0.0=py38h06a4308_0 49 | - six=1.16.0=pyh6c4a22f_0 50 | - sqlite=3.41.2=h5eee18b_0 51 | - stack_data=0.6.2=pyhd8ed1ab_0 52 | - tk=8.6.12=h1ccaba5_0 53 | - tornado=6.1=py38h0a891b7_3 54 | - traitlets=5.13.0=pyhd8ed1ab_0 55 | - typing_extensions=4.8.0=pyha770c72_0 56 | - wcwidth=0.2.9=pyhd8ed1ab_0 57 | - wheel=0.41.2=py38h06a4308_0 58 | - xz=5.4.2=h5eee18b_0 59 | - zeromq=4.3.4=h2531618_0 60 | - zlib=1.2.13=h5eee18b_0 61 | - pip=23.3.1 62 | - pip: 63 | - absl-py==2.0.0 64 | - appdirs==1.4.4 65 | - arch==5.3.0 66 | - blessed==1.20.0 67 | - cachetools==5.3.2 68 | - carla==0.9.14 69 | - certifi==2023.7.22 70 | - cffi==1.16.0 71 | - charset-normalizer==3.3.2 72 | - click==8.1.7 73 | - cloudpickle==3.0.0 74 | - colorama==0.4.6 75 | - conda-pack==0.7.1 76 | - contourpy==1.1.1 77 | - cycler==0.12.1 78 | - cython==0.29.36 79 | - decorator==4.4.2 80 | - dm-control==0.0.318037100 81 | - dm-env==1.6 82 | - dm-tree==0.1.8 83 | - docker-pycreds==0.4.0 84 | - dotmap==1.3.30 85 | - einops==0.7.0 86 | - etils==1.3.0 87 | - fasteners==0.19 88 | - fonttools==4.44.0 89 | - future==0.18.3 90 | - gitdb==4.0.11 91 | - gitpython==3.1.40 92 | - glfw==2.6.2 93 | - google-auth==2.23.4 94 | - google-auth-oauthlib==1.0.0 95 | - gpustat==1.1.1 96 | - grpcio==1.59.2 97 | - gym==0.26.2 98 | - gym-notices==0.0.8 99 | - h5py==3.10.0 100 | - idna==3.4 101 | - imageio==2.31.6 102 | - imageio-ffmpeg==0.4.9 103 | - importlib-metadata==6.8.0 104 | - importlib-resources==6.1.1 105 | - joblib==1.3.2 106 | - kiwisolver==1.4.5 107 | - kornia==0.7.0 108 | - labmaze==1.0.6 109 | - lxml==4.9.3 110 | - markdown==3.5.1 111 | - markupsafe==2.1.3 112 | - matplotlib==3.7.3 113 | - mo==0.3.0 114 | - moviepy==1.0.3 115 | - mujoco==3.0.1 116 | - mujoco-py==2.1.2.14 117 | - networkx==3.1 118 | - numpy==1.24.4 119 | - nvidia-ml-py==12.535.133 120 | - oauthlib==3.2.2 121 | - opencv-python==4.8.1.78 122 | - pandas==2.0.3 123 | - patsy==0.5.3 124 | - pillow==10.0.1 125 | - pip==23.3.1 126 | - proglog==0.1.10 127 | - property-cached==1.6.4 128 | - protobuf==4.25.0 129 | - psutil==5.9.6 130 | - pyasn1==0.5.0 131 | - pyasn1-modules==0.3.0 132 | - pycparser==2.21 133 | - pygame==2.5.2 134 | - pyopengl==3.1.7 135 | - pyparsing==2.4.7 136 | - pytz==2023.3.post1 137 | - pyyaml==6.0.1 138 | - requests==2.31.0 139 | - requests-oauthlib==1.3.1 140 | - rliable==1.0.8 141 | - rsa==4.9 142 | - ruamel-yaml==0.18.5 143 | - ruamel-yaml-clib==0.2.8 144 | - scikit-learn==1.3.2 145 | - scipy==1.10.1 146 | - seaborn==0.13.0 147 | - sentry-sdk==1.34.0 148 | - setproctitle==1.3.3 149 | - shapely==2.0.2 150 | - smmap==5.0.1 151 | - statsmodels==0.14.0 152 | - tensorboard==2.14.0 153 | - tensorboard-data-server==0.7.2 154 | - threadpoolctl==3.2.0 155 | - toml==0.10.2 156 | - -f https://download.pytorch.org/whl/cu113/torch_stable.html 157 | - torch==1.12.1+cu113 158 | - torchaudio==0.12.1+cu113 159 | - torchsort==0.1.9 160 | - torchtyping==0.1.4 161 | - torchvision==0.13.1+cu113 162 | - tqdm==4.66.1 163 | - typeguard==4.1.5 164 | - tzdata==2023.3 165 | - urllib3==2.0.7 166 | - wandb==0.16.0 167 | - werkzeug==3.0.1 168 | - zipp==3.17.0 169 | -------------------------------------------------------------------------------- /wmlib/train/pretrainer.py: -------------------------------------------------------------------------------- 1 | from ..core.driver import Driver 2 | from ..core.replay import Replay 3 | from .. import utils, agents 4 | import numpy as np 5 | import re 6 | import collections 7 | import torch 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | 11 | class Pretrainer: 12 | def __init__(self,config,agent:agents.BaseAgent,train_replay:Replay,eval_replay:Replay,step:utils.Counter,logger:utils.Logger) -> None: 13 | 14 | self.config = config 15 | 16 | 17 | self.device = config.device 18 | self._precision = config.precision 19 | self._dtype = torch.float16 if self._precision == 16 else torch.float32 # only on cuda 20 | self.metrics = collections.defaultdict(list) 21 | self.logdir = Path(config.logdir).expanduser() 22 | self.load_model_dir = Path(config.load_model_dir).expanduser() 23 | 24 | self.agent = agent 25 | self.step = step 26 | self.logger = logger 27 | 28 | self.train_replay = train_replay 29 | self.eval_replay = eval_replay 30 | 31 | self.should_log = utils.Every(config.log_every) 32 | self.should_video = utils.Every(config.video_every) 33 | self.should_save = utils.Every(config.eval_every) 34 | 35 | 36 | self.agent = self.agent.to(self.device) 37 | self.agent.init_optimizers() 38 | 39 | self.train_agent = CarryOverState(self.agent.train) 40 | 41 | self.train_dataset = iter(train_replay.dataset(**config.dataset)) 42 | self.report_dataset = iter(train_replay.dataset(**config.dataset)) 43 | if config.eval_video_list != 'none': 44 | self.eval_dataset = iter(eval_replay.dataset(**config.dataset)) 45 | 46 | self.init_agent() 47 | self.load_agent(self.logdir / 'variables.pt') 48 | self.load_agent(self.load_model_dir / 'variables.pt') 49 | 50 | def run(self,steps): 51 | for _ in tqdm(range(int(steps)), total=int(steps), initial=int(self.step)): 52 | self.train_step() 53 | 54 | def init_agent(self): 55 | print(f'Init agent from scratch. & Benchmark training.') 56 | self.agent.apply(self.weights_init) 57 | self.train_agent(self.next_batch(self.train_dataset)) 58 | torch.cuda.empty_cache() 59 | 60 | def load_agent(self,path:Path): 61 | if path.exists(): 62 | print(f'Load agent from checkpoint {path}.') 63 | self.agent.load_state_dict(torch.load(path)) 64 | 65 | def save_agent(self,dir:Path,suffix=''): 66 | self.agent.save_model(dir, suffix) 67 | 68 | def weights_init(self, m): 69 | if hasattr(m,'original_name'): 70 | classname = m.original_name 71 | else: 72 | classname = m.__class__.__name__ 73 | if classname.find('LayerNorm') == -1 and classname.find('BatchNorm') == -1 and hasattr(m, "weight"): 74 | torch.nn.init.xavier_uniform_(m.weight) 75 | if m.bias is not None and m.bias.data is not None: 76 | torch.nn.init.zeros_(m.bias) 77 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 78 | print("setting memory format to channels last") 79 | m.to(memory_format=torch.channels_last) 80 | 81 | def train_step(self): 82 | train_metrics = self.train_agent(self.next_batch(self.train_dataset)) 83 | [self.metrics[key].append(value) for key, value in train_metrics.items()] 84 | self.step.increment() 85 | 86 | if self.should_log(self.step): 87 | for name, values in self.metrics.items(): 88 | self.logger.scalar(name, np.array(values, np.float64).mean()) 89 | self.metrics[name].clear() 90 | # only video when log 91 | if self.should_video(self.step): 92 | report_metrics = self.agent.report(self.next_batch(self.report_dataset),recon=True) 93 | self.logger.add(report_metrics, prefix='train') 94 | self.logger.write(fps=True) 95 | 96 | if self.should_save(self.step): 97 | self.save_agent(self.logdir) 98 | if self.config.save_all_models and int(self.step) % 50000 == 1: 99 | self.save_agent(self.logdir, f'_s{int(self.step)}') 100 | 101 | if self.config.eval_video_list != 'none': 102 | eval_metrics = self.agent.eval(self.next_batch(self.eval_dataset))[1] 103 | for name, values in eval_metrics.items(): 104 | if name.endswith('loss'): 105 | self.logger.scalar('eval/' + name, np.array(values, np.float64).mean()) 106 | report_metrics =self.agent.report(self.next_batch(self.eval_dataset), recon=True) 107 | self.logger.add(report_metrics, prefix="val") 108 | self.logger.write(fps=True) 109 | 110 | def next_batch(self, iter): 111 | # casts to fp16 and cuda 112 | out = {k: v.to(device=self.device, dtype=self._dtype) for k, v in next(iter).items()} 113 | return out 114 | 115 | 116 | class CarryOverState: 117 | def __init__(self, fn): 118 | self._fn = fn 119 | self._state = None 120 | 121 | def __call__(self, *args): 122 | self._state, out = self._fn(*args, self._state) 123 | return out 124 | 125 | -------------------------------------------------------------------------------- /wmlib/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import collections 4 | import numpy as np 5 | import wandb 6 | from pathlib import Path 7 | from .counter import Counter 8 | 9 | class Logger: 10 | 11 | def __init__(self, step:Counter, outputs:list, multiplier=1): 12 | self._step = step 13 | self._outputs = outputs 14 | self._multiplier = multiplier 15 | self._last_step = None 16 | self._last_time = None 17 | self._metrics = [] 18 | 19 | def add(self, mapping, prefix=None): 20 | step = int(self._step) * self._multiplier 21 | for name, value in dict(mapping).items(): 22 | name = f'{prefix}/{name}' if prefix else name 23 | value = np.array(value) 24 | assert len(value.shape) in (0, 2, 3, 4), f'Shape {value.shape} for {name} cannot be interpreted as scalar, image, or video.' 25 | self._metrics.append((step, name, value)) 26 | 27 | def scalar(self, name, value): 28 | self.add({name: value}) 29 | 30 | def image(self, name, value): 31 | self.add({name: value}) 32 | 33 | def video(self, name, value): 34 | self.add({name: value}) 35 | 36 | def write(self, fps=False): 37 | if fps: 38 | self.scalar('fps', self._compute_fps()) 39 | if not self._metrics: 40 | return 41 | for output in self._outputs: 42 | output(self._metrics) 43 | self._metrics.clear() 44 | 45 | def _compute_fps(self): 46 | step = int(self._step) * self._multiplier 47 | if self._last_step is None: 48 | self._last_time = time.time() 49 | self._last_step = step 50 | return 0 51 | steps = step - self._last_step 52 | duration = time.time() - self._last_time 53 | self._last_time += duration 54 | self._last_step = step 55 | return steps / duration 56 | 57 | 58 | class TerminalOutput: 59 | def __call__(self, summaries): 60 | # TODO aggregate 61 | # aggregate values in the same step 62 | scalar_summaries = collections.defaultdict(lambda: collections.defaultdict(list)) 63 | for step, name, value in summaries: 64 | name = name.replace('/', '_') 65 | if len(value.shape) == 0: 66 | scalar_summaries[step][name].append(value.item()) 67 | for step in scalar_summaries: 68 | scalars = {k: float(np.mean(v)) for k, v in scalar_summaries[step].items()} 69 | formatted = {k: self._format_value(v) for k, v in scalars.items()} 70 | print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items())) 71 | # step = max(s for s, _, _, in summaries) 72 | # scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 73 | # formatted = {k: self._format_value(v) for k, v in scalars.items()} 74 | # print(f"[{step}]", " / ".join(f"{k} {v}" for k, v in formatted.items())) 75 | 76 | def _format_value(self, value): 77 | if value == 0: 78 | return '0' 79 | elif 0.01 < abs(value) < 10000: 80 | value = f'{value:.2f}' 81 | value = value.rstrip('0') 82 | # value = value.rstrip("0") 83 | value = value.rstrip('.') 84 | return value 85 | else: 86 | value = f'{value:.1e}' 87 | value = value.replace('.0e', 'e') 88 | value = value.replace('+0', '') 89 | value = value.replace('+', '') 90 | value = value.replace('-0', '-') 91 | return value 92 | 93 | 94 | class JSONLOutput: 95 | def __init__(self, logdir): 96 | self._logdir = Path(logdir).expanduser() 97 | 98 | def __call__(self, summaries): 99 | # aggregate values in the same step 100 | scalar_summaries = collections.defaultdict(lambda: collections.defaultdict(list)) 101 | for step, name, value in summaries: 102 | # name = name.replace('/', '_') 103 | if len(value.shape) == 0: 104 | scalar_summaries[step][name].append(value.item()) 105 | for step in scalar_summaries: 106 | scalars = {k: np.mean(v) for k, v in scalar_summaries[step].items()} 107 | with (self._logdir / 'metrics.jsonl').open('a') as f: 108 | f.write(json.dumps({'step': step, **scalars}) + '\n') 109 | 110 | 111 | class WandbOutput: 112 | def __init__(self, fps=20, **kwargs): 113 | self._fps = fps 114 | wandb.init(**kwargs) 115 | 116 | def __call__(self, summaries): 117 | ''' 118 | Dataformats: 119 | - scalar: float 120 | - image: numpy.ndarray with shape (C, H, W) or (H, W) 121 | - video: numpy.ndarray with shape (T, C, H, W) 122 | ''' 123 | for step, name, value in summaries: 124 | value_dimension = len(value.shape) 125 | if value_dimension == 0: 126 | wandb.log({name: value}, step=step) 127 | elif value_dimension == 2 or value_dimension == 3: 128 | value = value.transpose((1, 2, 0)) if value_dimension==3 else value 129 | wandb.log({f'image/{name}': wandb.Image(value)}, step=step) # (H, W, C) 130 | elif value_dimension == 4: 131 | name = name if isinstance(name, str) else name.decode('utf-8') 132 | if np.issubdtype(value.dtype, np.floating): 133 | value = np.clip(255 * value, 0, 255).astype(np.uint8) 134 | # value = value.transpose((0, 3, 1, 2)) 135 | wandb.log({f'video/{name}': wandb.Video(value, fps=self._fps, format='mp4')}, step=step) # (T, C, H, W) 136 | -------------------------------------------------------------------------------- /wmlib/nets/encoder/ctx_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import torchvision.transforms as T 5 | 6 | from .base import BaseEncoder 7 | from ..modules import * 8 | 9 | 10 | class ContextualizedResNetEncoder(BaseEncoder): 11 | 12 | # TODO: remame args 13 | def __init__( 14 | self, 15 | shapes, 16 | cnn_keys=r".*", 17 | mlp_keys=r".*", 18 | act="elu", 19 | cnn_depth=48, 20 | mlp_layers=[400, 400, 400, 400], 21 | res_layers=2, 22 | res_depth=3, 23 | res_norm='none', 24 | ctx_res_layers=2, 25 | ctx_cnn_depth=48, 26 | ctx_cond_choice='trand', 27 | ctx_aug='none', 28 | **dummy_kwargs, 29 | ): 30 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 31 | self._act = get_act(act) 32 | self._cnn_depth = cnn_depth 33 | 34 | self._res_layers = res_layers 35 | self._res_depth = res_depth 36 | self._res_norm = res_norm 37 | 38 | self._ctx_res_layers = ctx_res_layers 39 | self._ctx_cnn_depth = ctx_cnn_depth 40 | self._ctx_cond_choice = ctx_cond_choice 41 | self._ctx_aug = ctx_aug 42 | 43 | def __call__(self, data, eval=False): 44 | key, shape = list(self.shapes.items())[0] 45 | batch_dims = data[key].shape[:-len(shape)] 46 | data = { 47 | k: torch.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):]) 48 | for k, v in data.items() 49 | } 50 | 51 | output, shortcut = self._cnn({k: data[k] for k in self.cnn_keys}, batch_dims, eval) 52 | # TODO: self._mlp 53 | if eval: 54 | return output.reshape(batch_dims + output.shape[1:]) 55 | else: 56 | return { 57 | 'embed': output.reshape(batch_dims + output.shape[1:]), 58 | 'shortcut': shortcut, 59 | } 60 | 61 | def _cnn(self, data, batch_dims=None, eval=False): 62 | x = torch.cat(list(data.values()), -1) 63 | x = x.to(memory_format=torch.channels_last) 64 | 65 | shortcuts = {} 66 | if not eval: 67 | with torch.no_grad(): 68 | ctx = self.get_context(x.reshape(batch_dims + x.shape[1:])) # [B, T, C, H, W] => [B, C, H, W] 69 | 70 | module_name = f"cond_aug" 71 | if module_name not in self._modules: 72 | self._modules[module_name] = get_augmentation(self._ctx_aug, ctx.shape) 73 | ctx = self.cond_aug(ctx) 74 | 75 | x = self.get(f"convin", nn.Conv2d, x.shape[1], self._cnn_depth, 3, 2, 1)(x) 76 | x = self._act(x) 77 | 78 | if not eval: 79 | ctx = self.get(f"cond_convin", nn.Conv2d, ctx.shape[1], self._ctx_cnn_depth, 3, 2, 1)(ctx) 80 | ctx = self._act(ctx) 81 | 82 | L = self._res_depth 83 | for i in range(L): 84 | depth = 2 ** i * self._cnn_depth 85 | x = self.get(f"res{i}", ResidualStack, x.shape[1], depth, 86 | self._res_layers, 87 | norm=self._res_norm, 88 | spatial_dim=x.shape[-2:], 89 | )(x) 90 | x = self.get(f"pool{i}", nn.AvgPool2d, 2, 2)(x) 91 | 92 | if not eval: 93 | ctx_depth = 2 ** i * self._ctx_cnn_depth 94 | ctx = self.get(f"cond_res{i}", ResidualStack, ctx.shape[1], ctx_depth, 95 | self._ctx_res_layers, 96 | norm=self._res_norm, 97 | spatial_dim=ctx.shape[-2:], 98 | )(ctx) 99 | shortcuts[ctx.shape[2]] = ctx # [B, C, H, W] 100 | ctx = self.get(f"cond_pool{i}", nn.AvgPool2d, 2, 2)(ctx) 101 | 102 | return x.reshape(tuple(x.shape[:-3]) + (-1,)), shortcuts 103 | 104 | # TODO: clean up or rename t0 tlast trand 105 | def get_context(self, frames): 106 | """ 107 | frames: [B, T, C, H, W] 108 | """ 109 | with torch.no_grad(): 110 | if self._ctx_cond_choice == 't0': 111 | # * initial frame 112 | context = frames[:, 0] # [B, C, H, W] 113 | elif self._ctx_cond_choice == 'tlast': 114 | # * last frame 115 | context = frames[:, -1] # [B, C, H, W] 116 | elif self._ctx_cond_choice == 'trand': 117 | # * timestep randomization 118 | idx = torch.from_numpy(np.random.choice(frames.shape[1], frames.shape[0])).to(frames.device) 119 | idx = idx.reshape(-1, 1, 1, 1, 1).repeat(1, 1, *frames.shape[-3:]) # [B, 1, C, H, W] 120 | context = frames.gather(1, idx).squeeze(1) # [B, C, H, W] 121 | else: 122 | raise NotImplementedError 123 | return context 124 | 125 | 126 | def get_augmentation(aug_type, shape): 127 | if aug_type == 'none': 128 | return nn.Identity() 129 | elif aug_type == 'shift': 130 | return nn.Sequential( 131 | nn.ReplicationPad2d(padding=8), 132 | kornia.augmentation.RandomCrop(shape[-2:]) 133 | ) 134 | elif aug_type == 'shift4': 135 | return nn.Sequential( 136 | nn.ReplicationPad2d(padding=4), 137 | kornia.augmentation.RandomCrop(shape[-2:]) 138 | ) 139 | elif aug_type == 'flip': 140 | return T.RandomHorizontalFlip(p=0.5) 141 | elif aug_type == 'scale': 142 | return T.RandomResizedCrop( 143 | size=shape[-2:], scale=[0.666667, 1.0], ratio=(0.75, 1.333333)) 144 | elif aug_type == 'erasing': 145 | return kornia.augmentation.RandomErasing() 146 | else: 147 | raise NotImplementedError 148 | -------------------------------------------------------------------------------- /configs/dreamerv2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | seed: 0 6 | wandb: {project: 'world-model',name: 'dreamerv2',mode: 'online'} 7 | device: cuda 8 | task: metaworld_drawer_open 9 | envs: 1 10 | envs_parallel: none 11 | render_size: [64, 64] 12 | dmc_camera: -1 13 | camera: none 14 | dmcr_vary: all 15 | atari_grayscale: True 16 | time_limit: 0 17 | action_repeat: 1 18 | steps: 1e8 19 | log_every: 1e4 20 | eval_every: 1e5 21 | eval_eps: 1 22 | prefill: 10000 23 | pretrain: 1 24 | train_every: 5 25 | train_steps: 1 26 | expl_until: 0 27 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 28 | dataset: {batch: 16, length: 50} 29 | log_keys_video: ['image'] 30 | log_keys_sum: '^$' 31 | log_keys_mean: '^$' 32 | log_keys_max: '^$' 33 | precision: 16 34 | jit: True 35 | stop_steps: -1 36 | 37 | # CARLA 38 | carla_port: 2000 39 | carla: { 40 | collision_coeff: 1e-3, 41 | num_other_vehicles: 20, 42 | centering_reward_type: div, 43 | centering_reward_weight: 1.0, 44 | clip_collision_reward: 10.0, 45 | steer_coeff: 1.0, 46 | centering_border: 1.75, 47 | use_branch_lane_cut: True, 48 | changing_weather_speed: 0.1, 49 | } 50 | 51 | # Agent 52 | clip_rewards: tanh 53 | expl_behavior: greedy 54 | expl_noise: 0.0 55 | eval_noise: 0.0 56 | eval_state_mean: False 57 | 58 | # Intrinsic bonus parameters 59 | k: 16 60 | beta: 0.0 61 | beta_type: abs 62 | intr_seq_length: 5 63 | intr_reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8, init: 1.0} 64 | queue_size: 4096 65 | queue_dim: 128 66 | 67 | # World Model 68 | grad_heads: [decoder, reward, discount] 69 | pred_discount: True 70 | rssm: {ensemble: 1,embed_dim: 1536, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 71 | encoder_type: plaincnn # ['plaincnn', 'resnet'] 72 | encoder: { 73 | mlp_keys: '.*', 74 | cnn_keys: '.*', 75 | act: elu, 76 | norm: none, 77 | cnn_depth: 48, 78 | cnn_kernels: [4, 4, 4, 4], 79 | mlp_layers: [400, 400, 400, 400], 80 | res_norm: 'batch', 81 | res_depth: 3, 82 | res_layers: 2, 83 | } 84 | decoder_type: plaincnn # ['plaincnn', 'resnet'] 85 | decoder: { 86 | mlp_keys: '.*', 87 | cnn_keys: '.*', 88 | act: elu, 89 | norm: none, 90 | cnn_input_dim: 2048, 91 | cnn_depth: 48, 92 | cnn_kernels: [5, 5, 6, 6], 93 | mlp_layers: [400, 400, 400, 400], 94 | res_norm: 'batch', 95 | res_depth: 3, 96 | res_layers: 2, 97 | } 98 | reward_head: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 99 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 100 | loss_scales: { 101 | kl: 1.0, 102 | reward: 1.0, 103 | discount: 1.0, 104 | proprio: 1.0, 105 | image: 1.0 106 | } 107 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 108 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 109 | 110 | # Actor Critic 111 | actor: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 112 | critic: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 113 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 114 | critic_opt: {opt: adam, lr: 2e-4, eps: 1e-5, clip: 100, wd: 1e-6} 115 | discount: 0.99 116 | discount_lambda: 0.95 117 | imag_horizon: 15 118 | imag_batch: -1 119 | actor_grad: auto 120 | actor_grad_mix: 0.1 121 | actor_ent: 2e-3 122 | slow_target: True 123 | slow_target_update: 100 124 | slow_target_fraction: 1 125 | slow_baseline: True 126 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 127 | 128 | # Exploration 129 | expl_intr_scale: 1.0 130 | expl_extr_scale: 0.0 131 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 132 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 133 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 134 | disag_target: stoch 135 | disag_log: False 136 | disag_models: 10 137 | disag_offset: 1 138 | disag_action_cond: True 139 | expl_model_loss: kl 140 | 141 | 142 | metaworld: 143 | 144 | task: metaworld_drawer_open 145 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 146 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 147 | dataset: {batch: 50, length: 50} 148 | time_limit: 500 149 | action_repeat: 1 150 | eval_eps: 10 151 | prefill: 5000 152 | camera: corner 153 | steps: 256000 154 | stop_steps: 255000 155 | 156 | replay.capacity: 1e6 157 | eval_every: 1e4 158 | pretrain: 100 159 | clip_rewards: identity 160 | grad_heads: [decoder, reward] 161 | pred_discount: False 162 | actor_ent: 1e-4 163 | critic_opt.lr: 8e-5 164 | model_opt.lr: 3e-4 165 | 166 | robodesk: 167 | 168 | task: robodesk_open_slide 169 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 170 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 171 | dataset: {batch: 50, length: 50} 172 | time_limit: 500 173 | action_repeat: 1 174 | eval_eps: 10 175 | prefill: 5000 176 | camera: corner 177 | steps: 256000 178 | stop_steps: 255000 179 | 180 | replay.capacity: 1e6 181 | eval_every: 1e4 182 | pretrain: 100 183 | clip_rewards: identity 184 | grad_heads: [decoder, reward] 185 | pred_discount: False 186 | actor_ent: 1e-4 187 | critic_opt.lr: 8e-5 188 | model_opt.lr: 3e-4 189 | 190 | debug: 191 | 192 | jit: False 193 | time_limit: 100 194 | eval_every: 300 195 | log_every: 300 196 | prefill: 100 197 | pretrain: 1 198 | train_steps: 1 199 | replay: {minlen: 10, maxlen: 30} 200 | dataset: {batch: 10, length: 10} 201 | 202 | 203 | 204 | plaincnn: 205 | encoder_type: plaincnn 206 | decoder_type: plaincnn 207 | 208 | 209 | plainresnet: 210 | encoder_type: resnet 211 | decoder_type: resnet -------------------------------------------------------------------------------- /configs/naive_finetuning.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | load_logdir: none 6 | seed: 0 7 | device: cuda 8 | wandb: {project: 'world-model',name: 'naive_finetuning',mode: 'online'} 9 | task: metaworld_drawer_open 10 | envs: 1 11 | envs_parallel: none 12 | render_size: [64, 64] 13 | dmc_camera: -1 14 | camera: corner 15 | dmcr_vary: all 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | eval_every: 1e4 22 | eval_eps: 1 23 | prefill: 10000 24 | pretrain: 100 25 | train_every: 5 26 | train_steps: 1 27 | expl_until: 0 28 | replay: {capacity: 1e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 29 | dataset: {batch: 16, length: 50} 30 | log_keys_video: ['image'] 31 | log_keys_sum: '^$' 32 | log_keys_mean: '^$' 33 | log_keys_max: '^$' 34 | precision: 16 35 | jit: True 36 | stop_steps: -1 37 | 38 | # CARLA 39 | carla_port: 2000 40 | carla: { 41 | collision_coeff: 1e-3, 42 | num_other_vehicles: 20, 43 | centering_reward_type: div, 44 | centering_reward_weight: 1.0, 45 | clip_collision_reward: 10.0, 46 | steer_coeff: 1.0, 47 | centering_border: 1.75, 48 | use_branch_lane_cut: True, 49 | changing_weather_speed: 0.1, 50 | } 51 | 52 | # Agent 53 | clip_rewards: identity 54 | expl_behavior: greedy 55 | expl_noise: 0.0 56 | eval_noise: 0.0 57 | eval_state_mean: False 58 | 59 | # Fine-tuning parameters 60 | load_modules: [encoder, decoder, af_rssm] 61 | load_strict: True 62 | enc_lr_type: no_pretrain 63 | concat_embed: False 64 | finetune_rssm: False #True 65 | 66 | # Intrinsic bonus parameters 67 | k: 16 68 | beta: 1.0 69 | beta_type: abs 70 | intr_seq_length: 5 71 | intr_reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8, init: 1.0} 72 | queue_size: 4096 73 | queue_dim: 128 74 | 75 | # World Model 76 | grad_heads: [decoder, reward] 77 | pred_discount: False 78 | # rssm: {action_free: False, fill_action: 50, ensemble: 1, embed_dim: 2048,hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 79 | af_rssm: {action_free: False, fill_action: 50, ensemble: 1, embed_dim: 3072,hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 80 | encoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 81 | encoder: { 82 | mlp_keys: '.*', 83 | cnn_keys: '.*', 84 | act: elu, 85 | norm: none, 86 | cnn_depth: 48, 87 | cnn_kernels: [4, 4, 4, 4], 88 | mlp_layers: [400, 400, 400, 400], 89 | res_norm: 'batch', 90 | res_depth: 3, 91 | res_layers: 2, 92 | } 93 | decoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 94 | decoder: { 95 | mlp_keys: '.*', 96 | cnn_keys: '.*', 97 | act: elu, 98 | norm: none, 99 | cnn_depth: 48, 100 | cnn_kernels: [5, 5, 6, 6], 101 | mlp_layers: [400, 400, 400, 400], 102 | res_norm: 'batch', 103 | res_depth: 3, 104 | res_layers: 2, 105 | } 106 | reward_head: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 107 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 108 | loss_scales: { 109 | af_kl: 0.0, 110 | kl: 1.0, 111 | reward: 1.0, 112 | action: 1.0, 113 | discount: 1.0, 114 | proprio: 1.0, 115 | aux_reward: 0.0, 116 | } 117 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 118 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 119 | enc_model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 120 | 121 | # Actor Critic 122 | actor: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 123 | critic: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 124 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 125 | critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 126 | discount: 0.99 127 | discount_lambda: 0.95 128 | imag_horizon: 15 129 | imag_batch: -1 130 | actor_grad: auto 131 | actor_grad_mix: 0.1 132 | actor_ent: 1e-4 133 | slow_target: True 134 | slow_target_update: 100 135 | slow_target_fraction: 1 136 | slow_baseline: True 137 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 138 | 139 | # Exploration 140 | expl_intr_scale: 1.0 141 | expl_extr_scale: 0.0 142 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 143 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 144 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 145 | disag_target: stoch 146 | disag_log: False 147 | disag_models: 10 148 | disag_offset: 1 149 | disag_action_cond: True 150 | expl_model_loss: kl 151 | 152 | # Decoupled World Model 153 | encoder_deco: { 154 | deco_res_layers: 2, 155 | deco_cnn_depth: 48, 156 | deco_cond_choice: trand, 157 | ctx_aug: none, 158 | } 159 | decoder_deco: { 160 | deco_attmask: 0.75, 161 | ctx_attmaskwarmup: -1, 162 | } 163 | 164 | 165 | metaworld: 166 | 167 | task: metaworld_drawer_open 168 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 169 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 170 | dataset: {batch: 50, length: 50} 171 | time_limit: 500 172 | action_repeat: 1 173 | eval_eps: 10 174 | prefill: 5000 175 | camera: corner 176 | steps: 256000 177 | concat_embed: False 178 | enc_lr_type: no_pretrain 179 | beta: 1.0 180 | stop_steps: 255000 181 | 182 | robodesk: 183 | 184 | task: robodesk_open_slide 185 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 186 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 187 | dataset: {batch: 50, length: 50} 188 | time_limit: 500 189 | action_repeat: 1 190 | eval_eps: 10 191 | prefill: 5000 192 | camera: corner 193 | steps: 256000 194 | concat_embed: False 195 | enc_lr_type: no_pretrain 196 | beta: 1.0 197 | stop_steps: 255000 198 | 199 | small: 200 | rssm: {hidden: 200, deter: 200} 201 | af_rssm: {hidden: 200, deter: 200} 202 | 203 | 204 | plaincnn: 205 | encoder_type: plaincnn 206 | decoder_type: plaincnn 207 | 208 | 209 | plainresnet: 210 | encoder_type: resnet 211 | decoder_type: resnet 212 | 213 | 214 | contextualized: 215 | encoder_type: deco_resnet 216 | decoder_type: deco_resnet -------------------------------------------------------------------------------- /configs/apv_finetuning.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | load_logdir: none 6 | seed: 0 7 | device: cuda 8 | wandb: {project: 'world-model',name: 'apv_finetuning',mode: 'online'} 9 | task: metaworld_drawer_open 10 | envs: 1 11 | envs_parallel: none 12 | render_size: [64, 64] 13 | dmc_camera: -1 14 | camera: corner 15 | dmcr_vary: all 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | eval_every: 1e4 22 | eval_eps: 1 23 | prefill: 10000 24 | pretrain: 100 25 | train_every: 5 26 | train_steps: 1 27 | expl_until: 0 28 | replay: {capacity: 1e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 29 | dataset: {batch: 16, length: 50} 30 | log_keys_video: ['image'] 31 | log_keys_sum: '^$' 32 | log_keys_mean: '^$' 33 | log_keys_max: '^$' 34 | precision: 16 35 | jit: True 36 | stop_steps: -1 37 | 38 | # CARLA 39 | carla_port: 2000 40 | carla: { 41 | collision_coeff: 1e-3, 42 | num_other_vehicles: 20, 43 | centering_reward_type: div, 44 | centering_reward_weight: 1.0, 45 | clip_collision_reward: 10.0, 46 | steer_coeff: 1.0, 47 | centering_border: 1.75, 48 | use_branch_lane_cut: True, 49 | changing_weather_speed: 0.1, 50 | } 51 | 52 | # Agent 53 | clip_rewards: identity 54 | expl_behavior: greedy 55 | expl_noise: 0.0 56 | eval_noise: 0.0 57 | eval_state_mean: False 58 | 59 | # Fine-tuning parameters 60 | load_modules: [encoder, decoder, af_rssm] 61 | load_strict: True 62 | enc_lr_type: no_pretrain 63 | concat_embed: False 64 | 65 | # Intrinsic bonus parameters 66 | k: 16 67 | beta: 1.0 68 | beta_type: abs 69 | intr_seq_length: 5 70 | intr_reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8, init: 1.0} 71 | queue_size: 4096 72 | queue_dim: 128 73 | 74 | # World Model 75 | grad_heads: [decoder, reward] 76 | pred_discount: False 77 | rssm: {action_free: False, fill_action: 50, ensemble: 1, embed_dim: 2048,hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 78 | af_rssm: {action_free: True, fill_action: 50, ensemble: 1, embed_dim: 3072,hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 79 | encoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 80 | encoder: { 81 | mlp_keys: '.*', 82 | cnn_keys: '.*', 83 | act: elu, 84 | norm: none, 85 | cnn_depth: 48, 86 | cnn_kernels: [4, 4, 4, 4], 87 | mlp_layers: [400, 400, 400, 400], 88 | res_norm: 'batch', 89 | res_depth: 3, 90 | res_layers: 2, 91 | } 92 | decoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 93 | decoder: { 94 | mlp_keys: '.*', 95 | cnn_keys: '.*', 96 | act: elu, 97 | norm: none, 98 | cnn_depth: 48, 99 | cnn_kernels: [5, 5, 6, 6], 100 | mlp_layers: [400, 400, 400, 400], 101 | res_norm: 'batch', 102 | res_depth: 3, 103 | res_layers: 2, 104 | } 105 | reward_head: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 106 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 107 | loss_scales: { 108 | af_kl: 0.0, 109 | kl: 1.0, 110 | reward: 1.0, 111 | action: 1.0, 112 | discount: 1.0, 113 | proprio: 1.0, 114 | aux_reward: 0.0, 115 | } 116 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 117 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 118 | enc_model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 119 | 120 | # Actor Critic 121 | actor: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 122 | critic: {layers: 4, input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 123 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 124 | critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 125 | discount: 0.99 126 | discount_lambda: 0.95 127 | imag_horizon: 15 128 | imag_batch: -1 129 | actor_grad: auto 130 | actor_grad_mix: 0.1 131 | actor_ent: 1e-4 132 | slow_target: True 133 | slow_target_update: 100 134 | slow_target_fraction: 1 135 | slow_baseline: True 136 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 137 | 138 | # Exploration 139 | expl_intr_scale: 1.0 140 | expl_extr_scale: 0.0 141 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 142 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 143 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 144 | disag_target: stoch 145 | disag_log: False 146 | disag_models: 10 147 | disag_offset: 1 148 | disag_action_cond: True 149 | expl_model_loss: kl 150 | 151 | # Contextualized World Model (subset of Decoupled World Model) 152 | # Decoupled World Model 153 | encoder_deco: { 154 | deco_res_layers: 2, 155 | deco_cnn_depth: 48, 156 | deco_cond_choice: trand, 157 | ctx_aug: none, 158 | } 159 | decoder_deco: { 160 | deco_attmask: 0.75, 161 | ctx_attmaskwarmup: -1, 162 | } 163 | 164 | 165 | metaworld: 166 | 167 | task: metaworld_drawer_open 168 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 169 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 170 | dataset: {batch: 50, length: 50} 171 | time_limit: 500 172 | action_repeat: 1 173 | eval_eps: 10 174 | prefill: 5000 175 | camera: corner 176 | steps: 256000 177 | concat_embed: False 178 | enc_lr_type: no_pretrain 179 | beta: 1.0 180 | stop_steps: 255000 181 | 182 | robodesk: 183 | 184 | task: robodesk_open_slide 185 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 186 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 187 | dataset: {batch: 50, length: 50} 188 | time_limit: 500 189 | action_repeat: 1 190 | eval_eps: 10 191 | prefill: 5000 192 | camera: corner 193 | steps: 256000 194 | concat_embed: False 195 | enc_lr_type: no_pretrain 196 | beta: 1.0 197 | stop_steps: 255000 198 | 199 | 200 | small: 201 | rssm: {hidden: 200, deter: 200} 202 | af_rssm: {hidden: 200, deter: 200} 203 | 204 | 205 | plaincnn: 206 | encoder_type: plaincnn 207 | decoder_type: plaincnn 208 | 209 | 210 | plainresnet: 211 | encoder_type: resnet 212 | decoder_type: resnet 213 | 214 | 215 | contextualized: 216 | encoder_type: deco_resnet 217 | decoder_type: deco_resnet 218 | -------------------------------------------------------------------------------- /configs/prelar_finetuning.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | load_logdir: none 6 | seed: 0 7 | device: cuda 8 | wandb: {project: 'world-model',name: 'prelar_finetuning',mode: 'online'} 9 | task: metaworld_drawer_open 10 | envs: 1 11 | envs_parallel: none 12 | render_size: [64, 64] 13 | dmc_camera: -1 14 | camera: corner 15 | dmcr_vary: all 16 | atari_grayscale: True 17 | time_limit: 0 18 | action_repeat: 1 19 | steps: 1e8 20 | log_every: 1e4 21 | eval_every: 1e4 22 | eval_eps: 1 23 | prefill: 10000 24 | pretrain: 100 25 | train_every: 5 26 | train_steps: 1 27 | expl_until: 0 28 | replay: {capacity: 1e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 29 | dataset: {batch: 16, length: 50} 30 | log_keys_video: ['image'] 31 | log_keys_sum: '^$' 32 | log_keys_mean: '^$' 33 | log_keys_max: '^$' 34 | precision: 16 35 | jit: True 36 | stop_steps: -1 37 | 38 | # CARLA 39 | carla_port: 2000 40 | carla: { 41 | collision_coeff: 1e-3, 42 | num_other_vehicles: 20, 43 | centering_reward_type: div, 44 | centering_reward_weight: 1.0, 45 | clip_collision_reward: 10.0, 46 | steer_coeff: 1.0, 47 | centering_border: 1.75, 48 | use_branch_lane_cut: True, 49 | changing_weather_speed: 0.1, 50 | } 51 | 52 | # Agent 53 | clip_rewards: identity 54 | expl_behavior: greedy 55 | expl_noise: 0.0 56 | eval_noise: 0.0 57 | eval_state_mean: False 58 | 59 | # Fine-tuning parameters 60 | load_modules: [encoder, decoder, af_rssm] 61 | load_strict: True 62 | enc_lr_type: no_pretrain 63 | concat_embed: False 64 | finetune_rssm: False # False # test 65 | 66 | 67 | # Intrinsic bonus parameters 68 | k: 16 69 | beta: 1.0 70 | beta_type: abs 71 | intr_seq_length: 5 72 | intr_reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8, init: 1.0} 73 | queue_size: 4096 74 | queue_dim: 128 75 | use_feat: False 76 | 77 | # World Model 78 | grad_heads: [decoder, reward] 79 | pred_discount: False 80 | rssm: {action_free: False, ensemble: 1,embed_dim: 3072, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 81 | 82 | encoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 83 | encoder: { 84 | mlp_keys: '.*', 85 | cnn_keys: '.*', 86 | act: elu, 87 | norm: none, 88 | cnn_depth: 48, 89 | cnn_kernels: [4, 4, 4, 4], 90 | mlp_layers: [400, 400, 400, 400], 91 | res_norm: 'batch', 92 | res_depth: 3, 93 | res_layers: 2, 94 | } 95 | decoder_type: resnet # ['plaincnn', 'resnet', 'deco_resnet'] 96 | decoder: { 97 | mlp_keys: '.*', 98 | cnn_keys: '.*', 99 | act: elu, 100 | norm: none, 101 | cnn_depth: 48, 102 | cnn_kernels: [5, 5, 6, 6], 103 | mlp_layers: [400, 400, 400, 400], 104 | res_norm: 'batch', 105 | res_depth: 3, 106 | res_layers: 2, 107 | } 108 | reward_head: {layers: 4, input_dim: 2048,units: 400, act: elu, norm: none, dist: mse} 109 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 110 | loss_scales: { 111 | af_kl: 0.0, 112 | kl: 1.0, 113 | reward: 1.0, 114 | action: 1.0, 115 | discount: 1.0, 116 | proprio: 1.0, 117 | aux_reward: 0.0, 118 | } 119 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 120 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 121 | enc_model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 122 | 123 | # Actor Critic 124 | actor: {layers: 4,input_dim: 2048,units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 125 | critic: {layers: 4,input_dim: 2048, units: 400, act: elu, norm: none, dist: mse} 126 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 127 | critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 128 | discount: 0.99 129 | discount_lambda: 0.95 130 | imag_horizon: 15 131 | imag_batch: -1 132 | actor_grad: auto 133 | actor_grad_mix: 0.1 134 | actor_ent: 1e-4 135 | slow_target: True 136 | slow_target_update: 100 137 | slow_target_fraction: 1 138 | slow_baseline: True 139 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 140 | 141 | # Exploration 142 | expl_intr_scale: 1.0 143 | expl_extr_scale: 0.0 144 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 145 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 146 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 147 | disag_target: stoch 148 | disag_log: False 149 | disag_models: 10 150 | disag_offset: 1 151 | disag_action_cond: True 152 | expl_model_loss: kl 153 | 154 | # Contextualized World Model (subset of Decoupled World Model) 155 | # Decoupled World Model 156 | encoder_deco: { 157 | deco_res_layers: 2, 158 | deco_cnn_depth: 48, 159 | deco_cond_choice: trand, 160 | ctx_aug: none, 161 | } 162 | decoder_deco: { 163 | deco_attmask: 0.75, 164 | ctx_attmaskwarmup: -1, 165 | } 166 | 167 | action_encoder: { 168 | hidden_dim: 1024, 169 | act: elu, 170 | deter: 1024, 171 | discrete: 32, 172 | stoch: 32, 173 | std_act: sigmoid2, 174 | type_: stoch, # deter | stoch | mix 175 | twl: 1 176 | } 177 | 178 | 179 | metaworld: 180 | 181 | task: metaworld_drawer_open 182 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 183 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 184 | dataset: {batch: 50, length: 50} 185 | time_limit: 500 186 | action_repeat: 1 187 | eval_eps: 10 188 | prefill: 5000 189 | camera: corner 190 | steps: 256000 191 | concat_embed: False 192 | enc_lr_type: no_pretrain 193 | beta: 1.0 194 | stop_steps: 255000 195 | 196 | robodesk: 197 | 198 | task: robodesk_open_slide 199 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 200 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 201 | dataset: {batch: 50, length: 50} 202 | time_limit: 500 203 | action_repeat: 1 204 | eval_eps: 10 205 | prefill: 5000 206 | camera: corner 207 | steps: 256000 208 | concat_embed: False 209 | enc_lr_type: no_pretrain 210 | beta: 1.0 211 | stop_steps: 255000 212 | 213 | 214 | small: 215 | rssm: {hidden: 200, deter: 200} 216 | af_rssm: {hidden: 200, deter: 200} 217 | 218 | 219 | plaincnn: 220 | encoder_type: plaincnn 221 | decoder_type: plaincnn 222 | 223 | 224 | plainresnet: 225 | encoder_type: resnet 226 | decoder_type: resnet 227 | 228 | 229 | contextualized: 230 | encoder_type: deco_resnet 231 | decoder_type: deco_resnet 232 | -------------------------------------------------------------------------------- /wmlib/nets/va_net/va_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as tdists 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | from ..encoder import VAResNetEncoder 9 | from ..modules import get_act_module 10 | from .. import core 11 | from ...core import dists 12 | 13 | 14 | class VANet(nn.Module): 15 | # def __init__(self,shape, config) -> None: 16 | def __init__( 17 | self, 18 | shape, 19 | cnn_keys=r".*", 20 | mlp_keys=r".*", 21 | act="elu", 22 | cnn_depth=48, 23 | mlp_layers=[400, 400, 400, 400], 24 | res_layers=2, 25 | res_depth=3, 26 | res_norm='none', 27 | hidden_dim=1024, 28 | deter=0, 29 | stoch = 64, 30 | discrete=0, 31 | std_act='sigmoid2', 32 | va_method='concate', # concate | diff | flow | cat-diff 33 | type_='stoch', # deter | stoch | mix 34 | **dummy_kwargs ) -> None: 35 | super().__init__() 36 | h,w,_ = shape['image'] 37 | self._encode_dim = cnn_depth * 2** (res_depth - 1) * h//2**(res_depth+1) * w//2**(res_depth+1) 38 | self._hidden_dim = hidden_dim 39 | self._act_module = get_act_module(act) 40 | self._stoch = stoch 41 | self._discrete = discrete 42 | self._deter = deter 43 | self._std_act = std_act 44 | self._va_method = va_method 45 | self.type = type_ 46 | self._patch_size = (h//16,w//16) 47 | self._patch_split = (16,16) 48 | self._patch_num = self._patch_split[0] * self._patch_split[1] 49 | 50 | action_dim = self._stoch * self._discrete if self._discrete else self._stoch * 2 51 | action_dim = self._deter if self.type =='deter' else action_dim 52 | action_dim = self._deter + action_dim if self.type == 'mix' else action_dim 53 | self.va_encoder = VAResNetEncoder(shapes=shape,cnn_keys=cnn_keys,mlp_keys=mlp_keys, 54 | act=act,cnn_depth=cnn_depth,mlp_layers=mlp_layers, 55 | res_layers=res_layers,res_depth=res_depth, 56 | res_norm=res_norm,va_method=self._va_method) #nets.VAEncoder(shapes, **config.va_encoder) 57 | self.action_in = nn.Sequential(nn.Linear(self._encode_dim,self._hidden_dim), 58 | self._act_module(), 59 | nn.Linear(self._hidden_dim,action_dim)) 60 | self._std_fn = { 61 | "softplus": lambda std: F.softplus(std), 62 | "sigmoid": lambda std: torch.sigmoid(std), 63 | "sigmoid2": lambda std: 2 * torch.sigmoid(std / 2), 64 | }[self._std_act] 65 | 66 | def forward(self,img0,img1,sample=True): 67 | if self._va_method in ['concate', 'mask']: 68 | x = torch.cat([img0,img1],dim=2) # b t c h w 69 | elif self._va_method == 'diff': 70 | x = img1 - img0 71 | elif self._va_method == 'flow': 72 | x = img1 73 | elif self._va_method == 'catdiff': 74 | x = torch.cat([img0,img1,img1-img0],dim=2) 75 | x = self.va_encoder({'image':x}) 76 | x = self.action_in(x) 77 | action_code = {} 78 | if self.type == 'deter': 79 | return {'deter':x} 80 | elif self.type == 'stoch': 81 | x_stoch = x 82 | elif self.type == 'mix': 83 | x_deter = x[..., :self._deter] 84 | x_stoch = x[..., self._deter:] 85 | action_code.update({'deter':x_deter}) 86 | else: 87 | raise NotImplementedError 88 | 89 | if self._discrete: 90 | x_stoch = rearrange(x_stoch,'... (s d) -> ... s d', s=self._stoch, d=self._discrete) 91 | dist = self.get_dist({'logit':x_stoch}) 92 | stoch = dist.sample() if sample else dist.mode 93 | stoch = rearrange(stoch,'... s d -> ... (s d)') 94 | action_code.update({'stoch':stoch,'logit':x_stoch}) 95 | else: 96 | mean, std = torch.chunk(x_stoch, 2, dim=-1) 97 | dist = self.get_dist({'mean':mean,'std':std}) 98 | stoch = dist.sample() if sample else dist.mode 99 | action_code.update({'stoch':stoch,'mean':mean,'std':std}) 100 | return action_code 101 | 102 | def get_dist(self, state): 103 | """ 104 | gets the stochastic state distribution 105 | """ 106 | if self._discrete: 107 | logit = state["logit"] 108 | logit = logit.float() 109 | dist = dists.Independent(dists.OneHotDist(logit), 1) 110 | else: 111 | mean, std = state['mean'], state['std'] 112 | std = self._std_fn(std) 113 | dist = dists.Independent(dists.Normal(mean, std), 1) 114 | return dist 115 | 116 | def append_action(self, action, ahead = False): 117 | new_action = torch.zeros_like(action[:,0,:],device=action.device).unsqueeze(dim=1) 118 | if ahead: 119 | return torch.cat([new_action, action], dim=1) 120 | return torch.cat([action, new_action], dim=1) 121 | 122 | def kl_loss(self, post, prior=None, forward: bool=False, balance: float=0.5, free: float=0.0, free_avg: bool=True): 123 | """ 124 | computes the kl loss 125 | """ 126 | if self.type == 'deter': 127 | value = torch.tensor(0,dtype=float,device=post['deter'].device) 128 | return value, value 129 | if self._discrete and prior is None: 130 | value = torch.log(torch.tensor(self._discrete,device=post['logit'].device)) 131 | return value, value 132 | kld = tdists.kl_divergence 133 | sg = core.dict_detach 134 | _device = post['stoch'].device 135 | if prior is None: 136 | prior = {'mean':torch.zeros_like(post['mean'],device=_device),'std':torch.ones_like(post['std'],device=_device)} 137 | 138 | lhs, rhs = (prior, post) if forward else (post, prior) 139 | mix = balance if forward else (1 - balance) 140 | 141 | free = torch.tensor(free) 142 | if balance == 0.5: 143 | value = kld(self.get_dist(lhs), self.get_dist(rhs)) 144 | loss = torch.maximum(value, free).mean() 145 | else: 146 | value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs))) 147 | value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs)) 148 | if free_avg: 149 | loss_lhs = torch.maximum(value_lhs.mean(), free) 150 | loss_rhs = torch.maximum(value_rhs.mean(), free) 151 | else: 152 | loss_lhs = torch.maximum(value_lhs, free).mean() 153 | loss_rhs = torch.maximum(value_rhs, free).mean() 154 | loss = mix * loss_lhs + (1 - mix) * loss_rhs 155 | return loss, value 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /wmlib/agents/actor_critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import core, nets 4 | 5 | 6 | class ActorCritic(nn.Module): 7 | 8 | def __init__(self, config, act_space, tfstep, enable_fp16=False): 9 | super(ActorCritic, self).__init__() 10 | 11 | self.config = config 12 | self.act_space = act_space 13 | self.tfstep = tfstep 14 | self.enable_fp16 = enable_fp16 15 | discrete = hasattr(act_space, "n") 16 | if self.config.actor.dist == "auto": 17 | self.config = self.config.update({ 18 | "actor.dist": "onehot" if discrete else "trunc_normal"}) 19 | if self.config.actor_grad == "auto": 20 | self.config = self.config.update({ 21 | "actor_grad": "reinforce" if discrete else "dynamics"}) 22 | self.actor = nets.MLP(act_space.shape[0], **self.config.actor) 23 | self.critic = nets.MLP([], **self.config.critic) 24 | if self.config.slow_target: 25 | self._target_critic = nets.MLP([], **self.config.critic) 26 | self._updates = 0 27 | else: 28 | self._target_critic = self.critic 29 | self.actor_opt = core.EmptyOptimizer() 30 | self.critic_opt = core.EmptyOptimizer() 31 | self.rewnorm = core.StreamNorm(**self.config.reward_norm) 32 | 33 | def train(self, world_model, start, is_terminal, reward_fn): 34 | metrics = {} 35 | hor = self.config.imag_horizon 36 | # The weights are is_terminal flags for the imagination start states. 37 | # Technically, they should multiply the losses from the second trajectory 38 | # step onwards, which is the first imagined step. However, we are not 39 | # training the action that led into the first step anyway, so we can use 40 | # them to scale the whole sequence. 41 | with torch.cuda.amp.autocast(enabled=self.enable_fp16): 42 | # delete grads 43 | world_model.zero_grad(set_to_none=True) 44 | self.actor.zero_grad(set_to_none=True) 45 | self.critic.zero_grad(set_to_none=True) 46 | 47 | seq = world_model.imagine(self.actor, start, is_terminal, hor) 48 | if self.config.actor_grad == "reinforce": 49 | with torch.no_grad(): 50 | reward = reward_fn(seq) 51 | else: 52 | reward = reward_fn(seq) 53 | seq["reward"], mets1 = self.rewnorm(reward) 54 | mets1 = {f"reward_{k}": v for k, v in mets1.items()} 55 | target, mets2 = self.target(seq) 56 | actor_loss, mets3 = self.actor_loss(seq, target) 57 | critic_loss, mets4 = self.critic_loss(seq, target) 58 | 59 | # Backward passes under autocast are not recommended. 60 | self.actor_opt.backward(actor_loss, retain_graph=True) 61 | self.critic_opt.backward(critic_loss) 62 | 63 | metrics.update(self.actor_opt.step(actor_loss)) 64 | metrics.update(self.critic_opt.step(critic_loss)) 65 | metrics.update(**mets1, **mets2, **mets3, **mets4) 66 | self.update_slow_target() 67 | return metrics 68 | 69 | def actor_loss(self, seq, target): 70 | # Actions: 0 [a1] [a2] a3 71 | # ^ | ^ | ^ | 72 | # / v / v / v 73 | # States: [z0]->[z1]-> z2 -> z3 74 | # Targets: t0 [t1] [t2] 75 | # Baselines: [v0] [v1] v2 v3 76 | # Entropies: [e1] [e2] 77 | # Weights: [ 1] [w1] w2 w3 78 | # Loss: l1 l2 79 | metrics = {} 80 | # Two states are lost at the end of the trajectory, one for the boostrap 81 | # value prediction and one because the corresponding action does not lead 82 | # anywhere anymore. One target is lost at the start of the trajectory 83 | # because the initial state comes from the replay buffer. 84 | policy = self.actor(seq["feat"][:-2].detach()) 85 | if self.config.actor_grad == "dynamics": 86 | objective = target[1:] 87 | elif self.config.actor_grad == "reinforce": 88 | baseline = self._target_critic(seq["feat"][:-2]).mode 89 | advantage = (target[1:] - baseline).detach() 90 | action = (seq["action"][1:-1]).detach() 91 | objective = policy.log_prob(action) * advantage 92 | elif self.config.actor_grad == "both": 93 | baseline = self._target_critic(seq["feat"][:-2]).mode 94 | advantage = (target[1:] - baseline).detach() 95 | action = (seq["action"][1:-1]).detach() 96 | objective = policy.log_prob(action) * advantage 97 | mix = core.schedule(self.config.actor_grad_mix, self.tfstep) 98 | objective = mix * target[1:] + (1 - mix) * objective 99 | metrics["actor_grad_mix"] = mix 100 | else: 101 | raise NotImplementedError(self.config.actor_grad) 102 | ent = policy.entropy() 103 | ent_scale = core.schedule(self.config.actor_ent, self.tfstep) 104 | objective += ent_scale * ent 105 | weight = seq["weight"].detach() 106 | actor_loss = -(weight[:-2] * objective).mean() 107 | metrics["actor_entropy"] = ent.mean().item() 108 | metrics["actor_entropy_scale"] = ent_scale 109 | return actor_loss, metrics 110 | 111 | def critic_loss(self, seq, target): 112 | # States: [z0] [z1] [z2] z3 113 | # Rewards: [r0] [r1] [r2] r3 114 | # Values: [v0] [v1] [v2] v3 115 | # Weights: [ 1] [w1] [w2] w3 116 | # Targets: [t0] [t1] [t2] 117 | # Loss: l0 l1 l2 118 | dist = self.critic(seq["feat"][:-1].detach()) 119 | target = target.detach() 120 | weight = seq["weight"].detach() 121 | critic_loss = -(dist.log_prob(target) * weight[:-1]).mean() 122 | metrics = {"critic": dist.mode.mean().item()} 123 | return critic_loss, metrics 124 | 125 | def target(self, seq): 126 | # States: [z0] [z1] [z2] [z3] 127 | # Rewards: [r0] [r1] [r2] r3 128 | # Values: [v0] [v1] [v2] [v3] 129 | # Discount: [d0] [d1] [d2] d3 130 | # Targets: t0 t1 t2 131 | reward = seq["reward"] 132 | disc = seq["discount"] 133 | value = self._target_critic(seq["feat"]).mode 134 | # Skipping last time step because it is used for bootstrapping. 135 | target = core.lambda_return( 136 | reward[:-1], value[:-1], disc[:-1], 137 | bootstrap=value[-1], 138 | lambda_=self.config.discount_lambda, 139 | axis=0) 140 | metrics = {} 141 | metrics["critic_slow"] = value.mean().item() 142 | metrics["critic_target"] = target.mean().item() 143 | return target, metrics 144 | 145 | def update_slow_target(self): # polyak update 146 | if self.config.slow_target: 147 | if self._updates % self.config.slow_target_update == 0: 148 | mix = 1.0 if self._updates == 0 else float( 149 | self.config.slow_target_fraction) 150 | for s, d in zip(self.critic.parameters(), self._target_critic.parameters()): 151 | d.data = mix * s.data + (1 - mix) * d.data 152 | self._updates += 1 153 | -------------------------------------------------------------------------------- /wmlib/agents/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import ABC, abstractmethod 4 | 5 | from . import expl 6 | from .. import core 7 | 8 | 9 | class BaseAgent(nn.Module, ABC): 10 | def __init__(self, config, obs_space, act_space, step): 11 | super(BaseAgent, self).__init__() 12 | 13 | self.config = config 14 | self.obs_space = obs_space 15 | self.act_space = act_space["action"] 16 | self.step = step 17 | self.precision = config.precision 18 | self.enable_fp16 = self.precision == 16 19 | 20 | def init_expl_behavior(self): 21 | if self.config.expl_behavior == "greedy": 22 | self._expl_behavior = self._task_behavior 23 | else: 24 | self._expl_behavior = getattr(expl, self.config.expl_behavior)( 25 | self.config, 26 | self.act_space, 27 | self.wm, 28 | self.step, 29 | lambda seq: self.wm.heads["reward"](seq["feat"]).mode(), 30 | ) 31 | 32 | def init_modules(self): 33 | # * Hacky: init modules without optimizers (once in opt) 34 | with torch.no_grad(): 35 | # bs, sq = 4, max(8, self.config.intr_seq_length) 36 | bs, sq = 1, max(4, self.config.intr_seq_length) 37 | if "atari" in self.config.task: 38 | channels = 1 if self.config.atari_grayscale else 3 39 | elif "dmc" in self.config.task or "metaworld" in self.config.task or "carla" in self.config.task: 40 | channels = 3 41 | else: 42 | raise NotImplementedError 43 | actions = self.act_space.shape[0] 44 | dummy_data = { 45 | "image": torch.rand(bs, sq, channels, *self.config.render_size), 46 | "action": torch.rand(bs, sq, actions), 47 | "reward": torch.rand(bs, sq), 48 | "is_first": torch.rand(bs, sq), 49 | "is_last": torch.rand(bs, sq), 50 | "is_terminal": torch.rand(bs, sq), 51 | } 52 | dummy_data["is_first"] = torch.zeros_like(dummy_data["is_first"]) 53 | dummy_data["is_first"][:, 0] = 1.0 54 | for key in self.obs_space: 55 | if key not in dummy_data: 56 | dummy_data[key] = torch.rand(bs, sq, *self.obs_space[key].shape) 57 | # TODO: we should not update the model here 58 | self.train(dummy_data) 59 | 60 | @abstractmethod 61 | def init_optimizers(self): 62 | pass 63 | 64 | def train(self, data, state=None): 65 | metrics = {} 66 | self.wm.train() 67 | state, outputs, wm_metrics = self.wm.train_iter(data, state) 68 | self.wm.eval() 69 | metrics.update({f'wm/{k}':v for k, v in wm_metrics.items()}) 70 | 71 | start = outputs["post"] 72 | start = core.dict_detach(start) 73 | reward = lambda seq: self.wm.heads["reward"](seq["feat"]).mode 74 | behavior_metrics = self._task_behavior.train(self.wm, start, data["is_terminal"], reward) 75 | metrics.update({f'behavior/{k}':v for k, v in behavior_metrics.items()}) 76 | if self.config.expl_behavior != "greedy": 77 | behavior_metrics = self._expl_behavior.train(start, outputs, data)[-1] 78 | metrics.update({"expl/" + key: value for key, value in behavior_metrics.items()}) 79 | return core.dict_detach(state), metrics 80 | 81 | def get_action(self, feat, mode): 82 | if mode == "eval": 83 | actor = self._task_behavior.actor(feat) 84 | action = actor.mode 85 | noise = self.config.eval_noise 86 | elif mode == "explore": 87 | actor = self._expl_behavior.actor(feat) 88 | action = actor.sample() 89 | noise = self.config.expl_noise 90 | elif mode == "train": 91 | actor = self._task_behavior.actor(feat) 92 | action = actor.sample() 93 | noise = self.config.expl_noise 94 | action = core.action_noise(action, noise, self.act_space) 95 | return action 96 | 97 | def save_all(self, logdir): 98 | torch.save(self.state_dict(), logdir / "variables.pt") 99 | 100 | 101 | class BaseWorldModel(nn.Module, ABC): 102 | 103 | def preprocess(self, obs): 104 | obs = obs.copy() 105 | for key, value in obs.items(): 106 | if key.startswith("log_"): 107 | continue 108 | if value.dtype == torch.int32: 109 | value = value.float() 110 | if value.dtype == torch.uint8: 111 | # value = value.float() / 255.0 - 0.5 112 | value = value.float() 113 | obs[key] = value 114 | 115 | obs["image"] = obs["image"] / 255.0 - 0.5 116 | if self.config.clip_rewards in ["identity", "sign", "tanh"]: 117 | obs["reward"] = { 118 | "identity": (lambda x: x), 119 | "sign": torch.sign, 120 | "tanh": torch.tanh, 121 | }[self.config.clip_rewards](obs["reward"]) 122 | else: 123 | obs["reward"] /= float(self.config.clip_rewards) 124 | obs["discount"] = 1.0 - obs["is_terminal"].float() 125 | obs["discount"] *= self.config.discount 126 | return obs 127 | 128 | def imagine(self, policy, start, is_terminal, horizon): 129 | flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) 130 | start = {k: flatten(v) for k, v in start.items()} 131 | if self.config.imag_batch != -1: 132 | index = torch.randperm(len(start["deter"]), device=start["deter"].device)[:self.config.imag_batch] 133 | select = lambda x: torch.index_select(x, dim=0, index=index) 134 | start = {k: select(v) for k, v in start.items()} 135 | start["feat"] = self.rssm.get_feat(start) 136 | start["action"] = torch.zeros_like(policy(start["feat"]).mode) 137 | seq = {k: [v] for k, v in start.items()} 138 | for _ in range(horizon): 139 | action = policy(seq["feat"][-1].detach()).sample() 140 | state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action) 141 | feat = self.rssm.get_feat(state) 142 | for key, value in {**state, "action": action, "feat": feat}.items(): 143 | seq[key].append(value) 144 | seq = {k: torch.stack(v, 0) for k, v in seq.items()} 145 | if "discount" in self.heads: 146 | disc = self.heads["discount"](seq["feat"]).mean 147 | if is_terminal is not None: 148 | # Override discount prediction for the first step with the true 149 | # discount factor from the replay buffer. 150 | true_first = 1.0 - flatten(is_terminal).to(disc.dtype) 151 | true_first *= self.config.discount 152 | disc = torch.cat([true_first[None], disc[1:]], 0) 153 | else: 154 | disc = self.config.discount * torch.ones(seq["feat"].shape[:-1]).to(seq["feat"].device) 155 | seq["discount"] = disc 156 | # Shift discount factors because they imply whether the following state 157 | # will be valid, not whether the current state is valid. 158 | seq["weight"] = torch.cumprod( 159 | torch.cat([torch.ones_like(disc[:1]), disc[:-1]], 0), 0 160 | ) 161 | return seq 162 | -------------------------------------------------------------------------------- /wmlib/nets/encoder/deco_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import torchvision.transforms as T 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | import numpy as np 8 | 9 | from .base import BaseEncoder 10 | from ..modules import get_act_module, ResidualStack 11 | 12 | 13 | 14 | class DecoupledResNetEncoder(BaseEncoder): 15 | ''' 16 | Decoupled ResNet Encoder 17 | @Function: disentangle the visual context and the embedding 18 | ''' 19 | def __init__( 20 | self, 21 | shapes, 22 | cnn_keys=r".*", 23 | mlp_keys=r".*", 24 | act="elu", 25 | cnn_depth=48, 26 | mlp_layers=[400, 400, 400, 400], 27 | res_layers=2, 28 | res_depth=3, 29 | res_norm='none', 30 | deco_res_layers=2, 31 | deco_cnn_depth=48, 32 | deco_cond_choice='trand', 33 | deco_aug='none', 34 | **dummy_kwargs, 35 | ): 36 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 37 | # self._act = get_act(act) 38 | self._act_module = get_act_module(act) 39 | self._cnn_depth = cnn_depth 40 | 41 | self._res_layers = res_layers 42 | self._res_depth = res_depth 43 | self._res_norm = res_norm 44 | 45 | self._deco_res_layers = deco_res_layers 46 | self._deco_cnn_depth = deco_cnn_depth 47 | self._deco_cond_choice = deco_cond_choice 48 | self._deco_aug = deco_aug 49 | 50 | self._cnn_net = nn.Sequential() 51 | h, w, c = self.shapes[self.cnn_keys[0]] # raw image shape 52 | self._cnn_net = nn.Sequential() 53 | self._cnn_net.add_module('convin', nn.Conv2d(c, self._cnn_depth, 3, 2, 1)) 54 | self._cnn_net.add_module('act', self._act_module()) 55 | for i in range(self._res_depth): 56 | depth = 2 ** i * self._cnn_depth 57 | input_channels = depth // 2 if i else self._cnn_depth 58 | self._cnn_net.add_module(f"res{i}", ResidualStack(input_channels, depth, 59 | self._res_layers, 60 | norm=self._res_norm)) 61 | self._cnn_net.add_module(f"pool{i}", nn.AvgPool2d(2, 2)) 62 | 63 | self._cnn_net.add_module('flatten', Rearrange('b c h w -> b (c h w)')) 64 | 65 | self._deco_net = nn.ModuleDict() 66 | self._deco_net.add_module('cond_aug',get_augmentation(self._deco_aug, (h,w))) 67 | self._deco_net.add_module('convin', nn.Sequential(nn.Conv2d(c, self._deco_cnn_depth, 3, 2, 1), 68 | self._act_module())) 69 | for i in range(self._res_depth): 70 | depth = 2 ** i * self._deco_cnn_depth 71 | input_channels = depth // 2 if i else self._deco_cnn_depth 72 | self._deco_net.add_module(f"res{i}", ResidualStack(input_channels, depth, 73 | self._deco_res_layers, 74 | norm=self._res_norm)) 75 | self._deco_net.add_module(f"pool{i}", nn.AvgPool2d(2, 2)) 76 | 77 | def forward(self, data, is_eval=False): 78 | key, shape = list(self.shapes.items())[0] 79 | batch_dims = data[key].shape[:-len(shape)] 80 | data = { 81 | k: torch.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):]) 82 | for k, v in data.items() 83 | } 84 | 85 | output, shortcut = self._cnn({k: data[k] for k in self.cnn_keys}, batch_dims, is_eval) 86 | 87 | if is_eval: 88 | return output.reshape(batch_dims + output.shape[1:]) 89 | else: 90 | return { 91 | 'embed': output.reshape(batch_dims + output.shape[1:]), 92 | 'shortcut': shortcut, 93 | } 94 | 95 | def _cnn(self, data, batch_dims=None, is_eval=False): 96 | x = torch.cat(list(data.values()), -1) 97 | x = x.to(memory_format=torch.channels_last) 98 | embed = self._cnn_net(x) 99 | shortcuts = {} 100 | if not is_eval: 101 | b, t = batch_dims 102 | with torch.no_grad(): 103 | ctx = self.get_context(rearrange(x, '(b t) c h w -> b t c h w', b=b))#(x.reshape(batch_dims + x.shape[1:])) # [B, T, C, H, W] => [B, C, H, W] 104 | ctx = self._deco_net['cond_aug'](ctx) # [B, C, H, W] 105 | ctx = rearrange(ctx, 'b t c h w -> (b t) c h w') 106 | ctx = self._deco_net['convin'](ctx) 107 | for i in range(self._res_depth): 108 | ctx = self._deco_net[f"res{i}"](ctx) 109 | shortcuts[ctx.shape[2]] = rearrange(ctx, '(b t) c h w -> b t c h w', b=b) 110 | ctx = self._deco_net[f"pool{i}"](ctx) 111 | return embed, shortcuts 112 | 113 | 114 | # TODO: clean up or rename t0 tlast trand 115 | def get_context(self, frames): 116 | """ 117 | frames: [B, T, C, H, W] 118 | """ 119 | with torch.no_grad(): 120 | if self._deco_cond_choice == 't0': 121 | # * initial frame 122 | context = frames[:, 0].unsqueeze(1) # [B, C, H, W] 123 | elif self._deco_cond_choice == 'tlast': 124 | # * last frame 125 | context = frames[:, -1].unsqueeze(1) # [B, C, H, W] 126 | elif self._deco_cond_choice == 'trand': 127 | # * timestep randomization 128 | idx = torch.from_numpy(np.random.choice(frames.shape[1], frames.shape[0])).to(frames.device) 129 | idx = idx.reshape(-1, 1, 1, 1, 1).repeat(1, 1, *frames.shape[-3:]) # [B, 1, C, H, W] 130 | context = frames.gather(1, idx) # .squeeze(1) # [B, 1, C, H, W]# [B, C, H, W] 131 | elif self._deco_cond_choice == 'diff': 132 | idx = np.arange(frames.shape[1]) 133 | idx[1:]-=1 134 | idx = torch.from_numpy(idx).to(frames.device) 135 | idx = idx.reshape(1,frames.shape[1], 1, 1, 1).repeat(frames.shape[0], 1, *frames.shape[-3:]) 136 | context = frames.gather(1, idx) 137 | elif self._deco_cond_choice == 'self': 138 | idx = torch.from_numpy(np.arange(frames.shape[1])).to(frames.device) 139 | idx = idx.reshape(1,frames.shape[1], 1, 1, 1).repeat(frames.shape[0], 1, *frames.shape[-3:]) 140 | context = frames.gather(1, idx) 141 | else: 142 | raise NotImplementedError 143 | return context 144 | 145 | 146 | def get_augmentation(aug_type, shape): 147 | if aug_type == 'none': 148 | return nn.Identity() 149 | elif aug_type == 'shift': 150 | return nn.Sequential( 151 | nn.ReplicationPad2d(padding=8), 152 | kornia.augmentation.RandomCrop(shape[-2:]) 153 | ) 154 | elif aug_type == 'shift4': 155 | return nn.Sequential( 156 | nn.ReplicationPad2d(padding=4), 157 | kornia.augmentation.RandomCrop(shape[-2:]) 158 | ) 159 | elif aug_type == 'flip': 160 | return T.RandomHorizontalFlip(p=0.5) 161 | elif aug_type == 'scale': 162 | return T.RandomResizedCrop( 163 | size=shape[-2:], scale=[0.666667, 1.0], ratio=(0.75, 1.333333)) 164 | elif aug_type == 'erasing': 165 | return kornia.augmentation.RandomErasing() 166 | else: 167 | raise NotImplementedError 168 | -------------------------------------------------------------------------------- /wmlib/nets/dynamics/rssm.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import jit 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as tdist 7 | from torchtyping import TensorType, patch_typeguard 8 | from typing import Tuple 9 | from typeguard import typechecked 10 | from einops import rearrange 11 | from einops.layers.torch import Rearrange 12 | from collections import OrderedDict 13 | 14 | from ... import core 15 | from ..modules import get_act, NormLayer 16 | from .base import BaseDynamics, State 17 | 18 | 19 | class EnsembleRSSM(BaseDynamics): 20 | r""" 21 | References: 22 | - Hafner, Danijar, et al. "Learning latent dynamics for planning from pixels." 23 | - Hafner, Danijar, et al. "Dream to control: Learning behaviors by latent imagination." 24 | - Hafner, Danijar, et al. "Mastering atari with discrete world models." 25 | 26 | """ 27 | 28 | def __init__( 29 | self, 30 | action_free=False, 31 | fill_action=None, # 50 in apv 32 | action_dim=4, 33 | embed_dim=1536, 34 | ensemble=5, 35 | stoch=30, 36 | deter=200, 37 | hidden=200, 38 | discrete=False, 39 | act="elu", 40 | norm="none", 41 | std_act="softplus", 42 | min_std=0.1, 43 | concat_embed=False, 44 | ): 45 | super().__init__( 46 | action_free, fill_action, ensemble, stoch, deter, hidden, discrete, act, norm, std_act, min_std, 47 | ) 48 | self._cell = torch.jit.script(GRUCell(self._hidden, self._deter, norm=True)) 49 | img_in_dim = self._stoch * self._discrete if self._discrete else self._stoch 50 | img_in_dim += self._fill_action if self._fill_action else action_dim 51 | self._img_in = nn.Sequential(nn.Linear(img_in_dim, self._hidden), 52 | NormLayer(self._norm, self._hidden), 53 | self._act_module()) 54 | out_in_dim = self._deter + embed_dim 55 | if concat_embed: 56 | out_in_dim *= 2 57 | self._obs_out = nn.Sequential(nn.Linear(out_in_dim, self._hidden), 58 | NormLayer(self._norm, self._hidden), 59 | self._act_module()) 60 | if self._discrete: 61 | self._obs_out.add_module('obs_out_head', nn.Linear(self._hidden, self._stoch * self._discrete)) 62 | self._obs_out.add_module('obs_out_rerange',Rearrange('... (s d) -> ... s d', s=self._stoch, d=self._discrete)) 63 | else: 64 | self._obs_out.add_module('obs_out_head', nn.Linear(self._hidden, self._stoch * 2)) 65 | 66 | def initial(self, batch_size: int, device) -> State: 67 | """ 68 | returns initial RSSM state 69 | """ 70 | 71 | if self._discrete: 72 | state = dict( 73 | logit=torch.zeros(batch_size, self._stoch, self._discrete), 74 | stoch=torch.zeros(batch_size, self._stoch, self._discrete), 75 | deter=self._cell.get_initial_state(batch_size)) 76 | else: 77 | state = dict( 78 | mean=torch.zeros(batch_size, self._stoch), 79 | std=torch.zeros(batch_size, self._stoch), 80 | stoch=torch.zeros(batch_size, self._stoch), 81 | deter=self._cell.get_initial_state(batch_size)) 82 | return core.dict_to_device(state, device) 83 | 84 | # @jit.script_method 85 | def observe( 86 | self, 87 | embed: TensorType["batch", "seq", "emb_dim"], 88 | action: TensorType["batch", "seq", "act_dim"], 89 | is_first, 90 | state: State = None 91 | ): 92 | # a permute of (batch, sequence) to (sequence, batch) 93 | swap = lambda x: rearrange(x, 'b t ... -> t b ...') 94 | if state is None: 95 | state = self.initial(action.shape[0], action.device) 96 | embed, action, is_first = swap(embed), swap(action), swap(is_first) 97 | post, prior = core.sequence_scan( 98 | self.obs_step, 99 | state, action, embed, is_first 100 | ) 101 | post = {k: swap(v) for k, v in post.items()} # put to (batch, sequence) again 102 | prior = {k: swap(v) for k, v in prior.items()} 103 | return post, prior 104 | 105 | def obs_step( 106 | self, 107 | prev_state: State, 108 | prev_action: TensorType["batch", "act_dim"], 109 | embed: TensorType["batch", "emb_dim"], 110 | is_first: TensorType["batch"], 111 | sample=True, 112 | ) -> Tuple[State, State]: 113 | maskout = lambda x: torch.einsum("b,b...->b...", 1.0 - is_first.to(x.dtype), x) 114 | prev_state = core.dict_apply(prev_state, maskout) 115 | prev_action = maskout(prev_action) 116 | 117 | prior = self.img_step(prev_state, prev_action, sample) 118 | x = torch.cat([prior["deter"], embed], -1) # embed is encoder conv output 119 | x = self._obs_out(x) 120 | stats = self._suff_stats_layer(x) 121 | dist = self.get_dist(stats) 122 | stoch = dist.sample() if sample else dist.mode 123 | post = {"stoch": stoch, "deter": prior["deter"], **stats} 124 | return post, prior 125 | 126 | def img_step( 127 | self, 128 | prev_state: State, 129 | prev_action: TensorType["batch", "act_dim"], 130 | sample=True, 131 | ) -> State: 132 | prev_stoch = prev_state["stoch"] 133 | if self._discrete: 134 | prev_stoch = torch.reshape(prev_stoch, (*prev_stoch.shape[:-2], self._stoch * self._discrete)) 135 | x = torch.cat([prev_stoch, self.fill_action_with_zero(prev_action)], -1) 136 | x = self._img_in(x) 137 | deter = prev_state["deter"] 138 | x, deter = self._cell(x, deter) 139 | stats = self._suff_stats_ensemble(x) 140 | index = int(tdist.Uniform(0, self._ensemble).sample().item()) 141 | stats = {k: v[index] for k, v in stats.items()} 142 | dist = self.get_dist(stats) 143 | stoch = dist.sample() if sample else dist.mode # mode: the max probability 144 | prior = {"stoch": stoch, "deter": deter, **stats} 145 | return prior 146 | 147 | 148 | class GRUCell(nn.Module): 149 | 150 | def __init__(self, input_size, size, norm=False, act=torch.tanh, update_bias=-1): 151 | super().__init__() 152 | self._size = size 153 | self._act = get_act(act) 154 | self._norm = norm 155 | self._update_bias = update_bias 156 | self._layer = nn.Linear(input_size + size, 3 * size, bias=norm is not None) 157 | if norm: 158 | self._norm = nn.LayerNorm(3 * size, eps=1e-3) # eps equal to tf 159 | 160 | @property 161 | def state_size(self): 162 | return self._size 163 | 164 | @torch.jit.export 165 | def get_initial_state(self, batch_size: int): # defined by tf.keras.layers.AbstractRNNCell 166 | return torch.zeros(batch_size, self._size) 167 | 168 | # @jit.script_method 169 | def forward(self, input, state): 170 | parts = self._layer(torch.cat([input, state], -1)) 171 | if self._norm is not False: # check if jit compatible 172 | parts = self._norm(parts) 173 | reset, cand, update = torch.chunk(parts, 3, -1) 174 | reset = torch.sigmoid(reset) 175 | cand = self._act(reset * cand) # it also multiplies the reset by the input 176 | update = torch.sigmoid(update + self._update_bias) 177 | output = update * cand + (1 - update) * state 178 | return output, output 179 | -------------------------------------------------------------------------------- /wmlib/nets/dynamics/base.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC, abstractmethod 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as tdist 7 | from torchtyping import TensorType, patch_typeguard 8 | from typing import Dict, Tuple, Union 9 | from typeguard import typechecked 10 | from collections import OrderedDict 11 | from einops.layers.torch import Rearrange 12 | 13 | from ... import core 14 | from ...core import dists 15 | from ..modules import get_act_module, NormLayer 16 | 17 | 18 | State = Dict[str, torch.Tensor] # FIXME to be more specified 19 | 20 | 21 | class BaseDynamics(nn.Module, ABC): 22 | 23 | def __init__( 24 | self, 25 | action_free=False, 26 | fill_action=None, # 50 in apv 27 | ensemble=5, 28 | stoch=30, 29 | deter=200, 30 | hidden=200, 31 | discrete=False, 32 | act="elu", 33 | norm="none", 34 | std_act="softplus", 35 | min_std=0.1, 36 | ): 37 | super().__init__() 38 | self._action_free = action_free 39 | self._fill_action = fill_action 40 | self._ensemble = ensemble 41 | self._stoch = stoch 42 | self._deter = deter 43 | self._hidden = hidden 44 | self._discrete = discrete 45 | # self._act = get_act(act) 46 | self._act_module = get_act_module(act) 47 | self._norm = norm 48 | self._std_act = std_act 49 | self._min_std = min_std 50 | self._std_fn = { 51 | "softplus": lambda std: F.softplus(std), 52 | "sigmoid": lambda std: torch.sigmoid(std), 53 | "sigmoid2": lambda std: 2 * torch.sigmoid(std / 2), 54 | }[self._std_act] 55 | self._ensemble_out_layers = nn.ModuleList() 56 | for k in range(self._ensemble): 57 | self._ensemble_out_layers.append(nn.Sequential(OrderedDict([ 58 | (f'img_out_{k}', nn.Linear(self._hidden, self._hidden)), 59 | (f'img_out_norm_{k}', NormLayer(self._norm, self._hidden)), 60 | (f'img_out_act_{k}', self._act_module()) 61 | ]))) 62 | if self._discrete: 63 | self._ensemble_out_layers[-1].add_module(f'img_out_head_{k}', nn.Linear(self._hidden, self._stoch * self._discrete)) 64 | self._ensemble_out_layers[-1].add_module(f'rerange_{k}', Rearrange('... (s d) -> ... s d', s=self._stoch, d=self._discrete)) 65 | else: 66 | self._ensemble_out_layers[-1].add_module(f'img_out_head_{k}', nn.Linear(self._hidden, 2 * self._stoch)) 67 | 68 | 69 | 70 | @abstractmethod 71 | def initial(self, batch_size: int, device) -> State: 72 | pass 73 | 74 | def fill_action_with_zero(self, action): 75 | # action: [*B, action] 76 | B, D = action.shape[:-1], action.shape[-1] 77 | if self._action_free: 78 | return torch.zeros([*B, self._fill_action]).to(action.device) 79 | else: 80 | if self._fill_action is not None: 81 | zeros = torch.zeros([*B, self._fill_action - D]).to(action.device) 82 | return torch.cat([action, zeros], axis=1) 83 | else: 84 | # doing nothing 85 | return action 86 | 87 | @abstractmethod 88 | def observe( 89 | self, 90 | embed: TensorType["batch", "seq", "emb_dim"], 91 | action: TensorType["batch", "seq", "act_dim"], 92 | is_first, 93 | state: State = None 94 | ): 95 | pass 96 | 97 | def imagine( 98 | self, 99 | action: TensorType["batch", "seq", "act_dim"], 100 | state: State = None 101 | ): 102 | # a permute of (batch, sequence) to (sequence, batch) 103 | swap = lambda x: torch.permute(x, [1, 0] + list(range(2, len(x.shape)))) 104 | if state is None: 105 | state = self.initial(action.shape[0], action.device) 106 | assert isinstance(state, dict), state 107 | action = swap(action) 108 | prior = core.sequence_scan(self.img_step, state, action)[0] 109 | prior = {k: swap(v) for k, v in prior.items() if k != "mems"} 110 | return prior 111 | 112 | def get_feat(self, state): 113 | """ 114 | gets stoch and deter as tensor 115 | """ 116 | 117 | # FIXME verify shapes of this function 118 | stoch = state["stoch"] 119 | if self._discrete: 120 | stoch = torch.reshape(stoch, (*stoch.shape[:-2], self._stoch * self._discrete)) 121 | return torch.cat([stoch, state["deter"]], -1) 122 | 123 | def get_dist(self, state: State): 124 | """ 125 | gets the stochastic state distribution 126 | """ 127 | if self._discrete: 128 | logit = state["logit"] 129 | logit = logit.float() 130 | dist = dists.Independent(dists.OneHotDist(logit), 1) 131 | else: 132 | mean, std = state["mean"], state["std"] 133 | mean = mean.float() 134 | std = std.float() 135 | dist = dists.Independent(dists.Normal(mean, std), 1) 136 | return dist 137 | 138 | @abstractmethod 139 | def obs_step( 140 | self, 141 | prev_state: State, 142 | prev_action: TensorType["batch", "act_dim"], 143 | embed: TensorType["batch", "emb_dim"], 144 | is_first: TensorType["batch"], 145 | sample=True, 146 | ) -> Tuple[State, State]: 147 | pass 148 | 149 | @abstractmethod 150 | def img_step( 151 | self, 152 | prev_state: State, 153 | prev_action: TensorType["batch", "act_dim"], 154 | sample=True, 155 | ) -> State: 156 | pass 157 | 158 | def kl_loss(self, post: State, prior: State, forward: bool, balance: float, free: float, free_avg: bool): 159 | kld = tdist.kl_divergence 160 | sg = core.dict_detach 161 | lhs, rhs = (prior, post) if forward else (post, prior) 162 | mix = balance if forward else (1 - balance) 163 | 164 | free = torch.tensor(free) 165 | if balance == 0.5: 166 | value = kld(self.get_dist(lhs), self.get_dist(rhs)) 167 | loss = torch.maximum(value, free).mean() 168 | else: 169 | value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs))) 170 | value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs)) 171 | if free_avg: 172 | loss_lhs = torch.maximum(value_lhs.mean(), free) 173 | loss_rhs = torch.maximum(value_rhs.mean(), free) 174 | else: 175 | loss_lhs = torch.maximum(value_lhs, free).mean() 176 | loss_rhs = torch.maximum(value_rhs, free).mean() 177 | loss = mix * loss_lhs + (1 - mix) * loss_rhs 178 | return loss, value 179 | 180 | def _suff_stats_ensemble(self, inp: TensorType["batch", "hidden"]): 181 | bs = list(inp.shape[:-1]) 182 | assert len(bs) == 1, bs 183 | inp = inp.reshape([-1, inp.shape[-1]]) 184 | stats = [] 185 | for k in range(self._ensemble): 186 | x = self._ensemble_out_layers[k](inp) 187 | stats.append(self._suff_stats_layer(x)) 188 | stats = { 189 | k: torch.stack([x[k] for x in stats], 0) 190 | for k in stats[0].keys() 191 | } 192 | return stats 193 | 194 | def _suff_stats_layer(self, x: TensorType["batch", "hidden"]): 195 | if self._discrete: 196 | return {"logit": x} 197 | else: 198 | mean, std = torch.chunk(x, 2, -1) 199 | std = self._std_fn(std) 200 | std = std + self._min_std 201 | return {"mean": mean, "std": std} 202 | -------------------------------------------------------------------------------- /wmlib/nets/decoder/deco_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | from einops import rearrange, unpack 5 | 6 | from .base import BaseDecoder 7 | from ..modules import ResidualStack 8 | from ... import core 9 | 10 | from einops import rearrange, repeat 11 | 12 | class DecoupledResNetDecoder(BaseDecoder): 13 | 14 | # TODO: remame args 15 | def __init__( 16 | self, 17 | shapes, 18 | cnn_keys=r".*", 19 | mlp_keys=r".*", 20 | cnn_depth=48, 21 | cnn_input_dim = 2048, 22 | mlp_layers=[400, 400, 400, 400], 23 | res_layers=2, 24 | res_depth=3, 25 | res_norm='none', 26 | deco_attmask=0.75, 27 | deco_attmaskwarmup=-1, 28 | **dummy_kwargs, 29 | ): 30 | super().__init__(shapes, cnn_keys, mlp_keys, mlp_layers) 31 | 32 | self._cnn_depth = cnn_depth 33 | 34 | self._res_layers = res_layers 35 | self._res_depth = res_depth 36 | self._res_norm = res_norm 37 | self._cnn_input_dim = cnn_input_dim 38 | 39 | self._deco_attmask = deco_attmask 40 | self._deco_attmaskwarmup = None if deco_attmaskwarmup == -1 else deco_attmaskwarmup 41 | 42 | self._training_step = 0 43 | self._current_attmask = None 44 | 45 | cnn_out_channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 46 | hw = 64 // 2**(self._res_depth + 1) 47 | self.convin = nn.Sequential(nn.Linear(self._cnn_input_dim, hw * hw * (2**(self._res_depth - 1)) * self._cnn_depth), 48 | Rearrange('b t (c h w)-> (b t) c h w',h=hw,w=hw)) 49 | self._cnn_nn = nn.ModuleDict() 50 | self.ctx_shape = [hw*2**(i+1) for i in range(self._res_depth)] 51 | for i in range(self._res_depth): 52 | depth = depth // 2 if i else int((2**(self._res_depth - 1)) * self._cnn_depth) 53 | spitial_dim = self.ctx_shape[i] 54 | add_dim = hw * self._cnn_depth // (2**i) 55 | self._cnn_nn.add_module(f"unpool{i}", nn.UpsamplingNearest2d(scale_factor=2)) 56 | self._cnn_nn.add_module(f"res{i}", ResidualStack(depth, depth//2, 57 | self._res_layers, 58 | norm=self._res_norm, dec=True, 59 | addin_dim=add_dim, 60 | has_addin=(lambda x: x % 2 == 0) if spitial_dim < 32 else (lambda x: False), 61 | cross_att=True, 62 | mask=self._deco_attmask, 63 | spatial_dim=(spitial_dim,spitial_dim))) 64 | self.convout = nn.ConvTranspose2d(depth//2, sum(cnn_out_channels.values()), 3, 2, 1, output_padding=1) 65 | self._cnn_out_ps = [[out_channel] for out_channel in cnn_out_channels.values()] 66 | 67 | 68 | def forward(self, features, shortcuts=None): 69 | outputs = {} 70 | if self.cnn_keys: 71 | outputs.update(self._cnn(features, shortcuts)) 72 | if self.mlp_keys: 73 | outputs.update(self._mlp(features)) 74 | return outputs 75 | 76 | # # TODO: clean up 77 | # def _cnn_old(self, features, shortcuts=None): 78 | # if self.training: 79 | # self._training_step += 1 80 | 81 | # if self._deco_attmaskwarmup is not None: 82 | # self._current_attmask = (1 - self._deco_attmask) * \ 83 | # (1 - min(1, self._training_step / self._deco_attmaskwarmup)) + self._deco_attmask 84 | # if self._training_step % 100 == 0: 85 | # print(f"Current attention mask: {self._current_attmask} {self._training_step}") 86 | # else: 87 | # self._current_attmask = None 88 | 89 | # seq_len = features.shape[1] 90 | # channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 91 | 92 | # L = self._res_depth 93 | # hw = 64 // 2**(self._res_depth + 1) 94 | # x = self.get("convin", nn.Linear, features.shape[-1], hw * hw * (2**(L - 1)) * self._cnn_depth)(features) 95 | # # x = torch.reshape(x, [-1, (2**(L - 1)) * self._cnn_depth, hw, hw]).to(memory_format=torch.channels_last) 96 | # x = rearrange(x, 'b t (c h w) -> (b t) c h w',h=hw,w=hw) 97 | # for i in range(L): 98 | # x = self.get(f"unpool{i}", nn.UpsamplingNearest2d, scale_factor=2)(x) 99 | # depth = x.shape[1] 100 | 101 | # ctx = shortcuts[x.shape[2]] 102 | # addin = ctx 103 | # # addin = rearrange(ctx, '(b t) c h w -> b t c h w',b=features.shape[0]) 104 | # # addin = ctx.reshape(features.shape[0], -1, *ctx.shape[-3:]) # [B, K, C, H, W] 105 | # addin = repeat(addin, 'b c h w -> (b repeat) c h w', repeat=x.shape[0] // addin.shape[0]) # repeat_interleave 106 | # # addin = addin.repeat_interleave(x.shape[0] // addin.shape[0], dim=0) # [BT, K, C, H, W] 107 | # # addin = addin.reshape(-1, *addin.shape[-3:]) # [BTK, C, H, W] 108 | 109 | # x = self.get(f"res{i}", ResidualStack, x.shape[1], depth // 2, 110 | # self._res_layers, norm=self._res_norm, dec=True, 111 | # addin_dim=addin.shape[1], 112 | # has_addin=(lambda x: x % 2 == 0) if ctx.shape[-1] < 32 else (lambda x: False), 113 | # cross_att=True, 114 | # mask=self._deco_attmask, 115 | # spatial_dim=x.shape[-2:], 116 | # )(x, addin, attmask=self._current_attmask) 117 | 118 | # depth = sum(channels.values()) 119 | # x = self.get(f"convout", nn.ConvTranspose2d, x.shape[1], depth, 3, 2, 1, output_padding=1)(x) 120 | 121 | # x = x.reshape(features.shape[:-1] + x.shape[1:]) 122 | # means = torch.split(x, list(channels.values()), 2) 123 | # dists = { 124 | # key: core.dists.Independent(core.dists.MSE(mean), 3) 125 | # for (key, shape), mean in zip(channels.items(), means) 126 | # } 127 | # return dists 128 | 129 | def _cnn(self, features, shortcuts=None): 130 | if self.training: 131 | self._training_step += 1 132 | 133 | if self._deco_attmaskwarmup is not None: 134 | self._current_attmask = (1 - self._deco_attmask) * \ 135 | (1 - min(1, self._training_step / self._deco_attmaskwarmup)) + self._deco_attmask 136 | if self._training_step % 100 == 0: 137 | print(f"Current attention mask: {self._current_attmask} {self._training_step}") 138 | else: 139 | self._current_attmask = None 140 | 141 | x = self.convin(features).to(memory_format=torch.channels_last) 142 | for i in range(self._res_depth): 143 | ctx = shortcuts[self.ctx_shape[i]] 144 | addin = rearrange(ctx, 'b t c h w -> (b t) c h w') 145 | addin = repeat(addin, 'b c h w -> (b repeat) c h w', repeat=x.shape[0] // addin.shape[0]) 146 | x = self._cnn_nn[f"unpool{i}"](x) 147 | x = self._cnn_nn[f"res{i}"](x, addin, attmask=self._current_attmask) 148 | x = self.convout(x) 149 | x = rearrange(x,'(b t) c h w -> b t c h w',b=features.shape[0]) 150 | means = unpack(x, self._cnn_out_ps, 'b t * h w ') 151 | dists = { 152 | key: core.dists.Independent(core.dists.MSE(mean), 3) 153 | for key, mean in zip(self.cnn_keys, means) 154 | } 155 | return dists 156 | # addin = ctx 157 | --------------------------------------------------------------------------------