├── imgs ├── fig_1.png └── fig_2.png ├── run.sh ├── requirements.txt ├── args_yml ├── model_base_IRL │ ├── halfcheetah_v2_medexp.yml │ ├── hopper_v2_medrep.yml │ ├── hopper_v2_medium.yml │ ├── hopper_v2_medexp.yml │ ├── halfcheetah_v2_medium.yml │ ├── halfcheetah_v2_medrep.yml │ ├── walker2d_v2_medium.yml │ ├── walker2d_v2_medexp.yml │ └── walker2d_v2_medrep.yml └── transfer │ ├── hopper_v2_transfer.yml │ └── halfcheetah_v2_transfer.yml ├── done_funcs.py ├── README.md ├── offline_evaluation.py ├── train.py ├── sac.py ├── train_funcs.py ├── utils.py ├── trainer.py └── model.py /imgs/fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloud0723/Offline-MLIRL/HEAD/imgs/fig_1.png -------------------------------------------------------------------------------- /imgs/fig_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloud0723/Offline-MLIRL/HEAD/imgs/fig_2.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python train.py --yaml_file args_yml/model_base_IRL/hopper_v2_medexp.yml --seed 3 --uuid hopper_test1 2 | python train.py --yaml_file args_yml/model_base_IRL/halfcheetah_v2_medexp.yml --seed 3 --uuid halfcheetah_test1 3 | python train.py --yaml_file args_yml/model_base_IRL/walker2d_v2_medexp.yml --seed 3 --uuid walker_test1 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atari-py==0.2.6 2 | box2d-py==2.3.8 3 | cffi==1.13.2 4 | cloudpickle==1.2.2 5 | Cython==0.29.14 6 | fasteners==0.15 7 | future==0.18.2 8 | glfw==1.8.5 9 | gym==0.15.4 10 | imageio==2.6.1 11 | lockfile==0.12.2 12 | monotonic==1.5 13 | numpy==1.17.4 14 | opencv-python==4.1.2.30 15 | pandas==0.25.3 16 | Pillow==6.2.1 17 | pycparser==2.19 18 | pyglet==1.3.2 19 | python-dateutil==2.8.1 20 | pytz==2019.3 21 | PyYAML==5.1.2 22 | scipy==1.3.3 23 | six==1.13.0 24 | torchvision==0.4.2 25 | -------------------------------------------------------------------------------- /args_yml/model_base_IRL/halfcheetah_v2_medexp.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | args: 3 | env_name: halfcheetah-medium-expert-v2 4 | reward_head: True 5 | logvar_head: True 6 | states: 'uniform' 7 | steps_k: 5 8 | reward_steps: 200 9 | num_rollouts_per_step: 50 10 | policy_update_steps: 1000 11 | train_policy_every: 100 12 | train_val_ratio: 0.05 13 | real_sample_ratio: 0.05 14 | model_train_freq: 1000 15 | max_timesteps: 10000000 16 | n_eval_rollouts: 10 17 | num_models: 7 18 | num_elites: 5 19 | d4rl: True 20 | model_retain_epochs: 5 21 | mopo: True 22 | mopo_lam: 4.56 23 | # tune_mopo_lam: True 24 | offline_epochs: 1000 25 | save_model: True 26 | save_policy: True 27 | 28 | load_model_dir: world_model/model_halfcheetah/halfcheetah_medexp_v2_seed0/checkpoints/model_saved_weights/Model_halfcheetah-medium-expert-v2_seed0_2022_11_20_00-14-44 29 | train_memory: 2000000 30 | val_memory: 500000 31 | -------------------------------------------------------------------------------- /args_yml/model_base_IRL/hopper_v2_medrep.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | #N λ h: 12 39.08 43 3 | args: 4 | env_name: hopper-medium-replay-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 43 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 7 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | mopo_lam: 5.90 24 | offline_epochs: 1000 25 | save_model: True 26 | save_policy: True 27 | load_model_dir: /world_model/model_hopper/hopper-v2-medreplay/checkpoints/model_saved_weights/Model_hopper-medium-replay-v2_seed0_2023_01_05_03-33-00/ 28 | train_memory: 2000000 29 | val_memory: 500000 -------------------------------------------------------------------------------- /args_yml/model_base_IRL/hopper_v2_medium.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | #N λ h: 12 39.08 43 3 | args: 4 | env_name: hopper-medium-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 43 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 7 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | mopo_lam: 37.28 24 | mopo_penalty_type: ensemble_std 25 | offline_epochs: 1000 26 | save_model: True 27 | save_policy: True 28 | load_model_dir: world_model/model_hopper/hopper-v2-med/checkpoints/model_saved_weights/Model_hopper-medium-v2_seed0_2023_01_05_03-33-00/ 29 | train_memory: 2000000 30 | val_memory: 500000 -------------------------------------------------------------------------------- /args_yml/transfer/hopper_v2_transfer.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | #N λ h: 12 39.08 43 3 | args: 4 | env_name: hopper-medium-replay-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 43 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 7 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | mopo_lam: 5.90 24 | offline_epochs: 1000 25 | save_model: True 26 | save_policy: True 27 | load_model_dir: /home/luoqijun/code/IRL_Code/MBIRL/rethinking-code-supp/data/hopper-model/hopper-v2-medreplay/checkpoints/model_saved_weights/Model_hopper-medium-replay-v2_seed0_2023_01_04_05-02-49 28 | train_memory: 2000000 29 | val_memory: 500000 30 | transfer: True -------------------------------------------------------------------------------- /args_yml/model_base_IRL/hopper_v2_medexp.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | #N λ h: 12 39.08 43 3 | args: 4 | env_name: hopper-medium-expert-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 43 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 12 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | #mopo_lam: 39.08 24 | mopo_lam: 5.90 25 | offline_epochs: 1000 26 | save_model: True 27 | save_policy: True 28 | load_model_dir: world_model/model_hopper/hopper-v2-medexp/checkpoints/model_saved_weights/Model_hopper-medium-expert-v2_seed0_2023_01_04_05-02-50/torch_model_weights.pt 29 | train_memory: 2000000 30 | val_memory: 500000 31 | transfer: False -------------------------------------------------------------------------------- /args_yml/model_base_IRL/halfcheetah_v2_medium.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | # medium-v0 N, lam, h: 12 5.92 6 3 | args: 4 | env_name: halfcheetah-medium-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 6 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 7 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | #mopo_lam: 5.92 24 | mopo_lam: 5.92 25 | offline_epochs: 1000 26 | mopo_penalty_type: ensemble_var 27 | load_model_dir: world_model/model_halfcheetah/halfcheetah_medium_v2_seed0/checkpoints/model_saved_weights/Model_halfcheetah-medium-v2_seed0_2022_11_20_01-53-21 28 | save_model: True 29 | save_policy: True 30 | train_memory: 2000000 31 | val_memory: 500000 -------------------------------------------------------------------------------- /args_yml/model_base_IRL/halfcheetah_v2_medrep.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | # medium: 11 0.96 37 3 | args: 4 | env_name: halfcheetah-medium-replay-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 37 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 11 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | mopo_lam: 0.96 24 | offline_epochs: 1000 25 | augment_offline_data: False 26 | mopo_penalty_type: ensemble_var 27 | load_model_dir: world_model/model_halfcheetah/halfcheetah_mixed_v2_seed0/checkpoints/model_saved_weights/Model_halfcheetah-medium-replay-v2_seed0_2022_11_20_01-52-59/ 28 | save_policy: True 29 | save_model: True 30 | train_memory: 2000000 31 | val_memory: 500000 -------------------------------------------------------------------------------- /done_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def hopper_is_done_func(next_obs): 4 | 5 | if len(next_obs.shape) == 1: 6 | next_obs = next_obs.unsqueeze(0) 7 | 8 | height = next_obs[:, 0] 9 | angle = next_obs[:, 1] 10 | not_done = torch.isfinite(next_obs).all(axis=-1) \ 11 | * (torch.abs(next_obs[:,1:]) < 100).all(axis=-1) \ 12 | * (height > .7) \ 13 | * (torch.abs(angle) < .2) 14 | 15 | done = ~not_done 16 | return done 17 | 18 | def walker2d_is_done_func(next_obs): 19 | 20 | if len(next_obs.shape) == 1: 21 | next_obs.unsqueeze(0) 22 | 23 | 24 | height = next_obs[:, 0] 25 | ang = next_obs[:, 1] 26 | done = ~((height > 0.8) & (height < 2.0) & 27 | (ang > -1.0) & (ang < 1.0)) 28 | 29 | return done 30 | 31 | def ant_is_done_func(next_obs): 32 | 33 | if len(next_obs.shape) == 1: 34 | next_obs.unsqueeze(0) 35 | 36 | height = next_obs[:, 0] 37 | not_done = torch.isfinite(next_obs).all(axis=-1) \ 38 | * (height >= 0.2) \ 39 | * (height <= 1.0) 40 | 41 | done = ~not_done 42 | return done -------------------------------------------------------------------------------- /args_yml/model_base_IRL/walker2d_v2_medium.yml: -------------------------------------------------------------------------------- 1 | # N, lambda, h = 12 0.99 37 penalty_teype: Ensemble Std 2 | # environment: walker2d-medium-expert-v0 3 | 4 | args: 5 | env_name: walker2d-medium-v2 # d4rl mixed, supposedly the best for MOPO 6 | reward_head: True 7 | logvar_head: True 8 | states: 'uniform' 9 | steps_k: 37 10 | reward_steps: 200 11 | num_rollouts_per_step: 50 12 | policy_update_steps: 1000 13 | train_policy_every: 100 14 | train_val_ratio: 0.2 15 | real_sample_ratio: 0.05 16 | model_train_freq: 1000 17 | max_timesteps: 10000000 18 | n_eval_rollouts: 10 19 | num_models: 8 20 | num_elites: 5 21 | d4rl: True 22 | model_retain_epochs: 5 23 | mopo: True 24 | mopo_lam: 0.99 25 | mopo_penalty_type: ensemble_std 26 | #tune_mopo_lam: False 27 | min_model_epochs: 350 28 | offline_epochs: 2000 29 | save_model: True 30 | save_policy: True 31 | load_model_dir: world_model/model_walker2d/walker2d-med/checkpoints/model_saved_weights/Model_walker2d-medium-v2_seed0_2022_11_28_14-27-54 32 | train_memory: 1000000 33 | val_memory: 500000 34 | transfer: True -------------------------------------------------------------------------------- /args_yml/transfer/halfcheetah_v2_transfer.yml: -------------------------------------------------------------------------------- 1 | # halfcheetah experiments debug 2 | # medium: 11 0.96 37 3 | args: 4 | env_name: halfcheetah-medium-replay-v2 # d4rl mixed, supposedly the best for MOPO 5 | reward_head: True 6 | logvar_head: True 7 | states: 'uniform' 8 | steps_k: 37 9 | reward_steps: 200 10 | num_rollouts_per_step: 50 11 | policy_update_steps: 1000 12 | train_policy_every: 100 13 | train_val_ratio: 0.2 14 | real_sample_ratio: 0.05 15 | model_train_freq: 1000 16 | max_timesteps: 10000000 17 | n_eval_rollouts: 10 18 | num_models: 11 19 | num_elites: 5 20 | d4rl: True 21 | model_retain_epochs: 5 22 | mopo: True 23 | mopo_lam: 0.96 24 | offline_epochs: 1000 25 | augment_offline_data: False 26 | mopo_penalty_type: ensemble_var 27 | load_model_dir: /home/luoqijun/code/IRL_Code/MBIRL/rethinking-code-supp/data/model_halfcheetah/halfcheetah_mixed_v2_seed0/checkpoints/model_saved_weights/Model_halfcheetah-medium-replay-v2_seed0_2022_11_20_01-52-59 28 | save_policy: True 29 | save_model: True 30 | train_memory: 2000000 31 | val_memory: 500000 32 | transfer: True -------------------------------------------------------------------------------- /args_yml/model_base_IRL/walker2d_v2_medexp.yml: -------------------------------------------------------------------------------- 1 | # N, lambda, h = 12 0.99 37 penalty_teype: Ensemble Std 2 | # environment: walker2d-medium-expert-v0 3 | 4 | args: 5 | env_name: walker2d-medium-expert-v2 # d4rl mixed, supposedly the best for MOPO 6 | reward_head: True 7 | logvar_head: True 8 | states: 'uniform' 9 | steps_k: 37 10 | reward_steps: 200 11 | num_rollouts_per_step: 50 12 | policy_update_steps: 1000 13 | train_policy_every: 100 14 | train_val_ratio: 0.2 15 | real_sample_ratio: 0.05 16 | model_train_freq: 1000 17 | max_timesteps: 10000000 18 | n_eval_rollouts: 10 19 | num_models: 12 20 | num_elites: 5 21 | d4rl: True 22 | model_retain_epochs: 5 23 | mopo: True 24 | mopo_lam: 0.99 25 | mopo_penalty_type: ensemble_std 26 | #tune_mopo_lam: False 27 | min_model_epochs: 350 28 | offline_epochs: 2000 29 | save_model: True 30 | save_policy: True 31 | load_model_dir: world_model/model_walker2d/walker2d-medexp/checkpoints/model_saved_weights/Model_walker2d-medium-expert-v2_seed0_2022_11_28_14-27-31 32 | train_memory: 1000000 33 | val_memory: 500000 34 | transfer: False -------------------------------------------------------------------------------- /args_yml/model_base_IRL/walker2d_v2_medrep.yml: -------------------------------------------------------------------------------- 1 | # N, lambda, h = 12 0.99 37 penalty_teype: Ensemble Std 2 | # environment: walker2d-medium-expert-v0 3 | 4 | args: 5 | env_name: walker2d-medium-replay-v2 # d4rl mixed, supposedly the best for MOPO 6 | reward_head: True 7 | logvar_head: True 8 | states: 'uniform' 9 | steps_k: 37 10 | reward_steps: 200 11 | num_rollouts_per_step: 50 12 | policy_update_steps: 1000 13 | train_policy_every: 100 14 | train_val_ratio: 0.2 15 | real_sample_ratio: 0.05 16 | model_train_freq: 1000 17 | max_timesteps: 10000000 18 | n_eval_rollouts: 10 19 | num_models: 13 20 | num_elites: 5 21 | d4rl: True 22 | model_retain_epochs: 5 23 | mopo: True 24 | mopo_lam: 0.99 25 | mopo_penalty_type: ensemble_std 26 | #tune_mopo_lam: False 27 | min_model_epochs: 350 28 | offline_epochs: 2000 29 | save_model: True 30 | save_policy: True 31 | load_model_dir: world_model/model_walker2d/walker2d-medreplay/checkpoints/model_saved_weights/Model_walker2d-medium-replay-v2_seed0_2022_11_28_14-28-24 32 | train_memory: 1000000 33 | val_memory: 500000 34 | transfer: True -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # When Demonstrations meet Generative World Models: A Maximum Likelihood Framework for Offline Inverse Reinforcement Learning 2 | Offline ML-IRL is an algorithm for offline inverse reinforcement learning that is discussed in the article [arxiv link](http://arxiv.org/abs/2302.07457) 3 | 4 | Here is the [link](https://github.com/Cloud0723/ML-IRL) to our online version. 5 | ## Installation 6 | - PyTorch 1.13.1 7 | - MuJoCo 2.1.0 8 | - pip install -r requirements.txt 9 | 10 | 11 | ## File Structure 12 | - Experiment result :`data/` 13 | - Configurations: `args_yml/` 14 | - Expert Demonstrations: `expert_data/` 15 | 16 | ## Instructions 17 | - All the experiments are to be run under the root folder. 18 | - After running, you will see the training logs in `data/` folder. 19 | 20 | ## Experiments 21 | All the commands below are also provided in `run.sh`. 22 | 23 | ### Offline-IRL benchmark (MuJoCo) 24 | Before experiment, you can download our expert demonstrations and our trained world model [here](https://drive.google.com/drive/folders/1BbEZLEKP6HAijeRBXG0V3JLSrB0FIQg6?usp=drive_link). 25 | 26 | ```bash 27 | python train.py --yaml_file args_yml/model_base_IRL/halfcheetah_v2_medium.yml --seed 0 --uuid halfcheetah_result 28 | ``` 29 | also you can use: 30 | ```bash 31 | ./run.sh 32 | ``` 33 | 34 | ## Performances 35 | ![Graph](imgs/fig_1.png) 36 | 37 | ----- 38 | 39 | ![Graph](imgs/fig_2.png) 40 | 41 | ---- 42 | -------------------------------------------------------------------------------- /offline_evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | from random import shuffle 4 | import glob 5 | 6 | import numpy as np 7 | import torch 8 | import yaml 9 | import gym 10 | from gym.wrappers import TimeLimit 11 | import d4rl 12 | from numpy.random import default_rng 13 | 14 | from modified_envs import AntMOPOEnv 15 | from model import EnsembleGymEnv 16 | from trainer import Trainer 17 | from done_funcs import * 18 | from sac import SAC_Agent 19 | 20 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 21 | 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | class Assessor: 26 | """ 27 | An assessor that will tell us how well a policy performs agains: 28 | a) each member of our world model collection 29 | b) our actual real world env 30 | """ 31 | 32 | def __init__(self, trainer: Trainer, d4rl_env: str): 33 | """ 34 | Constructor 35 | 36 | Args: 37 | trainer: A trainer instance that contains the WM used to train a policy, and "population" members 38 | that we can assess against, as well as real envs 39 | d4rl_env: String of the d4rl offline env we've trained on so we can load start states 40 | """ 41 | self._trainer = trainer 42 | env = gym.make(d4rl_env) 43 | self._dataset = d4rl.qlearning_dataset(env) 44 | self._rng = default_rng() 45 | 46 | def evaluate_policy(self, policy, wm_start_states, wm_init_states, real_evals=20, truncated_steps=100): 47 | """ 48 | Method to evaluate a policy in our WMs, using both full traj and truncated/myopic 49 | 50 | Args: 51 | 52 | policy: The policy under evaluation 53 | 54 | wm_start_states: A NxM matrix of starting states for the WM 55 | (could be anywhere in the trajectory) where 56 | N is num states (i.e., evals), and M is state 57 | dimension 58 | 59 | wm_init_states: A NxM matrix of initial states (s_0) for the 60 | WM (sampled from initial state dist.) where N 61 | is num states (i.e., evals), and M is state 62 | dimension 63 | 64 | real_evals: How many returns to gather under true model (i.e., iterations) 65 | 66 | truncated_steps: How many steps to rollout for with the wm_start_states 67 | """ 68 | 69 | stats = { 70 | "True Perf": None, 71 | "WM Perf": None, 72 | "WM Myopic Perf": None, 73 | } 74 | 75 | # Save old policy just in case... 76 | old_policy = self._trainer.agent 77 | 78 | # Load the new policy 79 | self._trainer.agent = policy 80 | 81 | # Step 1: Get actual performance on real env 82 | stats["True Perf"] = self._trainer.test_agent(use_model=False, n_evals=real_evals) 83 | 84 | # Step 2: Get full trajectory perf in the WMs assuming it's T=1000 (MuJoCo) 85 | stats["WM Perf"] = self._test_agent_wm(1000, wm_init_states) 86 | 87 | # Step 3: Get myopic perf in the WMs 88 | stats["WM Myopic Perf"] = self._test_agent_wm(truncated_steps, wm_start_states) 89 | 90 | # Let's reload that old policy 91 | self._trainer.agent = old_policy 92 | 93 | return stats 94 | 95 | def _test_agent_wm(self, num_steps, start_states): 96 | """ 97 | Internal method to test our policy we loaded in the world models. 98 | 99 | Args: 100 | num_steps: Horizon to rollout for 101 | start_states: start states tensor 102 | """ 103 | results = [] 104 | 105 | for wm_idx in self._trainer.population_models: 106 | results.append( 107 | self._trainer.test_agent_myopic(start_states=start_states, num_steps=num_steps, population_idx=wm_idx)) 108 | 109 | return np.array(results) 110 | 111 | def get_pool_start_states(self, n_samples: int, reverse_reward: bool = False, unique: bool = True): 112 | """ 113 | Method to return some random states from the D4RL pool 114 | 115 | Args: 116 | n_samples: Int of number of samples 117 | reverse_reward: Whether subset reversed reward 118 | unique: Whether to have unique samples 119 | """ 120 | if reverse_reward: 121 | assert unique is True, "Reversed reward is by definition unique" 122 | all_states, all_rewards = np.array(self._dataset['observations']), np.array(self._dataset['rewards']) 123 | worst_states_idx = all_rewards.argsort()[:n_samples] 124 | start_states = all_states[worst_states_idx] 125 | else: 126 | len_dataset = self._dataset['observations'].shape[0] 127 | idx = np.random.randint(0, len_dataset, n_samples) if not unique else self._rng.choice(len_dataset, 128 | size=n_samples, 129 | replace=False) 130 | start_states = self._dataset['observations'][idx] 131 | return torch.Tensor(start_states).to(device) 132 | 133 | 134 | def get_assessor_from_yaml(yaml_file_path: str, wm_dirs: List[str], d4rl_env: str): 135 | with open(yaml_file_path, 'r') as f: 136 | params = yaml.load(f, Loader=yaml.FullLoader) 137 | 138 | params['population_model_dirs'] = wm_dirs 139 | 140 | if params['env_name'] == 'AntMOPOEnv': 141 | env = TimeLimit(AntMOPOEnv(), 1000) 142 | eval_env = TimeLimit(AntMOPOEnv(), 1000) 143 | else: 144 | env = gym.make(params['env_name']) 145 | eval_env = gym.make(params['env_name']) 146 | 147 | env = EnsembleGymEnv(params, env, eval_env) 148 | 149 | env_name_lower = params['env_name'].lower() 150 | 151 | if isinstance(params['steps_k'], list): 152 | init_steps_k = params['steps_k'][0] 153 | else: 154 | init_steps_k = params['steps_k'] 155 | 156 | steps_per_epoch = params['epoch_steps'] if params['epoch_steps'] else env.real_env.env.spec.max_episode_steps 157 | 158 | if 'hopper' in env_name_lower: 159 | params['is_done_func'] = hopper_is_done_func 160 | elif 'walker' in env_name_lower: 161 | params['is_done_func'] = walker2d_is_done_func 162 | elif 'ant' in env_name_lower: 163 | params['is_done_func'] = ant_is_done_func 164 | else: 165 | params['is_done_func'] = None 166 | 167 | init_buffer_size = init_steps_k * params['num_rollouts_per_step'] * steps_per_epoch * params[ 168 | 'model_retain_epochs'] 169 | 170 | state_dim = env.observation_space.shape[0] 171 | action_dim = env.action_space.shape[0] 172 | 173 | agent = SAC_Agent(params['seed'], state_dim, action_dim, gamma=params['gamma'], buffer_size=init_buffer_size, 174 | target_entropy=params['target_entropy'], augment_sac=params['augment_sac'], 175 | rad_rollout=params['rad_rollout'], context_type=params['context_type']) 176 | 177 | trainer = Trainer(params, env, agent, device=device) 178 | 179 | return Assessor(trainer, d4rl_env) 180 | 181 | 182 | def real_start_states(env, n_states, device='cuda'): 183 | """ 184 | This just gets some real start states from the true gym environments 185 | """ 186 | start_states = [env.reset() for _ in range(n_states)] 187 | return torch.Tensor(start_states).to(device) 188 | 189 | 190 | def get_d4rl_start_states(d4rl_name: str, n_states: int, device='cuda'): 191 | """ 192 | This is in fact broken, because D4RL doesn't save initial states in slot 0 193 | """ 194 | d4rl_sequence = d4rl.sequence_dataset(gym.make(d4rl_name)) 195 | starts = [t['observations'][0] for t in d4rl_sequence] 196 | shuffle(starts) 197 | return torch.Tensor(starts).to(device)[:n_states] 198 | 199 | 200 | def get_policy_from_seed(seed: int, directory: str, env: gym.Env): 201 | """ 202 | Returns a SAC policy from seed and policies directory 203 | """ 204 | fname_match = directory + '/' + 'torch_policy_weights_{}*.pt'.format(seed) 205 | for f in glob.glob(fname_match): 206 | sac = SAC_Agent(1, env.observation_space.shape[0], env.action_space.shape[0]) 207 | saved_policy = torch.load(f) 208 | sac.policy.load_state_dict(saved_policy['policy_state_dict']) 209 | return sac 210 | else: 211 | print("Policy not found!") 212 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import pdb 6 | 7 | import gym 8 | from gym.wrappers import TimeLimit 9 | import numpy as np 10 | import pandas as pd 11 | import yaml 12 | import torch 13 | import d4rl 14 | 15 | from model import EnsembleGymEnv 16 | from done_funcs import hopper_is_done_func, walker2d_is_done_func, ant_is_done_func 17 | from sac import SAC_Agent 18 | from train_funcs import (collect_data, test_agent, train_agent, 19 | train_agent_model_free, train_agent_model_free_debug) 20 | from utils import MeanStdevFilter, reward_func 21 | from trainer import Trainer 22 | 23 | import sys 24 | 25 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 26 | 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | def between_0_1(x): 31 | x = float(x) 32 | if x > 1 or x < 0: 33 | raise argparse.ArgumentTypeError("This should be between 0 and 1 (inclusive)") 34 | return x 35 | 36 | 37 | def train_agent_new(params, online_yaml_config=None): 38 | params['zeros'] = False 39 | 40 | if params['reward_head']: 41 | env = gym.make(params['env_name']) 42 | eval_env = gym.make(params['env_name']) 43 | 44 | else: 45 | raise Exception('Environment Not Supported!') 46 | 47 | env_name_lower = params['env_name'].lower() 48 | 49 | if 'hopper' in env_name_lower: 50 | params['is_done_func'] = hopper_is_done_func 51 | elif 'walker' in env_name_lower: 52 | params['is_done_func'] = walker2d_is_done_func 53 | elif 'ant' in env_name_lower: 54 | params['is_done_func'] = ant_is_done_func 55 | else: 56 | params['is_done_func'] = None 57 | 58 | params['ob_dim'] = env.observation_space.shape[0] 59 | params['ac_dim'] = env.action_space.shape[0] 60 | 61 | env=EnsembleGymEnv(params, env, eval_env) 62 | 63 | state_dim = env.observation_space.shape[0] 64 | action_dim = env.action_space.shape[0] 65 | 66 | 67 | env.real_env.seed(params['seed']) 68 | env.eval_env.seed(params['seed'] + 1) 69 | env.real_env.action_space.seed(params['seed']) 70 | env.eval_env.action_space.seed(params['seed'] + 1) 71 | np.random.seed(params['seed']) 72 | random.seed(params['seed']) 73 | 74 | if isinstance(params['steps_k'], list): 75 | init_steps_k = params['steps_k'][0] 76 | else: 77 | init_steps_k = params['steps_k'] 78 | 79 | steps_per_epoch = params['epoch_steps'] if params['epoch_steps'] else env.real_env.env.spec.max_episode_steps 80 | 81 | if params['d4rl']: 82 | # Trying this option out right now. 83 | init_buffer_size = init_steps_k * params['num_rollouts_per_step'] * steps_per_epoch * params[ 84 | 'model_retain_epochs'] 85 | print('Initial Buffer Size: {} using model_retain_epochs={}'.format(init_buffer_size, 86 | params['model_retain_epochs'])) 87 | else: 88 | init_buffer_size = init_steps_k * params['num_rollouts_per_step'] * steps_per_epoch 89 | print('Initial Buffer Size: {}'.format(init_buffer_size)) 90 | 91 | agent = SAC_Agent(params['seed'], state_dim, action_dim, gamma=params['gamma'], buffer_size=init_buffer_size, 92 | target_entropy=params['target_entropy'], augment_sac=params['augment_sac'], 93 | rad_rollout=params['rad_rollout'], context_type=params['context_type']) 94 | 95 | 96 | 97 | trainer = Trainer(params, env, agent, device=device) 98 | 99 | total_timesteps = 0 100 | rewards, rewards_m, lambdas, steps_used, k_used, errors, varmean, samples = [], [], [], [], [], [], [], [] 101 | if params['d4rl']: 102 | print("\nRunning initial training with offline data...") 103 | timesteps, error, model_steps, rewards = trainer.train_offline(params['offline_epochs'], 104 | save_model=params['save_model'], 105 | save_policy=params['save_policy'], 106 | load_model_dir=params['load_model_dir'], 107 | ) 108 | total_timesteps += timesteps 109 | varmean.append(trainer.var_mean) 110 | else: 111 | print("\nCollecting random rollouts...") 112 | timesteps, error, model_steps = trainer.train_epoch(init=True) 113 | total_timesteps += timesteps 114 | varmean.append(trainer.var_mean) 115 | 116 | return rewards[-10:] 117 | 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--env_name', type=str, 123 | default='HalfCheetah-v2') ## only works properly for HalfCheetah and Ant 124 | parser.add_argument('--seed', '-se', type=int, default=0) 125 | parser.add_argument('--num_models', '-nm', type=int, default=7) 126 | parser.add_argument('--adapt', '-ad', type=int, default=0) ## set to 1 for adaptive 127 | # parser.add_argument('--steps', '-s', type=int, default=100) ## maximum time we step through an env per episode 128 | parser.add_argument('--steps_k', '-sk', type=int, # nargs='+', 129 | default=1) ## maximum time we step through an env to make artificial rollouts 130 | parser.add_argument('--reward_steps', '-rs', type=int, # nargs='+', 131 | default=10) ## maximum time we step through an env to make artificial rollouts to update reward estimator 132 | parser.add_argument('--outer_steps', '-in', type=int, 133 | default=3000) ## how many time steps/samples we collect each outer loop (including initially) 134 | parser.add_argument('--max_timesteps', '-maxt', type=int, default=6000) ## total number of timesteps 135 | parser.add_argument('--model_epochs', '-me', type=int, default=2000) ## max number of times we improve model 136 | parser.add_argument('--update_timestep', '-ut', type=int, 137 | default=50000) ## for PPO only; how many steps to accumulate before training on them 138 | parser.add_argument('--policy_iters', '-it', type=int, default=1000) ## max number of times we improve policy 139 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.1) 140 | parser.add_argument('--gamma', '-gm', type=float, default=0.99) 141 | parser.add_argument('--lam', '-la', type=float, default=0) 142 | parser.add_argument('--pca', '-pc', type=float, default=0) ## threshold for residual to stop, try [1e-4,2-e4] 143 | parser.add_argument('--sigma', '-si', type=float, default=0.01) 144 | parser.add_argument('--filename', '-f', type=str, default='ModelBased') 145 | parser.add_argument('--dir', '-d', type=str, default='data') 146 | parser.add_argument('--yaml_file', '-yml', type=str, default=None) 147 | parser.add_argument('--uuid', '-id', type=str, default=None) 148 | parser.add_argument('--fix_std', dest='fix_std', action='store_true') 149 | parser.add_argument('--var_type', type=str, default='reward', choices=('reward', 'state')) 150 | parser.add_argument('--states', type=str, default='uniform', choices=('uniform', 'start', 'entropy')) 151 | parser.add_argument('--reward_head', '-rh', type=int, default=1) # 1 or 0 152 | parser.add_argument('--model_free', dest='model_free', action='store_true') 153 | parser.add_argument('--var_max', dest='var_max', action='store_true') 154 | parser.add_argument('--no_logvar_head', dest='logvar_head', action='store_false') 155 | parser.add_argument('--comment', '-c', type=str, default=None) 156 | parser.add_argument('--policy_update_steps', type=int, default=40) 157 | parser.add_argument('--init_collect', type=int, default=5000) 158 | parser.add_argument('--train_policy_every', type=int, default=1) 159 | parser.add_argument('--num_rollouts_per_step', type=int, default=400) 160 | parser.add_argument('--n_eval_rollouts', type=int, default=10) 161 | parser.add_argument('--train_val_ratio', type=float, default=0.2) 162 | parser.add_argument('--real_sample_ratio', type=float, default=0.05) 163 | parser.add_argument('--model_train_freq', type=int, default=250) 164 | parser.add_argument('--rollout_model_freq', type=int, default=250) 165 | parser.add_argument('--oac', type=bool, default=False) 166 | parser.add_argument('--espi', type=bool, default=False) 167 | parser.add_argument('--num_elites', type=int, default=5) 168 | parser.add_argument('--var_thresh', type=float, default=100) 169 | parser.add_argument('--epoch_steps', type=int, default=None) 170 | parser.add_argument('--target_entropy', type=float, default=None) 171 | parser.add_argument('--log_interval', type=int, default=100) 172 | parser.add_argument('--d4rl', dest='d4rl', action='store_true') 173 | parser.add_argument('--train_memory', type=int, default=800000) 174 | parser.add_argument('--val_memory', type=int, default=200000) 175 | parser.add_argument('--mopo', dest='mopo', action='store_true') 176 | parser.add_argument('--morel', dest='morel', action='store_true') 177 | # MOPO/MOReL tuning parameters 178 | parser.add_argument('--mopo_lam', type=float, default=1) 179 | parser.add_argument('--morel_thresh', type=between_0_1, default=0.3) 180 | parser.add_argument('--morel_halt_reward', type=float, default=-10) 181 | # This basically says to not truncate rollouts, but to keep going (like M2AC Non-Stop mode) 182 | parser.add_argument('--morel_non_stop', type=bool, default=False) 183 | parser.add_argument('--tune_mopo_lam', dest='tune_mopo_lam', action='store_true') 184 | parser.add_argument('--mopo_penalty_type', type=str, default='mopo_default', choices=( 185 | 'mopo_default', 'ensemble_var', 'ensemble_std', 'ensemble_var_rew', 'ensemble_var_comb', 'mopo_paper', 'lompo', 'm2ac', 'morel')) 186 | parser.add_argument('--mopo_uncertainty_target', type=float, default=1.5) 187 | 188 | parser.add_argument('--offline_epochs', type=int, default=1000) 189 | parser.add_argument('--model_retain_epochs', type=int, default=100) 190 | parser.add_argument('--save_model', type=bool, default=False) 191 | parser.add_argument('--transfer', type=bool, default=False) 192 | parser.add_argument('--load_model_dir', type=str, default=None) 193 | parser.add_argument('--deterministic_rollouts', type=bool, default=False) 194 | # Needed as some models seem to early terminate (this happens in author's code too, so not a PyTorch issue) 195 | parser.add_argument('--min_model_epochs', type=int, default=None) 196 | parser.add_argument('--augment_offline_data', type=bool, default=False) 197 | parser.add_argument('--augment_sac', type=bool, default=False) 198 | parser.add_argument('--context_type', type=str, default='rad_augmentation') 199 | parser.add_argument('--rad_rollout', type=bool, default=False) 200 | parser.add_argument('--save_policy', type=bool, default=False) 201 | parser.add_argument('--population_model_dirs', type=str, default=[], nargs="*") 202 | parser.add_argument('--ensemble_replace_model_dirs', type=str, default=[], nargs="*") 203 | parser.add_argument('--l2_reg_multiplier', type=float, default=1.) 204 | parser.add_argument('--model_lr', type=float, default=0.001) 205 | parser.set_defaults(fix_std=False) 206 | parser.set_defaults(model_free=False) 207 | parser.set_defaults(var_max=False) 208 | parser.set_defaults(logvar_head=True) 209 | 210 | args = parser.parse_args() 211 | params = vars(args) 212 | 213 | if params['yaml_file']: 214 | with open(args.yaml_file, 'r') as f: 215 | yaml_config = yaml.load(f, Loader=yaml.FullLoader) 216 | for config in yaml_config['args']: 217 | if config in params: 218 | params[config] = yaml_config['args'][config] 219 | 220 | online_yaml_config = None 221 | 222 | assert isinstance(params['steps_k'], (int, list)), "must be either a single input or a collection" 223 | 224 | if isinstance(params['steps_k'], list): 225 | assert len(params[ 226 | 'steps_k']) == 4, "if a list of inputs, must have 4 inputs (start steps, end steps, start epoch, end epoch)" 227 | 228 | time.sleep(random.random()) 229 | if not (os.path.exists(params['dir'])): 230 | os.makedirs(params['dir']) 231 | os.chdir(params['dir']) 232 | 233 | if params['uuid']: 234 | if not (os.path.exists(params['uuid'])): 235 | os.makedirs(params['uuid']) 236 | os.chdir(params['uuid']) 237 | 238 | rewards = train_agent_new(params, online_yaml_config) 239 | rewards = np.array(rewards) 240 | print(rewards) 241 | sys.stderr.write(str(np.mean(rewards))) 242 | 243 | return np.mean(rewards) 244 | 245 | 246 | if __name__ == '__main__': 247 | main() 248 | -------------------------------------------------------------------------------- /sac.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions import Normal, TransformedDistribution 10 | from tqdm import tqdm 11 | 12 | from utils import ReplayPoolCtxt, ReplayPool, FasterReplayPool, FasterReplayPoolCtxt, TanhTransform, Transition, TransitionContext, filter_torch 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class MLPNetwork(nn.Module): 18 | 19 | def __init__(self, input_dim, output_dim, hidden_size=256): 20 | super(MLPNetwork, self).__init__() 21 | self.network = nn.Sequential( 22 | nn.Linear(input_dim, hidden_size), 23 | nn.ReLU(), 24 | nn.Linear(hidden_size, hidden_size), 25 | nn.ReLU(), 26 | nn.Linear(hidden_size, hidden_size), 27 | nn.ReLU(), 28 | nn.Linear(hidden_size, output_dim), 29 | ) 30 | 31 | def forward(self, x): 32 | return self.network(x) 33 | 34 | 35 | class Policy(nn.Module): 36 | 37 | def __init__(self, state_dim, action_dim, hidden_size=256): 38 | super(Policy, self).__init__() 39 | self.action_dim = action_dim 40 | self.network = MLPNetwork(state_dim, action_dim * 2, hidden_size) 41 | 42 | def stable_network_forward(self, x): 43 | mu_logstd = self.network(x) 44 | mu, logstd = mu_logstd.chunk(2, dim=1) 45 | logstd = torch.clamp(logstd, -20, 2) 46 | return mu, logstd 47 | 48 | def compute_action(self, mu, std, get_logprob=False): 49 | dist = Normal(mu, std) 50 | transforms = [TanhTransform(cache_size=1)] 51 | dist = TransformedDistribution(dist, transforms) 52 | action = dist.rsample() 53 | if get_logprob: 54 | logprob = dist.log_prob(action).sum(axis=-1, keepdim=True) 55 | else: 56 | logprob = None 57 | mean = torch.tanh(mu) 58 | return action, logprob, mean 59 | 60 | def forward(self, x, get_logprob=False): 61 | mu, logstd = self.stable_network_forward(x) 62 | std = logstd.exp() 63 | return self.compute_action(mu, std, get_logprob) 64 | 65 | 66 | class DoubleQFunc(nn.Module): 67 | 68 | def __init__(self, state_dim, action_dim, hidden_size=256): 69 | super(DoubleQFunc, self).__init__() 70 | self.network1 = MLPNetwork(state_dim + action_dim, 1, hidden_size) 71 | self.network2 = MLPNetwork(state_dim + action_dim, 1, hidden_size) 72 | 73 | def forward(self, state, action): 74 | x = torch.cat((state, action), dim=1) 75 | return self.network1(x), self.network2(x) 76 | 77 | 78 | class SAC_Agent: 79 | 80 | def __init__(self, seed, state_dim, action_dim, lr=3e-4, gamma=0.99, tau=5e-3, batchsize=256, hidden_size=256, 81 | update_interval=1, buffer_size=1e6, target_entropy=None, augment_sac=False, rad_rollout=False, 82 | context_type='rad_augmentation'): 83 | self.gamma = gamma 84 | self.tau = tau 85 | self.target_entropy = target_entropy if target_entropy else -action_dim / 2 86 | self.batchsize = batchsize 87 | self.update_interval = update_interval 88 | 89 | torch.manual_seed(seed) 90 | 91 | # context-sac 92 | self.augment_sac = augment_sac 93 | self.rad_rollout = rad_rollout 94 | self.context_type = context_type 95 | 96 | original_state_dim = state_dim 97 | 98 | if self.augment_sac: 99 | if context_type == 'rad_augmentation': 100 | print('Augmenting state vector with context_type={}.'.format(context_type)) 101 | state_dim *= 2 102 | elif context_type == 'rad_magnitude': 103 | state_dim += 1 104 | 105 | # aka critic 106 | self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) 107 | self.target_q_funcs = copy.deepcopy(self.q_funcs) 108 | self.target_q_funcs.eval() 109 | for p in self.target_q_funcs.parameters(): 110 | p.requires_grad = False 111 | 112 | # aka actor 113 | self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) 114 | 115 | # aka temperature 116 | self.log_alpha = torch.zeros(1, requires_grad=True, device=device) 117 | self.alpha = self.log_alpha.exp() 118 | 119 | self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) 120 | self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) 121 | self.temp_optimizer = torch.optim.Adam([self.log_alpha], lr=lr) 122 | 123 | if augment_sac and rad_rollout: 124 | # self.replay_pool = ReplayPoolCtxt(capacity=int(buffer_size)) 125 | self.replay_pool = FasterReplayPoolCtxt(action_dim=action_dim, state_dim=original_state_dim, capacity=int(buffer_size)) 126 | else: 127 | # self.replay_pool = ReplayPool(capacity=int(buffer_size)) 128 | self.replay_pool = FasterReplayPool(action_dim=action_dim, state_dim=original_state_dim, capacity=int(buffer_size)) 129 | 130 | def reallocate_replay_pool(self, new_size: int): 131 | assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" 132 | new_replay_pool = FasterReplayPoolCtxt(self.replay_pool._action_dim, self.replay_pool._action_dim, capacity=new_size) 133 | new_replay_pool.initialise(self.replay_pool) 134 | self.replay_pool = new_replay_pool 135 | 136 | def get_action(self, state, state_filter=None, deterministic=False, oac=False): 137 | if state_filter: 138 | state = state_filter(state) 139 | state = torch.Tensor(state).view(1, -1).to(device) 140 | if oac: 141 | action, _, mean = self._get_optimistic_action(state) 142 | else: 143 | with torch.no_grad(): 144 | action, _, mean = self.policy(state) 145 | if deterministic: 146 | return np.atleast_1d(mean.squeeze().cpu().numpy()) 147 | return np.atleast_1d(action.squeeze().cpu().numpy()) 148 | 149 | def _get_optimistic_action(self, state, get_logprob=False): 150 | 151 | beta_UB = 4.66 # Table 1: https://arxiv.org/pdf/1910.12807.pdf 152 | delta = 23.53 # Table 1: https://arxiv.org/pdf/1910.12807.pdf 153 | 154 | mu, logvar = self.policy.stable_network_forward(state) 155 | mu.requires_grad_() 156 | std = logvar.exp() 157 | 158 | action = torch.tanh(mu) 159 | q_1, q_2 = self.q_funcs(state, action) 160 | 161 | mu_Q = (q_1 + q_2) / 2.0 162 | 163 | sigma_Q = torch.abs(q_1 - q_2) / 2.0 164 | 165 | Q_UB = mu_Q + beta_UB * sigma_Q 166 | 167 | grad = torch.autograd.grad(Q_UB, mu) 168 | grad = grad[0] 169 | 170 | grad = grad.detach() 171 | mu = mu.detach() 172 | std = std.detach() 173 | 174 | Sigma_T = torch.pow(std.detach(), 2) 175 | denom = torch.sqrt( 176 | torch.sum(torch.mul(torch.pow(grad, 2), Sigma_T))) + 10e-6 177 | 178 | # Obtain the change in mu 179 | mu_C = math.sqrt(2.0 * delta) * torch.mul(Sigma_T, grad) / denom 180 | 181 | mu_E = mu + mu_C 182 | 183 | assert mu_E.shape == std.shape 184 | 185 | # dist = TanhNormal(mu_E, std) 186 | # action = dist.sample() 187 | 188 | return self.policy.compute_action(mu_E, std, get_logprob=get_logprob) 189 | 190 | def update_target(self): 191 | """moving average update of target networks""" 192 | with torch.no_grad(): 193 | for target_q_param, q_param in zip(self.target_q_funcs.parameters(), self.q_funcs.parameters()): 194 | target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data) 195 | 196 | def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, done_batch): 197 | with torch.no_grad(): 198 | nextaction_batch, logprobs_batch, _ = self.policy(nextstate_batch, get_logprob=True) 199 | q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch) 200 | # take min to mitigate positive bias in q-function training 201 | q_target = torch.min(q_t1, q_t2) 202 | value_target = reward_batch + (1.0 - done_batch) * self.gamma * (q_target - self.alpha * logprobs_batch) 203 | q_1, q_2 = self.q_funcs(state_batch, action_batch) 204 | loss_1 = F.mse_loss(q_1, value_target) 205 | loss_2 = F.mse_loss(q_2, value_target) 206 | return loss_1, loss_2 207 | 208 | def update_policy_and_temp(self, state_batch): 209 | action_batch, logprobs_batch, _ = self.policy(state_batch, get_logprob=True) 210 | q_b1, q_b2 = self.q_funcs(state_batch, action_batch) 211 | qval_batch = torch.min(q_b1, q_b2) 212 | policy_loss = (self.alpha * logprobs_batch - qval_batch).mean() 213 | temp_loss = -self.log_alpha.exp() * (logprobs_batch.detach() + self.target_entropy).mean() 214 | return policy_loss, temp_loss 215 | 216 | def optimize(self, n_updates, state_filter=None, env_pool=None, env_ratio=0.05, augment_data=False,reward_function=None): 217 | q1_loss, q2_loss, pi_loss, a_loss = 0, 0, 0, 0 218 | 219 | hide_progress = True if n_updates < 50 else False 220 | 221 | for i in tqdm(range(n_updates), disable=hide_progress, ncols=100): 222 | if env_pool and env_ratio != 0: 223 | n_env_samples = int(env_ratio * self.batchsize) 224 | n_model_samples = self.batchsize - n_env_samples 225 | env_samples = env_pool.sample(n_env_samples)._asdict() 226 | model_samples = self.replay_pool.sample(n_model_samples)._asdict() 227 | if self.augment_sac and self.rad_rollout: 228 | samples = TransitionContext(*[env_samples[key] + model_samples[key] for key in env_samples]) 229 | else: 230 | samples = Transition(*[env_samples[key] + model_samples[key] for key in env_samples]) 231 | else: 232 | samples = self.replay_pool.sample(self.batchsize) 233 | #print(len(samples),samples) 234 | if state_filter: 235 | state_batch = torch.FloatTensor(state_filter(samples.state)).to(device) 236 | nextstate_batch = torch.FloatTensor(state_filter(samples.nextstate)).to(device) 237 | else: 238 | state_batch = torch.FloatTensor(samples.state).to(device) 239 | nextstate_batch = torch.FloatTensor(samples.nextstate).to(device) 240 | 241 | if self.augment_sac and self.rad_rollout: 242 | # Concatenate the context with the state after filtering, this is done on model before 243 | rad_batch = torch.FloatTensor(samples.rad_context).to(device) 244 | state_batch = torch.cat((state_batch, rad_batch), 1) 245 | nextstate_batch = torch.cat((nextstate_batch, rad_batch), 1) 246 | action_batch = torch.FloatTensor(samples.action).to(device) 247 | reward_batch = torch.FloatTensor(samples.reward).to(device).unsqueeze(1) 248 | if reward_function: 249 | #print('before:',reward_batch) 250 | reward_batch += reward_function(torch.cat((state_batch,action_batch),1)) 251 | #print('after:',reward_batch) 252 | done_batch = torch.FloatTensor(samples.real_done).to(device).unsqueeze(1) 253 | 254 | if augment_data: 255 | # Delta context 256 | magnitude = 0.5 257 | high = 1 + magnitude 258 | low = 1 - magnitude 259 | scale = high - low 260 | 261 | # Direct nextstate augmentation 262 | # # magnitude = np.random.uniform(0, 0.5) 263 | # random_amplitude_scaling = (torch.rand(state_batch.shape) * scale + low).to(device) 264 | # # state_batch *= random_amplitude_scaling 265 | # nextstate_batch *= random_amplitude_scaling 266 | 267 | # random_amplitude_scaling = (torch.rand(state_batch.shape[0]) * scale + low).unsqueeze(1).to(device) 268 | random_amplitude_scaling = (torch.rand(state_batch.shape) * scale + low).to(device) 269 | delta_batch = nextstate_batch - state_batch 270 | delta_batch *= random_amplitude_scaling 271 | nextstate_batch = state_batch + delta_batch 272 | 273 | # Additive Noise 274 | # random_amplitude_scaling = torch.randn_like(state_batch) * 0.1 275 | # nextstate_batch += random_amplitude_scaling 276 | 277 | if self.augment_sac and not self.rad_rollout and self.context_type == 'rad_augmentation': 278 | state_batch = torch.cat((state_batch, random_amplitude_scaling), 1) 279 | nextstate_batch = torch.cat((nextstate_batch, random_amplitude_scaling), 1) 280 | elif self.augment_sac and not self.rad_rollout and self.context_type == 'rad_magnitude': 281 | state_batch = torch.cat((state_batch, magnitude * torch.ones(state_batch.shape[0], 1).to(device)), 1) 282 | nextstate_batch = torch.cat((nextstate_batch, magnitude * torch.ones(state_batch.shape[0], 1).to(device)), 1) 283 | 284 | # update q-funcs 285 | q1_loss_step, q2_loss_step = self.update_q_functions(state_batch, action_batch, reward_batch, 286 | nextstate_batch, done_batch) 287 | q_loss_step = q1_loss_step + q2_loss_step 288 | self.q_optimizer.zero_grad() 289 | q_loss_step.backward() 290 | self.q_optimizer.step() 291 | 292 | # update policy and temperature parameter 293 | for p in self.q_funcs.parameters(): 294 | p.requires_grad = False 295 | pi_loss_step, a_loss_step = self.update_policy_and_temp(state_batch) 296 | self.policy_optimizer.zero_grad() 297 | pi_loss_step.backward() 298 | self.policy_optimizer.step() 299 | self.temp_optimizer.zero_grad() 300 | a_loss_step.backward() 301 | self.temp_optimizer.step() 302 | for p in self.q_funcs.parameters(): 303 | p.requires_grad = True 304 | 305 | self.alpha = self.log_alpha.exp() 306 | 307 | q1_loss += q1_loss_step.detach().item() 308 | q2_loss += q2_loss_step.detach().item() 309 | pi_loss += pi_loss_step.detach().item() 310 | a_loss += a_loss_step.detach().item() 311 | if i % self.update_interval == 0: 312 | self.update_target() 313 | return q1_loss, q2_loss, pi_loss, a_loss 314 | 315 | def save_policy(self, save_path, num_epochs, rew=None): 316 | q_funcs, target_q_funcs, policy, log_alpha = self.q_funcs, self.target_q_funcs, self.policy, self.log_alpha 317 | 318 | if rew is None: 319 | save_path = os.path.join(save_path, "torch_policy_weights_{}_epochs.pt".format(num_epochs)) 320 | else: 321 | save_path = os.path.join(save_path, "torch_policy_weights_{}_epochs_{}.pt".format(num_epochs, rew)) 322 | 323 | torch.save({ 324 | 'double_q_state_dict': q_funcs.state_dict(), 325 | 'target_double_q_state_dict': target_q_funcs.state_dict(), 326 | 'policy_state_dict': policy.state_dict(), 327 | 'log_alpha_state_dict': log_alpha 328 | }, save_path) 329 | -------------------------------------------------------------------------------- /train_funcs.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | from sac import SAC_Agent 8 | from utils import (filter_torch, filter_torch_invert, get_residual, get_stats, 9 | random_env_forward, torch_reward, Transition) 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | def collect_data(params, agent: SAC_Agent, ensemble_env, init=False): 15 | rollouts = [] 16 | timesteps = 0 17 | env = ensemble_env.real_env 18 | collection_timesteps = params['init_collect'] if init else params['outer_steps'] 19 | pca_data = [] 20 | # Standard RL interaction loop with the real env 21 | residual = 1 22 | while timesteps < collection_timesteps: 23 | rollout = [] 24 | done = False 25 | env_ts = 0 26 | state = env.reset() 27 | ensemble_env.state_filter.update(state) 28 | newdata = [] 29 | while (not done): 30 | #NB: No state filtering 31 | if init: 32 | action = env.action_space.sample() 33 | else: 34 | action = agent.get_action(state) 35 | newdata.append(np.concatenate((state, action))) 36 | nextstate, reward, done, _ = env.step(action) 37 | rollout.append(Transition(state, action, reward, nextstate, False)) 38 | state = nextstate 39 | ensemble_env.state_filter.update(state) 40 | ensemble_env.action_filter.update(action) 41 | timesteps += 1 42 | env_ts += 1 43 | 44 | if residual < params['pca']: 45 | collection_timesteps = 0 46 | 47 | if (timesteps) % 100 == 0: 48 | print("Collected Timesteps: %s" %(timesteps)) 49 | 50 | if len(pca_data) > 0: 51 | residual, train_resid = get_residual(newdata, pca_data, 0.99) 52 | print("Residual = {}, Train Residual = {}".format(str(residual), str(train_resid))) 53 | pca_data += newdata 54 | rollouts.append(rollout) 55 | 56 | num_valid = int(np.floor(ensemble_env.model.train_val_ratio * len(rollouts))) 57 | train = rollouts[(num_valid):] 58 | valid = rollouts[:num_valid] 59 | for rollout in train: 60 | ensemble_env.model.add_data(rollout) 61 | for rollout in valid: 62 | ensemble_env.model.add_data_validation(rollout) 63 | print("\nAdded {} samples to the model, {} for valid".format(str(len(train)), str(len(valid)))) 64 | 65 | ensemble_env.update_diff_filter() 66 | errors = [ensemble_env.model.models[i].get_acquisition(rollouts, ensemble_env.state_filter, ensemble_env.action_filter, ensemble_env.diff_filter) for i in range(params['num_models'])] 67 | error = np.sqrt(np.mean(np.array(errors)**2)) 68 | print("\nMSE Loss on new rollouts: %s" % error) 69 | return(timesteps, error) 70 | 71 | 72 | def train_agent(params, agent: SAC_Agent, env, policy_iters, update_timestep, env_resets, log_interval, lam=0, n_parallel=500): 73 | running_reward = 0 74 | avg_length = 0 75 | time_step = 0 76 | n_updates = 0 77 | i_episode = 0 78 | prev_performance = np.array([-np.inf for _ in range(len(env.model.models))]) 79 | rewards_history = deque(maxlen=6) 80 | best_weights = None 81 | is_done_func = env.model.is_done_func 82 | if params['var_type'] == 'reward': 83 | state_dynamics = False 84 | elif params['var_type'] == 'state': 85 | state_dynamics = True 86 | else: 87 | raise Exception("Variance must either be 'reward' or 'state'") 88 | 89 | for model in env.model.models.values(): 90 | model.to(device) 91 | 92 | env.state_filter.update_torch() 93 | env.action_filter.update_torch() 94 | env.diff_filter.update_torch() 95 | 96 | done_true = [True for _ in range(n_parallel)] 97 | done_false = [False for _ in range(n_parallel)] 98 | 99 | start_states_validate = torch.FloatTensor(env_resets).to(device) 100 | 101 | grad_per_timesteps = params['grad_per_timesteps'] 102 | collection_timesteps = params['outer_steps'] 103 | 104 | while n_updates < policy_iters: 105 | 106 | if params['states'] == 'uniform': 107 | env_resets = np.array(env.model.memory.sample(n_parallel)[0]) 108 | elif params['states'] == 'entropy': 109 | s = torch.FloatTensor(env.model.memory.get_all()[0]).to(device) 110 | s_f = env.state_filter.filter_torch(s) 111 | a = ppo.policy.actor(s_f) #deterministic 112 | u = get_stats(env, s, a, env.state_filter, env.action_filter, env.diff_filter, False, state_dynamics, params['reward_head']) 113 | neg_ent = -np.log(u * np.sqrt(2 * np.pi * np.e)) 114 | probs = np.exp(neg_ent) / (1 + np.exp(neg_ent)) 115 | dist = probs / np.sum(probs) 116 | sample = np.random.choice(s.shape[0], n_parallel, p=dist.flatten()) 117 | env_resets = np.take(np.array(env.model.memory.get_all()[0]), sample, axis=0) 118 | 119 | start_states = torch.FloatTensor(env_resets).to(device) 120 | 121 | i_episode += n_parallel 122 | state = start_states.clone() 123 | prev_done = done_false 124 | var = 0 125 | t = 0 126 | while t < params['steps_k']: 127 | # state_f = env.state_filter.filter_torch(state) 128 | time_step += n_parallel 129 | t += 1 130 | with torch.no_grad(): 131 | # TODO: Random steps intially? 132 | action, _, _ = agent.policy(state) 133 | # TODO: FIX THIS! Filters should be a member of Ensemble, not the wrapper 134 | nextstate, reward = env.model.random_env_step(state, action, env.state_filter, env.action_filter, env.diff_filter) 135 | if is_done_func: 136 | done = is_done_func(nextstate).cpu().numpy() 137 | done[prev_done] = True 138 | prev_done = done 139 | else: 140 | if t >= params['steps_k']: 141 | done = done_false 142 | else: 143 | done = done_false 144 | uncert = get_stats(env, state, action, env.state_filter, env.action_filter, env.diff_filter, done, state_dynamics, params['reward_head']) 145 | uncert = 0 146 | if params['reward_head']: 147 | reward = reward.cpu().detach().numpy() 148 | else: 149 | reward = torch_reward(env.name, nextstate, action, done) 150 | reward = (1-lam) * reward + lam * uncert 151 | for s, a, r, s_n, d in zip(state, action, reward, nextstate, done): 152 | s, a, s_n = s.detach().cpu().numpy(), a.detach().cpu().numpy(), s_n.detach().cpu().numpy() 153 | r = r[0] 154 | agent.replay_pool.push(Transition(s, a, r, s_n, d)) 155 | state = nextstate 156 | running_reward += reward 157 | var += uncert**2 158 | # update if it's time 159 | if time_step % update_timestep == 0: 160 | agent.optimize(n_updates=int(update_timestep / 10)) 161 | time_step = 0 162 | n_updates += 1 163 | if n_updates > 0: 164 | improved, prev_performance = validate_agent_with_ensemble(agent, env, start_states_validate, env.state_filter, env.action_filter, env.diff_filter, prev_performance, 0.7, params['steps_k'], params['reward_head']) 165 | if improved: 166 | best_weights = agent.q_funcs.state_dict(), agent.target_q_funcs.state_dict(), agent.policy.state_dict(), agent.log_alpha 167 | best_update = n_updates 168 | rewards_history.append(improved) 169 | if len(rewards_history) > 5: 170 | if rewards_history[0] > max(np.array(rewards_history)[1:]): 171 | print('Policy Stopped Improving after {} updates'.format(best_update)) 172 | agent.q_funcs.load_state_dict(best_weights[0]) 173 | agent.target_q_funcs.load_state_dict(best_weights[1]) 174 | agent.policy.load_state_dict(best_weights[2]) 175 | agent.log_alpha = best_weights[3] 176 | return 177 | avg_length += t * n_parallel 178 | if i_episode % log_interval == 0: 179 | avg_length = int(avg_length/log_interval) 180 | running_reward = int((running_reward.sum()/log_interval)) 181 | print('Episode {} \t Avg length: {} \t Avg reward: {} \t Number of Policy Updates: {}'.format(i_episode, avg_length, running_reward, n_updates)) 182 | running_reward = 0 183 | avg_length = 0 184 | 185 | 186 | def validate_agent_with_ensemble(agent, env, start_states, state_filter, action_filter, diff_filter, best_performance, threshold, ep_steps, reward_head): 187 | 188 | n_parallel = start_states.shape[0] 189 | 190 | performance = np.zeros(len(env.model.models)) 191 | is_done_func = env.model.is_done_func 192 | 193 | done_true = [False for _ in range(n_parallel)] 194 | done_false = [False for _ in range(n_parallel)] 195 | 196 | for i in env.model.models: 197 | total_reward = 0 198 | state = start_states.clone() 199 | prev_done = done_false 200 | t = 0 201 | while t < ep_steps: 202 | state_f = state_filter.filter_torch(state) 203 | t += 1 204 | with torch.no_grad(): 205 | _, _, action = agent.policy(state_f) 206 | action = torch.clamp(action, env.action_bounds.lowerbound[0], env.action_bounds.upperbound[0]) 207 | nextstate, reward = env.model.models[i].get_next_state_reward(state, action, state_filter, action_filter, diff_filter) 208 | if is_done_func: 209 | done = is_done_func(nextstate).cpu().numpy() 210 | done[prev_done] = True 211 | prev_done = done 212 | else: 213 | if t >= ep_steps: 214 | done = done_true 215 | else: 216 | done = done_false 217 | if reward_head: 218 | reward = reward.cpu().detach().numpy() 219 | else: 220 | reward = torch_reward(env.name, nextstate, action, done) 221 | state = nextstate 222 | total_reward += np.mean(reward) 223 | performance[i] = total_reward 224 | if (np.mean(performance > best_performance) > threshold): 225 | new_best_performance = np.maximum(performance, best_performance) 226 | return True, new_best_performance 227 | else: 228 | new_best_performance = best_performance 229 | return False, new_best_performance 230 | 231 | 232 | def test_agent(agent: SAC_Agent, env, ep_steps, subset_resets, subset_real_resets, use_model): 233 | num_rollouts = len(subset_resets) 234 | if use_model: 235 | test_env = env 236 | else: 237 | test_env = env.real_env 238 | half = int(np.ceil(len(subset_real_resets[0]) / 2)) 239 | total_reward = 0 240 | for reset, real_reset in zip(subset_resets, subset_real_resets): 241 | time_step = 0 242 | done = False 243 | test_env.reset() 244 | state = reset 245 | if use_model: 246 | test_env.current_state = state 247 | else: 248 | test_env.env.unwrapped.set_state(real_reset[:half], real_reset[half:]) 249 | while (not done) and (time_step < ep_steps): 250 | time_step += 1 251 | action = agent.get_action(state, deterministic=True) 252 | state, reward, done, _ = test_env.step(action) 253 | total_reward += reward 254 | return total_reward / num_rollouts 255 | 256 | 257 | def train_agent_model_free(ppo, ensemble_env, memory, update_timestep, seed, log_interval, ep_steps, start_states, start_real_states): 258 | # logging variables 259 | running_reward = 0 260 | running_reward_real = 0 261 | avg_length = 0 262 | time_step = 0 263 | cumulative_update_timestep = 0 264 | cumulative_log_timestep = 0 265 | n_updates = 0 266 | i_episode = 0 267 | log_episode = 0 268 | samples_number = 0 269 | samples = [] 270 | rewards = [] 271 | n_starts = len(start_states) 272 | 273 | env_name = ensemble_env.unwrapped.spec.id 274 | 275 | state_filter = ensemble_env.state_filter 276 | 277 | half = int(np.ceil(len(start_real_states[0]) / 2)) 278 | 279 | env = ensemble_env.real_env 280 | 281 | memory.clear_memory() 282 | 283 | while samples_number < 3e7: 284 | for reset, real_reset in zip(start_states, start_real_states): 285 | time_step = 0 286 | done = False 287 | env.reset() 288 | state = reset 289 | env.unwrapped.set_state(real_reset[:half], real_reset[half:]) 290 | i_episode += 1 291 | log_episode += 1 292 | state = env.reset() 293 | state_filter.update(state) 294 | state = state_filter.filter(state) 295 | done = False 296 | 297 | while (not done): 298 | cumulative_log_timestep += 1 299 | cumulative_update_timestep += 1 300 | time_step += 1 301 | samples_number += 1 302 | action = ppo.select_action(state_filter.filter(state), memory) 303 | nextstate, reward, done, _ = env.step(action) 304 | state = nextstate 305 | state_filter.update(state) 306 | 307 | memory.rewards.append(np.array([reward])) 308 | memory.is_terminals.append(np.array([done])) 309 | 310 | running_reward += reward 311 | 312 | # update if it's time 313 | if cumulative_update_timestep % update_timestep == 0: 314 | ppo.update(memory) 315 | memory.clear_memory() 316 | cumulative_update_timestep = 0 317 | n_updates += 1 318 | 319 | # logging 320 | if i_episode % log_interval == 0: 321 | subset_resets_idx = np.random.randint(0, n_starts, 10) 322 | subset_resets = start_states[subset_resets_idx] 323 | subset_resets_real = start_real_states[subset_resets_idx] 324 | avg_length = int(cumulative_log_timestep/log_episode) 325 | running_reward = int((running_reward_real/log_episode)) 326 | actual_reward = test_agent(ppo, ensemble_env, memory, ep_steps, subset_resets, subset_resets_real, use_model=False) 327 | samples.append(samples_number) 328 | rewards.append(actual_reward) 329 | print('Episode {} \t Samples {} \t Avg length: {} \t Avg reward: {} \t Actual reward: {} \t Number of Policy Updates: {}'.format(i_episode, samples_number, avg_length, running_reward, actual_reward, n_updates)) 330 | df = pd.DataFrame({'Samples': samples, 'Reward': rewards}) 331 | df.to_csv("{}.csv".format(env_name + '-ModelFree-Seed-' + str(seed))) 332 | cumulative_log_timestep = 0 333 | log_episode = 0 334 | running_reward = 0 335 | 336 | 337 | def train_agent_model_free_debug(ppo, ensemble_env, memory, update_timestep, log_interval, reward_func=None): 338 | # logging variables 339 | running_reward = 0 340 | running_reward_no_filter = 0 341 | running_reward_real = 0 342 | avg_length = 0 343 | time_step = 0 344 | cumulative_update_timestep = 0 345 | cumulative_log_timestep = 0 346 | n_updates = 0 347 | i_episode = 0 348 | log_episode = 0 349 | samples_number = 0 350 | samples = [] 351 | rewards = [] 352 | rewards_real = [] 353 | 354 | seed = 0 355 | env_name = ensemble_env.unwrapped.spec.id 356 | 357 | state_filter = ensemble_env.state_filter 358 | 359 | env = ensemble_env.real_env 360 | 361 | if hasattr(env, 'is_done_func'): 362 | is_done_func = env.is_done_func 363 | else: 364 | is_done_func = None 365 | 366 | memory.clear_memory() 367 | 368 | while samples_number < 2e7: 369 | i_episode += 1 370 | log_episode += 1 371 | state = env.reset() 372 | state_filter.update(state) 373 | state = state_filter.torch_filter(state) 374 | done = False 375 | while (not done): 376 | cumulative_log_timestep += 1 377 | cumulative_update_timestep += 1 378 | time_step += 1 379 | samples_number += 1 380 | action = ppo.select_action(state_filter.filter(state), memory) 381 | nextstate, reward, done, _ = env.step(action) 382 | running_reward_no_filter += reward_func(state, nextstate, action, env_name, is_done_func=is_done_func) 383 | running_reward_real += reward 384 | reward = reward_func(state_filter.filter(state), state_filter.filter(nextstate), action, env_name, state_filter, is_done_func=is_done_func) 385 | state = nextstate 386 | state_filter.update(state) 387 | 388 | memory.rewards.append(np.array([reward])) 389 | memory.is_terminals.append(np.array([done])) 390 | 391 | running_reward += reward 392 | 393 | # update if it's time 394 | if cumulative_update_timestep % update_timestep == 0: 395 | ppo.update(memory) 396 | memory.clear_memory() 397 | cumulative_update_timestep = 0 398 | n_updates += 1 399 | 400 | # logging 401 | if i_episode % log_interval == 0: 402 | avg_length = int(cumulative_log_timestep/log_episode) 403 | running_reward = int((running_reward/log_episode)) 404 | running_reward_no_filter = int((running_reward_no_filter/log_episode)) 405 | running_reward_real = int((running_reward_real/log_episode)) 406 | samples.append(samples_number) 407 | rewards.append(running_reward) 408 | rewards_real.append(running_reward_real) 409 | print('Episode {} \t Samples {} \t Avg length: {} \t Avg reward: {} \t Avg reward no filter: {} \t Avg real reward: {} \t Number of Policy Updates: {}'.format(i_episode, samples_number, avg_length, running_reward, running_reward_no_filter, running_reward_real, n_updates)) 410 | df = pd.DataFrame({'Samples': samples, 'Reward': rewards, 'Reward_Real': rewards_real}) 411 | df.to_csv("{}.csv".format(env_name + '-ModelFree-Seed-' + str(seed))) 412 | cumulative_log_timestep = 0 413 | log_episode = 0 414 | running_reward = 0 415 | running_reward_no_filter = 0 416 | running_reward_real = 0 417 | 418 | time_step = 0 419 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import os 4 | import random 5 | from collections import deque, namedtuple 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from numpy.random import default_rng 10 | import torch 11 | import torch.nn as nn 12 | from torch.distributions import constraints 13 | from torch.distributions.transforms import Transform 14 | from torch.nn.functional import softplus 15 | from torch.nn.init import _calculate_correct_fan 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | TransitionContext = namedtuple('Transition', ('state', 'action', 'reward', 'nextstate', 'real_done', 'rad_context')) 21 | Transition = namedtuple('Transition', ('state', 'action', 'reward', 'nextstate', 'real_done')) 22 | 23 | 24 | def parameterized_truncated_normal(uniform, mu, sigma, a, b): 25 | normal = torch.distributions.normal.Normal(0, 1) 26 | 27 | alpha = (a - mu) / sigma 28 | beta = (b - mu) / sigma 29 | 30 | alpha_normal_cdf = normal.cdf(alpha) 31 | p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform 32 | 33 | p = p.cpu().numpy() 34 | one = np.array(1, dtype=p.dtype) 35 | epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype) 36 | v = np.clip(2 * p - 1, -one + epsilon, one - epsilon) 37 | x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v)) 38 | x = torch.clamp(x, a, b) 39 | 40 | return x.to(device) 41 | 42 | 43 | def truncated_normal(uniform): 44 | return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2) 45 | 46 | 47 | def truncated_normal_replace(m, mode='fan_in'): 48 | fan = _calculate_correct_fan(m, mode) 49 | std = 1/(2*np.sqrt(fan)) 50 | with torch.no_grad(): 51 | weight = (truncated_normal(m.uniform_())) 52 | return weight * std 53 | 54 | 55 | def truncated_normal_init(layer): 56 | if type(layer) in [nn.Linear]: 57 | layer.weight.data = truncated_normal_replace(layer.weight.data) 58 | 59 | 60 | def reward_func(s1, s2, a, env_name, state_filter=None, is_done_func=None): 61 | if state_filter: 62 | s1_real = s1 * state_filter.stdev + state_filter.mean 63 | s2_real = s2 * state_filter.stdev + state_filter.mean 64 | else: 65 | s1_real = s1 66 | s2_real = s2 67 | if env_name == "HalfCheetah-v2": 68 | return np.squeeze(s2_real)[-1] - 0.1 * np.square(a).sum() 69 | if env_name == "Ant-v2": 70 | if is_done_func: 71 | if is_done_func(torch.Tensor(s2_real).reshape(1,-1)): 72 | return 0.0 73 | return np.squeeze(s2_real)[-1] - 0.5 * np.square(a).sum() + 1.0 74 | if env_name == "Swimmer-v2": 75 | return np.squeeze(s2_real)[-1] - 0.0001 * np.square(a).sum() 76 | if env_name == "Hopper-v2": 77 | if is_done_func: 78 | if is_done_func(torch.Tensor(s2_real).reshape(1,-1)): 79 | return 0.0 80 | return np.squeeze(s2_real)[-1] - 0.1 * np.square(a).sum() - 3.0 * np.square(s2_real[0] - 1.3) + 1.0 81 | 82 | 83 | class MeanStdevFilter(): 84 | def __init__(self, shape, clip=10.0): 85 | self.eps = 1e-12 86 | self.shape = shape 87 | self.clip = clip 88 | self._count = 0 89 | self._running_sum = np.zeros(shape) 90 | self._running_sum_sq = np.zeros(shape) + self.eps 91 | self.mean = 0 92 | self.stdev = 1 93 | 94 | def update(self, x): 95 | if len(x.shape) == 1: 96 | x = x.reshape(1,-1) 97 | self._running_sum += np.sum(x, axis=0) 98 | self._running_sum_sq += np.sum(np.square(x), axis=0) 99 | # assume 2D data 100 | self._count += x.shape[0] 101 | self.mean = self._running_sum / self._count 102 | self.stdev = np.sqrt( 103 | np.maximum( 104 | self._running_sum_sq / self._count - self.mean**2, 105 | self.eps 106 | )) 107 | self.stdev[self.stdev <= self.eps] = 1.0 108 | 109 | def reset(self): 110 | self.__init__(self.shape, self.clip) 111 | 112 | def update_torch(self): 113 | self.torch_mean = torch.FloatTensor(self.mean).to(device) 114 | self.torch_stdev = torch.FloatTensor(self.stdev).to(device) 115 | 116 | def filter(self, x): 117 | return np.clip(((x - self.mean) / self.stdev), -self.clip, self.clip) 118 | 119 | def filter_torch(self, x: torch.Tensor): 120 | self.update_torch() 121 | return torch.clamp(((x - self.torch_mean) / self.torch_stdev), -self.clip, self.clip) 122 | 123 | def invert(self, x): 124 | return (x * self.stdev) + self.mean 125 | 126 | def invert_torch(self, x: torch.Tensor): 127 | return (x * self.torch_stdev) + self.torch_mean 128 | 129 | 130 | def tidy_up_weight_dir(guids=None): 131 | if guids == None: 132 | guids = [] 133 | files = [i for i in os.listdir("./data/") if i.endswith("pth")] 134 | for weight_full in files: 135 | weight = weight_full.split('_')[1] 136 | if weight.split('.')[0] not in guids: 137 | os.remove("./data/" + weight_full) 138 | 139 | 140 | def prepare_data(state, action, nextstate, state_filter, action_filter): 141 | state_filtered = state_filter.filter(state) 142 | action_filtered = action_filter.filter(action) 143 | state_action_filtered = np.concatenate((state_filtered, action_filtered), axis=1) 144 | delta = np.array(nextstate) - np.array(state) 145 | return state_action_filtered, delta 146 | 147 | 148 | def get_residual(newdata, pca_data, pct=0.99): 149 | X_pca = np.array(pca_data) 150 | # standardize 151 | X_pca = (X_pca - np.mean(X_pca)) / (np.std(X_pca) + 1e-8) 152 | 153 | Q, Sigma, _ = np.linalg.svd(X_pca.T) 154 | # proportion 155 | weight = np.cumsum(Sigma / np.sum(Sigma)) 156 | index = np.sum((weight > pct) == 0) 157 | train_resid = 1-weight[index] 158 | V = Q[:,:index+1] 159 | 160 | basis = V.dot(V.T) 161 | 162 | X = np.array(newdata) 163 | # standardize with respect to old data 164 | X = (X - np.mean(X_pca)) / (np.std(X_pca) + 1e-8) 165 | orig = X.T.dot(X) 166 | projected = np.matmul(np.matmul(basis, orig), basis) 167 | residual = (np.trace(orig) - np.trace(projected))/np.trace(orig) 168 | return(residual, train_resid) 169 | 170 | 171 | def get_stats(env, state, action, state_filter, action_filter, done, dynamics=False, reward_head=0): 172 | with torch.no_grad(): 173 | stats_mean = [] 174 | stats_var = [] 175 | for model in env.model.models.values(): 176 | if model.model.is_probabilistic: 177 | nextstate, reward = model.get_next_state_reward(state, action, state_filter, action_filter, True) 178 | if dynamics: 179 | raise Exception('Not Implemented') 180 | if reward_head: 181 | stats_mean.append(reward[0]) 182 | stats_var.append(reward[1].exp()) 183 | else: 184 | # TODO: make this more efficient 185 | reward = torch.tensor(torch_reward(env.name, nextstate[0], action, done), device=device) 186 | stats_mean.append(reward) 187 | stats_var.append(nextstate[1][:,-1].exp()) 188 | else: 189 | nextstate, reward = model.get_next_state_reward(state, action, state_filter, action_filter, False) 190 | if dynamics: 191 | stats_mean.append(nextstate) 192 | stats_var.append(torch.zeros(nextstate.shape, device=device)) 193 | if reward_head: 194 | stats_mean.append(reward) 195 | stats_var.append(torch.zeros(reward.shape, device=device)) 196 | else: 197 | # TODO: make this more efficient 198 | reward = torch.tensor(torch_reward(env.name, nextstate, action, done), device=device) 199 | stats_mean.append(reward) 200 | stats_var.append(torch.zeros(reward.shape, device=device)) 201 | if dynamics: 202 | return (torch.stack(stats_mean) - torch.stack(stats_mean).mean((0))).pow(2).sum(2).mean(0).detach().cpu().numpy() 203 | else: 204 | # equivalent to the Lakshminarayanan paper 205 | return torch.sqrt(torch.var(torch.stack(stats_mean), axis=0) + torch.mean(torch.stack(stats_var), axis=0)).detach().cpu().numpy() 206 | 207 | 208 | def random_env_forward(data, env, reward_head): 209 | """Randomly allocate the data through the different dynamics models""" 210 | y = torch.zeros((data.shape[0], env.observation_space.shape[0]+reward_head), device=device) 211 | allocation = torch.randint(0, len(env.model.models), (data.shape[0],)) 212 | for i in env.model.models: 213 | data_i = data[allocation == i] 214 | y_i, _ = env.model.models[i].forward(data_i) 215 | y[allocation == i] = y_i 216 | return y 217 | 218 | 219 | def filter_torch(x, mean, stddev): 220 | x_f = (x - mean) / stddev 221 | return torch.clamp(x_f, -3, 3) 222 | 223 | 224 | def filter_torch_invert(x_f, mean, stddev): 225 | x = (x_f * stddev) + mean 226 | return x 227 | 228 | 229 | def halfcheetah_reward(nextstate, action): 230 | return (nextstate[:,-1] - 0.1 * torch.sum(torch.pow(action, 2), 1)).detach().cpu().numpy() 231 | 232 | 233 | def ant_reward(nextstate, action, dones): 234 | reward = (nextstate[:,-1] - 0.5 * torch.sum(torch.pow(action, 2), 1) + 1.0).detach().cpu().numpy() 235 | reward[dones] = 0.0 236 | return reward 237 | 238 | 239 | def swimmer_reward(nextstate, action): 240 | reward = (nextstate[:,-1] - 0.0001 * torch.sum(torch.pow(action, 2), 1)).detach().cpu().numpy() 241 | return reward 242 | 243 | 244 | def hopper_reward(nextstate, action, dones): 245 | reward = (nextstate[:,-1] - 0.1 * torch.sum(torch.pow(action, 2), 1) - 3.0 * (nextstate[:,0] - 1.3).pow(2) + 1.0).detach().cpu().numpy() 246 | reward[dones] = 0.0 247 | return reward 248 | 249 | 250 | def torch_reward(env_name, nextstate, action, dones=None): 251 | if env_name == "HalfCheetah-v2": 252 | return halfcheetah_reward(nextstate, action) 253 | elif env_name == "Ant-v2": 254 | return ant_reward(nextstate, action, dones) 255 | elif env_name == "Hopper-v2": 256 | return hopper_reward(nextstate, action, dones) 257 | elif env_name == "Swimmer-v2": 258 | return swimmer_reward(nextstate, action) 259 | else: 260 | raise Exception('Environment not supported') 261 | 262 | 263 | class GaussianMSELoss(nn.Module): 264 | 265 | def __init__(self): 266 | super(GaussianMSELoss, self).__init__() 267 | 268 | def forward(self, mu_logvar, target, logvar_loss = True): 269 | mu, logvar = mu_logvar.chunk(2, dim=1) 270 | inv_var = (-logvar).exp() 271 | if logvar_loss: 272 | return (logvar + (target - mu)**2 * inv_var).mean() 273 | else: 274 | return ((target - mu)**2).mean() 275 | 276 | 277 | class FasterReplayPool: 278 | 279 | def __init__(self, action_dim, state_dim, capacity=1e6): 280 | self.capacity = int(capacity) 281 | self._action_dim = action_dim 282 | self._state_dim = state_dim 283 | self._pointer = 0 284 | self._size = 0 285 | self._init_memory() 286 | self._rng = default_rng() 287 | 288 | def _init_memory(self): 289 | self._memory = { 290 | 'state': np.zeros((self.capacity, self._state_dim), dtype='float32'), 291 | 'action': np.zeros((self.capacity, self._action_dim), dtype='float32'), 292 | 'reward': np.zeros((self.capacity), dtype='float32'), 293 | 'nextstate': np.zeros((self.capacity, self._state_dim), dtype='float32'), 294 | 'real_done': np.zeros((self.capacity), dtype='bool') 295 | } 296 | 297 | def push(self, transition: Transition): 298 | 299 | # Handle 1-D Data 300 | num_samples = transition.state.shape[0] if len(transition.state.shape) > 1 else 1 301 | idx = np.arange(self._pointer, self._pointer + num_samples) % self.capacity 302 | 303 | for key, value in transition._asdict().items(): 304 | self._memory[key][idx] = value 305 | 306 | self._pointer = (self._pointer + num_samples) % self.capacity 307 | self._size = min(self._size + num_samples, self.capacity) 308 | 309 | def _return_from_idx(self, idx): 310 | sample = {k: tuple(v[idx]) for k,v in self._memory.items()} 311 | return Transition(**sample) 312 | 313 | def sample(self, batch_size: int, unique: bool = True): 314 | idx = np.random.randint(0, self._size, batch_size) if not unique else self._rng.choice(self._size, size=batch_size, replace=False) 315 | return self._return_from_idx(idx) 316 | 317 | def sample_all(self): 318 | return self._return_from_idx(np.arange(0, self._size)) 319 | 320 | def get(self, start_idx, end_idx): 321 | raise NotImplementedError 322 | 323 | def get_all(self): 324 | raise NotImplementedError 325 | 326 | def _get_from_idx(self, idx): 327 | raise NotImplementedError 328 | 329 | def __len__(self): 330 | return self._size 331 | 332 | def clear_pool(self): 333 | self._init_memory() 334 | 335 | def initialise(self, old_pool): 336 | # Not Tested 337 | old_memory = old_pool.sample_all() 338 | for key in self._memory: 339 | self._memory[key] = np.append(self._memory[key], old_memory[key], 0) 340 | 341 | class FasterReplayPoolCtxt: 342 | 343 | def __init__(self, action_dim, state_dim, capacity=1e6): 344 | self.capacity = int(capacity) 345 | self._action_dim = action_dim 346 | self._state_dim = state_dim 347 | self._pointer = 0 348 | self._size = 0 349 | self._init_memory() 350 | self._rng = default_rng() 351 | 352 | def _init_memory(self): 353 | self._memory = { 354 | 'state': np.zeros((self.capacity, self._state_dim), dtype='float32'), 355 | 'action': np.zeros((self.capacity, self._action_dim), dtype='float32'), 356 | 'reward': np.zeros((self.capacity), dtype='float32'), 357 | 'nextstate': np.zeros((self.capacity, self._state_dim), dtype='float32'), 358 | 'real_done': np.zeros((self.capacity), dtype='bool'), 359 | 'rad_context': np.zeros((self.capacity, self._state_dim), dtype='float32') 360 | } 361 | 362 | def push(self, transition: Transition): 363 | 364 | # Handle 1-D Data 365 | num_samples = transition.state.shape[0] if len(transition.state.shape) > 1 else 1 366 | idx = np.arange(self._pointer, self._pointer + num_samples) % self.capacity 367 | 368 | for key, value in transition._asdict().items(): 369 | self._memory[key][idx] = value 370 | 371 | self._pointer = (self._pointer + num_samples) % self.capacity 372 | self._size = min(self._size + num_samples, self.capacity) 373 | 374 | def _return_from_idx(self, idx): 375 | sample = {k: tuple(v[idx]) for k,v in self._memory.items()} 376 | return TransitionContext(**sample) 377 | 378 | def sample(self, batch_size: int, unique: bool = True): 379 | idx = np.random.randint(0, self._size, batch_size) if not unique else self._rng.choice(self._size, size=batch_size, replace=False) 380 | return self._return_from_idx(idx) 381 | 382 | def sample_all(self): 383 | return self._return_from_idx(np.arange(0, self._size)) 384 | 385 | def get(self, start_idx, end_idx): 386 | raise NotImplementedError 387 | 388 | def get_all(self): 389 | raise NotImplementedError 390 | 391 | def _get_from_idx(self, idx): 392 | raise NotImplementedError 393 | 394 | def __len__(self): 395 | return self._size 396 | 397 | def clear_pool(self): 398 | self._init_memory() 399 | 400 | def initialise(self, old_pool): 401 | # Not Tested 402 | old_memory = old_pool.sample_all() 403 | for key in self._memory: 404 | self._memory[key] = np.append(self._memory[key], old_memory[key], 0) 405 | 406 | class ReplayPool: 407 | 408 | def __init__(self, capacity=1e6): 409 | self.capacity = int(capacity) 410 | self._memory = deque(maxlen=int(capacity)) 411 | 412 | def push(self, transition: Transition): 413 | """ Saves a transition """ 414 | self._memory.append(transition) 415 | 416 | def sample(self, batch_size: int, unique: bool = True, dist=None) -> Transition: 417 | transitions = random.sample(self._memory, batch_size) if unique else random.choices(self._memory, k=batch_size) 418 | return Transition(*zip(*transitions)) 419 | 420 | def sample_traj(self, truncate_length = 300): 421 | traj_num = len(self._memory)//1000 # number of trajectories 422 | init_state=[self._memory[i * 1000].state for i in range(traj_num)] 423 | 424 | traj_s,traj_a = [], [] 425 | for i in range(traj_num): 426 | s,a=self.get2(i*1000, i*1000+truncate_length) 427 | traj_s.append(s) 428 | traj_a.append(a) 429 | traj_s=np.array(traj_s) 430 | traj_a=np.array(traj_a) 431 | return init_state,traj_s,traj_a 432 | 433 | def get(self, start_idx: int, end_idx: int) -> Transition: 434 | transitions = list(itertools.islice(self._memory, start_idx, end_idx)) 435 | return transitions 436 | 437 | def get2(self, start_idx: int, end_idx: int) -> Transition: 438 | transitions = list(itertools.islice(self._memory, start_idx, end_idx)) 439 | states=np.array([i.state for i in transitions]) 440 | actions=np.array([i.action for i in transitions]) 441 | return states,actions 442 | 443 | def get_all(self) -> Transition: 444 | return self.get(0, len(self._memory)) 445 | 446 | def sample_all(self) -> Transition: 447 | return Transition(*zip(*(self.get_all()))) 448 | 449 | def __len__(self) -> int: 450 | return len(self._memory) 451 | 452 | def clear_pool(self): 453 | self._memory.clear() 454 | 455 | def initialise(self, old_pool: 'ReplayPool'): 456 | old_memory = old_pool.get_all() 457 | self._memory.extend(old_memory) 458 | 459 | 460 | class ReplayPoolCtxt: 461 | 462 | def __init__(self, capacity=1e6): 463 | self.capacity = int(capacity) 464 | self._memory = deque(maxlen=int(capacity)) 465 | 466 | def push(self, transition: TransitionContext): 467 | """ Saves a transition """ 468 | self._memory.append(transition) 469 | 470 | def sample(self, batch_size: int, unique: bool = True, dist=None) -> TransitionContext: 471 | transitions = random.sample(self._memory, batch_size) if unique else random.choices(self._memory, k=batch_size) 472 | return TransitionContext(*zip(*transitions)) 473 | 474 | def get(self, start_idx: int, end_idx: int) -> TransitionContext: 475 | transitions = list(itertools.islice(self._memory, start_idx, end_idx)) 476 | return transitions 477 | 478 | def get_all(self) -> TransitionContext: 479 | return self.get(0, len(self._memory)) 480 | 481 | def sample_all(self) -> TransitionContext: 482 | return TransitionContext(*zip(*(self.get_all()))) 483 | 484 | def __len__(self) -> int: 485 | return len(self._memory) 486 | 487 | def clear_pool(self): 488 | self._memory.clear() 489 | 490 | def initialise(self, old_pool: 'ReplayPoolCtxt'): 491 | old_memory = old_pool.get_all() 492 | self._memory.extend(old_memory) 493 | 494 | 495 | # Taken from: https://github.com/pytorch/pytorch/pull/19785/files 496 | # The composition of affine + sigmoid + affine transforms is unstable numerically 497 | # tanh transform is (2 * sigmoid(2x) - 1) 498 | # Old Code Below: 499 | # transforms = [AffineTransform(loc=0, scale=2), SigmoidTransform(), AffineTransform(loc=-1, scale=2)] 500 | class TanhTransform(Transform): 501 | r""" 502 | Transform via the mapping :math:`y = \tanh(x)`. 503 | It is equivalent to 504 | ``` 505 | ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) 506 | ``` 507 | However this might not be numerically stable, thus it is recommended to use `TanhTransform` 508 | instead. 509 | Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. 510 | """ 511 | domain = constraints.real 512 | codomain = constraints.interval(-1.0, 1.0) 513 | bijective = True 514 | sign = +1 515 | 516 | @staticmethod 517 | def atanh(x): 518 | return 0.5 * (x.log1p() - (-x).log1p()) 519 | 520 | def __eq__(self, other): 521 | return isinstance(other, TanhTransform) 522 | 523 | def _call(self, x): 524 | return x.tanh() 525 | 526 | def _inverse(self, y): 527 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 528 | # one should use `cache_size=1` instead 529 | return self.atanh(y) 530 | 531 | def log_abs_det_jacobian(self, x, y): 532 | # We use a formula that is more numerically stable, see details in the following link 533 | # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 534 | return 2. * (math.log(2.) - x - softplus(-2. * x)) 535 | 536 | 537 | def check_or_make_folder(folder_path): 538 | """ 539 | Helper function that (safely) checks if a dir exists; if not, it creates it 540 | """ 541 | 542 | folder_path = Path(folder_path) 543 | 544 | try: 545 | folder_path.resolve(strict=True) 546 | except FileNotFoundError: 547 | print("{} dir not found, creating it".format(folder_path)) 548 | os.mkdir(folder_path) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import random 3 | import pdb 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.special import softmax 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from model import EnsembleGymEnv 14 | from sac import SAC_Agent 15 | from utils import (filter_torch, filter_torch_invert, get_residual, get_stats, 16 | random_env_forward, torch_reward, Transition, TransitionContext,ReplayPool, check_or_make_folder) 17 | from utils import ReplayPoolCtxt, ReplayPool, FasterReplayPool, FasterReplayPoolCtxt, TanhTransform, Transition, TransitionContext, filter_torch 18 | import d4rl 19 | import pickle 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | class reward_estimator(nn.Module): 23 | def __init__( 24 | self, 25 | input_dim, 26 | hidden_sizes=(256,256), 27 | hid_act='relu', 28 | use_bn=False, 29 | residual=False, 30 | clamp_magnitude=10.0, 31 | device=device, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | 36 | if hid_act == 'relu': 37 | hid_act_class = nn.ReLU 38 | elif hid_act == 'tanh': 39 | hid_act_class = nn.Tanh 40 | else: 41 | raise NotImplementedError() 42 | 43 | self.clamp_magnitude = clamp_magnitude 44 | self.input_dim = input_dim 45 | self.device = device 46 | self.residual = residual 47 | 48 | self.first_fc = nn.Linear(input_dim, hidden_sizes[0]) 49 | self.blocks_list = nn.ModuleList() 50 | 51 | for i in range(len(hidden_sizes) - 1): 52 | block = nn.ModuleList() 53 | block.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1])) 54 | if use_bn: block.append(nn.BatchNorm1d(hidden_sizes[i+1])) 55 | block.append(hid_act_class()) 56 | self.blocks_list.append(nn.Sequential(*block)) 57 | 58 | self.last_fc = nn.Linear(hidden_sizes[-1], 1) 59 | 60 | def forward(self, batch): 61 | x = self.first_fc(batch) 62 | for block in self.blocks_list: 63 | if self.residual: 64 | x = x + block(x) 65 | else: 66 | x = block(x) 67 | output = self.last_fc(x) 68 | output = torch.clamp(output, min=-1.0*self.clamp_magnitude, max=self.clamp_magnitude) 69 | return output 70 | 71 | def r(self, batch): 72 | return self.forward(batch) 73 | 74 | def get_scalar_reward(self, obs): 75 | self.eval() 76 | with torch.no_grad(): 77 | if not torch.is_tensor(obs): 78 | obs = torch.FloatTensor(obs.reshape(-1,self.input_dim)) 79 | obs = obs.to(self.device) 80 | reward = self.forward(obs).cpu().detach().numpy().flatten() 81 | self.train() 82 | return reward 83 | 84 | 85 | class Trainer(object): 86 | 87 | def __init__(self, params, model: EnsembleGymEnv, agent: SAC_Agent, device): 88 | self.agent = agent 89 | self.model = model 90 | 91 | self.expert_pool = ReplayPool(capacity=1e6) 92 | 93 | self.envname=params['env_name'].split('-')[0] 94 | self.load_expert_dataset() 95 | 96 | if params['env_name'].split('-')[0]=='hopper': 97 | self.reward_function=reward_estimator(14).to(device) 98 | else: 99 | self.reward_function=reward_estimator(23).to(device) 100 | self.transfer=params['transfer'] 101 | self.reward_opti= torch.optim.Adam(self.reward_function.parameters(), lr=1e-4,weight_decay=1e-3, betas=(0.9, 0.999)) 102 | if self.transfer: 103 | self.reward_function.load_state_dict(torch.load('./reward_walker/medexp_walker_reward_5603_.pt')) 104 | 105 | self.reward_function.eval() 106 | self._init_collect = params['init_collect'] 107 | self._max_model_epochs = params['model_epochs'] 108 | self._var_type = params['var_type'] 109 | self._num_rollouts = params['num_rollouts_per_step'] * params['model_train_freq'] 110 | self._model_retain_epochs = params['model_retain_epochs'] 111 | self._device = device 112 | self._train_policy_every = params['train_policy_every'] 113 | self._reward_head = params['reward_head'] 114 | self._policy_update_steps = params['policy_update_steps'] 115 | self._steps_k = params['steps_k'] 116 | self._reward_step = params['reward_steps'] 117 | if isinstance(self._steps_k, list): 118 | self._cur_steps_k = self._steps_k[0] 119 | else: 120 | self._cur_steps_k = self._steps_k 121 | self._n_eval_rollouts = params['n_eval_rollouts'] 122 | self._real_sample_ratio = params['real_sample_ratio'] 123 | self._model_train_freq = params['model_train_freq'] 124 | self._rollout_model_freq = params['rollout_model_freq'] 125 | self._oac = params['oac'] 126 | self._sample_states = params['states'] 127 | self._done = True 128 | self._state = None 129 | self._n_epochs = 0 130 | self._is_done_func = params['is_done_func'] 131 | self._var_thresh = params['var_thresh'] 132 | self._keep_logvar = True if self._var_thresh is not None else False 133 | self.k_used = [0] 134 | self._espi = params['espi'] 135 | self._max_steps = params['epoch_steps'] if params[ 136 | 'epoch_steps'] else self.model.real_env.env.spec.max_episode_steps 137 | self._env_step = 0 138 | self._curr_rollout = [] 139 | self._deterministic = params['deterministic_rollouts'] 140 | self._seed = params['seed'] 141 | self._min_model_epochs = params['min_model_epochs'] 142 | if self._min_model_epochs: 143 | assert self._min_model_epochs < self._max_model_epochs, "Can't have a min epochs that is less than the max" 144 | self._augment_offline_data = params['augment_offline_data'] 145 | 146 | if params['population_model_dirs']: 147 | self._load_population_models(params) 148 | 149 | self._morel_halt_reward = params['morel_halt_reward'] 150 | 151 | # Remove in a sec, just for testing 152 | self._params = params 153 | 154 | if self._params['mopo'] and self._params['morel']: 155 | raise Exception('Do not use MOReL and MOPO together please') 156 | 157 | def load_expert_dataset(self): 158 | import os 159 | current_path=os.path.dirname(os.path.abspath(__file__))+'/expert_data/'+self.envname+'/' 160 | state=np.load(current_path+'states.npy')[:50] 161 | action=np.load(current_path+'actions.npy')[:50] 162 | dones=np.load(current_path+'dones.npy')[:50] 163 | 164 | for i in range(state.shape[0]): 165 | for j in range(state.shape[1]): 166 | self.expert_pool.push(Transition(state[i][j], action[i][j], 0 , 0, dones[i][j])) 167 | 168 | def _train_model(self, d4rl_init=False, save_model=False): 169 | print("\nTraining Model...") 170 | self.model.train_model(self._max_model_epochs, d4rl_init=d4rl_init, save_model=save_model, 171 | min_model_epochs=self._min_model_epochs) 172 | 173 | 174 | #TODO: We have written this code like 3 times now, maybe incorporate it with myopic thing 175 | def rollout_with_ground_truth(self, policy: SAC_Agent, num_parallel = 100, num_steps=25): 176 | """ 177 | Rolls out 'policy' in the WM for 'num_steps' with 'num_parallel' starting points 178 | Also returns the uncertainty penalties of interest along the way 179 | Also returns the MSE w.r.t. ground truth for our true D_TV 180 | """ 181 | start_states = torch.FloatTensor( 182 | np.array(self.model.model.memory.sample(num_parallel, unique=False)[0])).to(self._device) 183 | done_false = [False for _ in range(start_states.shape[0])] 184 | t = 0 185 | state = start_states 186 | idxs_remaining = np.arange(start_states.shape[0]) 187 | stats_dic = { 188 | 'groundtruth_mses': np.zeros((num_parallel, num_steps)), 189 | 'mopo_paper': np.zeros((num_parallel, num_steps)), 190 | 'morel': np.zeros((num_parallel, num_steps)), 191 | 'lompo': np.zeros((num_parallel, num_steps)), 192 | 'm2ac': np.zeros((num_parallel, num_steps)), 193 | 'ensemble_var': np.zeros((num_parallel, num_steps)), 194 | 'ensemble_std': np.zeros((num_parallel, num_steps)) 195 | } 196 | while t < num_steps: 197 | t += 1 198 | with torch.no_grad(): 199 | # Get deterministic action 200 | _, _, action_det = policy.policy(state) 201 | penalties, nextstate = self.model.model.get_all_penalties_state_action(state, action_det) 202 | true_nextstate = self._get_ground_truth_nextstate(state, action_det) 203 | mses = ((nextstate.cpu().numpy() - true_nextstate)**2).mean(1) 204 | if self._is_done_func: 205 | done = self._is_done_func(nextstate).cpu().numpy().tolist() 206 | else: 207 | done = done_false[:nextstate.shape[0]] 208 | not_done = ~np.array(done) 209 | for k,v in penalties.items(): 210 | stats_dic[k][idxs_remaining, t-1] = v 211 | stats_dic['groundtruth_mses'][idxs_remaining, t-1] = mses 212 | idxs_remaining = idxs_remaining[not_done] 213 | if len(nextstate.shape) == 1: 214 | nextstate.unsqueeze_(0) 215 | state = nextstate[not_done] 216 | if len(state.shape) == 1: 217 | state.unsqueeze_(0) 218 | if not_done.sum() == 0: 219 | break 220 | return stats_dic 221 | 222 | def _get_ground_truth_nextstate(self, state, action): 223 | env = self.model.eval_env 224 | env.reset() 225 | state = state.cpu().numpy() 226 | action = action.cpu().numpy() 227 | qpos0 = np.array([0]) 228 | ground_truth_ns = [] 229 | if 'hopper' in env.spec.id.lower(): 230 | for s, a in zip(state, action): 231 | qpos = np.concatenate([qpos0, s[:5]]) 232 | qvel = s[5:] 233 | env.set_state(qpos, qvel) 234 | true_ns = env.step(a) 235 | ground_truth_ns.append(true_ns[0]) 236 | else: 237 | raise NotImplementedError 238 | return np.stack(ground_truth_ns) 239 | 240 | def _rollout_model_expert(self,start_states,sample_pool=None,IRL=True,Penalty_only=False): 241 | _num_rollouts=start_states.shape[0] 242 | 243 | self.model.convert_filter_to_torch() 244 | self.k_used = [0] 245 | self.var_mean = [] 246 | 247 | done_true = [True for _ in range(_num_rollouts)] 248 | done_false = [False for _ in range(_num_rollouts)] 249 | 250 | state = start_states.clone() 251 | prev_done = done_false 252 | var = 0 253 | t = 0 254 | transition_count = 0 255 | 256 | while t < self._reward_step: 257 | t += 1 258 | if self._params['rad_rollout']: 259 | high = 1.2 260 | low = 0.8 261 | scale = high - low 262 | random_amplitude_scaling = (torch.rand(start_states.shape) * scale + low).to(device) 263 | 264 | with torch.no_grad(): 265 | # TODO: Random steps intially? 266 | if self.agent.augment_sac and self._params['rad_rollout']: 267 | # Normalise this here 268 | # filter_torch(random_amplitude_scaling, 1, 0.11547) 269 | augmented_state = torch.cat((state, random_amplitude_scaling), 1) 270 | action, _, _ = self.agent.policy(augmented_state) 271 | elif self.agent.augment_sac and self._params['context_type'] == 'rad_augmentation': 272 | # Assume a base context of 1s just as a filler. 273 | augmented_state = torch.cat((state, torch.ones_like(state)), 1) 274 | action, _, _ = self.agent.policy(augmented_state) 275 | elif self.agent.augment_sac and self._params['context_type'] == 'rad_magnitude': 276 | # Assume a base context of 1s just as a filler. 277 | augmented_state = torch.cat((state, torch.ones(state.shape[0], 1).to(device)), 1) 278 | action, _, _ = self.agent.policy(augmented_state) 279 | else: 280 | action, _, _ = self.agent.policy(state) 281 | 282 | nextstate, reward, penalties = self.model.model.random_env_step(state, 283 | action, 284 | get_var=self._keep_logvar, 285 | deterministic=self._deterministic, 286 | IRL=Penalty_only, 287 | ) 288 | 289 | if self._keep_logvar: 290 | nextstate, nextstate_var, reward, reward_var = nextstate[0], nextstate[1], reward[0], reward[1] 291 | 292 | nextstate_copy = nextstate.clone() 293 | if self._params['rad_rollout']: 294 | nextstate *= random_amplitude_scaling 295 | 296 | if self._is_done_func: 297 | done = self._is_done_func(nextstate).cpu().numpy().tolist() 298 | else: 299 | done = done_false[:nextstate.shape[0]] 300 | if self._params['morel'] and not self._params['morel_non_stop']: 301 | # we're going to go to the absorbing HALT state here 302 | # TODO: fix this to be more efficient? 303 | done = torch.tensor(done).to(device) | (penalties > self._morel_threshold) 304 | done = done.cpu().numpy().tolist() 305 | 306 | not_done = ~np.array(done) 307 | if self._keep_logvar: 308 | print("Reward Var: mean = {}, max = {}, min = {}".format(np.mean(reward_var.cpu().numpy()), 309 | np.max(reward_var.cpu().numpy()), 310 | np.min(reward_var.cpu().numpy()))) 311 | self.var_mean.append(np.mean(reward_var.cpu().numpy())) 312 | var_low = reward_var < self._var_thresh 313 | done_k = np.array(done) + ~var_low.cpu().numpy().squeeze() 314 | not_done = not_done & var_low.cpu().numpy().squeeze() 315 | self.k_used += [t for _ in done_k if _ is True] 316 | 317 | uncert = 0 318 | if self._reward_head: 319 | reward = reward.cpu().detach().numpy() 320 | else: 321 | reward = torch_reward(self.model.name, nextstate, action, done) 322 | state_np, action_np, nextstate_np ,penalties= state.detach().cpu().numpy(), action.detach().cpu().numpy(), nextstate.detach().cpu().numpy(), penalties.detach().cpu().numpy() 323 | 324 | if self._params['rad_rollout'] and self.agent.augment_sac: 325 | rad_np = random_amplitude_scaling.detach().cpu().numpy() 326 | # for s, a, r, s_n, d, ctxt in zip(state_np, action_np, reward, nextstate_np, done, rad_np): 327 | # r = r.item() 328 | # self.agent.replay_pool.push(TransitionContext(s, a, r, s_n, d, ctxt)) 329 | sample_pool.push(TransitionContext(state_np, action_np, penalties, nextstate_np, np.array(done), rad_np)) 330 | 331 | else: 332 | # for s, a, r, s_n, d in zip(state_np, action_np, reward, nextstate_np, done): 333 | # r = r.item() 334 | # self.agent.replay_pool.push(Transition(s, a, r, s_n, d)) 335 | sample_pool.push(Transition(state_np, action_np, penalties, nextstate_np, np.array(done))) 336 | 337 | if not_done.sum() == 0: 338 | print("Finished rollouts early: all terminated after %s timesteps" % (t)) 339 | break 340 | transition_count += len(nextstate) 341 | # Initialize state clean to be augmented next step. 342 | if len(nextstate_copy.shape) == 1: 343 | nextstate_copy.unsqueeze_(0) 344 | state = nextstate_copy[not_done] 345 | if len(state.shape) == 1: 346 | state.unsqueeze_(0) 347 | print("Remaining = {}".format(np.round(state.shape[0]) / start_states.shape[0], 2)) 348 | var += uncert ** 2 349 | print("Finished rollouts: all terminated after %s timesteps" % (t)) 350 | print("Added {} transitions to agent replay pool".format(transition_count)) 351 | 352 | return 353 | 354 | 355 | def _rollout_model(self,sample_pool=None,IRL=False,Penalty_only=False): 356 | print("\nRolling out Policy in Model...") 357 | if self._params['mopo']: 358 | print("\nUsing MOPO Penalty") 359 | elif self._params['morel']: 360 | print("\nUsing MOReL Penalty") 361 | if self._var_type == 'reward': 362 | state_dynamics = False 363 | elif self._var_type == 'state': 364 | state_dynamics = True 365 | else: 366 | raise Exception("Variance must either be 'reward' or 'state'") 367 | 368 | for model in self.model.model.models.values(): 369 | model.to(self._device) 370 | 371 | self.model.convert_filter_to_torch() 372 | self.k_used = [0] 373 | self.var_mean = [] 374 | 375 | done_true = [True for _ in range(self._num_rollouts)] 376 | done_false = [False for _ in range(self._num_rollouts)] 377 | 378 | if self._sample_states == 'uniform': 379 | start_states = torch.FloatTensor( 380 | np.array(self.model.model.memory.sample(self._num_rollouts, unique=False)[0])).to(self._device) 381 | elif self._sample_states == 'entropy': 382 | all_states = [transition[0] for transition in self.model.model.memory.get_all()] 383 | all_states = torch.FloatTensor(all_states).to(self._device) 384 | _, _, all_actions = self.agent.policy(all_states) 385 | u = get_stats(self.model, all_states, all_actions, self.model.model.state_filter, 386 | self.model.model.action_filter, False, 387 | False, self._reward_head) 388 | scaled_neg_ent = -np.log(u) 389 | dist = softmax(scaled_neg_ent) 390 | sample = np.random.choice(all_states.shape[0], self._num_rollouts, p=dist.flatten()) 391 | start_states = all_states[sample] 392 | 393 | state = start_states.clone() 394 | prev_done = done_false 395 | var = 0 396 | t = 0 397 | transition_count = 0 398 | 399 | while t < self._cur_steps_k: 400 | t += 1 401 | if self._params['rad_rollout']: 402 | high = 1.2 403 | low = 0.8 404 | scale = high - low 405 | random_amplitude_scaling = (torch.rand(start_states.shape) * scale + low).to(device) 406 | # Just scale nextstate for now 407 | # state *= random_amplitude_scaling 408 | 409 | with torch.no_grad(): 410 | # TODO: Random steps intially? 411 | if self.agent.augment_sac and self._params['rad_rollout']: 412 | # Normalise this here 413 | # filter_torch(random_amplitude_scaling, 1, 0.11547) 414 | augmented_state = torch.cat((state, random_amplitude_scaling), 1) 415 | action, _, _ = self.agent.policy(augmented_state) 416 | elif self.agent.augment_sac and self._params['context_type'] == 'rad_augmentation': 417 | # Assume a base context of 1s just as a filler. 418 | augmented_state = torch.cat((state, torch.ones_like(state)), 1) 419 | action, _, _ = self.agent.policy(augmented_state) 420 | elif self.agent.augment_sac and self._params['context_type'] == 'rad_magnitude': 421 | # Assume a base context of 1s just as a filler. 422 | augmented_state = torch.cat((state, torch.ones(state.shape[0], 1).to(device)), 1) 423 | action, _, _ = self.agent.policy(augmented_state) 424 | else: 425 | action, _, _ = self.agent.policy(state) 426 | # to do 427 | 428 | nextstate, reward, penalties = self.model.model.random_env_step(state, 429 | action, 430 | get_var=self._keep_logvar, 431 | deterministic=self._deterministic, 432 | IRL=Penalty_only, 433 | ) 434 | 435 | if self._keep_logvar: 436 | nextstate, nextstate_var, reward, reward_var = nextstate[0], nextstate[1], reward[0], reward[1] 437 | 438 | nextstate_copy = nextstate.clone() 439 | if self._params['rad_rollout']: 440 | nextstate *= random_amplitude_scaling 441 | 442 | if self._params['tune_mopo_lam']: 443 | self.model.model.update_lambda() 444 | 445 | if self._is_done_func: 446 | done = self._is_done_func(nextstate).cpu().numpy().tolist() 447 | else: 448 | done = done_false[:nextstate.shape[0]] 449 | if self._params['morel'] and not self._params['morel_non_stop']: 450 | # we're going to go to the absorbing HALT state here 451 | # TODO: fix this to be more efficient? 452 | done = torch.tensor(done).to(device) | (penalties > self._morel_threshold) 453 | done = done.cpu().numpy().tolist() 454 | 455 | not_done = ~np.array(done) 456 | if self._keep_logvar: 457 | print("Reward Var: mean = {}, max = {}, min = {}".format(np.mean(reward_var.cpu().numpy()), 458 | np.max(reward_var.cpu().numpy()), 459 | np.min(reward_var.cpu().numpy()))) 460 | self.var_mean.append(np.mean(reward_var.cpu().numpy())) 461 | var_low = reward_var < self._var_thresh 462 | done_k = np.array(done) + ~var_low.cpu().numpy().squeeze() 463 | not_done = not_done & var_low.cpu().numpy().squeeze() 464 | self.k_used += [t for _ in done_k if _ is True] 465 | uncert = 0 466 | if self._reward_head: 467 | reward = reward.cpu().detach().numpy() 468 | else: 469 | reward = torch_reward(self.model.name, nextstate, action, done) 470 | state_np, action_np, nextstate_np ,penalties= state.detach().cpu().numpy(), action.detach().cpu().numpy(), nextstate.detach().cpu().numpy(), penalties.detach().cpu().numpy() 471 | if self._params['rad_rollout'] and self.agent.augment_sac: 472 | rad_np = random_amplitude_scaling.detach().cpu().numpy() 473 | # for s, a, r, s_n, d, ctxt in zip(state_np, action_np, reward, nextstate_np, done, rad_np): 474 | # r = r.item() 475 | # self.agent.replay_pool.push(TransitionContext(s, a, r, s_n, d, ctxt)) 476 | if IRL: 477 | sample_pool.push(TransitionContext(state_np, action_np, penalties, nextstate_np, np.array(done), rad_np)) 478 | else: 479 | self.agent.replay_pool.push(TransitionContext(state_np, action_np, penalties, nextstate_np, np.array(done), rad_np)) 480 | else: 481 | # for s, a, r, s_n, d in zip(state_np, action_np, reward, nextstate_np, done): 482 | # r = r.item() 483 | # self.agent.replay_pool.push(Transition(s, a, r, s_n, d)) 484 | if IRL: 485 | sample_pool.push(Transition(state_np, action_np, penalties, nextstate_np, np.array(done))) 486 | else: 487 | self.agent.replay_pool.push(Transition(state_np, action_np, penalties, nextstate_np, np.array(done))) 488 | if not_done.sum() == 0: 489 | print("Finished rollouts early: all terminated after %s timesteps" % (t)) 490 | break 491 | transition_count += len(nextstate) 492 | # Initialize state clean to be augmented next step. 493 | if len(nextstate_copy.shape) == 1: 494 | nextstate_copy.unsqueeze_(0) 495 | state = nextstate_copy[not_done] 496 | if len(state.shape) == 1: 497 | state.unsqueeze_(0) 498 | print("Remaining = {}".format(np.round(state.shape[0]) / start_states.shape[0], 2)) 499 | var += uncert ** 2 500 | print("Finished rollouts: all terminated after %s timesteps" % (t)) 501 | print("Added {} transitions to agent replay pool".format(transition_count)) 502 | print("Agent replay pool: {}/{}".format(len(self.agent.replay_pool), self.agent.replay_pool.capacity)) 503 | 504 | def _train_agent(self,IRL=False): 505 | if self._augment_offline_data: 506 | print("Augmenting model data with RAD") 507 | if IRL: 508 | self.agent.optimize(n_updates=self._policy_update_steps, env_pool=self.model.model.memory, 509 | env_ratio=self._real_sample_ratio, augment_data=self._augment_offline_data,reward_function=self.reward_function) 510 | else: 511 | self.agent.optimize(n_updates=self._policy_update_steps, env_pool=self.model.model.memory, 512 | env_ratio=self._real_sample_ratio, augment_data=self._augment_offline_data) 513 | 514 | 515 | def _train_reward(self): 516 | 517 | sample_pool = ReplayPool(capacity=1e6) 518 | init_state,traj_state,traj_action=self.expert_pool.sample_traj(self._reward_step) 519 | #init_state=np.array(init_state) 520 | init_state=torch.FloatTensor(np.array(init_state)).to(self._device) 521 | 522 | self._rollout_model_expert(init_state,sample_pool) 523 | samples = sample_pool.sample_all() 524 | 525 | state_batch = np.array([i for arr in samples.state for i in arr]) 526 | action_batch = np.array([i for arr in samples.action for i in arr]) 527 | 528 | agent_state_batch = torch.FloatTensor(state_batch).to(self._device) 529 | agent_action_batch = torch.FloatTensor(action_batch).to(self._device) 530 | 531 | 532 | traj_state=traj_state.reshape(-1,traj_state.shape[2]) 533 | traj_action=traj_action.reshape(-1,traj_action.shape[2]) 534 | expert_state_batch = torch.FloatTensor(traj_state).to(self._device) 535 | expert_action_batch = torch.FloatTensor(traj_action).to(self._device) 536 | 537 | agent_r=self.reward_function(torch.cat((agent_state_batch,agent_action_batch),1)) 538 | expert_r=self.reward_function(torch.cat((expert_state_batch,expert_action_batch),1)) 539 | #print(expert_r.shape) 540 | #avg_reward_norm = torch.mean(torch.squeeze(agent_r ** 2)) 541 | 542 | #regularizer = avg_reward_norm 543 | loss=agent_r.mean()-expert_r.mean() 544 | #loss=agent_r.sum()/2000-expert_r.sum()/2000 545 | self.reward_opti.zero_grad() 546 | loss.backward() 547 | self.reward_opti.step() 548 | 549 | def _load_population_models(self, params): 550 | model_dirs = params['population_model_dirs'] 551 | print("Loading population on {} models".format(len(model_dirs))) 552 | self.population_models = {} 553 | for i, m in enumerate(model_dirs): 554 | pop_env = EnsembleGymEnv(params, self.model.real_env, self.model.eval_env) 555 | pop_env.model.load_model(m) 556 | # takes up RAM, let's remove 557 | del pop_env.model.memory_val 558 | del pop_env.model.memory 559 | self.population_models['model_{}'.format(i)] = pop_env 560 | print("Loaded model {}".format(m.split('/')[-1])) 561 | print("Finished loading {} population models".format(len(model_dirs))) 562 | 563 | def train_epoch(self, init=False): 564 | timesteps = 0 565 | error = None 566 | env = self.model.real_env 567 | collect_steps = self._init_collect if init else self._max_steps 568 | while timesteps < collect_steps: 569 | done = False 570 | # check if we were actually mid-rollout at the end of the last epoch 571 | if self._done: 572 | state = env.reset() 573 | self._curr_rollout = [] 574 | self._env_step = 0 575 | else: 576 | state = self._curr_rollout[-1].nextstate 577 | while (not done) and (timesteps < collect_steps): 578 | if init: 579 | action = env.action_space.sample() 580 | else: 581 | action = self.agent.get_action(state, oac=self._oac) 582 | nextstate, reward, done, _ = env.step(action) 583 | self._env_step += 1 584 | # Check if environment actually terminated or just ran out of time 585 | if done and self._env_step != env.spec.max_episode_steps: 586 | real_done = True 587 | else: 588 | real_done = False 589 | t = Transition(state, action, reward, nextstate, real_done) 590 | self._curr_rollout.append(t) 591 | timesteps += 1 592 | if (timesteps) % 100 == 0: 593 | print("Collected Timesteps: %s" % (timesteps)) 594 | if done: 595 | self._push_trajectory() 596 | state = nextstate 597 | self._done = done 598 | if not init: 599 | if timesteps % self._model_train_freq == 0: 600 | self._train_model() 601 | if timesteps % self._rollout_model_freq == 0: 602 | self._rollout_model() 603 | if timesteps % self._train_policy_every == 0: 604 | self._train_agent() 605 | 606 | 607 | if init: 608 | self._train_model() 609 | self._rollout_model() 610 | else: 611 | self._n_epochs += 1 612 | 613 | errors = [self.model.model.models[i].get_acquisition(self._curr_rollout, self.model.model.state_filter, 614 | self.model.model.action_filter) 615 | for i in range(len(self.model.model.models))] 616 | error = np.sqrt(np.mean(np.array(errors) ** 2)) 617 | print("\nMSE Loss on latest rollout: %s" % error) 618 | steps_k_used = self._cur_steps_k 619 | self._steps_k_update() 620 | return timesteps, error, steps_k_used 621 | 622 | def _push_trajectory(self): 623 | collect_steps = len(self._curr_rollout) 624 | # randomly allocate the data to train and validation 625 | train_val_ind = random.sample(range(collect_steps), collect_steps) 626 | num_valid = int(np.floor(self.model.model.train_val_ratio * collect_steps)) 627 | train_ind = train_val_ind[num_valid:] 628 | for i, t in enumerate(self._curr_rollout): 629 | self.model.update_state_filter(t.state) 630 | self.model.update_action_filter(t.action) 631 | if i in train_ind: 632 | self.model.model.add_data(t) 633 | else: 634 | self.model.model.add_data_validation(t) 635 | print("\nAdded {} samples for train, {} for valid".format(str(len(train_ind)), 636 | str(len(train_val_ind) - len(train_ind)))) 637 | 638 | def train_offline(self, num_epochs, save_model=False, save_policy=False, load_model_dir=None): 639 | timesteps = 0 640 | val_size = 0 641 | train_size = 0 642 | 643 | # d4rl stuff - load all the offline data and train 644 | env = self.model.real_env 645 | # dataset = d4rl.qlearning_dataset(env, limit=5000) 646 | if self._params['env_name'] != 'AntMOPOEnv': 647 | dataset = d4rl.qlearning_dataset(env) 648 | else: 649 | with open('/Meta-Offline-RL/ant_mopo_1m_dataset.pkl', 'rb') as f: 650 | dataset = pickle.load(f) 651 | 652 | N = dataset['rewards'].shape[0] 653 | rollout = [] 654 | 655 | if load_model_dir or self._params['ensemble_replace_model_dirs']: 656 | # Load pretrained model, overrride this if population loading flag is on 657 | if not self._params['ensemble_replace_model_dirs']: 658 | errors = self.model.model.load_model(load_model_dir) 659 | else: 660 | print(self._params['ensemble_replace_model_dirs']) 661 | errors = self.model.model.load_model_from_population(self._params['ensemble_replace_model_dirs']) 662 | else: 663 | self.model.update_state_filter(dataset['observations'][0]) 664 | 665 | for i in range(N): 666 | state = dataset['observations'][i] 667 | action = dataset['actions'][i] 668 | nextstate = dataset['next_observations'][i] 669 | reward = dataset['rewards'][i] 670 | done = bool(dataset['terminals'][i]) 671 | 672 | t = Transition(state, action, reward, nextstate, done) 673 | rollout.append(t) 674 | 675 | self.model.update_state_filter(nextstate) 676 | self.model.update_action_filter(action) 677 | 678 | # Do this probabilistically to avoid maintaining a huge array of indices 679 | if random.uniform(0, 1) < self.model.model.train_val_ratio: 680 | self.model.model.add_data_validation(t) 681 | val_size += 1 682 | else: 683 | self.model.model.add_data(t) 684 | train_size += 1 685 | timesteps += 1 686 | 687 | self._done = True 688 | 689 | print("\nAdded {} samples for train, {} for valid".format(str(train_size), str(val_size))) 690 | 691 | if save_model: 692 | print('Saving model!') 693 | 694 | self._train_model(d4rl_init=True, save_model=save_model) 695 | 696 | if self._params['morel']: 697 | self._morel_threshold = self._get_morel_threshold() 698 | self.model.model.set_morel_hparams(self._morel_threshold, self._morel_halt_reward) 699 | else: 700 | self._morel_threshold = None 701 | 702 | rewards, rewards_m, k_used, mopo_lam, myopic_wm, myopic_pop, rewards_pop, myopic_pop_worst, myopic_wm_worst = [], [], [], [], [], [], [], [], [] 703 | for i in range(num_epochs): 704 | self._rollout_model(Penalty_only=True) 705 | self._train_agent(IRL=True) 706 | if i % 20 == 0 and not self.transfer: 707 | self._train_reward() 708 | 709 | reward_model = self.test_agent(use_model=True, n_evals=10) 710 | reward_actual_stats = self.test_agent(use_model=False) 711 | print("------------------------") 712 | stats_fmt = "{:<20}{:>30}" 713 | stats_str = ["Epoch", 714 | "WM Reward Mean", 715 | "WM Reward Max", 716 | "WM Reward Min", 717 | "WM Reward StdDev", 718 | "True Reward Mean", 719 | "True Reward Max", 720 | "True Reward Min", 721 | "True Reward StdDev"] 722 | stats_num = [i, 723 | reward_model.mean().round(2), 724 | reward_model.max().round(2), 725 | reward_model.min().round(2), 726 | reward_model.std().round(2), 727 | reward_actual_stats.mean().round(2), 728 | reward_actual_stats.max().round(2), 729 | reward_actual_stats.min().round(2), 730 | reward_actual_stats.std().round(2)] 731 | if hasattr(self, "population_models"): 732 | reward_pop = self.test_agent_population(full_trajectories=True, n_evals=10) 733 | reward_model_myopic, reward_pop_myopic = self.test_agent_population(full_trajectories=False) 734 | reward_model_myopic_worst, reward_pop_myopic_worst = self.test_agent_population(full_trajectories=False, 735 | bad_states=True) 736 | pop_str = ["WM Myopic Mean", "WM Myopic Mean Worst"] 737 | pop_num = [reward_model_myopic.mean().round(2), reward_model_myopic_worst.mean().round(2)] 738 | for j, (stat, stat_myopic, stat_myopic_worst) in enumerate( 739 | zip(reward_pop, reward_pop_myopic, reward_pop_myopic_worst)): 740 | pop_str += ["Pop WM {} Mean".format(j), 741 | "Pop WM {} Max".format(j), 742 | "Pop WM {} Min".format(j), 743 | "Pop WM {} StdDev".format(j), 744 | "Pop WM {} Myopic Mean".format(j), 745 | "Pop WM {} Myopic Mean Worst".format(j)] 746 | pop_num += [stat.mean().round(2), 747 | stat.max().round(2), 748 | stat.min().round(2), 749 | stat.std().round(2), 750 | stat_myopic.mean().round(2), 751 | stat_myopic_worst.mean().round(2)] 752 | stats_str.extend(pop_str) 753 | stats_num.extend(pop_num) 754 | myopic_wm.append(reward_model_myopic.mean()) 755 | myopic_pop.append([s.mean() for s in reward_pop_myopic]) 756 | rewards_pop.append([s.mean() for s in reward_pop]) 757 | myopic_wm_worst.append([s.mean() for s in reward_model_myopic_worst]) 758 | myopic_pop_worst.append([s.mean() for s in reward_pop_myopic_worst]) 759 | for s, n in zip(stats_str, stats_num): 760 | print(stats_fmt.format(s, n)) 761 | print("------------------------") 762 | # Log to csv (offline) 763 | rewards.append(reward_actual_stats.mean()) 764 | rewards_m.append(reward_model.mean()) 765 | k_used.append(self._cur_steps_k) 766 | if self._params['tune_mopo_lam']: 767 | ml = self.model.model.log_mopo_lam.exp().item() 768 | else: 769 | ml = self.model.model.mopo_lam 770 | mopo_lam.append(ml) 771 | save_stats = {'Reward': rewards, 'Reward_WM': rewards_m, 'k_used': k_used, 'mopo_lam': mopo_lam} 772 | if hasattr(self, "population_models"): 773 | save_stats['Myopic WM'] = myopic_wm 774 | save_stats['Myopic Population'] = myopic_pop 775 | save_stats['Rewards Population'] = rewards_pop 776 | save_stats['Myopic WM Worst'] = myopic_wm_worst 777 | save_stats['Myopic Population Worst'] = myopic_pop_worst 778 | df = pd.DataFrame(save_stats) 779 | lam = ['Adaptive' if self._params['adapt'] == 1 else 'fixed{}'.format(str(self._params['lam']))][0] 780 | save_name = "{}_{}_resid{}_{}_{}_offline".format(self._params['env_name'], lam, str(self._params['pca']), 781 | self._params['filename'], 782 | str(self._params['seed'])) 783 | if self._params['comment']: 784 | save_name = save_name + '_' + self._params['comment'] 785 | save_name += '.csv' 786 | df.to_csv(save_name) 787 | 788 | if save_policy and i % 20 == 0: 789 | save_path = './model_saved_weights_seed{}'.format(self._params['seed']) 790 | check_or_make_folder(save_path) 791 | print("Saving policy trained offline") 792 | self.agent.save_policy( 793 | # "{}".format(self.model.model._model_id), 794 | save_path, 795 | num_epochs=i, 796 | rew=int(reward_actual_stats.mean()) 797 | ) 798 | 799 | if not load_model_dir: 800 | errors = [self.model.model.models[i].get_acquisition(rollout[:1000], self.model.model.state_filter, 801 | self.model.model.action_filter) 802 | for i in range(len(self.model.model.models))] 803 | error = np.sqrt(np.mean(np.array(errors) ** 2)) 804 | print("\nMSE Loss on offline rollouts: %s" % error) 805 | steps_k_used = self._cur_steps_k 806 | self._steps_k_update() 807 | 808 | return timesteps, error, steps_k_used, rewards 809 | 810 | def _steps_k_update(self): 811 | if isinstance(self._steps_k, int): 812 | return 813 | else: 814 | steps_min, steps_max, start_epoch, end_epoch = self._steps_k 815 | m = (steps_max - steps_min) / (end_epoch - start_epoch) 816 | c = steps_min - m * start_epoch 817 | new_steps_k = m * self._n_epochs + c 818 | new_steps_k = int(min(steps_max, max(new_steps_k, steps_min))) 819 | if new_steps_k == self._cur_steps_k: 820 | return 821 | else: 822 | print("\nChanging model step size, going from %s to %s" % (self._cur_steps_k, new_steps_k)) 823 | self._cur_steps_k = new_steps_k 824 | new_pool_size = int( 825 | self._cur_steps_k * self._num_rollouts * ( 826 | self._max_steps / self._model_train_freq) * self._model_retain_epochs) 827 | print("\nReallocating agent pool, going from %s to %s" % (self.agent.replay_pool.capacity, new_pool_size)) 828 | self.agent.reallocate_replay_pool(new_pool_size) 829 | 830 | def get_pessimistic_states(self): 831 | all_states = self.model.model.memory.sample_all() 832 | all_states, _, all_rewards, _, _ = all_states 833 | all_states, all_rewards = np.array(all_states), np.array(all_rewards) 834 | worst_states_idx = all_rewards.argsort()[:5000] 835 | worst_states = all_states[worst_states_idx] 836 | return torch.FloatTensor(worst_states).to(device) 837 | 838 | def test_agent_population(self, full_trajectories=True, n_evals=5, bad_states=False): 839 | if not hasattr(self, "pessimistic_states"): 840 | self.pessimistic_states = self.get_pessimistic_states() 841 | if full_trajectories: 842 | return [self.test_agent(use_model=True, n_evals=n_evals, population_idx=model) for model in 843 | self.population_models] 844 | else: 845 | # Do random sampling of 5000 states and rolling out 25 steps 846 | if not bad_states: 847 | start_states = torch.FloatTensor( 848 | np.array(self.model.model.memory.sample(5000, unique=False)[0])).to(self._device) 849 | else: 850 | start_states = self.pessimistic_states 851 | # Need to test on self as we don't know how well we perform here under normal conditions 852 | own_WM_rewards = self.test_agent_myopic(start_states) 853 | # Now test on population 854 | pop_WM_rewards = [self.test_agent_myopic(start_states, population_idx=m) for m in self.population_models] 855 | return own_WM_rewards, pop_WM_rewards 856 | 857 | def test_agent_myopic(self, start_states, num_steps=100, population_idx=None): 858 | if population_idx: 859 | print("Getting myopic returns on populations models") 860 | test_env = self.population_models[population_idx] 861 | else: 862 | print("Getting myopic returns on World Model we trained in") 863 | test_env = self.model 864 | state = start_states 865 | sum_rewards = np.zeros(start_states.shape[0]) 866 | done_false = [False for _ in range(start_states.shape[0])] 867 | # needed to subset the rewards properly 868 | idxs_remaining = np.arange(start_states.shape[0]) 869 | t = 0 870 | test_env.convert_filter_to_torch() 871 | while t < num_steps: 872 | t += 1 873 | with torch.no_grad(): 874 | # Get deterministic action 875 | _, _, action_det = self.agent.policy(state) 876 | nextstate, reward = test_env.model.random_env_step(state, 877 | action_det, 878 | get_var=self._keep_logvar, 879 | deterministic=self._deterministic, 880 | disable_mopo=True 881 | ) 882 | if self._keep_logvar: 883 | nextstate, nextstate_var, reward, reward_var = nextstate[0], nextstate[1], reward[0], reward[1] 884 | if self._is_done_func: 885 | done = self._is_done_func(nextstate).cpu().numpy().tolist() 886 | else: 887 | done = done_false[:nextstate.shape[0]] 888 | not_done = ~np.array(done) 889 | if self._reward_head: 890 | reward = reward.cpu().detach().numpy() 891 | else: 892 | reward = torch_reward(self.model.name, nextstate, action_det, done) 893 | sum_rewards[idxs_remaining] += reward 894 | idxs_remaining = idxs_remaining[not_done] 895 | if len(nextstate.shape) == 1: 896 | nextstate.unsqueeze_(0) 897 | state = nextstate[not_done] 898 | if len(state.shape) == 1: 899 | state.unsqueeze_(0) 900 | if not_done.sum() == 0: 901 | break 902 | return sum_rewards 903 | 904 | def test_agent(self, use_model=False, n_evals=None, population_idx=None): 905 | if not use_model: 906 | assert population_idx is None, "You are evaluating performance on a real environment, why are you specifying population index?" 907 | rollout_rewards = [] 908 | n_evals = n_evals if n_evals else self._n_eval_rollouts 909 | if use_model: 910 | if population_idx: 911 | test_env = self.population_models[population_idx] 912 | else: 913 | test_env = self.model 914 | else: 915 | test_env = self.model.eval_env 916 | for _ in range(n_evals): 917 | total_reward = 0 918 | time_step = 0 919 | done = False 920 | state = test_env.reset() 921 | while not done: 922 | time_step += 1 923 | if self.agent.augment_sac and self.agent.context_type == 'rad_augmentation': 924 | state = np.concatenate((state, np.ones_like(state))) 925 | elif self.agent.augment_sac and self.agent.context_type == 'rad_magnitude': 926 | if len(state) == 1: 927 | state = state[0] 928 | state = np.concatenate((state, np.ones((1,)))) 929 | action = self.agent.get_action(state, deterministic=True) 930 | state, reward, done, info = test_env.step(action) 931 | if (self._params['env_name'] == 'AntMOPOEnv') and (not use_model): 932 | reward = info['reward_angle'] + info['reward_ctrl'] + 1 933 | reward = 0 if reward is None else reward 934 | total_reward += reward 935 | rollout_rewards.append(total_reward) 936 | rollout_rewards = np.array(rollout_rewards) 937 | return rollout_rewards 938 | 939 | def modify_online_training_params(self, online_params): 940 | """ 941 | Method to reassign the important training hyperparams to online, as offline hyperparams are different 942 | """ 943 | self._num_rollouts = online_params['num_rollouts_per_step'] * online_params['model_train_freq'] 944 | self._model_retain_epochs = online_params['model_retain_epochs'] 945 | self._steps_k = online_params['steps_k'] 946 | self._policy_update_steps = online_params['policy_update_steps'] 947 | self._model_train_freq = online_params['model_train_freq'] 948 | self._rollout_model_freq = online_params['rollout_model_freq'] 949 | self._train_policy_every = online_params['train_policy_every'] 950 | 951 | def _load_model_buffer_into_policy(self, new_buffer_size=None): 952 | """ 953 | Method to load model replay buffer into the policy (seed for model-free training) 954 | """ 955 | memory, memory_val = self.model.model.memory, self.model.model.memory_val 956 | if not new_buffer_size: 957 | new_buffer_size = int(len(memory) + len(memory_val)) 958 | new_pool = ReplayPool(capacity=new_buffer_size) 959 | train_transitions = memory.get_all() 960 | val_transitions = memory_val.get_all() 961 | all_transitions = train_transitions + val_transitions 962 | for t in all_transitions: 963 | new_pool.push(t) 964 | print("Reallocating policy replay buffer as world model memory") 965 | self.agent.replay_pool = new_pool 966 | 967 | def train_policy_model_free(self, n_random_actions=0, update_timestep=1, n_collect_steps=0, log_interval=1000, 968 | use_model_buffer=True, total_steps=3e6, policy_buffer_size=None, clear_buffer=False, 969 | use_modified_env=False, horizon=None): 970 | """ 971 | Method to train the internal policy in a model-free setting 972 | """ 973 | 974 | if not use_modified_env: 975 | env = self.model.eval_env 976 | else: 977 | # Using modified environment! TODO: make this more general, this is just half-cheetah right now. 978 | from modified_envs import HalfCheetahEnv 979 | print('Using modified environments!') 980 | env = HalfCheetahEnv() 981 | horizon = 1000 982 | agent = self.agent 983 | 984 | if use_model_buffer and not clear_buffer: 985 | self._load_model_buffer_into_policy(new_buffer_size=policy_buffer_size) 986 | 987 | if clear_buffer: 988 | self.agent.replay_pool = ReplayPool(capacity=1e6) 989 | # Recollect 5000 new transitions. 990 | n_collect_steps = 5000 991 | 992 | avg_length = 0 993 | time_step = 0 994 | cumulative_timestep = 0 995 | cumulative_log_timestep = 0 996 | n_updates = 0 997 | i_episode = 0 998 | log_episode = 0 999 | samples_number = 0 1000 | episode_rewards = [] 1001 | episode_steps = [] 1002 | all_rewards = [] 1003 | all_timesteps = [] 1004 | all_lengths = [] 1005 | 1006 | random.seed(self._seed) 1007 | torch.manual_seed(self._seed) 1008 | np.random.seed(self._seed) 1009 | env.seed(self._seed) 1010 | env.action_space.np_random.seed(self._seed) 1011 | 1012 | max_steps = horizon if horizon is not None else env.spec.max_episode_steps 1013 | 1014 | while samples_number < total_steps: 1015 | time_step = 0 1016 | episode_reward = 0 1017 | i_episode += 1 1018 | log_episode += 1 1019 | state = env.reset() 1020 | done = False 1021 | while (not done): 1022 | cumulative_log_timestep += 1 1023 | cumulative_timestep += 1 1024 | time_step += 1 1025 | samples_number += 1 1026 | if samples_number < n_random_actions: 1027 | action = env.action_space.sample() 1028 | else: 1029 | action = agent.get_action(state) 1030 | nextstate, reward, done, _ = env.step(action) 1031 | # Terminate if over horizon 1032 | if horizon is not None and time_step == horizon: 1033 | done = True 1034 | # if we hit the time-limit, it's not a 'real' done; we don't want to assign low value to those states 1035 | real_done = False if time_step == max_steps else done 1036 | agent.replay_pool.push(Transition(state, action, reward, nextstate, real_done)) 1037 | state = nextstate 1038 | episode_reward += reward 1039 | # update if it's time 1040 | if cumulative_timestep % update_timestep == 0 and cumulative_timestep > n_collect_steps: 1041 | q1_loss, q2_loss, pi_loss, a_loss = agent.optimize(update_timestep, 1042 | augment_data=self._augment_offline_data) 1043 | n_updates += 1 1044 | # logging 1045 | if cumulative_timestep % log_interval == 0 and cumulative_timestep > n_collect_steps: 1046 | avg_length = np.mean(episode_steps) 1047 | running_reward = np.mean(episode_rewards) 1048 | eval_reward = self.test_agent(n_evals=1).mean() 1049 | print( 1050 | 'Episode {} \t Samples {} \t Avg length: {} \t Test reward: {} \t Train reward: {} \t Number of Policy Updates: {}'.format( 1051 | i_episode, samples_number, avg_length, eval_reward, running_reward, n_updates)) 1052 | episode_steps = [] 1053 | episode_rewards = [] 1054 | all_timesteps.append(cumulative_timestep) 1055 | all_rewards.append(eval_reward) 1056 | all_lengths.append(avg_length) 1057 | df = pd.DataFrame( 1058 | {'Timesteps': all_timesteps, 'Reward': all_rewards, 'Average_Length': all_lengths}) 1059 | save_name = "model_free_{}_seed{}".format(self._params['env_name'], str(self._params['seed'])) 1060 | save_name += '.csv' 1061 | df.to_csv(save_name) 1062 | 1063 | episode_steps.append(time_step) 1064 | episode_rewards.append(episode_reward) 1065 | 1066 | def _get_morel_threshold(self): 1067 | """ Uses a UCB heuristic similar to author's paper to calculate the threshold except with robust statistics """ 1068 | # Interpolates between the median and the 99th percentile penalty values of the offline data 1069 | preds = self.model.model.get_replay_buffer_predictions(only_validation=True, return_sample=True) 1070 | sample, mu, logvar = preds.chunk(3, dim=2) 1071 | 1072 | sample_nextstates, sample_rewards, states_mu, rewards_mu, logvar_states, logvar_rewards = sample[:,:,:-1], sample[:,:,-1], mu[:,:,:-1], mu[:,:,-1], logvar[:,:,:-1], logvar[:,:,-1] 1073 | 1074 | allocation = torch.randint(0, self.model.model.num_elites, (preds.shape[1],), device=device) 1075 | allocation = torch.tensor([self.model.model._elites_idx[idx] for idx in allocation]).to(device) 1076 | allocation_states = allocation.repeat(sample_nextstates.shape[2], 1).T.view(1, -1, sample_nextstates.shape[2]) 1077 | allocation_rewards = allocation.view(1, -1, 1) 1078 | 1079 | nextstates = sample_nextstates.gather(0, allocation_states).squeeze() 1080 | rewards = sample_rewards.unsqueeze(2).gather(0, allocation_rewards).squeeze() 1081 | 1082 | penalties = self.model.model.get_penalty(self._params['mopo_penalty_type'], nextstates, sample_nextstates, logvar_states, rewards, sample_rewards, logvar_rewards, states_mu, rewards_mu, allocation) 1083 | 1084 | penalties_median, penalties_p99 = penalties.median(), penalties.quantile(.99) 1085 | 1086 | penalties_mad = (penalties - penalties_median).abs().median() 1087 | 1088 | beta_max = (penalties_p99 - penalties_median) / penalties_mad 1089 | 1090 | beta = beta_max * self._params['morel_thresh'] 1091 | 1092 | return (penalties_median + beta * penalties_mad).item() 1093 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import math 4 | from collections import deque, namedtuple 5 | from typing import List 6 | import time 7 | import sys 8 | from copy import deepcopy 9 | import pickle 10 | import datetime 11 | import pdb 12 | import time 13 | 14 | import gym 15 | from gym.wrappers import TimeLimit 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch.utils.data import DataLoader, Dataset 21 | from torch.utils.data.sampler import SequentialSampler 22 | import math 23 | 24 | from utils import MeanStdevFilter, prepare_data, reward_func, GaussianMSELoss, truncated_normal_init, Transition, \ 25 | ReplayPool, check_or_make_folder, FasterReplayPool 26 | 27 | ### Model 28 | 29 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 30 | 31 | 32 | class EnsembleGymEnv(gym.Env): 33 | """Wraps the Ensemble with a gym API, Outputs Normal states, and contains a copy of the true environment""" 34 | 35 | def __init__(self, params, env, eval_env, due_override=None): 36 | super(EnsembleGymEnv, self).__init__() 37 | self.name = params['env_name'] 38 | self.real_env = env 39 | self.eval_env = eval_env 40 | self.observation_space = self.real_env.observation_space 41 | self.action_space = self.real_env.action_space 42 | params['action_space'] = self.real_env.action_space.shape[0] 43 | params['observation_space'] = self.real_env.observation_space.shape[0] 44 | self.model = Ensemble(params, due_override=due_override) 45 | self.current_state = self.reset() 46 | self.reward_head = params['reward_head'] 47 | self.reward_func = reward_func 48 | self.action_bounds = self.get_action_bounds() 49 | self.spec = self.real_env.spec 50 | self._elapsed_steps = 0 51 | self._max_timesteps = self.real_env.env.spec.max_episode_steps 52 | torch.manual_seed(params['seed']) 53 | 54 | def get_action_bounds(self): 55 | Bounds = namedtuple('Bounds', ('lowerbound', 'upperbound')) 56 | lb = self.real_env.action_space.low 57 | ub = self.real_env.action_space.high 58 | return Bounds(lowerbound=lb, upperbound=ub) 59 | 60 | def seed(self, seed=None): 61 | return self.real_env.seed(seed) 62 | 63 | def train_model(self, max_epochs, n_samples: int = 200000, d4rl_init=False, save_model=False, 64 | min_model_epochs=None): 65 | self.model.train_model( 66 | max_epochs=max_epochs, 67 | n_samples=n_samples, 68 | d4rl_init=d4rl_init, 69 | save_model=save_model, 70 | min_model_epochs=min_model_epochs) 71 | 72 | def step(self, action): 73 | action = np.clip(action, self.action_bounds.lowerbound, self.action_bounds.upperbound) 74 | next_state, reward = self.model.predict_state( 75 | self.current_state.reshape(1, -1), 76 | action.reshape(1, -1)) 77 | if not reward: 78 | reward = self.reward_func( 79 | self.current_state, 80 | next_state, 81 | action, 82 | self.name, 83 | is_done_func=self.model.is_done_func) 84 | if self._elapsed_steps > self._max_timesteps: 85 | done = True 86 | elif self.model.is_done_func: 87 | done = self.model.is_done_func(torch.Tensor(next_state).reshape(1, -1)).item() 88 | else: 89 | done = False 90 | self.current_state = next_state 91 | self._elapsed_steps += 1 92 | return next_state, reward, done, {} 93 | 94 | def reset(self): 95 | self.current_state = self.eval_env.reset() 96 | self._elapsed_steps = 0 97 | return self.current_state 98 | 99 | def render(self, mode='human'): 100 | raise NotImplementedError 101 | 102 | def update_state_filter(self, new_state): 103 | self.model.state_filter.update(new_state) 104 | 105 | def update_action_filter(self, new_action): 106 | self.model.action_filter.update(new_action) 107 | 108 | def convert_filter_to_torch(self): 109 | self.model.state_filter.update_torch() 110 | self.model.action_filter.update_torch() 111 | 112 | 113 | class TransitionDataset(Dataset): 114 | ## Dataset wrapper for sampled transitions 115 | def __init__(self, batch: Transition, state_filter, action_filter): 116 | state_action_filtered, delta_filtered = prepare_data( 117 | batch.state, 118 | batch.action, 119 | batch.nextstate, 120 | state_filter, 121 | action_filter) 122 | self.data_X = torch.Tensor(state_action_filtered) 123 | self.data_y = torch.Tensor(delta_filtered) 124 | self.data_r = torch.Tensor(np.array(batch.reward)) 125 | 126 | def __len__(self): 127 | return len(self.data_X) 128 | 129 | def __getitem__(self, index): 130 | return self.data_X[index], self.data_y[index], self.data_r[index] 131 | 132 | 133 | class EnsembleTransitionDataset(Dataset): 134 | ## Dataset wrapper for sampled transitions 135 | def __init__(self, batch: Transition, state_filter, action_filter, n_models=1): 136 | state_action_filtered, delta_filtered = prepare_data( 137 | batch.state, 138 | batch.action, 139 | batch.nextstate, 140 | state_filter, 141 | action_filter) 142 | data_count = state_action_filtered.shape[0] 143 | # idxs = np.arange(0,data_count)[None,:].repeat(n_models, axis=0) 144 | # [np.random.shuffle(row) for row in idxs] 145 | idxs = np.random.randint(data_count, size=[n_models, data_count]) 146 | self._n_models = n_models 147 | self.data_X = torch.Tensor(state_action_filtered[idxs]) 148 | self.data_y = torch.Tensor(delta_filtered[idxs]) 149 | self.data_r = torch.Tensor(np.array(batch.reward)[idxs]) 150 | 151 | def __len__(self): 152 | return self.data_X.shape[1] 153 | 154 | def __getitem__(self, index): 155 | return self.data_X[:, index], self.data_y[:, index], self.data_r[:, index] 156 | 157 | 158 | class Ensemble(object): 159 | def __init__(self, params, due_override=None): 160 | 161 | self.params = params 162 | self.models = {i: Model(input_dim=params['ob_dim'] + params['ac_dim'], 163 | output_dim=params['ob_dim'] + params['reward_head'], 164 | is_probabilistic=params['logvar_head'], 165 | is_done_func=params['is_done_func'], 166 | reward_head=params['reward_head'], 167 | seed=params['seed'] + i, 168 | l2_reg_multiplier=params['l2_reg_multiplier'], 169 | num=i) 170 | for i in range(params['num_models'])} 171 | self.elites = {i: self.models[i] for i in range(params['num_elites'])} 172 | self._elites_idx = list(range(params['num_elites'])) 173 | self.num_models = params['num_models'] 174 | self.num_elites = params['num_elites'] 175 | self.output_dim = params['ob_dim'] + params['reward_head'] 176 | self.ob_dim = params['ob_dim'] 177 | self.memory = FasterReplayPool(action_dim=params['ac_dim'], state_dim=params['ob_dim'], 178 | capacity=params['train_memory']) 179 | self.memory_val = FasterReplayPool(action_dim=params['ac_dim'], state_dim=params['ob_dim'], 180 | capacity=params['val_memory']) 181 | self.train_val_ratio = params['train_val_ratio'] 182 | self.is_done_func = params['is_done_func'] 183 | self.is_probabilistic = params['logvar_head'] 184 | self._model_lr = params['model_lr'] if 'model_lr' in params else 0.001 185 | weights = [weight for model in self.models.values() for weight in model.weights] 186 | if self.is_probabilistic: 187 | self.max_logvar = torch.full((self.output_dim,), 0.5, requires_grad=True, device=device) 188 | self.min_logvar = torch.full((self.output_dim,), -10.0, requires_grad=True, device=device) 189 | weights.append({'params': [self.max_logvar]}) 190 | weights.append({'params': [self.min_logvar]}) 191 | self.set_model_logvar_limits() 192 | self.optimizer = torch.optim.Adam(weights, lr=self._model_lr) 193 | self._lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.3, verbose=False) 194 | self._model_id = "Model_{}_seed{}_{}".format(params['env_name'], params['seed'], 195 | datetime.datetime.now().strftime('%Y_%m_%d_%H-%M-%S')) 196 | self._env_name = params['env_name'] 197 | self.state_filter = MeanStdevFilter(params['observation_space']) 198 | self.action_filter = MeanStdevFilter(params['action_space']) 199 | self.use_automatic_lam_tuning = params['tune_mopo_lam'] 200 | 201 | self._removed_models = [] 202 | 203 | if params['mopo']: 204 | # this value is ignored if tuning is used 205 | self.mopo_lam = params['mopo_lam'] 206 | 207 | if self.use_automatic_lam_tuning: 208 | self.target_uncertainty = params['mopo_uncertainty_target'] 209 | self.log_mopo_lam = torch.full((1,), 1.0, requires_grad=True, device=device) 210 | self.lam_optimizer = torch.optim.Adam([self.log_mopo_lam], lr=0.01) 211 | else: 212 | self.mopo_lam = 0 213 | self.due_override = due_override 214 | 215 | self._morel_thresh = None 216 | self._morel_halt_reward = 0 217 | 218 | # def reset_models(self): 219 | # params = self.params 220 | # self.models = {i:Model(input_dim=params['ob_dim'] + params['ac_dim'], 221 | # output_dim=params['ob_dim'] + params['reward_head'], 222 | # is_probabilistic = params['logvar_head'], 223 | # is_done_func = params['is_done_func'], 224 | # reward_head = params['reward_head'], 225 | # seed = params['seed'] + i, 226 | # num=i) 227 | # for i in range(params['num_models'])} 228 | 229 | def forward(self, x: torch.Tensor): 230 | model_index = int(np.random.uniform() * len(self.models.keys())) 231 | return self.models[model_index].forward(x) 232 | 233 | def predict_state(self, state: np.array, action: np.array) -> (np.array, float): 234 | model_index = int(np.random.uniform() * len(self.models.keys())) 235 | return self.models[model_index].predict_state(state, action, self.state_filter, self.action_filter) 236 | 237 | def predict_state_at(self, state: np.array, action: np.array, model_index: int) -> (np.array, float): 238 | return self.models[model_index].predict_state(state, action, self.state_filter, self.action_filter) 239 | 240 | def add_data(self, step: Transition): 241 | # for step in rollout: 242 | self.memory.push(step) 243 | 244 | def add_data_validation(self, step: Transition): 245 | # for step in rollout: 246 | self.memory_val.push(step) 247 | 248 | def get_next_states_rewards(self, state, action, get_var=False, deterministic=False, return_mean=False): 249 | if self.mopo_lam == 0: 250 | nextstates_rewards = [ 251 | elite.get_next_state_reward(state, action, self.state_filter, self.action_filter, get_var, 252 | deterministic, return_mean) for elite 253 | in self.elites.values()] 254 | else: 255 | nextstates_rewards = [ 256 | model.get_next_state_reward(state, action, self.state_filter, self.action_filter, get_var, 257 | deterministic, return_mean) for model 258 | in self.models.values()] 259 | nextstates_all = torch.stack([torch.cat(nsr[0], dim=1) if get_var else nsr[0] for nsr in nextstates_rewards]) 260 | rewards_all = torch.stack([torch.cat(nsr[1], dim=1) if get_var else nsr[1] for nsr in nextstates_rewards]) 261 | return nextstates_all, rewards_all 262 | 263 | def get_max_state_action_uncertainty(self, state, action): 264 | state_action_uncertainties = [ 265 | elite.get_state_action_uncertainty(state, action, self.state_filter, self.action_filter) for 266 | elite in self.elites.values()] 267 | return torch.max(torch.stack([torch.norm(logvar.exp(), p='fro') for logvar in state_action_uncertainties])) 268 | 269 | #TODO: Use this method in the rollout code 270 | def get_elites_prediction(self, state, action): 271 | """ Gets a next state prediction according to the Elites """ 272 | 273 | allocation = torch.randint(0, self.num_elites, (state.shape[0],), device=device) 274 | 275 | return_var = True 276 | 277 | allocation = torch.tensor([self._elites_idx[idx] for idx in allocation]).to(device) 278 | # needs to get logvar for MOPO penalty 279 | get_var = True 280 | 281 | allocation_states = allocation.repeat(self.ob_dim, 1).T.view(1, -1, self.ob_dim) 282 | allocation_rewards = allocation.view(1, -1, 1) 283 | 284 | # need the parametric mean of the next states for the LOMPO and M2AC uncertainty metrics 285 | return_mean = True 286 | 287 | with torch.no_grad(): 288 | 289 | nextstates_all, rewards_all = self.get_next_states_rewards(state, action, get_var, False, return_mean) 290 | 291 | nextstates_all, nextstates_all_mu, nextstates_logvar = nextstates_all.chunk(3, dim=2) 292 | rewards_all, rewards_all_mu, rewards_logvar = rewards_all.chunk(3, dim=2) 293 | nextstates = nextstates_all.gather(0, allocation_states).squeeze() 294 | rewards = rewards_all.gather(0, allocation_rewards).squeeze() 295 | 296 | return {'nextstates': nextstates, 'nextstates_all': nextstates_all, 'nextstates_logvar': nextstates_logvar, 'rewards': rewards, 'rewards_all': rewards_all, 'rewards_logvar': rewards_logvar, 'nextstates_mu': nextstates_all_mu, 'rewards_mu': rewards_all_mu, 'allocation': allocation} 297 | 298 | def get_all_penalties_state_action(self, state, action, nextstates_info:dict=None): 299 | """Gets all the penalties (and predicted next state) from a given state-action tuple""" 300 | penalty_list = ['mopo_paper', 'morel', 'lompo', 'm2ac', 'ensemble_var', 'ensemble_std'] 301 | 302 | if not nextstates_info: 303 | nextstates_info = self.get_elites_prediction(state, action) 304 | 305 | with torch.no_grad(): 306 | penalties = {} 307 | 308 | for p in penalty_list: 309 | penalties[p] = self.get_penalty(p, **nextstates_info).cpu().numpy() 310 | 311 | return penalties, nextstates_info['nextstates'] 312 | 313 | def set_morel_hparams(self, morel_threshold, morel_halt_reward): 314 | self._morel_thresh, self._morel_halt_reward = morel_threshold, morel_halt_reward 315 | print("Set MOReL penalty threshold to {}, and negative HALT reward to {}".format(self._morel_thresh, self._morel_halt_reward)) 316 | 317 | def get_penalty(self, penalty_name: str, nextstates, nextstates_all, nextstates_logvar, rewards, rewards_all, rewards_logvar, nextstates_mu, rewards_mu, allocation): 318 | """Gets the penalty depending on name""" 319 | 320 | if penalty_name == 'mopo_default': 321 | nextstates_std = nextstates_logvar.exp().sqrt() 322 | mopo_pen = nextstates_std.norm(2,2).amax(0) 323 | elif penalty_name == 'mopo_paper': 324 | all_std = torch.cat((nextstates_logvar, rewards_logvar), dim=2).exp().sqrt() 325 | mopo_pen = all_std.norm(2,2).amax(0) 326 | elif penalty_name == 'ensemble_var': 327 | nextstates_var = nextstates_logvar.exp() 328 | mean_of_vars = torch.mean(nextstates_var, dim=0) 329 | var_of_means = torch.var(nextstates_all, dim=0) 330 | vr = mean_of_vars + var_of_means 331 | mopo_pen = torch.mean(vr, dim=1) 332 | elif penalty_name == 'ensemble_std': 333 | nextstates_var = nextstates_logvar.exp() 334 | mean_of_vars = torch.mean(nextstates_var, dim=0) 335 | var_of_means = torch.var(nextstates_all, dim=0) 336 | std = (mean_of_vars + var_of_means).sqrt() 337 | mopo_pen = torch.mean(std, dim=1) 338 | elif penalty_name == 'ensemble_var_rew': 339 | rewards_var = rewards_logvar.exp() 340 | mean_of_vars = torch.mean(rewards_var, dim=0) 341 | var_of_means = torch.var(rewards_all, dim=0) 342 | vr = mean_of_vars + var_of_means 343 | mopo_pen = torch.mean(vr, dim=1) 344 | elif penalty_name == 'ensemble_var_comb': 345 | nextstates_var = nextstates_logvar.exp() 346 | mean_of_vars = torch.mean(nextstates_var, dim=0) 347 | var_of_means = torch.var(nextstates_all, dim=0) 348 | vr = mean_of_vars + var_of_means 349 | mopo_pen1 = torch.mean(vr, dim=1) 350 | rewards_var = rewards_logvar.exp() 351 | mean_of_vars = torch.mean(rewards_var, dim=0) 352 | var_of_means = torch.var(rewards_all, dim=0) 353 | vr = mean_of_vars + var_of_means 354 | mopo_pen2 = torch.mean(vr, dim=1) 355 | mopo_pen = mopo_pen1 + mopo_pen2 356 | elif penalty_name == 'morel': 357 | # first try naive nested for loop 358 | # mopo_pen = torch.zeros_like(rewards) 359 | # for i, ns_i in enumerate(nextstates_all): 360 | # for j, ns_j in enumerate(nextstates_all): 361 | # # only need upper right triangle 362 | # if j > i: 363 | # mopo_pen = torch.max(mopo_pen, (ns_i - ns_j).norm(2,1)) 364 | 365 | # this parallelises the above code, runs 10x faster, and is 10x less readable 366 | mopo_pen = torch.cdist(nextstates_all.swapaxes(0,1), nextstates_all.swapaxes(0,1), 2).amax((1,2)) 367 | elif penalty_name == 'lompo': 368 | # Several steps: 369 | # 1) Get the next states and rewards from models 370 | # 2) Now sample a random next state from the uniform 371 | # 3) For the next state, pass it through the Gaussian induced by each NN to get a log likelihood 372 | # 4) Now take the variance of that likelihood over the models for the sampled next state 373 | # 5) Average over each reward, i.e.,: \hat{r} = 1/K * \sum_{i=1}^K [r_i(s_t,a_t)] - \lambda * u_i(s_t,a_t) 374 | # Let's not average reward for now to keep it consistent; we just care about uncertainty metric 375 | # rewards = rewards_all.mean(0) 376 | mus, stds = nextstates_mu, nextstates_logvar.exp().sqrt() 377 | dist = torch.distributions.Normal(mus, stds) 378 | ll = dist.log_prob(nextstates).sum(2) 379 | # the penalty is then just the variance of the log likelihoods, averaged across each next state prediction 380 | mopo_pen = ll.var(0) 381 | elif penalty_name == 'm2ac': 382 | # Several steps: 383 | # Let's assume they only did this on the Elites, but no implementation so can't be sure 384 | # 1) Figure out which model was allocated, and what elites remain 385 | # 2) Merge the remaining model Gaussians, leaving out the model that was allocated 386 | # 3) Calculate the KL divergence between the model that was allocated, and the merged Gaussian from step 2 387 | # We did 1) already, now for... 388 | # Merging the remaining Gaussians 389 | allocation_states = allocation.repeat(self.ob_dim, 1).T.view(1, -1, self.ob_dim) 390 | allocation_rewards = allocation.view(1, -1, 1) 391 | # Need to come up with allocations that: 392 | # 1) Exclude non-elites # actually scrap this 393 | # 2) Exclude the allocated model so we can do OvR metric 394 | # TODO: Maybe this is inefficient, is there some way to vectorize? 395 | # exclude_dic = {a: torch.tensor([idx for idx in self._elites_idx if idx != a]).to(device) for a in self._elites_idx} 396 | exclude_dic = {a: torch.tensor([idx for idx in range(self.num_models) if idx != a]).to(device) for a in range(self.num_models)} 397 | allocation_ovr = torch.stack([exclude_dic[i] for i in allocation.cpu().numpy()]) 398 | allocation_ovr_states = allocation_ovr.unsqueeze(2).repeat(1,1,self.ob_dim*2).swapaxes(1,0) 399 | allocation_ovr_rewards = allocation_ovr.unsqueeze(2).repeat(1,1,2).swapaxes(1,0) 400 | nextstates_mu_logvar = torch.cat((nextstates_mu, nextstates_logvar), dim=2) 401 | rewards_mu_logvar = torch.cat((rewards_mu, rewards_logvar), dim=2) 402 | nextstates_mu_alloc, nextstates_logvar_alloc = nextstates_mu_logvar.gather(0, allocation_states.repeat(1,1,2)).squeeze(0).chunk(2, dim=1) 403 | nextstates_mu_ovr, nextstates_logvar_ovr = nextstates_mu_logvar.gather(0, allocation_ovr_states).squeeze(0).chunk(2, dim=2) 404 | rewards_mu_alloc, rewards_logvar_alloc = rewards_mu_logvar.gather(0, allocation_rewards.repeat(1,1,2)).squeeze(0).chunk(2, dim=1) 405 | rewards_mu_ovr, rewards_logvar_ovr = rewards_mu_logvar.gather(0, allocation_ovr_rewards).squeeze(0).chunk(2, dim=2) 406 | nsr_mu_alloc = torch.cat((nextstates_mu_alloc, rewards_mu_alloc), dim=1) 407 | nsr_logvar_alloc = torch.cat((nextstates_logvar_alloc, rewards_logvar_alloc), dim=1) 408 | nsr_mu_ovr = torch.cat((nextstates_mu_ovr, rewards_mu_ovr), dim=2) 409 | nsr_logvar_ovr = torch.cat((nextstates_logvar_ovr, rewards_logvar_ovr), dim=2) 410 | merge_mu = nsr_mu_ovr.mean(0) 411 | merge_std = ((nsr_logvar_ovr.exp() + nsr_mu_ovr**2).mean(0) - merge_mu**2).clamp(min=1e-8).sqrt() 412 | alloc_gaussian = torch.distributions.Normal(nsr_mu_alloc, nsr_logvar_alloc.exp().sqrt()) 413 | merge_gaussian = torch.distributions.Normal(merge_mu, merge_std) 414 | mopo_pen = torch.distributions.kl_divergence(alloc_gaussian, merge_gaussian).sum(1) 415 | else: 416 | raise NotImplementedError 417 | 418 | return mopo_pen 419 | 420 | def random_env_step(self, state, action, get_var=False, deterministic=False, disable_mopo=False,IRL = False): 421 | """Randomly allocate the data through the different dynamics models""" 422 | 423 | allocation = torch.randint(0, self.num_elites, (state.shape[0],), device=device) 424 | 425 | return_var = get_var 426 | 427 | if self.mopo_lam != 0 or self._morel_thresh: 428 | # converts elite index into all-models index 429 | allocation = torch.tensor([self._elites_idx[idx] for idx in allocation]).to(device) 430 | # needs to get logvar for MOPO penalty 431 | get_var = True 432 | 433 | allocation_states = allocation.repeat(self.ob_dim, 1).T.view(1, -1, self.ob_dim) 434 | allocation_rewards = allocation.view(1, -1, 1) 435 | 436 | # need the parametric mean of the next states for the LOMPO and M2AC uncertainty metrics 437 | return_mean = True if self.params['mopo_penalty_type'] in ['lompo', 'm2ac'] else False 438 | 439 | nextstates_all, rewards_all = self.get_next_states_rewards(state, action, get_var, deterministic, return_mean) 440 | 441 | if get_var: 442 | if return_mean: 443 | nextstates_all, nextstates_all_mu, nextstates_logvar = nextstates_all.chunk(3, dim=2) 444 | rewards_all, rewards_all_mu, rewards_logvar = rewards_all.chunk(3, dim=2) 445 | else: 446 | nextstates_all, nextstates_logvar = nextstates_all.chunk(2, dim=2) 447 | rewards_all, rewards_logvar = rewards_all.chunk(2, dim=2) 448 | nextstates_all_mu = None 449 | rewards_all_mu = None 450 | nextstates = nextstates_all.gather(0, allocation_states).squeeze(0) 451 | rewards = rewards_all.gather(0, allocation_rewards).squeeze(0).squeeze(1) 452 | 453 | if disable_mopo: 454 | mopo_lam = 0 455 | morel_thresh = 0 456 | else: 457 | mopo_lam = self.mopo_lam 458 | morel_thresh = self._morel_thresh 459 | 460 | # TODO: Handle 1 dimensional states 461 | if mopo_lam != 0 or morel_thresh: 462 | 463 | mopo_pen = self.get_penalty(self.params['mopo_penalty_type'], nextstates, nextstates_all, nextstates_logvar, rewards, rewards_all, rewards_logvar, nextstates_all_mu, rewards_all_mu, allocation) 464 | 465 | if self.use_automatic_lam_tuning: 466 | mopo_lam = self.log_mopo_lam.exp() 467 | # save this for now because this method is called under no_grad 468 | self._uncertainty_diff_vector = (mopo_lam * mopo_pen - self.target_uncertainty).detach() 469 | else: 470 | mopo_lam = self.mopo_lam 471 | 472 | if mopo_lam != 0: 473 | rewards = rewards - mopo_lam * mopo_pen 474 | elif self._morel_thresh: 475 | rewards[mopo_pen > self._morel_thresh] += self._morel_halt_reward 476 | 477 | if self.params['env_name'] == 'AntMOPOEnv': 478 | rewards += 1 479 | if IRL: 480 | if return_var: 481 | nextstates_var = self.get_total_variance(nextstates_all, nextstates_logvar) 482 | rewards_var = self.get_total_variance(rewards_all, rewards_logvar) 483 | return (nextstates, nextstates_var), (rewards, rewards_var), -mopo_lam*mopo_pen 484 | return nextstates, rewards, -mopo_lam*mopo_pen 485 | 486 | if return_var: 487 | nextstates_var = self.get_total_variance(nextstates_all, nextstates_logvar) 488 | rewards_var = self.get_total_variance(rewards_all, rewards_logvar) 489 | return (nextstates, nextstates_var), (rewards, rewards_var), mopo_pen 490 | else: 491 | return nextstates, rewards, mopo_pen 492 | 493 | def get_mopo_pen(self, state, action, get_var=True, deterministic=True, mopo_penalty_type=None): 494 | """Randomly allocate the data through the different dynamics models""" 495 | nextstates_all, rewards_all = self.get_next_states_rewards(state, action, get_var, deterministic) 496 | 497 | if get_var: 498 | nextstates_all, nextstates_logvar = nextstates_all.chunk(2, dim=2) 499 | rewards_all, rewards_logvar = rewards_all.chunk(2, dim=2) 500 | 501 | if get_var: 502 | nextstates_all, nextstates_logvar = nextstates_all.chunk(2, dim=2) 503 | if self.mopo_lam != 0: 504 | if mopo_penalty_type == 'mopo_default': 505 | nextstates_std = nextstates_logvar.exp().sqrt() 506 | mopo_pen = nextstates_std.norm(2, 2).amax(0) 507 | elif mopo_penalty_type == 'ensemble_var': 508 | nextstates_var = nextstates_logvar.exp() 509 | mean_of_vars = torch.mean(nextstates_var, dim=0) 510 | var_of_means = torch.var(nextstates_all, dim=0) 511 | print(mean_of_vars.shape) 512 | print(var_of_means.shape) 513 | vr = mean_of_vars + var_of_means 514 | mopo_pen = torch.mean(vr, dim=1) 515 | elif mopo_penalty_type == 'ensemble_std': 516 | nextstates_var = nextstates_logvar.exp() 517 | mean_of_vars = torch.mean(nextstates_var, dim=0) 518 | var_of_means = torch.var(nextstates_all, dim=0) 519 | std = (mean_of_vars + var_of_means).sqrt() 520 | mopo_pen = torch.mean(std, dim=1) 521 | elif self.params['mopo_penalty_type'] == 'ensemble_var_rew': 522 | rewards_var = rewards_logvar.exp() 523 | mean_of_vars = torch.mean(rewards_var, dim=0) 524 | var_of_means = torch.var(rewards_all, dim=0) 525 | vr = mean_of_vars + var_of_means 526 | mopo_pen = torch.mean(vr, dim=1) 527 | elif mopo_penalty_type == 'due' or mopo_penalty_type == 'sparse_gp': 528 | combined = torch.cat((state, action), dim=1) 529 | combined = torch.tensor_split(combined, 50) 530 | mopo_pens = [] 531 | for c in combined: 532 | mopo_pens.append(self.due_override.predict_var(c)) 533 | mopo_pen = torch.cat(mopo_pens, dim=0) 534 | # mopo_pen = torch.mean(mopo_pen, dim=1) 535 | else: 536 | raise NotImplementedError 537 | 538 | return mopo_pen 539 | 540 | def update_lambda(self): 541 | lam_loss = (self.log_mopo_lam * self._uncertainty_diff_vector).mean() 542 | self.lam_optimizer.zero_grad() 543 | lam_loss.backward() 544 | self.lam_optimizer.step() 545 | 546 | @staticmethod 547 | def get_total_variance(mean_values, logvar_values): 548 | return (torch.var(mean_values, dim=0) + torch.mean(logvar_values.exp(), dim=0)).squeeze() 549 | 550 | def _get_validation_losses(self, validation_loader, get_weights=True): 551 | best_losses = [] 552 | best_weights = [] 553 | for model in self.models.values(): 554 | best_losses.append(model.get_validation_loss(validation_loader)) 555 | if get_weights: 556 | best_weights.append(deepcopy(model.state_dict())) 557 | best_losses = np.array(best_losses) 558 | return best_losses, best_weights 559 | 560 | def check_validation_losses(self, validation_loader): 561 | improved_any = False 562 | current_losses, current_weights = self._get_validation_losses(validation_loader, get_weights=True) 563 | improvements = ((self.current_best_losses - current_losses) / self.current_best_losses) > 0.01 564 | for i, improved in enumerate(improvements): 565 | if improved: 566 | self.current_best_losses[i] = current_losses[i] 567 | self.current_best_weights[i] = current_weights[i] 568 | improved_any = True 569 | return improved_any, current_losses 570 | 571 | def train_model(self, max_epochs: int = 100, n_samples: int = 200000, d4rl_init=False, save_model=False, 572 | min_model_epochs=None): 573 | self.current_best_losses = np.zeros( 574 | self.params['num_models']) + sys.maxsize # weird hack (YLTSI), there's almost surely a better way... 575 | self.current_best_weights = [None] * self.params['num_models'] 576 | val_improve = deque(maxlen=6) 577 | lr_lower = False 578 | min_model_epochs = 0 if not min_model_epochs else min_model_epochs 579 | if d4rl_init: 580 | # Train on the full buffer until convergence, should be under 5k epochs 581 | n_samples = len(self.memory) 582 | n_samples_val = len(self.memory_val) 583 | max_epochs = 1000 584 | elif len(self.memory) < n_samples: 585 | n_samples = len(self.memory) 586 | n_samples_val = len(self.memory_val) 587 | else: 588 | n_samples_val = int(np.floor((n_samples / (1 - self.train_val_ratio)) * (self.train_val_ratio))) 589 | 590 | samples_train = self.memory.sample(n_samples) 591 | samples_validate = self.memory_val.sample(n_samples_val) 592 | 593 | # TODO: shift the training and val dataset using the fn 'get_max_state_action_uncertainty' 594 | 595 | batch_size = 256 596 | if n_samples_val == len(self.memory_val): 597 | samples_validate = self.memory_val.sample_all() 598 | else: 599 | samples_validate = self.memory_val.sample(n_samples_val) 600 | ########## MIX VALDIATION AND TRAINING ########## 601 | new_samples_train_dict = dict.fromkeys(samples_train._fields) 602 | new_samples_validate_dict = dict.fromkeys(samples_validate._fields) 603 | randperm = np.random.permutation(n_samples + n_samples_val) 604 | 605 | train_idx, valid_idx = randperm[:n_samples], randperm[n_samples:] 606 | assert len(valid_idx) == n_samples_val 607 | 608 | for i, key in enumerate(samples_train._fields): 609 | train_vals = samples_train[i] 610 | valid_vals = samples_validate[i] 611 | all_vals = np.array(list(train_vals) + list(valid_vals)) 612 | train_vals = all_vals[train_idx] 613 | valid_vals = all_vals[valid_idx] 614 | new_samples_train_dict[key] = tuple(train_vals) 615 | new_samples_validate_dict[key] = tuple(valid_vals) 616 | 617 | samples_train = Transition(**new_samples_train_dict) 618 | samples_validate = Transition(**new_samples_validate_dict) 619 | ########## MIX VALDIATION AND TRAINING ########## 620 | transition_loader = DataLoader( 621 | EnsembleTransitionDataset(samples_train, self.state_filter, self.action_filter, n_models=self.num_models), 622 | shuffle=True, 623 | batch_size=batch_size, 624 | pin_memory=True 625 | ) 626 | validate_dataset = TransitionDataset(samples_validate, self.state_filter, self.action_filter) 627 | sampler = SequentialSampler(validate_dataset) 628 | validation_loader = DataLoader( 629 | validate_dataset, 630 | sampler=sampler, 631 | batch_size=batch_size, 632 | pin_memory=True 633 | ) 634 | 635 | ### check validation before first training epoch 636 | improved_any, iter_best_loss = self.check_validation_losses(validation_loader) 637 | val_improve.append(improved_any) 638 | best_epoch = 0 639 | model_idx = 0 640 | print('Epoch: %s, Total Loss: N/A' % (0)) 641 | print('Validation Losses:') 642 | print('\t'.join('M{}: {}'.format(i, loss) for i, loss in enumerate(iter_best_loss))) 643 | for i in range(max_epochs): 644 | t0 = time.time() 645 | total_loss = 0 646 | loss = 0 647 | step = 0 648 | # value to shuffle dataloader rows by so each epoch each model sees different data 649 | perm = np.random.choice(self.num_models, size=self.num_models, replace=False) 650 | for x_batch, diff_batch, r_batch in transition_loader: 651 | x_batch = x_batch[:, perm] 652 | diff_batch = diff_batch[:, perm] 653 | r_batch = r_batch[:, perm] 654 | step += 1 655 | for idx in range(self.num_models): 656 | loss += self.models[idx].train_model_forward(x_batch[:, idx], diff_batch[:, idx], r_batch[:, idx]) 657 | total_loss = loss.item() 658 | if self.is_probabilistic: 659 | loss += 0.01 * self.max_logvar.sum() - 0.01 * self.min_logvar.sum() 660 | self.optimizer.zero_grad() 661 | loss.backward() 662 | self.optimizer.step() 663 | loss = 0 664 | t1 = time.time() 665 | print("Epoch training took {} seconds".format(t1 - t0)) 666 | if (i + 1) % 1 == 0: 667 | improved_any, iter_best_loss = self.check_validation_losses(validation_loader) 668 | print('Epoch: {}, Total Loss: {}'.format(int(i + 1), float(total_loss))) 669 | print('Validation Losses:') 670 | print('\t'.join('M{}: {}'.format(i, loss) for i, loss in enumerate(iter_best_loss))) 671 | print('Best Validation Losses So Far:') 672 | print('\t'.join('M{}: {}'.format(i, loss) for i, loss in enumerate(self.current_best_losses))) 673 | val_improve.append(improved_any) 674 | if improved_any: 675 | best_epoch = (i + 1) 676 | print('Improvement detected this epoch.') 677 | else: 678 | epoch_diff = i + 1 - best_epoch 679 | plural = 's' if epoch_diff > 1 else '' 680 | print('No improvement detected this epoch: {} Epoch{} since last improvement.'.format(epoch_diff, 681 | plural)) 682 | if len(val_improve) > 5: 683 | if not any(np.array(val_improve)[1:]): 684 | # assert val_improve[0] 685 | if (i >= min_model_epochs): 686 | print('Validation loss stopped improving at %s epochs' % (best_epoch)) 687 | for model_index in self.models: 688 | self.models[model_index].load_state_dict(self.current_best_weights[model_index]) 689 | self._select_elites(validation_loader) 690 | if save_model: 691 | self._save_model() 692 | return 693 | elif not lr_lower: 694 | self._lr_scheduler.step() 695 | lr_lower = True 696 | val_improve = deque(maxlen=6) 697 | val_improve.append(True) 698 | print("Lowering Adam Learning for fine-tuning") 699 | t2 = time.time() 700 | print("Validation took {} seconds".format(t2 - t1)) 701 | self._select_elites(validation_loader) 702 | if save_model: 703 | self._save_model() 704 | 705 | def set_model_logvar_limits(self): 706 | if isinstance(self.max_logvar, dict): 707 | for i, model in enumerate(self.models.values()): 708 | model.model.update_logvar_limits(self.max_logvar[self._model_groups[i]], self.min_logvar[self._model_groups[i]]) 709 | else: 710 | for model in self.models.values(): 711 | model.model.update_logvar_limits(self.max_logvar, self.min_logvar) 712 | 713 | def _select_elites(self, validation_loader): 714 | val_losses, _ = self._get_validation_losses(validation_loader, get_weights=False) 715 | print('Sorting Models from most to least accurate...') 716 | models_val_rank = val_losses.argsort() 717 | val_losses.sort() 718 | print('\nModel validation losses: {}'.format(val_losses)) 719 | self.models = {i: self.models[idx] for i, idx in enumerate(models_val_rank)} 720 | self._elites_idx = list(range(self.num_elites)) 721 | self.elites = {i: self.models[j] for i, j in enumerate(self._elites_idx)} 722 | self.elite_errors = {i: val_losses[j] for i, j in enumerate(self._elites_idx)} 723 | print('\nSelected the following models as elites: {}'.format(self._elites_idx)) 724 | return val_losses 725 | 726 | def _save_model(self): 727 | """ 728 | Method to save model after training is completed 729 | """ 730 | print("Saving model checkpoint...") 731 | check_or_make_folder("./checkpoints") 732 | check_or_make_folder("./checkpoints/model_saved_weights") 733 | save_dir = "./checkpoints/model_saved_weights/{}".format(self._model_id) 734 | check_or_make_folder(save_dir) 735 | # Create a dictionary with pytorch objects we need to save, starting with models 736 | torch_state_dict = {'model_{}_state_dict'.format(i): w for i, w in enumerate(self.current_best_weights)} 737 | # Then add logvariance limit terms 738 | torch_state_dict['logvar_min'] = self.min_logvar 739 | torch_state_dict['logvar_max'] = self.max_logvar 740 | # Save Torch files 741 | torch.save(torch_state_dict, save_dir + "/torch_model_weights.pt") 742 | # Create a dict containing training and validation datasets 743 | data_state_dict = {'train_buffer': self.memory, 'valid_buffer': self.memory_val, 744 | 'state_filter': self.state_filter, 'action_filter': self.action_filter} 745 | # Then add validation performance for checking purposes during loading (i.e., make sure we got the same performance) 746 | data_state_dict['validation_performance'] = self.current_best_losses 747 | # Pickle the data dict 748 | pickle.dump(data_state_dict, open(save_dir + '/model_data.pkl', 'wb')) 749 | print("Saved model snapshot trained on {} datapoints".format(len(self.memory))) 750 | 751 | def load_model(self, model_dir): 752 | """ 753 | Method to load model from checkpoint folder 754 | """ 755 | # Check that the environment matches the dir name 756 | assert self._env_name.split('-')[ 757 | 0].lower() in model_dir.lower(), "Model loaded was not trained on this environment" 758 | 759 | print("Loading model from checkpoint...") 760 | import os 761 | model_dir = os.path.join(os.path.dirname(__file__), model_dir) 762 | print(model_dir) 763 | torch_state_dict = torch.load(model_dir + '/torch_model_weights.pt', map_location=device) 764 | for i in range(self.num_models): 765 | self.models[i].load_state_dict(torch_state_dict['model_{}_state_dict'.format(i)]) 766 | self.min_logvar = torch_state_dict['logvar_min'] 767 | self.max_logvar = torch_state_dict['logvar_max'] 768 | 769 | data_state_dict = pickle.load(open(model_dir + '/model_data.pkl', 'rb')) 770 | # Backwards Compatability 771 | if isinstance(data_state_dict['train_buffer'], ReplayPool): 772 | assert self.memory.capacity > len(data_state_dict['train_buffer']) 773 | assert self.memory_val.capacity > len(data_state_dict['valid_buffer']) 774 | all_train = data_state_dict['train_buffer'].sample_all()._asdict() 775 | all_train = Transition(**{k: np.stack(v) for k, v in all_train.items()}) 776 | all_valid = data_state_dict['valid_buffer'].sample_all()._asdict() 777 | all_valid = Transition(**{k: np.stack(v) for k, v in all_valid.items()}) 778 | self.memory.push(all_train) 779 | self.memory_val.push(all_valid) 780 | else: 781 | self.memory, self.memory_val = data_state_dict['train_buffer'], data_state_dict['valid_buffer'] 782 | self.state_filter, self.action_filter = data_state_dict['state_filter'], data_state_dict['action_filter'] 783 | 784 | # Confirm that we retrieve the checkpointed validation performance 785 | all_valid = self.memory_val.sample_all() 786 | validate_dataset = TransitionDataset(all_valid, self.state_filter, self.action_filter) 787 | sampler = SequentialSampler(validate_dataset) 788 | validation_loader = DataLoader( 789 | validate_dataset, 790 | sampler=sampler, 791 | batch_size=256, 792 | pin_memory=True 793 | ) 794 | 795 | val_losses = self._select_elites(validation_loader) 796 | self.set_model_logvar_limits() 797 | 798 | model_id = model_dir.split('/')[-1] 799 | self._model_id = model_id 800 | 801 | return val_losses 802 | 803 | # Doesn't work because we mix up val and train data 804 | # assert np.isclose(val_losses, data_state_dict['validation_performance']).all() 805 | 806 | def load_model_from_population(self, model_dirs): 807 | """ 808 | Method to load model from aggregated population 809 | """ 810 | 811 | # Aggregate the state dictionaries first and then relabel them 812 | aggregate_torch_state_dict = {} 813 | aggregated_logvar_min = {} 814 | aggregated_logvar_max = {} 815 | self._model_groups = {} 816 | for model_idx, model_dir in enumerate(model_dirs): 817 | # Check that the environment matches the dir name 818 | assert self._env_name.split('-')[ 819 | 0].lower() in model_dir.lower(), "Model loaded was not trained on this environment" 820 | torch_state_dict = torch.load(model_dir + '/torch_model_weights.pt', map_location=device) 821 | for key, value in torch_state_dict.items(): 822 | if 'model' in key: 823 | cur_idx = int(key.split('_')[1]) 824 | self._model_groups[model_idx * 7 + cur_idx] = model_idx 825 | relabelled_key = 'model_{}_state_dict'.format(model_idx * 7 + cur_idx) 826 | aggregate_torch_state_dict[relabelled_key] = value 827 | else: 828 | if key == 'logvar_min': 829 | aggregated_logvar_min[model_idx] = value 830 | elif key == 'logvar_max': 831 | aggregated_logvar_max[model_idx] = value 832 | 833 | for i in range(self.num_models): 834 | self.models[i].load_state_dict(aggregate_torch_state_dict['model_{}_state_dict'.format(i)]) 835 | self.min_logvar = aggregated_logvar_min 836 | self.max_logvar = aggregated_logvar_max 837 | 838 | # Takes the data from the first element of the population, shouldn't matter? 839 | data_state_dict = pickle.load(open(model_dirs[0] + '/model_data.pkl', 'rb')) 840 | # Backwards Compatability 841 | if isinstance(data_state_dict['train_buffer'], ReplayPool): 842 | assert self.memory.capacity > len(data_state_dict['train_buffer']) 843 | assert self.memory_val.capacity > len(data_state_dict['valid_buffer']) 844 | all_train = data_state_dict['train_buffer'].sample_all()._asdict() 845 | all_train = Transition(**{k: np.stack(v) for k, v in all_train.items()}) 846 | all_valid = data_state_dict['valid_buffer'].sample_all()._asdict() 847 | all_valid = Transition(**{k: np.stack(v) for k, v in all_valid.items()}) 848 | self.memory.push(all_train) 849 | self.memory_val.push(all_valid) 850 | else: 851 | self.memory, self.memory_val = data_state_dict['train_buffer'], data_state_dict['valid_buffer'] 852 | self.state_filter, self.action_filter = data_state_dict['state_filter'], data_state_dict['action_filter'] 853 | 854 | # Confirm that we retrieve the checkpointed validation performance 855 | all_valid = self.memory_val.sample_all() 856 | validate_dataset = TransitionDataset(all_valid, self.state_filter, self.action_filter) 857 | sampler = SequentialSampler(validate_dataset) 858 | validation_loader = DataLoader( 859 | validate_dataset, 860 | sampler=sampler, 861 | batch_size=1024, 862 | pin_memory=True 863 | ) 864 | 865 | self.set_model_logvar_limits() 866 | val_losses = self._select_elites(validation_loader) 867 | 868 | model_id = model_dirs[0].split('/')[-1] 869 | self._model_id = model_id 870 | 871 | return val_losses 872 | 873 | def remove_model(self): 874 | """ removes the least accurate model """ 875 | 876 | if self.num_models - 1 < self.num_elites: 877 | print("Can't remove any more models, otherwise we'll start removing elites") 878 | return 879 | 880 | self.num_models -= 1 881 | self._removed_models.append(self.models[self.num_models]) 882 | del self.models[self.num_models] 883 | print("Removed a model; you have {} models".format(self.num_models)) 884 | 885 | def add_models(self): 886 | """ adds back the most accurate removed model """ 887 | 888 | if len(self._removed_models) < 1: 889 | print("You haven't removed any models!") 890 | return 891 | 892 | model = self._removed_models.pop() 893 | self.models[self.num_models] = model 894 | self.num_models += 1 895 | print("Re-added a model; you have {} models".format(self.num_models)) 896 | 897 | def get_replay_buffer_predictions(self, only_validation=False, return_sample=False): 898 | """ Gets the predictions of all ensemble members on the data currently in the buffer """ 899 | buffer_data = self.memory_val.sample_all() 900 | if not only_validation: 901 | pass 902 | dataset = TransitionDataset(buffer_data, self.state_filter, self.action_filter) 903 | sampler = SequentialSampler(dataset) 904 | dataloader = DataLoader( 905 | dataset, 906 | sampler=sampler, 907 | batch_size=1024, 908 | pin_memory=True 909 | ) 910 | 911 | preds = torch.stack([m.get_predictions_from_loader(dataloader, return_sample=return_sample) for m in self.models.values()], 0) 912 | 913 | return preds 914 | 915 | class Model(nn.Module): 916 | def __init__(self, input_dim: int, 917 | output_dim: int, 918 | h: int = 1024, 919 | is_probabilistic=True, 920 | is_done_func=None, 921 | reward_head=1, 922 | seed=0, 923 | l2_reg_multiplier=1., 924 | num=0): 925 | 926 | super(Model, self).__init__() 927 | torch.manual_seed(seed) 928 | if is_probabilistic: 929 | self.model = BayesianNeuralNetwork(input_dim, output_dim, 200, is_done_func, reward_head, l2_reg_multiplier, 930 | seed) 931 | else: 932 | self.model = VanillaNeuralNetwork(input_dim, output_dim, h, is_done_func, reward_head, seed) 933 | self.is_probabilistic = self.model.is_probabilistic 934 | self.weights = self.model.weights 935 | self.reward_head = reward_head 936 | 937 | def forward(self, x: torch.Tensor): 938 | return self.model(x) 939 | 940 | def get_next_state_reward(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter, 941 | keep_logvar=False, deterministic=False, return_mean=False): 942 | return self.model.get_next_state_reward(state, action, state_filter, action_filter, keep_logvar, 943 | deterministic, return_mean) 944 | 945 | def predict_state(self, state: np.array, action: np.array, state_filter, action_filter) -> (np.array, float): 946 | state, action = torch.Tensor(state).to(device), torch.Tensor(action).to(device) 947 | nextstate, reward = self.get_next_state_reward(state, action, state_filter, action_filter) 948 | nextstate = nextstate.detach().cpu().numpy() 949 | if self.reward_head: 950 | reward = reward.detach().cpu().item() 951 | return nextstate, reward 952 | 953 | def get_state_action_uncertainty(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter): 954 | return self.model.get_state_action_uncertainty(state, action, state_filter, action_filter) 955 | 956 | def _train_model_forward(self, x_batch): 957 | self.model.train() 958 | self.model.zero_grad() 959 | x_batch = x_batch.to(device, non_blocking=True) 960 | y_pred = self.forward(x_batch) 961 | return y_pred 962 | 963 | def train_model_forward(self, x_batch, delta_batch, r_batch): 964 | delta_batch, r_batch = delta_batch.to(device, non_blocking=True), r_batch.to(device, non_blocking=True) 965 | y_pred = self._train_model_forward(x_batch) 966 | y_batch = torch.cat([delta_batch, r_batch.unsqueeze(dim=1)], dim=1) if self.reward_head else delta_batch 967 | loss = self.model.loss(y_pred, y_batch) 968 | return loss 969 | 970 | def get_predictions_from_loader(self, data_loader, return_targets = False, return_sample=False): 971 | self.model.eval() 972 | preds, targets = [], [] 973 | with torch.no_grad(): 974 | for x_batch_val, delta_batch_val, r_batch_val in data_loader: 975 | x_batch_val, delta_batch_val, r_batch_val = x_batch_val.to(device, 976 | non_blocking=True), delta_batch_val.to( 977 | device, non_blocking=True), r_batch_val.to(device, non_blocking=True) 978 | y_pred_val = self.forward(x_batch_val) 979 | preds.append(y_pred_val) 980 | if return_targets: 981 | y_batch_val = torch.cat([delta_batch_val, r_batch_val.unsqueeze(dim=1)], 982 | dim=1) if self.reward_head else delta_batch_val 983 | targets.append(y_batch_val) 984 | 985 | preds = torch.vstack(preds) 986 | 987 | if return_sample: 988 | mu, logvar = preds.chunk(2, dim=1) 989 | dist = torch.distributions.Normal(mu, logvar.exp().sqrt()) 990 | sample = dist.sample() 991 | preds = torch.cat((sample, preds), dim=1) 992 | 993 | if return_targets: 994 | targets = torch.vstack(targets) 995 | return preds, targets 996 | else: 997 | return preds 998 | 999 | def get_validation_loss(self, validation_loader): 1000 | self.model.eval() 1001 | preds, targets = self.get_predictions_from_loader(validation_loader, return_targets=True) 1002 | if self.is_probabilistic: 1003 | return self.model.loss(preds, targets, logvar_loss=False).item() 1004 | else: 1005 | return self.model.loss(preds, targets).item() 1006 | 1007 | def get_acquisition(self, rollout: List[Transition], state_filter, action_filter): 1008 | self.model.eval() 1009 | state = [] 1010 | action = [] 1011 | nextstate = [] 1012 | reward = [] 1013 | # for rollout in rollouts: 1014 | for step in rollout: 1015 | state.append(step.state) 1016 | action.append(step.action) 1017 | nextstate.append(step.nextstate) 1018 | reward.append(step.reward) 1019 | state = np.array(state) 1020 | action = np.array(action) 1021 | reward = np.array(reward).reshape(-1, 1) 1022 | state_action_filtered, delta_filtered = prepare_data( 1023 | state, 1024 | action, 1025 | nextstate, 1026 | state_filter, 1027 | action_filter) 1028 | y_true = np.concatenate((delta_filtered, reward), axis=1) if self.reward_head else delta_filtered 1029 | y_true, state_action_filtered = torch.Tensor(y_true).to(device), torch.Tensor(state_action_filtered).to(device) 1030 | y_pred = self.forward(state_action_filtered) 1031 | 1032 | if self.is_probabilistic: 1033 | loss = self.model.loss(y_pred, y_true, logvar_loss=False) 1034 | else: 1035 | loss = self.model.loss(y_pred, y_true) 1036 | 1037 | return float(loss.item()) 1038 | 1039 | 1040 | class VanillaNeuralNetwork(nn.Module): 1041 | def __init__(self, input_dim: int, 1042 | output_dim: int, 1043 | h: int = 1024, 1044 | is_done_func=None, 1045 | reward_head=True, 1046 | seed=0): 1047 | 1048 | super().__init__() 1049 | torch.manual_seed(seed) 1050 | self.network = nn.Sequential( 1051 | nn.Linear(input_dim, h), 1052 | nn.ReLU(), 1053 | nn.Linear(h, h), 1054 | nn.ReLU() 1055 | ) 1056 | self.delta = nn.Linear(h, output_dim) 1057 | params = list(self.network.parameters()) + list(self.delta.parameters()) 1058 | self.weights = params 1059 | self.to(device) 1060 | self.loss = nn.MSELoss() 1061 | self.is_done_func = is_done_func 1062 | self.reward_head = reward_head 1063 | 1064 | @property 1065 | def is_probabilistic(self): 1066 | return False 1067 | 1068 | def forward(self, x: torch.Tensor): 1069 | hidden = self.network(x) 1070 | delta = self.delta(hidden) 1071 | return delta 1072 | 1073 | @staticmethod 1074 | def filter_inputs(state, action, state_filter, action_filter): 1075 | state_f = state_filter.filter_torch(state) 1076 | action_f = action_filter.filter_torch(action) 1077 | state_action_f = torch.cat((state_f, action_f), dim=1) 1078 | return state_action_f 1079 | 1080 | def get_next_state_reward(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter, 1081 | keep_logvar=False): 1082 | if keep_logvar: 1083 | raise Exception("This is a deterministic network, there is no logvariance prediction") 1084 | state_action_f = self.filter_inputs(state, action, state_filter, action_filter) 1085 | y = self.forward(state_action_f) 1086 | if self.reward_head: 1087 | diff_f = y[:, :-1] 1088 | reward = y[:, -1].unsqueeze(1) 1089 | else: 1090 | diff_f = y 1091 | reward = 0 1092 | diff = diff_f 1093 | nextstate = state + diff 1094 | return nextstate, reward 1095 | 1096 | def get_state_action_uncertainty(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter): 1097 | raise Exception("This is a deterministic network, there is no logvariance prediction") 1098 | 1099 | 1100 | class BayesianNeuralNetwork(VanillaNeuralNetwork): 1101 | def __init__(self, input_dim: int, 1102 | output_dim: int, 1103 | h: int = 200, 1104 | is_done_func=None, 1105 | reward_head=True, 1106 | l2_reg_multiplier=1., 1107 | seed=0): 1108 | super().__init__(input_dim, 1109 | output_dim, 1110 | h, 1111 | is_done_func, 1112 | reward_head, 1113 | seed) 1114 | torch.manual_seed(seed) 1115 | del self.network 1116 | self.fc1 = nn.Linear(input_dim, h) 1117 | reinitialize_fc_layer_(self.fc1) 1118 | self.fc2 = nn.Linear(h, h) 1119 | reinitialize_fc_layer_(self.fc2) 1120 | self.fc3 = nn.Linear(h, h) 1121 | reinitialize_fc_layer_(self.fc3) 1122 | self.fc4 = nn.Linear(h, h) 1123 | reinitialize_fc_layer_(self.fc4) 1124 | self.use_blr = False 1125 | self.delta = nn.Linear(h, output_dim) 1126 | reinitialize_fc_layer_(self.delta) 1127 | self.logvar = nn.Linear(h, output_dim) 1128 | reinitialize_fc_layer_(self.logvar) 1129 | self.loss = GaussianMSELoss() 1130 | self.activation = nn.SiLU() 1131 | self.lambda_prec = 1.0 1132 | self.max_logvar = None 1133 | self.min_logvar = None 1134 | params = [] 1135 | self.layers = [self.fc1, self.fc2, self.fc3, self.fc4, self.delta, self.logvar] 1136 | self.decays = np.array([0.000025, 0.00005, 0.000075, 0.000075, 0.0001, 0.0001]) * l2_reg_multiplier 1137 | for layer, decay in zip(self.layers, self.decays): 1138 | params.extend(get_weight_bias_parameters_with_decays(layer, decay)) 1139 | self.weights = params 1140 | self.to(device) 1141 | 1142 | def get_l2_reg_loss(self): 1143 | l2_loss = 0 1144 | for layer, decay in zip(self.layers, self.decays): 1145 | for name, parameter in layer.named_parameters(): 1146 | if 'weight' in name: 1147 | l2_loss += parameter.pow(2).sum() / 2 * decay 1148 | return l2_loss 1149 | 1150 | def update_logvar_limits(self, max_logvar, min_logvar): 1151 | self.max_logvar, self.min_logvar = max_logvar, min_logvar 1152 | 1153 | @property 1154 | def is_probabilistic(self): 1155 | return True 1156 | 1157 | def forward(self, x: torch.Tensor): 1158 | x = self.activation(self.fc1(x)) 1159 | x = self.activation(self.fc2(x)) 1160 | x = self.activation(self.fc3(x)) 1161 | x = self.activation(self.fc4(x)) 1162 | delta = self.delta(x) 1163 | logvar = self.logvar(x) 1164 | # Taken from the PETS code to stabilise training 1165 | logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) 1166 | logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) 1167 | return torch.cat((delta, logvar), dim=1) 1168 | 1169 | def get_next_state_reward(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter, 1170 | keep_logvar=False, deterministic=False, return_mean=False): 1171 | state_action_f = self.filter_inputs(state, action, state_filter, action_filter) 1172 | mu, logvar = self.forward(state_action_f).chunk(2, dim=1) 1173 | mu_orig = mu 1174 | if not deterministic: 1175 | dist = torch.distributions.Normal(mu, logvar.exp().sqrt()) 1176 | mu = dist.sample() 1177 | if self.reward_head: 1178 | mu_diff_f = mu[:, :-1] 1179 | logvar_diff_f = logvar[:, :-1] 1180 | mu_reward = mu[:, -1].unsqueeze(1) 1181 | logvar_reward = logvar[:, -1].unsqueeze(1) 1182 | mu_diff_f_orig = mu_orig[:, :-1] 1183 | mu_reward_orig = mu_orig[:, -1].unsqueeze(1) 1184 | else: 1185 | mu_diff_f = mu 1186 | logvar_diff_f = logvar 1187 | mu_reward = torch.zeros_like(mu[:, -1].unsqueeze(1)) 1188 | logvar_reward = torch.zeros_like(logvar[:, -1].unsqueeze(1)) 1189 | mu_diff_f_orig = mu_orig 1190 | mu_reward_orig = mu_reward 1191 | mu_diff = mu_diff_f 1192 | mu_nextstate = state + mu_diff 1193 | logvar_nextstate = logvar_diff_f 1194 | if return_mean: 1195 | mu_nextstate = torch.cat((mu_nextstate, mu_diff_f_orig + state), dim=1) 1196 | mu_reward = torch.cat((mu_reward, mu_reward_orig), dim=1) 1197 | if keep_logvar: 1198 | return (mu_nextstate, logvar_nextstate), (mu_reward, logvar_reward) 1199 | else: 1200 | return mu_nextstate, mu_reward 1201 | 1202 | def get_state_action_uncertainty(self, state: torch.Tensor, action: torch.Tensor, state_filter, action_filter): 1203 | state_action_f = self.filter_inputs(state, action, state_filter, action_filter) 1204 | _, logvar = self.forward(state_action_f).chunk(2, dim=1) 1205 | # TODO: See above, which is the correct logvar? 1206 | return logvar 1207 | 1208 | 1209 | def reinitialize_fc_layer_(fc_layer): 1210 | """ 1211 | Helper function to initialize a fc layer to have a truncated normal over the weights, and zero over the biases 1212 | """ 1213 | input_dim = fc_layer.weight.shape[1] 1214 | std = get_trunc_normal_std(input_dim) 1215 | torch.nn.init.trunc_normal_(fc_layer.weight, std=std, a=-2 * std, b=2 * std) 1216 | torch.nn.init.zeros_(fc_layer.bias) 1217 | 1218 | 1219 | def get_trunc_normal_std(input_dim): 1220 | """ 1221 | Returns the truncated normal standard deviation required for weight initialization 1222 | """ 1223 | return 1 / (2 * np.sqrt(input_dim)) 1224 | 1225 | 1226 | def get_weight_bias_parameters_with_decays(fc_layer, decay): 1227 | """ 1228 | For the fc_layer, extract only the weight from the .parameters() method so we don't regularize the bias terms 1229 | """ 1230 | decay_params = [] 1231 | non_decay_params = [] 1232 | for name, parameter in fc_layer.named_parameters(): 1233 | if 'weight' in name: 1234 | decay_params.append(parameter) 1235 | elif 'bias' in name: 1236 | non_decay_params.append(parameter) 1237 | 1238 | decay_dicts = [{'params': decay_params, 'weight_decay': decay}, {'params': non_decay_params, 'weight_decay': 0.}] 1239 | 1240 | return decay_dicts 1241 | --------------------------------------------------------------------------------