├── gridworld ├── model │ ├── tiny_llama │ │ ├── __init__.py │ │ ├── config.py │ │ ├── fused_rotary_embedding.py │ │ ├── model.py │ │ └── utils.py │ ├── __init__.py │ ├── ad.py │ └── dpt.py ├── cfg │ ├── alg │ │ ├── ppo_base.yaml │ │ ├── ppo_dktd.yaml │ │ ├── ppo_dr.yaml │ │ └── ppo_dr_permuted.yaml │ ├── env │ │ ├── darkroom.yaml │ │ ├── dark_key_to_door.yaml │ │ └── darkroom_permuted.yaml │ └── model │ │ ├── dpt_base.yaml │ │ ├── ad_base.yaml │ │ ├── idt_base.yaml │ │ ├── ad_dktd.yaml │ │ ├── ad_dr.yaml │ │ ├── ad_dr_permuted.yaml │ │ ├── dpt_dr.yaml │ │ ├── dpt_dr_permuted.yaml │ │ ├── dpt_dktd.yaml │ │ ├── idt_dr.yaml │ │ ├── idt_dktd.yaml │ │ └── idt_dr_permuted.yaml ├── alg │ ├── __init__.py │ ├── ppo.py │ └── utils.py ├── env │ ├── __init__.py │ ├── dark_key_to_door.py │ └── darkroom.py ├── evaluate.py ├── utils.py ├── collect_data.py ├── dataset.py └── train.py ├── metaworld ├── model │ ├── tiny_llama │ │ ├── __init__.py │ │ ├── config.py │ │ ├── fused_rotary_embedding.py │ │ └── model.py │ ├── __init__.py │ └── ad.py ├── cfg │ ├── env │ │ └── ml1.yaml │ ├── alg │ │ ├── ppo_ml1.yaml │ │ └── sac_ml1.yaml │ └── model │ │ ├── ad_base.yaml │ │ ├── idt_base.yaml │ │ ├── ad_ml1.yaml │ │ └── idt_ml1.yaml ├── alg │ ├── __init__.py │ ├── ppo.py │ ├── sac.py │ └── utils.py ├── evaluate.py ├── utils.py ├── collect_data.py ├── dataset.py └── train.py ├── environment.yml └── README.md /gridworld/model/tiny_llama/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metaworld/model/tiny_llama/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gridworld/cfg/alg/ppo_base.yaml: -------------------------------------------------------------------------------- 1 | alg: PPO 2 | policy: MlpPolicy -------------------------------------------------------------------------------- /metaworld/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .ad import AD 2 | from .idt import IDT 3 | 4 | MODEL = { 5 | "AD": AD, 6 | "IDT": IDT, 7 | } -------------------------------------------------------------------------------- /gridworld/alg/__init__.py: -------------------------------------------------------------------------------- 1 | from .ppo import PPOWrapper 2 | from .utils import HistoryLoggerCallback 3 | 4 | ALGORITHM = { 5 | 'PPO': PPOWrapper, 6 | } -------------------------------------------------------------------------------- /metaworld/cfg/env/ml1.yaml: -------------------------------------------------------------------------------- 1 | env: metaworld 2 | task: pick-place-v2 3 | 4 | mw_seed: 0 5 | horizon: 100 6 | max_reward: 1000 7 | 8 | dim_obs: 11 9 | dim_actions: 4 -------------------------------------------------------------------------------- /gridworld/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .ad import AD 2 | from .dpt import DPT 3 | from .idt import IDT 4 | 5 | MODEL = { 6 | "AD": AD, 7 | "DPT": DPT, 8 | "IDT": IDT, 9 | } -------------------------------------------------------------------------------- /metaworld/alg/__init__.py: -------------------------------------------------------------------------------- 1 | from .ppo import PPOWrapper 2 | from .sac import SACWrapper 3 | from .utils import HistoryLoggerCallback 4 | 5 | ALGORITHM = { 6 | 'PPO': PPOWrapper, 7 | 'SAC': SACWrapper, 8 | } -------------------------------------------------------------------------------- /gridworld/cfg/env/darkroom.yaml: -------------------------------------------------------------------------------- 1 | env: darkroom 2 | grid_size: 9 3 | dim_states: 2 4 | dim_actions: 1 5 | num_actions: 5 6 | 7 | horizon: 20 8 | env_split_seed: 0 9 | 10 | train_env_ratio: 0.9 11 | 12 | max_reward: 20 -------------------------------------------------------------------------------- /metaworld/cfg/alg/ppo_ml1.yaml: -------------------------------------------------------------------------------- 1 | alg: PPO 2 | policy: MlpPolicy 3 | batch_size: 200 4 | n_epochs: 20 5 | n_steps: 100 6 | alg_seed: 0 7 | total_source_timesteps: 1000000 8 | n_stream: 100 9 | source_lr: 0.0003 10 | n_process: 8 -------------------------------------------------------------------------------- /gridworld/cfg/alg/ppo_dktd.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/alg/ppo_base.yaml 2 | 3 | n_steps: 50 4 | batch_size: 100 5 | n_epochs: 10 6 | alg_seed: 0 7 | total_source_timesteps: 100000 8 | n_stream: 100 9 | source_lr: 0.0003 10 | n_process: 8 -------------------------------------------------------------------------------- /gridworld/cfg/alg/ppo_dr.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/alg/ppo_base.yaml 2 | 3 | n_steps: 20 4 | batch_size: 50 5 | n_epochs: 20 6 | alg_seed: 0 7 | total_source_timesteps: 100000 8 | n_stream: 100 9 | source_lr: 0.0003 10 | n_process: 8 -------------------------------------------------------------------------------- /gridworld/cfg/alg/ppo_dr_permuted.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/alg/ppo_base.yaml 2 | 3 | n_steps: 50 4 | batch_size: 50 5 | n_epochs: 20 6 | alg_seed: 0 7 | total_source_timesteps: 100000 8 | n_stream: 100 9 | source_lr: 0.0003 10 | n_process: 8 -------------------------------------------------------------------------------- /gridworld/cfg/env/dark_key_to_door.yaml: -------------------------------------------------------------------------------- 1 | env: darkkeytodoor 2 | grid_size: 9 3 | dim_states: 2 4 | dim_actions: 1 5 | num_actions: 5 6 | 7 | horizon: 50 8 | env_split_seed: 4 9 | 10 | train_env_ratio: 0.95 11 | 12 | max_reward: 2 -------------------------------------------------------------------------------- /gridworld/cfg/env/darkroom_permuted.yaml: -------------------------------------------------------------------------------- 1 | env: darkroompermuted 2 | grid_size: 9 3 | dim_states: 2 4 | dim_actions: 1 5 | num_actions: 5 6 | 7 | horizon: 50 8 | env_split_seed: 4 9 | 10 | train_env_ratio: 0.9 11 | 12 | max_reward: 35 -------------------------------------------------------------------------------- /metaworld/cfg/alg/sac_ml1.yaml: -------------------------------------------------------------------------------- 1 | alg: SAC 2 | policy: MlpPolicy 3 | source_lr: 0.0003 4 | buffer_size: 1000000 5 | learning_starts: 100 6 | batch_size: 128 7 | train_freq: 1 8 | gradient_steps: 10 9 | n_stream: 100 10 | n_process: 8 11 | total_source_timesteps: 1000000 12 | alg_seed: 0 -------------------------------------------------------------------------------- /metaworld/cfg/model/ad_base.yaml: -------------------------------------------------------------------------------- 1 | dynamics: False 2 | dynamics_strength: 1.0 3 | 4 | # training 5 | label_smoothing: 0.0 6 | flash_attn: True 7 | ## optimizer 8 | beta1: 0.9 9 | beta2: 0.99 10 | weight_decay: 0.01 11 | 12 | # eval & logging 13 | summary_interval: 100 14 | eval_interval: 1000 15 | gen_interval: 10000 16 | ckpt_interval: 10000 -------------------------------------------------------------------------------- /gridworld/cfg/model/dpt_base.yaml: -------------------------------------------------------------------------------- 1 | model: DPT 2 | 3 | # training 4 | label_smoothing: 0.0 5 | flash_attn: True 6 | ## optimizer 7 | beta1: 0.9 8 | beta2: 0.99 9 | weight_decay: 0.01 10 | 11 | # eval & logging 12 | summary_interval: 100 13 | eval_interval: 1000 14 | gen_interval: 10000 15 | ckpt_interval: 10000 16 | 17 | num_workers: 4 18 | 19 | train_n_stream: 100 -------------------------------------------------------------------------------- /metaworld/cfg/model/idt_base.yaml: -------------------------------------------------------------------------------- 1 | dynamics: False 2 | dynamics_strength: 1.0 3 | 4 | # training 5 | train_n_stream: 100 6 | label_smoothing: 0.0 7 | flash_attn: True 8 | ## optimizer 9 | beta1: 0.9 10 | beta2: 0.99 11 | weight_decay: 0.01 12 | 13 | # eval & logging 14 | summary_interval: 100 15 | eval_interval: 1000 16 | gen_interval: 10000 17 | ckpt_interval: 10000 -------------------------------------------------------------------------------- /gridworld/cfg/model/ad_base.yaml: -------------------------------------------------------------------------------- 1 | dynamics: False 2 | dynamics_strength: 1.0 3 | 4 | # training 5 | label_smoothing: 0.0 6 | flash_attn: True 7 | ## optimizer 8 | beta1: 0.9 9 | beta2: 0.99 10 | weight_decay: 0.01 11 | 12 | # eval & logging 13 | summary_interval: 100 14 | eval_interval: 1000 15 | gen_interval: 10000 16 | ckpt_interval: 10000 17 | 18 | train_n_stream: 100 -------------------------------------------------------------------------------- /gridworld/cfg/model/idt_base.yaml: -------------------------------------------------------------------------------- 1 | dynamics: False 2 | dynamics_strength: 1.0 3 | 4 | # training 5 | train_n_stream: 100 6 | label_smoothing: 0.0 7 | flash_attn: True 8 | ## optimizer 9 | beta1: 0.9 10 | beta2: 0.99 11 | weight_decay: 0.01 12 | 13 | # eval & logging 14 | summary_interval: 100 15 | eval_interval: 1000 16 | gen_interval: 10000 17 | ckpt_interval: 10000 18 | 19 | train_n_stream: 100 -------------------------------------------------------------------------------- /gridworld/cfg/model/ad_dktd.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/ad_base.yaml 2 | 3 | model: AD 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 200 13 | 14 | # training 15 | lr: 0.01 16 | train_batch_size: 4096 17 | test_batch_size: 8192 18 | train_source_timesteps: 1000 19 | train_timesteps: 50000 20 | num_warmup_steps: 0 21 | 22 | num_workers: 2 -------------------------------------------------------------------------------- /gridworld/cfg/model/ad_dr.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/ad_base.yaml 2 | 3 | model: AD 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 80 13 | 14 | # training 15 | lr: 0.01 16 | train_batch_size: 512 17 | test_batch_size: 2048 18 | train_source_timesteps: 1000 19 | train_timesteps: 50000 20 | num_warmup_steps: 0 21 | 22 | num_workers: 4 -------------------------------------------------------------------------------- /gridworld/cfg/model/ad_dr_permuted.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/ad_base.yaml 2 | 3 | model: AD 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 200 13 | 14 | # training 15 | lr: 0.01 16 | train_batch_size: 512 17 | test_batch_size: 2048 18 | train_source_timesteps: 1000 19 | train_timesteps: 50000 20 | num_warmup_steps: 0 21 | 22 | num_workers: 4 -------------------------------------------------------------------------------- /gridworld/cfg/model/dpt_dr.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/dpt_base.yaml 2 | 3 | dynamics: False 4 | dynamics_strength: 1.0 5 | 6 | # Transformer 7 | tf_n_embd: 32 8 | tf_n_layer: 4 9 | tf_n_head: 4 10 | tf_dropout: 0.1 11 | tf_attn_dropout: 0.1 12 | tf_n_inner: 128 13 | n_transit: 80 14 | 15 | # training 16 | lr: 0.01 17 | train_batch_size: 512 18 | test_batch_size: 2048 19 | train_source_timesteps: 2000 20 | train_timesteps: 50000 21 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /gridworld/cfg/model/dpt_dr_permuted.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/dpt_base.yaml 2 | 3 | dynamics: False 4 | dynamics_strength: 1.0 5 | 6 | # Transformer 7 | tf_n_embd: 32 8 | tf_n_layer: 4 9 | tf_n_head: 4 10 | tf_dropout: 0.1 11 | tf_attn_dropout: 0.1 12 | tf_n_inner: 128 13 | n_transit: 200 14 | 15 | # training 16 | lr: 0.01 17 | train_batch_size: 512 18 | test_batch_size: 2048 19 | train_source_timesteps: 2000 20 | train_timesteps: 50000 21 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /gridworld/cfg/model/dpt_dktd.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/dpt_base.yaml 2 | 3 | dynamics: False 4 | dynamics_strength: 1.0 5 | 6 | # Transformer 7 | tf_n_embd: 32 8 | tf_n_layer: 4 9 | tf_n_head: 4 10 | tf_dropout: 0.1 11 | tf_attn_dropout: 0.1 12 | tf_n_inner: 200 13 | n_transit: 200 14 | 15 | # training 16 | lr: 0.01 17 | train_batch_size: 4096 18 | test_batch_size: 32768 19 | train_source_timesteps: 1000 20 | train_timesteps: 50000 21 | num_warmup_steps: 0 22 | 23 | num_workers: 2 -------------------------------------------------------------------------------- /gridworld/cfg/model/idt_dr.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/idt_base.yaml 2 | 3 | model: IDT 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 80 13 | low_per_high: 10 14 | dim_z: 8 15 | 16 | # training 17 | lr: 0.01 18 | train_batch_size: 1024 19 | test_batch_size: 2048 20 | train_source_timesteps: 1000 21 | train_timesteps: 50000 22 | num_warmup_steps: 0 23 | 24 | num_workers: 4 -------------------------------------------------------------------------------- /gridworld/cfg/model/idt_dktd.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/idt_base.yaml 2 | 3 | model: IDT 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 200 13 | low_per_high: 10 14 | dim_z: 8 15 | 16 | # training 17 | lr: 0.001 18 | train_batch_size: 2048 19 | test_batch_size: 2048 20 | train_source_timesteps: 1000 21 | train_timesteps: 100000 22 | num_warmup_steps: 0 23 | 24 | num_workers: 2 -------------------------------------------------------------------------------- /gridworld/cfg/model/idt_dr_permuted.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/idt_base.yaml 2 | 3 | model: IDT 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 200 13 | low_per_high: 10 14 | dim_z: 8 15 | 16 | # training 17 | lr: 0.001 18 | train_batch_size: 512 19 | test_batch_size: 2048 20 | train_source_timesteps: 1000 21 | train_timesteps: 50000 22 | num_warmup_steps: 0 23 | 24 | num_workers: 4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dicp 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - tensorboard 9 | - pyyaml=6.0.1 10 | - pip: 11 | - torch==2.3.0 12 | - matplotlib==3.7.5 13 | - python-dateutil==2.9.0.post0 14 | - networkx==3.1 15 | - einops==0.8.0 16 | - numpy==1.24.4 17 | - tqdm==4.66.4 18 | - transformers==4.41.1 19 | - accelerate==0.21.0 20 | - xformers==0.0.26.post1 21 | - tensorboard==2.14.0 22 | - ninja==1.11.1.1 23 | - stable-baselines3==2.3.2 24 | - h5py==3.11.0 25 | - lightning==2.1.2 -------------------------------------------------------------------------------- /metaworld/cfg/model/ad_ml1.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/ad_base.yaml 2 | 3 | model: AD 4 | 5 | # Transformer 6 | tf_n_embd: 64 7 | tf_n_layer: 4 8 | tf_n_head: 8 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 256 12 | n_transit: 1000 13 | 14 | # training 15 | lr: 0.01 16 | train_batch_size: 512 17 | test_batch_size: 2048 18 | train_source_timesteps: 10000 19 | test_source_timesteps: 2000 20 | train_timesteps: 50000 21 | num_warmup_steps: 0 22 | learn_var: True 23 | 24 | train_n_seed: 50 25 | train_n_stream: 100 26 | 27 | n_train_envs_per_task: 5 28 | n_test_envs_per_task: 10 29 | 30 | num_workers: 2 31 | 32 | learn_transition: False -------------------------------------------------------------------------------- /metaworld/cfg/model/idt_ml1.yaml: -------------------------------------------------------------------------------- 1 | include: cfg/model/idt_base.yaml 2 | 3 | model: IDT 4 | 5 | # Transformer 6 | tf_n_embd: 32 7 | tf_n_layer: 4 8 | tf_n_head: 4 9 | tf_dropout: 0.1 10 | tf_attn_dropout: 0.1 11 | tf_n_inner: 128 12 | n_transit: 1000 13 | low_per_high: 10 14 | dim_z: 8 15 | 16 | # training 17 | lr: 0.001 18 | train_batch_size: 512 19 | test_batch_size: 512 20 | train_source_timesteps: 10000 21 | test_source_timesteps: 2000 22 | train_timesteps: 50000 23 | num_warmup_steps: 0 24 | learn_var: True 25 | 26 | train_n_seed: 50 27 | train_n_stream: 100 28 | 29 | n_train_envs_per_task: 10 30 | n_test_envs_per_task: 10 31 | 32 | num_workers: 2 33 | 34 | learn_transition: False -------------------------------------------------------------------------------- /gridworld/env/__init__.py: -------------------------------------------------------------------------------- 1 | from .darkroom import sample_darkroom, sample_darkroom_permuted, Darkroom, DarkroomPermuted, map_dark_states, map_dark_states_inverse 2 | from .dark_key_to_door import DarkKeyToDoor, sample_dark_key_to_door 3 | 4 | 5 | ENVIRONMENT = { 6 | 'darkroom': Darkroom, 7 | 'darkroompermuted': DarkroomPermuted, 8 | 'darkkeytodoor': DarkKeyToDoor, 9 | } 10 | 11 | 12 | SAMPLE_ENVIRONMENT = { 13 | 'darkroom': sample_darkroom, 14 | 'darkroompermuted': sample_darkroom_permuted, 15 | 'darkkeytodoor': sample_dark_key_to_door, 16 | } 17 | 18 | 19 | def make_env(config, **kwargs): 20 | def _init(): 21 | return ENVIRONMENT[config['env']](config, **kwargs) 22 | return _init -------------------------------------------------------------------------------- /gridworld/alg/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from stable_baselines3 import PPO 3 | 4 | 5 | class PPOWrapper(PPO): 6 | def __init__(self, config, env, seed, log_dir): 7 | policy = config['policy'] 8 | n_steps = config['n_steps'] 9 | batch_size = config['batch_size'] 10 | n_epochs = config['n_epochs'] 11 | lr = config['source_lr'] 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | env = env 14 | 15 | super(PPOWrapper, self).__init__(policy=policy, 16 | env=env, 17 | learning_rate=lr, 18 | n_steps=n_steps, 19 | batch_size=batch_size, 20 | n_epochs=n_epochs, 21 | verbose=0, 22 | seed=seed, 23 | device=device, 24 | tensorboard_log=log_dir) -------------------------------------------------------------------------------- /metaworld/alg/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from stable_baselines3 import PPO 3 | 4 | 5 | class PPOWrapper(PPO): 6 | def __init__(self, config, env, seed, log_dir): 7 | policy = config['policy'] 8 | n_steps = config['n_steps'] 9 | batch_size = config['batch_size'] 10 | n_epochs = config['n_epochs'] 11 | lr = config['source_lr'] 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | env = env 14 | 15 | super(PPOWrapper, self).__init__(policy=policy, 16 | env=env, 17 | learning_rate=lr, 18 | n_steps=n_steps, 19 | batch_size=batch_size, 20 | n_epochs=n_epochs, 21 | verbose=0, 22 | seed=seed, 23 | device=device, 24 | tensorboard_log=log_dir) -------------------------------------------------------------------------------- /metaworld/alg/sac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from stable_baselines3 import SAC 3 | 4 | 5 | class SACWrapper(SAC): 6 | def __init__(self, config, env, seed, log_dir): 7 | policy = config['policy'] 8 | lr = config['source_lr'] 9 | buffer_size = config['buffer_size'] 10 | learning_starts = config['learning_starts'] 11 | batch_size = config['batch_size'] 12 | train_freq = config['train_freq'] 13 | gradient_steps = config['gradient_steps'] 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | env = env 16 | seed = seed 17 | 18 | super(SACWrapper, self).__init__(policy=policy, 19 | env=env, 20 | learning_rate=lr, 21 | buffer_size = buffer_size, 22 | learning_starts=learning_starts, 23 | batch_size=batch_size, 24 | train_freq=train_freq, 25 | gradient_steps=gradient_steps, 26 | verbose=0, 27 | seed=seed, 28 | device=device, 29 | tensorboard_log=log_dir) -------------------------------------------------------------------------------- /gridworld/model/tiny_llama/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Type 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Config: 9 | block_size: int 10 | n_layer: int 11 | n_head: int 12 | n_embd: int 13 | rotary_percentage: float 14 | parallel_residual: bool 15 | bias: bool 16 | dropout: float 17 | attention_dropout: float 18 | 19 | shared_attention_norm: bool 20 | _norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] 21 | _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] 22 | n_query_groups: Optional[int] = None 23 | norm_eps: float = 1e-5 24 | intermediate_size: Optional[int] = None 25 | condense_ratio: int = 1 26 | flash_attn: bool = True 27 | 28 | def __post_init__(self): 29 | assert self.n_embd % self.n_head == 0 30 | if self.n_query_groups is not None: 31 | assert self.n_head % self.n_query_groups == 0 32 | else: 33 | self.n_query_groups = self.n_head 34 | if self.intermediate_size is None: 35 | self.intermediate_size = 4 * self.n_embd 36 | 37 | @property 38 | def head_size(self) -> int: 39 | return self.n_embd // self.n_head 40 | 41 | @property 42 | def norm_class(self) -> Type: 43 | if self._norm_class == "RMSNorm": 44 | from .rmsnorm import RMSNorm 45 | 46 | return RMSNorm 47 | elif self._norm_class == "FusedRMSNorm": 48 | from .rmsnorm import FusedRMSNorm 49 | 50 | return FusedRMSNorm 51 | return getattr(torch.nn, self._norm_class) 52 | -------------------------------------------------------------------------------- /metaworld/model/tiny_llama/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Type 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Config: 9 | block_size: int 10 | n_layer: int 11 | n_head: int 12 | n_embd: int 13 | rotary_percentage: float 14 | parallel_residual: bool 15 | bias: bool 16 | dropout: float 17 | attention_dropout: float 18 | 19 | shared_attention_norm: bool 20 | _norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] 21 | _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] 22 | n_query_groups: Optional[int] = None 23 | norm_eps: float = 1e-5 24 | intermediate_size: Optional[int] = None 25 | condense_ratio: int = 1 26 | flash_attn: bool = True 27 | 28 | def __post_init__(self): 29 | assert self.n_embd % self.n_head == 0 30 | if self.n_query_groups is not None: 31 | assert self.n_head % self.n_query_groups == 0 32 | else: 33 | self.n_query_groups = self.n_head 34 | if self.intermediate_size is None: 35 | self.intermediate_size = 4 * self.n_embd 36 | 37 | @property 38 | def head_size(self) -> int: 39 | return self.n_embd // self.n_head 40 | 41 | @property 42 | def norm_class(self) -> Type: 43 | if self._norm_class == "RMSNorm": 44 | from .rmsnorm import RMSNorm 45 | 46 | return RMSNorm 47 | elif self._norm_class == "FusedRMSNorm": 48 | from .rmsnorm import FusedRMSNorm 49 | 50 | return FusedRMSNorm 51 | return getattr(torch.nn, self._norm_class) 52 | -------------------------------------------------------------------------------- /gridworld/alg/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from stable_baselines3.common.callbacks import BaseCallback 3 | 4 | 5 | class HistoryLoggerCallback(BaseCallback): 6 | def __init__(self, env_name, env_idx, history=None): 7 | super(HistoryLoggerCallback, self).__init__() 8 | self.env_name = env_name 9 | self.env_idx = env_idx 10 | 11 | self.states = [] 12 | self.actions = [] 13 | self.rewards = [] 14 | self.next_states = [] 15 | self.dones = [] 16 | 17 | self.history = history 18 | 19 | self.episode_rewards = [] 20 | self.episode_success = [] 21 | 22 | def _on_step(self) -> bool: 23 | # Capture state, action, and reward at each step 24 | self.states.append(self.locals["obs_tensor"].cpu().numpy()) 25 | self.next_states.append(self.locals["new_obs"]) 26 | self.actions.append(self.locals["actions"]) 27 | 28 | self.rewards.append(self.locals["rewards"].copy()) 29 | self.dones.append(self.locals["dones"]) 30 | 31 | self.episode_rewards.append(self.locals['rewards']) 32 | 33 | if self.locals['dones'][0]: 34 | mean_reward = np.mean(np.mean(self.episode_rewards, axis=0)) 35 | self.logger.record('rollout/mean_reward', mean_reward) 36 | self.episode_rewards = [] 37 | 38 | return True 39 | 40 | def _on_training_end(self): 41 | self.history[self.env_idx] = { 42 | 'states': np.array(self.states, dtype=np.int32), 43 | 'actions': np.array(self.actions, dtype=np.int32), 44 | 'rewards': np.array(self.rewards, dtype=np.int32), 45 | 'next_states': np.array(self.next_states, dtype=np.int32), 46 | 'dones': np.array(self.dones, dtype=np.bool_) 47 | } 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distillation for In-Context Planning (DICP) 2 | 3 | This repository provides the official implementation of our ICLR 2025 paper, [Distilling Reinforcement Learning Algorithms for In-Context Model-Based Planning](https://openreview.net/forum?id=BfUugGfBE5¬eId=BfUugGfBE5). 4 | 5 | ## Requirements 6 | 7 | To set up the required environment, run: 8 | ```bash 9 | conda env create -f environment.yml 10 | ``` 11 | 12 | ### Meta-World Installation 13 | 14 | The following command installs Meta-World, adapted from [Farama-Foundation/Metaworld](https://github.com/Farama-Foundation/Metaworld): 15 | ```bash 16 | git clone https://github.com/Farama-Foundation/Metaworld.git 17 | cd Metaworld 18 | git checkout 83ac03c 19 | pip install . 20 | cd .. && rm -rf Metaworld 21 | ``` 22 | 23 | ### TinyLlama Dependencies 24 | 25 | To install the required TinyLlama dependencies, run the following commands (adapted from [TinyLlama’s PRETRAIN.md](https://github.com/jzhang38/TinyLlama/blob/main/PRETRAIN.md)): 26 | ```bash 27 | git clone https://github.com/Dao-AILab/flash-attention 28 | cd flash-attention 29 | git checkout 320fb59 30 | python setup.py install 31 | cd csrc/rotary && pip install . 32 | cd ../layer_norm && pip install . 33 | cd ../xentropy && pip install . 34 | cd ../../.. && rm -rf flash-attention 35 | ``` 36 | 37 | ## Usage 38 | 39 | The following commands demonstrate the basic usage of the code in GridWorld environments. 40 | 41 | ### Data Collection 42 | 43 | To collect training data, run: 44 | ```bash 45 | python collect_data.py -ac [algorithm config] -ec [environment config] -t [trajectory directory] 46 | ``` 47 | 48 | ### Training 49 | 50 | To train the model, run: 51 | ```bash 52 | python train.py -ac [algorithm config] -ec [environment config] -mc [model config] -t [trajectory directory] -l [log directory] 53 | ``` 54 | 55 | ### Evaluation 56 | 57 | To evaluate a trained model, run: 58 | ```bash 59 | python evaluate.py -c [checkpoint directory] -k [beam size] 60 | ``` 61 | 62 | 63 | ## Citation 64 | If you find this work useful, please cite our paper: 65 | ```bibtex 66 | @inproceedings{son2025distilling, 67 | author = {Jaehyeon Son and Soochan Lee and Gunhee Kim}, 68 | title = {Distilling Reinforcement Learning Algorithms for In-Context Model-Based Planning}, 69 | booktitle = {International Conference on Learning Representations}, 70 | year = {2025}, 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /metaworld/alg/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from stable_baselines3.common.callbacks import BaseCallback 3 | 4 | 5 | class HistoryLoggerCallback(BaseCallback): 6 | def __init__(self, env_name, env_idx, history=None): 7 | super(HistoryLoggerCallback, self).__init__() 8 | self.env_name = env_name 9 | self.env_idx = env_idx 10 | 11 | self.states = [] 12 | self.actions = [] 13 | self.rewards = [] 14 | self.next_states = [] 15 | self.dones = [] 16 | self.already_success = [] 17 | self.success = [] 18 | 19 | self.history = history 20 | 21 | self.episode_rewards = [] 22 | self.episode_success = [] 23 | 24 | def _on_step(self) -> bool: 25 | self.states.append(self.locals["new_obs"][:, list(range(18, 36))]) 26 | self.next_states.append(self.locals["new_obs"][:, list(range(18))]) 27 | 28 | success = [info['success'] for info in self.locals['infos']] 29 | self.success.append(success) 30 | self.episode_success.append(success) 31 | 32 | self.actions.append(self.locals["actions"]) 33 | self.rewards.append(self.locals["rewards"].copy()) 34 | self.dones.append(self.locals["dones"]) 35 | 36 | self.episode_rewards.append(self.locals['rewards']) 37 | 38 | if self.locals['dones'][0]: 39 | mean_reward = np.mean(np.mean(self.episode_rewards, axis=0)) 40 | self.logger.record('rollout/mean_reward', mean_reward) 41 | self.episode_rewards = [] 42 | 43 | mean_success_rate = np.mean((np.sum(self.episode_success, axis=0) > 0.0)) 44 | self.logger.record('rollout/mean_success_rate', mean_success_rate) 45 | self.episode_success = [] 46 | 47 | return True 48 | 49 | def _on_training_end(self): 50 | if self.env_name == 'metaworld': 51 | self.history[self.env_idx] = { 52 | 'states': np.array(self.states, dtype=np.float32), 53 | 'actions': np.array(self.actions, dtype=np.float32), 54 | 'rewards': np.array(self.rewards, dtype=np.float32), 55 | 'next_states': np.array(self.next_states, dtype=np.float32), 56 | 'dones': np.array(self.dones, dtype=np.bool_), 57 | 'success': np.array(self.success, dtype=np.float32) 58 | } 59 | else: 60 | raise ValueError('Invalid environment') 61 | -------------------------------------------------------------------------------- /gridworld/model/tiny_llama/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | import rotary_emb 2 | import torch 3 | from einops import rearrange 4 | 5 | 6 | class ApplyRotaryEmb(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 9 | batch, seqlen, nheads, headdim = x.shape 10 | rotary_seqlen, rotary_dim = cos.shape 11 | rotary_dim *= 2 12 | assert rotary_dim <= headdim 13 | assert seqlen <= rotary_seqlen 14 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 15 | x_ro = x[..., :rotary_dim] 16 | x1, x2 = ( 17 | x_ro.chunk(2, dim=-1) 18 | if not interleaved 19 | else (x_ro[..., ::2], x_ro[..., 1::2]) 20 | ) 21 | out = torch.empty_like(x) if not inplace else x 22 | out_ro = out[..., :rotary_dim] 23 | if inplace: 24 | o1, o2 = x1, x2 25 | else: 26 | o1, o2 = ( 27 | out_ro.chunk(2, dim=-1) 28 | if not interleaved 29 | else (out_ro[..., ::2], out_ro[..., 1::2]) 30 | ) 31 | 32 | rotary_emb.apply_rotary( 33 | x1, 34 | x2, 35 | rearrange(cos[:seqlen], "s d -> s 1 d"), 36 | rearrange(sin[:seqlen], "s d -> s 1 d"), 37 | o1, 38 | o2, 39 | False, 40 | ) 41 | if not inplace and rotary_dim < headdim: 42 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 43 | ctx.save_for_backward(cos, sin) 44 | ctx.interleaved = interleaved 45 | ctx.inplace = inplace 46 | return out if not inplace else x 47 | 48 | @staticmethod 49 | def backward(ctx, do): 50 | cos, sin = ctx.saved_tensors 51 | _, seqlen, _, headdim = do.shape 52 | rotary_dim = cos.shape[-1] 53 | rotary_dim *= 2 54 | inplace = ctx.inplace 55 | do_ro = do[..., :rotary_dim] 56 | do1, do2 = ( 57 | do_ro.chunk(2, dim=-1) 58 | if not ctx.interleaved 59 | else (do_ro[..., ::2], do_ro[..., 1::2]) 60 | ) 61 | dx = torch.empty_like(do) if not inplace else do 62 | if inplace: 63 | dx1, dx2 = do1, do2 64 | else: 65 | dx_ro = dx[..., :rotary_dim] 66 | dx1, dx2 = ( 67 | dx_ro.chunk(2, dim=-1) 68 | if not ctx.interleaved 69 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 70 | ) 71 | rotary_emb.apply_rotary( 72 | do1, 73 | do2, 74 | rearrange(cos[:seqlen], "s d -> s 1 d"), 75 | rearrange(sin[:seqlen], "s d -> s 1 d"), 76 | dx1, 77 | dx2, 78 | True, 79 | ) 80 | if not inplace and rotary_dim < headdim: 81 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 82 | return dx, None, None, None, None 83 | 84 | 85 | apply_rotary_emb_func = ApplyRotaryEmb.apply 86 | -------------------------------------------------------------------------------- /metaworld/model/tiny_llama/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | import rotary_emb 2 | import torch 3 | from einops import rearrange 4 | 5 | 6 | class ApplyRotaryEmb(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 9 | batch, seqlen, nheads, headdim = x.shape 10 | rotary_seqlen, rotary_dim = cos.shape 11 | rotary_dim *= 2 12 | assert rotary_dim <= headdim 13 | assert seqlen <= rotary_seqlen 14 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 15 | x_ro = x[..., :rotary_dim] 16 | x1, x2 = ( 17 | x_ro.chunk(2, dim=-1) 18 | if not interleaved 19 | else (x_ro[..., ::2], x_ro[..., 1::2]) 20 | ) 21 | out = torch.empty_like(x) if not inplace else x 22 | out_ro = out[..., :rotary_dim] 23 | if inplace: 24 | o1, o2 = x1, x2 25 | else: 26 | o1, o2 = ( 27 | out_ro.chunk(2, dim=-1) 28 | if not interleaved 29 | else (out_ro[..., ::2], out_ro[..., 1::2]) 30 | ) 31 | 32 | rotary_emb.apply_rotary( 33 | x1, 34 | x2, 35 | rearrange(cos[:seqlen], "s d -> s 1 d"), 36 | rearrange(sin[:seqlen], "s d -> s 1 d"), 37 | o1, 38 | o2, 39 | False, 40 | ) 41 | if not inplace and rotary_dim < headdim: 42 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 43 | ctx.save_for_backward(cos, sin) 44 | ctx.interleaved = interleaved 45 | ctx.inplace = inplace 46 | return out if not inplace else x 47 | 48 | @staticmethod 49 | def backward(ctx, do): 50 | cos, sin = ctx.saved_tensors 51 | _, seqlen, _, headdim = do.shape 52 | rotary_dim = cos.shape[-1] 53 | rotary_dim *= 2 54 | inplace = ctx.inplace 55 | do_ro = do[..., :rotary_dim] 56 | do1, do2 = ( 57 | do_ro.chunk(2, dim=-1) 58 | if not ctx.interleaved 59 | else (do_ro[..., ::2], do_ro[..., 1::2]) 60 | ) 61 | dx = torch.empty_like(do) if not inplace else do 62 | if inplace: 63 | dx1, dx2 = do1, do2 64 | else: 65 | dx_ro = dx[..., :rotary_dim] 66 | dx1, dx2 = ( 67 | dx_ro.chunk(2, dim=-1) 68 | if not ctx.interleaved 69 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 70 | ) 71 | rotary_emb.apply_rotary( 72 | do1, 73 | do2, 74 | rearrange(cos[:seqlen], "s d -> s 1 d"), 75 | rearrange(sin[:seqlen], "s d -> s 1 d"), 76 | dx1, 77 | dx2, 78 | True, 79 | ) 80 | if not inplace and rotary_dim < headdim: 81 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 82 | return dx, None, None, None, None 83 | 84 | 85 | apply_rotary_emb_func = ApplyRotaryEmb.apply 86 | -------------------------------------------------------------------------------- /gridworld/evaluate.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import argparse 4 | from glob import glob 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.dirname(sys.path[0])) 9 | 10 | import torch 11 | import os.path as path 12 | 13 | from env import SAMPLE_ENVIRONMENT, make_env 14 | from model import MODEL 15 | from stable_baselines3.common.vec_env import SubprocVecEnv 16 | import numpy as np 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | torch.backends.cudnn.benchmark = False 20 | torch.backends.cudnn.deterministic = True 21 | 22 | def parse_arguments(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--ckpt-dir', '-c', required=True, help="Checkpoint directory") 25 | parser.add_argument('--beam-k', '-k', required=False, type=int, default=10, help="Beam_k for planning") 26 | parser.add_argument('--seed', '-s', required=False, type=int, default=0, help="Torch seed") 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | if __name__ == '__main__': 33 | args = parse_arguments() 34 | 35 | ckpt_paths = sorted(glob(path.join(args.ckpt_dir, 'ckpt-*.pt'))) 36 | if len(ckpt_paths) > 0: 37 | ckpt_path = ckpt_paths[-1] 38 | ckpt = torch.load(ckpt_path) 39 | print(f'Checkpoint loaded from {ckpt_path}') 40 | config = ckpt['config'] 41 | else: 42 | raise ValueError('No checkpoint found') 43 | 44 | model_name = config['model'] 45 | 46 | model = MODEL[model_name](config).to(device) 47 | 48 | model.load_state_dict(ckpt['model']) 49 | model.eval() 50 | 51 | env_name = config['env'] 52 | _, test_env_args = SAMPLE_ENVIRONMENT[env_name](config) 53 | 54 | if env_name == "darkroom": 55 | envs = SubprocVecEnv([make_env(config, goal=arg) for arg in test_env_args]) 56 | elif env_name == "darkkeytodoor": 57 | envs = SubprocVecEnv([make_env(config, key=arg[:2], goal=arg[2:]) for arg in test_env_args]) 58 | elif env_name == 'darkroompermuted': 59 | envs = SubprocVecEnv([make_env(config, perm_idx=arg) for arg in test_env_args]) 60 | else: 61 | raise NotImplementedError(f'Environment {env_name} is not supported') 62 | 63 | torch.manual_seed(args.seed) 64 | torch.cuda.manual_seed(args.seed) 65 | torch.cuda.manual_seed_all(args.seed) 66 | 67 | start_time = datetime.now() 68 | print(f'Generation started at {start_time}') 69 | 70 | with torch.no_grad(): 71 | if config['dynamics']: 72 | test_rewards = model.evaluate_in_context(vec_env=envs, 73 | eval_timesteps=config['horizon'] * 100, 74 | beam_k=args.beam_k)['reward_episode'] 75 | path = path.join(args.ckpt_dir, f'eval_result_k{args.beam_k}.npy') 76 | 77 | else: 78 | test_rewards = model.evaluate_in_context(vec_env=envs, 79 | eval_timesteps=config['horizon'] * 50)['reward_episode'] 80 | path = path.join(args.ckpt_dir, 'eval_result.npy') 81 | 82 | end_time = datetime.now() 83 | print() 84 | print(f'Generation ended at {end_time}') 85 | print(f'Elapsed time: {end_time - start_time}') 86 | 87 | envs.close() 88 | 89 | with open(path, 'wb') as f: 90 | np.save(f, test_rewards) -------------------------------------------------------------------------------- /metaworld/evaluate.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import os.path as path 4 | 5 | import argparse 6 | from glob import glob 7 | 8 | import torch 9 | 10 | from model import MODEL 11 | from stable_baselines3.common.vec_env import SubprocVecEnv 12 | import numpy as np 13 | import metaworld 14 | from gymnasium.wrappers.time_limit import TimeLimit 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | torch.backends.cudnn.benchmark = False 18 | torch.backends.cudnn.deterministic = True 19 | 20 | 21 | def parse_arguments(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--ckpt-dir', '-c', required=True, help="Checkpoint directory") 24 | parser.add_argument('--sample-size', '-k', required=False, type=int, default=10, help="Sample size for planning") 25 | parser.add_argument('--seed', '-s', required=False, type=int, default=0, help="Torch seed") 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def make_env(config, env_cls, task): 32 | def _init(): 33 | env = env_cls() 34 | env.set_task(task) 35 | return TimeLimit(env, max_episode_steps=config['horizon']) 36 | return _init 37 | 38 | 39 | if __name__ == '__main__': 40 | args = parse_arguments() 41 | 42 | ckpt_paths = sorted(glob(path.join(args.ckpt_dir, 'ckpt-*.pt'))) 43 | if len(ckpt_paths) > 0: 44 | ckpt_path = ckpt_paths[-1] 45 | ckpt = torch.load(ckpt_path) 46 | print(f'Checkpoint loaded from {ckpt_path}') 47 | config = ckpt['config'] 48 | else: 49 | raise ValueError('No checkpoint found') 50 | 51 | model_name = config['model'] 52 | 53 | model = MODEL[model_name](config).to(device) 54 | 55 | model.load_state_dict(ckpt['model']) 56 | model.eval() 57 | 58 | ml1 = metaworld.ML1(env_name=config['task'], seed=config['mw_seed']) 59 | 60 | test_envs = [] 61 | 62 | for task_name, env_cls in ml1.test_classes.items(): 63 | task_instances = [task for task in ml1.test_tasks if task.env_name == task_name] 64 | for i in range(50): 65 | test_envs.append(make_env(config, env_cls, task_instances[i])) 66 | 67 | envs = SubprocVecEnv(test_envs) 68 | model.set_obs_space(envs.observation_space) 69 | 70 | torch.manual_seed(args.seed) 71 | torch.cuda.manual_seed(args.seed) 72 | torch.cuda.manual_seed_all(args.seed) 73 | 74 | if config['task'] == 'reach-v2': 75 | eval_episodes = 100 76 | elif config['task'] == 'push-v2': 77 | eval_episodes = 300 78 | elif config['task'] == 'pick-place-v2' or config['task'] == 'peg-insert-side-v2': 79 | eval_episodes = 2000 80 | else: 81 | eval_episodes = 200 82 | 83 | start_time = datetime.now() 84 | print(f'Generation started at {start_time}') 85 | 86 | with torch.no_grad(): 87 | if config['dynamics']: 88 | output = model.evaluate_in_context(vec_env=envs, 89 | eval_timesteps=eval_episodes * config['horizon'], 90 | sample_size=args.sample_size) 91 | 92 | else: 93 | output = model.evaluate_in_context(vec_env=envs, 94 | eval_timesteps=eval_episodes * config['horizon']) 95 | 96 | end_time = datetime.now() 97 | print() 98 | print(f'Generation ended at {end_time}') 99 | print(f'Elapsed time: {end_time - start_time}') 100 | 101 | # Clean up 102 | envs.close() 103 | 104 | if config['dynamics']: 105 | path = path.join(args.ckpt_dir, f'eval_result_k{args.sample_size}') 106 | else: 107 | path = path.join(args.ckpt_dir, 'eval_result') 108 | 109 | reward_episode = output['reward_episode'] 110 | success = output['success'] 111 | 112 | with open(f'{path}_reward.npy', 'wb') as f: 113 | np.save(f, reward_episode) 114 | with open(f'{path}_success.npy', 'wb') as f: 115 | np.save(f, success) -------------------------------------------------------------------------------- /gridworld/env/dark_key_to_door.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | from typing import Any, Tuple 5 | 6 | import gymnasium as gym 7 | import numpy as np 8 | from gymnasium import spaces 9 | from gymnasium.core import ObsType 10 | 11 | 12 | def sample_dark_key_to_door(config, shuffle=True): 13 | keys_goals = [np.array([i, j, k, l]) 14 | for i in range(config['grid_size']) for j in range(config['grid_size']) 15 | for k in range(config['grid_size']) for l in range(config['grid_size'])] 16 | 17 | if shuffle: 18 | random.seed(config['env_split_seed']) 19 | random.shuffle(keys_goals) 20 | 21 | n_train_envs = round(config['grid_size'] ** 4 * config['train_env_ratio']) 22 | 23 | train_keys_goals = keys_goals[:n_train_envs] 24 | test_keys_goals = keys_goals[n_train_envs:] 25 | 26 | return train_keys_goals, test_keys_goals 27 | 28 | 29 | class DarkKeyToDoor(gym.Env): 30 | metadata = {'render.modes': ['human']} 31 | 32 | def __init__(self, config, **kwargs): 33 | super(DarkKeyToDoor, self).__init__() 34 | self.grid_size = config['grid_size'] 35 | if 'key' in kwargs: 36 | self.key = kwargs['key'] 37 | else: 38 | self.key = np.random.randint(0, self.grid_size, 2) 39 | if 'goal' in kwargs: 40 | self.goal = kwargs['goal'] 41 | else: 42 | self.goal = np.random.randint(0, self.grid_size, 2) 43 | self.horizon = config['horizon'] 44 | self.dim_obs = 2 45 | self.dim_action = 5 46 | self.observation_space = spaces.Box(low=0, high=self.grid_size-1, shape=(self.dim_obs,), dtype=np.int32) 47 | self.action_space = spaces.Discrete(self.dim_action) 48 | 49 | def reset( 50 | self, 51 | *, 52 | seed: int | None = None, 53 | options: dict[str, Any] | None = None, 54 | ) -> Tuple[ObsType, dict[str, Any]]: 55 | self.current_step = 0 56 | 57 | center = self.grid_size // 2 58 | self.state = np.array([center, center]) 59 | self.have_key = False 60 | self.reach_goal = False 61 | 62 | return self.state, {} 63 | 64 | def step(self, action): 65 | if self.current_step >= self.horizon: 66 | raise ValueError("Episode has already ended") 67 | 68 | s = np.array(self.state) 69 | a = action 70 | 71 | # Action handling 72 | if a == 0: 73 | s[0] += 1 74 | elif a == 1: 75 | s[0] -= 1 76 | elif a == 2: 77 | s[1] += 1 78 | elif a == 3: 79 | s[1] -= 1 80 | 81 | s = np.clip(s, 0, self.grid_size - 1) 82 | self.state = s 83 | 84 | info = {} 85 | info['already_success'] = self.reach_goal 86 | 87 | if not self.have_key and np.array_equal(s, self.key): 88 | self.have_key = True 89 | reward = 1 90 | elif self.have_key and not self.reach_goal and np.array_equal(s, self.goal): 91 | self.reach_goal = True 92 | reward = 1 93 | else: 94 | reward = 0 95 | 96 | self.current_step += 1 97 | 98 | done = self.current_step >= self.horizon 99 | 100 | info['success'] = self.reach_goal 101 | 102 | return s.copy(), reward, done, done, info 103 | 104 | def get_optimal_action(self, state, have_key=False): 105 | if have_key: 106 | if state[0] < self.goal[0]: 107 | a = 0 108 | elif state[0] > self.goal[0]: 109 | a = 1 110 | elif state[1] < self.goal[1]: 111 | a = 2 112 | elif state[1] > self.goal[1]: 113 | a = 3 114 | else: 115 | a = 4 116 | else: 117 | if state[0] < self.key[0]: 118 | a = 0 119 | elif state[0] > self.key[0]: 120 | a = 1 121 | elif state[1] < self.key[1]: 122 | a = 2 123 | elif state[1] > self.key[1]: 124 | a = 3 125 | else: 126 | a = 4 127 | 128 | return a 129 | 130 | def get_max_return(self): 131 | return 2 -------------------------------------------------------------------------------- /metaworld/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def get_traj_file_name(config): 9 | if config["env"] == 'metaworld': 10 | task = config['task'] 11 | else: 12 | task = config['env'] 13 | 14 | path = f'history_{task}_{config["alg"]}_alg-seed{config["alg_seed"]}' 15 | 16 | return path 17 | 18 | 19 | def get_config(config_path): 20 | with open(config_path, 'r') as f: 21 | new_config = yaml.full_load(f) 22 | config = {} 23 | if 'include' in new_config: 24 | include_config = get_config(new_config['include']) 25 | config.update(include_config) 26 | del new_config['include'] 27 | config.update(new_config) 28 | return config 29 | 30 | 31 | def ad_collate_fn(batch): 32 | res = {} 33 | res['query_states'] = torch.tensor(np.array([item['query_states'] for item in batch]), requires_grad=False, dtype=torch.float) 34 | res['target_actions'] = torch.tensor(np.array([item['target_actions'] for item in batch]), requires_grad=False, dtype=torch.float) 35 | res['states'] = torch.tensor(np.array([item['states'] for item in batch]), requires_grad=False, dtype=torch.float) 36 | res['actions'] = torch.tensor(np.array([item['actions'] for item in batch]), requires_grad=False, dtype=torch.float) 37 | res['rewards'] = torch.tensor(np.array([item['rewards'] for item in batch]), requires_grad=False, dtype=torch.float) 38 | res['next_states'] = torch.tensor(np.array([item['next_states'] for item in batch]), requires_grad=False, dtype=torch.float) 39 | 40 | if 'target_next_states' in batch[0].keys(): 41 | res['target_next_states'] = torch.tensor(np.array([item['target_next_states'] for item in batch]), dtype=torch.float, requires_grad=False) 42 | res['target_rewards'] = torch.tensor(np.array([item['target_rewards'] for item in batch]), dtype=torch.float, requires_grad=False) 43 | 44 | return res 45 | 46 | 47 | def idt_collate_fn(batch): 48 | res = {} 49 | res['states'] = torch.tensor(np.array([item['states'] for item in batch]), dtype=torch.float, requires_grad=False) 50 | res['actions'] = torch.tensor(np.array([item['actions'] for item in batch]), dtype=torch.float, requires_grad=False) 51 | res['rewards'] = torch.tensor(np.array([item['rewards'] for item in batch]), dtype=torch.float, requires_grad=False) 52 | res['returns_to_go'] = torch.tensor(np.array([item['returns_to_go'] for item in batch]), dtype=torch.float, requires_grad=False) 53 | res['next_states'] = torch.tensor(np.array([item['next_states'] for item in batch]), dtype=torch.float, requires_grad=False) 54 | 55 | return res 56 | 57 | 58 | def get_data_loader(dataset, batch_size, config, shuffle=True): 59 | if config['model'] == 'AD': 60 | collate_fn = ad_collate_fn 61 | elif config['model'] == 'IDT': 62 | collate_fn = idt_collate_fn 63 | else: 64 | print(config['model']) 65 | raise ValueError('Invalid model') 66 | 67 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=config['num_workers'], pin_memory=True, persistent_workers=True) 68 | 69 | 70 | def log_in_context(values: np.ndarray, max_reward: int, episode_length: int, tag: str, title: str, xlabel: str, ylabel: str, step: int, success=None, writer=None) -> None: 71 | steps = np.arange(1, len(values[0])+1) * episode_length 72 | mean_value = values.mean(axis=0) 73 | 74 | plt.plot(steps, mean_value) 75 | 76 | if success is not None: 77 | success_rate = success.astype(np.float32).mean(axis=0) 78 | 79 | for i, (xi, yi) in enumerate(zip(steps, mean_value)): 80 | if (i+1) % 10 == 0: 81 | plt.annotate(f'{success_rate[i]:.2f}', (xi, yi)) 82 | 83 | plt.title(title) 84 | plt.ylabel(ylabel) 85 | plt.xlabel(xlabel) 86 | plt.ylim(-max_reward * 0.05, max_reward * 1.05) 87 | writer.add_figure(f'{tag}/mean', plt.gcf(), global_step=step) 88 | plt.close() 89 | 90 | 91 | def next_dataloader(dataloader: DataLoader): 92 | """ 93 | Makes the dataloader never end when the dataset is exhausted. 94 | This is done to remove the notion of an 'epoch' and to count only the amount 95 | of training steps. 96 | """ 97 | while True: 98 | for batch in dataloader: 99 | yield batch -------------------------------------------------------------------------------- /gridworld/env/darkroom.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import gymnasium as gym 5 | from gymnasium import spaces 6 | from gymnasium.core import ObsType 7 | import torch 8 | from typing import Any, Tuple 9 | import random 10 | import itertools 11 | 12 | 13 | def map_dark_states(states, grid_size): 14 | return torch.sum(states * torch.tensor((grid_size, 1), device=states.device, requires_grad=False), dim=-1) 15 | 16 | 17 | def map_dark_states_inverse(index, grid_size): 18 | return torch.stack((index // grid_size, index % grid_size), dim=-1) 19 | 20 | 21 | def sample_darkroom(config, shuffle=True): 22 | goals = [np.array([i, j]) for i in range(config['grid_size']) for j in range(config['grid_size'])] 23 | 24 | if shuffle: 25 | random.seed(config['env_split_seed']) 26 | random.shuffle(goals) 27 | 28 | n_train_envs = round(config['grid_size'] ** 2 * config['train_env_ratio']) 29 | 30 | train_goals = goals[:n_train_envs] 31 | test_goals = goals[n_train_envs:] 32 | 33 | return train_goals, test_goals 34 | 35 | 36 | def sample_darkroom_permuted(config, shuffle=True): 37 | perms = list(range(120)) 38 | 39 | if shuffle: 40 | random.seed(config['env_split_seed']) 41 | random.shuffle(perms) 42 | 43 | n_train_envs = round(120 * config['train_env_ratio']) 44 | 45 | train_perms = perms[:n_train_envs] 46 | test_perms = perms[n_train_envs:] 47 | 48 | return train_perms, test_perms 49 | 50 | 51 | class Darkroom(gym.Env): 52 | def __init__(self, config, **kwargs): 53 | super(Darkroom, self).__init__() 54 | self.grid_size = config['grid_size'] 55 | if 'goal' in kwargs: 56 | self.goal = kwargs['goal'] 57 | self.horizon = config['horizon'] 58 | self.dim_obs = 2 59 | self.dim_action = 1 60 | self.num_action = 5 61 | self.observation_space = spaces.Box(low=0, high=self.grid_size-1, shape=(self.dim_obs,), dtype=np.int32) 62 | self.action_space = spaces.Discrete(self.num_action) 63 | 64 | def reset( 65 | self, 66 | *, 67 | seed: int | None = None, 68 | options: dict[str, Any] | None = None, 69 | ) -> Tuple[ObsType, dict[str, Any]]: 70 | self.current_step = 0 71 | 72 | center = self.grid_size // 2 73 | self.state = np.array([center, center]) 74 | 75 | return self.state, {} 76 | 77 | def step(self, action): 78 | if self.current_step >= self.horizon: 79 | raise ValueError("Episode has already ended") 80 | 81 | s = np.array(self.state) 82 | a = action 83 | 84 | # Action handling 85 | if a == 0: 86 | s[0] += 1 87 | elif a == 1: 88 | s[0] -= 1 89 | elif a == 2: 90 | s[1] += 1 91 | elif a == 3: 92 | s[1] -= 1 93 | 94 | s = np.clip(s, 0, self.grid_size - 1) 95 | self.state = s 96 | 97 | reward = 1 if np.array_equal(s, self.goal) else 0 98 | self.current_step += 1 99 | done = self.current_step >= self.horizon 100 | info = {} 101 | return s.copy(), reward, done, done, info 102 | 103 | def get_optimal_action(self, state): 104 | if state[0] < self.goal[0]: 105 | a = 0 106 | elif state[0] > self.goal[0]: 107 | a = 1 108 | elif state[1] < self.goal[1]: 109 | a = 2 110 | elif state[1] > self.goal[1]: 111 | a = 3 112 | else: 113 | a = 4 114 | 115 | return a 116 | 117 | def transit(self, s, a): 118 | if a == 0: 119 | s[0] += 1 120 | elif a == 1: 121 | s[0] -= 1 122 | elif a == 2: 123 | s[1] += 1 124 | elif a == 3: 125 | s[1] -= 1 126 | elif a == 4: 127 | pass 128 | else: 129 | raise ValueError('Invalid action') 130 | 131 | s = np.clip(s, 0, self.grid_size - 1) 132 | 133 | if np.all(s == self.goal): 134 | r = 1 135 | else: 136 | r = 0 137 | 138 | return s, r 139 | 140 | def get_max_return(self): 141 | center = self.grid_size // 2 142 | return (self.horizon + 1 - np.sum(np.absolute(self.goal - np.array([center, center])))).clip(0, self.horizon) 143 | 144 | 145 | class DarkroomPermuted(Darkroom): 146 | def __init__(self, config, **kwargs): 147 | super().__init__(config, **kwargs) 148 | 149 | self.perm_idx = kwargs['perm_idx'] 150 | self.goal = np.array([self.grid_size-1, self.grid_size-1]) 151 | 152 | assert self.perm_idx < 120 # 5! permutations in darkroom 153 | 154 | actions = np.arange(self.action_space.n) 155 | permutations = list(itertools.permutations(actions)) 156 | self.perm = permutations[self.perm_idx] 157 | 158 | def reset( 159 | self, 160 | *, 161 | seed: int | None = None, 162 | options: dict[str, Any] | None = None, 163 | ) -> Tuple[ObsType, dict[str, Any]]: 164 | self.current_step = 0 165 | 166 | self.state = np.array([0, 0]) 167 | 168 | return self.state, {} 169 | 170 | def step(self, action): 171 | return super().step(self.perm[action]) 172 | 173 | def transit(self, s, a): 174 | return super().transit(s, self.perm[a]) 175 | 176 | def get_optimal_action(self, state): 177 | action = super().get_optimal_action(state) 178 | return self.perm.index(action) 179 | 180 | def get_max_return(self): 181 | return (self.horizon + 1 - np.sum(np.absolute(self.goal - np.array([0, 0])))) -------------------------------------------------------------------------------- /gridworld/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from env import map_dark_states 8 | from functools import partial 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def get_traj_file_name(config): 13 | if config["env"] == 'metaworld': 14 | task = config['task'] 15 | else: 16 | task = config['env'] 17 | 18 | path = f'history_{task}_{config["alg"]}_alg-seed{config["alg_seed"]}' 19 | 20 | return path 21 | 22 | 23 | def get_config(config_path): 24 | with open(config_path, 'r') as f: 25 | new_config = yaml.full_load(f) 26 | config = {} 27 | if 'include' in new_config: 28 | include_config = get_config(new_config['include']) 29 | config.update(include_config) 30 | del new_config['include'] 31 | config.update(new_config) 32 | return config 33 | 34 | 35 | def ad_collate_fn(batch, grid_size): 36 | res = {} 37 | res['query_states'] = torch.tensor(np.array([item['query_states'] for item in batch]), requires_grad=False, dtype=torch.float) 38 | res['target_actions'] = torch.tensor(np.array([item['target_actions'] for item in batch]), requires_grad=False, dtype=torch.long) 39 | res['states'] = torch.tensor(np.array([item['states'] for item in batch]), requires_grad=False, dtype=torch.float) 40 | res['actions'] = F.one_hot(torch.tensor(np.array([item['actions'] for item in batch]), requires_grad=False, dtype=torch.long), num_classes=5) 41 | res['rewards'] = torch.tensor(np.array([item['rewards'] for item in batch]), dtype=torch.float, requires_grad=False) 42 | res['next_states'] = torch.tensor(np.array([item['next_states'] for item in batch]), requires_grad=False, dtype=torch.float) 43 | 44 | if 'target_next_states' in batch[0].keys(): 45 | res['target_next_states'] = map_dark_states(torch.tensor(np.array([item['target_next_states'] for item in batch]), dtype=torch.long, requires_grad=False), grid_size=grid_size) 46 | res['target_rewards'] = torch.tensor(np.array([item['target_rewards'] for item in batch]), dtype=torch.long, requires_grad=False) 47 | 48 | return res 49 | 50 | 51 | def dpt_collate_fn(batch, grid_size): 52 | res = {} 53 | res['query_states'] = torch.tensor(np.array([item['query_states'] for item in batch]), requires_grad=False, dtype=torch.float) 54 | res['target_actions'] = torch.tensor(np.array([item['target_actions'] for item in batch]), requires_grad=False, dtype=torch.long) 55 | res['states'] = torch.tensor(np.array([item['states'] for item in batch]), requires_grad=False, dtype=torch.float) 56 | res['actions'] = F.one_hot(torch.tensor(np.array([item['actions'] for item in batch]), requires_grad=False, dtype=torch.long), num_classes=5) 57 | res['rewards'] = torch.tensor(np.array([item['rewards'] for item in batch]), dtype=torch.float, requires_grad=False) 58 | res['next_states'] = torch.tensor(np.array([item['next_states'] for item in batch]), requires_grad=False, dtype=torch.float) 59 | 60 | if 'target_next_states' in batch[0].keys(): 61 | res['query_actions'] = torch.tensor(np.array([item['query_actions'] for item in batch]), requires_grad=False, dtype=torch.long) 62 | res['target_next_states'] = map_dark_states(torch.tensor(np.array([item['target_next_states'] for item in batch]), dtype=torch.long, requires_grad=False), grid_size=grid_size) 63 | res['target_rewards'] = torch.tensor(np.array([item['target_rewards'] for item in batch]), dtype=torch.long, requires_grad=False) 64 | 65 | return res 66 | 67 | 68 | def idt_collate_fn(batch): 69 | res = {} 70 | res['states'] = torch.tensor(np.array([item['states'] for item in batch]), requires_grad=False) 71 | res['actions'] = torch.tensor(np.array([item['actions'] for item in batch]), dtype=torch.long, requires_grad=False) 72 | res['rewards'] = torch.tensor(np.array([item['rewards'] for item in batch]), dtype=torch.long, requires_grad=False) 73 | res['return_to_go'] = torch.tensor(np.array([item['return_to_go'] for item in batch]), dtype=torch.float, requires_grad=False) 74 | res['next_states'] = torch.tensor(np.array([item['next_states'] for item in batch]), requires_grad=False) 75 | 76 | return res 77 | 78 | 79 | def get_data_loader(dataset, batch_size, config, shuffle=True): 80 | if config['model'] == 'AD': 81 | collate_fn = partial(ad_collate_fn, grid_size=config['grid_size']) 82 | elif config['model'] == 'DPT': 83 | collate_fn = partial(dpt_collate_fn, grid_size=config['grid_size']) 84 | elif config['model'] == 'IDT': 85 | collate_fn = idt_collate_fn 86 | 87 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=config['num_workers'], pin_memory=True, persistent_workers=True) 88 | 89 | 90 | def log_in_context(values: np.ndarray, max_reward: int, episode_length: int, tag: str, title: str, xlabel: str, ylabel: str, step: int, success=None, writer=None) -> None: 91 | steps = np.arange(1, len(values[0])+1) * episode_length 92 | mean_value = values.mean(axis=0) 93 | 94 | plt.plot(steps, mean_value) 95 | 96 | if success is not None: 97 | success_rate = success.astype(np.float32).mean(axis=0) 98 | 99 | for i, (xi, yi) in enumerate(zip(steps, mean_value)): 100 | if (i+1) % 10 == 0: 101 | plt.annotate(f'{success_rate[i]:.2f}', (xi, yi)) 102 | 103 | plt.title(title) 104 | plt.ylabel(ylabel) 105 | plt.xlabel(xlabel) 106 | plt.ylim(-max_reward * 0.05, max_reward * 1.05) 107 | writer.add_figure(f'{tag}/mean', plt.gcf(), global_step=step) 108 | plt.close() 109 | 110 | 111 | def next_dataloader(dataloader: DataLoader): 112 | """ 113 | Makes the dataloader never end when the dataset is exhausted. 114 | This is done to remove the notion of an 'epoch' and to count only the amount 115 | of training steps. 116 | """ 117 | while True: 118 | for batch in dataloader: 119 | yield batch -------------------------------------------------------------------------------- /metaworld/collect_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import yaml 5 | 6 | import sys 7 | sys.path.append(os.path.dirname(sys.path[0])) 8 | 9 | from alg import ALGORITHM, HistoryLoggerCallback 10 | import argparse 11 | import multiprocessing 12 | from utils import get_config, get_traj_file_name 13 | import h5py 14 | 15 | from stable_baselines3.common.vec_env import DummyVecEnv 16 | from gymnasium.wrappers.time_limit import TimeLimit 17 | import metaworld 18 | 19 | 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--alg-config', '-ac', required=True, help="Algorithm config") 23 | parser.add_argument('--env-config', '-ec', required=True, help="Environment config") 24 | parser.add_argument('--traj-dir', '-t', required=False, default='./datasets', help="Directory for history saving") 25 | parser.add_argument('--override', '-o', default='') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def worker(config, env_cls, task_instance, traj_dir, env_idx, history, file_name): 31 | 32 | env = DummyVecEnv([make_env(config, env_cls, task_instance)] * config['n_stream']) 33 | 34 | alg_name = config['alg'] 35 | seed = config['alg_seed'] + env_idx 36 | 37 | # Initialize algorithm 38 | alg = ALGORITHM[alg_name](config, env, seed, traj_dir) 39 | 40 | callback = HistoryLoggerCallback(config['env'], env_idx, history) 41 | 42 | log_name = f'{file_name}_{env_idx}' 43 | 44 | # Execute learning algorithm 45 | alg.learn(total_timesteps=config['total_source_timesteps'], 46 | callback=callback, 47 | log_interval=1, 48 | tb_log_name=log_name, 49 | reset_num_timesteps=True, 50 | progress_bar=False) 51 | 52 | env.close() 53 | 54 | 55 | def make_env(config, env_cls, task): 56 | def _init(): 57 | env = env_cls() 58 | env.set_task(task) 59 | return TimeLimit(env, max_episode_steps=config['horizon']) 60 | return _init 61 | 62 | 63 | if __name__ == '__main__': 64 | # Initialize multiprocessing 65 | multiprocessing.set_start_method('spawn') 66 | 67 | args = parse_arguments() 68 | 69 | # Load and update config 70 | config = get_config(args.env_config) 71 | config.update(get_config(args.alg_config)) 72 | 73 | # Override options 74 | for option in args.override.split('|'): 75 | if not option: 76 | continue 77 | address, value = option.split('=') 78 | keys = address.split('.') 79 | here = config 80 | for key in keys[:-1]: 81 | if key not in here: 82 | here[key] = {} 83 | here = here[key] 84 | if keys[-1] not in here: 85 | print(f'Warning: {address} is not defined in config file.') 86 | here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader) 87 | 88 | task = config['task'] 89 | 90 | ml1 = metaworld.ML1(env_name=task, seed=config['mw_seed']) 91 | 92 | file_name = get_traj_file_name(config) 93 | 94 | # Collcet train task histories 95 | name, env_cls = list(ml1.train_classes.items())[0] 96 | task_instances = ml1.train_tasks 97 | path = f'{args.traj_dir}/{task}/' 98 | 99 | if not os.path.exists(path): 100 | os.makedirs(path, exist_ok=True) 101 | 102 | start_time = datetime.now() 103 | print(f'Collecting train task histories started at {start_time}') 104 | 105 | with h5py.File(os.path.join(path, f'{file_name}.hdf5'), 'a') as f: 106 | start_idx = 0 107 | 108 | while f'{start_idx}' in f.keys(): 109 | start_idx += 1 110 | 111 | with multiprocessing.Manager() as manager: 112 | 113 | while start_idx < len(task_instances): 114 | history = manager.dict() 115 | 116 | instances = task_instances[start_idx:start_idx+config['n_process']] 117 | 118 | with multiprocessing.Pool(processes=config['n_process']) as pool: 119 | pool.starmap(worker, [(config, env_cls, task_instance, path, start_idx+i, history, file_name) for i, task_instance in enumerate(instances)]) 120 | 121 | # Save the history dictionary 122 | for i in range(start_idx, start_idx+len(instances)): 123 | env_group = f.create_group(f'{i}') 124 | for key, value in history[i].items(): 125 | env_group.create_dataset(key, data=value) 126 | 127 | start_idx += len(instances) 128 | 129 | end_time = datetime.now() 130 | print() 131 | print(f'Collecting train task histories ended at {end_time}') 132 | print(f'Elapsed time: {end_time - start_time}') 133 | 134 | 135 | # Collcet test task histories 136 | name, env_cls = list(ml1.test_classes.items())[0] 137 | task_instances = ml1.test_tasks[:10] 138 | path = f'{args.traj_dir}/{task}/test/' 139 | 140 | if not os.path.exists(path): 141 | os.makedirs(path, exist_ok=True) 142 | 143 | start_time = datetime.now() 144 | print(f'Collecting test task histories started at {start_time}') 145 | 146 | print() 147 | 148 | with h5py.File(f'{path}/{file_name}.hdf5', 'a') as f: 149 | start_idx = 0 150 | 151 | while f'{start_idx}' in f.keys(): 152 | start_idx += 1 153 | 154 | with multiprocessing.Manager() as manager: 155 | 156 | while start_idx < len(task_instances): 157 | history = manager.dict() 158 | 159 | instances = task_instances[start_idx:start_idx+config['n_process']] 160 | 161 | with multiprocessing.Pool(processes=config['n_process']) as pool: 162 | pool.starmap(worker, [(config, env_cls, task_instance, path, start_idx+i, history, file_name) for i, task_instance in enumerate(instances)]) 163 | 164 | # Save the history dictionary 165 | for i in range(start_idx, start_idx+len(instances)): 166 | env_group = f.create_group(f'{i}') 167 | for key, value in history[i].items(): 168 | env_group.create_dataset(key, data=value) 169 | 170 | start_idx += len(instances) 171 | 172 | end_time = datetime.now() 173 | print() 174 | print(f'Collecting test task histories ended at {end_time}') 175 | print(f'Elapsed time: {end_time - start_time}') -------------------------------------------------------------------------------- /gridworld/collect_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import yaml 5 | 6 | import sys 7 | sys.path.append(os.path.dirname(sys.path[0])) 8 | 9 | from env import SAMPLE_ENVIRONMENT, make_env, Darkroom, DarkroomPermuted, DarkKeyToDoor 10 | from alg import ALGORITHM, HistoryLoggerCallback 11 | import argparse 12 | import multiprocessing 13 | from utils import get_config, get_traj_file_name 14 | import h5py 15 | import numpy as np 16 | 17 | from stable_baselines3.common.vec_env import DummyVecEnv 18 | 19 | 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--alg-config', '-ac', required=True, help="Algorithm config") 23 | parser.add_argument('--env-config', '-ec', required=True, help="Environment config") 24 | parser.add_argument('--traj-dir', '-t', required=False, default='./datasets', help="Directory for history saving") 25 | parser.add_argument('--override', '-o', default='') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def worker(arg, config, traj_dir, env_idx, history, file_name): 31 | 32 | if config['env'] == 'darkroom': 33 | env = DummyVecEnv([make_env(config, goal=arg)] * config['n_stream']) 34 | elif config['env'] == 'darkroompermuted': 35 | env = DummyVecEnv([make_env(config, perm_idx=arg)] * config['n_stream']) 36 | elif config['env'] == 'darkkeytodoor': 37 | env = DummyVecEnv([make_env(config, key=arg[:2], goal=arg[2:])] * config['n_stream']) 38 | else: 39 | raise ValueError('Invalid environment') 40 | 41 | alg_name = config['alg'] 42 | seed = config['alg_seed'] + env_idx 43 | 44 | # Initialize algorithm 45 | alg = ALGORITHM[alg_name](config, env, seed, traj_dir) 46 | 47 | callback = HistoryLoggerCallback(config['env'], env_idx, history) 48 | 49 | log_name = f'{file_name}_{env_idx}' 50 | 51 | # Execute learning algorithm 52 | alg.learn(total_timesteps=config['total_source_timesteps'], 53 | callback=callback, 54 | log_interval=1, 55 | tb_log_name=log_name, 56 | reset_num_timesteps=True, 57 | progress_bar=False) 58 | 59 | env.close() 60 | 61 | 62 | if __name__ == '__main__': 63 | # Initialize multiprocessing 64 | multiprocessing.set_start_method('spawn') 65 | 66 | args = parse_arguments() 67 | 68 | # Load and update config 69 | config = get_config(args.env_config) 70 | config.update(get_config(args.alg_config)) 71 | 72 | # Ensure the log directory exists 73 | if not os.path.exists(args.traj_dir): 74 | os.makedirs(args.traj_dir, exist_ok=True) 75 | 76 | # Override options 77 | for option in args.override.split('|'): 78 | if not option: 79 | continue 80 | address, value = option.split('=') 81 | keys = address.split('.') 82 | here = config 83 | for key in keys[:-1]: 84 | if key not in here: 85 | here[key] = {} 86 | here = here[key] 87 | if keys[-1] not in here: 88 | print(f'Warning: {address} is not defined in config file.') 89 | here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader) 90 | 91 | train_args, test_args = SAMPLE_ENVIRONMENT[config['env']](config, shuffle=False) 92 | 93 | total_args = train_args + test_args 94 | n_envs = len(total_args) 95 | 96 | file_name = get_traj_file_name(config) 97 | path = f'{args.traj_dir}/{file_name}.hdf5' 98 | 99 | start_time = datetime.now() 100 | print(f'Training started at {start_time}') 101 | 102 | with multiprocessing.Manager() as manager: 103 | history = manager.dict() 104 | 105 | # Create a pool with a maximum of n_workers 106 | with multiprocessing.Pool(processes=config['n_process']) as pool: 107 | # Map the worker function to the environments with the other arguments 108 | pool.starmap(worker, [(total_args[i], config, args.traj_dir, i, history, file_name) for i in range(n_envs)]) 109 | 110 | # Save the history dictionary 111 | with h5py.File(path, 'w-') as f: 112 | for i in range(n_envs): 113 | env_group = f.create_group(f'{i}') 114 | for key, value in history[i].items(): 115 | env_group.create_dataset(key, data=value) 116 | 117 | end_time = datetime.now() 118 | print() 119 | print(f'Training ended at {end_time}') 120 | print(f'Elapsed time: {end_time - start_time}') 121 | 122 | start_time = datetime.now() 123 | print(f'Annotating optimal actions for DPT started at {start_time}') 124 | 125 | with h5py.File(path, 'a') as f: 126 | for i in range(n_envs): 127 | if config['env'] == 'darkroom': 128 | states = f[f'{i}']['states'][()].transpose(1, 0, 2) 129 | actions = f[f'{i}']['actions'][()].transpose(1, 0) 130 | rewards = f[f'{i}']['rewards'][()].transpose(1, 0) 131 | env = Darkroom(config, goal=total_args[i]) 132 | optimal_actions = np.zeros_like(actions) 133 | 134 | for stream_idx in range(states.shape[0]): 135 | for step_idx in range(states.shape[1]): 136 | optimal_actions[stream_idx, step_idx] = env.get_optimal_action(states[stream_idx, step_idx]) 137 | 138 | group = f[f'{i}'] 139 | group.create_dataset('optimal_actions', data=optimal_actions) 140 | 141 | elif config['env'] == 'darkroompermuted': 142 | states = f[f'{i}']['states'][()].transpose(1, 0, 2) 143 | actions = f[f'{i}']['actions'][()].transpose(1, 0) 144 | rewards = f[f'{i}']['rewards'][()].transpose(1, 0) 145 | env = DarkroomPermuted(config, perm_idx=i) 146 | optimal_actions = np.zeros_like(actions) 147 | 148 | for stream_idx in range(states.shape[0]): 149 | for step_idx in range(states.shape[1]): 150 | optimal_actions[stream_idx, step_idx] = env.get_optimal_action(states[stream_idx, step_idx]) 151 | 152 | group = f[f'{i}'] 153 | group.create_dataset('optimal_actions', data=optimal_actions) 154 | 155 | elif config['env'] == 'darkkeytodoor': 156 | states = f[f'{i}']['states'][()].transpose(1, 0, 2).reshape(100, -1, config['horizon'], 2) 157 | actions = f[f'{i}']['actions'][()].transpose(1, 0).reshape(100, -1, config['horizon']) 158 | rewards = f[f'{i}']['rewards'][()].transpose(1, 0).reshape(100, -1, config['horizon']) 159 | env = DarkKeyToDoor(config, key=total_args[i][:2], goal=total_args[i][2:]) 160 | optimal_actions = np.zeros_like(actions) 161 | 162 | for stream_idx in range(states.shape[0]): 163 | for episode_idx in range(states.shape[1]): 164 | have_key=False 165 | for step_idx in range(states.shape[2]): 166 | optimal_actions[stream_idx, episode_idx, step_idx] = env.get_optimal_action(states[stream_idx, episode_idx, step_idx], have_key) 167 | if not have_key and rewards[stream_idx, episode_idx, step_idx] > 0: 168 | have_key = True 169 | 170 | group = f[f'{i}'] 171 | group.create_dataset('optimal_actions', data=optimal_actions.reshape(100, -1)) 172 | 173 | else: 174 | raise ValueError('Invalid environment') 175 | 176 | end_time = datetime.now() 177 | print() 178 | print(f'Annotating ended at {end_time}') 179 | print(f'Elapsed time: {end_time - start_time}') -------------------------------------------------------------------------------- /metaworld/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | from utils import get_traj_file_name 4 | import h5py 5 | from einops import rearrange, repeat 6 | 7 | 8 | 9 | class ADDataset(Dataset): 10 | def __init__(self, config, traj_dir, mode='train', n_seed=None, n_stream=None, source_timesteps=None): 11 | self.env = config['env'] 12 | self.n_transit = config['n_transit'] 13 | self.config = config 14 | 15 | states = [] 16 | actions = [] 17 | rewards = [] 18 | next_states = [] 19 | 20 | if mode == 'train': 21 | file_path = f'{traj_dir}/{get_traj_file_name(config)}.hdf5' 22 | elif mode == 'test': 23 | file_path = f'{traj_dir}/test/{get_traj_file_name(config)}.hdf5' 24 | else: 25 | raise ValueError('Invalid mode') 26 | 27 | with h5py.File(file_path, 'r') as f: 28 | for i in range(n_seed): 29 | states.append(f[f'{i}']['states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps, :11]) 30 | actions.append(f[f'{i}']['actions'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 31 | rewards.append(f[f'{i}']['rewards'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 32 | next_states.append(f[f'{i}']['next_states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps, :11]) 33 | 34 | self.states = np.concatenate(states, axis=0) 35 | self.actions = np.concatenate(actions, axis=0) 36 | self.rewards = np.concatenate(rewards, axis=0) 37 | self.next_states = np.concatenate(next_states, axis=0) 38 | 39 | def __len__(self): 40 | return (len(self.states[0]) - self.n_transit + 1) * len(self.states) 41 | 42 | def __getitem__(self, i): 43 | history_idx = i // (len(self.states[0]) - self.n_transit + 1) 44 | transition_idx = i % (len(self.states[0]) - self.n_transit + 1) 45 | 46 | traj = { 47 | 'query_states': self.states[history_idx, transition_idx + self.n_transit - 1], 48 | 'target_actions': self.actions[history_idx, transition_idx + self.n_transit - 1], 49 | 'states': self.states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 50 | 'actions': self.actions[history_idx, transition_idx:transition_idx + self.n_transit - 1], 51 | 'rewards': self.rewards[history_idx, transition_idx:transition_idx + self.n_transit - 1], 52 | 'next_states': self.next_states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 53 | } 54 | 55 | if self.config['dynamics']: 56 | traj.update({ 57 | 'target_next_states': self.next_states[history_idx, transition_idx + self.n_transit - 1], 58 | 'target_rewards': self.rewards[history_idx, transition_idx + self.n_transit - 1] 59 | }) 60 | 61 | return traj 62 | 63 | 64 | class IDTDataset(Dataset): 65 | def __init__(self, config, traj_dir, mode='train', n_seed=50, n_stream=None, source_timesteps=None): 66 | self.config = config 67 | self.env = config['env'] 68 | self.n_transit = config['n_transit'] 69 | 70 | states = [] 71 | actions = [] 72 | rewards = [] 73 | next_states = [] 74 | 75 | if mode == 'train': 76 | file_path = f'{traj_dir}/{get_traj_file_name(config)}.hdf5' 77 | elif mode == 'test': 78 | file_path = f'{traj_dir}/test/{get_traj_file_name(config)}.hdf5' 79 | else: 80 | raise ValueError('Invalid mode') 81 | 82 | with h5py.File(file_path, 'r') as f: 83 | for i in range(n_seed): 84 | states.append(f[f'{i}']['states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps, :config['dim_obs']]) 85 | actions.append(f[f'{i}']['actions'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 86 | rewards.append(f[f'{i}']['rewards'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 87 | next_states.append(f[f'{i}']['next_states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps, :config['dim_obs']]) 88 | 89 | self.states = np.concatenate(states, axis=0) 90 | self.actions = np.concatenate(actions, axis=0) 91 | self.rewards = np.concatenate(rewards, axis=0) 92 | self.next_states = np.concatenate(next_states, axis=0) 93 | self.returns_to_go = self.get_returns_to_go(self.rewards) 94 | 95 | self.sort_episodes() 96 | 97 | self.returns_to_go = self.relabel_returns_to_go(self.returns_to_go) 98 | 99 | def __len__(self): 100 | return (len(self.states[0]) - self.n_transit + 1) * len(self.states) 101 | 102 | def __getitem__(self, i): 103 | history_idx = i // (len(self.states[0]) - self.n_transit + 1) 104 | transition_idx = i % (len(self.states[0]) - self.n_transit + 1) 105 | 106 | traj = { 107 | 'states': self.states[history_idx, transition_idx:transition_idx + self.n_transit], 108 | 'actions': self.actions[history_idx, transition_idx:transition_idx + self.n_transit], 109 | 'rewards': self.rewards[history_idx, transition_idx:transition_idx + self.n_transit], 110 | 'returns_to_go': self.returns_to_go[history_idx, transition_idx:transition_idx + self.n_transit], 111 | 'next_states': self.next_states[history_idx, transition_idx:transition_idx + self.n_transit], 112 | } 113 | 114 | return traj 115 | 116 | def get_returns_to_go(self, rewards): 117 | episode_rewards = rewards.reshape(-1, rewards.shape[1] // self.config['horizon'], self.config['horizon']) 118 | return np.flip(np.flip(episode_rewards, axis=-1).cumsum(axis=-1), axis=-1).reshape(-1, rewards.shape[1]) 119 | 120 | def sort_episodes(self): 121 | returns_to_go = rearrange(self.returns_to_go, 'traj (epi time) -> traj epi time', time=self.config['horizon']) 122 | sorted_episode_idx = np.argsort(returns_to_go[:, :, 0]) 123 | sorted_episode_idx = repeat(sorted_episode_idx, 'traj epi -> traj epi time', time=self.config['horizon']) 124 | 125 | returns_to_go = np.take_along_axis(returns_to_go, sorted_episode_idx, axis=1) 126 | self.returns_to_go = rearrange(returns_to_go, 'traj epi time -> traj (epi time)') 127 | 128 | rewards = rearrange(self.rewards, 'traj (epi time) -> traj epi time', time=self.config['horizon']) 129 | rewards = np.take_along_axis(rewards, sorted_episode_idx, axis=1) 130 | self.rewards = rearrange(rewards, 'traj epi time -> traj (epi time)') 131 | 132 | actions = rearrange(self.actions, 'traj (epi time) dim -> traj epi time dim', time=self.config['horizon']) 133 | actions = np.take_along_axis(actions, 134 | repeat(sorted_episode_idx, 'traj epi time -> traj epi time dim', dim=self.actions.shape[-1]), 135 | axis=1) 136 | self.actions = rearrange(actions, 'traj epi time dim -> traj (epi time) dim') 137 | 138 | sorted_episode_idx = repeat(sorted_episode_idx, 'traj epi time -> traj epi time dim', dim=self.states.shape[-1]) 139 | 140 | states = rearrange(self.states, 'traj (epi time) dim -> traj epi time dim', time=self.config['horizon']) 141 | states = np.take_along_axis(states, sorted_episode_idx, axis=1) 142 | self.states = rearrange(states, 'traj epi time dim -> traj (epi time) dim') 143 | 144 | next_states = rearrange(self.next_states, 'traj (epi time) dim -> traj epi time dim', time=self.config['horizon']) 145 | next_states = np.take_along_axis(next_states, sorted_episode_idx, axis=1) 146 | self.next_states = rearrange(next_states, 'traj epi time dim -> traj (epi time) dim') 147 | 148 | def relabel_returns_to_go(self, rtg): 149 | max_episode_rtg = rtg.max(axis=-1) # (num_traj, ) 150 | max_episode_rtg = repeat(max_episode_rtg, 'traj -> traj epi', epi=rtg.shape[1] // self.config['horizon']) 151 | 152 | episode_rtg = rtg.reshape(-1, rtg.shape[1] // self.config['horizon'], self.config['horizon']) 153 | 154 | episode_offset = max_episode_rtg - episode_rtg[:, :, 0] 155 | offset = repeat(episode_offset, 'traj epi -> traj epi time', time=self.config['horizon']) 156 | 157 | return (episode_rtg + offset).reshape(-1, rtg.shape[1]) -------------------------------------------------------------------------------- /gridworld/model/tiny_llama/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | # from flash_attn import flash_attn_func 8 | from lightning_utilities.core.imports import RequirementCache 9 | from xformers.ops import SwiGLU 10 | 11 | from .config import Config 12 | from .fused_rotary_embedding import apply_rotary_emb_func 13 | 14 | from torch.nn.attention import sdpa_kernel, SDPBackend 15 | 16 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 17 | KVCache = Tuple[torch.Tensor, torch.Tensor] 18 | FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 19 | 20 | 21 | class Transformer(nn.Module): 22 | def __init__(self, config) -> None: 23 | super().__init__() 24 | self.config = Config( 25 | block_size=3*config['n_transit'], 26 | n_layer=config['tf_n_layer'], 27 | n_head=config['tf_n_head'], 28 | n_embd=config['tf_n_embd'], 29 | bias=True, 30 | rotary_percentage=1.0, 31 | parallel_residual=False, 32 | shared_attention_norm=False, 33 | _norm_class="FusedRMSNorm", 34 | _mlp_class="LLaMAMLP", 35 | dropout=config['tf_dropout'], 36 | attention_dropout=config['tf_attn_dropout'], 37 | intermediate_size=config['tf_n_inner'], 38 | flash_attn=config['flash_attn'], 39 | ) 40 | self.device = config['device'] 41 | self.blocks = nn.ModuleList(Block(self.config) for _ in range(config['tf_n_layer'])) 42 | self.rope_cache_fp16 = self.build_rope_cache(device=self.device, dtype=torch.float16) 43 | self.rope_cache_bf16 = self.build_rope_cache(device=self.device, dtype=torch.bfloat16) 44 | self.rope_cache_fp32 = self.build_rope_cache(device=self.device, dtype=torch.float32) 45 | 46 | def forward(self, 47 | x: torch.Tensor, 48 | max_seq_length: int, 49 | mask: Optional[torch.Tensor] = None, 50 | dtype="bf16") -> Tuple[torch.Tensor, Optional[KVCache]]: 51 | 52 | if dtype == "bf16": 53 | cos, sin = self.rope_cache_bf16 54 | elif dtype == "fp16": 55 | cos, sin = self.rope_cache_fp16 56 | elif dtype == "fp32": 57 | cos, sin = self.rope_cache_fp32 58 | else: 59 | raise ValueError(f"Unsupported dtype: {dtype}") 60 | 61 | for block in self.blocks: 62 | x, *_ = block(x, 63 | (cos[:x.size(1)], sin[:x.size(1)]), 64 | max_seq_length, 65 | mask) 66 | return x 67 | 68 | def build_rope_cache(self, device, dtype) -> RoPECache: 69 | return build_rope_cache( 70 | seq_len=self.config.block_size, 71 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 72 | dtype=dtype, 73 | device=device, 74 | condense_ratio=self.config.condense_ratio, 75 | ) 76 | 77 | 78 | 79 | class Block(nn.Module): 80 | def __init__(self, config: Config) -> None: 81 | super().__init__() 82 | self.norm_1 = config.norm_class( 83 | config.n_embd, eps=config.norm_eps, dropout=config.dropout 84 | ) 85 | self.attn = CausalSelfAttention(config) 86 | if not config.shared_attention_norm: 87 | self.norm_2 = config.norm_class( 88 | config.n_embd, eps=config.norm_eps, dropout=config.dropout 89 | ) 90 | self.mlp = getattr(sys.modules[__name__], config._mlp_class)(config) 91 | self.config = config 92 | 93 | def forward( 94 | self, 95 | x: torch.Tensor, 96 | rope: RoPECache, 97 | max_seq_length: int, 98 | mask: Optional[torch.Tensor] = None, 99 | input_pos: Optional[torch.Tensor] = None, 100 | kv_cache: Optional[KVCache] = None, 101 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 102 | n_1 = self.norm_1(x) 103 | h, new_kv_cache = self.attn( 104 | n_1, rope, max_seq_length, mask, input_pos, kv_cache 105 | ) 106 | if self.config.parallel_residual: 107 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 108 | x = x + h + self.mlp(n_2) 109 | else: 110 | if self.config.shared_attention_norm: 111 | raise NotImplementedError( 112 | "No checkpoint amongst the ones we support uses this configuration" 113 | " (non-parallel residual and shared attention norm)." 114 | ) 115 | 116 | x = x + h 117 | x = x + self.mlp(self.norm_2(x)) 118 | return x, new_kv_cache 119 | 120 | 121 | class CausalSelfAttention(nn.Module): 122 | def __init__(self, config: Config) -> None: 123 | super().__init__() 124 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 125 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 126 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 127 | 128 | self.config = config 129 | 130 | def forward( 131 | self, 132 | x: torch.Tensor, 133 | rope: RoPECache, 134 | max_seq_length: int, 135 | mask: Optional[torch.Tensor] = None, 136 | input_pos: Optional[torch.Tensor] = None, 137 | kv_cache: Optional[KVCache] = None, 138 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 139 | ( 140 | B, 141 | T, 142 | C, 143 | ) = x.size() 144 | 145 | qkv = self.attn(x) 146 | 147 | q_per_kv = self.config.n_head // self.config.n_query_groups 148 | total_qkv = q_per_kv + 2 149 | qkv = qkv.view( 150 | B, T, self.config.n_query_groups, total_qkv, self.config.head_size 151 | ) 152 | 153 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 154 | 155 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 156 | k = k.reshape(B, T, -1, self.config.head_size) 157 | v = v.reshape(B, T, -1, self.config.head_size) 158 | 159 | cos, sin = rope 160 | q = apply_rotary_emb_func(q, cos, sin, False, True) 161 | k = apply_rotary_emb_func(k, cos, sin, False, True) 162 | 163 | if kv_cache is not None: 164 | cache_k, cache_v = kv_cache 165 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 166 | if input_pos[-1] >= max_seq_length: 167 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 168 | cache_k = torch.roll(cache_k, -1, dims=2) 169 | cache_v = torch.roll(cache_v, -1, dims=2) 170 | k = cache_k.index_copy_(2, input_pos, k) 171 | v = cache_v.index_copy_(2, input_pos, v) 172 | kv_cache = k, v 173 | 174 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 175 | 176 | y = y.reshape(B, T, C) 177 | 178 | y = self.proj(y) 179 | 180 | return y, kv_cache 181 | 182 | def scaled_dot_product_attention( 183 | self, 184 | q: torch.Tensor, 185 | k: torch.Tensor, 186 | v: torch.Tensor, 187 | mask: Optional[torch.Tensor] = None, 188 | ): 189 | scale = 1.0 / math.sqrt(self.config.head_size) 190 | q = q.transpose(1, 2) 191 | k = k.transpose(1, 2) 192 | v = v.transpose(1, 2) 193 | if q.size() != k.size(): 194 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 195 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 196 | 197 | if ( 198 | FlashAttention2Available 199 | and mask is None 200 | and q.device.type == "cuda" 201 | and q.dtype in (torch.float16, torch.bfloat16) 202 | and self.config.flash_attn 203 | ): 204 | attn_type = SDPBackend.FLASH_ATTENTION 205 | else: 206 | attn_type = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] 207 | 208 | with sdpa_kernel(attn_type): 209 | y = torch.nn.functional.scaled_dot_product_attention( 210 | q, k, v, 211 | attn_mask=mask, 212 | dropout_p=self.config.attention_dropout if self.training else 0.0, 213 | scale=scale, 214 | is_causal=mask is None 215 | ) 216 | 217 | 218 | return y.transpose(1, 2) 219 | 220 | 221 | class GptNeoxMLP(nn.Module): 222 | def __init__(self, config: Config) -> None: 223 | super().__init__() 224 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 225 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 226 | 227 | def forward(self, x: torch.Tensor) -> torch.Tensor: 228 | x = self.fc(x) 229 | x = torch.nn.functional.gelu(x) 230 | return self.proj(x) 231 | 232 | 233 | class LLaMAMLP(nn.Module): 234 | def __init__(self, config: Config) -> None: 235 | super().__init__() 236 | self.swiglu = SwiGLU( 237 | config.n_embd, config.intermediate_size, bias=False, _pack_weights=False 238 | ) 239 | 240 | def forward(self, x: torch.Tensor) -> torch.Tensor: 241 | return self.swiglu(x) 242 | 243 | 244 | def build_rope_cache( 245 | seq_len: int, 246 | n_elem: int, 247 | dtype: torch.dtype, 248 | device: torch.device, 249 | base: int = 10000, 250 | condense_ratio: int = 1, 251 | ) -> RoPECache: 252 | 253 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 254 | 255 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 256 | 257 | idx_theta = torch.outer(seq_idx, theta) 258 | 259 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 260 | 261 | if dtype == torch.bfloat16: 262 | return cos.bfloat16(), sin.bfloat16() 263 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 264 | return cos.half(), sin.half() 265 | return cos, sin 266 | 267 | 268 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 269 | head_size = x.size(-1) 270 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 271 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 272 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 273 | roped = (x * cos) + (rotated * sin) 274 | return roped.type_as(x) 275 | -------------------------------------------------------------------------------- /metaworld/model/tiny_llama/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | # from flash_attn import flash_attn_func 8 | from lightning_utilities.core.imports import RequirementCache 9 | from xformers.ops import SwiGLU 10 | 11 | from .config import Config 12 | from .fused_rotary_embedding import apply_rotary_emb_func 13 | 14 | from torch.nn.attention import sdpa_kernel, SDPBackend 15 | 16 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 17 | KVCache = Tuple[torch.Tensor, torch.Tensor] 18 | FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 19 | 20 | 21 | class Transformer(nn.Module): 22 | def __init__(self, config) -> None: 23 | super().__init__() 24 | self.config = Config( 25 | block_size=3*config['n_transit'], 26 | n_layer=config['tf_n_layer'], 27 | n_head=config['tf_n_head'], 28 | n_embd=config['tf_n_embd'], 29 | bias=True, 30 | rotary_percentage=1.0, 31 | parallel_residual=False, 32 | shared_attention_norm=False, 33 | _norm_class="FusedRMSNorm", 34 | _mlp_class="LLaMAMLP", 35 | dropout=config['tf_dropout'], 36 | attention_dropout=config['tf_attn_dropout'], 37 | intermediate_size=config['tf_n_inner'], 38 | flash_attn=config['flash_attn'], 39 | ) 40 | self.device = config['device'] 41 | self.blocks = nn.ModuleList(Block(self.config) for _ in range(config['tf_n_layer'])) 42 | self.rope_cache_fp16 = self.build_rope_cache(device=self.device, dtype=torch.float16) 43 | self.rope_cache_bf16 = self.build_rope_cache(device=self.device, dtype=torch.bfloat16) 44 | self.rope_cache_fp32 = self.build_rope_cache(device=self.device, dtype=torch.float32) 45 | 46 | def forward(self, 47 | x: torch.Tensor, 48 | max_seq_length: int, 49 | mask: Optional[torch.Tensor] = None, 50 | dtype="bf16") -> Tuple[torch.Tensor, Optional[KVCache]]: 51 | 52 | if dtype == "bf16": 53 | cos, sin = self.rope_cache_bf16 54 | elif dtype == "fp16": 55 | cos, sin = self.rope_cache_fp16 56 | elif dtype == "fp32": 57 | cos, sin = self.rope_cache_fp32 58 | else: 59 | raise ValueError(f"Unsupported dtype: {dtype}") 60 | 61 | for block in self.blocks: 62 | x, *_ = block(x, 63 | (cos[:x.size(1)], sin[:x.size(1)]), 64 | max_seq_length, 65 | mask) 66 | return x 67 | 68 | def build_rope_cache(self, device, dtype) -> RoPECache: 69 | return build_rope_cache( 70 | seq_len=self.config.block_size, 71 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 72 | dtype=dtype, 73 | device=device, 74 | condense_ratio=self.config.condense_ratio, 75 | ) 76 | 77 | 78 | 79 | class Block(nn.Module): 80 | def __init__(self, config: Config) -> None: 81 | super().__init__() 82 | self.norm_1 = config.norm_class( 83 | config.n_embd, eps=config.norm_eps, dropout=config.dropout 84 | ) 85 | self.attn = CausalSelfAttention(config) 86 | if not config.shared_attention_norm: 87 | self.norm_2 = config.norm_class( 88 | config.n_embd, eps=config.norm_eps, dropout=config.dropout 89 | ) 90 | self.mlp = getattr(sys.modules[__name__], config._mlp_class)(config) 91 | self.config = config 92 | 93 | def forward( 94 | self, 95 | x: torch.Tensor, 96 | rope: RoPECache, 97 | max_seq_length: int, 98 | mask: Optional[torch.Tensor] = None, 99 | input_pos: Optional[torch.Tensor] = None, 100 | kv_cache: Optional[KVCache] = None, 101 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 102 | n_1 = self.norm_1(x) 103 | h, new_kv_cache = self.attn( 104 | n_1, rope, max_seq_length, mask, input_pos, kv_cache 105 | ) 106 | if self.config.parallel_residual: 107 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 108 | x = x + h + self.mlp(n_2) 109 | else: 110 | if self.config.shared_attention_norm: 111 | raise NotImplementedError( 112 | "No checkpoint amongst the ones we support uses this configuration" 113 | " (non-parallel residual and shared attention norm)." 114 | ) 115 | 116 | x = x + h 117 | x = x + self.mlp(self.norm_2(x)) 118 | return x, new_kv_cache 119 | 120 | 121 | class CausalSelfAttention(nn.Module): 122 | def __init__(self, config: Config) -> None: 123 | super().__init__() 124 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 125 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 126 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 127 | 128 | self.config = config 129 | 130 | def forward( 131 | self, 132 | x: torch.Tensor, 133 | rope: RoPECache, 134 | max_seq_length: int, 135 | mask: Optional[torch.Tensor] = None, 136 | input_pos: Optional[torch.Tensor] = None, 137 | kv_cache: Optional[KVCache] = None, 138 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 139 | ( 140 | B, 141 | T, 142 | C, 143 | ) = x.size() 144 | 145 | qkv = self.attn(x) 146 | 147 | q_per_kv = self.config.n_head // self.config.n_query_groups 148 | total_qkv = q_per_kv + 2 149 | qkv = qkv.view( 150 | B, T, self.config.n_query_groups, total_qkv, self.config.head_size 151 | ) 152 | 153 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 154 | 155 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 156 | k = k.reshape(B, T, -1, self.config.head_size) 157 | v = v.reshape(B, T, -1, self.config.head_size) 158 | 159 | cos, sin = rope 160 | q = apply_rotary_emb_func(q, cos, sin, False, True) 161 | k = apply_rotary_emb_func(k, cos, sin, False, True) 162 | 163 | if kv_cache is not None: 164 | cache_k, cache_v = kv_cache 165 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 166 | if input_pos[-1] >= max_seq_length: 167 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 168 | cache_k = torch.roll(cache_k, -1, dims=2) 169 | cache_v = torch.roll(cache_v, -1, dims=2) 170 | k = cache_k.index_copy_(2, input_pos, k) 171 | v = cache_v.index_copy_(2, input_pos, v) 172 | kv_cache = k, v 173 | 174 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 175 | 176 | y = y.reshape(B, T, C) 177 | 178 | y = self.proj(y) 179 | 180 | return y, kv_cache 181 | 182 | def scaled_dot_product_attention( 183 | self, 184 | q: torch.Tensor, 185 | k: torch.Tensor, 186 | v: torch.Tensor, 187 | mask: Optional[torch.Tensor] = None, 188 | ): 189 | scale = 1.0 / math.sqrt(self.config.head_size) 190 | q = q.transpose(1, 2) 191 | k = k.transpose(1, 2) 192 | v = v.transpose(1, 2) 193 | if q.size() != k.size(): 194 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 195 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 196 | 197 | if ( 198 | FlashAttention2Available 199 | and mask is None 200 | and q.device.type == "cuda" 201 | and q.dtype in (torch.float16, torch.bfloat16) 202 | and self.config.flash_attn 203 | ): 204 | attn_type = SDPBackend.FLASH_ATTENTION 205 | else: 206 | attn_type = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] 207 | 208 | with sdpa_kernel(attn_type): 209 | y = torch.nn.functional.scaled_dot_product_attention( 210 | q, k, v, 211 | attn_mask=mask, 212 | dropout_p=self.config.attention_dropout if self.training else 0.0, 213 | scale=scale, 214 | is_causal=mask is None 215 | ) 216 | 217 | 218 | return y.transpose(1, 2) 219 | 220 | 221 | class GptNeoxMLP(nn.Module): 222 | def __init__(self, config: Config) -> None: 223 | super().__init__() 224 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 225 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 226 | 227 | def forward(self, x: torch.Tensor) -> torch.Tensor: 228 | x = self.fc(x) 229 | x = torch.nn.functional.gelu(x) 230 | return self.proj(x) 231 | 232 | 233 | class LLaMAMLP(nn.Module): 234 | def __init__(self, config: Config) -> None: 235 | super().__init__() 236 | self.swiglu = SwiGLU( 237 | config.n_embd, config.intermediate_size, bias=False, _pack_weights=False 238 | ) 239 | 240 | def forward(self, x: torch.Tensor) -> torch.Tensor: 241 | return self.swiglu(x) 242 | 243 | 244 | def build_rope_cache( 245 | seq_len: int, 246 | n_elem: int, 247 | dtype: torch.dtype, 248 | device: torch.device, 249 | base: int = 10000, 250 | condense_ratio: int = 1, 251 | ) -> RoPECache: 252 | 253 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 254 | 255 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 256 | 257 | idx_theta = torch.outer(seq_idx, theta) 258 | 259 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 260 | 261 | if dtype == torch.bfloat16: 262 | return cos.bfloat16(), sin.bfloat16() 263 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 264 | return cos.half(), sin.half() 265 | return cos, sin 266 | 267 | 268 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 269 | head_size = x.size(-1) 270 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 271 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 272 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 273 | roped = (x * cos) + (rotated * sin) 274 | return roped.type_as(x) 275 | -------------------------------------------------------------------------------- /metaworld/model/ad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .tiny_llama.model import Transformer 4 | from einops import pack, rearrange, repeat 5 | import numpy as np 6 | 7 | 8 | class AD(torch.nn.Module): 9 | def __init__(self, config): 10 | super(AD, self).__init__() 11 | 12 | self.config = config 13 | self.device = config['device'] 14 | self.n_transit = config['n_transit'] 15 | self.max_seq_length = config['n_transit'] 16 | self.mixed_precision = config['mixed_precision'] 17 | 18 | if 'task' in config and (config['task'] == "hammer-v2" or config['task'] == "stick-push-v2" or config['task'] == "stick-pull-v2"): 19 | assert config['dim_obs'] == 18 20 | self.obs_dim_idx = list(range(18)) 21 | else: 22 | assert config['dim_obs'] == 11 23 | self.obs_dim_idx = list(range(11)) 24 | 25 | self.transformer = Transformer(config) 26 | 27 | self.embed_context = nn.Linear(config['dim_obs'] * 2 + config['dim_actions'] + 1, config['tf_n_embd']) 28 | self.embed_query_state = nn.Linear(config['dim_obs'], config['tf_n_embd']) 29 | 30 | self.loss_fn = nn.MSELoss(reduction='mean') 31 | 32 | if config['learn_var']: 33 | self.pred_actions = nn.Linear(config['tf_n_embd'], 2 * config['dim_actions']) 34 | self.loss_fn_gaussian = nn.GaussianNLLLoss(full=True, reduction='mean') 35 | else: 36 | self.pred_actions = nn.Linear(config['tf_n_embd'], config['dim_actions']) 37 | 38 | if config['dynamics']: 39 | self.embed_query_action = nn.Linear(config['dim_actions'], config['tf_n_embd']) 40 | self.pred_rewards = nn.Linear(config['tf_n_embd'], 1) 41 | 42 | if config['learn_transition']: 43 | self.pred_next_states = nn.Linear(config['tf_n_embd'], config['dim_obs']) 44 | 45 | def forward(self, x): 46 | query_states = x['query_states'].to(self.device) # (batch_size, dim_obs) 47 | target_actions = x['target_actions'].to(self.device) # (batch_size, dim_actions) 48 | states = x['states'].to(self.device) # (batch_size, n_transit-1, dim_obs) 49 | actions = x['actions'].to(self.device) # (batch_size, n_transit-1, dim_actions) 50 | next_states = x['next_states'].to(self.device) # (batch_size, n_transit-1, dim_obs) 51 | rewards = x['rewards'].to(self.device) # (batch_size, n_transit-1) 52 | rewards = rearrange(rewards, 'b n -> b n 1') 53 | 54 | query_states_embed = self.embed_query_state(query_states) 55 | query_states_embed = rearrange(query_states_embed, 'b d -> b 1 d') 56 | 57 | context, _ = pack([states, actions, rewards, next_states], 'b n *') 58 | 59 | context_embed = self.embed_context(context) 60 | context_embed, _ = pack([context_embed, query_states_embed], 'b * d') 61 | 62 | if self.config['dynamics']: 63 | query_actions_embed = self.embed_query_action(target_actions) 64 | query_actions_embed = rearrange(query_actions_embed, 'b d -> b 1 d') 65 | context_embed, _ = pack([context_embed, query_actions_embed], 'b * d') 66 | 67 | transformer_output = self.transformer(context_embed, 68 | max_seq_length=self.max_seq_length, 69 | dtype=self.mixed_precision) 70 | 71 | result = {} 72 | 73 | if self.config['learn_var']: 74 | dist = self.pred_actions(transformer_output[:, self.n_transit-1]) 75 | mean = dist[:, :self.config['dim_actions']] 76 | var = torch.exp(dist[:, self.config['dim_actions']:]) 77 | result['loss_action'] = self.loss_fn_gaussian(mean, target_actions, var) 78 | else: 79 | predicted_actions = self.pred_actions(transformer_output[:, self.n_transit-1]) 80 | result['loss_action'] = self.loss_fn(predicted_actions, target_actions) 81 | 82 | if self.config['dynamics']: 83 | predicted_rewards = self.pred_rewards(transformer_output[:, -1])[:, 0].clip(0, 10) 84 | target_rewards = x['target_rewards'].to(self.device) # (batch_size, ) 85 | result['loss_reward'] = self.loss_fn(predicted_rewards, target_rewards) 86 | 87 | if self.config['learn_transition']: 88 | predicted_states = self.pred_next_states(transformer_output[:, -1]).clip(self.obs_low, self.obs_high) 89 | target_states = x['target_next_states'].to(self.device) # (batch_size, dim_obs) 90 | result['loss_next_state'] = self.loss_fn(predicted_states, target_states) 91 | 92 | return result 93 | 94 | def evaluate_in_context(self, vec_env, eval_timesteps, sample_size=1, beam_start=50, sample=True): 95 | outputs = {} 96 | outputs['reward_episode'] = [] 97 | outputs['success'] = [] 98 | 99 | reward_episode = np.zeros(vec_env.num_envs) 100 | success = np.zeros(vec_env.num_envs) 101 | 102 | # Get inital states embeddings 103 | query_states = vec_env.reset()[:, self.obs_dim_idx] # (n_envs, obs_dim) 104 | query_states = torch.tensor(query_states, device=self.device, requires_grad=False, dtype=torch.float) 105 | query_states = rearrange(query_states, 'e d -> e 1 d') 106 | query_states_embed = self.embed_query_state(query_states) 107 | transformer_input = query_states_embed 108 | 109 | for step in range(eval_timesteps): 110 | query_states_prev = query_states.clone().detach() 111 | 112 | position=step % self.config['horizon'] 113 | if self.config['dynamics'] and sample_size > 1 and step >= self.n_transit and position > beam_start: 114 | actions = self.greedy_search(x=transformer_input.clone().detach(), 115 | sample_size=sample_size) 116 | else: 117 | output = self.transformer(transformer_input, 118 | max_seq_length=self.max_seq_length, 119 | dtype='fp32') 120 | 121 | if self.config['learn_var']: 122 | dist = self.pred_actions(output[:, -1]) 123 | mean = dist[:, :self.config['dim_actions']] 124 | std = torch.exp(dist[:, self.config['dim_actions']:] / 2) 125 | actions = (std * torch.randn_like(mean) + mean) 126 | elif sample: 127 | mean = self.pred_actions(output[:, -1]) 128 | std = torch.ones_like(mean) 129 | actions = (std * torch.randn_like(mean) + mean) 130 | else: 131 | actions = self.pred_actions(output[:, -1]) 132 | 133 | query_states, rewards, dones, infos = vec_env.step(actions.cpu().numpy()) 134 | 135 | actions = rearrange(actions, 'e d -> e 1 d') 136 | 137 | reward_episode += rewards 138 | rewards = torch.tensor(rewards, device=self.device, requires_grad=False, dtype=torch.float) 139 | rewards = rearrange(rewards, 'e -> e 1 1') 140 | 141 | query_states = torch.tensor(query_states[:, self.obs_dim_idx], device=self.device, requires_grad=False, dtype=torch.float) 142 | query_states = rearrange(query_states, 'e d -> e 1 d') 143 | 144 | success += np.array([info['success'] for info in infos]) 145 | 146 | if dones[0]: 147 | outputs['reward_episode'].append(reward_episode) 148 | reward_episode = np.zeros(vec_env.num_envs) 149 | outputs['success'].append(success > 0.0) 150 | success = np.zeros(vec_env.num_envs) 151 | 152 | states_next = torch.tensor(np.stack([info['terminal_observation'][self.obs_dim_idx] for info in infos]), device=self.device, dtype=torch.float) 153 | states_next = rearrange(states_next, 'e d -> e 1 d') 154 | else: 155 | states_next = query_states.clone().detach() 156 | 157 | query_states_embed = self.embed_query_state(query_states) 158 | 159 | context, _ = pack([query_states_prev, actions, rewards, states_next], 'e i *') 160 | context_embed = self.embed_context(context) 161 | 162 | if transformer_input.size(1) > 1: 163 | context_embed, _ = pack([transformer_input[:, :-1], context_embed], 'e * h') 164 | context_embed = context_embed[:, -(self.n_transit-1):] 165 | 166 | transformer_input, _ = pack([context_embed, query_states_embed], 'e * h') 167 | 168 | outputs['reward_episode'] = np.stack(outputs['reward_episode'], axis=1) 169 | outputs['success'] = np.maximum.accumulate(np.stack(outputs['success'], axis=1), axis=-1) 170 | 171 | return outputs 172 | 173 | def greedy_search(self, x, sample_size=5): 174 | batch_size = x.size(0) 175 | 176 | output = self.transformer(x, 177 | max_seq_length=self.max_seq_length, 178 | dtype="fp32") 179 | 180 | if self.config['learn_var']: 181 | dist_actions = self.pred_actions(output[:, -1]) 182 | mean_actions = dist_actions[:, :self.config['dim_actions']] 183 | std_actions = torch.exp(dist_actions[:, self.config['dim_actions']:] / 2) 184 | else: 185 | mean_actions = self.pred_actions(output[:, -1]) 186 | std_actions = torch.ones_like(mean_actions) 187 | 188 | mean_actions = rearrange(mean_actions, 'b a -> b 1 a') 189 | std_actions = rearrange(std_actions, 'b a -> b 1 a') 190 | 191 | sampled_actions = torch.randn((batch_size, sample_size-1, self.config['dim_actions']), device=self.device) * std_actions + mean_actions 192 | sampled_actions, _ = pack([mean_actions, sampled_actions], 'b * a') # (batch_size, sample_size, dim_actions), Use mean as one action sample 193 | 194 | # Query sampled actions 195 | embed_actions = self.embed_query_action(sampled_actions) # (batch_size, sample_size, hidden) 196 | embed_actions = rearrange(embed_actions, 'b k h -> b k 1 h') 197 | x = repeat(x, 'b n h -> b k n h', k=sample_size) 198 | x, _ = pack([x, embed_actions], 'b k * h') 199 | 200 | output = self.transformer(rearrange(x, 'b k n h -> (b k) n h'), 201 | max_seq_length=self.max_seq_length, 202 | dtype="fp32") 203 | 204 | output = rearrange(output, '(b k) n h -> b k n h', k=sample_size) 205 | 206 | rewards = self.pred_rewards(output[:, :, -1]).clip(0, 10) 207 | rewards = rearrange(rewards, 'b k 1 -> b k') 208 | rewards, indicies = rewards.sort(dim=-1, descending=True) 209 | beam = torch.gather(sampled_actions, 1, repeat(indicies, 'b k -> b k a', a=sampled_actions.size(2))) 210 | beam = rearrange(beam, 'b k a -> b k 1 a') 211 | 212 | return beam[:, 0, 0] 213 | 214 | def set_obs_space(self, obs_space): 215 | self.obs_low = torch.tensor(obs_space.low[:self.config['dim_obs']], device=self.device, requires_grad=False, dtype=torch.float) 216 | self.obs_high = torch.tensor(obs_space.high[:self.config['dim_obs']], device=self.device, requires_grad=False, dtype=torch.float) 217 | 218 | def set_action_space(self, action_space): 219 | self.action_low = torch.tensor(action_space.low, device=self.device, requires_grad=False, dtype=torch.float) 220 | self.action_high = torch.tensor(action_space.high, device=self.device, requires_grad=False, dtype=torch.float) -------------------------------------------------------------------------------- /gridworld/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | from utils import get_traj_file_name 4 | import h5py 5 | import random 6 | from einops import rearrange, repeat 7 | 8 | 9 | class ADDataset(Dataset): 10 | def __init__(self, config, traj_dir, mode='train', n_stream=None, source_timesteps=None): 11 | self.config = config 12 | self.env = config['env'] 13 | self.n_transit = config['n_transit'] 14 | self.dynamics = config['dynamics'] 15 | 16 | if self.env == 'darkroom': 17 | n_total_envs = config['grid_size'] ** 2 18 | elif self.env == 'darkroompermuted': 19 | n_total_envs = 120 20 | elif self.env == 'darkkeytodoor': 21 | n_total_envs = config['grid_size'] ** 4 22 | else: 23 | raise ValueError('Invalid env') 24 | 25 | total_env_idx = list(range(n_total_envs)) 26 | random.seed(config['env_split_seed']) 27 | random.shuffle(total_env_idx) 28 | 29 | n_train_envs = round(n_total_envs * config['train_env_ratio']) 30 | 31 | if mode == 'train': 32 | env_idx = total_env_idx[:n_train_envs] 33 | elif mode == 'test': 34 | env_idx = total_env_idx[n_train_envs:] 35 | elif mode == 'all': 36 | env_idx = total_env_idx 37 | else: 38 | raise ValueError('Invalid mode') 39 | 40 | states = [] 41 | actions = [] 42 | rewards = [] 43 | next_states = [] 44 | 45 | with h5py.File(f'{traj_dir}/{get_traj_file_name(config)}.hdf5', 'r') as f: 46 | for i in env_idx: 47 | states.append(f[f'{i}']['states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 48 | actions.append(f[f'{i}']['actions'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 49 | rewards.append(f[f'{i}']['rewards'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 50 | next_states.append(f[f'{i}']['next_states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 51 | 52 | self.states = np.concatenate(states, axis=0) 53 | self.actions = np.concatenate(actions, axis=0) 54 | self.rewards = np.concatenate(rewards, axis=0) 55 | self.next_states = np.concatenate(next_states, axis=0) 56 | 57 | def __len__(self): 58 | return (len(self.states[0]) - self.n_transit + 1) * len(self.states) 59 | 60 | def __getitem__(self, i): 61 | history_idx = i // (len(self.states[0]) - self.n_transit + 1) 62 | transition_idx = i % (len(self.states[0]) - self.n_transit + 1) 63 | 64 | traj = { 65 | 'query_states': self.states[history_idx, transition_idx + self.n_transit - 1], 66 | 'target_actions': self.actions[history_idx, transition_idx + self.n_transit - 1], 67 | 'states': self.states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 68 | 'actions': self.actions[history_idx, transition_idx:transition_idx + self.n_transit - 1], 69 | 'rewards': self.rewards[history_idx, transition_idx:transition_idx + self.n_transit - 1], 70 | 'next_states': self.next_states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 71 | } 72 | 73 | if self.dynamics: 74 | traj.update({ 75 | 'target_next_states': self.next_states[history_idx, transition_idx + self.n_transit - 1], 76 | 'target_rewards': self.rewards[history_idx, transition_idx + self.n_transit - 1], 77 | }) 78 | 79 | return traj 80 | 81 | 82 | class DPTDataset(Dataset): 83 | def __init__(self, config, traj_dir, mode='train', n_stream=None, source_timesteps=None): 84 | self.env = config['env'] 85 | self.n_transit = config['n_transit'] 86 | self.dynamics = config['dynamics'] 87 | 88 | if self.env == 'darkroom': 89 | n_total_envs = config['grid_size'] ** 2 90 | elif self.env == 'darkroompermuted': 91 | n_total_envs = 120 92 | elif self.env == 'darkkeytodoor': 93 | n_total_envs = config['grid_size'] ** 4 94 | else: 95 | raise ValueError('Invalid env') 96 | 97 | total_env_idx = list(range(n_total_envs)) 98 | random.seed(config['env_split_seed']) 99 | random.shuffle(total_env_idx) 100 | 101 | n_train_envs = round(n_total_envs * config['train_env_ratio']) 102 | 103 | if mode == 'train': 104 | env_idx = total_env_idx[:n_train_envs] 105 | elif mode == 'test': 106 | env_idx = total_env_idx[n_train_envs:] 107 | elif mode == 'all': 108 | env_idx = total_env_idx 109 | else: 110 | raise ValueError('Invalid mode') 111 | 112 | states = [] 113 | actions = [] 114 | rewards = [] 115 | next_states = [] 116 | optimal_actions = [] 117 | 118 | with h5py.File(f'{traj_dir}/{get_traj_file_name(config)}.hdf5', 'r') as f: 119 | for i in env_idx: 120 | states.append(f[f'{i}']['states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 121 | actions.append(f[f'{i}']['actions'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 122 | rewards.append(f[f'{i}']['rewards'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 123 | next_states.append(f[f'{i}']['next_states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 124 | optimal_actions.append(f[f'{i}']['optimal_actions'][()][:n_stream, :source_timesteps]) 125 | 126 | self.states = np.concatenate(states, axis=0) 127 | self.actions = np.concatenate(actions, axis=0) 128 | self.rewards = np.concatenate(rewards, axis=0) 129 | self.next_states = np.concatenate(next_states, axis=0) 130 | self.optimal_actions = np.concatenate(optimal_actions, axis=0) 131 | 132 | def __len__(self): 133 | return (len(self.states[0]) - self.n_transit + 1) * len(self.states) 134 | 135 | def __getitem__(self, i): 136 | history_idx = i // (len(self.states[0]) - self.n_transit + 1) 137 | transition_idx = i % (len(self.states[0]) - self.n_transit + 1) 138 | 139 | traj = { 140 | 'query_states': self.states[history_idx, transition_idx + self.n_transit - 1], 141 | 'target_actions': self.optimal_actions[history_idx, transition_idx + self.n_transit - 1], 142 | 'states': self.states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 143 | 'actions': self.actions[history_idx, transition_idx:transition_idx + self.n_transit - 1], 144 | 'rewards': self.rewards[history_idx, transition_idx:transition_idx + self.n_transit - 1], 145 | 'next_states': self.next_states[history_idx, transition_idx:transition_idx + self.n_transit - 1], 146 | } 147 | 148 | if self.dynamics: 149 | traj.update({ 150 | 'query_actions': self.actions[history_idx, transition_idx + self.n_transit - 1], 151 | 'target_next_states': self.next_states[history_idx, transition_idx + self.n_transit - 1], 152 | 'target_rewards': self.rewards[history_idx, transition_idx + self.n_transit - 1] 153 | }) 154 | 155 | return traj 156 | 157 | 158 | class IDTDataset(Dataset): 159 | def __init__(self, config, traj_dir, mode='train', n_stream=None, source_timesteps=None): 160 | self.config = config 161 | self.env = config['env'] 162 | self.n_transit = config['n_transit'] 163 | self.dynamics = config['dynamics'] 164 | 165 | if self.env == 'darkroom': 166 | n_total_envs = config['grid_size'] ** 2 167 | elif self.env == 'darkroompermuted': 168 | n_total_envs = 120 169 | elif self.env == 'darkkeytodoor': 170 | n_total_envs = config['grid_size'] ** 4 171 | else: 172 | raise ValueError('Invalid env') 173 | 174 | total_env_idx = list(range(n_total_envs)) 175 | random.seed(config['env_split_seed']) 176 | random.shuffle(total_env_idx) 177 | 178 | n_train_envs = round(n_total_envs * config['train_env_ratio']) 179 | 180 | if mode == 'train': 181 | env_idx = total_env_idx[:n_train_envs] 182 | elif mode == 'test': 183 | env_idx = total_env_idx[n_train_envs:] 184 | elif mode == 'all': 185 | env_idx = total_env_idx 186 | else: 187 | raise ValueError('Invalid mode') 188 | 189 | states = [] 190 | actions = [] 191 | rewards = [] 192 | next_states = [] 193 | 194 | with h5py.File(f'{traj_dir}/{get_traj_file_name(config)}.hdf5', 'r') as f: 195 | for i in env_idx: 196 | states.append(f[f'{i}']['states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 197 | actions.append(f[f'{i}']['actions'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 198 | rewards.append(f[f'{i}']['rewards'][()].transpose(1, 0)[:n_stream, :source_timesteps]) 199 | next_states.append(f[f'{i}']['next_states'][()].transpose(1, 0, 2)[:n_stream, :source_timesteps]) 200 | 201 | self.states = np.concatenate(states, axis=0) 202 | self.actions = np.concatenate(actions, axis=0) 203 | self.rewards = np.concatenate(rewards, axis=0) 204 | self.return_to_go = self.get_return_to_go(self.rewards) 205 | self.next_states = np.concatenate(next_states, axis=0) 206 | 207 | self.sort_episodes() 208 | 209 | self.return_to_go = self.relabel_return_to_go(self.return_to_go) 210 | 211 | def __len__(self): 212 | return (len(self.states[0]) - self.n_transit + 1) * len(self.states) 213 | 214 | def __getitem__(self, i): 215 | history_idx = i // (len(self.states[0]) - self.n_transit + 1) 216 | transition_idx = i % (len(self.states[0]) - self.n_transit + 1) 217 | 218 | traj = { 219 | 'states': self.states[history_idx, transition_idx:transition_idx + self.n_transit], 220 | 'actions': self.actions[history_idx, transition_idx:transition_idx + self.n_transit], 221 | 'rewards': self.rewards[history_idx, transition_idx:transition_idx + self.n_transit], 222 | 'return_to_go': self.return_to_go[history_idx, transition_idx:transition_idx + self.n_transit], 223 | 'next_states': self.next_states[history_idx, transition_idx:transition_idx + self.n_transit], 224 | } 225 | 226 | return traj 227 | 228 | def get_return_to_go(self, rewards): 229 | episode_rewards = rewards.reshape(-1, rewards.shape[1] // self.config['horizon'], self.config['horizon']) 230 | return np.flip(np.flip(episode_rewards, axis=-1).cumsum(axis=-1), axis=-1).reshape(-1, rewards.shape[1]) 231 | 232 | def sort_episodes(self): 233 | return_to_go = rearrange(self.return_to_go, 'traj (epi time) -> traj epi time', time=self.config['horizon']) 234 | sorted_episode_idx = np.argsort(return_to_go[:, :, 0]) 235 | sorted_episode_idx = repeat(sorted_episode_idx, 'traj epi -> traj epi time', time=self.config['horizon']) 236 | 237 | return_to_go = np.take_along_axis(return_to_go, sorted_episode_idx, axis=1) 238 | self.return_to_go = rearrange(return_to_go, 'traj epi time -> traj (epi time)') 239 | 240 | actions = rearrange(self.actions, 'traj (epi time) -> traj epi time', time=self.config['horizon']) 241 | actions = np.take_along_axis(actions, sorted_episode_idx, axis=1) 242 | self.actions = rearrange(actions, 'traj epi time -> traj (epi time)') 243 | 244 | rewards = rearrange(self.rewards, 'traj (epi time) -> traj epi time', time=self.config['horizon']) 245 | rewards = np.take_along_axis(rewards, sorted_episode_idx, axis=1) 246 | self.rewards = rearrange(rewards, 'traj epi time -> traj (epi time)') 247 | 248 | sorted_episode_idx = repeat(sorted_episode_idx, 'traj epi time -> traj epi time dim', dim=self.states.shape[-1]) 249 | 250 | states = rearrange(self.states, 'traj (epi time) dim -> traj epi time dim', time=self.config['horizon']) 251 | states = np.take_along_axis(states, sorted_episode_idx, axis=1) 252 | self.states = rearrange(states, 'traj epi time dim -> traj (epi time) dim') 253 | 254 | next_states = rearrange(self.next_states, 'traj (epi time) dim -> traj epi time dim', time=self.config['horizon']) 255 | next_states = np.take_along_axis(next_states, sorted_episode_idx, axis=1) 256 | self.next_states = rearrange(next_states, 'traj epi time dim -> traj (epi time) dim') 257 | 258 | def relabel_return_to_go(self, rtg): 259 | max_episode_rtg = rtg.max(axis=-1) # (num_traj, ) 260 | max_episode_rtg = repeat(max_episode_rtg, 'traj -> traj epi', epi=rtg.shape[1] // self.config['horizon']) 261 | 262 | episode_rtg = rtg.reshape(-1, rtg.shape[1] // self.config['horizon'], self.config['horizon']) 263 | 264 | episode_offset = max_episode_rtg - episode_rtg[:, :, 0] 265 | offset = repeat(episode_offset, 'traj epi -> traj epi time', time=self.config['horizon']) 266 | 267 | return (episode_rtg + offset).reshape(-1, rtg.shape[1]) -------------------------------------------------------------------------------- /gridworld/model/ad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import pack, rearrange, repeat 6 | 7 | from env import map_dark_states, map_dark_states_inverse 8 | 9 | from .tiny_llama.model import Transformer 10 | 11 | 12 | class AD(torch.nn.Module): 13 | def __init__(self, config): 14 | super(AD, self).__init__() 15 | 16 | self.config = config 17 | self.device = config['device'] 18 | self.n_transit = config['n_transit'] 19 | self.max_seq_length = config['n_transit'] 20 | self.mixed_precision = config['mixed_precision'] 21 | self.grid_size = config['grid_size'] 22 | self.dynamics = config['dynamics'] 23 | 24 | self.transformer = Transformer(config) 25 | 26 | self.embed_context = nn.Linear(config['dim_states'] * 2 + config['num_actions'] + 1, config['tf_n_embd']) 27 | self.embed_query_state = nn.Embedding(config['grid_size'] * config['grid_size'], config['tf_n_embd']) 28 | self.pred_action = nn.Linear(config['tf_n_embd'], config['num_actions']) 29 | 30 | self.loss_fn = nn.CrossEntropyLoss(reduction='mean', label_smoothing=config['label_smoothing']) 31 | 32 | if self.dynamics: 33 | self.embed_query_action = nn.Embedding(config['num_actions'], config['tf_n_embd']) 34 | self.pred_reward = nn.Linear(config['tf_n_embd'], 2) 35 | self.pred_next_state = nn.Linear(config['tf_n_embd'], self.grid_size * self.grid_size) 36 | 37 | def forward(self, x): 38 | query_states = x['query_states'].to(self.device) # (batch_size, dim_state) 39 | target_actions = x['target_actions'].to(self.device) # (batch_size,) 40 | states = x['states'].to(self.device) # (batch_size, num_transit, dim_state) 41 | actions = x['actions'].to(self.device) # (batch_size, num_transit, num_actions) 42 | next_states = x['next_states'].to(self.device) # (batch_size, num_transit, dim_state) 43 | rewards = x['rewards'].to(self.device) # (batch_size, num_transit) 44 | rewards = rearrange(rewards, 'b n -> b n 1') 45 | 46 | query_states_embed = self.embed_query_state(map_dark_states(query_states, self.grid_size).to(torch.long)) 47 | query_states_embed = rearrange(query_states_embed, 'b d -> b 1 d') 48 | 49 | context, _ = pack([states, actions, rewards, next_states], 'b n *') 50 | context_embed = self.embed_context(context) 51 | context_embed, _ = pack([context_embed, query_states_embed], 'b * d') 52 | 53 | if self.dynamics: 54 | query_actions = x['target_actions'].to(self.device) # (batch_size, ) 55 | query_actions_embed = self.embed_query_action(query_actions) 56 | context_embed, _ = pack([context_embed, query_actions_embed], 'b * d') 57 | 58 | transformer_output = self.transformer(context_embed, 59 | max_seq_length=self.max_seq_length, 60 | dtype=self.mixed_precision) 61 | 62 | result = {} 63 | 64 | logits_actions = self.pred_action(transformer_output[:, self.n_transit-1]) # (batch_size, dim_action) 65 | 66 | loss_full_action = self.loss_fn(logits_actions, target_actions) 67 | acc_full_action = (logits_actions.argmax(dim=-1) == target_actions).float().mean() 68 | 69 | result['loss_action'] = loss_full_action 70 | result['acc_action'] = acc_full_action 71 | 72 | if self.dynamics: 73 | logit_rewards = self.pred_reward(transformer_output[:, -1]) 74 | target_rewards = x['target_rewards'].to(self.device) # (batch_size, ) 75 | result['loss_reward'] = self.loss_fn(logit_rewards, target_rewards) 76 | result['acc_reward'] = (logit_rewards.argmax(dim=-1) == target_rewards).float().mean() 77 | 78 | logits_states = self.pred_next_state(transformer_output[:, -1]) 79 | target_states = x['target_next_states'].to(self.device) # (batch_size, ) 80 | result['loss_next_state'] = self.loss_fn(logits_states, target_states) 81 | result['acc_next_state'] = (logits_states.argmax(dim=-1) == target_states).float().mean() 82 | 83 | return result 84 | 85 | def evaluate_in_context(self, vec_env, eval_timesteps, beam_k=0, sample=True): 86 | outputs = {} 87 | outputs['reward_episode'] = [] 88 | 89 | reward_episode = np.zeros(vec_env.num_envs) 90 | 91 | # Get inital states embeddings 92 | query_states = vec_env.reset() 93 | query_states = torch.tensor(query_states, device=self.device, requires_grad=False, dtype=torch.long) 94 | query_states = rearrange(query_states, 'e d -> e 1 d') 95 | query_states_embed = self.embed_query_state(map_dark_states(query_states, self.grid_size)) 96 | transformer_input = query_states_embed 97 | 98 | for step in range(eval_timesteps): 99 | query_states_prev = query_states.clone().detach().to(torch.float) 100 | 101 | # position = step % self.config['horizon'] 102 | if self.dynamics and beam_k > 0 and step >= self.n_transit: 103 | actions = self.beam_search(x=transformer_input.clone().detach(), 104 | query_states=query_states_prev.clone().detach(), 105 | position=step % self.config['horizon'], 106 | beam_k=beam_k, 107 | sample=sample) 108 | else: 109 | output = self.transformer(transformer_input, 110 | max_seq_length=self.max_seq_length, 111 | dtype='fp32') 112 | 113 | logits = self.pred_action(output[:, -1]) 114 | 115 | if sample: 116 | log_probs = F.log_softmax(logits, dim=-1) 117 | actions = torch.multinomial(log_probs.exp(), num_samples=1) 118 | actions = rearrange(actions, 'e 1 -> e') 119 | else: 120 | actions = logits.argmax(dim=-1) 121 | 122 | query_states, rewards, dones, infos = vec_env.step(actions.cpu().numpy()) 123 | 124 | actions = rearrange(actions, 'e -> e 1 1') 125 | actions = F.one_hot(actions, num_classes=self.config['num_actions']) 126 | 127 | reward_episode += rewards 128 | rewards = torch.tensor(rewards, device=self.device, requires_grad=False, dtype=torch.float) 129 | rewards = rearrange(rewards, 'e -> e 1 1') 130 | 131 | query_states = torch.tensor(query_states, device=self.device, requires_grad=False, dtype=torch.long) 132 | query_states = rearrange(query_states, 'e d -> e 1 d') 133 | 134 | if dones[0]: 135 | outputs['reward_episode'].append(reward_episode) 136 | reward_episode = np.zeros(vec_env.num_envs) 137 | 138 | states_next = torch.tensor(np.stack([info['terminal_observation'] for info in infos]), device=self.device, dtype=torch.float) 139 | 140 | states_next = rearrange(states_next, 'e d -> e 1 d') 141 | else: 142 | states_next = query_states.clone().detach().to(torch.float) 143 | 144 | query_states_embed = self.embed_query_state(map_dark_states(query_states, self.grid_size)) 145 | 146 | context, _ = pack([query_states_prev, actions, rewards, states_next], 'e i *') 147 | context_embed = self.embed_context(context) 148 | 149 | if transformer_input.size(1) > 1: 150 | context_embed, _ = pack([transformer_input[:, :-1], context_embed], 'e * h') 151 | context_embed = context_embed[:, -(self.n_transit-1):] 152 | 153 | transformer_input, _ = pack([context_embed, query_states_embed], 'e * h') 154 | 155 | outputs['reward_episode'] = np.stack(outputs['reward_episode'], axis=1) 156 | 157 | return outputs 158 | 159 | def beam_search(self, x, query_states, position, beam_k=5, sample=True): 160 | batch_size = x.size(0) 161 | 162 | output = self.transformer(x, 163 | max_seq_length=self.max_seq_length, 164 | dtype="fp32") 165 | 166 | logit_actions = self.pred_action(output[:, -1]) 167 | 168 | if sample: 169 | log_probs = F.log_softmax(logit_actions, dim=-1) 170 | all_actions = torch.multinomial(log_probs.exp(), num_samples=self.config['num_actions']) 171 | else: 172 | all_actions = logit_actions.argsort(dim=-1, descending=True) # (batch_size, num_actions) 173 | 174 | # Query all actions 175 | all_actions_embed = self.embed_query_action(all_actions) 176 | all_actions_embed = rearrange(all_actions_embed, 'b a h -> b a 1 h') 177 | 178 | x = repeat(x, 'b n h -> b a n h', a=self.config['num_actions']) 179 | x, _ = pack([x, all_actions_embed], 'b a * h') 180 | 181 | output = self.transformer(rearrange(x, 'b a n h -> (b a) n h'), 182 | max_seq_length=self.max_seq_length, 183 | dtype="fp32") 184 | 185 | output = rearrange(output, '(b a) n h -> b a n h', a=self.config['num_actions']) 186 | 187 | # Get rewards 188 | logits_rewards = self.pred_reward(output[:, :, -1]) 189 | rewards = logits_rewards.argmax(dim=-1) # (batch_size, num_actions) 190 | 191 | # Get next states 192 | logit_next_states = self.pred_next_state(output[:, :, -1]) 193 | next_states = logit_next_states.argmax(dim=-1) # (batch_size, num_actions) 194 | 195 | # Initialize cumulative rewards 196 | cum_rewards = rewards.clone().detach() 197 | 198 | # Sort actions according to rewards 199 | rewards_sort = cum_rewards.sort(dim=-1, descending=True, stable=True) 200 | cum_rewards = rewards_sort.values[:, :beam_k] 201 | indices_k = rewards_sort.indices[:, :beam_k] 202 | 203 | # Update cumulative rewards 204 | beam = torch.gather(all_actions, 1, indices_k) 205 | beam = rearrange(beam, 'b k -> b k 1') 206 | 207 | if self.config['env'] == 'darkroom': 208 | max_beam_steps = self.grid_size - 1 209 | elif self.config['env'] == 'darkkeytodoor' or self.config['env'] == 'darkroompermuted': 210 | max_beam_steps = (self.grid_size - 1) * 2 211 | else: 212 | raise ValueError('Invalid environment') 213 | 214 | position += 1 215 | beam_step = 1 216 | 217 | while position < self.config['horizon'] and beam_step < max_beam_steps: 218 | # Sort and cutoff variables 219 | x = torch.gather(x, 1, repeat(indices_k, 'b k -> b k n h', n=x.size(2), h=x.size(3))) 220 | actions_onehot = F.one_hot(beam[:, :, -1], num_classes=self.config['num_actions']) 221 | rewards = torch.gather(rewards, 1, indices_k) 222 | rewards = rearrange(rewards, 'b k -> b k 1') 223 | next_states = torch.gather(next_states, 1, indices_k) 224 | next_states_coord = map_dark_states_inverse(next_states, self.config['grid_size']) 225 | query_states = repeat(query_states, 'b k d -> b (k a) d', a=self.config['num_actions']) 226 | query_states = torch.gather(query_states, 1, repeat(indices_k, 'b k -> b k d', d=query_states.size(2))) 227 | 228 | # Make new context transition 229 | new_context, _ = pack([query_states, actions_onehot, rewards, next_states_coord], 'b k *') 230 | new_context_embed = self.embed_context(new_context.float()) 231 | new_context_embed = repeat(new_context_embed, 'b k h -> b (k a) 1 h', a=self.config['num_actions']) 232 | 233 | # Make new query states 234 | query_states_embed = self.embed_query_state(next_states) 235 | query_states_embed = repeat(query_states_embed, 'b k h -> b (k a) 1 h', a=self.config['num_actions']) 236 | 237 | query_states = next_states_coord # (batch_size, beam_k, dim_state) 238 | 239 | # Make transformer input 240 | x = repeat(x, 'b k n h -> b (k a) n h', a=self.config['num_actions']) 241 | 242 | all_actions = torch.arange(self.config['num_actions'], device=self.device) 243 | all_actions_embed = self.embed_query_action(all_actions) 244 | all_actions_embed = repeat(all_actions_embed, 'a h -> b (k a) 1 h', b=batch_size, k=rewards.size(1)) 245 | 246 | x, _ = pack([x[:, :, 1:self.config['n_transit']-1], new_context_embed, query_states_embed, all_actions_embed], 'b ka * h') 247 | 248 | assert x.size(2) == self.config['n_transit'] + 1 249 | 250 | # query states & actions 251 | output = self.transformer(rearrange(x, 'b ka n h -> (b ka) n h'), 252 | max_seq_length=self.max_seq_length, 253 | dtype="fp32") 254 | 255 | output = rearrange(output, '(b ka) n h -> b ka n h', b=batch_size) 256 | 257 | # Get rewards 258 | logit_rewards = self.pred_reward(output[:, :, -1]) 259 | rewards = logit_rewards.argmax(dim=-1) # (batch_size, beam_k * num_actions) 260 | 261 | # Get next states 262 | logit_next_states = self.pred_next_state(output[:, :, -1]) 263 | next_states = logit_next_states.argmax(dim=-1) # (batch_size, beam_k * num_actions) 264 | 265 | # Update cumulative rewards 266 | cum_rewards = repeat(cum_rewards, 'b k -> b (k a)', a=self.config['num_actions']) 267 | cum_rewards = cum_rewards + rewards 268 | rewards_sort = cum_rewards.sort(dim=-1, descending=True, stable=True) 269 | cum_rewards = rewards_sort.values[:, :beam_k] 270 | indices_k = rewards_sort.indices[:, :beam_k] 271 | 272 | new_actions = repeat(all_actions, 'a -> b (k a) 1', b=batch_size, k=beam.size(1)) 273 | beam = repeat(beam, 'b k s -> b (k a) s', a=self.config['num_actions']) 274 | beam, _ = pack([beam, new_actions], 'b ka *') 275 | beam = torch.gather(beam, 1, repeat(indices_k, 'b k -> b k s', s=beam.size(2))) 276 | 277 | position += 1 278 | beam_step += 1 279 | 280 | return beam[:, 0, 0] -------------------------------------------------------------------------------- /gridworld/model/dpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .tiny_llama.model import Transformer, LLaMAMLP 5 | from einops import pack, rearrange, repeat 6 | import numpy as np 7 | from env import map_dark_states, map_dark_states_inverse 8 | 9 | 10 | class DPT(nn.Module): 11 | def __init__(self, config): 12 | super(DPT, self).__init__() 13 | 14 | self.config = config 15 | self.device = config['device'] 16 | self.n_transit = config['n_transit'] 17 | self.max_seq_length = config['n_transit'] 18 | self.mixed_precision = config['mixed_precision'] 19 | self.grid_size = config['grid_size'] 20 | self.dynamics = config['dynamics'] 21 | 22 | self.transformer = Transformer(config) 23 | 24 | self.embed_context = nn.Linear(config['dim_states'] * 2 + config['num_actions'] + 1, config['tf_n_embd']) 25 | self.embed_query_state = nn.Linear(config['dim_states'], config['tf_n_embd']) 26 | self.pred_actions = nn.Linear(config['tf_n_embd'], config['num_actions']) 27 | 28 | self.embed_query_action = nn.Embedding(config['num_actions'], config['tf_n_embd']) 29 | 30 | self.loss_fn = nn.CrossEntropyLoss(reduction='mean', label_smoothing=config['label_smoothing']) 31 | 32 | if self.dynamics: 33 | self.pred_rewards = nn.Linear(config['tf_n_embd'], 2) 34 | self.pred_next_states = nn.Linear(config['tf_n_embd'], self.grid_size * self.grid_size) 35 | 36 | def forward(self, x): 37 | query_states = x['query_states'].to(self.device) # (batch_size, dim_state) 38 | target_actions = x['target_actions'].to(self.device) # (batch_size,) 39 | states = x['states'].to(self.device) # (batch_size, num_transit, dim_state) 40 | actions = x['actions'].to(self.device) # (batch_size, num_transit, num_actions) 41 | next_states = x['next_states'].to(self.device) # (batch_size, num_transit, dim_state) 42 | rewards = x['rewards'].to(self.device) # (batch_size, num_transit) 43 | rewards = rearrange(rewards, 'b n -> b n 1') 44 | 45 | query_states = F.pad(query_states, (0, actions.size(2) + rewards.size(2) + next_states.size(2))) 46 | query_states = rearrange(query_states, 'b d -> b 1 d') 47 | context, _ = pack([states, actions, rewards, next_states], 'b n *') 48 | 49 | context, _ = pack([query_states, context], 'b * d') 50 | context_embed = self.embed_context(context) 51 | 52 | if self.dynamics: 53 | query_actions = x['query_actions'].to(self.device) # (batch_size, num_actions) 54 | query_actions_embed = self.embed_query_action(query_actions) 55 | context_embed, _ = pack([context_embed, query_actions_embed], 'b * d') 56 | 57 | transformer_output = self.transformer(context_embed, 58 | max_seq_length=self.max_seq_length, 59 | dtype=self.mixed_precision) 60 | 61 | result = {} 62 | 63 | logits_actions = self.pred_actions(transformer_output[:, 1:self.n_transit]) # (batch_size, num_transit-1 , dim_action) 64 | target_actions_repeated = repeat(target_actions, 'b -> b n', n=logits_actions.size(1)) 65 | 66 | result['loss_action'] = self.loss_fn(rearrange(logits_actions, 'b n a -> (b n) a'), 67 | rearrange(target_actions_repeated, 'b n -> (b n)')) 68 | result['acc_action'] = (logits_actions.argmax(dim=-1) == target_actions_repeated).float().mean() 69 | 70 | if self.dynamics: 71 | logits_rewards = self.pred_rewards(transformer_output[:, -1]) 72 | target_rewards = x['target_rewards'].to(self.device) # (batch_size, ) 73 | 74 | result['loss_reward'] = self.loss_fn(logits_rewards, target_rewards) 75 | result['acc_reward'] = (logits_rewards.argmax(dim=-1) == target_rewards).float().mean() 76 | 77 | logits_states = self.pred_next_states(transformer_output[:, -1]) 78 | target_states = x['target_next_states'].to(self.device) # (batch_size, ) 79 | 80 | result['loss_next_state'] = self.loss_fn(logits_states, target_states) 81 | result['acc_next_state'] = (logits_states.argmax(dim=-1) == target_states).float().mean() 82 | 83 | return result 84 | 85 | def evaluate_in_context(self, vec_env, eval_timesteps, beam_k=0, sample=True): 86 | 87 | outputs = {} 88 | outputs['reward_episode'] = [] 89 | 90 | reward_episode = np.zeros(vec_env.num_envs) 91 | 92 | # Get inital states embeddings 93 | query_states = vec_env.reset() 94 | query_states = torch.tensor(query_states, device=self.device, requires_grad=False, dtype=torch.float) 95 | query_states_padded = F.pad(query_states, (0, self.config['dim_states'] + self.config['num_actions'] + 1)) 96 | query_states_padded = rearrange(query_states_padded, 'e d -> e 1 d') 97 | # query_states = rearrange(query_states, 'e d -> e 1 d') 98 | 99 | transformer_input = self.embed_context(query_states_padded) 100 | # transformer_input = self.embed_query_state(query_states_padded) 101 | 102 | 103 | for step in range(eval_timesteps): 104 | query_states_prev = query_states_padded[:,:,:self.config['dim_states']].clone().detach() 105 | # query_states_prev = query_states.clone().detach() 106 | 107 | # position = step % self.config['horizon'] 108 | if self.dynamics and beam_k > 0 and step >= self.n_transit: 109 | actions = self.beam_search(x=transformer_input.clone().detach(), 110 | query_states=query_states.clone().detach(), 111 | position=step % self.config['horizon'], 112 | beam_k=beam_k, 113 | sample=sample) 114 | else: 115 | output = self.transformer(transformer_input, 116 | max_seq_length=self.max_seq_length, 117 | dtype='fp32') 118 | 119 | logits = self.pred_actions(output[:, -1]) 120 | 121 | if sample: 122 | log_probs = F.log_softmax(logits, dim=-1) 123 | actions = torch.multinomial(log_probs.exp(), num_samples=1) 124 | actions = rearrange(actions, 'e 1 -> e') 125 | else: 126 | actions = logits.argmax(dim=-1) 127 | 128 | query_states, rewards, dones, infos = vec_env.step(actions.cpu().numpy()) 129 | 130 | actions = rearrange(actions, 'e -> e 1 1') 131 | actions = F.one_hot(actions, num_classes=self.config['num_actions']) 132 | 133 | reward_episode += rewards 134 | rewards = torch.tensor(rewards, device=self.device, requires_grad=False, dtype=torch.float) 135 | rewards = rearrange(rewards, 'e -> e 1 1') 136 | 137 | query_states = torch.tensor(query_states, device=self.device, requires_grad=False, dtype=torch.float) 138 | query_states = rearrange(query_states, 'e d -> e 1 d') 139 | 140 | if dones[0]: 141 | outputs['reward_episode'].append(reward_episode) 142 | reward_episode = np.zeros(vec_env.num_envs) 143 | 144 | states_next = torch.tensor(np.stack([info['terminal_observation'] for info in infos]), device=self.device, dtype=torch.float) 145 | 146 | states_next = rearrange(states_next, 'e d -> e 1 d') 147 | else: 148 | states_next = query_states.clone().detach() 149 | 150 | query_states_padded = F.pad(query_states, (0, self.config['dim_states'] + self.config['num_actions'] + 1)) 151 | query_states_embed = self.embed_context(query_states_padded) 152 | # query_states_embed = self.embed_query_state(query_states) 153 | 154 | context, _ = pack([query_states_prev, actions, rewards, states_next], 'e n *') 155 | context_embed = self.embed_context(context) 156 | 157 | if transformer_input.size(1) > 1: 158 | context_embed, _ = pack([transformer_input[:, 1:], context_embed], 'e * h') 159 | context_embed = context_embed[:, -(self.n_transit-1):] 160 | 161 | transformer_input, _ = pack([query_states_embed, context_embed], 'e * h') 162 | 163 | outputs['reward_episode'] = np.stack(outputs['reward_episode'], axis=1) 164 | 165 | return outputs 166 | 167 | def beam_search(self, x, query_states, position, beam_k=5, sample=True): 168 | batch_size = x.size(0) 169 | 170 | output = self.transformer(x, 171 | max_seq_length=self.max_seq_length, 172 | dtype="fp32") 173 | 174 | logit_actions = self.pred_actions(output[:, -1]) 175 | 176 | if sample: 177 | log_probs = F.log_softmax(logit_actions, dim=-1) 178 | all_actions = torch.multinomial(log_probs.exp(), num_samples=self.config['num_actions']) 179 | else: 180 | all_actions = logit_actions.argsort(dim=-1, descending=True) # (batch_size, num_actions) 181 | 182 | # Query all actions 183 | # all_actions_onehot = F.one_hot(all_actions, num_classes=self.config['num_actions']) # (batch_size, num_actions, num_actions) 184 | # all_actions_onehot = rearrange(all_actions_onehot, 'b a d -> b a 1 d') 185 | all_actions_embed = self.embed_query_action(all_actions) 186 | all_actions_embed = rearrange(all_actions_embed, 'b a h -> b a 1 h') 187 | 188 | # query_states = repeat(query_states, 'b 1 d -> b a 1 d', a=self.config['num_actions']) 189 | # query_states_actions, _ = pack([query_states, all_actions_onehot], 'b a i *') 190 | # query_states_actions = F.pad(query_states_actions, (0, 1+self.config['dim_states'])) 191 | # query_states_actions_embed = self.embed_context(query_states_actions) 192 | 193 | x = repeat(x, 'b n h -> b a n h', a=self.config['num_actions']) 194 | x, _ = pack([x, all_actions_embed], 'b a * h') 195 | 196 | output = self.transformer(rearrange(x, 'b a n h -> (b a) n h'), 197 | max_seq_length=self.max_seq_length, 198 | dtype="fp32") 199 | 200 | output = rearrange(output, '(b a) n h -> b a n h', a=self.config['num_actions']) 201 | 202 | # Get rewards 203 | logit_rewards = self.pred_rewards(output[:, :, -1]) 204 | rewards = logit_rewards.argmax(dim=-1) # (batch_size, num_actions) 205 | 206 | # # Get next states 207 | logit_next_states = self.pred_next_states(output[:, :, -1]) 208 | next_states = logit_next_states.argmax(dim=-1) # (batch_size, num_actions) 209 | 210 | # Initialize cumulative rewards 211 | cum_rewards = rewards.clone().detach() 212 | 213 | # Sort actions according to rewards 214 | rewards_sort = cum_rewards.sort(dim=-1, descending=True, stable=True) 215 | cum_rewards = rewards_sort.values[:, :beam_k] 216 | indices_k = rewards_sort.indices[:, :beam_k] 217 | 218 | beam = torch.gather(all_actions, 1, indices_k) 219 | beam = rearrange(beam, 'b k -> b k 1') 220 | 221 | position += 1 222 | if self.config['env'] == 'darkroom': 223 | max_beam_steps = self.grid_size - 1 224 | elif self.config['env'] == 'darkkeytodoor' or self.config['env'] == 'darkroompermuted': 225 | max_beam_steps = (self.grid_size - 1) * 2 226 | else: 227 | raise ValueError('Invalid environment') 228 | 229 | beam_step = 1 230 | 231 | while position < self.config['horizon'] and beam_step < max_beam_steps: 232 | # Sort and cutoff variables 233 | x = torch.gather(x, 1, repeat(indices_k, 'b k -> b k n h', n=x.size(2), h=x.size(3))) 234 | actions_onehot = F.one_hot(beam[:, :, -1], num_classes=self.config['num_actions']) 235 | rewards = torch.gather(rewards, 1, indices_k) 236 | rewards = rearrange(rewards, 'b k -> b k 1') 237 | next_states = torch.gather(next_states, 1, indices_k) 238 | next_states_coord = map_dark_states_inverse(next_states, self.config['grid_size']) 239 | query_states = repeat(query_states, 'b k d -> b (k a) d', a=self.config['num_actions']) 240 | query_states = torch.gather(query_states, 1, repeat(indices_k, 'b k -> b k d', d=query_states.size(2))) 241 | 242 | # Make new context transition 243 | new_context, _ = pack([query_states, actions_onehot, rewards, next_states_coord], 'b k *') 244 | new_context_embed = self.embed_context(new_context.float()) 245 | new_context_embed = repeat(new_context_embed, 'b k h -> b (k a) 1 h', a=self.config['num_actions']) 246 | 247 | # Make new query states 248 | next_states_padded = F.pad(next_states_coord, (0, self.config['dim_states'] + self.config['num_actions'] + 1)) 249 | query_states_embed = self.embed_context(next_states_padded.to(torch.float)) 250 | query_states_embed = repeat(query_states_embed, 'b k h -> b (k a) 1 h', a=self.config['num_actions']) 251 | 252 | query_states = next_states_coord # (batch_size, beam_k, dim_state) 253 | 254 | # Make transformer input 255 | x = repeat(x, 'b k n h -> b (k a) n h', a=self.config['num_actions']) 256 | 257 | all_actions = torch.arange(self.config['num_actions'], device=self.device) 258 | all_actions_embed = self.embed_query_action(all_actions) 259 | all_actions_embed = repeat(all_actions_embed, 'a h -> b (k a) 1 h', b=batch_size, k=rewards.size(1)) 260 | 261 | x, _ = pack([query_states_embed, x[:, :, 2:self.config['n_transit']], new_context_embed, all_actions_embed], 'b ka * h') 262 | 263 | assert x.size(2) == self.config['n_transit'] + 1 264 | 265 | # query (states, actions) 266 | output = self.transformer(rearrange(x, 'b ka n h -> (b ka) n h'), 267 | max_seq_length=self.max_seq_length, 268 | dtype="fp32") 269 | 270 | output = rearrange(output, '(b ka) n h -> b ka n h', b=batch_size) 271 | 272 | # Get rewards 273 | logit_rewards = self.pred_rewards(output[:, :, -1]) 274 | rewards = logit_rewards.argmax(dim=-1) # (batch_size, beam_k * num_actions) 275 | 276 | # Get next states 277 | logit_next_states = self.pred_next_states(output[:, :, -1]) 278 | next_states = logit_next_states.argmax(dim=-1) # (batch_size, beam_k * num_actions) 279 | 280 | # Update cumulative rewards 281 | cum_rewards = repeat(cum_rewards, 'b k -> b (k a)', a=self.config['num_actions']) 282 | cum_rewards = cum_rewards + rewards 283 | rewards_sort = cum_rewards.sort(dim=-1, descending=True, stable=True) 284 | cum_rewards = rewards_sort.values[:, :beam_k] 285 | indices_k = rewards_sort.indices[:, :beam_k] 286 | 287 | new_actions = repeat(all_actions, 'a -> b (k a) 1', b=batch_size, k=beam.size(1)) 288 | beam = repeat(beam, 'b k s -> b (k a) s', a=self.config['num_actions']) 289 | beam, _ = pack([beam, new_actions], 'b ka *') 290 | beam = torch.gather(beam, 1, repeat(indices_k, 'b k -> b k s', s=beam.size(2))) 291 | 292 | position += 1 293 | beam_step += 1 294 | 295 | return beam[:, 0, 0] -------------------------------------------------------------------------------- /gridworld/train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from datetime import datetime 3 | from glob import glob 4 | import os 5 | import os.path as path 6 | from modulefinder import ModuleFinder 7 | 8 | import yaml 9 | import argparse 10 | import torch 11 | from torch.optim import AdamW 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from dataset import ADDataset, DPTDataset, IDTDataset 15 | from env import SAMPLE_ENVIRONMENT 16 | from model import MODEL 17 | from utils import get_config, get_data_loader, log_in_context, next_dataloader 18 | from transformers import get_cosine_schedule_with_warmup 19 | 20 | import multiprocessing 21 | from tqdm import tqdm 22 | from accelerate import Accelerator 23 | from stable_baselines3.common.vec_env import SubprocVecEnv 24 | 25 | from env import make_env 26 | 27 | 28 | def parse_arguments(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--alg-config', '-ac', required=False, default='./cfg/alg/ppo_dr.yaml', help="Algorithm config") 31 | parser.add_argument('--env-config', '-ec', required=False, default='./cfg/env/darkroom.yaml', help="Environment config") 32 | parser.add_argument('--model-config', '-mc', required=False, default='./cfg/model/ad_dr.yaml', help="Model config") 33 | parser.add_argument('--log-dir', '-l', required=False, default='./runs', help="Log directory") 34 | parser.add_argument('--traj-dir', '-t', required=False, default='./datasets', help="Trajectory directory") 35 | parser.add_argument('--no-backup', '-nb', required=False, default=False, help="Save code", action='store_true') 36 | parser.add_argument('--override', '-o', default='') 37 | parser.add_argument('--resume', required=False, default=False, help="Resume train", action='store_true') 38 | parser.add_argument('--mixed-precision', '-m', required=False, default='fp32', help="fp32 or fp16 or bf16") 39 | parser.add_argument('--disable-tqdm', '-d', required=False, default=False, action='store_true') 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | if __name__ == '__main__': 46 | multiprocessing.set_start_method('spawn', force=True) 47 | args = parse_arguments() 48 | 49 | # Load and update config 50 | config = get_config(args.env_config) 51 | config.update(get_config(args.alg_config)) 52 | config.update(get_config(args.model_config)) 53 | 54 | # Override options 55 | for option in args.override.split('|'): 56 | if not option: 57 | continue 58 | address, value = option.split('=') 59 | keys = address.split('.') 60 | here = config 61 | for key in keys[:-1]: 62 | if key not in here: 63 | here[key] = {} 64 | here = here[key] 65 | if keys[-1] not in here: 66 | print(f'Warning: {address} is not defined in config file.') 67 | here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader) 68 | 69 | if config['dynamics']: 70 | log_dir = path.join(args.log_dir, f"{config['model']}-{config['env']}-dynamics{config['dynamics']}-str{config['dynamics_strength']}-seed{config['env_split_seed']}") 71 | else: 72 | log_dir = path.join(args.log_dir, f"{config['model']}-{config['env']}-dynamics{config['dynamics']}-seed{config['env_split_seed']}") 73 | 74 | writer = SummaryWriter(log_dir, flush_secs=15) 75 | 76 | # Prevent overwriting 77 | config['log_dir'] = log_dir 78 | config_save_path = path.join(config['log_dir'], 'config.yaml') 79 | try: 80 | # Try to open config file to bypass NFS cache 81 | with open(config_save_path, 'r') as f: 82 | f.read(1) 83 | config_exists = True 84 | except FileNotFoundError: 85 | config_exists = False 86 | 87 | if config_exists and not args.resume: 88 | print(f'WARNING: {log_dir} already exists. Skipping...') 89 | exit(0) 90 | 91 | config['traj_dir'] = args.traj_dir 92 | config['mixed_precision'] = args.mixed_precision 93 | 94 | # Save config 95 | os.makedirs(config['log_dir'], mode=0o755, exist_ok=True) 96 | with open(config_save_path, 'w') as f: 97 | yaml.dump(config, f) 98 | print(f'Config saved to {config_save_path}') 99 | 100 | # Save code 101 | if not args.no_backup: 102 | code_dir = path.join(config['log_dir'], 'code_' + datetime.now().strftime('%Y%m%d_%H%M%S')) 103 | mf = ModuleFinder([os.getcwd()]) 104 | mf.run_script(__file__) 105 | for name, module in mf.modules.items(): 106 | if module.__file__ is None: 107 | continue 108 | rel_path = path.relpath(module.__file__) 109 | new_path = path.join(code_dir, rel_path) 110 | new_dirname = path.dirname(new_path) 111 | os.makedirs(new_dirname, mode=0o750, exist_ok=True) 112 | shutil.copy2(rel_path, new_path) 113 | print(f'Code saved to {code_dir}') 114 | 115 | # Define accelerator 116 | if args.mixed_precision == 'bf16' or args.mixed_precision == 'fp16': 117 | accelerator = Accelerator(mixed_precision=args.mixed_precision) 118 | elif args.mixed_precision == 'fp32': 119 | accelerator = Accelerator(mixed_precision='no') 120 | else: 121 | raise ValueError(f'Unsupported mixed precision: {args.mixed_precision}') 122 | 123 | config['device'] = accelerator.device 124 | 125 | # Define model 126 | model_name = config['model'] 127 | model = MODEL[model_name](config) 128 | 129 | # Get datasets and dataloaders 130 | load_start_time = datetime.now() 131 | print(f'Data loading started at {load_start_time}') 132 | 133 | if config['model'] == 'DPT': 134 | train_dataset = DPTDataset(config, args.traj_dir, 'train', config['train_n_stream'], config['train_source_timesteps']) 135 | test_dataset = DPTDataset(config, args.traj_dir, 'test', 1, config['train_source_timesteps']) 136 | elif config['model'] == 'AD': 137 | train_dataset = ADDataset(config, args.traj_dir, 'train', config['train_n_stream'], config['train_source_timesteps']) 138 | test_dataset = ADDataset(config, args.traj_dir, 'test', 1, config['train_source_timesteps']) 139 | elif config['model'] == 'IDT': 140 | train_dataset = IDTDataset(config, args.traj_dir, 'train', config['train_n_stream'], config['train_source_timesteps']) 141 | test_dataset = IDTDataset(config, args.traj_dir, 'test', 1, config['train_source_timesteps']) 142 | else: 143 | raise ValueError(f'Unsupported model: {config["model"]}') 144 | 145 | train_dataloader = get_data_loader(train_dataset, batch_size=config['train_batch_size'], config=config, shuffle=True) 146 | train_dataloader = next_dataloader(train_dataloader) 147 | 148 | test_dataloader = get_data_loader(test_dataset, batch_size=config['test_batch_size'], config=config, shuffle=False) 149 | 150 | load_end_time = datetime.now() 151 | print() 152 | print(f'Data loading ended at {load_end_time}') 153 | print(f'Elapsed time: {load_end_time - load_start_time}') 154 | 155 | # Define optimizer and scheduler 156 | optimizer = AdamW(model.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']), weight_decay=config['weight_decay']) 157 | lr_sched = get_cosine_schedule_with_warmup(optimizer, config['num_warmup_steps'], config['train_timesteps']) 158 | step = 0 159 | 160 | # Resume checkpoint 161 | if args.resume: 162 | ckpt_paths = sorted(glob(path.join(config['log_dir'], 'ckpt-*.pt'))) 163 | if len(ckpt_paths) > 0: 164 | ckpt_path = ckpt_paths[-1] 165 | ckpt = torch.load(ckpt_path) 166 | model.load_state_dict(ckpt['model']) 167 | optimizer.load_state_dict(ckpt['optimizer']) 168 | lr_sched.load_state_dict(ckpt['lr_sched']) 169 | step = ckpt['step'] 170 | print(f'Checkpoint loaded from {ckpt_path}') 171 | 172 | # Define environments for evaluation 173 | env_name = config['env'] 174 | train_env_args, test_env_args = SAMPLE_ENVIRONMENT[env_name](config) 175 | train_env_args = train_env_args[:10] 176 | test_env_args = test_env_args[:10] 177 | env_args = train_env_args + test_env_args 178 | 179 | if env_name == "darkroom": 180 | envs = SubprocVecEnv([make_env(config, goal=arg) for arg in env_args]) 181 | elif env_name == 'darkroompermuted': 182 | envs = SubprocVecEnv([make_env(config, perm_idx=arg) for arg in env_args]) 183 | elif env_name == "darkkeytodoor": 184 | envs = SubprocVecEnv([make_env(config, key=arg[:2], goal=arg[2:]) for arg in env_args]) 185 | else: 186 | raise NotImplementedError(f'Environment {env_name} is not supported') 187 | 188 | model, optimizer, train_dataloader, lr_sched = accelerator.prepare( 189 | model, optimizer, train_dataloader, lr_sched 190 | ) 191 | 192 | # Main training loop 193 | start_time = datetime.now() 194 | print(f'Training started at {start_time}') 195 | 196 | with tqdm(total=config['train_timesteps'], position=0, leave=True, disable=args.disable_tqdm) as pbar: 197 | pbar.update(step) 198 | 199 | while True: 200 | batch = next(train_dataloader) 201 | 202 | step += 1 203 | 204 | with accelerator.autocast(): 205 | output = model(batch) 206 | 207 | if config['dynamics']: 208 | loss = output['loss_action'] + (output['loss_reward'] + output['loss_next_state']) * config['dynamics_strength'] 209 | else: 210 | loss = output['loss_action'] 211 | 212 | optimizer.zero_grad() 213 | accelerator.backward(loss) 214 | accelerator.clip_grad_norm_(model.parameters(), 1) 215 | optimizer.step() 216 | if not accelerator.optimizer_step_was_skipped: 217 | lr_sched.step() 218 | 219 | pbar.set_postfix(loss=loss.item()) 220 | 221 | if step % config['summary_interval'] == 0: 222 | 223 | writer.add_scalar('train/loss', loss.item(), step) 224 | writer.add_scalar('train/loss_action', output['loss_action'].item(), step) 225 | writer.add_scalar('train/lr', lr_sched.get_last_lr()[0], step) 226 | writer.add_scalar('train/acc_action', output['acc_action'].item(), step) 227 | 228 | if config['dynamics']: 229 | writer.add_scalar('train/loss_reward', output['loss_reward'].item(), step) 230 | writer.add_scalar('train/acc_reward', output['acc_reward'].item(), step) 231 | writer.add_scalar('train/loss_next_state', output['loss_next_state'].item(), step) 232 | writer.add_scalar('train/acc_next_state', output['acc_next_state'].item(), step) 233 | 234 | ############ Evaluation ############ 235 | if step % config['eval_interval'] == 0: 236 | torch.cuda.empty_cache() 237 | model.eval() 238 | eval_start_time = datetime.now() 239 | print(f'Evaluating started at {eval_start_time}') 240 | 241 | with torch.no_grad(): 242 | test_loss_action = 0.0 243 | test_acc_action = 0.0 244 | test_loss_reward = 0.0 245 | test_acc_reward = 0.0 246 | test_loss_next_state = 0.0 247 | test_acc_next_state = 0.0 248 | test_cnt = 0 249 | 250 | for j, batch in enumerate(test_dataloader): 251 | output = model(batch) 252 | cnt = len(batch['states']) 253 | test_loss_action += output['loss_action'].item() * cnt 254 | test_acc_action += output['acc_action'].item() * cnt 255 | 256 | if config['dynamics']: 257 | test_loss_reward += output['loss_reward'].item() * cnt 258 | test_acc_reward += output['acc_reward'].item() * cnt 259 | test_loss_next_state += output['loss_next_state'].item() * cnt 260 | test_acc_next_state += output['acc_next_state'].item() * cnt 261 | 262 | test_cnt += cnt 263 | 264 | writer.add_scalar('test/loss_action', test_loss_action / test_cnt, step) 265 | writer.add_scalar('test/acc_action', test_acc_action / test_cnt, step) 266 | 267 | if config['dynamics']: 268 | writer.add_scalar('test/loss_reward', test_loss_reward / test_cnt, step) 269 | writer.add_scalar('test/acc_reward', test_acc_reward / test_cnt, step) 270 | writer.add_scalar('test/loss_next_state', test_loss_next_state / test_cnt, step) 271 | writer.add_scalar('test/acc_next_state', test_acc_next_state / test_cnt, step) 272 | 273 | eval_end_time = datetime.now() 274 | print() 275 | print(f'Evaluating ended at {eval_end_time}') 276 | print(f'Elapsed time: {eval_end_time - eval_start_time}') 277 | model.train() 278 | torch.cuda.empty_cache() 279 | #################################### 280 | 281 | ############ Generation ############ 282 | if step % config['gen_interval'] == 0: 283 | model.eval() 284 | gen_start_time = datetime.now() 285 | print(f'Generation started at {gen_start_time}') 286 | 287 | with torch.no_grad(): 288 | output = model.evaluate_in_context(envs, config['train_source_timesteps']) 289 | 290 | train_rewards = output['reward_episode'][:len(train_env_args)] 291 | test_rewards = output['reward_episode'][len(train_env_args):] 292 | 293 | log_in_context(values=train_rewards, 294 | max_reward=config['max_reward'], 295 | success=None, 296 | episode_length = config['horizon'], 297 | tag='train_gen/reward_episode', 298 | title='', 299 | xlabel='In-context steps', 300 | ylabel='Reward', 301 | step=step, 302 | writer=writer) 303 | 304 | log_in_context(values=test_rewards, 305 | max_reward=config['max_reward'], 306 | success=None, 307 | episode_length = config['horizon'], 308 | tag='test_gen/reward_episode', 309 | title='', 310 | xlabel='In-context steps', 311 | ylabel='Reward', 312 | step=step, 313 | writer=writer) 314 | 315 | gen_end_time = datetime.now() 316 | print() 317 | print(f'Generation ended at {gen_end_time}') 318 | print(f'Elapsed time: {gen_end_time - gen_start_time}') 319 | model.train() 320 | torch.cuda.empty_cache() 321 | #################################### 322 | 323 | pbar.update(1) 324 | 325 | # LOGGING 326 | if step % config['ckpt_interval'] == 0: 327 | # Remove old checkpoints 328 | ckpt_paths = sorted(glob(path.join(config['log_dir'], 'ckpt-*.pt'))) 329 | for ckpt_path in ckpt_paths: 330 | os.remove(ckpt_path) 331 | 332 | new_ckpt_path = path.join(config['log_dir'], f'ckpt-{step}.pt') 333 | 334 | torch.save({ 335 | 'step': step, 336 | 'config': config, 337 | 'model': model.state_dict(), 338 | 'optimizer': optimizer.state_dict(), 339 | 'lr_sched': lr_sched.state_dict(), 340 | }, new_ckpt_path) 341 | print(f'\nCheckpoint saved to {new_ckpt_path}') 342 | 343 | if step >= config['train_timesteps']: 344 | break 345 | 346 | writer.flush() 347 | envs.close() 348 | 349 | end_time = datetime.now() 350 | print() 351 | print(f'Training ended at {end_time}') 352 | print(f'Elapsed time: {end_time - start_time}') -------------------------------------------------------------------------------- /metaworld/train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from datetime import datetime 3 | from glob import glob 4 | import os 5 | import os.path as path 6 | import sys 7 | sys.path.append(path.dirname(sys.path[0])) 8 | from modulefinder import ModuleFinder 9 | 10 | import yaml 11 | import argparse 12 | import torch 13 | from torch.optim import AdamW 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from dataset import ADDataset, IDTDataset 17 | from model import MODEL 18 | from utils import get_config, get_data_loader, log_in_context, next_dataloader 19 | from transformers import get_cosine_schedule_with_warmup 20 | 21 | import multiprocessing 22 | from tqdm import tqdm 23 | from accelerate import Accelerator 24 | from stable_baselines3.common.vec_env import SubprocVecEnv 25 | from gymnasium.wrappers.time_limit import TimeLimit 26 | import metaworld 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | 30 | 31 | def parse_arguments(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--alg-config', '-ac', required=False, default='./cfg/alg/ppo_ml1.yaml', help="Algorithm config") 34 | parser.add_argument('--env-config', '-ec', required=False, default='./cfg/env/ml1.yaml', help="Environment config") 35 | parser.add_argument('--model-config', '-mc', required=False, default='./cfg/model/idt_ml1.yaml', help="Model config") 36 | parser.add_argument('--log-dir', '-l', required=False, default='./runs', help="Log directory") 37 | parser.add_argument('--traj-dir', '-t', required=False, default='./datasets', help="Trajectory directory") 38 | parser.add_argument('--no-backup', '-nb', required=False, default=False, help="Save code", action='store_true') 39 | parser.add_argument('--override', '-o', default='') 40 | parser.add_argument('--resume', required=False, default=False, help="Resume train", action='store_true') 41 | parser.add_argument('--mixed-precision', '-m', required=False, default='fp32') 42 | parser.add_argument('--disable-tqdm', '-d', required=False, default=False, action='store_true') 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def make_env(config, env_cls, task): 49 | def _init(): 50 | env = env_cls() 51 | env.set_task(task) 52 | return TimeLimit(env, max_episode_steps=config['horizon']) 53 | return _init 54 | 55 | 56 | if __name__ == '__main__': 57 | multiprocessing.set_start_method('spawn', force=True) 58 | args = parse_arguments() 59 | 60 | # Load and update config 61 | config = get_config(args.env_config) 62 | config.update(get_config(args.alg_config)) 63 | config.update(get_config(args.model_config)) 64 | 65 | # Override options 66 | for option in args.override.split('|'): 67 | if not option: 68 | continue 69 | address, value = option.split('=') 70 | keys = address.split('.') 71 | here = config 72 | for key in keys[:-1]: 73 | if key not in here: 74 | here[key] = {} 75 | here = here[key] 76 | if keys[-1] not in here: 77 | print(f'Warning: {address} is not defined in config file.') 78 | here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader) 79 | 80 | log_dir = path.join(args.log_dir, f"{config['model']}-ml1-{config['task']}-dynamics{config['dynamics']}-var{config['learn_var']}") 81 | writer = SummaryWriter(log_dir, flush_secs=15) 82 | 83 | # Prevent overwriting 84 | config['log_dir'] = log_dir 85 | config_save_path = path.join(config['log_dir'], 'config.yaml') 86 | try: 87 | # Try to open config file to bypass NFS cache 88 | with open(config_save_path, 'r') as f: 89 | f.read(1) 90 | config_exists = True 91 | except FileNotFoundError: 92 | config_exists = False 93 | 94 | if config_exists and not args.resume: 95 | print(f'WARNING: {log_dir} already exists. Skipping...') 96 | exit(0) 97 | 98 | traj_dir = path.join(args.traj_dir, config['task']) 99 | config['traj_dir'] = traj_dir 100 | config['device'] = device 101 | config['mixed_precision'] = args.mixed_precision 102 | 103 | # Save config 104 | os.makedirs(config['log_dir'], mode=0o755, exist_ok=True) 105 | with open(config_save_path, 'w') as f: 106 | yaml.dump(config, f) 107 | print(f'Config saved to {config_save_path}') 108 | 109 | # Save code 110 | if not args.no_backup: 111 | code_dir = path.join(config['log_dir'], 'code_' + datetime.now().strftime('%Y%m%d_%H%M%S')) 112 | mf = ModuleFinder([os.getcwd()]) 113 | mf.run_script(__file__) 114 | for name, module in mf.modules.items(): 115 | if module.__file__ is None: 116 | continue 117 | rel_path = path.relpath(module.__file__) 118 | new_path = path.join(code_dir, rel_path) 119 | new_dirname = path.dirname(new_path) 120 | os.makedirs(new_dirname, mode=0o750, exist_ok=True) 121 | shutil.copy2(rel_path, new_path) 122 | print(f'Code saved to {code_dir}') 123 | 124 | # Define model 125 | model_name = config['model'] 126 | model = MODEL[model_name](config).to(device) 127 | 128 | # Get datasets and dataloaders 129 | load_start_time = datetime.now() 130 | print(f'Data loading started at {load_start_time}') 131 | 132 | if config['model'] == 'AD': 133 | train_dataset = ADDataset(config, traj_dir, 'train', config['train_n_seed'], config['train_n_stream'], config['train_source_timesteps']) 134 | test_dataset = ADDataset(config, traj_dir, 'test', 1, 1, config['train_source_timesteps']) 135 | elif config['model'] == 'IDT': 136 | train_dataset = IDTDataset(config, traj_dir, 'train', config['train_n_seed'], config['train_n_stream'], config['train_source_timesteps']) 137 | test_dataset = IDTDataset(config, traj_dir, 'test', 1, 1, config['train_source_timesteps']) 138 | else: 139 | raise ValueError(f'Unsupported model: {config["model"]}') 140 | 141 | train_dataloader = get_data_loader(train_dataset, batch_size=config['train_batch_size'], config=config, shuffle=True) 142 | train_dataloader = next_dataloader(train_dataloader) 143 | test_dataloader = get_data_loader(test_dataset, batch_size=config['test_batch_size'], config=config, shuffle=False) 144 | 145 | load_end_time = datetime.now() 146 | print() 147 | print(f'Data loading ended at {load_end_time}') 148 | print(f'Elapsed time: {load_end_time - load_start_time}') 149 | 150 | # Define optimizer and scheduler 151 | optimizer = AdamW(model.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']), weight_decay=config['weight_decay']) 152 | lr_sched = get_cosine_schedule_with_warmup(optimizer, config['num_warmup_steps'], config['train_timesteps']) 153 | step = 0 154 | 155 | # Resume checkpoint 156 | if args.resume: 157 | ckpt_paths = sorted(glob(path.join(config['log_dir'], 'ckpt-*.pt'))) 158 | if len(ckpt_paths) > 0: 159 | ckpt_path = ckpt_paths[-1] 160 | ckpt = torch.load(ckpt_path) 161 | model.load_state_dict(ckpt['model']) 162 | optimizer.load_state_dict(ckpt['optimizer']) 163 | lr_sched.load_state_dict(ckpt['lr_sched']) 164 | step = ckpt['step'] 165 | print(f'Checkpoint loaded from {ckpt_path}') 166 | 167 | # Define environments for evaluation 168 | ml1 = metaworld.ML1(env_name=config['task'], seed=config['mw_seed']) 169 | 170 | train_envs = [] 171 | test_envs = [] 172 | 173 | for task_name, env_cls in ml1.train_classes.items(): 174 | task_instances = [task for task in ml1.train_tasks if task.env_name == task_name] 175 | for i in range(config['n_train_envs_per_task']): 176 | train_envs.append(make_env(config, env_cls, task_instances[i])) 177 | 178 | for task_name, env_cls in ml1.test_classes.items(): 179 | task_instances = [task for task in ml1.test_tasks if task.env_name == task_name] 180 | for i in range(config['n_test_envs_per_task']): 181 | test_envs.append(make_env(config, env_cls, task_instances[i])) 182 | 183 | envs = train_envs + test_envs 184 | 185 | envs = SubprocVecEnv(envs) 186 | model.set_obs_space(envs.observation_space) 187 | model.set_action_space(envs.action_space) 188 | 189 | # Wrap everything into an accelerator 190 | if args.mixed_precision == 'bf16' or args.mixed_precision == 'fp16': 191 | accelerator = Accelerator(mixed_precision=args.mixed_precision) 192 | elif args.mixed_precision == 'fp32': 193 | accelerator = Accelerator(mixed_precision='no') 194 | else: 195 | raise ValueError(f'Unsupported mixed precision: {args.mixed_precision}') 196 | 197 | # Main training loop 198 | start_time = datetime.now() 199 | print(f'Training started at {start_time}') 200 | 201 | with tqdm(total=config['train_timesteps'], position=0, leave=True, disable=args.disable_tqdm) as pbar: 202 | pbar.update(step) 203 | 204 | while True: 205 | with accelerator.autocast(): 206 | batch = next(train_dataloader) 207 | 208 | step += 1 209 | 210 | with accelerator.autocast(): 211 | output = model(batch) 212 | 213 | if config['dynamics']: 214 | if config['learn_transition']: 215 | loss = output['loss_action'] + (output['loss_reward'] + output['loss_next_state']) * config['dynamics_strength'] 216 | else: 217 | loss = output['loss_action'] + output['loss_reward'] * config['dynamics_strength'] 218 | else: 219 | loss = output['loss_action'] 220 | 221 | optimizer.zero_grad() 222 | accelerator.backward(loss) 223 | accelerator.clip_grad_norm_(model.parameters(), 1) 224 | optimizer.step() 225 | if not accelerator.optimizer_step_was_skipped: 226 | lr_sched.step() 227 | 228 | pbar.set_postfix(loss=loss.item()) 229 | 230 | if step % config['summary_interval'] == 0: 231 | 232 | writer.add_scalar('train/loss', loss.item(), step) 233 | writer.add_scalar('train/loss_action', output['loss_action'].item(), step) 234 | writer.add_scalar('train/lr', lr_sched.get_last_lr()[0], step) 235 | 236 | if config['dynamics']: 237 | writer.add_scalar('train/loss_reward', output['loss_reward'].item(), step) 238 | if config['learn_transition']: 239 | writer.add_scalar('train/loss_next_state', output['loss_next_state'].item(), step) 240 | 241 | ############ Evaluation ############ 242 | if step % config['eval_interval'] == 0: 243 | torch.cuda.empty_cache() 244 | model.eval() 245 | eval_start_time = datetime.now() 246 | print(f'Evaluating started at {eval_start_time}') 247 | 248 | with torch.no_grad(): 249 | test_loss_action = 0.0 250 | test_loss_reward = 0.0 251 | test_loss_next_state = 0.0 252 | test_cnt = 0 253 | 254 | for j, batch in enumerate(test_dataloader): 255 | with accelerator.autocast(): 256 | output = model(batch) 257 | cnt = len(batch['states']) 258 | test_loss_action += output['loss_action'].item() * cnt 259 | 260 | if config['dynamics']: 261 | test_loss_reward += output['loss_reward'].item() * cnt 262 | if config['learn_transition']: 263 | test_loss_next_state += output['loss_next_state'].item() * cnt 264 | 265 | test_cnt += cnt 266 | 267 | writer.add_scalar('test/loss_action', test_loss_action / test_cnt, step) 268 | 269 | if config['dynamics']: 270 | writer.add_scalar('test/loss_reward', test_loss_reward / test_cnt, step) 271 | if config['learn_transition']: 272 | writer.add_scalar('test/loss_next_state', test_loss_next_state / test_cnt, step) 273 | 274 | 275 | eval_end_time = datetime.now() 276 | print() 277 | print(f'Evaluating ended at {eval_end_time}') 278 | print(f'Elapsed time: {eval_end_time - eval_start_time}') 279 | model.train() 280 | torch.cuda.empty_cache() 281 | #################################### 282 | 283 | ############ Generation ############ 284 | if step % config['gen_interval'] == 0: 285 | model.eval() 286 | gen_start_time = datetime.now() 287 | print(f'Generation started at {gen_start_time}') 288 | 289 | with torch.no_grad(): 290 | output = model.evaluate_in_context(envs, config['test_source_timesteps']) 291 | 292 | train_rewards = output['reward_episode'][:len(train_envs)] 293 | test_rewards = output['reward_episode'][len(train_envs):] 294 | 295 | if 'success' in output.keys(): 296 | train_success = output['success'][:len(train_envs)] 297 | test_success = output['success'][len(train_envs):] 298 | 299 | writer.add_scalar('train/success_rate', train_success.max(axis=1).mean(), step) 300 | writer.add_scalar('test/success_rate', test_success.max(axis=1).mean(), step) 301 | 302 | else: 303 | train_success = None 304 | test_success = None 305 | 306 | 307 | log_in_context(values=train_rewards, 308 | max_reward=config['max_reward'], 309 | success=train_success, 310 | episode_length = config['horizon'], 311 | tag='train_gen/reward_episode', 312 | title='', 313 | xlabel='In-context steps', 314 | ylabel='Reward', 315 | step=step, 316 | writer=writer) 317 | 318 | log_in_context(values=test_rewards, 319 | max_reward=config['max_reward'], 320 | success=test_success, 321 | episode_length = config['horizon'], 322 | tag='test_gen/reward_episode', 323 | title='', 324 | xlabel='In-context steps', 325 | ylabel='Reward', 326 | step=step, 327 | writer=writer) 328 | 329 | gen_end_time = datetime.now() 330 | print() 331 | print(f'Generation ended at {gen_end_time}') 332 | print(f'Elapsed time: {gen_end_time - gen_start_time}') 333 | model.train() 334 | torch.cuda.empty_cache() 335 | #################################### 336 | 337 | pbar.update(1) 338 | 339 | # LOGGING 340 | if step % config['ckpt_interval'] == 0: 341 | # Remove old checkpoints 342 | ckpt_paths = sorted(glob(path.join(config['log_dir'], 'ckpt-*.pt'))) 343 | for ckpt_path in ckpt_paths: 344 | os.remove(ckpt_path) 345 | 346 | new_ckpt_path = path.join(config['log_dir'], f'ckpt-{step}.pt') 347 | 348 | torch.save({ 349 | 'step': step, 350 | 'config': config, 351 | 'model': model.state_dict(), 352 | 'optimizer': optimizer.state_dict(), 353 | 'lr_sched': lr_sched.state_dict(), 354 | }, new_ckpt_path) 355 | print(f'\nCheckpoint saved to {new_ckpt_path}') 356 | 357 | 358 | if step >= config['train_timesteps']: 359 | break 360 | 361 | writer.flush() 362 | envs.close() 363 | 364 | end_time = datetime.now() 365 | print() 366 | print(f'Training ended at {end_time}') 367 | print(f'Elapsed time: {end_time - start_time}') -------------------------------------------------------------------------------- /gridworld/model/tiny_llama/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training and inference.""" 2 | 3 | import pickle 4 | import sys 5 | import warnings 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from io import BytesIO 9 | from pathlib import Path 10 | from types import MethodType 11 | from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils._device 16 | from lightning.fabric.loggers import CSVLogger 17 | from torch.serialization import normalize_storage_type 18 | 19 | 20 | def find_multiple(n: int, k: int) -> int: 21 | assert k > 0 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 28 | return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) 29 | 30 | 31 | @contextmanager 32 | def quantization(mode: Optional[str] = None): 33 | if mode is None: 34 | yield 35 | return 36 | 37 | if mode == "bnb.int8": 38 | from quantize.bnb import InferenceLinear8bitLt 39 | 40 | quantized_linear_cls = InferenceLinear8bitLt 41 | elif mode == "bnb.fp4": 42 | from quantize.bnb import Linear4bit 43 | 44 | class QuantizedLinear(Linear4bit): 45 | def __init__(self, *args, **kwargs): 46 | super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) 47 | 48 | quantized_linear_cls = QuantizedLinear 49 | elif mode == "bnb.fp4-dq": 50 | from quantize.bnb import Linear4bit 51 | 52 | class QuantizedLinear(Linear4bit): 53 | def __init__(self, *args, **kwargs): 54 | super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) 55 | 56 | quantized_linear_cls = QuantizedLinear 57 | elif mode == "bnb.nf4": 58 | from quantize.bnb import Linear4bit 59 | 60 | class QuantizedLinear(Linear4bit): 61 | def __init__(self, *args, **kwargs): 62 | super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) 63 | 64 | quantized_linear_cls = QuantizedLinear 65 | elif mode == "bnb.nf4-dq": 66 | from quantize.bnb import Linear4bit 67 | 68 | class QuantizedLinear(Linear4bit): 69 | def __init__(self, *args, **kwargs): 70 | super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) 71 | 72 | quantized_linear_cls = QuantizedLinear 73 | elif mode == "gptq.int4": 74 | from quantize.gptq import ColBlockQuantizedLinear 75 | 76 | class QuantizedLinear(ColBlockQuantizedLinear): 77 | def __init__(self, *args, **kwargs): 78 | super().__init__(*args, bits=4, tile_cols=-1, **kwargs) 79 | 80 | quantized_linear_cls = QuantizedLinear 81 | else: 82 | raise ValueError(f"Unknown quantization mode: {mode}") 83 | 84 | torch_linear_cls = torch.nn.Linear 85 | torch.nn.Linear = quantized_linear_cls 86 | yield 87 | torch.nn.Linear = torch_linear_cls 88 | 89 | 90 | class NotYetLoadedTensor: 91 | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): 92 | self.metatensor = metatensor 93 | self.archiveinfo = archiveinfo 94 | self.storageinfo = storageinfo 95 | self.rebuild_args = rebuild_args 96 | 97 | @classmethod 98 | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): 99 | ret = func(*args) 100 | if isinstance(ret, NotYetLoadedTensor): 101 | old_lt = ret._load_tensor 102 | 103 | def _load_tensor(): 104 | t = old_lt() 105 | return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) 106 | 107 | ret._load_tensor = _load_tensor 108 | return ret 109 | return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) 110 | 111 | @classmethod 112 | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): 113 | if isinstance(data, NotYetLoadedTensor): 114 | old_lt = data._load_tensor 115 | 116 | def _load_tensor(): 117 | t = old_lt() 118 | return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) 119 | 120 | data._load_tensor = _load_tensor 121 | return data 122 | return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) 123 | 124 | @classmethod 125 | def rebuild_tensor_v2( 126 | cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None 127 | ): 128 | rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) 129 | metatensor = torch._utils._rebuild_tensor_v2( 130 | storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata 131 | ) 132 | storageinfo = storage.archiveinfo 133 | return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) 134 | 135 | def _load_tensor(self): 136 | name, storage_cls, fn, device, size = self.storageinfo 137 | dtype = self.metatensor.dtype 138 | 139 | uts = ( 140 | self.archiveinfo.zipfile_context.zf.get_storage_from_record( 141 | f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage 142 | ) 143 | ._typed_storage() 144 | ._untyped_storage 145 | ) 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore") 148 | storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) 149 | return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) 150 | 151 | @classmethod 152 | def __torch_function__(cls, func, types, args=(), kwargs=None): 153 | if kwargs is None: 154 | kwargs = {} 155 | loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] 156 | return func(*loaded_args, **kwargs) 157 | 158 | def __getattr__(self, name): 159 | if name in { 160 | "dtype", 161 | "grad", 162 | "grad_fn", 163 | "layout", 164 | "names", 165 | "ndim", 166 | "output_nr", 167 | "requires_grad", 168 | "retains_grad", 169 | "shape", 170 | "volatile", 171 | }: 172 | return getattr(self.metatensor, name) 173 | if name in {"size"}: 174 | return getattr(self.metatensor, name) 175 | if name in {"contiguous"}: 176 | return getattr(self._load_tensor(), name) 177 | 178 | raise AttributeError(f"{type(self)} does not have {name}") 179 | 180 | def __repr__(self): 181 | return f"NotYetLoadedTensor({repr(self.metatensor)})" 182 | 183 | 184 | class LazyLoadingUnpickler(pickle.Unpickler): 185 | def __init__(self, file, zipfile_context): 186 | super().__init__(file) 187 | self.zipfile_context = zipfile_context 188 | 189 | def find_class(self, module, name): 190 | res = super().find_class(module, name) 191 | if module == "torch._utils" and name == "_rebuild_tensor_v2": 192 | return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) 193 | if module == "torch._tensor" and name == "_rebuild_from_type_v2": 194 | return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) 195 | if module == "torch._utils" and name == "_rebuild_parameter": 196 | return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) 197 | return res 198 | 199 | def persistent_load(self, pid): 200 | name, cls, fn, device, size = pid 201 | with warnings.catch_warnings(): 202 | warnings.simplefilter("ignore") 203 | s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") 204 | s.archiveinfo = pid 205 | return s 206 | 207 | 208 | class lazy_load: 209 | def __init__(self, fn): 210 | self.zf = torch._C.PyTorchFileReader(str(fn)) 211 | with BytesIO(self.zf.get_record("data.pkl")) as pkl: 212 | mup = LazyLoadingUnpickler(pkl, self) 213 | self.sd = mup.load() 214 | 215 | def __enter__(self): 216 | return self.sd 217 | 218 | def __exit__(self, exc_type, exc_val, exc_tb): 219 | del self.zf 220 | self.zf = None 221 | 222 | 223 | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: 224 | files = { 225 | "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), 226 | "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), 227 | "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( 228 | checkpoint_dir / "tokenizer.model" 229 | ).is_file(), 230 | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), 231 | } 232 | if checkpoint_dir.is_dir(): 233 | if all(files.values()): 234 | return 235 | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" 236 | else: 237 | problem = " is not a checkpoint directory" 238 | 239 | available = list(Path("checkpoints").glob("*/*")) 240 | if available: 241 | options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) 242 | extra = f"\nYou have downloaded locally:{options}\n" 243 | else: 244 | extra = "" 245 | 246 | error_message = ( 247 | f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." 248 | "\nFind download instructions at ." 249 | f"{extra}\nSee all download options by running:\n python scripts/download.py" 250 | ) 251 | print(error_message, file=sys.stderr) 252 | raise SystemExit(1) 253 | 254 | 255 | class SavingProxyForStorage: 256 | def __init__(self, obj, saver, protocol_version=5): 257 | self.protocol_version = protocol_version 258 | self.saver = saver 259 | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): 260 | raise TypeError(f"expected storage, not {type(obj)}") 261 | 262 | if isinstance(obj, torch.storage.TypedStorage): 263 | storage = obj._untyped_storage 264 | storage_type_str = obj._pickle_storage_type() 265 | storage_type = getattr(torch, storage_type_str) 266 | storage_numel = obj._size() 267 | else: 268 | storage = obj 269 | storage_type = normalize_storage_type(type(obj)) 270 | storage_numel = storage.nbytes() 271 | 272 | storage_key = saver._write_storage_and_return_key(storage) 273 | location = torch.serialization.location_tag(storage) 274 | 275 | self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) 276 | 277 | def __reduce_ex__(self, protocol_version): 278 | assert False, "this should be handled with out of band" 279 | 280 | 281 | class SavingProxyForTensor: 282 | def __init__(self, tensor, saver, protocol_version=5): 283 | self.protocol_version = protocol_version 284 | self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) 285 | assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" 286 | storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) 287 | self.reduce_args = (storage_proxy, *other_reduce_args) 288 | 289 | def __reduce_ex__(self, protocol_version): 290 | if protocol_version != self.protocol_version: 291 | raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") 292 | return self.reduce_ret_fn, self.reduce_args 293 | 294 | 295 | class IncrementalPyTorchPickler(pickle.Pickler): 296 | def __init__(self, saver, *args, **kwargs): 297 | super().__init__(*args, **kwargs) 298 | self.storage_dtypes = {} 299 | self.saver = saver 300 | self.id_map = {} 301 | 302 | def persistent_id(self, obj): 303 | if isinstance(obj, SavingProxyForStorage): 304 | return obj.storage_info 305 | 306 | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 307 | if isinstance(obj, torch.storage.TypedStorage): 308 | storage = obj._untyped_storage 309 | storage_dtype = obj.dtype 310 | storage_type_str = obj._pickle_storage_type() 311 | storage_type = getattr(torch, storage_type_str) 312 | storage_numel = obj._size() 313 | 314 | else: 315 | storage = obj 316 | storage_dtype = torch.uint8 317 | storage_type = normalize_storage_type(type(obj)) 318 | storage_numel = storage.nbytes() 319 | 320 | if storage.data_ptr() != 0: 321 | if storage.data_ptr() in self.storage_dtypes: 322 | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: 323 | raise RuntimeError( 324 | "Cannot save multiple tensors or storages that view the same data as different types" 325 | ) 326 | else: 327 | self.storage_dtypes[storage.data_ptr()] = storage_dtype 328 | 329 | storage_key = self.id_map.get(storage._cdata) 330 | if storage_key is None: 331 | storage_key = self.saver._write_storage_and_return_key(storage) 332 | self.id_map[storage._cdata] = storage_key 333 | location = torch.serialization.location_tag(storage) 334 | 335 | return ("storage", storage_type, storage_key, location, storage_numel) 336 | 337 | return None 338 | 339 | 340 | class incremental_save: 341 | def __init__(self, name): 342 | self.name = name 343 | self.zipfile = torch._C.PyTorchFileWriter(str(name)) 344 | self.has_saved = False 345 | self.next_key = 0 346 | 347 | def __enter__(self): 348 | return self 349 | 350 | def store_early(self, tensor): 351 | if isinstance(tensor, torch.Tensor): 352 | return SavingProxyForTensor(tensor, self) 353 | raise TypeError(f"can only store tensors early, not {type(tensor)}") 354 | 355 | def save(self, obj): 356 | if self.has_saved: 357 | raise RuntimeError("have already saved") 358 | data_buf = BytesIO() 359 | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) 360 | pickler.dump(obj) 361 | data_value = data_buf.getvalue() 362 | self.zipfile.write_record("data.pkl", data_value, len(data_value)) 363 | self.has_saved = True 364 | 365 | def _write_storage_and_return_key(self, storage): 366 | if self.has_saved: 367 | raise RuntimeError("have already saved") 368 | key = self.next_key 369 | self.next_key += 1 370 | name = f"data/{key}" 371 | if storage.device.type != "cpu": 372 | storage = storage.cpu() 373 | num_bytes = storage.nbytes() 374 | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) 375 | return key 376 | 377 | def __exit__(self, type, value, traceback): 378 | self.zipfile.write_end_of_file() 379 | 380 | 381 | T = TypeVar("T") 382 | 383 | 384 | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: 385 | logger = cls(*args, **kwargs) 386 | 387 | def merge_by(dicts, key): 388 | from collections import defaultdict 389 | 390 | out = defaultdict(dict) 391 | for d in dicts: 392 | if key in d: 393 | out[d[key]].update(d) 394 | return [v for _, v in sorted(out.items())] 395 | 396 | def save(self) -> None: 397 | import csv 398 | 399 | if not self.metrics: 400 | return 401 | metrics = merge_by(self.metrics, "step") 402 | keys = sorted({k for m in metrics for k in m}) 403 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 404 | writer = csv.DictWriter(f, fieldnames=keys) 405 | writer.writeheader() 406 | writer.writerows(metrics) 407 | 408 | logger.experiment.save = MethodType(save, logger.experiment) 409 | 410 | return logger 411 | 412 | 413 | def chunked_cross_entropy( 414 | logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 415 | ) -> torch.Tensor: 416 | 417 | if isinstance(logits, list): 418 | if chunk_size == 0: 419 | logits = torch.cat(logits, dim=1) 420 | logits = logits.reshape(-1, logits.size(-1)) 421 | targets = targets.reshape(-1) 422 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 423 | 424 | logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] 425 | target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] 426 | loss_chunks = [ 427 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 428 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 429 | ] 430 | return torch.cat(loss_chunks).mean() 431 | 432 | logits = logits.reshape(-1, logits.size(-1)) 433 | targets = targets.reshape(-1) 434 | if chunk_size == 0: 435 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 436 | 437 | logit_chunks = logits.split(chunk_size) 438 | target_chunks = targets.split(chunk_size) 439 | loss_chunks = [ 440 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 441 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 442 | ] 443 | return torch.cat(loss_chunks).mean() 444 | 445 | 446 | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: 447 | for checkpoint_name, attribute_name in mapping.items(): 448 | full_checkpoint_name = prefix + checkpoint_name 449 | if full_checkpoint_name in state_dict: 450 | full_attribute_name = prefix + attribute_name 451 | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) 452 | return state_dict 453 | 454 | 455 | def get_default_supported_precision(training: bool, tpu: bool = False) -> str: 456 | if tpu: 457 | return "32-true" 458 | if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): 459 | return "bf16-mixed" if training else "bf16-true" 460 | return "16-mixed" if training else "16-true" --------------------------------------------------------------------------------