├── offline ├── wrappers │ ├── __init__.py │ ├── common.py │ ├── single_precision.py │ └── episode_monitor.py ├── configs │ ├── kitchen_config.py │ ├── mujoco_config.py │ └── antmaze_config.py ├── evaluation.py ├── logging_utils │ ├── serialization_utils.py │ ├── mpi_tools.py │ └── logx.py ├── actor.py ├── README.md ├── policy.py ├── value_net.py ├── common.py ├── dataset_utils.py ├── critic.py ├── learner.py ├── train_offline.py └── environment.yml ├── README.md └── .gitignore /offline/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from wrappers.episode_monitor import EpisodeMonitor 2 | from wrappers.single_precision import SinglePrecision 3 | -------------------------------------------------------------------------------- /offline/wrappers/common.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | TimeStep = Tuple[np.ndarray, float, bool, dict] 6 | -------------------------------------------------------------------------------- /offline/configs/kitchen_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.actor_lr = 3e-4 8 | config.value_lr = 3e-4 9 | config.critic_lr = 3e-4 10 | 11 | config.hidden_dims = (256, 256) 12 | 13 | config.discount = 0.99 14 | 15 | config.expectile = 0.7 # The actual tau for expectiles. 16 | config.temperature = 0.5 17 | config.dropout_rate = 0.1 18 | config.layernorm = True 19 | 20 | config.tau = 0.005 # For soft target updates. 21 | 22 | return config 23 | -------------------------------------------------------------------------------- /offline/configs/mujoco_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.actor_lr = 3e-4 8 | config.value_lr = 3e-4 9 | config.critic_lr = 3e-4 10 | 11 | config.hidden_dims = (256, 256) 12 | 13 | config.discount = 0.99 14 | 15 | config.expectile = 0.7 # The actual tau for expectiles. 16 | config.temperature = 3.0 17 | config.dropout_rate = 0.0 18 | config.layernorm = True 19 | 20 | config.tau = 0.005 # For soft target updates. 21 | 22 | return config 23 | -------------------------------------------------------------------------------- /offline/configs/antmaze_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict(type_safe=False) 6 | 7 | config.actor_lr = 3e-4 8 | config.value_lr = 3e-4 9 | config.critic_lr = 3e-4 10 | 11 | config.hidden_dims = (256, 256) 12 | 13 | config.discount = 0.99 14 | 15 | config.expectile = 0.9 # The actual tau for expectiles. 16 | config.temperature = 10.0 # 10.0 17 | config.dropout_rate = 0.0 18 | config.layernorm = True 19 | config.value_dropout_rate = 0.5 20 | 21 | config.tau = 0.005 # For soft target updates. 22 | config.opt_decay_schedule = None # Don't decay optimizer lr 23 | 24 | return config 25 | -------------------------------------------------------------------------------- /offline/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import flax.linen as nn 4 | import gym 5 | import numpy as np 6 | 7 | 8 | def evaluate(agent: nn.Module, env: gym.Env, 9 | num_episodes: int, verbose: bool = False) -> Dict[str, float]: 10 | stats = {'return': [], 'length': []} 11 | 12 | for _ in range(num_episodes): 13 | observation, done = env.reset(), False 14 | 15 | while not done: 16 | action = agent.sample_actions(observation, temperature=0.0) 17 | observation, _, done, info = env.step(action) 18 | 19 | for k in stats.keys(): 20 | stats[k].append(info['episode'][k]) 21 | if verbose: 22 | v = info['episode'][k] 23 | print(f'{k}:{v}') 24 | 25 | for k, v in stats.items(): 26 | stats[k] = np.mean(v) 27 | 28 | return stats 29 | -------------------------------------------------------------------------------- /offline/logging_utils/serialization_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def convert_json(obj): 4 | """ Convert obj to a version which can be serialized with JSON. """ 5 | if is_json_serializable(obj): 6 | return obj 7 | else: 8 | if isinstance(obj, dict): 9 | return {convert_json(k): convert_json(v) 10 | for k,v in obj.items()} 11 | 12 | elif isinstance(obj, tuple): 13 | return (convert_json(x) for x in obj) 14 | 15 | elif isinstance(obj, list): 16 | return [convert_json(x) for x in obj] 17 | 18 | elif hasattr(obj,'__name__') and not('lambda' in obj.__name__): 19 | return convert_json(obj.__name__) 20 | 21 | elif hasattr(obj,'__dict__') and obj.__dict__: 22 | obj_dict = {convert_json(k): convert_json(v) 23 | for k,v in obj.__dict__.items()} 24 | return {str(obj): obj_dict} 25 | 26 | return str(obj) 27 | 28 | def is_json_serializable(v): 29 | try: 30 | json.dumps(v) 31 | return True 32 | except: 33 | return False -------------------------------------------------------------------------------- /offline/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import jax 3 | import jax.numpy as jnp 4 | from common import Batch, InfoDict, Model, Params, PRNGKey 5 | 6 | 7 | def update(key: PRNGKey, actor: Model, critic: Model, value: Model, 8 | batch: Batch, temperature: float, double: bool) -> Tuple[Model, InfoDict]: 9 | v = value(batch.observations) 10 | 11 | q1, q2 = critic(batch.observations, batch.actions) 12 | if double: 13 | q = jnp.minimum(q1, q2) 14 | else: 15 | q = q1 16 | exp_a = jnp.exp((q - v) * temperature) 17 | exp_a = jnp.minimum(exp_a, 100.0) 18 | 19 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 20 | dist = actor.apply({'params': actor_params}, 21 | batch.observations, 22 | training=True, 23 | rngs={'dropout': key}) 24 | log_probs = dist.log_prob(batch.actions) 25 | actor_loss = -(exp_a * log_probs).mean() 26 | 27 | return actor_loss, {'actor_loss': actor_loss, 'adv': q - v} 28 | 29 | new_actor, info = actor.apply_gradient(actor_loss_fn) 30 | 31 | return new_actor, info 32 | -------------------------------------------------------------------------------- /offline/wrappers/single_precision.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gym 3 | import numpy as np 4 | from gym.spaces import Box, Dict 5 | 6 | 7 | class SinglePrecision(gym.ObservationWrapper): 8 | def __init__(self, env): 9 | super().__init__(env) 10 | 11 | if isinstance(self.observation_space, Box): 12 | obs_space = self.observation_space 13 | self.observation_space = Box(obs_space.low, obs_space.high, 14 | obs_space.shape) 15 | elif isinstance(self.observation_space, Dict): 16 | obs_spaces = copy.copy(self.observation_space.spaces) 17 | for k, v in obs_spaces.items(): 18 | obs_spaces[k] = Box(v.low, v.high, v.shape) 19 | self.observation_space = Dict(obs_spaces) 20 | else: 21 | raise NotImplementedError 22 | 23 | def observation(self, observation: np.ndarray) -> np.ndarray: 24 | if isinstance(observation, np.ndarray): 25 | return observation.astype(np.float32) 26 | elif isinstance(observation, dict): 27 | observation = copy.copy(observation) 28 | for k, v in observation.items(): 29 | observation[k] = v.astype(np.float32) 30 | return observation 31 | -------------------------------------------------------------------------------- /offline/wrappers/episode_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from wrappers.common import TimeStep 7 | 8 | 9 | class EpisodeMonitor(gym.ActionWrapper): 10 | """A class that computes episode returns and lengths.""" 11 | def __init__(self, env: gym.Env): 12 | super().__init__(env) 13 | self._reset_stats() 14 | self.total_timesteps = 0 15 | 16 | def _reset_stats(self): 17 | self.reward_sum = 0.0 18 | self.episode_length = 0 19 | self.start_time = time.time() 20 | 21 | def step(self, action: np.ndarray) -> TimeStep: 22 | observation, reward, done, info = self.env.step(action) 23 | 24 | self.reward_sum += reward 25 | self.episode_length += 1 26 | self.total_timesteps += 1 27 | info['total'] = {'timesteps': self.total_timesteps} 28 | 29 | if done: 30 | info['episode'] = {} 31 | info['episode']['return'] = self.reward_sum 32 | info['episode']['length'] = self.episode_length 33 | info['episode']['duration'] = time.time() - self.start_time 34 | 35 | if hasattr(self, 'get_normalized_score'): 36 | info['episode']['return'] = self.get_normalized_score( 37 | info['episode']['return']) * 100.0 38 | 39 | return observation, reward, done, info 40 | 41 | def reset(self) -> np.ndarray: 42 | self._reset_stats() 43 | return self.env.reset() -------------------------------------------------------------------------------- /offline/README.md: -------------------------------------------------------------------------------- 1 | # Dual V-Learning 2 | 3 | Official code base for Dual RL: Dual RL: Unification and New Methods for Reinforcement and Imitation Learning 4 | 5 | 6 | 7 | ## How to run the code 8 | 9 | ### Install dependencies 10 | 11 | Create an empty conda environment and follow the commands below. 12 | 13 | ```bash 14 | conda create -n dvl python=3.9 15 | 16 | conda install -c conda-forge cudnn 17 | 18 | pip install --upgrade pip 19 | 20 | # Install 1 of the below jax versions depending on your CUDA version 21 | ## 1. CUDA 12 installation 22 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 23 | 24 | ## 2. CUDA 11 installation 25 | pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 26 | 27 | 28 | pip install -r requirements.txt 29 | 30 | ``` 31 | 32 | ### Example training code 33 | 34 | Locomotion 35 | ```bash 36 | python train_offline.py --env_name=halfcheetah-medium-expert-v2 --f=chi-square --config=configs/mujoco_config.py --max_clip=5 --sample_random_times=1 --temp=1 37 | ``` 38 | 39 | AntMaze 40 | ```bash 41 | python train_offline.py --env_name=antmaze-large-play-v0 --f=total-variation --config=configs/antmaze_config.py --eval_episodes=100 --eval_interval=100000 --max_clip=5 --temp=0.8 42 | ``` 43 | 44 | Kitchen and Adroit 45 | ```bash 46 | python train_offline.py --env_name=pen-human-v0 --f=reverse-KL --config=configs/kitchen_config.py --max_clip=5 --sample_random_times=1 --temp=8 47 | ``` 48 | 49 | 50 | 51 | ## Acknowledgement and Reference 52 | 53 | This code base heavily builds upon the following code bases: [Extreme Q-learning](https://github.com/Div99/XQL) and [Implicit Q-Learning](https://github.com/ikostrikov/implicit_q_learning). 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual V-Learning (DVL) 2 | 3 | ### [**[Project Page](https://hari-sikchi.github.io/dual-rl/)**] 4 | 5 | ### [**[Paper](https://arxiv.org/abs/2302.08560)**] 6 | 7 | 8 | 9 | Official code base for **[Dual RL: Unification and New Methods for Reinforcement and Imitation Learning](https://arxiv.org/abs/2302.08560)** by [Harshit Sikchi](https://hari-sikchi.github.io/), [Qinqing Zheng](https://enosair.github.io/), [Amy Zhang](https://www.ece.utexas.edu/people/faculty/amy-zhang), and [Scott Niekum](https://people.cs.umass.edu/~sniekum/). 10 | 11 | 12 | 13 | 14 | This repository contains code for **Dual V-Learning (DVL)** framework for Reinforcement Learning proposed in our paper. 15 | 16 | 17 | Please refer to instructions inside the **offline** folder to get started with installation and running the code. 18 | 19 | 20 | ## Benefits of DVL over other offline RL methods 21 | ✅ Fixes the instability of Extreme Q Learning \(XQL\) \ 22 | ✅ Directly models V* in continuous action spaces \ 23 | ✅ Implict, no OOD Sampling or actor-critic formulation \ 24 | ✅ Conservative with respect to the induced behavior policy distribution \ 25 | ✅ Improves performance on the D4RL benchmark versus similar approaches 26 | 27 | ### Citation 28 | ``` 29 | @misc{sikchi2023dual, 30 | title={Dual RL: Unification and New Methods for Reinforcement and Imitation Learning}, 31 | author={Harshit Sikchi and Qinqing Zheng and Amy Zhang and Scott Niekum}, 32 | year={2023}, 33 | eprint={2302.08560}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.LG} 36 | } 37 | ``` 38 | 39 | 40 | 41 | 42 | ## Questions 43 | Please feel free to email us if you have any questions. 44 | 45 | Harshit Sikchi ([hsikchi@utexas.edu](mailto:hsikchi@utexas.edu?subject=[GitHub]%DVL)) 46 | 47 | 48 | ## Acknowledgement 49 | 50 | This repository builds heavily on the XQL(https://github.com/Div99/xql) and IQL(https://github.com/ikostrikov/implicit_q_learning) codebases. Please make sure to cite them as well when using this code. 51 | 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | */.DS_STORE 6 | # C extensions 7 | *.so 8 | */results/* 9 | */tmp/* 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /offline/logging_utils/mpi_tools.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import os, subprocess, sys 3 | import numpy as np 4 | 5 | 6 | def mpi_fork(n, bind_to_core=False): 7 | """ 8 | Re-launches the current script with workers linked by MPI. 9 | 10 | Also, terminates the original process that launched it. 11 | 12 | Taken almost without modification from the Baselines function of the 13 | `same name`_. 14 | 15 | .. _`same name`: https://github.com/openai/baselines/blob/master/baselines/common/mpi_fork.py 16 | 17 | Args: 18 | n (int): Number of process to split into. 19 | 20 | bind_to_core (bool): Bind each MPI process to a core. 21 | """ 22 | if n<=1: 23 | return 24 | if os.getenv("IN_MPI") is None: 25 | env = os.environ.copy() 26 | env.update( 27 | MKL_NUM_THREADS="1", 28 | OMP_NUM_THREADS="1", 29 | IN_MPI="1" 30 | ) 31 | args = ["mpirun", "-np", str(n)] 32 | if bind_to_core: 33 | args += ["-bind-to", "core"] 34 | args += [sys.executable] + sys.argv 35 | subprocess.check_call(args, env=env) 36 | sys.exit() 37 | 38 | 39 | def msg(m, string=''): 40 | print(('Message from %d: %s \t '%(MPI.COMM_WORLD.Get_rank(), string))+str(m)) 41 | 42 | def proc_id(): 43 | """Get rank of calling process.""" 44 | return MPI.COMM_WORLD.Get_rank() 45 | 46 | def allreduce(*args, **kwargs): 47 | return MPI.COMM_WORLD.Allreduce(*args, **kwargs) 48 | 49 | def num_procs(): 50 | """Count active MPI processes.""" 51 | return MPI.COMM_WORLD.Get_size() 52 | 53 | def broadcast(x, root=0): 54 | MPI.COMM_WORLD.Bcast(x, root=root) 55 | 56 | def mpi_op(x, op): 57 | x, scalar = ([x], True) if np.isscalar(x) else (x, False) 58 | x = np.asarray(x, dtype=np.float32) 59 | buff = np.zeros_like(x, dtype=np.float32) 60 | allreduce(x, buff, op=op) 61 | return buff[0] if scalar else buff 62 | 63 | def mpi_sum(x): 64 | return mpi_op(x, MPI.SUM) 65 | 66 | def mpi_avg(x): 67 | """Average a scalar or vector over MPI processes.""" 68 | return mpi_sum(x) / num_procs() 69 | 70 | def mpi_statistics_scalar(x, with_min_and_max=False): 71 | """ 72 | Get mean/std and optional min/max of scalar x across MPI processes. 73 | 74 | Args: 75 | x: An array containing samples of the scalar to produce statistics 76 | for. 77 | 78 | with_min_and_max (bool): If true, return min and max of x in 79 | addition to mean and std. 80 | """ 81 | x = np.array(x, dtype=np.float32) 82 | global_sum, global_n = mpi_sum([np.sum(x), len(x)]) 83 | mean = global_sum / global_n 84 | 85 | global_sum_sq = mpi_sum(np.sum((x - mean)**2)) 86 | std = np.sqrt(global_sum_sq / global_n) # compute global std 87 | 88 | if with_min_and_max: 89 | global_min = mpi_op(np.min(x) if len(x) > 0 else np.inf, op=MPI.MIN) 90 | global_max = mpi_op(np.max(x) if len(x) > 0 else -np.inf, op=MPI.MAX) 91 | return mean, std, global_min, global_max 92 | return mean, std -------------------------------------------------------------------------------- /offline/policy.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Optional, Sequence, Tuple 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from tensorflow_probability.substrates import jax as tfp 9 | 10 | tfd = tfp.distributions 11 | tfb = tfp.bijectors 12 | 13 | from common import MLP, Params, PRNGKey, default_init 14 | 15 | LOG_STD_MIN = -10.0 16 | LOG_STD_MAX = 2.0 17 | 18 | 19 | class NormalTanhPolicy(nn.Module): 20 | hidden_dims: Sequence[int] 21 | action_dim: int 22 | state_dependent_std: bool = True 23 | dropout_rate: Optional[float] = None 24 | log_std_scale: float = 1.0 25 | log_std_min: Optional[float] = None 26 | log_std_max: Optional[float] = None 27 | tanh_squash_distribution: bool = True 28 | 29 | @nn.compact 30 | def __call__(self, 31 | observations: jnp.ndarray, 32 | temperature: float = 1.0, 33 | training: bool = False) -> tfd.Distribution: 34 | outputs = MLP(self.hidden_dims, 35 | activate_final=True, 36 | dropout_rate=self.dropout_rate)(observations, 37 | training=training) 38 | 39 | means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) 40 | 41 | if self.state_dependent_std: 42 | log_stds = nn.Dense(self.action_dim, 43 | kernel_init=default_init( 44 | self.log_std_scale))(outputs) 45 | else: 46 | log_stds = self.param('log_stds', nn.initializers.zeros, 47 | (self.action_dim, )) 48 | 49 | log_std_min = self.log_std_min or LOG_STD_MIN 50 | log_std_max = self.log_std_max or LOG_STD_MAX 51 | log_stds = jnp.clip(log_stds, log_std_min, log_std_max) 52 | 53 | if not self.tanh_squash_distribution: 54 | means = nn.tanh(means) 55 | 56 | base_dist = tfd.MultivariateNormalDiag(loc=means, 57 | scale_diag=jnp.exp(log_stds) * 58 | temperature) 59 | if self.tanh_squash_distribution: 60 | return tfd.TransformedDistribution(distribution=base_dist, 61 | bijector=tfb.Tanh()) 62 | else: 63 | return base_dist 64 | 65 | 66 | @functools.partial(jax.jit, static_argnames=('actor_def', 'distribution')) 67 | def _sample_actions(rng: PRNGKey, 68 | actor_def: nn.Module, 69 | actor_params: Params, 70 | observations: np.ndarray, 71 | temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]: 72 | dist = actor_def.apply({'params': actor_params}, observations, temperature) 73 | rng, key = jax.random.split(rng) 74 | return rng, dist.sample(seed=key) 75 | 76 | 77 | def sample_actions(rng: PRNGKey, 78 | actor_def: nn.Module, 79 | actor_params: Params, 80 | observations: np.ndarray, 81 | temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]: 82 | return _sample_actions(rng, actor_def, actor_params, observations, 83 | temperature) 84 | -------------------------------------------------------------------------------- /offline/value_net.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence, Tuple 2 | 3 | import jax.numpy as jnp 4 | from flax import linen as nn 5 | 6 | from common import MLP 7 | 8 | 9 | class ValueCritic(nn.Module): 10 | hidden_dims: Sequence[int] 11 | layer_norm: bool = False 12 | dropout_rate: Optional[float] = 0.0 13 | 14 | @nn.compact 15 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 16 | critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm, dropout_rate=self.dropout_rate)(observations) 17 | return jnp.squeeze(critic, -1) 18 | 19 | 20 | class Critic(nn.Module): 21 | hidden_dims: Sequence[int] 22 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 23 | layer_norm: bool = False 24 | 25 | @nn.compact 26 | def __call__(self, observations: jnp.ndarray, 27 | actions: jnp.ndarray) -> jnp.ndarray: 28 | inputs = jnp.concatenate([observations, actions], -1) 29 | critic = MLP((*self.hidden_dims, 1), 30 | activations=self.activations, 31 | layer_norm=self.layer_norm)(inputs) 32 | return jnp.squeeze(critic, -1) 33 | 34 | def grad_norm(self, obs, action, interpolate=False, lambda_=1): 35 | 36 | data = jnp.concatenate([obs, action], 1) 37 | if interpolate: 38 | expert_data = jnp.concatenate([obs1, action1], 1) 39 | policy_data = jnp.concatenate([obs2, action2], 1) 40 | 41 | # Interpolate between fake and real images with epsilon 42 | alpha = jax.random.uniform(key, shape=(expert_data.shape[0], 1)) 43 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 44 | data_mix = data * epsilon + fake_data * (1 - epsilon) 45 | 46 | # Fetch the gradient penalty 47 | gradients = critic_forward(params_c, vars_c, data_mix) 48 | gradients = gradients.reshape((gradients.shape[0], -1)) 49 | 50 | alpha = torch.rand(expert_data.size()[0], 1) 51 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 52 | 53 | interpolated = alpha * expert_data + (1 - alpha) * policy_data 54 | interpolated = Variable(interpolated, requires_grad=True) 55 | 56 | interpolated_state, interpolated_action = torch.split( 57 | interpolated, [self.obs_dim, self.action_dim], dim=1) 58 | q = self.forward(interpolated_state, interpolated_action) 59 | ones = torch.ones(q.size()).to(policy_data.device) 60 | gradient = grad( 61 | outputs=q, 62 | inputs=interpolated, 63 | grad_outputs=ones, 64 | create_graph=True, 65 | retain_graph=True, 66 | only_inputs=True, 67 | )[0] 68 | grad_pen = lambda_ * (jnp.linalg.norm(gradient, axis=1) - 1).pow(2).mean() 69 | return grad_pen 70 | 71 | 72 | class DoubleCritic(nn.Module): 73 | hidden_dims: Sequence[int] 74 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 75 | layer_norm: bool = False 76 | 77 | @nn.compact 78 | def __call__(self, observations: jnp.ndarray, 79 | actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 80 | critic1 = Critic(self.hidden_dims, 81 | activations=self.activations, 82 | layer_norm=self.layer_norm)(observations, actions) 83 | critic2 = Critic(self.hidden_dims, 84 | activations=self.activations, 85 | layer_norm=self.layer_norm)(observations, actions) 86 | return critic1, critic2 87 | -------------------------------------------------------------------------------- /offline/common.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple 4 | 5 | import flax 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | 11 | Batch = collections.namedtuple( 12 | 'Batch', 13 | ['observations', 'actions', 'rewards', 'masks', 'next_observations']) 14 | 15 | 16 | def default_init(scale: Optional[float] = jnp.sqrt(2)): 17 | return nn.initializers.orthogonal(scale) 18 | 19 | 20 | PRNGKey = Any 21 | Params = flax.core.FrozenDict[str, Any] 22 | PRNGKey = Any 23 | Shape = Sequence[int] 24 | Dtype = Any # this could be a real type? 25 | InfoDict = Dict[str, float] 26 | 27 | 28 | def _l2_normalize(x, eps=1e-4): 29 | return x * jax.lax.rsqrt((x ** 2).sum() + eps) 30 | 31 | 32 | def _l2_norm(x): 33 | return jnp.sqrt((x ** 2).sum()) 34 | 35 | 36 | def _power_iteration(A, u, n_steps=10): 37 | """Update an estimate of the first right-singular vector of A().""" 38 | def fun(u, _): 39 | v, A_transpose = jax.vjp(A, u) 40 | u, = A_transpose(v) 41 | u = _l2_normalize(u) 42 | return u, None 43 | u, _ = lax.scan(fun, u, xs=None, length=n_steps) 44 | return u 45 | 46 | 47 | def estimate_spectral_norm(f, x, seed=0, n_steps=10): 48 | """Estimate the spectral norm of f(x) linearized at x.""" 49 | rng = jax.random.PRNGKey(seed) 50 | u0 = jax.random.normal(rng, x.shape) 51 | _, f_jvp = jax.linearize(f, x) 52 | u = _power_iteration(f_jvp, u0, n_steps) 53 | sigma = _l2_norm(f_jvp(u)) 54 | return sigma 55 | 56 | 57 | class MLP(nn.Module): 58 | hidden_dims: Sequence[int] 59 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 60 | activate_final: int = False 61 | layer_norm: bool = False 62 | dropout_rate: Optional[float] = None 63 | 64 | @nn.compact 65 | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: 66 | for i, size in enumerate(self.hidden_dims): 67 | x = nn.Dense(size, kernel_init=default_init())(x) 68 | if i + 1 < len(self.hidden_dims) or self.activate_final: 69 | if self.layer_norm: 70 | x = nn.LayerNorm()(x) 71 | x = self.activations(x) 72 | if self.dropout_rate is not None and self.dropout_rate > 0: 73 | x = nn.Dropout(rate=self.dropout_rate)( 74 | x, deterministic=not training) 75 | return x 76 | 77 | 78 | @flax.struct.dataclass 79 | class Model: 80 | step: int 81 | apply_fn: nn.Module = flax.struct.field(pytree_node=False) 82 | params: Params 83 | tx: Optional[optax.GradientTransformation] = flax.struct.field( 84 | pytree_node=False) 85 | opt_state: Optional[optax.OptState] = None 86 | 87 | @classmethod 88 | def create(cls, 89 | model_def: nn.Module, 90 | inputs: Sequence[jnp.ndarray], 91 | tx: Optional[optax.GradientTransformation] = None) -> 'Model': 92 | variables = model_def.init(*inputs) 93 | 94 | _, params = variables.pop('params') 95 | 96 | if tx is not None: 97 | opt_state = tx.init(params) 98 | else: 99 | opt_state = None 100 | 101 | return cls(step=1, 102 | apply_fn=model_def, 103 | params=params, 104 | tx=tx, 105 | opt_state=opt_state) 106 | 107 | def __call__(self, *args, **kwargs): 108 | return self.apply_fn.apply({'params': self.params}, *args, **kwargs) 109 | 110 | def apply(self, *args, **kwargs): 111 | return self.apply_fn.apply(*args, **kwargs) 112 | 113 | def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']: 114 | grad_fn = jax.grad(loss_fn, has_aux=True) 115 | grads, info = grad_fn(self.params) 116 | 117 | updates, new_opt_state = self.tx.update(grads, self.opt_state, 118 | self.params) 119 | new_params = optax.apply_updates(self.params, updates) 120 | 121 | return self.replace(step=self.step + 1, 122 | params=new_params, 123 | opt_state=new_opt_state), info 124 | 125 | def save(self, save_path: str): 126 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 127 | with open(save_path, 'wb') as f: 128 | f.write(flax.serialization.to_bytes(self.params)) 129 | 130 | def load(self, load_path: str) -> 'Model': 131 | with open(load_path, 'rb') as f: 132 | params = flax.serialization.from_bytes(self.params, f.read()) 133 | return self.replace(params=params) 134 | -------------------------------------------------------------------------------- /offline/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Optional 3 | 4 | import d4rl 5 | import gym 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | Batch = collections.namedtuple( 10 | 'Batch', 11 | ['observations', 'actions', 'rewards', 'masks', 'next_observations']) 12 | 13 | 14 | def split_into_trajectories(observations, actions, rewards, masks, dones_float, 15 | next_observations): 16 | trajs = [[]] 17 | 18 | for i in tqdm(range(len(observations))): 19 | trajs[-1].append((observations[i], actions[i], rewards[i], masks[i], 20 | dones_float[i], next_observations[i])) 21 | if dones_float[i] == 1.0 and i + 1 < len(observations): 22 | trajs.append([]) 23 | 24 | return trajs 25 | 26 | 27 | def merge_trajectories(trajs): 28 | observations = [] 29 | actions = [] 30 | rewards = [] 31 | masks = [] 32 | dones_float = [] 33 | next_observations = [] 34 | 35 | for traj in trajs: 36 | for (obs, act, rew, mask, done, next_obs) in traj: 37 | observations.append(obs) 38 | actions.append(act) 39 | rewards.append(rew) 40 | masks.append(mask) 41 | dones_float.append(done) 42 | next_observations.append(next_obs) 43 | 44 | return np.stack(observations), np.stack(actions), np.stack( 45 | rewards), np.stack(masks), np.stack(dones_float), np.stack( 46 | next_observations) 47 | 48 | 49 | class Dataset(object): 50 | def __init__(self, observations: np.ndarray, actions: np.ndarray, 51 | rewards: np.ndarray, masks: np.ndarray, 52 | dones_float: np.ndarray, next_observations: np.ndarray, 53 | size: int): 54 | self.observations = observations 55 | self.actions = actions 56 | self.rewards = rewards 57 | self.masks = masks 58 | self.dones_float = dones_float 59 | self.next_observations = next_observations 60 | self.size = size 61 | 62 | def sample(self, batch_size: int) -> Batch: 63 | indx = np.random.randint(self.size, size=batch_size) 64 | return Batch(observations=self.observations[indx], 65 | actions=self.actions[indx], 66 | rewards=self.rewards[indx], 67 | masks=self.masks[indx], 68 | next_observations=self.next_observations[indx]) 69 | 70 | 71 | class D4RLDataset(Dataset): 72 | def __init__(self, 73 | env: gym.Env, 74 | clip_to_eps: bool = True, 75 | eps: float = 1e-5): 76 | dataset = d4rl.qlearning_dataset(env) 77 | 78 | if clip_to_eps: 79 | lim = 1 - eps 80 | dataset['actions'] = np.clip(dataset['actions'], -lim, lim) 81 | 82 | dones_float = np.zeros_like(dataset['rewards']) 83 | 84 | for i in range(len(dones_float) - 1): 85 | if np.linalg.norm(dataset['observations'][i + 1] - 86 | dataset['next_observations'][i] 87 | ) > 1e-6 or dataset['terminals'][i] == 1.0: 88 | dones_float[i] = 1 89 | else: 90 | dones_float[i] = 0 91 | 92 | dones_float[-1] = 1 93 | 94 | super().__init__(dataset['observations'].astype(np.float32), 95 | actions=dataset['actions'].astype(np.float32), 96 | rewards=dataset['rewards'].astype(np.float32), 97 | masks=1.0 - dataset['terminals'].astype(np.float32), 98 | dones_float=dones_float.astype(np.float32), 99 | next_observations=dataset['next_observations'].astype( 100 | np.float32), 101 | size=len(dataset['observations'])) 102 | 103 | 104 | class ReplayBuffer(Dataset): 105 | def __init__(self, observation_space: gym.spaces.Box, action_dim: int, 106 | capacity: int): 107 | 108 | observations = np.empty((capacity, *observation_space.shape), 109 | dtype=observation_space.dtype) 110 | actions = np.empty((capacity, action_dim), dtype=np.float32) 111 | rewards = np.empty((capacity, ), dtype=np.float32) 112 | masks = np.empty((capacity, ), dtype=np.float32) 113 | dones_float = np.empty((capacity, ), dtype=np.float32) 114 | next_observations = np.empty((capacity, *observation_space.shape), 115 | dtype=observation_space.dtype) 116 | super().__init__(observations=observations, 117 | actions=actions, 118 | rewards=rewards, 119 | masks=masks, 120 | dones_float=dones_float, 121 | next_observations=next_observations, 122 | size=0) 123 | 124 | self.size = 0 125 | 126 | self.insert_index = 0 127 | self.capacity = capacity 128 | 129 | def initialize_with_dataset(self, dataset: Dataset, 130 | num_samples: Optional[int]): 131 | assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' 132 | 133 | dataset_size = len(dataset.observations) 134 | 135 | if num_samples is None: 136 | num_samples = dataset_size 137 | else: 138 | num_samples = min(dataset_size, num_samples) 139 | assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' 140 | 141 | if num_samples < dataset_size: 142 | perm = np.random.permutation(dataset_size) 143 | indices = perm[:num_samples] 144 | else: 145 | indices = np.arange(num_samples) 146 | 147 | self.observations[:num_samples] = dataset.observations[indices] 148 | self.actions[:num_samples] = dataset.actions[indices] 149 | self.rewards[:num_samples] = dataset.rewards[indices] 150 | self.masks[:num_samples] = dataset.masks[indices] 151 | self.dones_float[:num_samples] = dataset.dones_float[indices] 152 | self.next_observations[:num_samples] = dataset.next_observations[ 153 | indices] 154 | 155 | self.insert_index = num_samples 156 | self.size = num_samples 157 | 158 | def insert(self, observation: np.ndarray, action: np.ndarray, 159 | reward: float, mask: float, done_float: float, 160 | next_observation: np.ndarray): 161 | self.observations[self.insert_index] = observation 162 | self.actions[self.insert_index] = action 163 | self.rewards[self.insert_index] = reward 164 | self.masks[self.insert_index] = mask 165 | self.dones_float[self.insert_index] = done_float 166 | self.next_observations[self.insert_index] = next_observation 167 | 168 | self.insert_index = (self.insert_index + 1) % self.capacity 169 | self.size = min(self.size + 1, self.capacity) 170 | -------------------------------------------------------------------------------- /offline/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | import jax 5 | from functools import partial 6 | 7 | from common import Batch, InfoDict, Model, Params, PRNGKey 8 | 9 | 10 | def chi_square_loss(diff, alpha, args=None): 11 | loss = alpha*jnp.maximum(diff+diff**2/4,0) - (1-alpha)*diff 12 | return loss 13 | 14 | def total_variation_loss(diff, alpha, args=None): 15 | loss = alpha*jnp.maximum(diff,0) - (1-alpha)*diff 16 | return loss 17 | 18 | def reverse_kl_loss(diff, alpha, args=None): 19 | """ Gumbel loss J: E[e^x - x - 1]. For stability to outliers, we scale the gradients with the max value over a batch 20 | and optionally clip the exponent. This has the effect of training with an adaptive lr. 21 | """ 22 | z = diff/alpha 23 | if args.max_clip is not None: 24 | z = jnp.minimum(z, args.max_clip) # clip max value 25 | max_z = jnp.max(z, axis=0) 26 | max_z = jnp.where(max_z < -1.0, -1.0, max_z) 27 | max_z = jax.lax.stop_gradient(max_z) # Detach the gradients 28 | loss = jnp.exp(z - max_z) - z*jnp.exp(-max_z) - jnp.exp(-max_z) # scale by e^max_z 29 | return loss 30 | 31 | def expectile_loss(diff, expectile=0.8): 32 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 33 | return weight * (diff**2) 34 | 35 | def update_v(critic: Model, value: Model, batch: Batch, 36 | expectile: float, loss_temp: float, double: bool, vanilla: bool, key: PRNGKey, args) -> Tuple[Model, InfoDict]: 37 | actions = batch.actions 38 | 39 | rng1, rng2 = jax.random.split(key) 40 | if args.sample_random_times > 0: 41 | # add random actions to smooth loss computation (use 1/2(rho + Unif)) 42 | times = args.sample_random_times 43 | random_action = jax.random.uniform( 44 | rng1, shape=(times * actions.shape[0], 45 | actions.shape[1]), 46 | minval=-1.0, maxval=1.0) 47 | obs = jnp.concatenate([batch.observations, jnp.repeat( 48 | batch.observations, times, axis=0)], axis=0) 49 | acts = jnp.concatenate([batch.actions, random_action], axis=0) 50 | else: 51 | obs = batch.observations 52 | acts = batch.actions 53 | 54 | if args.noise: 55 | std = args.noise_std 56 | noise = jax.random.normal(rng2, shape=(acts.shape[0], acts.shape[1])) 57 | noise = jnp.clip(noise * std, -0.5, 0.5) 58 | acts = (batch.actions + noise) 59 | acts = jnp.clip(acts, -1, 1) 60 | 61 | q1, q2 = critic(obs, acts) 62 | if double: 63 | q = jnp.minimum(q1, q2) 64 | else: 65 | q = q1 66 | 67 | def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 68 | v = value.apply({'params': value_params}, obs) 69 | 70 | if args.f=='chi-square': 71 | value_loss = chi_square_loss(q - v, alpha=loss_temp, args=args).mean() 72 | elif args.f=='total-variation': 73 | value_loss = total_variation_loss(q - v, alpha=loss_temp, args=args).mean() 74 | elif args.f=='reverse-kl': # Same as XQL 75 | value_loss = reverse_kl_loss(q - v, alpha=loss_temp, args=args).mean() 76 | 77 | return value_loss, { 78 | 'value_loss': value_loss, 79 | 'v': v.mean(), 80 | } 81 | 82 | new_value, info = value.apply_gradient(value_loss_fn) 83 | 84 | return new_value, info 85 | 86 | 87 | def update_q(critic: Model, target_value: Model, batch: Batch, 88 | discount: float, double: bool, key: PRNGKey, loss_temp: float, args) -> Tuple[Model, InfoDict]: 89 | next_v = target_value(batch.next_observations) 90 | 91 | target_q = batch.rewards + discount * batch.masks * next_v 92 | 93 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 94 | acts = batch.actions 95 | q1, q2 = critic.apply({'params': critic_params}, batch.observations, acts) 96 | v = target_value(batch.observations) 97 | 98 | def mse_loss(q, q_target, *args): 99 | loss_dict = {} 100 | 101 | x = q-q_target 102 | loss = huber_loss(x, delta=20.0) # x**2 103 | loss_dict['critic_loss'] = loss.mean() 104 | 105 | return loss.mean(), loss_dict 106 | 107 | critic_loss = mse_loss 108 | 109 | if double: 110 | loss1, dict1 = critic_loss(q1, target_q, v, loss_temp) 111 | loss2, dict2 = critic_loss(q2, target_q, v, loss_temp) 112 | 113 | critic_loss = (loss1 + loss2).mean() 114 | for k, v in dict2.items(): 115 | dict1[k] += v 116 | loss_dict = dict1 117 | else: 118 | # critic_loss, loss_dict = dual_q_loss(q1, target_q, v, loss_temp) 119 | critic_loss, loss_dict = critic_loss(q1, target_q, v, loss_temp) 120 | 121 | if args.grad_pen: 122 | lambda_ = args.lambda_gp 123 | q1_grad, q2_grad = grad_norm(critic, critic_params, batch.observations, acts) 124 | loss_dict['q1_grad'] = q1_grad.mean() 125 | loss_dict['q2_grad'] = q2_grad.mean() 126 | 127 | if double: 128 | gp_loss = (q1_grad + q2_grad).mean() 129 | else: 130 | gp_loss = q1_grad.mean() 131 | 132 | critic_loss += lambda_ * gp_loss 133 | 134 | loss_dict.update({ 135 | 'q1': q1.mean(), 136 | 'q2': q2.mean() 137 | }) 138 | return critic_loss, loss_dict 139 | 140 | new_critic, info = critic.apply_gradient(critic_loss_fn) 141 | 142 | return new_critic, info 143 | 144 | 145 | def grad_norm(model, params, obs, action, lambda_=10): 146 | 147 | @partial(jax.vmap, in_axes=(0, 0)) 148 | @partial(jax.jacrev, argnums=1) 149 | def input_grad_fn(obs, action): 150 | return model.apply({'params': params}, obs, action) 151 | 152 | def grad_pen_fn(grad): 153 | # We use gradient penalties inspired from WGAN-LP loss which penalizes grad_norm > 1 154 | penalty = jnp.maximum(jnp.linalg.norm(grad1, axis=-1) - 1, 0)**2 155 | return penalty 156 | 157 | grad1, grad2 = input_grad_fn(obs, action) 158 | 159 | return grad_pen_fn(grad1), grad_pen_fn(grad2) 160 | 161 | 162 | def huber_loss(x, delta: float = 1.): 163 | """Huber loss, similar to L2 loss close to zero, L1 loss away from zero. 164 | See "Robust Estimation of a Location Parameter" by Huber. 165 | (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732). 166 | Args: 167 | x: a vector of arbitrary shape. 168 | delta: the bounds for the huber loss transformation, defaults at 1. 169 | Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`. 170 | Returns: 171 | a vector of same shape of `x`. 172 | """ 173 | # 0.5 * x^2 if |x| <= d 174 | # 0.5 * d^2 + d * (|x| - d) if |x| > d 175 | abs_x = jnp.abs(x) 176 | quadratic = jnp.minimum(abs_x, delta) 177 | # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient. 178 | linear = abs_x - quadratic 179 | return 0.5 * quadratic**2 + delta * linear 180 | -------------------------------------------------------------------------------- /offline/learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | from typing import Optional, Sequence, Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | import os 10 | 11 | import policy 12 | import value_net 13 | from actor import update as awr_update_actor 14 | from common import Batch, InfoDict, Model, PRNGKey 15 | from critic import update_q, update_v 16 | # from dual_critic import update_q_dual, update_v_dual 17 | 18 | from functools import partial 19 | 20 | 21 | def target_update(critic: Model, target_critic: Model, tau: float) -> Model: 22 | new_target_params = jax.tree_map( 23 | lambda p, tp: p * tau + tp * (1 - tau), critic.params, 24 | target_critic.params) 25 | 26 | return target_critic.replace(params=new_target_params) 27 | 28 | 29 | @partial(jax.jit, static_argnames=['double', 'vanilla', 'args']) 30 | def _update_jit( 31 | rng: PRNGKey, actor: Model, critic: Model, value: Model, 32 | target_critic: Model, batch: Batch, discount: float, tau: float, 33 | expectile: float, temperature: float, loss_temp: float, double: bool, vanilla: bool, args, 34 | ) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]: 35 | 36 | key, rng = jax.random.split(rng) 37 | for i in range(args.num_v_updates): 38 | new_value, value_info = update_v(target_critic, value, batch, expectile, loss_temp, double, vanilla, key, args) 39 | value = new_value 40 | new_actor, actor_info = awr_update_actor(key, actor, target_critic, 41 | new_value, batch, temperature, double) 42 | 43 | new_critic, critic_info = update_q(critic, new_value, batch, discount, double, key, loss_temp, args) 44 | 45 | new_target_critic = target_update(new_critic, target_critic, tau) 46 | 47 | return rng, new_actor, new_critic, new_value, new_target_critic, { 48 | **critic_info, 49 | **value_info, 50 | **actor_info 51 | } 52 | 53 | 54 | class Learner(object): 55 | def __init__(self, 56 | seed: int, 57 | observations: jnp.ndarray, 58 | actions: jnp.ndarray, 59 | actor_lr: float = 3e-4, 60 | value_lr: float = 3e-4, 61 | critic_lr: float = 3e-4, 62 | hidden_dims: Sequence[int] = (256, 256), 63 | discount: float = 0.99, 64 | tau: float = 0.005, 65 | expectile: float = 0.8, 66 | temperature: float = 0.1, 67 | dropout_rate: Optional[float] = None, 68 | layernorm: bool = False, 69 | value_dropout_rate: Optional[float] = None, 70 | max_steps: Optional[int] = None, 71 | loss_temp: float = 1.0, 72 | double_q: bool = True, 73 | vanilla: bool = True, 74 | opt_decay_schedule: str = "cosine", 75 | args=None): 76 | """ 77 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290 78 | """ 79 | 80 | self.expectile = expectile 81 | self.tau = tau 82 | self.discount = discount 83 | self.temperature = temperature 84 | self.loss_temp = loss_temp 85 | self.double_q = double_q 86 | self.vanilla = vanilla 87 | self.args = args 88 | 89 | rng = jax.random.PRNGKey(seed) 90 | rng, actor_key, critic_key, value_key = jax.random.split(rng, 4) 91 | 92 | action_dim = actions.shape[-1] 93 | actor_def = policy.NormalTanhPolicy(hidden_dims, 94 | action_dim, 95 | log_std_scale=1e-3, 96 | log_std_min=-5.0, 97 | dropout_rate=dropout_rate, 98 | state_dependent_std=False, 99 | tanh_squash_distribution=False) 100 | 101 | if opt_decay_schedule == "cosine": 102 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps) 103 | optimiser = optax.chain(optax.scale_by_adam(), 104 | optax.scale_by_schedule(schedule_fn)) 105 | else: 106 | optimiser = optax.adam(learning_rate=actor_lr) 107 | 108 | actor = Model.create(actor_def, 109 | inputs=[actor_key, observations], 110 | tx=optimiser) 111 | 112 | critic_def = value_net.DoubleCritic(hidden_dims) 113 | 114 | critic = Model.create(critic_def, 115 | inputs=[critic_key, observations, actions], 116 | tx=optax.adam(learning_rate=critic_lr)) 117 | 118 | value_def = value_net.ValueCritic(hidden_dims, 119 | layer_norm=layernorm, 120 | dropout_rate=value_dropout_rate) 121 | value = Model.create(value_def, 122 | inputs=[value_key, observations], 123 | tx=optax.adam(learning_rate=value_lr)) 124 | 125 | target_critic = Model.create( 126 | critic_def, inputs=[critic_key, observations, actions]) 127 | 128 | self.actor = actor 129 | self.critic = critic 130 | self.value = value 131 | self.target_critic = target_critic 132 | self.rng = rng 133 | 134 | def sample_actions(self, 135 | observations: np.ndarray, 136 | temperature: float = 1.0) -> jnp.ndarray: 137 | rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn, 138 | self.actor.params, observations, 139 | temperature) 140 | self.rng = rng 141 | 142 | actions = np.asarray(actions) 143 | return np.clip(actions, -1, 1) 144 | 145 | def update(self, batch: Batch) -> InfoDict: 146 | new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit( 147 | self.rng, self.actor, self.critic, self.value, self.target_critic, 148 | batch, self.discount, self.tau, self.expectile, self.temperature, self.loss_temp, self.double_q, self.vanilla, self.args) 149 | 150 | self.rng = new_rng 151 | self.actor = new_actor 152 | self.critic = new_critic 153 | self.value = new_value 154 | self.target_critic = new_target_critic 155 | 156 | return info 157 | 158 | def load(self, save_dir: str): 159 | self.actor = self.actor.load(os.path.join(save_dir, 'actor')) 160 | self.critic = self.critic.load(os.path.join(save_dir, 'critic')) 161 | self.value = self.value.load(os.path.join(save_dir, 'value')) 162 | self.target_critic = self.target_critic.load(os.path.join(save_dir, 'critic')) 163 | 164 | def save(self, save_dir: str): 165 | self.actor.save(os.path.join(save_dir, 'actor')) 166 | self.critic.save(os.path.join(save_dir, 'critic')) 167 | self.value.save(os.path.join(save_dir, 'value')) 168 | -------------------------------------------------------------------------------- /offline/train_offline.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 3 | from jax.config import config 4 | 5 | from typing import Tuple 6 | 7 | import datetime 8 | import gym 9 | import numpy as np 10 | import tqdm 11 | import time 12 | import absl 13 | import sys 14 | from absl import app, flags 15 | from ml_collections import config_flags 16 | from tensorboardX import SummaryWriter 17 | from dataclasses import dataclass 18 | 19 | import wrappers 20 | from dataset_utils import D4RLDataset, split_into_trajectories 21 | from evaluation import evaluate 22 | from learner import Learner 23 | import warnings 24 | from logging_utils.logx import EpochLogger 25 | 26 | 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.') 31 | flags.DEFINE_string('f', 'chi-square', 'f-divergence to use.[chi-square, total-variation, reverse-KL(XQL)]') 32 | flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.') 33 | flags.DEFINE_string('exp_name', 'dump', 'Epoch logging dir.') 34 | flags.DEFINE_integer('seed', 42, 'Random seed.') 35 | flags.DEFINE_integer('eval_episodes', 10, 36 | 'Number of episodes used for evaluation.') 37 | flags.DEFINE_integer('log_interval', 5000, 'Logging interval.') 38 | flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.') 39 | flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.') 40 | flags.DEFINE_float('temp', 1.0, 'Loss temperature') 41 | flags.DEFINE_boolean('double', True, 'Use double q-learning') 42 | flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') 43 | flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.') 44 | flags.DEFINE_integer('sample_random_times', 0, 'Number of random actions to add to smooth dataset') 45 | flags.DEFINE_boolean('grad_pen', False, 'Add a gradient penalty to critic network') 46 | flags.DEFINE_float('lambda_gp', 1, 'Gradient penalty coefficient') 47 | flags.DEFINE_float('max_clip', 7., 'Loss clip value') 48 | flags.DEFINE_integer('num_v_updates', 1, 'Number of value updates per iter') 49 | flags.DEFINE_boolean('log_loss', False, 'Use log gumbel loss') 50 | flags.DEFINE_boolean('noise', False, 'Add noise to actions') 51 | flags.DEFINE_float('noise_std', 0.1, 'Noise std for actions') 52 | 53 | config_flags.DEFINE_config_file( 54 | 'config', 55 | 'default.py', 56 | 'File path to the training hyperparameter configuration.', 57 | lock_config=False) 58 | 59 | 60 | 61 | @dataclass(frozen=True) 62 | class ConfigArgs: 63 | f : str 64 | sample_random_times: int 65 | grad_pen: bool 66 | noise: bool 67 | noise_std: float 68 | lambda_gp: int 69 | max_clip: float 70 | num_v_updates: int 71 | log_loss: bool 72 | 73 | 74 | def normalize(dataset): 75 | 76 | trajs = split_into_trajectories(dataset.observations, dataset.actions, 77 | dataset.rewards, dataset.masks, 78 | dataset.dones_float, 79 | dataset.next_observations) 80 | 81 | def compute_returns(traj): 82 | episode_return = 0 83 | for _, _, rew, _, _, _ in traj: 84 | episode_return += rew 85 | 86 | return episode_return 87 | 88 | trajs.sort(key=compute_returns) 89 | 90 | dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0]) 91 | dataset.rewards *= 1000.0 92 | 93 | 94 | def make_env_and_dataset(env_name: str, 95 | seed: int) -> Tuple[gym.Env, D4RLDataset]: 96 | env = gym.make(env_name) 97 | 98 | env = wrappers.EpisodeMonitor(env) 99 | env = wrappers.SinglePrecision(env) 100 | 101 | env.seed(seed) 102 | env.action_space.seed(seed) 103 | env.observation_space.seed(seed) 104 | 105 | dataset = D4RLDataset(env) 106 | 107 | if 'antmaze' in FLAGS.env_name: 108 | dataset.rewards -= 1.0 109 | # dataset.rewards = (dataset.rewards - 0.5) * 4 110 | # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 111 | elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name 112 | or 'hopper' in FLAGS.env_name): 113 | normalize(dataset) 114 | 115 | return env, dataset 116 | 117 | 118 | def main(_): 119 | ts_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d_%H-%M-%S") 120 | save_dir = os.path.join(FLAGS.save_dir, ts_str) 121 | exp_id = f"results/offline_rl/{FLAGS.env_name}/" + FLAGS.exp_name 122 | log_folder = exp_id + '/'+FLAGS.exp_name+'_s'+str(FLAGS.seed) 123 | logger_kwargs={'output_dir':log_folder, 'exp_name':FLAGS.exp_name} 124 | 125 | import pandas as pd 126 | if os.path.isfile(log_folder+'/progress.txt'): 127 | try: 128 | df = pd.read_csv(log_folder+'/progress.txt',delim_whitespace=True) 129 | if df['Iterations'].to_numpy()[-1]>900000: 130 | print("Exiting because already trained") 131 | exit() 132 | except: 133 | print("Cannot read progress.txt") 134 | pass 135 | e_logger = EpochLogger(**logger_kwargs) 136 | hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name) 137 | hparam_str = ','.join([ 138 | '%s=%s' % (k, str(hparam_str_dict[k])) 139 | for k in sorted(hparam_str_dict.keys()) 140 | ]) 141 | 142 | os.makedirs(save_dir, exist_ok=True) 143 | 144 | env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed) 145 | 146 | kwargs = dict(FLAGS.config) 147 | args = ConfigArgs(f = FLAGS.f, 148 | sample_random_times=FLAGS.sample_random_times, 149 | grad_pen=FLAGS.grad_pen, 150 | lambda_gp=FLAGS.lambda_gp, 151 | noise=FLAGS.noise, 152 | max_clip=FLAGS.max_clip, 153 | num_v_updates=FLAGS.num_v_updates, 154 | log_loss=FLAGS.log_loss, 155 | noise_std=FLAGS.noise_std) 156 | agent = Learner(FLAGS.seed, 157 | env.observation_space.sample()[np.newaxis], 158 | env.action_space.sample()[np.newaxis], 159 | max_steps=FLAGS.max_steps, 160 | loss_temp=FLAGS.temp, 161 | double_q=FLAGS.double, 162 | vanilla=False, 163 | args=args, 164 | **kwargs) 165 | 166 | best_eval_returns = -np.inf 167 | eval_returns = [] 168 | for i in range(1, FLAGS.max_steps + 1): # Remove TQDM 169 | batch = dataset.sample(FLAGS.batch_size) 170 | 171 | update_info = agent.update(batch) 172 | 173 | if i % FLAGS.eval_interval == 0: 174 | eval_stats = evaluate(agent, env, FLAGS.eval_episodes) 175 | if eval_stats['return'] >= best_eval_returns: 176 | # Store best eval returns 177 | best_eval_returns = eval_stats['return'] 178 | e_logger.log_tabular('Iterations', i) 179 | e_logger.log_tabular('AverageNormalizedReturn', eval_stats['return']) 180 | e_logger.dump_tabular() 181 | 182 | eval_returns.append((i, eval_stats['return'])) 183 | 184 | 185 | sys.exit(0) 186 | os._exit(0) 187 | raise SystemExit 188 | 189 | 190 | if __name__ == '__main__': 191 | app.run(main) 192 | -------------------------------------------------------------------------------- /offline/environment.yml: -------------------------------------------------------------------------------- 1 | name: dualrl 2 | channels: 3 | - pytorch 4 | - borismarin 5 | - nvidia 6 | - defaults 7 | - conda-forge 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_kmp_llvm 11 | - aom=3.5.0=h27087fc_0 12 | - blas=2.3=openblas 13 | - brotlipy=0.7.0=py38h0a891b7_1005 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2022.12.7=ha878542_0 16 | - certifi=2022.12.7=pyhd8ed1ab_0 17 | - cffi=1.15.1=py38h4a40e3a_3 18 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 19 | - cryptography=39.0.2=py38h3d167d9_0 20 | - cuda=11.7.1=0 21 | - cuda-cccl=11.7.91=0 22 | - cuda-command-line-tools=11.7.1=0 23 | - cuda-compiler=11.7.1=0 24 | - cuda-cudart=11.7.99=0 25 | - cuda-cudart-dev=11.7.99=0 26 | - cuda-cuobjdump=11.7.91=0 27 | - cuda-cupti=11.7.101=0 28 | - cuda-cuxxfilt=11.7.91=0 29 | - cuda-demo-suite=12.1.55=0 30 | - cuda-documentation=12.1.55=0 31 | - cuda-driver-dev=11.7.99=0 32 | - cuda-gdb=12.1.55=0 33 | - cuda-libraries=11.7.1=0 34 | - cuda-libraries-dev=11.7.1=0 35 | - cuda-memcheck=11.8.86=0 36 | - cuda-nsight=12.1.55=0 37 | - cuda-nsight-compute=12.1.0=0 38 | - cuda-nvcc=11.7.99=0 39 | - cuda-nvdisasm=12.1.55=0 40 | - cuda-nvml-dev=11.7.91=0 41 | - cuda-nvprof=12.1.55=0 42 | - cuda-nvprune=11.7.91=0 43 | - cuda-nvrtc=11.7.99=0 44 | - cuda-nvrtc-dev=11.7.99=0 45 | - cuda-nvtx=11.7.91=0 46 | - cuda-nvvp=12.1.55=0 47 | - cuda-runtime=11.7.1=0 48 | - cuda-sanitizer-api=12.1.55=0 49 | - cuda-toolkit=11.7.1=0 50 | - cuda-tools=11.7.1=0 51 | - cuda-visual-tools=11.7.1=0 52 | - cudatoolkit=11.8.0=h37601d7_11 53 | - cudnn=8.4.1.50=hed8a83a_0 54 | - cycler=0.11.0=pyhd8ed1ab_0 55 | - expat=2.5.0=h27087fc_0 56 | - ffmpeg=5.1.2=gpl_h8dda1f0_106 57 | - flit-core=3.8.0=pyhd8ed1ab_0 58 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 59 | - font-ttf-inconsolata=3.000=h77eed37_0 60 | - font-ttf-source-code-pro=2.038=h77eed37_0 61 | - font-ttf-ubuntu=0.83=hab24e00_0 62 | - fontconfig=2.14.2=h14ed4e7_0 63 | - fonts-conda-ecosystem=1=0 64 | - fonts-conda-forge=1=0 65 | - freetype=2.12.1=hca18f0e_1 66 | - gds-tools=1.6.0.25=0 67 | - gettext=0.21.1=h27087fc_0 68 | - giflib=5.2.1=h36c2ea0_2 69 | - glew=2.1.0=h9c3ff4c_2 70 | - gmp=6.2.1=h58526e2_0 71 | - gnutls=3.7.8=hf3e180e_0 72 | - icu=70.1=h27087fc_0 73 | - idna=3.4=pyhd8ed1ab_0 74 | - intel-openmp=2021.4.0=h06a4308_3561 75 | - jpeg=9e=h0b41bf4_3 76 | - kiwisolver=1.4.4=py38h43d8883_1 77 | - lame=3.100=h166bdaf_1003 78 | - lcms2=2.14=hfd0df8a_1 79 | - ld_impl_linux-64=2.40=h41732ed_0 80 | - lerc=4.0.0=h27087fc_0 81 | - libblas=3.9.0=3_openblas 82 | - libcblas=3.9.0=3_openblas 83 | - libcublas=11.10.3.66=0 84 | - libcublas-dev=11.10.3.66=0 85 | - libcufft=10.7.2.124=h4fbf590_0 86 | - libcufft-dev=10.7.2.124=h98a8f43_0 87 | - libcufile=1.6.0.25=0 88 | - libcufile-dev=1.6.0.25=0 89 | - libcurand=10.3.2.56=0 90 | - libcurand-dev=10.3.2.56=0 91 | - libcusolver=11.4.0.1=0 92 | - libcusolver-dev=11.4.0.1=0 93 | - libcusparse=11.7.4.91=0 94 | - libcusparse-dev=11.7.4.91=0 95 | - libdeflate=1.17=h0b41bf4_0 96 | - libdrm=2.4.114=h166bdaf_0 97 | - libffi=3.4.2=h7f98852_5 98 | - libgcc-ng=12.2.0=h65d4601_19 99 | - libgfortran-ng=7.5.0=h14aa051_20 100 | - libgfortran4=7.5.0=h14aa051_20 101 | - libglu=9.0.0=he1b5a44_1001 102 | - libgomp=12.2.0=h65d4601_19 103 | - libhwloc=2.9.0=hd6dc26d_0 104 | - libiconv=1.17=h166bdaf_0 105 | - libidn2=2.3.4=h166bdaf_0 106 | - liblapack=3.9.0=3_openblas 107 | - liblapacke=3.9.0=3_openblas 108 | - libnpp=11.7.4.75=0 109 | - libnpp-dev=11.7.4.75=0 110 | - libnsl=2.0.0=h7f98852_0 111 | - libnvjpeg=11.8.0.2=0 112 | - libnvjpeg-dev=11.8.0.2=0 113 | - libopenblas=0.3.12=pthreads_hb3c22a3_1 114 | - libopus=1.3.1=h7f98852_1 115 | - libpciaccess=0.17=h166bdaf_0 116 | - libpng=1.6.39=h753d276_0 117 | - libprotobuf=3.21.12=h3eb15da_0 118 | - libsqlite=3.40.0=h753d276_0 119 | - libstdcxx-ng=12.2.0=h46fd767_19 120 | - libtasn1=4.19.0=h166bdaf_0 121 | - libtiff=4.5.0=h6adf6a1_2 122 | - libunistring=0.9.10=h7f98852_0 123 | - libuuid=2.32.1=h7f98852_1000 124 | - libva=2.17.0=h0b41bf4_0 125 | - libvpx=1.11.0=h9c3ff4c_3 126 | - libwebp=1.2.4=h1daa5a0_1 127 | - libwebp-base=1.2.4=h166bdaf_0 128 | - libx11=1.6.2=0 129 | - libxcb=1.13=h7f98852_1004 130 | - libxml2=2.10.3=h7463322_0 131 | - libzlib=1.2.13=h166bdaf_4 132 | - llvm-openmp=15.0.7=h0cdce71_0 133 | - lz4-c=1.9.4=hcb278e6_0 134 | - magma=2.6.2=hc72dce7_0 135 | - matplotlib=3.3.2=0 136 | - matplotlib-base=3.3.2=py38h5c7f4ab_1 137 | - mkl=2022.2.1=h84fe81f_16997 138 | - mkl-service=2.4.0=py38h80f09db_0 139 | - mkl_fft=1.3.1=py38hf8530d2_4 140 | - mkl_random=1.2.2=py38h5ca245f_1 141 | - mpi=1.0=openmpi 142 | - mpi4py=3.1.2=py38h3e8e7aa_0 143 | - nccl=2.14.3.1=h0800d71_0 144 | - ncurses=6.3=h27087fc_1 145 | - nettle=3.8.1=hc379101_1 146 | - ninja=1.11.1=h924138e_0 147 | - nsight-compute=2023.1.0.15=0 148 | - numpy-base=1.18.5=py38h2f8d375_0 149 | - openh264=2.3.1=hcb278e6_2 150 | - openjpeg=2.5.0=hfec8fc6_2 151 | - openmpi=4.1.3=hbea3300_101 152 | - openssl=3.1.0=h0b41bf4_0 153 | - p11-kit=0.24.1=hc5aa10d_0 154 | - pthread-stubs=0.4=h36c2ea0_1001 155 | - pycparser=2.21=pyhd8ed1ab_0 156 | - pyopenssl=23.0.0=pyhd8ed1ab_0 157 | - pysocks=1.7.1=pyha2e5f31_6 158 | - python=3.8.13=ha86cf86_0_cpython 159 | - python-dateutil=2.8.2=pyhd8ed1ab_0 160 | - python_abi=3.8=3_cp38 161 | - pytorch=1.13.1=cuda112py38hd94e077_200 162 | - pytorch-cuda=11.7=h67b0de4_1 163 | - pytorch-mutex=1.0=cuda 164 | - readline=8.1.2=h0f457ee_0 165 | - requests=2.28.2=pyhd8ed1ab_0 166 | - setuptools=67.4.0=pyhd8ed1ab_0 167 | - six=1.16.0=pyh6c4a22f_0 168 | - sleef=3.5.1=h9b69904_2 169 | - sqlite=3.40.0=h4ff8645_0 170 | - svt-av1=1.4.1=hcb278e6_0 171 | - tbb=2021.8.0=hf52228f_0 172 | - tk=8.6.12=h27826a3_0 173 | - torchaudio=0.13.1=py38_cu117 174 | - torchvision=0.14.1=cuda112py38hebebe89_0 175 | - tornado=6.3=py38h1de0b5d_0 176 | - typing_extensions=4.4.0=pyha770c72_0 177 | - wheel=0.38.4=pyhd8ed1ab_0 178 | - x264=1!164.3095=h166bdaf_2 179 | - x265=3.5=h924138e_3 180 | - xorg-fixesproto=5.0=h7f98852_1002 181 | - xorg-kbproto=1.0.7=h7f98852_1002 182 | - xorg-libx11=1.6.12=h36c2ea0_0 183 | - xorg-libxau=1.0.9=h7f98852_0 184 | - xorg-libxdmcp=1.1.3=h7f98852_0 185 | - xorg-libxext=1.3.4=h516909a_0 186 | - xorg-libxfixes=5.0.3=h516909a_1004 187 | - xorg-xextproto=7.3.0=h0b41bf4_1003 188 | - xorg-xproto=7.0.31=h7f98852_1007 189 | - xz=5.2.6=h166bdaf_0 190 | - zlib=1.2.13=h166bdaf_4 191 | - zstd=1.5.2=h3eb15da_6 192 | - pip: 193 | - absl-py==1.3.0 194 | - asttokens==2.2.1 195 | - backcall==0.2.0 196 | - beautifulsoup4==4.11.1 197 | - cachetools==5.2.0 198 | - chex==0.1.5 199 | - click==8.1.3 200 | - cloudpickle==2.2.0 201 | - colorama==0.4.6 202 | - commonmark==0.9.1 203 | - contextlib2==21.6.0 204 | - contourpy==1.0.6 205 | - cython==0.29.32 206 | - d4rl==1.1 207 | - decorator==5.1.1 208 | - dm-control==1.0.8 209 | - dm-env==1.5 210 | - dm-tree==0.1.7 211 | - docker-pycreds==0.4.0 212 | - etils==0.9.0 213 | - executing==1.2.0 214 | - fasteners==0.18 215 | - filelock==3.8.0 216 | - flax==0.5.3 217 | - fonttools==4.38.0 218 | - gast==0.5.3 219 | - gdown==4.5.3 220 | - gitdb==4.0.9 221 | - gitpython==3.1.29 222 | - glfw==2.5.5 223 | - google-auth==2.14.1 224 | - google-auth-oauthlib==0.4.6 225 | - grpcio==1.50.0 226 | - gym==0.23.1 227 | - gym-notices==0.0.8 228 | - h5py==3.8.0 229 | - imageio==2.22.4 230 | - imageio-ffmpeg==0.4.7 231 | - importlib-metadata==5.0.0 232 | - importlib-resources==5.10.0 233 | - install==1.3.5 234 | - ipdb==0.13.13 235 | - ipython==8.11.0 236 | - jax==0.4.5 237 | - jaxlib==0.4.4+cuda11.cudnn82 238 | - jedi==0.18.2 239 | - joblib==1.2.0 240 | - labmaze==1.0.5 241 | - lockfile==0.12.2 242 | - lxml==4.9.1 243 | - markdown==3.4.1 244 | - markupsafe==2.1.1 245 | - matplotlib-inline==0.1.6 246 | - mjrl==1.0.0 247 | - ml-collections==0.1.1 248 | - msgpack==1.0.4 249 | - mujoco==2.3.2 250 | - mujoco-py==2.1.2.14 251 | - numpy==1.23.4 252 | - nvidia-cublas-cu11==11.10.3.66 253 | - nvidia-cuda-nvrtc-cu11==11.7.99 254 | - nvidia-cuda-runtime-cu11==11.7.99 255 | - nvidia-cudnn-cu11==8.5.0.96 256 | - oauthlib==3.2.2 257 | - opencv-python==4.7.0.72 258 | - opt-einsum==3.3.0 259 | - optax==0.1.3 260 | - packaging==21.3 261 | - pandas==1.3.2 262 | - parso==0.8.3 263 | - patchelf==0.17.2.1 264 | - pathtools==0.1.2 265 | - pexpect==4.8.0 266 | - pickleshare==0.7.5 267 | - pillow==9.3.0 268 | - pip==23.0.1 269 | - plotly==5.14.1 270 | - promise==2.3 271 | - prompt-toolkit==3.0.38 272 | - protobuf==3.19.6 273 | - psutil==5.9.4 274 | - ptyprocess==0.7.0 275 | - pure-eval==0.2.2 276 | - pyasn1==0.4.8 277 | - pyasn1-modules==0.2.8 278 | - pybullet==3.2.5 279 | - pygments==2.13.0 280 | - pyopengl==3.1.6 281 | - pyparsing==2.4.7 282 | - pytz==2023.3 283 | - pyyaml==6.0 284 | - requests-oauthlib==1.3.1 285 | - rich==11.2.0 286 | - rsa==4.9 287 | - scipy==1.9.3 288 | - sentry-sdk==1.10.1 289 | - setproctitle==1.3.2 290 | - shortuuid==1.0.9 291 | - sk-video==1.1.10 292 | - smmap==5.0.0 293 | - soupsieve==2.3.2.post1 294 | - stack-data==0.6.2 295 | - tenacity==8.2.2 296 | - tensorboard==2.10.1 297 | - tensorboard-data-server==0.6.1 298 | - tensorboard-plugin-wit==1.8.1 299 | - tensorboardx==2.5.1 300 | - tensorflow-probability==0.18.0 301 | - tensorstore==0.1.28 302 | - termcolor==2.1.0 303 | - tomli==2.0.1 304 | - toolz==0.12.0 305 | - tqdm==4.64.1 306 | - traitlets==5.9.0 307 | - urllib3==1.26.12 308 | - wandb==0.13.5 309 | - wcwidth==0.2.6 310 | - werkzeug==2.2.2 311 | - zipp==3.10.0 312 | -------------------------------------------------------------------------------- /offline/logging_utils/logx.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Some simple logging functionality, inspired by rllab's logging. 4 | 5 | Logs to a tab-separated-values file (path/to/output_directory/progress.txt) 6 | 7 | """ 8 | import json 9 | import joblib 10 | import shutil 11 | import numpy as np 12 | import os.path as osp, time, atexit, os 13 | import warnings 14 | from logging_utils.serialization_utils import convert_json 15 | 16 | color2num = dict( 17 | gray=30, 18 | red=31, 19 | green=32, 20 | yellow=33, 21 | blue=34, 22 | magenta=35, 23 | cyan=36, 24 | white=37, 25 | crimson=38 26 | ) 27 | 28 | def colorize(string, color, bold=False, highlight=False): 29 | """ 30 | Colorize a string. 31 | 32 | This function was originally written by John Schulman. 33 | """ 34 | attr = [] 35 | num = color2num[color] 36 | if highlight: num += 10 37 | attr.append(str(num)) 38 | if bold: attr.append('1') 39 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 40 | 41 | 42 | class Logger: 43 | """ 44 | A general-purpose logger. 45 | 46 | Makes it easy to save diagnostics, hyperparameter configurations, the 47 | state of a training run, and the trained model. 48 | """ 49 | 50 | def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None): 51 | """ 52 | Initialize a Logger. 53 | 54 | Args: 55 | output_dir (string): A directory for saving results to. If 56 | ``None``, defaults to a temp directory of the form 57 | ``/tmp/experiments/somerandomnumber``. 58 | 59 | output_fname (string): Name for the tab-separated-value file 60 | containing metrics logged throughout a training run. 61 | Defaults to ``progress.txt``. 62 | 63 | exp_name (string): Experiment name. If you run multiple training 64 | runs and give them all the same ``exp_name``, the plotter 65 | will know to group them. (Use case: if you run the same 66 | hyperparameter configuration with multiple random seeds, you 67 | should give them all the same ``exp_name``.) 68 | """ 69 | # if proc_id()==0: 70 | self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time()) 71 | if osp.exists(self.output_dir): 72 | print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir) 73 | else: 74 | os.makedirs(self.output_dir) 75 | self.output_file = open(osp.join(self.output_dir, output_fname), 'w') 76 | atexit.register(self.output_file.close) 77 | print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True)) 78 | 79 | self.first_row=True 80 | self.log_headers = [] 81 | self.log_current_row = {} 82 | self.exp_name = exp_name 83 | 84 | def log(self, msg, color='green'): 85 | """Print a colorized message to stdout.""" 86 | # if proc_id()==0: 87 | print(colorize(msg, color, bold=True)) 88 | 89 | def log_tabular(self, key, val): 90 | """ 91 | Log a value of some diagnostic. 92 | 93 | Call this only once for each diagnostic quantity, each iteration. 94 | After using ``log_tabular`` to store values for each diagnostic, 95 | make sure to call ``dump_tabular`` to write them out to file and 96 | stdout (otherwise they will not get saved anywhere). 97 | """ 98 | if self.first_row: 99 | self.log_headers.append(key) 100 | else: 101 | assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key 102 | assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key 103 | self.log_current_row[key] = val 104 | 105 | def save_config(self, config): 106 | """ 107 | Log an experiment configuration. 108 | 109 | Call this once at the top of your experiment, passing in all important 110 | config vars as a dict. This will serialize the config to JSON, while 111 | handling anything which can't be serialized in a graceful way (writing 112 | as informative a string as possible). 113 | 114 | Example use: 115 | 116 | .. code-block:: python 117 | 118 | logger = EpochLogger(**logger_kwargs) 119 | logger.save_config(locals()) 120 | """ 121 | config_json = convert_json(config) 122 | if self.exp_name is not None: 123 | config_json['exp_name'] = self.exp_name 124 | # if proc_id()==0: 125 | output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True) 126 | print(colorize('Saving config:\n', color='cyan', bold=True)) 127 | print(output) 128 | with open(osp.join(self.output_dir, "config.json"), 'w') as out: 129 | out.write(output) 130 | 131 | def setup_pytorch_saver(self, what_to_save): 132 | """ 133 | Set up easy model saving for a single PyTorch model. 134 | 135 | Because PyTorch saving and loading is especially painless, this is 136 | very minimal; we just need references to whatever we would like to 137 | pickle. This is integrated into the logger because the logger 138 | knows where the user would like to save information about this 139 | training run. 140 | 141 | Args: 142 | what_to_save: Any PyTorch model or serializable object containing 143 | PyTorch models. 144 | """ 145 | self.pytorch_saver_elements = what_to_save 146 | 147 | def _pytorch_simple_save(self, itr=None): 148 | """ 149 | Saves the PyTorch model (or models). 150 | """ 151 | # if proc_id()==0: 152 | assert hasattr(self, 'pytorch_saver_elements'), \ 153 | "First have to setup saving with self.setup_pytorch_saver" 154 | fpath = 'pyt_save' 155 | fpath = osp.join(self.output_dir, fpath) 156 | fname = 'model' + ('%d'%itr if itr is not None else '') + '.pt' 157 | fname = osp.join(fpath, fname) 158 | os.makedirs(fpath, exist_ok=True) 159 | 160 | def dump_tabular(self): 161 | """ 162 | Write all of the diagnostics from the current iteration. 163 | 164 | Writes both to stdout, and to the output file. 165 | """ 166 | # if proc_id()==0: 167 | vals = [] 168 | key_lens = [len(key) for key in self.log_headers] 169 | max_key_len = max(15,max(key_lens)) 170 | keystr = '%'+'%d'%max_key_len 171 | fmt = "| " + keystr + "s | %15s |" 172 | n_slashes = 22 + max_key_len 173 | print("-"*n_slashes) 174 | for key in self.log_headers: 175 | val = self.log_current_row.get(key, "") 176 | valstr = "%8.3g"%val if hasattr(val, "__float__") else val 177 | print(fmt%(key, valstr)) 178 | vals.append(val) 179 | print("-"*n_slashes, flush=True) 180 | if self.output_file is not None: 181 | if self.first_row: 182 | self.output_file.write("\t".join(self.log_headers)+"\n") 183 | self.output_file.write("\t".join(map(str,vals))+"\n") 184 | self.output_file.flush() 185 | self.log_current_row.clear() 186 | self.first_row=False 187 | 188 | class EpochLogger(Logger): 189 | """ 190 | A variant of Logger tailored for tracking average values over epochs. 191 | 192 | Typical use case: there is some quantity which is calculated many times 193 | throughout an epoch, and at the end of the epoch, you would like to 194 | report the average / std / min / max value of that quantity. 195 | 196 | With an EpochLogger, each time the quantity is calculated, you would 197 | use 198 | 199 | .. code-block:: python 200 | 201 | epoch_logger.store(NameOfQuantity=quantity_value) 202 | 203 | to load it into the EpochLogger's state. Then at the end of the epoch, you 204 | would use 205 | 206 | .. code-block:: python 207 | 208 | epoch_logger.log_tabular(NameOfQuantity, **options) 209 | 210 | to record the desired values. 211 | """ 212 | 213 | def __init__(self, *args, **kwargs): 214 | super().__init__(*args, **kwargs) 215 | self.epoch_dict = dict() 216 | 217 | def store(self, **kwargs): 218 | """ 219 | Save something into the epoch_logger's current state. 220 | 221 | Provide an arbitrary number of keyword arguments with numerical 222 | values. 223 | """ 224 | for k,v in kwargs.items(): 225 | if not(k in self.epoch_dict.keys()): 226 | self.epoch_dict[k] = [] 227 | self.epoch_dict[k].append(v) 228 | 229 | def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False): 230 | """ 231 | Log a value or possibly the mean/std/min/max values of a diagnostic. 232 | 233 | Args: 234 | key (string): The name of the diagnostic. If you are logging a 235 | diagnostic whose state has previously been saved with 236 | ``store``, the key here has to match the key you used there. 237 | 238 | val: A value for the diagnostic. If you have previously saved 239 | values for this key via ``store``, do *not* provide a ``val`` 240 | here. 241 | 242 | with_min_and_max (bool): If true, log min and max values of the 243 | diagnostic over the epoch. 244 | 245 | average_only (bool): If true, do not log the standard deviation 246 | of the diagnostic over the epoch. 247 | """ 248 | if val is not None: 249 | super().log_tabular(key,val) 250 | else: 251 | v = self.epoch_dict[key] 252 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 253 | # stats = mpi_statistics_scalar(vals, with_min_and_max=with_min_and_max) 254 | stats = self.stats_scalar(vals, with_min_and_max=with_min_and_max) 255 | super().log_tabular(key if average_only else 'Average' + key, stats[0]) 256 | if not(average_only): 257 | super().log_tabular('Std'+key, stats[1]) 258 | if with_min_and_max: 259 | super().log_tabular('Max'+key, stats[3]) 260 | super().log_tabular('Min'+key, stats[2]) 261 | self.epoch_dict[key] = [] 262 | 263 | def get_stats(self, key): 264 | """ 265 | Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. 266 | """ 267 | v = self.epoch_dict[key] 268 | vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v 269 | return np.mean(vals), np.std(vals) 270 | # return mpi_statistics_scalar(vals) 271 | 272 | def stats_scalar(vals, with_min_and_max=False): 273 | if with_min_and_max: 274 | return np.mean(vals), np.std(vals), np.min(vals), np.max(vals) 275 | else: 276 | return np.mean(vals), np.std(vals) --------------------------------------------------------------------------------