├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── requirements.txt ├── results.py └── src ├── iql.py ├── policy.py ├── util.py └── value_functions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Garrett Thomas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implicit Q-Learning (IQL) in PyTorch 2 | This repository houses a minimal PyTorch implementation of [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169), an offline reinforcement learning algorithm, along with a script to run IQL on tasks from the [D4RL](https://github.com/rail-berkeley/d4rl) benchmark. 3 | 4 | To install the dependencies, use `pip install -r requirements.txt`. 5 | 6 | You can run the script like so: 7 | ``` 8 | python main.py --log-dir /path/where/results/will/go --env-name hopper-medium-v2 --tau 0.7 --beta 3.0 9 | ``` 10 | 11 | Note that the paper's authors have published [their official implementation](https://github.com/ikostrikov/implicit_q_learning), which is based on JAX. My implementation is intended to be an alternative for PyTorch users, and my general recommendation is to use the authors' code unless you specifically want/need PyTorch for some reason. 12 | 13 | I am validating my implementation against the results stated in the paper as compute permits. 14 | Below are results for the MuJoCo locomotion tasks, normalized return at the end of training, averaged (+/- standard deviation) over 3 seeds: 15 | 16 | | Environment | This implementation | Official implementation | 17 | | ----------- | ------------------- | ----------------------- | 18 | | halfcheetah-medium-v2 | 47.7 +/- 0.2 | 47.4 | 19 | | hopper-medium-v2 | 61.2 +/- 6.4 | 66.3 | 20 | | walker2d-medium-v2 | 78.7 +/- 4.5 | 78.3 | 21 | | halfcheetah-medium-replay-v2 | 42.9 +/- 1.7 | 44.2 | 22 | | hopper-medium-replay-v2 | 86.8 +/- 15.5 | 94.7 | 23 | | walker2d-medium-replay-v2 | 68.3 +/- 6.4 | 73.9 | 24 | | halfcheetah-medium-expert-v2 | 88.3 +/- 2.8 | 86.7 | 25 | | hopper-medium-expert-v2 | 76.6 +/- 34.9 | 91.5 | 26 | | walker2d-medium-expert-v2 | 108.7 +/- 2.2 | 109.6 | 27 | 28 | We can see that the performance is mostly similar to what is stated in the paper, but slightly worse on a few tasks. Note that these results were obtained using a small simplification (deterministic policy and least-squares loss rather than a Gaussian distribution and negative log likelihood), which may explain the discrepancy. 29 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import gym 4 | import d4rl 5 | import numpy as np 6 | import torch 7 | from tqdm import trange 8 | 9 | from src.iql import ImplicitQLearning 10 | from src.policy import GaussianPolicy, DeterministicPolicy 11 | from src.value_functions import TwinQ, ValueFunction 12 | from src.util import return_range, set_seed, Log, sample_batch, torchify, evaluate_policy 13 | 14 | 15 | def get_env_and_dataset(log, env_name, max_episode_steps): 16 | env = gym.make(env_name) 17 | dataset = d4rl.qlearning_dataset(env) 18 | 19 | if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')): 20 | min_ret, max_ret = return_range(dataset, max_episode_steps) 21 | log(f'Dataset returns have range [{min_ret}, {max_ret}]') 22 | dataset['rewards'] /= (max_ret - min_ret) 23 | dataset['rewards'] *= max_episode_steps 24 | elif 'antmaze' in env_name: 25 | dataset['rewards'] -= 1. 26 | 27 | for k, v in dataset.items(): 28 | dataset[k] = torchify(v) 29 | 30 | return env, dataset 31 | 32 | 33 | def main(args): 34 | torch.set_num_threads(1) 35 | log = Log(Path(args.log_dir)/args.env_name, vars(args)) 36 | log(f'Log dir: {log.dir}') 37 | 38 | env, dataset = get_env_and_dataset(log, args.env_name, args.max_episode_steps) 39 | obs_dim = dataset['observations'].shape[1] 40 | act_dim = dataset['actions'].shape[1] # this assume continuous actions 41 | set_seed(args.seed, env=env) 42 | 43 | if args.deterministic_policy: 44 | policy = DeterministicPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 45 | else: 46 | policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden) 47 | def eval_policy(): 48 | eval_returns = np.array([evaluate_policy(env, policy, args.max_episode_steps) \ 49 | for _ in range(args.n_eval_episodes)]) 50 | normalized_returns = d4rl.get_normalized_score(args.env_name, eval_returns) * 100.0 51 | log.row({ 52 | 'return mean': eval_returns.mean(), 53 | 'return std': eval_returns.std(), 54 | 'normalized return mean': normalized_returns.mean(), 55 | 'normalized return std': normalized_returns.std(), 56 | }) 57 | 58 | iql = ImplicitQLearning( 59 | qf=TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden), 60 | vf=ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden), 61 | policy=policy, 62 | optimizer_factory=lambda params: torch.optim.Adam(params, lr=args.learning_rate), 63 | max_steps=args.n_steps, 64 | tau=args.tau, 65 | beta=args.beta, 66 | alpha=args.alpha, 67 | discount=args.discount 68 | ) 69 | 70 | for step in trange(args.n_steps): 71 | iql.update(**sample_batch(dataset, args.batch_size)) 72 | if (step+1) % args.eval_period == 0: 73 | eval_policy() 74 | 75 | torch.save(iql.state_dict(), log.dir/'final.pt') 76 | log.close() 77 | 78 | 79 | if __name__ == '__main__': 80 | from argparse import ArgumentParser 81 | parser = ArgumentParser() 82 | parser.add_argument('--env-name', required=True) 83 | parser.add_argument('--log-dir', required=True) 84 | parser.add_argument('--seed', type=int, default=0) 85 | parser.add_argument('--discount', type=float, default=0.99) 86 | parser.add_argument('--hidden-dim', type=int, default=256) 87 | parser.add_argument('--n-hidden', type=int, default=2) 88 | parser.add_argument('--n-steps', type=int, default=10**6) 89 | parser.add_argument('--batch-size', type=int, default=256) 90 | parser.add_argument('--learning-rate', type=float, default=3e-4) 91 | parser.add_argument('--alpha', type=float, default=0.005) 92 | parser.add_argument('--tau', type=float, default=0.7) 93 | parser.add_argument('--beta', type=float, default=3.0) 94 | parser.add_argument('--deterministic-policy', action='store_true') 95 | parser.add_argument('--eval-period', type=int, default=5000) 96 | parser.add_argument('--n-eval-episodes', type=int, default=10) 97 | parser.add_argument('--max-episode-steps', type=int, default=1000) 98 | main(parser.parse_args()) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | tqdm 5 | gym[mujoco] >= 0.18.0 6 | torch>=1.7.0 7 | git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl -------------------------------------------------------------------------------- /results.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | LOCOMOTION_ENVS = { 8 | 'halfcheetah-medium-v2': 47.4, 9 | 'hopper-medium-v2': 66.3, 10 | 'walker2d-medium-v2': 78.3, 11 | 'halfcheetah-medium-replay-v2': 44.2, 12 | 'hopper-medium-replay-v2': 94.7, 13 | 'walker2d-medium-replay-v2': 73.9, 14 | 'halfcheetah-medium-expert-v2': 86.7, 15 | 'hopper-medium-expert-v2': 91.5, 16 | 'walker2d-medium-expert-v2': 109.6 17 | } 18 | 19 | ANTMAZE_ENVS = { 20 | 'antmaze-umaze-v0': 87.5, 21 | 'antmaze-umaze-diverse-v0': 62.2, 22 | 'antmaze-medium-play-v0': 71.2, 23 | 'antmaze-medium-diverse-v0': 70.0, 24 | 'antmaze-large-play-v0': 39.6, 25 | 'antmaze-large-diverse-v0': 47.5 26 | } 27 | 28 | KITCHEN_ENVS = { 29 | 'kitchen-complete-v0': 62.5, 30 | 'kitchen-partial-v0': 46.3, 31 | 'kitchen-mixed-v0': 51.0 32 | } 33 | 34 | ADROIT_ENVS = { 35 | 'pen-human-v0': 71.5, 36 | 'hammer-human-v0': 1.4, 37 | 'door-human-v0': 4.3, 38 | 'relocate-human-v0': 0.1, 39 | 'pen-cloned-v0': 37.3, 40 | 'hammer-cloned-v0': 2.1, 41 | 'door-cloned-v0': 1.6, 42 | 'relocate-cloned-v0': -0.2 43 | } 44 | 45 | ENV_COLLECTIONS = { 46 | 'locomotion-all': LOCOMOTION_ENVS, 47 | 'antmaze-all': ANTMAZE_ENVS, 48 | 'kitchen-all': KITCHEN_ENVS, 49 | 'adroit-all': ADROIT_ENVS 50 | } 51 | 52 | 53 | def main(args): 54 | dir = Path(args.dir) 55 | assert dir.is_dir(), f'{dir} is not a directory' 56 | print('| Environment | This implementation | Official implementation |\n' 57 | '| ----------- | ------------------- | ----------------------- |') 58 | envs = ENV_COLLECTIONS[args.envs] 59 | for env, ref_score in envs.items(): 60 | env_dir = dir/env 61 | assert env_dir.is_dir(), f'{env_dir} is not a directory' 62 | run_dirs = [d for d in env_dir.iterdir() if d.is_dir()] 63 | final_perfs = [] 64 | for run_dir in run_dirs: 65 | data = pd.read_csv(run_dir/'progress.csv') 66 | normalized_returns = data['normalized return mean'].to_numpy() 67 | final_perfs.append(normalized_returns[-args.last_k:]) 68 | print(f'| {env} | {np.mean(final_perfs):.1f} +/- {np.std(final_perfs):.1f} | {ref_score:.1f} |') 69 | 70 | 71 | if __name__ == '__main__': 72 | from argparse import ArgumentParser 73 | parser = ArgumentParser() 74 | parser.add_argument('-d', '--dir', required=True) 75 | parser.add_argument('-e', '--envs', required=True) 76 | parser.add_argument('-k', '--last-k', type=int, default=10) # average over last k evals 77 | main(parser.parse_args()) -------------------------------------------------------------------------------- /src/iql.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | 8 | from .util import DEFAULT_DEVICE, compute_batched, update_exponential_moving_average 9 | 10 | 11 | EXP_ADV_MAX = 100. 12 | 13 | 14 | def asymmetric_l2_loss(u, tau): 15 | return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) 16 | 17 | 18 | class ImplicitQLearning(nn.Module): 19 | def __init__(self, qf, vf, policy, optimizer_factory, max_steps, 20 | tau, beta, discount=0.99, alpha=0.005): 21 | super().__init__() 22 | self.qf = qf.to(DEFAULT_DEVICE) 23 | self.q_target = copy.deepcopy(qf).requires_grad_(False).to(DEFAULT_DEVICE) 24 | self.vf = vf.to(DEFAULT_DEVICE) 25 | self.policy = policy.to(DEFAULT_DEVICE) 26 | self.v_optimizer = optimizer_factory(self.vf.parameters()) 27 | self.q_optimizer = optimizer_factory(self.qf.parameters()) 28 | self.policy_optimizer = optimizer_factory(self.policy.parameters()) 29 | self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, max_steps) 30 | self.tau = tau 31 | self.beta = beta 32 | self.discount = discount 33 | self.alpha = alpha 34 | 35 | def update(self, observations, actions, next_observations, rewards, terminals): 36 | with torch.no_grad(): 37 | target_q = self.q_target(observations, actions) 38 | next_v = self.vf(next_observations) 39 | 40 | # v, next_v = compute_batched(self.vf, [observations, next_observations]) 41 | 42 | # Update value function 43 | v = self.vf(observations) 44 | adv = target_q - v 45 | v_loss = asymmetric_l2_loss(adv, self.tau) 46 | self.v_optimizer.zero_grad(set_to_none=True) 47 | v_loss.backward() 48 | self.v_optimizer.step() 49 | 50 | # Update Q function 51 | targets = rewards + (1. - terminals.float()) * self.discount * next_v.detach() 52 | qs = self.qf.both(observations, actions) 53 | q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) 54 | self.q_optimizer.zero_grad(set_to_none=True) 55 | q_loss.backward() 56 | self.q_optimizer.step() 57 | 58 | # Update target Q network 59 | update_exponential_moving_average(self.q_target, self.qf, self.alpha) 60 | 61 | # Update policy 62 | exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX) 63 | policy_out = self.policy(observations) 64 | if isinstance(policy_out, torch.distributions.Distribution): 65 | bc_losses = -policy_out.log_prob(actions) 66 | elif torch.is_tensor(policy_out): 67 | assert policy_out.shape == actions.shape 68 | bc_losses = torch.sum((policy_out - actions)**2, dim=1) 69 | else: 70 | raise NotImplementedError 71 | policy_loss = torch.mean(exp_adv * bc_losses) 72 | self.policy_optimizer.zero_grad(set_to_none=True) 73 | policy_loss.backward() 74 | self.policy_optimizer.step() 75 | self.policy_lr_schedule.step() -------------------------------------------------------------------------------- /src/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import MultivariateNormal 4 | 5 | from .util import mlp 6 | 7 | 8 | LOG_STD_MIN = -5.0 9 | LOG_STD_MAX = 2.0 10 | 11 | 12 | class GaussianPolicy(nn.Module): 13 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2): 14 | super().__init__() 15 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim]) 16 | self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 17 | 18 | def forward(self, obs): 19 | mean = self.net(obs) 20 | std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 21 | scale_tril = torch.diag(std) 22 | return MultivariateNormal(mean, scale_tril=scale_tril) 23 | # if mean.ndim > 1: 24 | # batch_size = len(obs) 25 | # return MultivariateNormal(mean, scale_tril=scale_tril.repeat(batch_size, 1, 1)) 26 | # else: 27 | # return MultivariateNormal(mean, scale_tril=scale_tril) 28 | 29 | def act(self, obs, deterministic=False, enable_grad=False): 30 | with torch.set_grad_enabled(enable_grad): 31 | dist = self(obs) 32 | return dist.mean if deterministic else dist.sample() 33 | 34 | 35 | class DeterministicPolicy(nn.Module): 36 | def __init__(self, obs_dim, act_dim, hidden_dim=256, n_hidden=2): 37 | super().__init__() 38 | self.net = mlp([obs_dim, *([hidden_dim] * n_hidden), act_dim], 39 | output_activation=nn.Tanh) 40 | 41 | def forward(self, obs): 42 | return self.net(obs) 43 | 44 | def act(self, obs, deterministic=False, enable_grad=False): 45 | with torch.set_grad_enabled(enable_grad): 46 | return self(obs) -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from datetime import datetime 3 | import json 4 | from pathlib import Path 5 | import random 6 | import string 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | class Squeeze(nn.Module): 18 | def __init__(self, dim=None): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | return x.squeeze(dim=self.dim) 24 | 25 | 26 | def mlp(dims, activation=nn.ReLU, output_activation=None, squeeze_output=False): 27 | n_dims = len(dims) 28 | assert n_dims >= 2, 'MLP requires at least two dims (input and output)' 29 | 30 | layers = [] 31 | for i in range(n_dims - 2): 32 | layers.append(nn.Linear(dims[i], dims[i+1])) 33 | layers.append(activation()) 34 | layers.append(nn.Linear(dims[-2], dims[-1])) 35 | if output_activation is not None: 36 | layers.append(output_activation()) 37 | if squeeze_output: 38 | assert dims[-1] == 1 39 | layers.append(Squeeze(-1)) 40 | net = nn.Sequential(*layers) 41 | net.to(dtype=torch.float32) 42 | return net 43 | 44 | 45 | def compute_batched(f, xs): 46 | return f(torch.cat(xs, dim=0)).split([len(x) for x in xs]) 47 | 48 | 49 | def update_exponential_moving_average(target, source, alpha): 50 | for target_param, source_param in zip(target.parameters(), source.parameters()): 51 | target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha) 52 | 53 | 54 | def torchify(x): 55 | x = torch.from_numpy(x) 56 | if x.dtype is torch.float64: 57 | x = x.float() 58 | x = x.to(device=DEFAULT_DEVICE) 59 | return x 60 | 61 | 62 | 63 | def return_range(dataset, max_episode_steps): 64 | returns, lengths = [], [] 65 | ep_ret, ep_len = 0., 0 66 | for r, d in zip(dataset['rewards'], dataset['terminals']): 67 | ep_ret += float(r) 68 | ep_len += 1 69 | if d or ep_len == max_episode_steps: 70 | returns.append(ep_ret) 71 | lengths.append(ep_len) 72 | ep_ret, ep_len = 0., 0 73 | # returns.append(ep_ret) # incomplete trajectory 74 | lengths.append(ep_len) # but still keep track of number of steps 75 | assert sum(lengths) == len(dataset['rewards']) 76 | return min(returns), max(returns) 77 | 78 | 79 | # dataset is a dict, values of which are tensors of same first dimension 80 | def sample_batch(dataset, batch_size): 81 | k = list(dataset.keys())[0] 82 | n, device = len(dataset[k]), dataset[k].device 83 | for v in dataset.values(): 84 | assert len(v) == n, 'Dataset values must have same length' 85 | indices = torch.randint(low=0, high=n, size=(batch_size,), device=device) 86 | return {k: v[indices] for k, v in dataset.items()} 87 | 88 | 89 | def evaluate_policy(env, policy, max_episode_steps, deterministic=True): 90 | obs = env.reset() 91 | total_reward = 0. 92 | for _ in range(max_episode_steps): 93 | with torch.no_grad(): 94 | action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy() 95 | next_obs, reward, done, info = env.step(action) 96 | total_reward += reward 97 | if done: 98 | break 99 | else: 100 | obs = next_obs 101 | return total_reward 102 | 103 | 104 | def set_seed(seed, env=None): 105 | torch.manual_seed(seed) 106 | if torch.cuda.is_available(): 107 | torch.cuda.manual_seed_all(seed) 108 | np.random.seed(seed) 109 | random.seed(seed) 110 | if env is not None: 111 | env.seed(seed) 112 | 113 | 114 | def _gen_dir_name(): 115 | now_str = datetime.now().strftime('%m-%d-%y_%H.%M.%S') 116 | rand_str = ''.join(random.choices(string.ascii_lowercase, k=4)) 117 | return f'{now_str}_{rand_str}' 118 | 119 | class Log: 120 | def __init__(self, root_log_dir, cfg_dict, 121 | txt_filename='log.txt', 122 | csv_filename='progress.csv', 123 | cfg_filename='config.json', 124 | flush=True): 125 | self.dir = Path(root_log_dir)/_gen_dir_name() 126 | self.dir.mkdir(parents=True) 127 | self.txt_file = open(self.dir/txt_filename, 'w') 128 | self.csv_file = None 129 | (self.dir/cfg_filename).write_text(json.dumps(cfg_dict)) 130 | self.txt_filename = txt_filename 131 | self.csv_filename = csv_filename 132 | self.cfg_filename = cfg_filename 133 | self.flush = flush 134 | 135 | def write(self, message, end='\n'): 136 | now_str = datetime.now().strftime('%H:%M:%S') 137 | message = f'[{now_str}] ' + message 138 | for f in [sys.stdout, self.txt_file]: 139 | print(message, end=end, file=f, flush=self.flush) 140 | 141 | def __call__(self, *args, **kwargs): 142 | self.write(*args, **kwargs) 143 | 144 | def row(self, dict): 145 | if self.csv_file is None: 146 | self.csv_file = open(self.dir/self.csv_filename, 'w', newline='') 147 | self.csv_writer = csv.DictWriter(self.csv_file, list(dict.keys())) 148 | self.csv_writer.writeheader() 149 | 150 | self(str(dict)) 151 | self.csv_writer.writerow(dict) 152 | if self.flush: 153 | self.csv_file.flush() 154 | 155 | def close(self): 156 | self.txt_file.close() 157 | if self.csv_file is not None: 158 | self.csv_file.close() -------------------------------------------------------------------------------- /src/value_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .util import mlp 4 | 5 | 6 | class TwinQ(nn.Module): 7 | def __init__(self, state_dim, action_dim, hidden_dim=256, n_hidden=2): 8 | super().__init__() 9 | dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1] 10 | self.q1 = mlp(dims, squeeze_output=True) 11 | self.q2 = mlp(dims, squeeze_output=True) 12 | 13 | def both(self, state, action): 14 | sa = torch.cat([state, action], 1) 15 | return self.q1(sa), self.q2(sa) 16 | 17 | def forward(self, state, action): 18 | return torch.min(*self.both(state, action)) 19 | 20 | 21 | class ValueFunction(nn.Module): 22 | def __init__(self, state_dim, hidden_dim=256, n_hidden=2): 23 | super().__init__() 24 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 25 | self.v = mlp(dims, squeeze_output=True) 26 | 27 | def forward(self, state): 28 | return self.v(state) --------------------------------------------------------------------------------