├── .gitignore ├── IQL.py ├── README.md ├── actor.py ├── common.py ├── critic.py ├── imgs ├── antmaze_results.png └── mujoco_results.png ├── log.py ├── main_iql.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | __pycache__ 3 | .git 4 | make_tasks.py 5 | tasks*.txt 6 | models 7 | -------------------------------------------------------------------------------- /IQL.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | import os 5 | from actor import Actor 6 | from critic import Critic, ValueCritic 7 | 8 | 9 | def loss(diff, expectile=0.8): 10 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 11 | return weight * (diff**2) 12 | 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class IQL(object): 18 | def __init__( 19 | self, 20 | state_dim, 21 | action_dim, 22 | expectile, 23 | discount, 24 | tau, 25 | temperature, 26 | ): 27 | 28 | self.actor = Actor(state_dim, action_dim, 256, 3).to(device) 29 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 30 | self.actor_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.actor_optimizer, T_max=int(1e6)) 31 | 32 | self.critic = Critic(state_dim, action_dim).to(device) 33 | self.critic_target = copy.deepcopy(self.critic) 34 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 35 | 36 | self.value = ValueCritic(state_dim, 256, 3).to(device) 37 | self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=3e-4) 38 | 39 | self.discount = discount 40 | self.tau = tau 41 | self.temperature = temperature 42 | 43 | self.total_it = 0 44 | self.expectile = expectile 45 | 46 | def update_v(self, states, actions, logger=None): 47 | with torch.no_grad(): 48 | q1, q2 = self.critic_target(states, actions) 49 | q = torch.minimum(q1, q2).detach() 50 | 51 | v = self.value(states) 52 | value_loss = loss(q - v, self.expectile).mean() 53 | 54 | self.value_optimizer.zero_grad() 55 | value_loss.backward() 56 | self.value_optimizer.step() 57 | 58 | logger.log('train/value_loss', value_loss, self.total_it) 59 | logger.log('train/v', v.mean(), self.total_it) 60 | 61 | def update_q(self, states, actions, rewards, next_states, not_dones, logger=None): 62 | with torch.no_grad(): 63 | next_v = self.value(next_states) 64 | target_q = (rewards + self.discount * not_dones * next_v).detach() 65 | 66 | q1, q2 = self.critic(states, actions) 67 | critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() 68 | 69 | self.critic_optimizer.zero_grad() 70 | critic_loss.backward() 71 | self.critic_optimizer.step() 72 | 73 | logger.log('train/critic_loss', critic_loss, self.total_it) 74 | logger.log('train/q1', q1.mean(), self.total_it) 75 | logger.log('train/q2', q2.mean(), self.total_it) 76 | 77 | def update_target(self): 78 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 79 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 80 | 81 | def update_actor(self, states, actions, logger=None): 82 | with torch.no_grad(): 83 | v = self.value(states) 84 | q1, q2 = self.critic_target(states, actions) 85 | q = torch.minimum(q1, q2) 86 | exp_a = torch.exp((q - v) * self.temperature) 87 | exp_a = torch.clamp(exp_a, max=100.0).squeeze(-1).detach() 88 | 89 | mu = self.actor(states) 90 | actor_loss = (exp_a.unsqueeze(-1) * ((mu - actions)**2)).mean() 91 | 92 | self.actor_optimizer.zero_grad() 93 | actor_loss.backward() 94 | self.actor_optimizer.step() 95 | self.actor_scheduler.step() 96 | 97 | logger.log('train/actor_loss', actor_loss, self.total_it) 98 | logger.log('train/adv', (q - v).mean(), self.total_it) 99 | 100 | def select_action(self, state): 101 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 102 | return self.actor.get_action(state).cpu().data.numpy().flatten() 103 | 104 | def train(self, replay_buffer, batch_size=256, logger=None): 105 | self.total_it += 1 106 | 107 | # Sample replay buffer 108 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 109 | 110 | # Update 111 | self.update_v(state, action, logger) 112 | self.update_actor(state, action, logger) 113 | self.update_q(state, action, reward, next_state, not_done, logger) 114 | self.update_target() 115 | 116 | def save(self, model_dir): 117 | torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_s{str(self.total_it)}.pth")) 118 | torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_s{str(self.total_it)}.pth")) 119 | torch.save(self.critic_optimizer.state_dict(), os.path.join( 120 | model_dir, f"critic_optimizer_s{str(self.total_it)}.pth")) 121 | 122 | torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_s{str(self.total_it)}.pth")) 123 | torch.save(self.actor_optimizer.state_dict(), os.path.join( 124 | model_dir, f"actor_optimizer_s{str(self.total_it)}.pth")) 125 | torch.save(self.actor_scheduler.state_dict(), os.path.join( 126 | model_dir, f"actor_scheduler_s{str(self.total_it)}.pth")) 127 | 128 | torch.save(self.value.state_dict(), os.path.join(model_dir, f"value_s{str(self.total_it)}.pth")) 129 | torch.save(self.value_optimizer.state_dict(), os.path.join( 130 | model_dir, f"value_optimizer_s{str(self.total_it)}.pth")) 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IQL Implementation in PyTorch 2 | 3 | ## IQL 4 | 5 | This repo is an unofficial implementation of **Implicit Q-Learning (In-sample Q-Learning)** in PyTorch. 6 | 7 | ``` 8 | @inproceedings{ 9 | kostrikov2022offline, 10 | title={Offline Reinforcement Learning with Implicit Q-Learning}, 11 | author={Ilya Kostrikov and Ashvin Nair and Sergey Levine}, 12 | booktitle={International Conference on Learning Representations}, 13 | year={2022}, 14 | url={https://openreview.net/forum?id=68n2s9ZJWF8} 15 | } 16 | ``` 17 | 18 | **Note**: Reward standardization (_We standardize MuJoCo locomotion task rewards by dividing by the difference of returns of the best and worst trajectories in each dataset_) used in [official implementation](https://github.com/ikostrikov/implicit_q_learning/blob/09d700248117881a75cb21f0adb95c6c8a694cb2/train_offline.py#L51C18-L51C18) is missed in this implementation. One can easily add it by itself. 19 | 20 | ## Train 21 | 22 | ### Gym-MuJoCo 23 | 24 | ``` 25 | python main_iql.py --env halfcheetah-medium-v2 --expectile 0.7 --temperature 3.0 --eval_freq 5000 --eval_episodes 10 --normalize 26 | ``` 27 | 28 | ### AntMaze 29 | 30 | ``` 31 | python main_iql.py --env antmaze-medium-play-v2 --expectile 0.9 --temperature 10.0 --eval_freq 50000 --eval_episodes 100 32 | ``` 33 | 34 | ## Results 35 | 36 | ![mujoco_results](imgs/mujoco_results.png) 37 | 38 | ![antmaze_results](imgs/antmaze_results.png) 39 | 40 | ## Acknowledgement 41 | 42 | This repo borrows heavily from [sfujim/TD3_BC](https://github.com/sfujim/TD3_BC) and [ikostrikov/implicit_q_learning](https://github.com/ikostrikov/implicit_q_learning). 43 | -------------------------------------------------------------------------------- /actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions 4 | from common import MLP 5 | 6 | 7 | class Actor(nn.Module): 8 | """MLP actor network.""" 9 | 10 | def __init__( 11 | self, state_dim, action_dim, hidden_dim, n_layers, dropout_rate=None, 12 | log_std_min=-10.0, log_std_max=2.0, 13 | ): 14 | super().__init__() 15 | 16 | self.mlp = MLP( 17 | state_dim, 2 * action_dim, hidden_dim, n_layers, dropout_rate=dropout_rate 18 | ) 19 | 20 | self.log_std_min = log_std_min 21 | self.log_std_max = log_std_max 22 | 23 | def forward( 24 | self, states 25 | ): 26 | mu, log_std = self.mlp(states).chunk(2, dim=-1) 27 | mu = torch.tanh(mu) 28 | return mu 29 | 30 | def get_action(self, states): 31 | mu = self.forward(states) 32 | return mu 33 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Callable, Optional 3 | 4 | from torch.nn.modules.dropout import Dropout 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | in_dim, 12 | out_dim, 13 | hidden_dim, 14 | n_layers, 15 | activations: Callable = nn.ReLU, 16 | activate_final: int = False, 17 | dropout_rate: Optional[float] = None 18 | ) -> None: 19 | super().__init__() 20 | 21 | self.affines = [] 22 | self.affines.append(nn.Linear(in_dim, hidden_dim)) 23 | for i in range(n_layers-2): 24 | self.affines.append(nn.Linear(hidden_dim, hidden_dim)) 25 | self.affines.append(nn.Linear(hidden_dim, out_dim)) 26 | self.affines = nn.ModuleList(self.affines) 27 | 28 | self.activations = activations() 29 | self.activate_final = activate_final 30 | self.dropout_rate = dropout_rate 31 | if dropout_rate is not None: 32 | self.dropout = Dropout(self.dropout_rate) 33 | 34 | def forward(self, x): 35 | for i in range(len(self.affines)): 36 | x = self.affines[i](x) 37 | if i != len(self.affines)-1 or self.activate_final: 38 | x = self.activations(x) 39 | if self.dropout_rate is not None: 40 | x = self.dropout(x) 41 | return x 42 | -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from common import MLP 6 | 7 | 8 | class ValueCritic(nn.Module): 9 | def __init__( 10 | self, 11 | in_dim, 12 | hidden_dim, 13 | n_layers, 14 | **kwargs 15 | ) -> None: 16 | super().__init__() 17 | self.mlp = MLP(in_dim, 1, hidden_dim, n_layers, **kwargs) 18 | 19 | def forward(self, state): 20 | return self.mlp(state) 21 | 22 | 23 | class Critic(nn.Module): 24 | """ 25 | From TD3+BC 26 | """ 27 | 28 | def __init__(self, state_dim, action_dim): 29 | super(Critic, self).__init__() 30 | 31 | # Q1 architecture 32 | self.l1 = nn.Linear(state_dim + action_dim, 256) 33 | self.l2 = nn.Linear(256, 256) 34 | self.l3 = nn.Linear(256, 1) 35 | 36 | # Q2 architecture 37 | self.l4 = nn.Linear(state_dim + action_dim, 256) 38 | self.l5 = nn.Linear(256, 256) 39 | self.l6 = nn.Linear(256, 1) 40 | 41 | def forward(self, state, action): 42 | sa = torch.cat([state, action], 1) 43 | 44 | q1 = F.relu(self.l1(sa)) 45 | q1 = F.relu(self.l2(q1)) 46 | q1 = self.l3(q1) 47 | 48 | q2 = F.relu(self.l4(sa)) 49 | q2 = F.relu(self.l5(q2)) 50 | q2 = self.l6(q2) 51 | return q1, q2 52 | 53 | def Q1(self, state, action): 54 | sa = torch.cat([state, action], 1) 55 | 56 | q1 = F.relu(self.l1(sa)) 57 | q1 = F.relu(self.l2(q1)) 58 | q1 = self.l3(q1) 59 | return q1 60 | -------------------------------------------------------------------------------- /imgs/antmaze_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manchery/iql-pytorch/83d8364667a60438591666fd3f932dccac22f629/imgs/antmaze_results.png -------------------------------------------------------------------------------- /imgs/mujoco_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manchery/iql-pytorch/83d8364667a60438591666fd3f932dccac22f629/imgs/mujoco_results.png -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import shutil 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | from termcolor import colored 10 | 11 | FORMAT_CONFIG = { 12 | 'rl': { 13 | 'train': [ 14 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 15 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 16 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'), 17 | ('critic_loss', 'CR_LOSS', 'float'), 18 | ], 19 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float'), ('episode_reward_test_env', 'ERTEST', 'float')] 20 | } 21 | } 22 | 23 | 24 | class AverageMeter(object): 25 | def __init__(self): 26 | self._sum = 0 27 | self._count = 0 28 | 29 | def update(self, value, n=1): 30 | self._sum += value 31 | self._count += n 32 | 33 | def value(self): 34 | return self._sum / max(1, self._count) 35 | 36 | 37 | class MetersGroup(object): 38 | def __init__(self, file_name, formating): 39 | self._file_name = file_name 40 | if os.path.exists(file_name): 41 | os.remove(file_name) 42 | self._formating = formating 43 | self._meters = defaultdict(AverageMeter) 44 | 45 | def log(self, key, value, n=1): 46 | self._meters[key].update(value, n) 47 | 48 | def _prime_meters(self): 49 | data = dict() 50 | for key, meter in self._meters.items(): 51 | if key.startswith('train'): 52 | key = key[len('train') + 1:] 53 | else: 54 | key = key[len('eval') + 1:] 55 | key = key.replace('/', '_') 56 | data[key] = meter.value() 57 | return data 58 | 59 | def _dump_to_file(self, data): 60 | with open(self._file_name, 'a') as f: 61 | f.write(json.dumps(data) + '\n') 62 | 63 | def _format(self, key, value, ty): 64 | template = '%s: ' 65 | if ty == 'int': 66 | template += '%d' 67 | elif ty == 'float': 68 | template += '%.04f' 69 | elif ty == 'time': 70 | template += '%.01f s' 71 | else: 72 | raise 'invalid format type: %s' % ty 73 | return template % (key, value) 74 | 75 | def _dump_to_console(self, data, prefix): 76 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 77 | pieces = ['{:5}'.format(prefix)] 78 | for key, disp_key, ty in self._formating: 79 | value = data.get(key, 0) 80 | pieces.append(self._format(disp_key, value, ty)) 81 | print('| %s' % (' | '.join(pieces))) 82 | 83 | def dump(self, step, prefix): 84 | if len(self._meters) == 0: 85 | return 86 | data = self._prime_meters() 87 | data['step'] = step 88 | self._dump_to_file(data) 89 | self._dump_to_console(data, prefix) 90 | self._meters.clear() 91 | 92 | 93 | class Logger(object): 94 | def __init__(self, log_dir, use_tb=True, config='rl'): 95 | self._log_dir = log_dir 96 | if use_tb: 97 | tb_dir = os.path.join(log_dir, 'tb') 98 | if os.path.exists(tb_dir): 99 | shutil.rmtree(tb_dir) 100 | self._sw = SummaryWriter(tb_dir) 101 | else: 102 | self._sw = None 103 | self._train_mg = MetersGroup( 104 | os.path.join(log_dir, 'train.log'), 105 | formating=FORMAT_CONFIG[config]['train'] 106 | ) 107 | self._eval_mg = MetersGroup( 108 | os.path.join(log_dir, 'eval.log'), 109 | formating=FORMAT_CONFIG[config]['eval'] 110 | ) 111 | 112 | def _try_sw_log(self, key, value, step): 113 | if self._sw is not None: 114 | self._sw.add_scalar(key, value, step) 115 | 116 | def _try_sw_log_image(self, key, image, step): 117 | if self._sw is not None: 118 | assert image.dim() == 3 119 | # grid = torchvision.utils.make_grid(image.unsqueeze(1)) 120 | self._sw.add_image(key, image, step) 121 | 122 | def _try_sw_log_video(self, key, frames, step): 123 | if self._sw is not None: 124 | frames = torch.from_numpy(np.array(frames)) 125 | frames = frames.unsqueeze(0) 126 | self._sw.add_video(key, frames, step, fps=30) 127 | 128 | def _try_sw_log_histogram(self, key, histogram, step): 129 | if self._sw is not None: 130 | self._sw.add_histogram(key, histogram, step) 131 | 132 | def log(self, key, value, step, n=1): 133 | assert key.startswith('train') or key.startswith('eval') 134 | if type(value) == torch.Tensor: 135 | value = value.item() 136 | self._try_sw_log(key, value / n, step) 137 | mg = self._train_mg if key.startswith('train') else self._eval_mg 138 | mg.log(key, value, n) 139 | 140 | def log_param(self, key, param, step): 141 | self.log_histogram(key + '_w', param.weight.data, step) 142 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 143 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 144 | if hasattr(param, 'bias'): 145 | self.log_histogram(key + '_b', param.bias.data, step) 146 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 147 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 148 | 149 | def log_image(self, key, image, step): 150 | assert key.startswith('train') or key.startswith('eval') 151 | self._try_sw_log_image(key, image, step) 152 | 153 | def log_video(self, key, frames, step): 154 | assert key.startswith('train') or key.startswith('eval') 155 | self._try_sw_log_video(key, frames, step) 156 | 157 | def log_histogram(self, key, histogram, step): 158 | assert key.startswith('train') or key.startswith('eval') 159 | self._try_sw_log_histogram(key, histogram, step) 160 | 161 | def dump(self, step): 162 | self._train_mg.dump(step, 'train') 163 | self._eval_mg.dump(step, 'eval') 164 | -------------------------------------------------------------------------------- /main_iql.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | import d4rl 7 | from tqdm import trange 8 | from coolname import generate_slug 9 | import time 10 | import json 11 | from log import Logger 12 | 13 | import utils 14 | from utils import VideoRecorder 15 | import IQL 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | # Runs policy for X episodes and returns average reward 20 | # A fixed seed is used for the eval environment 21 | 22 | 23 | def eval_policy(args, iter, video: VideoRecorder, logger: Logger, policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10): 24 | eval_env = gym.make(env_name) 25 | eval_env.seed(seed + seed_offset) 26 | 27 | lengths = [] 28 | returns = [] 29 | avg_reward = 0. 30 | for _ in range(eval_episodes): 31 | video.init(enabled=(args.save_video and _ == 0)) 32 | state, done = eval_env.reset(), False 33 | video.record(eval_env) 34 | steps = 0 35 | episode_return = 0 36 | while not done: 37 | state = (np.array(state).reshape(1, -1) - mean)/std 38 | action = policy.select_action(state) 39 | state, reward, done, _ = eval_env.step(action) 40 | video.record(eval_env) 41 | avg_reward += reward 42 | episode_return += reward 43 | steps += 1 44 | lengths.append(steps) 45 | returns.append(episode_return) 46 | video.save(f'eval_s{iter}_r{str(episode_return)}.mp4') 47 | 48 | avg_reward /= eval_episodes 49 | d4rl_score = eval_env.get_normalized_score(avg_reward) 50 | 51 | logger.log('eval/lengths_mean', np.mean(lengths), iter) 52 | logger.log('eval/lengths_std', np.std(lengths), iter) 53 | logger.log('eval/returns_mean', np.mean(returns), iter) 54 | logger.log('eval/returns_std', np.std(returns), iter) 55 | logger.log('eval/d4rl_score', d4rl_score, iter) 56 | 57 | print("---------------------------------------") 58 | print(f"Evaluation over {eval_episodes} episodes: {d4rl_score:.3f}") 59 | print("---------------------------------------") 60 | return d4rl_score 61 | 62 | 63 | if __name__ == "__main__": 64 | 65 | parser = argparse.ArgumentParser() 66 | # Experiment 67 | parser.add_argument("--policy", default="IQL") # Policy name 68 | parser.add_argument("--env", default="halfcheetah-medium-v2") # OpenAI gym environment name 69 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 70 | parser.add_argument("--eval_freq", default=1e4, type=int) # How often (time steps) we evaluate 71 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 72 | parser.add_argument("--save_model", action="store_true", default=False) # Save model and optimizer parameters 73 | parser.add_argument('--eval_episodes', default=10, type=int) 74 | parser.add_argument('--save_video', default=False, action='store_true') 75 | parser.add_argument("--normalize", default=False, action='store_true') 76 | # IQL 77 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 78 | parser.add_argument("--temperature", default=3.0, type=float) 79 | parser.add_argument("--expectile", default=0.7, type=float) 80 | parser.add_argument("--tau", default=0.005, type=float) 81 | parser.add_argument("--discount", default=0.99, type=float) # Discount factor 82 | # Work dir 83 | parser.add_argument('--work_dir', default='tmp', type=str) 84 | args = parser.parse_args() 85 | args.cooldir = generate_slug(2) 86 | 87 | # Build work dir 88 | base_dir = 'runs' 89 | utils.make_dir(base_dir) 90 | base_dir = os.path.join(base_dir, args.work_dir) 91 | utils.make_dir(base_dir) 92 | args.work_dir = os.path.join(base_dir, args.env) 93 | utils.make_dir(args.work_dir) 94 | 95 | # make directory 96 | ts = time.gmtime() 97 | ts = time.strftime("%m-%d-%H-%M", ts) 98 | exp_name = str(args.env) + '-' + ts + '-bs' + str(args.batch_size) + '-s' + str(args.seed) 99 | if args.policy == 'IQL': 100 | exp_name += '-t' + str(args.temperature) + '-e' + str(args.expectile) 101 | else: 102 | raise NotImplementedError 103 | exp_name += '-' + args.cooldir 104 | args.work_dir = args.work_dir + '/' + exp_name 105 | utils.make_dir(args.work_dir) 106 | 107 | args.model_dir = os.path.join(args.work_dir, 'model') 108 | utils.make_dir(args.model_dir) 109 | args.video_dir = os.path.join(args.work_dir, 'video') 110 | utils.make_dir(args.video_dir) 111 | 112 | with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: 113 | json.dump(vars(args), f, sort_keys=True, indent=4) 114 | 115 | utils.snapshot_src('.', os.path.join(args.work_dir, 'src'), '.gitignore') 116 | 117 | print("---------------------------------------") 118 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 119 | print("---------------------------------------") 120 | 121 | env = gym.make(args.env) 122 | 123 | # Set seeds 124 | env.seed(args.seed) 125 | env.action_space.seed(args.seed) 126 | torch.manual_seed(args.seed) 127 | np.random.seed(args.seed) 128 | 129 | state_dim = env.observation_space.shape[0] 130 | action_dim = env.action_space.shape[0] 131 | max_action = float(env.action_space.high[0]) 132 | 133 | kwargs = { 134 | "state_dim": state_dim, 135 | "action_dim": action_dim, 136 | # IQL 137 | "discount": args.discount, 138 | "tau": args.tau, 139 | "temperature": args.temperature, 140 | "expectile": args.expectile, 141 | } 142 | 143 | # Initialize policy 144 | if args.policy == 'IQL': 145 | policy = IQL.IQL(**kwargs) 146 | else: 147 | raise NotImplementedError 148 | 149 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 150 | replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env)) 151 | if 'antmaze' in args.env: 152 | # Center reward for Ant-Maze 153 | # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 154 | replay_buffer.reward = replay_buffer.reward - 1.0 155 | if args.normalize: 156 | mean, std = replay_buffer.normalize_states() 157 | else: 158 | mean, std = 0, 1 159 | 160 | logger = Logger(args.work_dir, use_tb=True) 161 | video = VideoRecorder(dir_name=args.video_dir) 162 | 163 | for t in trange(int(args.max_timesteps)): 164 | policy.train(replay_buffer, args.batch_size, logger=logger) 165 | # Evaluate episode 166 | if (t + 1) % args.eval_freq == 0: 167 | eval_episodes = 100 if t+1 == int(args.max_timesteps) else args.eval_episodes 168 | d4rl_score = eval_policy(args, t+1, video, logger, policy, args.env, 169 | args.seed, mean, std, eval_episodes=eval_episodes) 170 | if args.save_model: 171 | policy.save(args.model_dir) 172 | 173 | logger._sw.close() 174 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import os 5 | import random 6 | import imageio 7 | import gym 8 | from tqdm import trange 9 | import pickle 10 | 11 | 12 | class ReplayBuffer(object): 13 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 14 | self.max_size = max_size 15 | self.ptr = 0 16 | self.size = 0 17 | 18 | self.state = np.zeros((max_size, state_dim)) 19 | self.action = np.zeros((max_size, action_dim)) 20 | self.next_state = np.zeros((max_size, state_dim)) 21 | self.reward = np.zeros((max_size, 1)) 22 | self.not_done = np.zeros((max_size, 1)) 23 | 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | def add(self, state, action, next_state, reward, done): 27 | self.state[self.ptr] = state 28 | self.action[self.ptr] = action 29 | self.next_state[self.ptr] = next_state 30 | self.reward[self.ptr] = reward 31 | self.not_done[self.ptr] = 1. - done 32 | 33 | self.ptr = (self.ptr + 1) % self.max_size 34 | self.size = min(self.size + 1, self.max_size) 35 | 36 | def sample(self, batch_size): 37 | ind = np.random.randint(0, self.size, size=batch_size) 38 | 39 | return ( 40 | torch.FloatTensor(self.state[ind]).to(self.device), 41 | torch.FloatTensor(self.action[ind]).to(self.device), 42 | torch.FloatTensor(self.next_state[ind]).to(self.device), 43 | torch.FloatTensor(self.reward[ind]).to(self.device), 44 | torch.FloatTensor(self.not_done[ind]).to(self.device) 45 | ) 46 | 47 | def convert_D4RL(self, dataset): 48 | self.state = dataset['observations'] 49 | self.action = dataset['actions'] 50 | self.next_state = dataset['next_observations'] 51 | self.reward = dataset['rewards'].reshape(-1, 1) 52 | self.not_done = 1. - dataset['terminals'].reshape(-1, 1) 53 | self.size = self.state.shape[0] 54 | 55 | def normalize_states(self, eps=1e-3): 56 | mean = self.state.mean(0, keepdims=True) 57 | std = self.state.std(0, keepdims=True) + eps 58 | self.state = (self.state - mean)/std 59 | self.next_state = (self.next_state - mean)/std 60 | return mean, std 61 | 62 | 63 | def make_dir(dir_path): 64 | try: 65 | os.mkdir(dir_path) 66 | except OSError: 67 | pass 68 | return dir_path 69 | 70 | 71 | def set_seed_everywhere(seed): 72 | torch.manual_seed(seed) 73 | if torch.cuda.is_available(): 74 | torch.cuda.manual_seed_all(seed) 75 | np.random.seed(seed) 76 | random.seed(seed) 77 | 78 | 79 | def get_lr(optimizer): 80 | for param_group in optimizer.param_groups: 81 | return param_group['lr'] 82 | 83 | 84 | def snapshot_src(src, target, exclude_from): 85 | make_dir(target) 86 | os.system(f"rsync -rv --exclude-from={exclude_from} {src} {target}") 87 | 88 | 89 | class VideoRecorder(object): 90 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 91 | self.dir_name = dir_name 92 | self.height = height 93 | self.width = width 94 | self.camera_id = camera_id 95 | self.fps = fps 96 | self.frames = [] 97 | 98 | def init(self, enabled=True): 99 | self.frames = [] 100 | self.enabled = self.dir_name is not None and enabled 101 | 102 | def record(self, env): 103 | if self.enabled: 104 | frame = env.render( 105 | mode='rgb_array', 106 | height=self.height, 107 | width=self.width, 108 | camera_id=self.camera_id 109 | ) 110 | self.frames.append(frame) 111 | 112 | def save(self, file_name): 113 | if self.enabled: 114 | path = os.path.join(self.dir_name, file_name) 115 | imageio.mimsave(path, self.frames, fps=self.fps) 116 | --------------------------------------------------------------------------------