├── .gitignore ├── README.md ├── configs ├── offline_IL.yaml └── offline_RL.yaml ├── dataset_init.py ├── main_odice_il.py ├── main_odice_rl.py ├── odice.py ├── policy.py ├── requirements.txt ├── run_offline_il.sh ├── util.py └── value_functions.py /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | wandb 3 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ODICE: Revealing the Mystery of Distribution Correction Estimation via Orthogonal-gradient Update 2 | 3 | This is the official implementation for the paper **ODICE: Revealing the Mystery of Distribution Correction Estimation via Orthogonal-gradient Update** accepted as **Spotlight** at ICLR'2024. 4 | 5 | ### Usage 6 | To reproduce the experiments in Offline RL part, please run: 7 | ``` Bash 8 | python main_odice_rl.py --env_name your_env_name --Lambda your_lambda --eta your_eta --type orthogonal_true_g 9 | ``` 10 | 11 | To reproduce the experiments in Offline IL part, please run: 12 | ``` Bash 13 | python main_odice_il.py --env_name your_env_name --Lambda your_lambda --eta your_eta --type orthogonal_true_g 14 | ``` 15 | 16 | Note that although we set "--type" as "orthogonal_true_g" for ODICE, you can check the results of other gradient types("true_g" and "semi_g") if you like. The choice of other hyper-parameters are listed in appendix D. 17 | 18 | ### Bibtex 19 | ``` 20 | @inproceedings{mao2024odice, 21 | title = {ODICE: Revealing the Mystery of Distribution Correction Estimation via Orthogonal-gradient Update}, 22 | author = {Liyuan Mao, Haoran Xu, Weinan Zhang, Xianyuan Zhan}, 23 | year = {2024}, 24 | booktitle = {International Conference on Learning Representations}, 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /configs/offline_IL.yaml: -------------------------------------------------------------------------------- 1 | # experiment parameters 2 | log_dir: ./results/ 3 | model_dir: ./models/ 4 | #seed: 0 5 | train_steps: 500000 6 | eval_period: 5000 7 | n_eval_episodes: 10 8 | max_episode_steps: 1000 9 | load_step: 0 # step of model to load (if needed) 10 | 11 | # network parameters 12 | hidden_dim: 256 13 | n_hidden: 2 14 | batch_size: 256 15 | value_lr: 1.0e-4 16 | policy_lr: 1.0e-4 17 | #weight_decay: 1.0e-5 # 1.0e-5 18 | layer_norm: False 19 | use_tanh: True # use tanh for policy output 20 | 21 | # RL parameters 22 | f_name: Pearson_chi_square 23 | discount: 0.99 24 | normalize: True # use data normalization 25 | use_twin_v: True -------------------------------------------------------------------------------- /configs/offline_RL.yaml: -------------------------------------------------------------------------------- 1 | # experiment parameters 2 | log_dir: ./results/ 3 | model_dir: ./models/ 4 | seed: 0 5 | train_steps: 1000000 6 | eval_period: 5000 7 | n_eval_episodes: 10 8 | max_episode_steps: 1000 9 | load_step: 0 # step of model to load (if needed) 10 | 11 | # network parameters 12 | hidden_dim: 256 13 | n_hidden: 2 14 | batch_size: 256 15 | value_lr: 2.0e-4 16 | policy_lr: 2.0e-4 17 | weight_decay: 1.0e-4 18 | layer_norm: True 19 | use_tanh: False # use tanh for policy output 20 | 21 | # RL parameters 22 | f_name: Pearson_chi_square 23 | discount: 0.99 24 | normalize: False # use data normalization 25 | use_twin_v: True -------------------------------------------------------------------------------- /dataset_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import collections 4 | import numpy as np 5 | 6 | import d4rl.infos 7 | from d4rl.offline_env import set_dataset_path, get_keys 8 | 9 | SUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0)) 10 | 11 | _ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.' 12 | 13 | try: 14 | import d4rl.locomotion 15 | import d4rl.hand_manipulation_suite 16 | import d4rl.pointmaze 17 | import d4rl.gym_minigrid 18 | import d4rl.gym_mujoco 19 | except ImportError as e: 20 | if not SUPPRESS_MESSAGES: 21 | print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr) 22 | print(e, file=sys.stderr) 23 | 24 | try: 25 | import d4rl.flow 26 | except ImportError as e: 27 | if not SUPPRESS_MESSAGES: 28 | print(_ERROR_MESSAGE % 'Flow', file=sys.stderr) 29 | print(e, file=sys.stderr) 30 | 31 | try: 32 | import d4rl.kitchen 33 | except ImportError as e: 34 | if not SUPPRESS_MESSAGES: 35 | print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr) 36 | print(e, file=sys.stderr) 37 | 38 | try: 39 | import d4rl.carla 40 | except ImportError as e: 41 | if not SUPPRESS_MESSAGES: 42 | print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr) 43 | print(e, file=sys.stderr) 44 | 45 | try: 46 | import d4rl.gym_bullet 47 | import d4rl.pointmaze_bullet 48 | except ImportError as e: 49 | if not SUPPRESS_MESSAGES: 50 | print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr) 51 | print(e, file=sys.stderr) 52 | 53 | def reverse_normalized_score(env_name, score): 54 | ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] 55 | ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] 56 | return (score * (ref_max_score - ref_min_score)) + ref_min_score 57 | 58 | def get_normalized_score(env_name, score): 59 | ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] 60 | ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] 61 | return (score - ref_min_score) / (ref_max_score - ref_min_score) 62 | 63 | def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs): 64 | """ 65 | Returns datasets formatted for use by standard Q-learning algorithms, 66 | with observations, actions, next_observations, rewards, and a terminal 67 | flag. 68 | 69 | Args: 70 | env: An OfflineEnv object. 71 | dataset: An optional dataset to pass in for processing. If None, 72 | the dataset will default to env.get_dataset() 73 | terminate_on_end (bool): Set done=True on the last timestep 74 | in a trajectory. Default is False, and will discard the 75 | last timestep in each trajectory. 76 | **kwargs: Arguments to pass to env.get_dataset(). 77 | 78 | Returns: 79 | A dictionary containing keys: 80 | observations: An N x dim_obs array of observations. 81 | actions: An N x dim_action array of actions. 82 | next_observations: An N x dim_obs array of next observations. 83 | rewards: An N-dim float array of rewards. 84 | terminals: An N-dim boolean array of "done" or episode termination flags. 85 | """ 86 | if dataset is None: 87 | dataset = env.get_dataset(**kwargs) 88 | 89 | N = dataset['rewards'].shape[0] 90 | obs_ = [] 91 | next_obs_ = [] 92 | action_ = [] 93 | reward_ = [] 94 | done_ = [] 95 | trajectory_done_ = [] 96 | 97 | # The newer version of the dataset adds an explicit 98 | # timeouts field. Keep old method for backwards compatability. 99 | use_timeouts = False 100 | if 'timeouts' in dataset: 101 | use_timeouts = True 102 | 103 | episode_step = 0 104 | for i in range(N-1): 105 | obs = dataset['observations'][i].astype(np.float32) 106 | new_obs = dataset['observations'][i+1].astype(np.float32) 107 | action = dataset['actions'][i].astype(np.float32) 108 | reward = dataset['rewards'][i].astype(np.float32) 109 | done_bool = bool(dataset['terminals'][i]) 110 | 111 | if use_timeouts: 112 | final_timestep = dataset['timeouts'][i] 113 | else: 114 | final_timestep = (episode_step == env._max_episode_steps - 1) 115 | if (not terminate_on_end) and final_timestep: 116 | # Skip this transition and don't apply terminals on the last step of an episode 117 | episode_step = 0 118 | trajectory_done_[-1] = True 119 | continue 120 | if done_bool or final_timestep: 121 | episode_step = 0 122 | 123 | obs_.append(obs) 124 | next_obs_.append(new_obs) 125 | action_.append(action) 126 | reward_.append(reward) 127 | done_.append(done_bool) 128 | trajectory_done_.append(final_timestep) 129 | episode_step += 1 130 | 131 | return { 132 | 'observations': np.array(obs_), 133 | 'actions': np.array(action_), 134 | 'next_observations': np.array(next_obs_), 135 | 'rewards': np.array(reward_), 136 | 'terminals': np.array(done_), 137 | 'trajectory_terminals': np.array(trajectory_done_), 138 | } 139 | 140 | 141 | def sequence_dataset(env, dataset=None, **kwargs): 142 | """ 143 | Returns an iterator through trajectories. 144 | 145 | Args: 146 | env: An OfflineEnv object. 147 | dataset: An optional dataset to pass in for processing. If None, 148 | the dataset will default to env.get_dataset() 149 | **kwargs: Arguments to pass to env.get_dataset(). 150 | 151 | Returns: 152 | An iterator through dictionaries with keys: 153 | observations 154 | actions 155 | rewards 156 | terminals 157 | """ 158 | if dataset is None: 159 | dataset = env.get_dataset(**kwargs) 160 | 161 | N = dataset['rewards'].shape[0] 162 | data_ = collections.defaultdict(list) 163 | 164 | # The newer version of the dataset adds an explicit 165 | # timeouts field. Keep old method for backwards compatability. 166 | use_timeouts = False 167 | if 'timeouts' in dataset: 168 | use_timeouts = True 169 | 170 | episode_step = 0 171 | for i in range(N): 172 | done_bool = bool(dataset['terminals'][i]) 173 | if use_timeouts: 174 | final_timestep = dataset['timeouts'][i] 175 | else: 176 | final_timestep = (episode_step == env._max_episode_steps - 1) 177 | 178 | for k in dataset: 179 | data_[k].append(dataset[k][i]) 180 | 181 | if done_bool or final_timestep: 182 | episode_step = 0 183 | episode_data = {} 184 | for k in data_: 185 | episode_data[k] = np.array(data_[k]) 186 | yield episode_data 187 | data_ = collections.defaultdict(list) 188 | 189 | episode_step += 1 190 | 191 | -------------------------------------------------------------------------------- /main_odice_il.py: -------------------------------------------------------------------------------- 1 | # path 2 | import os, sys 3 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(BASE_DIR) 5 | import argparse, yaml 6 | import gym 7 | import os 8 | import d4rl 9 | import numpy as np 10 | import torch 11 | from tqdm import trange 12 | from odice import ODICE 13 | from policy import GaussianPolicy 14 | from value_functions import ValueFunction, TwinV 15 | from util import return_range, set_seed, sample_batch, torchify, evaluate 16 | import wandb 17 | import time 18 | 19 | 20 | def dataset_T_trajs(dataset, T, terminate_on_end=False): 21 | """ 22 | Returns Tth trajs from dataset. 23 | """ 24 | N = dataset['rewards'].shape[0] 25 | return_traj = [] 26 | obs_traj = [[]] 27 | next_obs_traj = [[]] 28 | action_traj = [[]] 29 | reward_traj = [[]] 30 | done_traj = [[]] 31 | 32 | for i in range(N-1): 33 | obs_traj[-1].append(dataset['observations'][i].astype(np.float32)) 34 | next_obs_traj[-1].append(dataset['observations'][i+1].astype(np.float32)) 35 | action_traj[-1].append(dataset['actions'][i].astype(np.float32)) 36 | reward_traj[-1].append(np.zeros_like(dataset['rewards'][i]).astype(np.float32)) 37 | done_traj[-1].append(bool(dataset['terminals'][i])) 38 | 39 | final_timestep = dataset['timeouts'][i] | dataset['terminals'][i] 40 | if (not terminate_on_end) and final_timestep: 41 | # Skip this transition and don't apply terminals on the last step of an episode 42 | return_traj.append(np.sum(reward_traj[-1])) 43 | obs_traj.append([]) 44 | next_obs_traj.append([]) 45 | action_traj.append([]) 46 | reward_traj.append([]) 47 | done_traj.append([]) 48 | 49 | # select Tth trajectories 50 | inds_all = list(range(len(obs_traj))) 51 | assert T < len(inds_all) 52 | inds = inds_all[T:T+1] 53 | inds = list(inds) 54 | 55 | print('# select {}th trajs in the dataset'.format(T)) 56 | 57 | obs_traj = [obs_traj[i] for i in inds] 58 | next_obs_traj = [next_obs_traj[i] for i in inds] 59 | action_traj = [action_traj[i] for i in inds] 60 | reward_traj = [reward_traj[i] for i in inds] 61 | done_traj = [done_traj[i] for i in inds] 62 | 63 | def concat_trajectories(trajectories): 64 | return np.concatenate(trajectories, 0) 65 | 66 | return { 67 | 'observations': concat_trajectories(obs_traj), 68 | 'actions': concat_trajectories(action_traj), 69 | 'next_observations': concat_trajectories(next_obs_traj), 70 | 'rewards': concat_trajectories(reward_traj), 71 | 'terminals': concat_trajectories(done_traj), 72 | } 73 | 74 | 75 | def get_env_and_dataset(env_name, max_episode_steps, normalize, T): 76 | env = gym.make(env_name) 77 | dataset = env.get_dataset() 78 | dataset = dataset_T_trajs(dataset, T) 79 | 80 | dataset_length = len(dataset['terminals']) 81 | if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')): 82 | min_ret, max_ret = return_range(dataset, max_episode_steps) 83 | print(f'Dataset returns have range [{min_ret}, {max_ret}]') 84 | elif 'antmaze' in env_name: 85 | dataset['rewards'] = np.where(dataset['rewards'] == 0., -3.0, 0) 86 | 87 | 88 | print("***********************************************************************") 89 | print(f"Normalize for the state: {normalize}") 90 | print("***********************************************************************") 91 | if normalize: 92 | mean = dataset['observations'].mean(0) 93 | std = dataset['observations'].std(0) + 1e-3 94 | dataset['observations'] = (dataset['observations'] - mean)/std 95 | dataset['next_observations'] = (dataset['next_observations'] - mean)/std 96 | else: 97 | obs_dim = dataset['observations'].shape[1] 98 | mean, std = np.zeros(obs_dim), np.ones(obs_dim) 99 | 100 | for k, v in dataset.items(): 101 | dataset[k] = torchify(v) 102 | for k, v in list(dataset.items()): 103 | assert len(v) == dataset_length, 'Dataset values must have same length' 104 | 105 | return env, dataset, mean, std 106 | 107 | 108 | def main(args): 109 | # args.log_dir = '/'.join(__file__.split('/')[: -1]) + '/' + args.log_dir 110 | # args.model_dir = '/'.join(__file__.split('/')[: -1]) + '/' + args.model_dir 111 | if 'antmaze' in args.env_name: 112 | args.eval_period = 20000 if args.eval_period < 20000 else args.eval_period 113 | args.n_eval_episodes = 50 114 | 115 | wandb.init(project=f"odice_offline_IL", 116 | entity="your name", 117 | name=f"{args.env_name}_ODICE", 118 | config={ 119 | "env_name": args.env_name, 120 | "type": args.type, 121 | "seed": args.seed, 122 | "normalize": args.normalize, 123 | "Lambda": args.Lambda, 124 | "eta": args.eta, 125 | "use_twin_v": args.use_twin_v, 126 | "use_tanh": args.use_tanh, 127 | "f_name": args.f_name, 128 | "weight_decay": args.weight_decay, 129 | "gamma": args.discount, 130 | "T": args.T, 131 | }) 132 | torch.set_num_threads(1) 133 | 134 | env, dataset, mean, std = get_env_and_dataset(args.env_name, 135 | args.max_episode_steps, 136 | args.normalize, 137 | args.T, 138 | ) 139 | obs_dim = dataset['observations'].shape[1] 140 | act_dim = dataset['actions'].shape[1] # this assume continuous actions 141 | set_seed(args.seed, env=env) 142 | 143 | policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=1024, n_hidden=2, use_tanh=args.use_tanh) 144 | vf = TwinV(obs_dim, layer_norm=args.layer_norm, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) if args.use_twin_v else ValueFunction(obs_dim, layer_norm=args.layer_norm, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 145 | 146 | odice = ODICE( 147 | vf=vf, 148 | policy=policy, 149 | max_steps=args.train_steps, 150 | f_name=args.f_name, 151 | Lambda=args.Lambda, 152 | eta=args.eta, 153 | discount=args.discount, 154 | value_lr=args.value_lr, 155 | policy_lr=args.policy_lr, 156 | weight_decay=args.weight_decay, 157 | use_twin_v=args.use_twin_v, 158 | ) 159 | 160 | def eval(step): 161 | eval_returns = np.array([evaluate(env, policy, mean, std) \ 162 | for _ in range(args.n_eval_episodes)]) 163 | normalized_returns = d4rl.get_normalized_score(args.env_name, eval_returns) * 100.0 164 | return_info = {} 165 | return_info["normalized return mean"] = normalized_returns.mean() 166 | return_info["normalized return std"] = normalized_returns.std() 167 | return_info["percent difference 10"] = (normalized_returns[: 10].min() - normalized_returns[: 10].mean()) / normalized_returns[: 10].mean() 168 | wandb.log(return_info, step=step) 169 | 170 | print("---------------------------------------") 171 | print(f"Env: {args.env_name}, Evaluation over {args.n_eval_episodes} episodes: D4RL score: {normalized_returns.mean():.3f}") 172 | print("---------------------------------------") 173 | 174 | return normalized_returns.mean() 175 | 176 | algo_name = f"{args.type}_lambda-{args.Lambda}_gamma-{args.discount}_eta-{args.eta}_f_name-{args.f_name}_use_tanh-{args.use_tanh}_normalize-{args.normalize}_use_twin_v-{args.use_twin_v}" 177 | os.makedirs(f"{args.log_dir}/{args.env_name}/{algo_name}", exist_ok=True) 178 | eval_log = open(f"{args.log_dir}/{args.env_name}/{algo_name}/seed-{args.seed}.txt", 'w') 179 | for step in trange(args.train_steps): 180 | if args.type == 'orthogonal_true_g': 181 | odice.orthogonal_true_g_update(**sample_batch(dataset, args.batch_size)) 182 | elif args.type == 'true_g': 183 | odice.true_g_update(**sample_batch(dataset, args.batch_size)) 184 | elif args.type == 'semi_g': 185 | odice.semi_g_update(**sample_batch(dataset, args.batch_size)) 186 | 187 | if (step+1) % args.eval_period == 0: 188 | average_returns = eval(odice.step) 189 | eval_log.write(f'{step + 1}\tavg return: {average_returns}\t\n') 190 | eval_log.flush() 191 | eval_log.close() 192 | os.makedirs(f"{args.model_dir}/{args.env_name}", exist_ok=True) 193 | odice.save(f"{args.model_dir}/{args.env_name}") 194 | 195 | 196 | if __name__ == '__main__': 197 | from argparse import ArgumentParser 198 | parser = ArgumentParser() 199 | parser.add_argument('--env_name', type=str, default="hopper-expert-v2") 200 | parser.add_argument('--seed', type=int, default=1) 201 | parser.add_argument('--Lambda', type=float, default=0.4) 202 | parser.add_argument('--eta', type=float, default=1.0) 203 | parser.add_argument('--T', type=int, default=1) 204 | parser.add_argument('--weight_decay', type=float, default=1e-5) 205 | parser.add_argument("--type", type=str, choices=['orthogonal_true_g', 'true_g', 'semi_g'], default='orthogonal_true_g') 206 | with open("configs/offline_IL.yaml", "r") as file: 207 | config = yaml.safe_load(file) 208 | now = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 209 | 210 | args = parser.parse_args(namespace=argparse.Namespace(**config)) 211 | 212 | main(args) -------------------------------------------------------------------------------- /main_odice_rl.py: -------------------------------------------------------------------------------- 1 | import argparse, yaml 2 | import gym 3 | import os 4 | import d4rl 5 | import numpy as np 6 | import torch 7 | from tqdm import trange 8 | from collections import defaultdict 9 | from odice import ODICE 10 | from policy import GaussianPolicy 11 | from value_functions import ValueFunction, TwinV 12 | from util import return_range, set_seed, sample_batch, torchify, evaluate 13 | import wandb 14 | import time 15 | 16 | 17 | 18 | def get_env_and_dataset(env_name, max_episode_steps, normalize): 19 | env = gym.make(env_name) 20 | dataset = d4rl.qlearning_dataset(env) 21 | dataset_length = len(dataset['terminals']) 22 | if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')): 23 | min_ret, max_ret = return_range(dataset, max_episode_steps) 24 | print(f'Dataset returns have range [{min_ret}, {max_ret}]') 25 | dataset['rewards'] /= (max_ret - min_ret) 26 | dataset['rewards'] *= max_episode_steps 27 | elif 'antmaze' in env_name: 28 | dataset['rewards'] = np.where(dataset['rewards'] == 0., -3.0, 0) 29 | 30 | print("***********************************************************************") 31 | print(f"Normalize for the state: {normalize}") 32 | print("***********************************************************************") 33 | if normalize: 34 | mean = dataset['observations'].mean(0) 35 | std = dataset['observations'].std(0) + 1e-3 36 | dataset['observations'] = (dataset['observations'] - mean)/std 37 | dataset['next_observations'] = (dataset['next_observations'] - mean)/std 38 | else: 39 | obs_dim = dataset['observations'].shape[1] 40 | mean, std = np.zeros(obs_dim), np.ones(obs_dim) 41 | 42 | for k, v in dataset.items(): 43 | dataset[k] = torchify(v) 44 | for k, v in list(dataset.items()): 45 | assert len(v) == dataset_length, 'Dataset values must have same length' 46 | 47 | return env, dataset, mean, std 48 | 49 | 50 | def main(args): 51 | args.log_dir = '/'.join(__file__.split('/')[: -1]) + '/' + args.log_dir 52 | args.model_dir = '/'.join(__file__.split('/')[: -1]) + '/' + args.model_dir 53 | if 'antmaze' in args.env_name: 54 | args.eval_period = 20000 if args.eval_period < 20000 else args.eval_period 55 | args.n_eval_episodes = 50 56 | args.layer_norm = False 57 | if 'large' in args.env_name or 'umaze-diverse' in args.env_name: 58 | args.use_twin_v = False 59 | 60 | wandb.init(project=f"odice_offline_RL", 61 | entity="your name", 62 | name=f"{args.env_name}_ODICE", 63 | config={ 64 | "env_name": args.env_name, 65 | "type": args.type, 66 | "seed": args.seed, 67 | "normalize": args.normalize, 68 | "Lambda": args.Lambda, 69 | "eta": args.eta, 70 | "use_twin_v": args.use_twin_v, 71 | "use_tanh": args.use_tanh, 72 | "f_name": args.f_name, 73 | "weight_decay": args.weight_decay, 74 | "gamma": args.discount, 75 | }) 76 | torch.set_num_threads(1) 77 | 78 | env, dataset, mean, std = get_env_and_dataset(args.env_name, 79 | args.max_episode_steps, 80 | args.normalize) 81 | obs_dim = dataset['observations'].shape[1] 82 | act_dim = dataset['actions'].shape[1] # this assume continuous actions 83 | set_seed(args.seed, env=env) 84 | 85 | policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=1024, n_hidden=2, use_tanh=args.use_tanh) 86 | vf = TwinV(obs_dim, layer_norm=args.layer_norm, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) if args.use_twin_v else ValueFunction(obs_dim, layer_norm=args.layer_norm, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 87 | 88 | odice = ODICE( 89 | vf=vf, 90 | policy=policy, 91 | max_steps=args.train_steps, 92 | f_name=args.f_name, 93 | Lambda=args.Lambda, 94 | eta=args.eta, 95 | discount=args.discount, 96 | value_lr=args.value_lr, 97 | policy_lr=args.policy_lr, 98 | weight_decay=args.weight_decay, 99 | use_twin_v = args.use_twin_v, 100 | ) 101 | if os.path.exists(f"{args.model_dir}/{args.env_name}" + f"/eta_{args.eta}_Lambda_{args.Lambda}_checkpoint_{args.load_step}.pth"): 102 | odice.load(f"{args.model_dir}/{args.env_name}", args.load_step) 103 | 104 | def eval(step): 105 | eval_returns = np.array([evaluate(env, policy, mean, std) \ 106 | for _ in range(args.n_eval_episodes)]) 107 | normalized_returns = d4rl.get_normalized_score(args.env_name, eval_returns) * 100.0 108 | return_info = {} 109 | return_info["normalized return mean"] = normalized_returns.mean() 110 | return_info["normalized return std"] = normalized_returns.std() 111 | return_info["percent difference 10"] = (normalized_returns[: 10].min() - normalized_returns[: 10].mean()) / normalized_returns[: 10].mean() 112 | wandb.log(return_info, step=step) 113 | 114 | print("---------------------------------------") 115 | print(f"Env: {args.env_name}, Evaluation over {args.n_eval_episodes} episodes: D4RL score: {normalized_returns.mean():.3f}") 116 | print("---------------------------------------") 117 | 118 | return normalized_returns.mean() 119 | 120 | algo_name = f"{args.type}_lambda-{args.Lambda}_gamma-{args.discount}_eta-{args.eta}_f_name-{args.f_name}_use_tanh-{args.use_tanh}_normalize-{args.normalize}_use_twin_v-{args.use_twin_v}" 121 | os.makedirs(f"{args.log_dir}/{args.env_name}/{algo_name}", exist_ok=True) 122 | eval_log = open(f"{args.log_dir}/{args.env_name}/{algo_name}/seed-{args.seed}.txt", 'w') 123 | for step in trange(args.train_steps): 124 | if args.type == 'orthogonal_true_g': 125 | odice.orthogonal_true_g_update(**sample_batch(dataset, args.batch_size)) 126 | elif args.type == 'true_g': 127 | odice.true_g_update(**sample_batch(dataset, args.batch_size)) 128 | elif args.type == 'semi_g': 129 | odice.semi_g_update(**sample_batch(dataset, args.batch_size)) 130 | 131 | if (step+1) % args.eval_period == 0: 132 | average_returns = eval(odice.step) 133 | eval_log.write(f'{step + 1}\tavg return: {average_returns}\t\n') 134 | eval_log.flush() 135 | eval_log.close() 136 | os.makedirs(f"{args.model_dir}/{args.env_name}", exist_ok=True) 137 | odice.save(f"{args.model_dir}/{args.env_name}") 138 | 139 | 140 | if __name__ == '__main__': 141 | from argparse import ArgumentParser 142 | parser = ArgumentParser() 143 | parser.add_argument('--env_name', type=str, default="hopper-medium-replay-v2") 144 | parser.add_argument('--Lambda', type=float, default=0.6) 145 | parser.add_argument('--eta', type=float, default=1.0) 146 | parser.add_argument("--type", type=str, choices=['orthogonal_true_g', 'true_g', 'semi_g'], default='orthogonal_true_g') 147 | with open("configs/offline_RL.yaml", "r") as file: 148 | config = yaml.safe_load(file) 149 | now = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 150 | 151 | args = parser.parse_args(namespace=argparse.Namespace(**config)) 152 | 153 | main(args) -------------------------------------------------------------------------------- /odice.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | import wandb 7 | import numpy as np 8 | from util import DEFAULT_DEVICE, update_exponential_moving_average 9 | 10 | 11 | EXP_ADV_MAX = 100. 12 | 13 | 14 | def f_star(residual, name="Pearson_chi_square"): 15 | if name == "Reverse_KL": 16 | return torch.exp(residual - 1) 17 | elif name == "Pearson_chi_square": 18 | omega_star = torch.max(residual / 2 + 1, torch.zeros_like(residual)) 19 | return residual * omega_star - (omega_star - 1)**2 20 | 21 | 22 | def f_prime_inverse(residual, name='Pearson_chi_square'): 23 | if name == "Reverse_KL": 24 | return torch.exp(residual - 1) 25 | elif name == "Pearson_chi_square": 26 | return torch.max(residual, torch.zeros_like(residual)) 27 | 28 | 29 | class ODICE(nn.Module): 30 | def __init__(self, vf, policy, max_steps, f_name="Pearson_chi_square", Lambda=0.8, eta=1.0, 31 | use_twin_v=False, value_lr=1e-4, policy_lr=1e-4, weight_decay=1e-5, discount=0.99, beta=0.005): 32 | super().__init__() 33 | self.vf = vf.to(DEFAULT_DEVICE) 34 | self.vf_target = copy.deepcopy(vf).requires_grad_(False).to(DEFAULT_DEVICE) 35 | self.policy = policy.to(DEFAULT_DEVICE) 36 | self.v_optimizer = torch.optim.Adam(self.vf.parameters(), lr=value_lr) 37 | self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=policy_lr, weight_decay=weight_decay) 38 | self.state_feature = [] 39 | self.Lambda = Lambda 40 | self.eta = eta 41 | self.f_name = f_name 42 | self.use_twin_v = use_twin_v 43 | self.discount = discount 44 | self.beta = beta 45 | self.step = 0 46 | 47 | def orthogonal_true_g_update(self, observations, actions, next_observations, rewards, terminals): 48 | # the network will NOT update 49 | with torch.no_grad(): 50 | target_v = self.vf_target(observations) 51 | target_v_next = self.vf_target(next_observations) 52 | 53 | v = self.vf.both(observations) if self.use_twin_v else self.vf(observations) 54 | v_next = self.vf.both(next_observations) if self.use_twin_v else self.vf(next_observations) 55 | 56 | forward_residual = rewards + (1. - terminals.float()) * self.discount * target_v_next - v 57 | backward_residual = rewards + (1. - terminals.float()) * self.discount * v_next - target_v 58 | forward_dual_loss = torch.mean(self.Lambda * f_star(forward_residual, self.f_name)) 59 | backward_dual_loss = torch.mean(self.Lambda * self.eta * f_star(backward_residual, self.f_name)) 60 | pi_residual = forward_residual.clone().detach() 61 | td_mean, td_min, td_max = torch.mean(forward_residual), torch.min(forward_residual), torch.max(forward_residual) 62 | 63 | self.v_optimizer.zero_grad(set_to_none=True) 64 | forward_grad_list, backward_grad_list = [], [] 65 | forward_dual_loss.backward(retain_graph=True) 66 | for param in list(self.vf.parameters()): 67 | forward_grad_list.append(param.grad.clone().detach().reshape(-1)) 68 | backward_dual_loss.backward() 69 | for i, param in enumerate(list(self.vf.parameters())): 70 | backward_grad_list.append(param.grad.clone().detach().reshape(-1) - forward_grad_list[i]) 71 | forward_grad, backward_grad = torch.cat(forward_grad_list), torch.cat(backward_grad_list) 72 | parallel_coef = (torch.dot(forward_grad, backward_grad) / max(torch.dot(forward_grad, forward_grad), 1e-10)).item() # avoid zero grad caused by f* 73 | forward_grad = (1 - parallel_coef) * forward_grad + backward_grad 74 | 75 | param_idx = 0 76 | for i, grad in enumerate(forward_grad_list): 77 | forward_grad_list[i] = forward_grad[param_idx: param_idx + grad.shape[0]] 78 | param_idx += grad.shape[0] 79 | 80 | self.v_optimizer.zero_grad(set_to_none=True) 81 | torch.mean((1 - self.Lambda) * v).backward() 82 | for i, param in enumerate(list(self.vf.parameters())): 83 | param.grad += forward_grad_list[i].reshape(param.grad.shape) 84 | 85 | self.v_optimizer.step() 86 | 87 | # Update target V network 88 | update_exponential_moving_average(self.vf_target, self.vf, self.beta) 89 | 90 | # Update policy 91 | weight = f_prime_inverse(pi_residual, self.f_name) 92 | weight = torch.clamp_max(weight, EXP_ADV_MAX).detach() 93 | policy_out = self.policy(observations) 94 | bc_losses = -policy_out.log_prob(actions) 95 | policy_loss = torch.mean(weight * bc_losses) 96 | self.policy_optimizer.zero_grad(set_to_none=True) 97 | policy_loss.backward() 98 | self.policy_optimizer.step() 99 | 100 | # wandb 101 | if (self.step + 1) % 10000 == 0: 102 | wandb.log({"v_value": v.mean(), "weight_max": weight.max(), "weight_min": weight.min(), 103 | "td_mean": td_mean, "td_min": td_min, "td_max": td_max, }, step=self.step) 104 | 105 | self.step += 1 106 | 107 | def true_g_update(self, observations, actions, next_observations, rewards, terminals): 108 | v = self.vf.both(observations) if self.use_twin_v else self.vf(observations) 109 | v_next = self.vf.both(next_observations) if self.use_twin_v else self.vf(next_observations) 110 | 111 | residual = rewards + (1. - terminals.float()) * self.discount * v_next - v 112 | dual_loss = f_star(residual, self.f_name) 113 | pi_residual = residual.clone().detach() 114 | td_mean, td_min, td_max = torch.mean(residual), torch.min(residual), torch.max(residual) 115 | 116 | v_loss = torch.mean(((1 - self.Lambda) * v + self.Lambda * dual_loss)) 117 | self.v_optimizer.zero_grad(set_to_none=True) 118 | v_loss.backward() 119 | self.v_optimizer.step() 120 | 121 | # Update target V network 122 | update_exponential_moving_average(self.vf_target, self.vf, self.beta) 123 | 124 | # Update policy 125 | weight = f_prime_inverse(pi_residual, self.f_name) 126 | weight = torch.clamp_max(weight, EXP_ADV_MAX).detach() 127 | policy_out = self.policy(observations) 128 | bc_losses = -policy_out.log_prob(actions) 129 | policy_loss = torch.mean(weight * bc_losses) 130 | self.policy_optimizer.zero_grad(set_to_none=True) 131 | policy_loss.backward() 132 | self.policy_optimizer.step() 133 | 134 | # wandb 135 | if (self.step + 1) % 10000 == 0: 136 | wandb.log({"v_value": v.mean(), "weight_max": weight.max(), "weight_min": weight.min(), 137 | "td_mean": td_mean, "td_min": td_min, "td_max": td_max, }, step=self.step) 138 | 139 | self.step += 1 140 | 141 | def semi_g_update(self, observations, actions, next_observations, rewards, terminals): 142 | # the network will NOT update 143 | with torch.no_grad(): 144 | target_v_next = self.vf_target(next_observations) 145 | 146 | v = self.vf.both(observations) if self.use_twin_v else self.vf(observations) 147 | 148 | TD_error = rewards + (1. - terminals.float()) * self.discount * target_v_next - v 149 | dual_loss = f_star(TD_error, self.f_name) 150 | pi_residual = TD_error.clone().detach() 151 | td_mean, td_min, td_max = torch.mean(TD_error), torch.min(TD_error), torch.max(TD_error) 152 | 153 | v_loss = torch.mean(((1 - self.Lambda) * v + self.Lambda * dual_loss)) 154 | self.v_optimizer.zero_grad(set_to_none=True) 155 | v_loss.backward() 156 | self.v_optimizer.step() 157 | 158 | # Update target V network 159 | update_exponential_moving_average(self.vf_target, self.vf, self.beta) 160 | 161 | # Update policy 162 | weight = f_prime_inverse(pi_residual, self.f_name) 163 | weight = torch.clamp_max(weight, EXP_ADV_MAX).detach() 164 | policy_out = self.policy(observations) 165 | bc_losses = -policy_out.log_prob(actions) 166 | policy_loss = torch.mean(weight * bc_losses) 167 | self.policy_optimizer.zero_grad(set_to_none=True) 168 | policy_loss.backward() 169 | self.policy_optimizer.step() 170 | 171 | # wandb 172 | if (self.step + 1) % 10000 == 0: 173 | wandb.log({"v_value": v.mean(), "weight_max": weight.max(), "weight_min": weight.min(), 174 | "td_mean": td_mean, "td_min": td_min, "td_max": td_max, }, step=self.step) 175 | 176 | self.step += 1 177 | 178 | def get_activation(self): 179 | def hook(model, input, output): 180 | self.state_feature.append(output.detach()) 181 | return hook 182 | 183 | def save(self, model_dir): 184 | checkpoint = { 185 | 'step': self.step, 186 | 'vf': self.vf.state_dict(), 187 | 'vf_target': self.vf_target.state_dict(), 188 | 'policy': self.policy.state_dict(), 189 | 'v_optimizer': self.v_optimizer.state_dict(), 190 | 'policy_optimizer': self.policy_optimizer.state_dict(), 191 | } 192 | torch.save(checkpoint, model_dir + f"/eta_{self.eta}_Lambda_{self.Lambda}_checkpoint_{self.step}.pth") 193 | print(f"***save models to {model_dir}***") 194 | 195 | def load(self, model_dir, step): 196 | checkpoint = torch.load(model_dir + f"/eta_{self.eta}_Lambda_{self.Lambda}_checkpoint_{step}.pth") 197 | self.step = checkpoint['step'] 198 | self.vf.load_state_dict(checkpoint['vf']) 199 | self.vf_target.load_state_dict(checkpoint['vf_target']) 200 | self.policy.load_state_dict(checkpoint['policy']) 201 | self.v_optimizer.load_state_dict(checkpoint['v_optimizer']) 202 | self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer']) 203 | print(f"***load the model from {model_dir}***") 204 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import MultivariateNormal, Normal 4 | 5 | from util import mlp 6 | 7 | 8 | LOG_STD_MIN = -5.0 9 | LOG_STD_MAX = 2.0 10 | epsilon = 1e-6 11 | 12 | class SquashedGaussianPolicy(nn.Module): 13 | """Squashed Gaussian Actor, which maps the given obs to a parameterized Gaussian Distribution, 14 | followed by a Tanh transformation to squash the action sample to [-1, 1]. """ 15 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2, conditioned_logstd=True): 16 | super().__init__() 17 | self.conditioned_logstd = conditioned_logstd 18 | if self.conditioned_logstd is True: 19 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), 2 * act_dim]) 20 | else: 21 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim]) 22 | self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 23 | 24 | def forward(self, obs): 25 | out = self.net(obs) 26 | if self.conditioned_logstd is True: 27 | mean, self.log_std = torch.split(out, out.shape[-1] // 2, dim=-1) 28 | else: 29 | mean = out 30 | std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 31 | return Normal(mean, std) 32 | 33 | def sample(self, obs, deterministic=False): 34 | """For training and evaluation.""" 35 | dist = self(obs) 36 | if deterministic: 37 | raw_action, log_prob = dist.mean, None 38 | action = torch.tanh(raw_action) 39 | else: 40 | raw_action = dist.rsample() 41 | log_prob = dist.log_prob(raw_action) 42 | action = torch.tanh(raw_action) 43 | # Enforcing Action Bound 44 | log_prob -= torch.log((1 - action.pow(2)) + epsilon) 45 | log_prob = log_prob.sum(-1) 46 | return action, log_prob 47 | 48 | def act(self, obs, deterministic=False, enable_grad=False): 49 | """For training and evaluation.""" 50 | with torch.set_grad_enabled(enable_grad): 51 | dist = self(obs) 52 | if deterministic: 53 | action = torch.tanh(dist.mean) 54 | else: 55 | action = torch.tanh(dist.rsample()) 56 | return action 57 | 58 | def evaluate(self, obs, action): 59 | """For Behavior Cloning.""" 60 | dist = self(obs) 61 | action = torch.clip(action, -1.0, 1.0) 62 | raw_action = 0.5 * (action.log1p() - (-action).log1p()) 63 | # Enforcing Action Bound 64 | log_prob = dist.log_prob(raw_action) 65 | log_prob -= torch.log((1 - action.pow(2)) + epsilon) 66 | log_prob = log_prob.sum(-1) 67 | return log_prob 68 | 69 | 70 | # class GaussianPolicy(nn.Module): 71 | # def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2, use_tanh="False"): 72 | # super().__init__() 73 | # self.use_tanh = use_tanh 74 | # self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim]) 75 | # self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 76 | # 77 | # def forward(self, obs): 78 | # mean = self.net(obs) 79 | # std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 80 | # scale_tril = torch.diag(std) 81 | # return MultivariateNormal(mean, scale_tril=scale_tril) 82 | # 83 | # def act(self, obs, deterministic=False, enable_grad=False): 84 | # with torch.set_grad_enabled(enable_grad): 85 | # dist = self(obs) 86 | # action = dist.mean if deterministic else dist.rsample() 87 | # action = torch.tanh(action) if self.use_tanh else torch.clip(action, min=-1.0, max=1.0) 88 | # return action 89 | 90 | 91 | class GaussianPolicy(nn.Module): 92 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2, use_tanh="False"): 93 | super().__init__() 94 | self.use_tanh = use_tanh 95 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim]) 96 | self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 97 | 98 | def forward(self, obs): 99 | mean = self.net(obs) 100 | if self.use_tanh: 101 | mean = torch.tanh(mean) 102 | std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 103 | scale_tril = torch.diag(std) 104 | return MultivariateNormal(mean, scale_tril=scale_tril) 105 | 106 | def act(self, obs, deterministic=False, enable_grad=False): 107 | with torch.set_grad_enabled(enable_grad): 108 | dist = self(obs) 109 | action = dist.mean if deterministic else dist.rsample() 110 | action = torch.clip(action, min=-1.0, max=1.0) 111 | return action 112 | 113 | 114 | class DeterministicPolicy(nn.Module): 115 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2): 116 | super().__init__() 117 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim], 118 | output_activation=nn.Tanh) 119 | 120 | def forward(self, obs): 121 | return self.net(obs) 122 | 123 | def act(self, obs, deterministic=False, enable_grad=False): 124 | with torch.set_grad_enabled(enable_grad): 125 | action = self(obs) 126 | action = torch.clip(action, min=-1.0, max=1.0) 127 | return action -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | tqdm 5 | gym[mujoco] >= 0.18.0 6 | torch>=1.7.0 7 | git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl -------------------------------------------------------------------------------- /run_offline_il.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to reproduce offline IL results 4 | 5 | GPU_LIST=(0 1 2 3) 6 | 7 | for T in 1 10 20 30; do 8 | for seed in 0 10 100; do 9 | 10 | GPU_DEVICE=${GPU_LIST[task%${#GPU_LIST[@]}]} 11 | CUDA_VISIBLE_DEVICES=$GPU_DEVICE python main_odice_il.py \ 12 | --env_name "hopper-expert-v2" \ 13 | --weight_decay 0.00001 \ 14 | --T $T \ 15 | --seed $seed & 16 | 17 | sleep 2 18 | let "task=$task+1" 19 | 20 | GPU_DEVICE=${GPU_LIST[task%${#GPU_LIST[@]}]} 21 | CUDA_VISIBLE_DEVICES=$GPU_DEVICE python main_odice_il.py \ 22 | --env_name "walker2d-expert-v2" \ 23 | --weight_decay 0.001 \ 24 | --T $T \ 25 | --seed $seed & 26 | 27 | sleep 2 28 | let "task=$task+1" 29 | 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from datetime import datetime 3 | import json 4 | from pathlib import Path 5 | import random 6 | import string 7 | import sys 8 | import os 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | class Squeeze(nn.Module): 17 | def __init__(self, dim=None): 18 | super().__init__() 19 | self.dim = dim 20 | 21 | def forward(self, x): 22 | return x.squeeze(dim=self.dim) 23 | 24 | 25 | def mlp(dims, activation=nn.ReLU, output_activation=None, layer_norm=False, squeeze_output=False, use_orthogonal=False): 26 | n_dims = len(dims) 27 | assert n_dims >= 2, 'MLP requires at least two dims (input and output)' 28 | 29 | layers = [] 30 | for i in range(n_dims - 2): 31 | if use_orthogonal: 32 | fc = nn.Linear(dims[i], dims[i+1]) 33 | nn.init.orthogonal_(fc.weight) 34 | layers.append(fc) 35 | else: 36 | layers.append(nn.Linear(dims[i], dims[i+1])) 37 | if layer_norm: 38 | layers.append(nn.LayerNorm(dims[i+1])) 39 | layers.append(activation()) 40 | if use_orthogonal: 41 | fc = nn.Linear(dims[-2], dims[-1]) 42 | nn.init.orthogonal_(fc.weight) 43 | layers.append(fc) 44 | else: 45 | layers.append(nn.Linear(dims[-2], dims[-1])) 46 | if output_activation is not None: 47 | layers.append(output_activation()) 48 | if squeeze_output: 49 | # assert dims[-1] == 1 50 | layers.append(Squeeze(-1)) 51 | net = nn.Sequential(*layers) 52 | net.to(dtype=torch.float32) 53 | return net 54 | 55 | 56 | def compute_batched(f, xs): 57 | return f(torch.cat(xs, dim=0)).split([len(x) for x in xs]) 58 | 59 | 60 | def update_exponential_moving_average(target, source, alpha): 61 | for target_param, source_param in zip(target.parameters(), source.parameters()): 62 | target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha) 63 | 64 | 65 | def torchify(x): 66 | x = torch.from_numpy(x) 67 | if x.dtype is torch.float64: 68 | x = x.float() 69 | x = x.to(device=DEFAULT_DEVICE) 70 | return x 71 | 72 | 73 | def return_range(dataset, max_episode_steps): 74 | returns, lengths = [], [] 75 | ep_ret, ep_len = 0., 0 76 | for r, d in zip(dataset['rewards'], dataset['terminals']): 77 | ep_ret += float(r) 78 | ep_len += 1 79 | if d or ep_len == max_episode_steps: 80 | returns.append(ep_ret) 81 | lengths.append(ep_len) 82 | ep_ret, ep_len = 0., 0 83 | # returns.append(ep_ret) # incomplete trajectory 84 | lengths.append(ep_len) # but still keep track of number of steps 85 | assert sum(lengths) == len(dataset['rewards']) 86 | return min(returns), max(returns) 87 | 88 | 89 | def extract_done_makers(dones): 90 | (ends, ) = np.where(dones) 91 | starts = np.concatenate(([0], ends[:-1] + 1)) 92 | length = ends - starts + 1 93 | return starts, ends, length 94 | 95 | 96 | def _sample_indces(dataset, batch_size): 97 | try: 98 | dones = dataset["timeouts"].cpu().numpy() 99 | except: 100 | dones = dataset["terminals"].cpu().numpy() 101 | starts, ends, lengths = extract_done_makers(dones) 102 | # credit to Dibya Ghosh's GCSL codebase 103 | trajectory_indces = np.random.choice(len(starts), batch_size) 104 | proportional_indices_1 = np.random.rand(batch_size) 105 | proportional_indices_2 = np.random.rand(batch_size) 106 | # proportional_indices_2 = 1 107 | time_dinces_1 = np.floor( 108 | proportional_indices_1 * (lengths[trajectory_indces] - 1) 109 | ).astype(int) 110 | time_dinces_2 = np.floor( 111 | proportional_indices_2 * (lengths[trajectory_indces]) 112 | ).astype(int) 113 | start_indices = starts[trajectory_indces] + np.minimum( 114 | time_dinces_1, 115 | time_dinces_2 116 | ) 117 | goal_indices = starts[trajectory_indces] + np.maximum( 118 | time_dinces_1, 119 | time_dinces_2 120 | ) 121 | 122 | return start_indices, goal_indices 123 | 124 | 125 | # dataset is a dict, values of which are tensors of same first dimension 126 | def sample_batch(dataset, batch_size): 127 | n, device = len(dataset['observations']), dataset['observations'].device 128 | batch = {} 129 | # indices_0 = torch.randint(low=0, high=n, size=(batch_size,), device=device) 130 | indices = torch.randint(low=0, high=n, size=(batch_size,), device=device) 131 | for k, v in dataset.items(): 132 | if k == "trajectory_terminals": 133 | continue 134 | else: 135 | batch[k] = v[indices] 136 | return batch 137 | 138 | 139 | def rvs_sample_batch(dataset, batch_size): 140 | start_indices, goal_indices = _sample_indces(dataset, batch_size) 141 | dict = {} 142 | for k, v in dataset.items(): 143 | if (k == "observations") or (k == "actions"): 144 | dict[k] = v[start_indices] 145 | dict["next_observations"] = dataset["observations"][goal_indices] 146 | dict["rewards"] = 0 147 | dict["terminals"] = 0 148 | return dict 149 | 150 | 151 | def evaluate(env, policy, mean, std, deterministic=True): 152 | obs = env.reset() 153 | total_reward = 0. 154 | done, i = False, 0 155 | while not done: 156 | obs = (obs - mean)/std 157 | with torch.no_grad(): 158 | action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy() 159 | obs, reward, done, info = env.step(action) 160 | total_reward += reward 161 | i += 1 162 | return total_reward 163 | 164 | 165 | def evaluate_por(env, policy, goal_policy, mean, std, deterministic=True): 166 | obs = env.reset() 167 | total_reward = 0. 168 | done, i = False, 0 169 | while not done: 170 | obs = (obs - mean)/std 171 | with torch.no_grad(): 172 | g = goal_policy.act(torchify(obs), deterministic=deterministic).cpu().numpy() 173 | action = policy.act(torchify(np.concatenate([obs, g])), deterministic=deterministic).cpu().numpy() 174 | obs, reward, done, info = env.step(action) 175 | total_reward += reward 176 | i += 1 177 | return total_reward 178 | 179 | 180 | def evaluate_rvs(env, policy, mean, std, deterministic=True): 181 | obs = env.reset() 182 | goal = np.array(env.target_goal) 183 | goal = (goal - mean[:2])/std[:2] 184 | total_reward = 0. 185 | done, i = False, 0 186 | while not done: 187 | obs = (obs - mean)/std 188 | with torch.no_grad(): 189 | if i % 100 == 0: 190 | print('current location:', obs[:2]) 191 | action = policy.act(torchify(np.concatenate([obs, goal])), deterministic=deterministic).cpu().numpy() 192 | obs, reward, done, info = env.step(action) 193 | total_reward += reward 194 | i += 1 195 | return total_reward 196 | 197 | 198 | def set_seed(seed, env=None): 199 | torch.manual_seed(seed) 200 | if torch.cuda.is_available(): 201 | torch.cuda.manual_seed_all(seed) 202 | np.random.seed(seed) 203 | random.seed(seed) 204 | if env is not None: 205 | env.seed(seed) 206 | 207 | 208 | def save(dir ,filename, env_name, network_model): 209 | if not os.path.exists(dir): 210 | os.mkdir(dir) 211 | file = dir + env_name + "-" + filename 212 | torch.save(network_model.state_dict(), file) 213 | print(f"***save the {network_model} model to {file}***") 214 | 215 | 216 | def load(dir, filename, env_name, network_model): 217 | file = dir + env_name + "-" + filename 218 | if not os.path.exists(file): 219 | raise FileExistsError("Doesn't exist the model") 220 | network_model.load_state_dict(torch.load(file, map_location=torch.device('cpu'))) 221 | print(f"***load the model from {file}***") 222 | 223 | 224 | def _gen_dir_name(): 225 | now_str = datetime.now().strftime('%m-%d-%y_%H.%M.%S') 226 | rand_str = ''.join(random.choices(string.ascii_lowercase, k=4)) 227 | return f'{now_str}_{rand_str}' 228 | 229 | class Log: 230 | def __init__(self, root_log_dir, cfg_dict, 231 | txt_filename='log.txt', 232 | csv_filename='progress.csv', 233 | cfg_filename='config.json', 234 | flush=True): 235 | self.dir = Path(root_log_dir)/_gen_dir_name() 236 | self.dir.mkdir(parents=True) 237 | self.txt_file = open(self.dir/txt_filename, 'w') 238 | self.csv_file = None 239 | (self.dir/cfg_filename).write_text(json.dumps(cfg_dict)) 240 | self.txt_filename = txt_filename 241 | self.csv_filename = csv_filename 242 | self.cfg_filename = cfg_filename 243 | self.flush = flush 244 | 245 | def write(self, message, end='\n'): 246 | now_str = datetime.now().strftime('%H:%M:%S') 247 | message = f'[{now_str}] ' + message 248 | for f in [sys.stdout, self.txt_file]: 249 | print(message, end=end, file=f, flush=self.flush) 250 | 251 | def __call__(self, *args, **kwargs): 252 | self.write(*args, **kwargs) 253 | 254 | def row(self, dict): 255 | if self.csv_file is None: 256 | self.csv_file = open(self.dir/self.csv_filename, 'w', newline='') 257 | self.csv_writer = csv.DictWriter(self.csv_file, list(dict.keys())) 258 | self.csv_writer.writeheader() 259 | 260 | self(str(dict)) 261 | self.csv_writer.writerow(dict) 262 | if self.flush: 263 | self.csv_file.flush() 264 | 265 | def close(self): 266 | self.txt_file.close() 267 | if self.csv_file is not None: 268 | self.csv_file.close() 269 | 270 | def dataset_T_trajs(dataset, T, terminate_on_end=False): 271 | """ 272 | Returns T trajs from dataset. 273 | """ 274 | N = dataset['rewards'].shape[0] 275 | return_traj = [] 276 | obs_traj = [[]] 277 | next_obs_traj = [[]] 278 | action_traj = [[]] 279 | reward_traj = [[]] 280 | done_traj = [[]] 281 | 282 | for i in range(N-1): 283 | obs_traj[-1].append(dataset['observations'][i].astype(np.float32)) 284 | next_obs_traj[-1].append(dataset['observations'][i+1].astype(np.float32)) 285 | action_traj[-1].append(dataset['actions'][i].astype(np.float32)) 286 | reward_traj[-1].append(dataset['rewards'][i].astype(np.float32)) 287 | done_traj[-1].append(bool(dataset['terminals'][i])) 288 | 289 | final_timestep = dataset['timeouts'][i] | dataset['terminals'][i] 290 | if (not terminate_on_end) and final_timestep: 291 | # Skip this transition and don't apply terminals on the last step of an episode 292 | return_traj.append(np.sum(reward_traj[-1])) 293 | obs_traj.append([]) 294 | next_obs_traj.append([]) 295 | action_traj.append([]) 296 | reward_traj.append([]) 297 | done_traj.append([]) 298 | 299 | # select T trajectories 300 | inds_all = list(range(len(obs_traj))) 301 | inds = inds_all[:T] 302 | inds = list(inds) 303 | 304 | print('# select {} trajs in the dataset'.format(T)) 305 | 306 | obs_traj = [obs_traj[i] for i in inds] 307 | next_obs_traj = [next_obs_traj[i] for i in inds] 308 | action_traj = [action_traj[i] for i in inds] 309 | reward_traj = [reward_traj[i] for i in inds] 310 | done_traj = [done_traj[i] for i in inds] 311 | 312 | def concat_trajectories(trajectories): 313 | return np.concatenate(trajectories, 0) 314 | 315 | return { 316 | 'observations': concat_trajectories(obs_traj), 317 | 'actions': concat_trajectories(action_traj), 318 | 'next_observations': concat_trajectories(next_obs_traj), 319 | 'rewards': concat_trajectories(reward_traj), 320 | 'terminals': concat_trajectories(done_traj), 321 | } 322 | 323 | def dataset_split_expert(dataset, split_x, exp_num, terminate_on_end=False): 324 | """ 325 | Returns D_e and expert data in D_o of setting 1 in the paper. 326 | """ 327 | N = dataset['rewards'].shape[0] 328 | return_traj = [] 329 | obs_traj = [[]] 330 | next_obs_traj = [[]] 331 | action_traj = [[]] 332 | reward_traj = [[]] 333 | done_traj = [[]] 334 | 335 | for i in range(N-1): 336 | obs_traj[-1].append(dataset['observations'][i].astype(np.float32)) 337 | next_obs_traj[-1].append(dataset['observations'][i+1].astype(np.float32)) 338 | action_traj[-1].append(dataset['actions'][i].astype(np.float32)) 339 | reward_traj[-1].append(dataset['rewards'][i].astype(np.float32)) 340 | done_traj[-1].append(bool(dataset['terminals'][i])) 341 | 342 | final_timestep = dataset['timeouts'][i] | dataset['terminals'][i] 343 | if (not terminate_on_end) and final_timestep: 344 | # Skip this transition and don't apply terminals on the last step of an episode 345 | return_traj.append(np.sum(reward_traj[-1])) 346 | obs_traj.append([]) 347 | next_obs_traj.append([]) 348 | action_traj.append([]) 349 | reward_traj.append([]) 350 | done_traj.append([]) 351 | 352 | # select 10 trajectories 353 | inds_all = list(range(len(obs_traj))) 354 | inds_succ = inds_all[:exp_num] 355 | inds_o = inds_succ[-split_x:] 356 | inds_o = list(inds_o) 357 | inds_succ = list(inds_succ) 358 | inds_e = set(inds_succ) - set(inds_o) 359 | inds_e = list(inds_e) 360 | 361 | print('# select {} trajs in expert dataset as D_e'.format(len(inds_e))) 362 | print('# select {} trajs in expert dataset as expert data in D_o'.format(len(inds_o))) 363 | 364 | obs_traj_e = [obs_traj[i] for i in inds_e] 365 | next_obs_traj_e = [next_obs_traj[i] for i in inds_e] 366 | action_traj_e = [action_traj[i] for i in inds_e] 367 | reward_traj_e = [reward_traj[i] for i in inds_e] 368 | done_traj_e = [done_traj[i] for i in inds_e] 369 | 370 | obs_traj_o = [obs_traj[i] for i in inds_o] 371 | next_obs_traj_o = [next_obs_traj[i] for i in inds_o] 372 | action_traj_o = [action_traj[i] for i in inds_o] 373 | reward_traj_o = [reward_traj[i] for i in inds_o] 374 | done_traj_o = [done_traj[i] for i in inds_o] 375 | 376 | def concat_trajectories(trajectories): 377 | return np.concatenate(trajectories, 0) 378 | 379 | dataset_e = { 380 | 'observations': concat_trajectories(obs_traj_e), 381 | 'actions': concat_trajectories(action_traj_e), 382 | 'next_observations': concat_trajectories(next_obs_traj_e), 383 | 'rewards': concat_trajectories(reward_traj_e), 384 | 'terminals': concat_trajectories(done_traj_e), 385 | } 386 | 387 | dataset_o = { 388 | 'observations': concat_trajectories(obs_traj_o), 389 | 'actions': concat_trajectories(action_traj_o), 390 | 'next_observations': concat_trajectories(next_obs_traj_o), 391 | 'rewards': concat_trajectories(reward_traj_o), 392 | 'terminals': concat_trajectories(done_traj_o), 393 | } 394 | 395 | return dataset_e, dataset_o 396 | 397 | def dataset_mix_trajs(expert_dataset, random_dataset, split_num, exp_num): 398 | dataset_o = dataset_T_trajs(random_dataset, 1000) 399 | dataset_o['flags'] = np.zeros_like(dataset_o['terminals']).astype(np.float32) 400 | dataset_e, dataset_o_extra = dataset_split_expert(expert_dataset, split_num, exp_num) 401 | dataset_e['flags'] = np.ones_like(dataset_e['terminals']).astype(np.float32) 402 | dataset_o_extra['flags'] = np.ones_like(dataset_o_extra['terminals']).astype(np.float32) 403 | for key in dataset_o.keys(): 404 | dataset_o[key] = np.concatenate([dataset_o[key], dataset_o_extra[key]], 0) 405 | return dataset_e, dataset_o -------------------------------------------------------------------------------- /value_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from util import mlp 5 | # All networks with name {Net}Hook are used for monitoring representation of state when forwarding 6 | # Use self.vf.fc2.register_forward_hook(self.get_activation()) to record state representation and then calculate cosine similarity 7 | # Please check https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html for more details 8 | 9 | class ValueFunction(nn.Module): 10 | def __init__(self, state_dim, layer_norm=False, hidden_dim=256, n_hidden=2): 11 | super().__init__() 12 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 13 | self.v = mlp(dims, layer_norm=layer_norm, squeeze_output=True) 14 | 15 | def forward(self, state): 16 | return self.v(state) 17 | 18 | class ValueFunctionHook(nn.Module): 19 | def __init__(self, state_dim, layer_norm=False, hidden_dim=256, squeeze_output=True, use_orthogonal=False): 20 | super().__init__() 21 | self.use_layer_norm = layer_norm 22 | self.squeeze_output = squeeze_output 23 | self.fc1 = nn.Linear(state_dim, hidden_dim) 24 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 25 | self.fc3 = nn.Linear(hidden_dim, 1) 26 | if use_orthogonal: 27 | nn.init.orthogonal_(self.fc1.weight) 28 | nn.init.orthogonal_(self.fc2.weight) 29 | nn.init.orthogonal_(self.fc3.weight) 30 | self.activation = nn.ReLU() 31 | if layer_norm: 32 | self.layer_norm1 = nn.LayerNorm(hidden_dim) 33 | self.layer_norm2 = nn.LayerNorm(hidden_dim) 34 | 35 | def forward(self, state): 36 | x = self.activation(self.layer_norm1(self.fc1(state))) if self.use_layer_norm else self.activation(self.fc1(state)) 37 | x = self.activation(self.layer_norm2(self.fc2(x))) if self.use_layer_norm else self.activation(self.fc2(x)) 38 | value = self.fc3(x).squeeze(-1) if self.squeeze_output else self.fc3(x) 39 | return value 40 | 41 | class TwinV(nn.Module): 42 | def __init__(self, state_dim, layer_norm=False, hidden_dim=256, n_hidden=2): 43 | super().__init__() 44 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 45 | self.v1 = mlp(dims, layer_norm=layer_norm, squeeze_output=True) 46 | self.v2 = mlp(dims, layer_norm=layer_norm, squeeze_output=True) 47 | 48 | def both(self, state): 49 | return torch.stack([self.v1(state), self.v2(state)], dim=0) 50 | 51 | def forward(self, state): 52 | return torch.min(self.both(state), dim=0)[0] 53 | 54 | class TwinVHook(nn.Module): 55 | def __init__(self, state_dim, layer_norm=False, hidden_dim=256, squeeze_output=True, use_orthogonal=False): 56 | super().__init__() 57 | self.v1 = ValueFunctionHook(state_dim, layer_norm, hidden_dim, squeeze_output, use_orthogonal) 58 | self.v2 = ValueFunctionHook(state_dim, layer_norm, hidden_dim, squeeze_output, use_orthogonal) 59 | 60 | def both(self, state): 61 | return torch.stack([self.v1(state), self.v2(state)], dim=0) 62 | 63 | def forward(self, state): 64 | return torch.min(self.both(state), dim=0)[0] 65 | 66 | class Discriminator(nn.Module): 67 | def __init__(self, state_dim, layer_norm=False, hidden_dim=256, n_hidden=2): 68 | super().__init__() 69 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 70 | self.d = mlp(dims, layer_norm=layer_norm, squeeze_output=True, output_activation=nn.Sigmoid) 71 | 72 | def forward(self, state): 73 | return self.d(state) 74 | 75 | class RepNet(nn.Module): 76 | def __init__(self, state_dim, out_dim, layer_norm=False, hidden_dim=256, n_hidden=2): 77 | super().__init__() 78 | dims = [state_dim, *([hidden_dim] * n_hidden), out_dim] 79 | self.rep = mlp(dims, layer_norm=layer_norm, squeeze_output=True) 80 | 81 | def forward(self, state): 82 | return self.rep(state) 83 | 84 | # Auto-Encoder 85 | class AutoEncoder(nn.Module): 86 | def __init__(self, state_dim, action_dim, latent_dim, max_action): 87 | super(AutoEncoder, self).__init__() 88 | self.e1 = nn.Linear(state_dim + action_dim, 750) 89 | self.e2 = nn.Linear(750, 750) 90 | self.mean = nn.Linear(750, latent_dim) 91 | 92 | self.d1 = nn.Linear(state_dim + latent_dim, 750) 93 | self.d2 = nn.Linear(750, 750) 94 | self.d3 = nn.Linear(750, action_dim) 95 | 96 | self.max_action = max_action 97 | self.latent_dim = latent_dim 98 | 99 | def forward(self, state, action): 100 | z = F.relu(self.e1(torch.cat([state, action], 1))) 101 | z = F.relu(self.e2(z)) 102 | z = self.mean(z) 103 | 104 | u = self.decode(state, z) 105 | 106 | return u, z 107 | 108 | def decode(self, state, z): 109 | a = F.relu(self.d1(torch.cat([state, z], 1))) 110 | a = F.relu(self.d2(a)) 111 | return self.max_action * torch.tanh(self.d3(a)) --------------------------------------------------------------------------------