├── .gitignore ├── configs ├── offline │ ├── hopper-medium.yml │ ├── walker2d-medium.yml │ ├── halfcheetah-medium.yml │ ├── hopper-medium-expert.yml │ ├── hopper-medium-replay.yml │ ├── walker2d-medium-expert.yml │ ├── walker2d-medium-replay.yml │ ├── halfcheetah-medium-expert.yml │ ├── halfcheetah-medium-replay.yml │ ├── antmaze-umaze.yml │ ├── antmaze-large-play.yml │ ├── antmaze-medium-play.yml │ ├── antmaze-large-diverse.yml │ ├── antmaze-umaze-diverse.yml │ └── antmaze-medium-diverse.yml └── online_finetune │ ├── antmaze-umaze.yml │ ├── antmaze-medium-play.yml │ ├── antmaze-umaze-diverse.yml │ ├── antmaze-medium-diverse.yml │ ├── antmaze-large-play.yml │ └── antmaze-large-diverse.yml ├── conda_env.yml ├── LICENSE ├── eval.py ├── README.md ├── utils.py ├── vae.py ├── log.py ├── train_vae.py ├── main.py ├── main_finetune.py └── SPOT.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .git 3 | .vscode 4 | models 5 | spot-models.zip 6 | runs 7 | .DS_Store -------------------------------------------------------------------------------- /configs/offline/hopper-medium.yml: -------------------------------------------------------------------------------- 1 | env: hopper-medium-v2 2 | vae_model_path: models/vae_trained_models/vae_model_hopper_medium.pt 3 | lambd: 0.1 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/walker2d-medium.yml: -------------------------------------------------------------------------------- 1 | env: walker2d-medium-v2 2 | vae_model_path: models/vae_trained_models/vae_model_walker2d_medium.pt 3 | lambd: 0.2 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/halfcheetah-medium.yml: -------------------------------------------------------------------------------- 1 | env: halfcheetah-medium-v2 2 | vae_model_path: models/vae_trained_models/vae_model_halfcheetah_medium.pt 3 | lambd: 0.05 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/hopper-medium-expert.yml: -------------------------------------------------------------------------------- 1 | env: hopper-medium-expert-v2 2 | vae_model_path: models/vae_trained_models/vae_model_hopper_medium-expert.pt 3 | lambd: 0.2 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/hopper-medium-replay.yml: -------------------------------------------------------------------------------- 1 | env: hopper-medium-replay-v2 2 | vae_model_path: models/vae_trained_models/vae_model_hopper_medium-replay.pt 3 | lambd: 0.1 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/walker2d-medium-expert.yml: -------------------------------------------------------------------------------- 1 | env: walker2d-medium-expert-v2 2 | vae_model_path: models/vae_trained_models/vae_model_walker2d_medium-expert.pt 3 | lambd: 0.5 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/walker2d-medium-replay.yml: -------------------------------------------------------------------------------- 1 | env: walker2d-medium-replay-v2 2 | vae_model_path: models/vae_trained_models/vae_model_walker2d_medium-replay.pt 3 | lambd: 0.2 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/halfcheetah-medium-expert.yml: -------------------------------------------------------------------------------- 1 | env: halfcheetah-medium-expert-v2 2 | vae_model_path: models/vae_trained_models/vae_model_halfcheetah_medium-expert.pt 3 | lambd: 1.0 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/halfcheetah-medium-replay.yml: -------------------------------------------------------------------------------- 1 | env: halfcheetah-medium-replay-v2 2 | vae_model_path: models/vae_trained_models/vae_model_halfcheetah_medium-replay.pt 3 | lambd: 0.05 4 | work_dir: train_offline 5 | -------------------------------------------------------------------------------- /configs/offline/antmaze-umaze.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-umaze-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_umaze.pt 3 | lambd: 0.25 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/offline/antmaze-large-play.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-large-play-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_large-play.pt 3 | lambd: 0.025 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/offline/antmaze-medium-play.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-medium-play-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_medium-play.pt 3 | lambd: 0.05 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/offline/antmaze-large-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-large-diverse-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_large-diverse.pt 3 | lambd: 0.025 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/offline/antmaze-umaze-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-umaze-diverse-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_umaze-diverse.pt 3 | lambd: 0.25 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/offline/antmaze-medium-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-medium-diverse-v2 2 | vae_model_path: models/vae_trained_models/vae_model_antmaze_medium-diverse.pt 3 | lambd: 0.025 4 | actor_lr: 0.0001 5 | actor_dropout: 0.0 6 | actor_init_w: 0.001 7 | critic_init_w: 0.003 8 | antmaze_no_normalize: true 9 | eval_episodes: 100 10 | eval_freq: 50000 11 | work_dir: train_offline 12 | -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-umaze.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-umaze-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_umaze.pt 4 | pretrain_model: models/offline_trained_models/antmaze_umaze 5 | lambd: 0.25 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | eval_episodes: 100 11 | eval_freq: 50000 12 | work_dir: train_finetune 13 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: spot 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - anaconda 8 | - cudatoolkit=10. 9 | - pytorch==1.10.0 10 | - tqdm==4.61.2 11 | - numpy 12 | - pip 13 | - pip: 14 | - gym==0.18.3 15 | - mujoco-py==2.0.2.13 16 | - numpy==1.20.3 17 | - pyyaml==5.4.1 18 | - coolname==1.1.0 19 | - termcolor==1.1.0 20 | - tensorboard==2.7.0 21 | - protobuf==3.17.3 -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-medium-play.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-medium-play-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_medium-play.pt 4 | pretrain_model: models/offline_trained_models/antmaze_medium-play 5 | lambd: 0.05 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | eval_episodes: 100 11 | eval_freq: 50000 12 | work_dir: train_finetune 13 | -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-umaze-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-umaze-diverse-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_umaze-diverse.pt 4 | pretrain_model: models/offline_trained_models/antmaze_umaze-diverse 5 | lambd: 0.25 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | eval_episodes: 100 11 | eval_freq: 50000 12 | work_dir: train_finetune 13 | -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-medium-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-medium-diverse-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_medium-diverse.pt 4 | pretrain_model: models/offline_trained_models/antmaze_medium-diverse 5 | lambd: 0.025 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | eval_episodes: 100 11 | eval_freq: 50000 12 | work_dir: train_finetune 13 | -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-large-play.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-large-play-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_large-play.pt 4 | pretrain_model: models/offline_trained_models/antmaze_large-play 5 | lambd: 0.025 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | discount: 0.995 11 | eval_episodes: 100 12 | eval_freq: 50000 13 | work_dir: train_finetune 14 | -------------------------------------------------------------------------------- /configs/online_finetune/antmaze-large-diverse.yml: -------------------------------------------------------------------------------- 1 | env: antmaze-large-diverse-v2 2 | seed: 1 3 | vae_model_path: models/vae_trained_models/vae_model_antmaze_large-diverse.pt 4 | pretrain_model: models/offline_trained_models/antmaze_large-diverse 5 | lambd: 0.025 6 | lambd_cool: true 7 | actor_lr: 0.0001 8 | actor_dropout: 0.0 9 | antmaze_no_normalize: true 10 | discount: 0.995 11 | eval_episodes: 100 12 | eval_freq: 50000 13 | work_dir: train_finetune 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Scott Fujimoto (https://github.com/sfujim/TD3_BC) 4 | Copyright (c) 2022 THUML @ Tsinghua University 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from log import Logger 4 | from tqdm import trange 5 | 6 | from utils import VideoRecorder 7 | 8 | # Runs policy for X episodes and returns average reward 9 | # A fixed seed is used for the eval environment 10 | 11 | 12 | def eval_policy(args, iter, video: VideoRecorder, logger: Logger, policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10): 13 | eval_env = gym.make(env_name) 14 | eval_env.seed(seed + seed_offset) 15 | 16 | lengths = [] 17 | returns = [] 18 | last_rewards = [] 19 | avg_reward = 0. 20 | for episode in trange(eval_episodes): 21 | video.init(enabled=(args.save_video and _ == 0)) 22 | state, done = eval_env.reset(), False 23 | video.record(eval_env) 24 | steps = 0 25 | episode_return = 0 26 | while not done: 27 | state = (np.array(state).reshape(1, -1) - mean) / std 28 | action = policy.select_action(state) 29 | state, reward, done, _ = eval_env.step(action) 30 | video.record(eval_env) 31 | avg_reward += reward 32 | episode_return += reward 33 | steps += 1 34 | lengths.append(steps) 35 | returns.append(episode_return) 36 | last_rewards.append(reward) 37 | video.save(f'eval_s{iter}_e{episode}_r{str(episode_return)}.mp4') 38 | if 'antmaze' in args.env: 39 | print("\tsuccess", float(steps != eval_env._max_episode_steps), "\tlast reward", reward) 40 | 41 | avg_reward /= eval_episodes 42 | d4rl_score = eval_env.get_normalized_score(avg_reward) 43 | 44 | logger.log('eval/lengths_mean', np.mean(lengths), iter) 45 | logger.log('eval/lengths_std', np.std(lengths), iter) 46 | logger.log('eval/returns_mean', np.mean(returns), iter) 47 | logger.log('eval/returns_std', np.std(returns), iter) 48 | logger.log('eval/d4rl_score', d4rl_score, iter) 49 | if 'antmaze' in args.env: 50 | logger.log('eval/success_rate', 1 - np.mean(np.array(lengths) == eval_env._max_episode_steps), iter) 51 | if 'dense' in args.env: 52 | logger.log('eval/last_reward_mean', np.mean(last_rewards), iter) 53 | logger.log('eval/last_reward_std', np.std(last_rewards), iter) 54 | 55 | print("---------------------------------------") 56 | print(f"Evaluation over {eval_episodes} episodes: {d4rl_score:.3f}") 57 | print("\tepisode returns:", *['%.2f' % x for x in returns]) 58 | print("\tepisode lengths", lengths) 59 | if 'antmaze' in args.env: 60 | print("\tsuccess rate", 1 - np.mean(np.array(lengths) == eval_env._max_episode_steps)) 61 | if 'dense' in args.env: 62 | print("\tlast reward", *['%.2f' % x for x in last_rewards]) 63 | print("---------------------------------------") 64 | return d4rl_score 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supported Policy Optimization 2 | 3 | Official implementation for NeurIPS 2022 paper [Supported Policy Optimization for Offline Reinforcement Learning](https://arxiv.org/abs/2202.06239). 4 | 5 | 🚩 **News**: 6 | 7 | - June, 2023: SPOT has been included in [Clean Offline Reinforcement Learning (CORL)](https://github.com/tinkoff-ai/CORL) library as a strong baseline for Offline-to-Online RL. Thanks [Tinkoff AI](https://github.com/tinkoff-ai) and [Denis Tarasov](https://github.com/DT6A) for the implementation! 8 | 9 | ## Environment 10 | 11 | 1. Install [MuJoCo version 2.0](https://www.roboti.us/download.html) at ~/.mujoco/mujoco200 and copy license key to ~/.mujoco/mjkey.txt 12 | 2. Create a conda environment 13 | ``` 14 | conda env create -f conda_env.yml 15 | conda activate spot 16 | ``` 17 | 3. Install [D4RL](https://github.com/Farama-Foundation/D4RL/tree/4aff6f8c46f62f9a57f79caa9287efefa45b6688) 18 | 19 | ## Usage 20 | 21 | ### Pretrained Models 22 | 23 | We have uploaded pretrained VAE models and offline models to facilitate experiment reproduction. Download from this [link](https://drive.google.com/file/d/1_v6yPpwYw6T7CcBs1u_UJizf9wZmV1PW/view?usp=sharing) and unzip: 24 | 25 | ``` 26 | unzip spot-models.zip -d . 27 | ``` 28 | 29 | ### Offline RL 30 | 31 | Run the following command to train VAE. 32 | 33 | ``` 34 | python train_vae.py --env halfcheetah --dataset medium-replay 35 | python train_vae.py --env antmaze --dataset medium-diverse --no_normalize 36 | ``` 37 | 38 | Run the following command to train offline RL on D4RL with pretrained VAE models. 39 | 40 | ``` 41 | python main.py --config configs/offline/halfcheetah-medium-replay.yml 42 | python main.py --config configs/offline/antmaze-medium-diverse.yml 43 | ``` 44 | 45 | You can also specify the random seed and VAE model: 46 | 47 | ``` 48 | python main.py --config configs/offline/halfcheetah-medium-replay.yml --seed --vae_model_path 49 | ``` 50 | 51 | #### Logging 52 | 53 | This codebase uses tensorboard. You can view saved runs with: 54 | 55 | ``` 56 | tensorboard --logdir 57 | ``` 58 | 59 | ### Online Fine-tuning 60 | 61 | Run the following command to online fine-tune on AntMaze with pretrained VAE models and offline models. 62 | 63 | ``` 64 | python main_finetune.py --config configs/online_finetune/antmaze-medium-diverse.yml 65 | ``` 66 | 67 | You can also specify the random seed, VAE model and offline models: 68 | 69 | ``` 70 | python main_finetune.py --config configs/online_finetune/antmaze-medium-diverse.yml --seed --vae_model_path --pretrain_model 71 | ``` 72 | 73 | ## Citation 74 | 75 | If you find this code useful for your research, please cite our paper as: 76 | 77 | ``` 78 | @inproceedings{wu2022supported, 79 | title={Supported Policy Optimization for Offline Reinforcement Learning}, 80 | author={Jialong Wu and Haixu Wu and Zihan Qiu and Jianmin Wang and Mingsheng Long}, 81 | booktitle={Advances in Neural Information Processing Systems}, 82 | year={2022} 83 | } 84 | ``` 85 | 86 | ## Contact 87 | 88 | If you have any question, please contact wujialong0229@gmail.com . 89 | 90 | ## Acknowledgement 91 | 92 | This repo borrows heavily from [sfujim/TD3_BC](https://github.com/sfujim/TD3_BC) and [sfujim/BCQ](https://github.com/sfujim/BCQ). -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import os 5 | import random 6 | import imageio 7 | import gym 8 | from tqdm import trange 9 | import pickle 10 | 11 | 12 | class ReplayBuffer(object): 13 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 14 | self.max_size = max_size 15 | self.ptr = 0 16 | self.size = 0 17 | 18 | self.state = np.zeros((max_size, state_dim)) 19 | self.action = np.zeros((max_size, action_dim)) 20 | self.next_state = np.zeros((max_size, state_dim)) 21 | self.reward = np.zeros((max_size, 1)) 22 | self.not_done = np.zeros((max_size, 1)) 23 | 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | def add(self, state, action, next_state, reward, done): 27 | self.state[self.ptr] = state 28 | self.action[self.ptr] = action 29 | self.next_state[self.ptr] = next_state 30 | self.reward[self.ptr] = reward 31 | self.not_done[self.ptr] = 1. - done 32 | 33 | self.ptr = (self.ptr + 1) % self.max_size 34 | self.size = min(self.size + 1, self.max_size) 35 | 36 | def sample(self, batch_size): 37 | ind = np.random.randint(0, self.size, size=batch_size) 38 | return ( 39 | torch.FloatTensor(self.state[ind]).to(self.device), 40 | torch.FloatTensor(self.action[ind]).to(self.device), 41 | torch.FloatTensor(self.next_state[ind]).to(self.device), 42 | torch.FloatTensor(self.reward[ind]).to(self.device), 43 | torch.FloatTensor(self.not_done[ind]).to(self.device) 44 | ) 45 | 46 | def convert_D4RL(self, dataset): 47 | self.state = dataset['observations'] 48 | self.action = dataset['actions'] 49 | self.next_state = dataset['next_observations'] 50 | self.reward = dataset['rewards'].reshape(-1, 1) 51 | self.not_done = 1. - dataset['terminals'].reshape(-1, 1) 52 | self.size = self.state.shape[0] 53 | 54 | def convert_D4RL_finetune(self, dataset): 55 | self.ptr = dataset['observations'].shape[0] 56 | self.size = dataset['observations'].shape[0] 57 | self.state[:self.ptr] = dataset['observations'] 58 | self.action[:self.ptr] = dataset['actions'] 59 | self.next_state[:self.ptr] = dataset['next_observations'] 60 | self.reward[:self.ptr] = dataset['rewards'].reshape(-1, 1) 61 | self.not_done[:self.ptr] = 1. - dataset['terminals'].reshape(-1, 1) 62 | 63 | def normalize_states(self, eps=1e-3): 64 | mean = self.state.mean(0, keepdims=True) 65 | std = self.state.std(0, keepdims=True) + eps 66 | self.state = (self.state - mean) / std 67 | self.next_state = (self.next_state - mean) / std 68 | return mean, std 69 | 70 | def clip_to_eps(self, eps=1e-5): 71 | lim = 1 - eps 72 | self.action = np.clip(self.action, -lim, lim) 73 | 74 | 75 | def make_dir(dir_path): 76 | try: 77 | os.mkdir(dir_path) 78 | except OSError: 79 | pass 80 | return dir_path 81 | 82 | 83 | def set_seed_everywhere(seed): 84 | torch.manual_seed(seed) 85 | if torch.cuda.is_available(): 86 | torch.cuda.manual_seed_all(seed) 87 | np.random.seed(seed) 88 | random.seed(seed) 89 | 90 | 91 | def get_lr(optimizer): 92 | for param_group in optimizer.param_groups: 93 | return param_group['lr'] 94 | 95 | 96 | def snapshot_src(src, target, exclude_from): 97 | make_dir(target) 98 | os.system(f"rsync -rv --exclude-from={exclude_from} {src} {target}") 99 | 100 | 101 | class VideoRecorder(object): 102 | def __init__(self, dir_name, height=512, width=512, camera_id=0, fps=30): 103 | self.dir_name = dir_name 104 | self.height = height 105 | self.width = width 106 | self.camera_id = camera_id 107 | self.fps = fps 108 | self.frames = [] 109 | 110 | def init(self, enabled=True): 111 | self.frames = [] 112 | self.enabled = self.dir_name is not None and enabled 113 | 114 | def record(self, env): 115 | if self.enabled: 116 | frame = env.render( 117 | mode='rgb_array', 118 | height=self.height, 119 | width=self.width, 120 | # camera_id=self.camera_id 121 | ) 122 | self.frames.append(frame) 123 | 124 | def save(self, file_name): 125 | if self.enabled: 126 | path = os.path.join(self.dir_name, file_name) 127 | imageio.mimsave(path, self.frames, fps=self.fps) 128 | 129 | 130 | def grad_norm(model): 131 | total_norm = 0. 132 | for p in model.parameters(): 133 | param_norm = p.grad.data.norm(2) 134 | total_norm += param_norm.item() ** 2 135 | total_norm = total_norm ** (1. / 2) 136 | return total_norm 137 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import math 5 | import torch.distributions as td 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class VAE(nn.Module): 12 | # Vanilla Variational Auto-Encoder 13 | 14 | def __init__(self, state_dim, action_dim, latent_dim, max_action, hidden_dim=750, dropout=0.0): 15 | super(VAE, self).__init__() 16 | self.e1 = nn.Linear(state_dim + action_dim, hidden_dim) 17 | self.e2 = nn.Linear(hidden_dim, hidden_dim) 18 | 19 | self.mean = nn.Linear(hidden_dim, latent_dim) 20 | self.log_std = nn.Linear(hidden_dim, latent_dim) 21 | 22 | self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim) 23 | self.d2 = nn.Linear(hidden_dim, hidden_dim) 24 | self.d3 = nn.Linear(hidden_dim, action_dim) 25 | 26 | self.max_action = max_action 27 | self.latent_dim = latent_dim 28 | self.device = device 29 | 30 | def forward(self, state, action): 31 | mean, std = self.encode(state, action) 32 | z = mean + std * torch.randn_like(std) 33 | u = self.decode(state, z) 34 | return u, mean, std 35 | 36 | def elbo_loss(self, state, action, beta, num_samples=1): 37 | """ 38 | Note: elbo_loss one is proportional to elbo_estimator 39 | i.e. there exist a>0 and b, elbo_loss = a * (-elbo_estimator) + b 40 | """ 41 | mean, std = self.encode(state, action) 42 | 43 | mean_s = mean.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 44 | std_s = std.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 45 | z = mean_s + std_s * torch.randn_like(std_s) 46 | 47 | state = state.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 48 | action = action.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 49 | u = self.decode(state, z) 50 | recon_loss = ((u - action) ** 2).mean(dim=(1, 2)) 51 | 52 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean(-1) 53 | vae_loss = recon_loss + beta * KL_loss 54 | return vae_loss 55 | 56 | def iwae_loss(self, state, action, beta, num_samples=10): 57 | ll = self.importance_sampling_estimator(state, action, beta, num_samples) 58 | return -ll 59 | 60 | def elbo_estimator(self, state, action, beta, num_samples=1): 61 | mean, std = self.encode(state, action) 62 | 63 | mean_s = mean.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 64 | std_s = std.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 65 | z = mean_s + std_s * torch.randn_like(std_s) 66 | 67 | state = state.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 68 | action = action.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 69 | mean_dec = self.decode(state, z) 70 | std_dec = math.sqrt(beta / 4) 71 | 72 | # Find p(x|z) 73 | std_dec = torch.ones_like(mean_dec).to(self.device) * std_dec 74 | log_pxz = td.Normal(loc=mean_dec, scale=std_dec).log_prob(action) 75 | 76 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).sum(-1) 77 | elbo = log_pxz.sum(-1).mean(-1) - KL_loss 78 | return elbo 79 | 80 | def importance_sampling_estimator(self, state, action, beta, num_samples=500): 81 | # * num_samples correspond to num of samples L in the paper 82 | # * note that for exact value for \hat \log \pi_\beta in the paper, we also need **an expection over L samples** 83 | mean, std = self.encode(state, action) 84 | 85 | mean_enc = mean.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 86 | std_enc = std.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 87 | z = mean_enc + std_enc * torch.randn_like(std_enc) # [B x S x D] 88 | 89 | state = state.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 90 | action = action.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C] 91 | mean_dec = self.decode(state, z) 92 | std_dec = math.sqrt(beta / 4) 93 | 94 | # Find q(z|x) 95 | log_qzx = td.Normal(loc=mean_enc, scale=std_enc).log_prob(z) 96 | # Find p(z) 97 | mu_prior = torch.zeros_like(z).to(self.device) 98 | std_prior = torch.ones_like(z).to(self.device) 99 | log_pz = td.Normal(loc=mu_prior, scale=std_prior).log_prob(z) 100 | # Find p(x|z) 101 | std_dec = torch.ones_like(mean_dec).to(self.device) * std_dec 102 | log_pxz = td.Normal(loc=mean_dec, scale=std_dec).log_prob(action) 103 | 104 | w = log_pxz.sum(-1) + log_pz.sum(-1) - log_qzx.sum(-1) 105 | ll = w.logsumexp(dim=-1) - math.log(num_samples) 106 | return ll 107 | 108 | def encode(self, state, action): 109 | z = F.relu(self.e1(torch.cat([state, action], -1))) 110 | z = F.relu(self.e2(z)) 111 | 112 | mean = self.mean(z) 113 | # Clamped for numerical stability 114 | log_std = self.log_std(z).clamp(-4, 15) 115 | std = torch.exp(log_std) 116 | return mean, std 117 | 118 | def decode(self, state, z=None): 119 | # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5] 120 | if z is None: 121 | z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5) 122 | 123 | a = F.relu(self.d1(torch.cat([state, z], -1))) 124 | a = F.relu(self.d2(a)) 125 | if self.max_action is not None: 126 | return self.max_action * torch.tanh(self.d3(a)) 127 | else: 128 | return self.d3(a) 129 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import shutil 6 | import torch 7 | import numpy as np 8 | from termcolor import colored 9 | 10 | # Borrows from https://github.com/MishaLaskin/rad/blob/master/logger.py 11 | 12 | FORMAT_CONFIG = { 13 | 'rl': { 14 | 'train': [ 15 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 16 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 17 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'), 18 | ('critic_loss', 'CR_LOSS', 'float') 19 | ], 20 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float'), ('episode_reward_test_env', 'ERTEST', 'float')] 21 | } 22 | } 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self): 27 | self._sum = 0 28 | self._count = 0 29 | 30 | def update(self, value, n=1): 31 | self._sum += value 32 | self._count += n 33 | 34 | def value(self): 35 | return self._sum / max(1, self._count) 36 | 37 | 38 | class MetersGroup(object): 39 | def __init__(self, file_name, formating): 40 | self._file_name = file_name 41 | if os.path.exists(file_name): 42 | os.remove(file_name) 43 | self._formating = formating 44 | self._meters = defaultdict(AverageMeter) 45 | 46 | def log(self, key, value, n=1): 47 | self._meters[key].update(value, n) 48 | 49 | def _prime_meters(self): 50 | data = dict() 51 | for key, meter in self._meters.items(): 52 | if key.startswith('train'): 53 | key = key[len('train') + 1:] 54 | else: 55 | key = key[len('eval') + 1:] 56 | key = key.replace('/', '_') 57 | data[key] = meter.value() 58 | return data 59 | 60 | def _dump_to_file(self, data): 61 | with open(self._file_name, 'a') as f: 62 | f.write(json.dumps(data) + '\n') 63 | 64 | def _format(self, key, value, ty): 65 | template = '%s: ' 66 | if ty == 'int': 67 | template += '%d' 68 | elif ty == 'float': 69 | template += '%.04f' 70 | elif ty == 'time': 71 | template += '%.01f s' 72 | else: 73 | raise 'invalid format type: %s' % ty 74 | return template % (key, value) 75 | 76 | def _dump_to_console(self, data, prefix): 77 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 78 | pieces = ['{:5}'.format(prefix)] 79 | for key, disp_key, ty in self._formating: 80 | value = data.get(key, 0) 81 | pieces.append(self._format(disp_key, value, ty)) 82 | print('| %s' % (' | '.join(pieces))) 83 | 84 | def dump(self, step, prefix): 85 | if len(self._meters) == 0: 86 | return 87 | data = self._prime_meters() 88 | data['step'] = step 89 | self._dump_to_file(data) 90 | self._dump_to_console(data, prefix) 91 | self._meters.clear() 92 | 93 | 94 | class Logger(object): 95 | def __init__(self, log_dir, use_tb=True, config='rl', train_log_interval=100): 96 | self._log_dir = log_dir 97 | if use_tb: 98 | tb_dir = os.path.join(log_dir, 'tb') 99 | if os.path.exists(tb_dir): 100 | shutil.rmtree(tb_dir) 101 | self._sw = SummaryWriter(tb_dir) 102 | else: 103 | self._sw = None 104 | self._train_mg = MetersGroup( 105 | os.path.join(log_dir, 'train.log'), 106 | formating=FORMAT_CONFIG[config]['train'] 107 | ) 108 | self._eval_mg = MetersGroup( 109 | os.path.join(log_dir, 'eval.log'), 110 | formating=FORMAT_CONFIG[config]['eval'] 111 | ) 112 | self._train_log_interval = train_log_interval 113 | 114 | def _try_sw_log(self, key, value, step): 115 | if self._sw is not None: 116 | self._sw.add_scalar(key, value, step) 117 | 118 | def _try_sw_log_image(self, key, image, step): 119 | if self._sw is not None: 120 | assert image.dim() == 3 121 | # grid = torchvision.utils.make_grid(image.unsqueeze(1)) 122 | self._sw.add_image(key, image, step) 123 | 124 | def _try_sw_log_video(self, key, frames, step): 125 | if self._sw is not None: 126 | frames = torch.from_numpy(np.array(frames)) 127 | frames = frames.unsqueeze(0) 128 | self._sw.add_video(key, frames, step, fps=30) 129 | 130 | def _try_sw_log_histogram(self, key, histogram, step): 131 | if self._sw is not None: 132 | self._sw.add_histogram(key, histogram, step) 133 | 134 | def log(self, key, value, step, n=1): 135 | assert key.startswith('train') or key.startswith('eval') 136 | if key.startswith('train') and step % self._train_log_interval: 137 | return 138 | if type(value) == torch.Tensor: 139 | value = value.item() 140 | self._try_sw_log(key, value / n, step) 141 | mg = self._train_mg if key.startswith('train') else self._eval_mg 142 | mg.log(key, value, n) 143 | 144 | def log_param(self, key, param, step): 145 | self.log_histogram(key + '_w', param.weight.data, step) 146 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 147 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 148 | if hasattr(param, 'bias'): 149 | self.log_histogram(key + '_b', param.bias.data, step) 150 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 151 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 152 | 153 | def log_image(self, key, image, step): 154 | assert key.startswith('train') or key.startswith('eval') 155 | self._try_sw_log_image(key, image, step) 156 | 157 | def log_video(self, key, frames, step): 158 | assert key.startswith('train') or key.startswith('eval') 159 | self._try_sw_log_video(key, frames, step) 160 | 161 | def log_histogram(self, key, histogram, step): 162 | assert key.startswith('train') or key.startswith('eval') 163 | self._try_sw_log_histogram(key, histogram, step) 164 | 165 | def dump(self, step): 166 | self._train_mg.dump(step, 'train') 167 | self._eval_mg.dump(step, 'eval') 168 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import gym 7 | from tqdm import tqdm 8 | import os 9 | 10 | from vae import VAE 11 | import time 12 | from coolname import generate_slug 13 | import utils 14 | import json 15 | from log import Logger 16 | import d4rl 17 | from utils import get_lr 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument('--seed', type=int, default=0) 23 | # dataset 24 | parser.add_argument('--env', type=str, default='hopper') 25 | parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert 26 | parser.add_argument('--version', type=str, default='v2') 27 | # model 28 | parser.add_argument('--model', default='VAE', type=str) 29 | parser.add_argument('--hidden_dim', type=int, default=750) 30 | parser.add_argument('--beta', type=float, default=0.5) 31 | # train 32 | parser.add_argument('--num_iters', type=int, default=int(1e5)) 33 | parser.add_argument('--batch_size', type=int, default=256) 34 | parser.add_argument('--lr', type=float, default=1e-3) 35 | parser.add_argument('--weight_decay', default=0, type=float) 36 | parser.add_argument('--scheduler', default=False, action='store_true') 37 | parser.add_argument('--gamma', default=0.95, type=float) 38 | parser.add_argument('--no_max_action', default=False, action='store_true') 39 | parser.add_argument('--clip_to_eps', default=False, action='store_true') 40 | parser.add_argument('--eps', default=1e-4, type=float) 41 | parser.add_argument('--latent_dim', default=None, type=int, help="default: action_dim * 2") 42 | parser.add_argument('--no_normalize', default=False, action='store_true', help="do not normalize states") 43 | 44 | parser.add_argument('--eval_data', default=0.0, type=float, help="proportion of data used for validation, e.g. 0.05") 45 | # work dir 46 | parser.add_argument('--work_dir', type=str, default='train_vae') 47 | parser.add_argument('--notes', default=None, type=str) 48 | 49 | args = parser.parse_args() 50 | 51 | # make directory 52 | base_dir = 'runs' 53 | utils.make_dir(base_dir) 54 | base_dir = os.path.join(base_dir, args.work_dir) 55 | utils.make_dir(base_dir) 56 | args.work_dir = os.path.join(base_dir, args.env + '_' + args.dataset) 57 | utils.make_dir(args.work_dir) 58 | 59 | ts = time.gmtime() 60 | ts = time.strftime("%m-%d-%H:%M", ts) 61 | exp_name = str(args.env) + '-' + str(args.dataset) + '-' + ts + '-bs' \ 62 | + str(args.batch_size) + '-s' + str(args.seed) + '-b' + str(args.beta) + \ 63 | '-h' + str(args.hidden_dim) + '-lr' + str(args.lr) + '-wd' + str(args.weight_decay) 64 | exp_name += '-' + generate_slug(2) 65 | if args.notes is not None: 66 | exp_name = args.notes + '_' + exp_name 67 | args.work_dir = args.work_dir + '/' + exp_name 68 | utils.make_dir(args.work_dir) 69 | 70 | args.model_dir = os.path.join(args.work_dir, 'model') 71 | utils.make_dir(args.model_dir) 72 | 73 | with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: 74 | json.dump(vars(args), f, sort_keys=True, indent=4) 75 | 76 | utils.snapshot_src('.', os.path.join(args.work_dir, 'src'), '.gitignore') 77 | logger = Logger(args.work_dir, use_tb=True) 78 | 79 | utils.set_seed_everywhere(args.seed) 80 | 81 | device = 'cuda' 82 | 83 | # load data 84 | env_name = f"{args.env}-{args.dataset}-{args.version}" 85 | env = gym.make(env_name) 86 | 87 | state_dim = env.observation_space.shape[0] 88 | action_dim = env.action_space.shape[0] 89 | max_action = float(env.action_space.high[0]) 90 | if args.no_max_action: 91 | max_action = None 92 | print(state_dim, action_dim, max_action) 93 | latent_dim = action_dim * 2 94 | if args.latent_dim is not None: 95 | latent_dim = args.latent_dim 96 | 97 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 98 | replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env)) 99 | if not args.no_normalize: 100 | mean, std = replay_buffer.normalize_states() 101 | else: 102 | print("No normalize") 103 | if args.clip_to_eps: 104 | replay_buffer.clip_to_eps(args.eps) 105 | states = replay_buffer.state 106 | actions = replay_buffer.action 107 | 108 | if args.eval_data: 109 | eval_size = int(states.shape[0] * args.eval_data) 110 | eval_idx = np.random.choice(states.shape[0], eval_size, replace=False) 111 | train_idx = np.setdiff1d(np.arange(states.shape[0]), eval_idx) 112 | eval_states = states[eval_idx] 113 | eval_actions = actions[eval_idx] 114 | states = states[train_idx] 115 | actions = actions[train_idx] 116 | else: 117 | eval_states = None 118 | eval_actions = None 119 | 120 | # train 121 | if args.model == 'VAE': 122 | vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=args.hidden_dim).to(device) 123 | else: 124 | raise NotImplementedError 125 | optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr, weight_decay=args.weight_decay) 126 | if args.scheduler: 127 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.gamma) 128 | 129 | total_size = states.shape[0] 130 | batch_size = args.batch_size 131 | 132 | for step in tqdm(range(args.num_iters + 1), desc='train'): 133 | idx = np.random.choice(total_size, batch_size) 134 | train_states = torch.from_numpy(states[idx]).to(device) 135 | train_actions = torch.from_numpy(actions[idx]).to(device) 136 | 137 | # Variational Auto-Encoder Training 138 | recon, mean, std = vae(train_states, train_actions) 139 | 140 | recon_loss = F.mse_loss(recon, train_actions) 141 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 142 | vae_loss = recon_loss + args.beta * KL_loss 143 | 144 | logger.log('train/recon_loss', recon_loss, step=step) 145 | logger.log('train/KL_loss', KL_loss, step=step) 146 | logger.log('train/vae_loss', vae_loss, step=step) 147 | 148 | optimizer.zero_grad() 149 | vae_loss.backward() 150 | optimizer.step() 151 | 152 | if step % 5000 == 0: 153 | logger.dump(step) 154 | torch.save(vae.state_dict(), '%s/vae_model_%s_%s_b%s_%s.pt' % 155 | (args.model_dir, args.env, args.dataset, str(args.beta), step)) 156 | 157 | if eval_states is not None and eval_actions is not None: 158 | vae.eval() 159 | with torch.no_grad(): 160 | eval_states_tensor = torch.from_numpy(eval_states).to(device) 161 | eval_actions_tensor = torch.from_numpy(eval_actions).to(device) 162 | 163 | # Variational Auto-Encoder Evaluation 164 | recon, mean, std = vae(eval_states_tensor, eval_actions_tensor) 165 | 166 | recon_loss = F.mse_loss(recon, eval_actions_tensor) 167 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 168 | vae_loss = recon_loss + args.beta * KL_loss 169 | 170 | logger.log('eval/recon_loss', recon_loss, step=step) 171 | logger.log('eval/KL_loss', KL_loss, step=step) 172 | logger.log('eval/vae_loss', vae_loss, step=step) 173 | vae.train() 174 | 175 | if args.scheduler and (step + 1) % 10000 == 0: 176 | logger.log('train/lr', get_lr(optimizer), step=step) 177 | scheduler.step() 178 | 179 | logger._sw.close() 180 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | import d4rl 7 | from tqdm import trange 8 | from coolname import generate_slug 9 | import time 10 | import json 11 | import yaml 12 | from log import Logger 13 | 14 | import utils 15 | from utils import VideoRecorder 16 | import SPOT 17 | from vae import VAE 18 | from eval import eval_policy 19 | 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | parser = argparse.ArgumentParser() 27 | # Experiment 28 | parser.add_argument("--policy", default="SPOT_TD3") # Policy name 29 | parser.add_argument("--env", default="hopper-medium-v0") # OpenAI gym environment name 30 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 31 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 32 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 33 | parser.add_argument("--save_model", default=False, action="store_true") # Save model and optimizer parameters 34 | parser.add_argument('--save_model_final', default=True, action='store_true') 35 | parser.add_argument('--eval_episodes', default=10, type=int) 36 | parser.add_argument('--save_video', default=False, action='store_true') 37 | parser.add_argument('--clip_to_eps', default=False, action='store_true') 38 | # TD3 39 | parser.add_argument("--expl_noise", default=0.1, type=float) # Std of Gaussian exploration noise 40 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 41 | parser.add_argument("--discount", default=0.99, type=float) # Discount factor 42 | parser.add_argument("--tau", default=0.005) # Target network update rate 43 | parser.add_argument("--policy_noise", default=0.2, type=float) # Noise added to target policy during critic update 44 | parser.add_argument("--noise_clip", default=0.5, type=float) # Range to clip target policy noise 45 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 46 | parser.add_argument('--lr', default=3e-4, type=float) 47 | parser.add_argument('--actor_lr', default=None, type=float) 48 | # TD3 actor-critic 49 | parser.add_argument('--actor_hidden_dim', default=256, type=int) 50 | parser.add_argument('--critic_hidden_dim', default=256, type=int) 51 | parser.add_argument('--actor_init_w', default=None, type=float) 52 | parser.add_argument('--critic_init_w', default=None, type=float) 53 | parser.add_argument('--actor_dropout', default=0.1, type=float) 54 | # TD3 + BC 55 | parser.add_argument("--alpha", default=0.4, type=float) 56 | parser.add_argument("--normalize", default=True) 57 | # VAE 58 | parser.add_argument('--vae_model_path', default=None, type=str) 59 | parser.add_argument('--beta', default=0.5, type=float) 60 | parser.add_argument('--latent_dim', default=None, type=int) 61 | parser.add_argument('--iwae', default=False, action='store_true') 62 | parser.add_argument('--num_samples', default=1, type=int) 63 | # SPOT 64 | parser.add_argument('--lambd', default=1.0, type=float) 65 | parser.add_argument('--without_Q_norm', default=False, action='store_true') 66 | parser.add_argument('--lambd_cool', default=False, action='store_true') 67 | parser.add_argument('--lambd_end', default=0.2, type=float) 68 | # Antmaze 69 | parser.add_argument('--antmaze_center_reward', default=0.0, type=float) 70 | parser.add_argument('--antmaze_no_normalize', default=False, action='store_true') 71 | # Work dir 72 | parser.add_argument('--notes', default=None, type=str) 73 | parser.add_argument('--work_dir', default='tmp', type=str) 74 | # Config 75 | parser.add_argument('--config', default=None, type=str) 76 | 77 | args = parser.parse_args() 78 | # log config 79 | if args.config is not None: 80 | with open(args.config, 'r') as f: 81 | parser.set_defaults(**yaml.load(f.read(), Loader=yaml.FullLoader)) 82 | args = parser.parse_args() 83 | 84 | args.cooldir = generate_slug(2) 85 | 86 | # Build work dir 87 | base_dir = 'runs' 88 | utils.make_dir(base_dir) 89 | base_dir = os.path.join(base_dir, args.work_dir) 90 | utils.make_dir(base_dir) 91 | args.work_dir = os.path.join(base_dir, args.env) 92 | utils.make_dir(args.work_dir) 93 | 94 | # make directory 95 | ts = time.gmtime() 96 | ts = time.strftime("%m-%d-%H:%M", ts) 97 | exp_name = str(args.env) + '-' + ts + '-bs' + str(args.batch_size) + '-s' + str(args.seed) 98 | if args.policy == 'SPOT_TD3': 99 | exp_name += '-lamb' + str(args.lambd) + '-b' + \ 100 | str(args.beta) + '-a' + str(args.antmaze_center_reward) + '-lr' + str(args.lr) 101 | else: 102 | raise NotImplementedError 103 | exp_name += '-' + args.cooldir 104 | if args.notes is not None: 105 | exp_name = args.notes + '_' + exp_name 106 | args.work_dir = args.work_dir + '/' + exp_name 107 | utils.make_dir(args.work_dir) 108 | 109 | args.model_dir = os.path.join(args.work_dir, 'model') 110 | utils.make_dir(args.model_dir) 111 | args.video_dir = os.path.join(args.work_dir, 'video') 112 | utils.make_dir(args.video_dir) 113 | 114 | with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: 115 | json.dump(vars(args), f, sort_keys=True, indent=4) 116 | 117 | utils.snapshot_src('.', os.path.join(args.work_dir, 'src'), '.gitignore') 118 | 119 | print("---------------------------------------") 120 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 121 | print("---------------------------------------") 122 | 123 | env = gym.make(args.env) 124 | 125 | # Set seeds 126 | env.seed(args.seed) 127 | env.action_space.seed(args.seed) 128 | torch.manual_seed(args.seed) 129 | np.random.seed(args.seed) 130 | 131 | state_dim = env.observation_space.shape[0] 132 | action_dim = env.action_space.shape[0] 133 | max_action = float(env.action_space.high[0]) 134 | 135 | kwargs = { 136 | "state_dim": state_dim, 137 | "action_dim": action_dim, 138 | "max_action": max_action, 139 | "discount": args.discount, 140 | "tau": args.tau, 141 | # TD3 142 | "policy_noise": args.policy_noise * max_action, 143 | "noise_clip": args.noise_clip * max_action, 144 | "policy_freq": args.policy_freq, 145 | # SPOT 146 | "lambd": args.lambd, 147 | "lr": args.lr, 148 | "actor_lr": args.actor_lr, 149 | "without_Q_norm": args.without_Q_norm, 150 | "num_samples": args.num_samples, 151 | "iwae": args.iwae, 152 | "actor_hidden_dim": args.actor_hidden_dim, 153 | "critic_hidden_dim": args.critic_hidden_dim, 154 | "actor_dropout": args.actor_dropout, 155 | "actor_init_w": args.actor_init_w, 156 | "critic_init_w": args.critic_init_w, 157 | # finetune 158 | # "lambd_cool": args.lambd_cool, 159 | # "lambd_end": args.lambd_end, 160 | } 161 | 162 | # Initialize policy 163 | if args.policy == 'SPOT_TD3': 164 | vae = VAE(state_dim, action_dim, args.latent_dim if args.latent_dim else 2 * action_dim, max_action).to(device) 165 | vae.load_state_dict(torch.load(args.vae_model_path)) 166 | vae.eval() 167 | 168 | kwargs['vae'] = vae 169 | kwargs['beta'] = args.beta 170 | policy = SPOT.SPOT_TD3(**kwargs) 171 | else: 172 | raise NotImplementedError 173 | 174 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 175 | replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env)) 176 | print("Dataset size:", replay_buffer.reward.shape[0]) 177 | if 'antmaze' in args.env and args.antmaze_center_reward is not None: 178 | # Center reward for Ant-Maze 179 | # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 180 | replay_buffer.reward = np.where(replay_buffer.reward == 1.0, args.antmaze_center_reward, -1.0) 181 | if args.normalize and not ('antmaze' in args.env and args.antmaze_no_normalize): 182 | mean, std = replay_buffer.normalize_states() 183 | else: 184 | print("No normalize") 185 | mean, std = 0, 1 186 | if args.clip_to_eps: 187 | replay_buffer.clip_to_eps() 188 | 189 | logger = Logger(args.work_dir, use_tb=True) 190 | video = VideoRecorder(dir_name=args.video_dir) 191 | for t in trange(int(args.max_timesteps)): 192 | policy.train(replay_buffer, args.batch_size, logger=logger) 193 | 194 | # Evaluate episode 195 | if (t + 1) % args.eval_freq == 0: 196 | eval_episodes = 100 if t + 1 == int(args.max_timesteps) and 'antmaze' in args.env else args.eval_episodes 197 | d4rl_score = eval_policy(args, t + 1, video, logger, policy, args.env, 198 | args.seed, mean, std, eval_episodes=eval_episodes) 199 | if args.save_model: 200 | policy.save(args.model_dir) 201 | 202 | if args.save_model_final: 203 | policy.save(args.model_dir) 204 | 205 | logger._sw.close() 206 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | import d4rl 7 | from tqdm import trange 8 | from coolname import generate_slug 9 | import time 10 | import json 11 | import yaml 12 | from log import Logger 13 | 14 | import utils 15 | from utils import VideoRecorder 16 | import SPOT 17 | from vae import VAE 18 | from eval import eval_policy 19 | 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | parser = argparse.ArgumentParser() 27 | # Experiment 28 | parser.add_argument("--policy", default="SPOT_TD3") # Policy name 29 | parser.add_argument("--env", default="hopper-medium-v0") # OpenAI gym environment name 30 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 31 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 32 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 33 | parser.add_argument("--save_model", default=False, action="store_true") # Save model and optimizer parameters 34 | parser.add_argument('--save_model_final', default=True, action='store_true') 35 | parser.add_argument('--eval_episodes', default=10, type=int) 36 | parser.add_argument('--save_video', default=False, action='store_true') 37 | parser.add_argument('--clip_to_eps', default=False, action='store_true') 38 | # TD3 39 | parser.add_argument("--expl_noise", default=0.1, type=float) # Std of Gaussian exploration noise 40 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 41 | parser.add_argument("--discount", default=0.99, type=float) # Discount factor 42 | parser.add_argument("--tau", default=0.005) # Target network update rate 43 | parser.add_argument("--policy_noise", default=0.2, type=float) # Noise added to target policy during critic update 44 | parser.add_argument("--noise_clip", default=0.5, type=float) # Range to clip target policy noise 45 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 46 | parser.add_argument('--lr', default=3e-4, type=float) 47 | parser.add_argument('--actor_lr', default=None, type=float) 48 | # TD3 actor-critic 49 | parser.add_argument('--actor_hidden_dim', default=256, type=int) 50 | parser.add_argument('--critic_hidden_dim', default=256, type=int) 51 | parser.add_argument('--actor_init_w', default=None, type=float) 52 | parser.add_argument('--critic_init_w', default=None, type=float) 53 | parser.add_argument('--actor_dropout', default=0.1, type=float) 54 | # TD3 + BC 55 | parser.add_argument("--alpha", default=0.4, type=float) 56 | parser.add_argument("--normalize", default=True) 57 | # VAE 58 | parser.add_argument('--vae_model_path', default=None, type=str) 59 | parser.add_argument('--beta', default=0.5, type=float) 60 | parser.add_argument('--latent_dim', default=None, type=int) 61 | parser.add_argument('--iwae', default=False, action='store_true') 62 | parser.add_argument('--num_samples', default=1, type=int) 63 | # SPOT 64 | parser.add_argument('--lambd', default=1.0, type=float) 65 | parser.add_argument('--without_Q_norm', default=False, action='store_true') 66 | parser.add_argument('--lambd_cool', default=False, action='store_true') 67 | parser.add_argument('--lambd_end', default=0.2, type=float) 68 | # Antmaze 69 | parser.add_argument('--antmaze_center_reward', default=0.0, type=float) 70 | parser.add_argument('--antmaze_no_normalize', default=False, action='store_true') 71 | # Work dir 72 | parser.add_argument('--notes', default=None, type=str) 73 | parser.add_argument('--work_dir', default='tmp', type=str) 74 | # Config 75 | parser.add_argument('--config', default=None, type=str) 76 | 77 | # Finetune 78 | parser.add_argument('--pretrain_model', default=None, type=str) 79 | parser.add_argument('--pretrain_step', default=1000000, type=int) 80 | parser.add_argument('--buffer_size', default=2000000, type=int) 81 | parser.add_argument('--start_times', default=0, type=int) 82 | 83 | args = parser.parse_args() 84 | # log config 85 | if args.config is not None: 86 | with open(args.config, 'r') as f: 87 | parser.set_defaults(**yaml.load(f.read(), Loader=yaml.FullLoader)) 88 | args = parser.parse_args() 89 | 90 | args.cooldir = generate_slug(2) 91 | 92 | # Build work dir 93 | base_dir = 'runs' 94 | utils.make_dir(base_dir) 95 | base_dir = os.path.join(base_dir, args.work_dir) 96 | utils.make_dir(base_dir) 97 | args.work_dir = os.path.join(base_dir, args.env) 98 | utils.make_dir(args.work_dir) 99 | 100 | # make directory 101 | ts = time.gmtime() 102 | ts = time.strftime("%m-%d-%H:%M", ts) 103 | exp_name = str(args.env) + '-' + ts + '-bs' + str(args.batch_size) + '-s' + str(args.seed) 104 | if args.policy == 'SPOT_TD3': 105 | exp_name += '-lamb' + str(args.lambd) + '-lamb_end' + str(args.lambd_end) + '-b' + \ 106 | str(args.beta) + '-a' + str(args.antmaze_center_reward) + '-lr' + str(args.lr) 107 | else: 108 | raise NotImplementedError 109 | exp_name += '-' + args.cooldir 110 | if args.notes is not None: 111 | exp_name = args.notes + '_' + exp_name 112 | args.work_dir = args.work_dir + '/' + exp_name 113 | utils.make_dir(args.work_dir) 114 | 115 | args.model_dir = os.path.join(args.work_dir, 'model') 116 | utils.make_dir(args.model_dir) 117 | args.video_dir = os.path.join(args.work_dir, 'video') 118 | utils.make_dir(args.video_dir) 119 | 120 | with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: 121 | json.dump(vars(args), f, sort_keys=True, indent=4) 122 | 123 | utils.snapshot_src('.', os.path.join(args.work_dir, 'src'), '.gitignore') 124 | 125 | print("---------------------------------------") 126 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 127 | print("---------------------------------------") 128 | 129 | env = gym.make(args.env) 130 | 131 | # Set seeds 132 | env.seed(args.seed) 133 | env.action_space.seed(args.seed) 134 | torch.manual_seed(args.seed) 135 | np.random.seed(args.seed) 136 | 137 | state_dim = env.observation_space.shape[0] 138 | action_dim = env.action_space.shape[0] 139 | max_action = float(env.action_space.high[0]) 140 | 141 | kwargs = { 142 | "state_dim": state_dim, 143 | "action_dim": action_dim, 144 | "max_action": max_action, 145 | "discount": args.discount, 146 | "tau": args.tau, 147 | # TD3 148 | "policy_noise": args.policy_noise * max_action, 149 | "noise_clip": args.noise_clip * max_action, 150 | "policy_freq": args.policy_freq, 151 | # SPOT 152 | "lambd": args.lambd, 153 | "lr": args.lr, 154 | "actor_lr": args.actor_lr, 155 | "without_Q_norm": args.without_Q_norm, 156 | "num_samples": args.num_samples, 157 | "iwae": args.iwae, 158 | "actor_hidden_dim": args.actor_hidden_dim, 159 | "critic_hidden_dim": args.critic_hidden_dim, 160 | "actor_dropout": args.actor_dropout, 161 | "actor_init_w": args.actor_init_w, 162 | "critic_init_w": args.critic_init_w, 163 | # finetune 164 | "lambd_cool": args.lambd_cool, 165 | "lambd_end": args.lambd_end, 166 | } 167 | 168 | # Initialize policy 169 | if args.policy == 'SPOT_TD3': 170 | vae = VAE(state_dim, action_dim, args.latent_dim if args.latent_dim else 2 * action_dim, max_action).to(device) 171 | vae.load_state_dict(torch.load(args.vae_model_path)) 172 | vae.eval() 173 | 174 | kwargs['vae'] = vae 175 | kwargs['beta'] = args.beta 176 | policy = SPOT.SPOT_TD3(**kwargs) 177 | else: 178 | raise NotImplementedError 179 | 180 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim, max_size=args.buffer_size) 181 | replay_buffer.convert_D4RL_finetune(d4rl.qlearning_dataset(env)) 182 | assert replay_buffer.size + args.max_timesteps <= replay_buffer.max_size 183 | if 'antmaze' in args.env and args.antmaze_center_reward is not None: 184 | # Center reward for Ant-Maze 185 | # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 186 | replay_buffer.reward = np.where(replay_buffer.reward == 1.0, args.antmaze_center_reward, -1.0) 187 | if args.normalize and not ('antmaze' in args.env and args.antmaze_no_normalize): 188 | assert False 189 | # TODO: normalize to self.state[:self.ptr] 190 | mean, std = replay_buffer.normalize_states() 191 | else: 192 | print("No normalize") 193 | mean, std = 0, 1 194 | if args.clip_to_eps: 195 | replay_buffer.clip_to_eps() 196 | 197 | logger = Logger(args.work_dir, use_tb=True) 198 | video = VideoRecorder(dir_name=args.video_dir) 199 | 200 | state, done = env.reset(), False 201 | episode_reward = 0 202 | episode_timesteps = 0 203 | episode_num = 0 204 | episode_state = [state] 205 | 206 | # load offline pretrained model 207 | if args.pretrain_model is None: 208 | print("No pretrained model") 209 | exit(0) 210 | else: 211 | policy.load(args.pretrain_model, args.pretrain_step) 212 | 213 | for t in trange(int(args.max_timesteps)): 214 | episode_timesteps += 1 215 | 216 | # Select action randomly or according to policy 217 | action = ( 218 | policy.select_action(state) + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 219 | ).clip(-max_action, max_action) 220 | 221 | # Perform action 222 | next_state, reward, done, _ = env.step(action) 223 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 224 | 225 | if 'antmaze' in args.env and args.antmaze_center_reward is not None: 226 | reward_original = reward 227 | reward = args.antmaze_center_reward if done_bool else -1.0 228 | 229 | # Store data in replay buffer 230 | replay_buffer.add(state, action, next_state, reward, done_bool) 231 | 232 | state = next_state 233 | episode_reward += reward_original 234 | episode_state.append(state) 235 | 236 | if t >= args.start_times: 237 | policy.train_online(replay_buffer, args.batch_size, logger=logger) 238 | # policy.train(replay_buffer, args.batch_size, logger=logger) 239 | 240 | if done: 241 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 242 | print( 243 | f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f} Last Reward: {reward_original:.3f}") 244 | # Reset environment 245 | state, done = env.reset(), False 246 | episode_reward = 0 247 | episode_timesteps = 0 248 | episode_num += 1 249 | episode_state = [state] 250 | 251 | # Evaluate episode 252 | if t == 0 or (t + 1) % args.eval_freq == 0: 253 | eval_episodes = 100 if t + 1 == int(args.max_timesteps) and 'antmaze' in args.env else args.eval_episodes 254 | d4rl_score = eval_policy(args, t + 1, video, logger, policy, args.env, 255 | args.seed, mean, std, eval_episodes=eval_episodes) 256 | if args.save_model: 257 | policy.save(args.model_dir) 258 | 259 | if args.save_model_final: 260 | policy.save(args.model_dir) 261 | 262 | logger._sw.close() 263 | -------------------------------------------------------------------------------- /SPOT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import copy 6 | import os 7 | from vae import VAE 8 | from utils import ReplayBuffer 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def weights_init_(m, init_w=3e-3): 14 | if isinstance(m, nn.Linear): 15 | m.weight.data.uniform_(-init_w, init_w) 16 | m.bias.data.uniform_(-init_w, init_w) 17 | 18 | 19 | class Actor(nn.Module): 20 | def __init__(self, state_dim, action_dim, max_action, dropout=None, hidden_dim=256, init_w=None): 21 | super(Actor, self).__init__() 22 | 23 | if dropout: 24 | self.l1 = nn.Sequential(nn.Linear(state_dim, hidden_dim), nn.Dropout(dropout)) 25 | self.l2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Dropout(dropout)) 26 | else: 27 | self.l1 = nn.Linear(state_dim, hidden_dim) 28 | self.l2 = nn.Linear(hidden_dim, hidden_dim) 29 | self.l3 = nn.Linear(hidden_dim, action_dim) 30 | 31 | self.max_action = max_action 32 | 33 | if init_w: 34 | weights_init_(self.l3, init_w=init_w) 35 | 36 | def forward(self, state): 37 | a = F.relu(self.l1(state)) 38 | a = F.relu(self.l2(a)) 39 | a = self.l3(a) 40 | if self.max_action is not None: 41 | action = self.max_action * torch.tanh(a) 42 | else: 43 | action = a 44 | return action 45 | 46 | 47 | class Critic(nn.Module): 48 | def __init__(self, state_dim, action_dim, hidden_dim=256, init_w=None): 49 | super(Critic, self).__init__() 50 | 51 | # Q1 architecture 52 | self.l1 = nn.Linear(state_dim + action_dim, hidden_dim) 53 | self.l2 = nn.Linear(hidden_dim, hidden_dim) 54 | self.l3 = nn.Linear(hidden_dim, 1) 55 | 56 | # Q2 architecture 57 | self.l4 = nn.Linear(state_dim + action_dim, hidden_dim) 58 | self.l5 = nn.Linear(hidden_dim, hidden_dim) 59 | self.l6 = nn.Linear(hidden_dim, 1) 60 | 61 | if init_w: 62 | weights_init_(self.l3, init_w=init_w) 63 | weights_init_(self.l6, init_w=init_w) 64 | 65 | def forward(self, state, action): 66 | sa = torch.cat([state, action], 1) 67 | 68 | q1 = F.relu(self.l1(sa)) 69 | q1 = F.relu(self.l2(q1)) 70 | q1 = self.l3(q1) 71 | 72 | q2 = F.relu(self.l4(sa)) 73 | q2 = F.relu(self.l5(q2)) 74 | q2 = self.l6(q2) 75 | return q1, q2 76 | 77 | def Q1(self, state, action): 78 | sa = torch.cat([state, action], 1) 79 | 80 | q1 = F.relu(self.l1(sa)) 81 | q1 = F.relu(self.l2(q1)) 82 | q1 = self.l3(q1) 83 | return q1 84 | 85 | 86 | class SPOT_TD3(object): 87 | def __init__( 88 | self, 89 | vae: VAE, 90 | state_dim, 91 | action_dim, 92 | max_action, 93 | discount=0.99, 94 | tau=0.005, 95 | policy_noise=0.2, 96 | noise_clip=0.5, 97 | policy_freq=2, 98 | beta=0.5, 99 | lambd=1.0, 100 | lr=3e-4, 101 | actor_lr=None, 102 | without_Q_norm=False, 103 | # density estimation 104 | num_samples=1, 105 | iwae=False, 106 | # actor-critic 107 | actor_hidden_dim=256, 108 | critic_hidden_dim=256, 109 | actor_dropout=0.1, 110 | actor_init_w=None, 111 | critic_init_w=None, 112 | # finetune 113 | lambd_cool=False, 114 | lambd_end=0.2, 115 | ): 116 | self.total_it = 0 117 | 118 | # Actor 119 | self.actor = Actor(state_dim, action_dim, max_action, dropout=actor_dropout, 120 | hidden_dim=actor_hidden_dim, init_w=actor_init_w).to(device) 121 | self.actor_target = copy.deepcopy(self.actor) 122 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr or lr) 123 | 124 | # Critic 125 | self.critic = Critic(state_dim, action_dim, hidden_dim=critic_hidden_dim, init_w=critic_init_w).to(device) 126 | self.critic_target = copy.deepcopy(self.critic) 127 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) 128 | 129 | # TD3 130 | self.state_dim = state_dim 131 | self.action_dim = action_dim 132 | self.max_action = max_action 133 | self.discount = discount 134 | self.tau = tau 135 | self.policy_noise = policy_noise 136 | self.noise_clip = noise_clip 137 | self.policy_freq = policy_freq 138 | 139 | # density estimation 140 | self.vae = vae 141 | self.beta = beta 142 | self.num_samples = num_samples 143 | self.iwae = iwae 144 | self.without_Q_norm = without_Q_norm 145 | 146 | # support constraint 147 | self.lambd = lambd 148 | 149 | # fine-tuning 150 | self.lambd_cool = lambd_cool 151 | self.lambd_end = lambd_end 152 | 153 | def select_action(self, state): 154 | with torch.no_grad(): 155 | self.actor.eval() 156 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 157 | action = self.actor(state).cpu().data.numpy().flatten() 158 | self.actor.train() 159 | return action 160 | 161 | def train(self, replay_buffer: ReplayBuffer, batch_size=256, logger=None): 162 | self.total_it += 1 163 | 164 | # Sample replay buffer 165 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 166 | 167 | with torch.no_grad(): 168 | # Select action according to policy and add clipped noise 169 | noise = ( 170 | torch.randn_like(action) * self.policy_noise 171 | ).clamp(-self.noise_clip, self.noise_clip) 172 | 173 | next_action = ( 174 | self.actor_target(next_state) + noise 175 | ).clamp(-self.max_action, self.max_action) 176 | 177 | # Compute the target Q value 178 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 179 | target_Q = torch.min(target_Q1, target_Q2) 180 | 181 | target_Q = reward + not_done * self.discount * target_Q 182 | 183 | # Get current Q estimates 184 | current_Q1, current_Q2 = self.critic(state, action) 185 | 186 | # Compute critic loss 187 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 188 | 189 | # Optimize the critic 190 | self.critic_optimizer.zero_grad() 191 | critic_loss.backward() 192 | self.critic_optimizer.step() 193 | 194 | # Log 195 | logger.log('train/critic_loss', critic_loss, self.total_it) 196 | 197 | # Delayed policy updates 198 | if self.total_it % self.policy_freq == 0: 199 | 200 | # Compute actor loss 201 | pi = self.actor(state) 202 | Q = self.critic.Q1(state, pi) 203 | 204 | if self.iwae: 205 | neg_log_beta = self.vae.iwae_loss(state, pi, self.beta, self.num_samples) 206 | else: 207 | neg_log_beta = self.vae.elbo_loss(state, pi, self.beta, self.num_samples) 208 | 209 | if self.without_Q_norm: 210 | actor_loss = - Q.mean() + self.lambd * neg_log_beta.mean() 211 | else: 212 | actor_loss = - Q.mean() / Q.abs().mean().detach() + self.lambd * neg_log_beta.mean() 213 | 214 | # Optimize the actor 215 | self.actor_optimizer.zero_grad() 216 | actor_loss.backward() 217 | self.actor_optimizer.step() 218 | 219 | # Log 220 | logger.log('train/Q', Q.mean(), self.total_it) 221 | logger.log('train/actor_loss', actor_loss, self.total_it) 222 | logger.log('train/neg_log_beta', neg_log_beta.mean(), self.total_it) 223 | logger.log('train/neg_log_beta_max', neg_log_beta.max(), self.total_it) 224 | 225 | # kill for diverging 226 | if Q.mean().item() > 1e4: 227 | exit(0) 228 | 229 | # Update the frozen target models 230 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 231 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 232 | 233 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 234 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 235 | 236 | def train_online(self, replay_buffer: ReplayBuffer, batch_size=256, logger=None): 237 | self.total_it += 1 238 | 239 | # Sample replay buffer 240 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 241 | 242 | with torch.no_grad(): 243 | # Select action according to policy and add clipped noise 244 | noise = ( 245 | torch.randn_like(action) * self.policy_noise 246 | ).clamp(-self.noise_clip, self.noise_clip) 247 | 248 | next_action = ( 249 | self.actor_target(next_state) + noise 250 | ).clamp(-self.max_action, self.max_action) 251 | 252 | # Compute the target Q value 253 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 254 | target_Q = torch.min(target_Q1, target_Q2) 255 | 256 | target_Q = reward + not_done * self.discount * target_Q 257 | 258 | # Get current Q estimates 259 | current_Q1, current_Q2 = self.critic(state, action) 260 | 261 | # Compute critic loss 262 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 263 | 264 | # Optimize the critic 265 | self.critic_optimizer.zero_grad() 266 | critic_loss.backward() 267 | self.critic_optimizer.step() 268 | 269 | # Log 270 | logger.log('train/critic_loss', critic_loss, self.total_it) 271 | 272 | # Delayed policy updates 273 | if self.total_it % self.policy_freq == 0: 274 | 275 | # Compute actor loss 276 | pi = self.actor(state) 277 | Q = self.critic.Q1(state, pi) 278 | 279 | if self.iwae: 280 | neg_log_beta = self.vae.iwae_loss(state, pi, self.beta, self.num_samples) 281 | else: 282 | neg_log_beta = self.vae.elbo_loss(state, pi, self.beta, self.num_samples) 283 | 284 | if self.lambd_cool: 285 | lambd = self.lambd * max(self.lambd_end, (1.0 - self.total_it / 1000000)) 286 | logger.log('train/lambd', lambd, self.total_it) 287 | else: 288 | lambd = self.lambd 289 | 290 | if self.without_Q_norm: 291 | actor_loss = - Q.mean() + lambd * neg_log_beta.mean() 292 | else: 293 | actor_loss = - Q.mean() / Q.abs().mean().detach() + lambd * neg_log_beta.mean() 294 | 295 | # Optimize the actor 296 | self.actor_optimizer.zero_grad() 297 | actor_loss.backward() 298 | self.actor_optimizer.step() 299 | 300 | # Log 301 | logger.log('train/Q', Q.mean(), self.total_it) 302 | logger.log('train/actor_loss', actor_loss, self.total_it) 303 | logger.log('train/neg_log_beta', neg_log_beta.mean(), self.total_it) 304 | logger.log('train/neg_log_beta_max', neg_log_beta.max(), self.total_it) 305 | 306 | # Update the frozen target models 307 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 308 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 309 | 310 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 311 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 312 | 313 | def save(self, model_dir): 314 | torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_s{str(self.total_it)}.pth")) 315 | torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_s{str(self.total_it)}.pth")) 316 | torch.save(self.critic_optimizer.state_dict(), os.path.join( 317 | model_dir, f"critic_optimizer_s{str(self.total_it)}.pth")) 318 | 319 | torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_s{str(self.total_it)}.pth")) 320 | torch.save(self.actor_target.state_dict(), os.path.join(model_dir, f"actor_target_s{str(self.total_it)}.pth")) 321 | torch.save(self.actor_optimizer.state_dict(), os.path.join( 322 | model_dir, f"actor_optimizer_s{str(self.total_it)}.pth")) 323 | 324 | def load(self, model_dir, step=1000000): 325 | self.critic.load_state_dict(torch.load(os.path.join(model_dir, f"critic_s{str(step)}.pth"))) 326 | self.critic_target.load_state_dict(torch.load(os.path.join(model_dir, f"critic_target_s{str(step)}.pth"))) 327 | self.critic_optimizer.load_state_dict(torch.load(os.path.join(model_dir, f"critic_optimizer_s{str(step)}.pth"))) 328 | 329 | self.actor.load_state_dict(torch.load(os.path.join(model_dir, f"actor_s{str(step)}.pth"))) 330 | self.actor_target.load_state_dict(torch.load(os.path.join(model_dir, f"actor_target_s{str(step)}.pth"))) 331 | self.actor_optimizer.load_state_dict(torch.load(os.path.join(model_dir, f"actor_optimizer_s{str(step)}.pth"))) 332 | --------------------------------------------------------------------------------