├── jaxrl ├── __init__.py ├── agents │ ├── bc │ │ ├── __init__.py │ │ ├── actor.py │ │ └── bc_learner.py │ ├── drq │ │ ├── __init__.py │ │ ├── augmentations.py │ │ ├── networks.py │ │ └── drq_learner.py │ ├── sac │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── temperature.py │ │ └── critic.py │ ├── awac │ │ ├── __init__.py │ │ ├── value.py │ │ ├── actor.py │ │ └── awac_learner.py │ ├── ddpg │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── critic.py │ │ └── ddpg_learner.py │ ├── redq │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── critic.py │ │ └── redq_learner.py │ ├── sac_v1 │ │ ├── __init__.py │ │ ├── critic.py │ │ └── sac_v1_learner.py │ └── __init__.py ├── networks │ ├── __init__.py │ ├── critic_net.py │ ├── lion_optax.py │ ├── autoregressive_policy.py │ └── policies.py ├── datasets │ ├── rl_unplugged │ │ ├── __init__.py │ │ ├── README.md │ │ └── preprocess.py │ ├── __init__.py │ ├── dataset_utils.py │ ├── rl_unplugged_dataset.py │ ├── d4rl_dataset.py │ ├── replay_buffer.py │ ├── awac_dataset.py │ └── dataset.py ├── wrappers │ ├── common.py │ ├── __init__.py │ ├── sticky_actions.py │ ├── take_key.py │ ├── rgb2gray.py │ ├── repeat_action.py │ ├── single_precision.py │ ├── episode_monitor.py │ ├── absorbing_states.py │ ├── frame_stack.py │ ├── dmc_env.py │ └── normalization.py ├── utils.py ├── evaluation.py └── dict_learning │ ├── plot_utils.py │ └── task_dict.py ├── configs ├── hard_update.py ├── bc_default.py ├── sac_hat.py ├── ddpg_default.py ├── sac_cw_task_descriptor.py ├── sac_v1_default.py ├── redq_default.py ├── sac_default.py ├── drq_default.py ├── drq_faster.py ├── awac_default.py ├── sac_tadell.yaml ├── sac_cotasp.yaml ├── sac_cw.py └── sac_cotasp.py ├── README.md ├── .gitignore ├── continual_world.py ├── train_tadell.py └── train_cotasp.py /jaxrl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/bc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/drq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/sac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/awac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/redq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/sac_v1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/datasets/rl_unplugged/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl.agents.sac.sac_learner import SACLearner 2 | -------------------------------------------------------------------------------- /jaxrl/wrappers/common.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | TimeStep = Tuple[np.ndarray, float, bool, dict] 6 | -------------------------------------------------------------------------------- /jaxrl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl.datasets.dataset import Batch 2 | from jaxrl.datasets.dataset_utils import make_env_and_dataset 3 | from jaxrl.datasets.replay_buffer import ReplayBuffer 4 | -------------------------------------------------------------------------------- /configs/hard_update.py: -------------------------------------------------------------------------------- 1 | from configs import sac_default as default_lib 2 | 3 | 4 | def get_config(): 5 | config = default_lib.get_config() 6 | 7 | config.tau = 1.0 8 | config.target_update_period = 50 9 | 10 | return config 11 | -------------------------------------------------------------------------------- /configs/bc_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.distribution = 'det' 8 | # det (deterministic) or mog (mixture of gaussians) or made_mog or made_d (discretized) 9 | 10 | config.actor_lr = 1e-3 11 | config.hidden_dims = (256, 256) 12 | 13 | return config 14 | -------------------------------------------------------------------------------- /configs/sac_hat.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.actor_lr = 1e-3 8 | config.critic_lr = 1e-3 9 | config.temp_lr = 1e-3 10 | 11 | config.hidden_dims = (256, 256, 256, 256) 12 | config.name_activation = 'leaky_relu' 13 | config.use_layer_norm = True 14 | 15 | config.init_temperature = 0.02 16 | config.backup_entropy = True 17 | 18 | return config 19 | -------------------------------------------------------------------------------- /configs/ddpg_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.algo = 'ddpg' 8 | 9 | config.actor_lr = 3e-4 10 | config.critic_lr = 3e-4 11 | 12 | config.hidden_dims = (256, 256) 13 | 14 | config.discount = 0.99 15 | 16 | config.tau = 0.005 17 | config.target_update_period = 1 18 | 19 | config.replay_buffer_size = None 20 | 21 | config.exploration_noise = 0.1 22 | 23 | return config 24 | -------------------------------------------------------------------------------- /configs/sac_cw_task_descriptor.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.coder_lr = 1e-3 8 | config.actor_lr = 1e-3 9 | config.critic_lr = 1e-3 10 | config.temp_lr = 1e-3 11 | 12 | config.hidden_dims = (256, 256, 256, 256) 13 | config.name_activation = 'leaky_relu' 14 | config.use_layer_norm = True 15 | 16 | config.init_temperature = 0.1 17 | config.backup_entropy = True 18 | 19 | return config 20 | -------------------------------------------------------------------------------- /jaxrl/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl.wrappers.absorbing_states import AbsorbingStatesWrapper 2 | from jaxrl.wrappers.dmc_env import DMCEnv 3 | from jaxrl.wrappers.episode_monitor import EpisodeMonitor 4 | from jaxrl.wrappers.frame_stack import FrameStack 5 | from jaxrl.wrappers.repeat_action import RepeatAction 6 | from jaxrl.wrappers.rgb2gray import RGB2Gray 7 | from jaxrl.wrappers.single_precision import SinglePrecision 8 | from jaxrl.wrappers.sticky_actions import StickyActionEnv 9 | from jaxrl.wrappers.take_key import TakeKey 10 | -------------------------------------------------------------------------------- /configs/sac_v1_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.algo = 'sac_v1' 8 | 9 | config.actor_lr = 3e-4 10 | config.value_lr = 3e-4 11 | config.critic_lr = 3e-4 12 | config.temp_lr = 3e-4 13 | 14 | config.hidden_dims = (256, 256) 15 | 16 | config.discount = 0.99 17 | 18 | config.tau = 0.005 19 | config.target_update_period = 1 20 | 21 | config.init_temperature = 1.0 22 | config.target_entropy = None 23 | 24 | config.replay_buffer_size = None 25 | 26 | return config 27 | -------------------------------------------------------------------------------- /jaxrl/datasets/rl_unplugged/README.md: -------------------------------------------------------------------------------- 1 | Run to preprocess the datasets: 2 | 3 | ```bash 4 | python preprocess.py --task_name cartpole_swingup 5 | python preprocess.py --task_name cheetah_run 6 | python preprocess.py --task_name finger_turn_hard 7 | python preprocess.py --task_name fish_swim 8 | python preprocess.py --task_name humanoid_run 9 | python preprocess.py --task_name manipulator_insert_ball 10 | python preprocess.py --task_name manipulator_insert_peg 11 | python preprocess.py --task_name walker_stand 12 | python preprocess.py --task_name walker_walk 13 | ``` 14 | 15 | The datasets will be saved in the default d4rl folder. -------------------------------------------------------------------------------- /jaxrl/wrappers/sticky_actions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take from 3 | https://github.com/openai/atari-reset/blob/master/atari_reset/wrappers.py 4 | """ 5 | 6 | import gym 7 | import numpy as np 8 | 9 | 10 | class StickyActionEnv(gym.Wrapper): 11 | 12 | def __init__(self, env, p=0.25): 13 | super().__init__(env) 14 | self.p = p 15 | self.last_action = 0 16 | 17 | def step(self, action): 18 | if np.random.uniform() < self.p: 19 | action = self.last_action 20 | self.last_action = action 21 | obs, reward, done, info = self.env.step(action) 22 | return obs, reward, done, info 23 | -------------------------------------------------------------------------------- /jaxrl/wrappers/take_key.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import gym 4 | 5 | 6 | class TakeKey(gym.ObservationWrapper): 7 | 8 | def __init__(self, env, take_key): 9 | super(TakeKey, self).__init__(env) 10 | self._take_key = take_key 11 | 12 | assert take_key in self.observation_space.spaces 13 | self.observation_space = self.env.observation_space[take_key] 14 | 15 | def observation(self, observation): 16 | observation = copy.copy(observation) 17 | taken_observation = observation.pop(self._take_key) 18 | self._ignored_observations = observation 19 | return taken_observation 20 | -------------------------------------------------------------------------------- /jaxrl/agents/drq/augmentations.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def random_crop(key, img, padding): 6 | crop_from = jax.random.randint(key, (2, ), 0, 2 * padding + 1) 7 | crop_from = jnp.concatenate([crop_from, jnp.zeros((1, ), dtype=jnp.int32)]) 8 | padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0)), 9 | mode='edge') 10 | return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) 11 | 12 | 13 | def batched_random_crop(key, imgs, padding=4): 14 | keys = jax.random.split(key, imgs.shape[0]) 15 | return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding) 16 | -------------------------------------------------------------------------------- /jaxrl/wrappers/rgb2gray.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class RGB2Gray(gym.ObservationWrapper): 6 | 7 | def __init__(self, env): 8 | super().__init__(env) 9 | 10 | obs_shape = env.observation_space.shape 11 | self.observation_space = gym.spaces.Box(low=0, 12 | high=255, 13 | shape=(*obs_shape[:2], 1), 14 | dtype=np.uint8) 15 | 16 | def observation(self, observation): 17 | observation = np.dot(observation, [[0.299], [0.587], [0.114]]) 18 | return observation.astype(np.uint8) 19 | -------------------------------------------------------------------------------- /configs/redq_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | # https://arxiv.org/abs/2101.05982 6 | config = ml_collections.ConfigDict() 7 | 8 | config.algo = 'redq' 9 | 10 | config.actor_lr = 3e-4 11 | config.critic_lr = 3e-4 12 | config.temp_lr = 3e-4 13 | 14 | config.n = 10 15 | config.m = 2 16 | 17 | config.hidden_dims = (256, 256) 18 | 19 | config.discount = 0.99 20 | 21 | config.tau = 0.005 22 | config.target_update_period = 1 23 | 24 | config.init_temperature = 1.0 25 | config.target_entropy = None 26 | config.backup_entropy = True 27 | 28 | config.replay_buffer_size = None 29 | 30 | return config 31 | -------------------------------------------------------------------------------- /configs/sac_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.algo = 'sac' 8 | 9 | config.actor_lr = 1e-3 10 | config.critic_lr = 1e-3 11 | config.temp_lr = 1e-3 12 | 13 | config.hidden_dims = (256, 256, 256, 256) 14 | config.name_activation = 'leaky_relu' 15 | config.use_layer_norm = True 16 | 17 | config.discount = 0.99 18 | 19 | config.tau = 5e-3 20 | config.target_update_period = 1 21 | 22 | config.init_temperature = 1.0 23 | config.target_entropy = None 24 | config.backup_entropy = True 25 | 26 | config.replay_buffer_size = int(1e6) 27 | 28 | return config 29 | -------------------------------------------------------------------------------- /jaxrl/agents/ddpg/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | 5 | from jaxrl.datasets import Batch 6 | from jaxrl.networks.common import InfoDict, Model, Params 7 | 8 | 9 | def update(actor: Model, critic: Model, 10 | batch: Batch) -> Tuple[Model, InfoDict]: 11 | 12 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 13 | actions = actor.apply_fn({'params': actor_params}, batch.observations) 14 | q1, q2 = critic(batch.observations, actions) 15 | q = jnp.minimum(q1, q2) 16 | actor_loss = -q.mean() 17 | return actor_loss, {'actor_loss': actor_loss} 18 | 19 | new_actor, info = actor.apply_gradient(actor_loss_fn) 20 | 21 | return new_actor, info 22 | -------------------------------------------------------------------------------- /jaxrl/wrappers/repeat_action.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from jaxrl.wrappers.common import TimeStep 5 | 6 | 7 | class RepeatAction(gym.Wrapper): 8 | 9 | def __init__(self, env, action_repeat=4): 10 | super().__init__(env) 11 | self._action_repeat = action_repeat 12 | 13 | def step(self, action: np.ndarray) -> TimeStep: 14 | total_reward = 0.0 15 | done = None 16 | combined_info = {} 17 | 18 | for _ in range(self._action_repeat): 19 | obs, reward, done, info = self.env.step(action) 20 | total_reward += reward 21 | combined_info.update(info) 22 | if done: 23 | break 24 | 25 | return obs, total_reward, done, combined_info 26 | -------------------------------------------------------------------------------- /configs/drq_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.algo = 'drq' 8 | 9 | config.actor_lr = 3e-4 10 | config.critic_lr = 3e-4 11 | config.temp_lr = 3e-4 12 | 13 | config.hidden_dims = (256, 256) 14 | 15 | config.cnn_features = (32, 32, 32, 32) 16 | config.cnn_strides = (2, 1, 1, 1) 17 | config.cnn_padding = 'VALID' 18 | config.latent_dim = 50 19 | 20 | config.discount = 0.99 21 | 22 | config.tau = 0.005 23 | config.target_update_period = 1 24 | 25 | config.init_temperature = 0.1 26 | config.target_entropy = None 27 | 28 | config.replay_buffer_size = 100_000 29 | 30 | config.gray_scale = False 31 | config.image_size = 84 32 | 33 | return config 34 | -------------------------------------------------------------------------------- /configs/drq_faster.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.algo = 'drq' 8 | 9 | config.actor_lr = 3e-4 10 | config.critic_lr = 3e-4 11 | config.temp_lr = 3e-4 12 | 13 | config.hidden_dims = (256, 256) 14 | 15 | config.cnn_features = (32, 64, 128, 256) 16 | config.cnn_strides = (2, 2, 2, 2) 17 | config.cnn_padding = 'SAME' 18 | config.latent_dim = 50 19 | 20 | config.discount = 0.99 21 | 22 | config.tau = 0.005 23 | config.target_update_period = 1 24 | 25 | config.init_temperature = 0.1 26 | config.target_entropy = None 27 | 28 | config.replay_buffer_size = 100_000 29 | 30 | config.gray_scale = True 31 | config.image_size = 64 32 | 33 | return config 34 | -------------------------------------------------------------------------------- /jaxrl/agents/awac/value.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxrl.datasets import Batch 7 | from jaxrl.networks.common import Model, PRNGKey 8 | 9 | 10 | def get_value(key: PRNGKey, actor: Model, critic: Model, batch: Batch, 11 | num_samples: int) -> Tuple[jnp.ndarray, jnp.ndarray]: 12 | dist = actor(batch.observations) 13 | 14 | policy_actions = dist.sample(seed=key, sample_shape=[num_samples]) 15 | 16 | n_observations = jnp.repeat(batch.observations[jnp.newaxis], 17 | num_samples, 18 | axis=0) 19 | q_pi1, q_pi2 = critic(n_observations, policy_actions) 20 | 21 | def get_v(q): 22 | return jnp.mean(q, axis=0) 23 | 24 | return get_v(q_pi1), get_v(q_pi2) 25 | -------------------------------------------------------------------------------- /configs/awac_default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | # Hyper parameters from the official implementation. 5 | def get_config(): 6 | config = ml_collections.ConfigDict() 7 | 8 | config.algo = 'awac' 9 | 10 | config.actor_optim_kwargs = ml_collections.ConfigDict() 11 | config.actor_optim_kwargs.learning_rate = 3e-4 12 | config.actor_optim_kwargs.weight_decay = 1e-4 13 | config.actor_hidden_dims = (256, 256, 256, 256) 14 | config.state_dependent_std = False 15 | 16 | config.critic_lr = 3e-4 17 | config.critic_hidden_dims = (256, 256) 18 | config.discount = 0.99 19 | 20 | config.tau = 0.005 21 | config.target_update_period = 1 22 | 23 | config.beta = 2.0 24 | 25 | config.num_samples = 1 26 | 27 | config.replay_buffer_size = None 28 | 29 | return config 30 | -------------------------------------------------------------------------------- /configs/sac_tadell.yaml: -------------------------------------------------------------------------------- 1 | actor_configs: 2 | clip_mean: 2.0 3 | final_fc_init_scale: 0.001 4 | hidden_dims: !!python/tuple 5 | - 256 6 | - 256 7 | - 256 8 | name_activation: leaky_relu 9 | state_dependent_std: true 10 | use_layer_norm: true 11 | critic_configs: 12 | hidden_dims: !!python/tuple 13 | - 256 14 | - 256 15 | - 256 16 | name_activation: leaky_relu 17 | use_layer_norm: true 18 | dict_configs: 19 | alpha: 0.001 20 | c: 1.0 21 | method: lasso_lars 22 | positive_code: false 23 | scale_code: false 24 | init_temperature: 1.0 25 | pi_opt_configs: 26 | clip_method: global_clip 27 | lr: 0.0003 28 | max_norm: 1.0 29 | optim_algo: adam 30 | q_opt_configs: 31 | clip_method: global_clip 32 | lr: 0.0003 33 | max_norm: 1.0 34 | optim_algo: adam 35 | t_opt_configs: 36 | clip_method: none 37 | lr: 0.0003 38 | max_norm: -1 39 | optim_algo: adam 40 | target_entropy: -2.0 41 | tau: 0.005 42 | -------------------------------------------------------------------------------- /jaxrl/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import gym 4 | 5 | from jaxrl.datasets.awac_dataset import AWACDataset 6 | from jaxrl.datasets.d4rl_dataset import D4RLDataset 7 | from jaxrl.datasets.dataset import Dataset 8 | from jaxrl.datasets.rl_unplugged_dataset import RLUnpluggedDataset 9 | from jaxrl.utils import make_env 10 | 11 | 12 | def make_env_and_dataset(env_name: str, seed: int, dataset_name: str, 13 | video_save_folder: str) -> Tuple[gym.Env, Dataset]: 14 | env = make_env(env_name, seed, video_save_folder) 15 | 16 | if 'd4rl' in dataset_name: 17 | dataset = D4RLDataset(env) 18 | elif 'awac' in dataset_name: 19 | dataset = AWACDataset(env_name) 20 | elif 'rl_unplugged' in dataset_name: 21 | dataset = RLUnpluggedDataset(env_name.replace('-', '_')) 22 | else: 23 | raise NotImplementedError(f'{dataset_name} is not available!') 24 | 25 | return env, dataset 26 | -------------------------------------------------------------------------------- /jaxrl/agents/redq/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | 5 | from jaxrl.datasets import Batch 6 | from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey 7 | 8 | 9 | def update(key: PRNGKey, actor: Model, critic: Model, temp: Model, 10 | batch: Batch) -> Tuple[Model, InfoDict]: 11 | 12 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 13 | dist = actor.apply_fn({'params': actor_params}, batch.observations) 14 | actions = dist.sample(seed=key) 15 | log_probs = dist.log_prob(actions) 16 | qs = critic(batch.observations, actions) 17 | q = jnp.mean(qs, 0) 18 | actor_loss = (log_probs * temp() - q).mean() 19 | return actor_loss, { 20 | 'actor_loss': actor_loss, 21 | 'entropy': -log_probs.mean() 22 | } 23 | 24 | new_actor, info = actor.apply_gradient(actor_loss_fn) 25 | 26 | return new_actor, info 27 | -------------------------------------------------------------------------------- /jaxrl/agents/sac/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxrl.datasets import Batch 7 | from jaxrl.networks.common import InfoDict, TrainState, Params, PRNGKey 8 | 9 | 10 | def update(key: PRNGKey, actor: TrainState, critic: TrainState, 11 | temp: TrainState, batch: Batch) -> Tuple[TrainState, InfoDict]: 12 | 13 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 14 | dist = actor.apply_fn({'params': actor_params}, batch.observations) 15 | actions = dist.sample(seed=key) 16 | log_probs = dist.log_prob(actions) 17 | q1, q2 = critic(batch.observations, actions) 18 | q = jnp.minimum(q1, q2) 19 | actor_loss = (log_probs * temp() - q).mean() 20 | return actor_loss, { 21 | 'actor_loss': actor_loss, 22 | 'entropy': -log_probs.mean(), 23 | 'means': actions.mean() 24 | } 25 | 26 | grads_actor, info = jax.grad(actor_loss_fn, has_aux=True)(actor.params) 27 | new_actor = actor.apply_gradients(grads=grads_actor) 28 | 29 | return new_actor, info -------------------------------------------------------------------------------- /jaxrl/agents/ddpg/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | 5 | from jaxrl.datasets import Batch 6 | from jaxrl.networks.common import InfoDict, Model, Params 7 | 8 | 9 | def update(actor: Model, critic: Model, target_critic: Model, batch: Batch, 10 | discount: float) -> Tuple[Model, InfoDict]: 11 | next_actions = actor(batch.next_observations) 12 | next_q1, next_q2 = target_critic(batch.next_observations, next_actions) 13 | next_q = jnp.minimum(next_q1, next_q2) 14 | 15 | target_q = batch.rewards + discount * batch.masks * next_q 16 | 17 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 18 | q1, q2 = critic.apply_fn({'params': critic_params}, batch.observations, 19 | batch.actions) 20 | critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() 21 | return critic_loss, { 22 | 'critic_loss': critic_loss, 23 | 'q1': q1.mean(), 24 | 'q2': q2.mean() 25 | } 26 | 27 | new_critic, info = critic.apply_gradient(critic_loss_fn) 28 | 29 | return new_critic, info 30 | -------------------------------------------------------------------------------- /configs/sac_cotasp.yaml: -------------------------------------------------------------------------------- 1 | # networks 2 | actor_configs: 3 | clip_mean: 1.0 4 | final_fc_init_scale: 1.0e-4 5 | hidden_dims: !!python/tuple 6 | - 1024 7 | - 1024 8 | - 1024 9 | - 1024 10 | name_activation: leaky_relu 11 | state_dependent_std: true 12 | use_layer_norm: true 13 | critic_configs: 14 | hidden_dims: !!python/tuple 15 | - 256 16 | - 256 17 | - 256 18 | - 256 19 | name_activation: leaky_relu 20 | use_layer_norm: true 21 | # dictionaries 22 | update_coef: true 23 | update_dict: true 24 | dict_configs: 25 | alpha: 0.001 26 | c: 1.0 27 | method: lasso_lars 28 | positive_code: false 29 | scale_code: false 30 | # optimizers 31 | pi_opt_configs: 32 | optim_algo: adam 33 | clip_method: none 34 | max_norm: -1 35 | opt_kargs: 36 | learning_rate: 3.0e-4 37 | q_opt_configs: 38 | optim_algo: adam 39 | clip_method: none 40 | max_norm: -1 41 | opt_kargs: 42 | learning_rate: 3.0e-4 43 | t_opt_configs: 44 | optim_algo: adam 45 | clip_method: none 46 | max_norm: -1 47 | opt_kargs: 48 | learning_rate: 3.0e-4 49 | # SAC misc 50 | init_temperature: 1.0 51 | target_update_period: 1 52 | target_entropy: -2.0 53 | tau: 0.005 54 | -------------------------------------------------------------------------------- /jaxrl/wrappers/single_precision.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import gym 4 | import numpy as np 5 | from gym.spaces import Box, Dict 6 | 7 | 8 | class SinglePrecision(gym.ObservationWrapper): 9 | 10 | def __init__(self, env): 11 | super().__init__(env) 12 | 13 | if isinstance(self.observation_space, Box): 14 | obs_space = self.observation_space 15 | self.observation_space = Box(obs_space.low, obs_space.high, 16 | obs_space.shape) 17 | elif isinstance(self.observation_space, Dict): 18 | obs_spaces = copy.copy(self.observation_space.spaces) 19 | for k, v in obs_spaces.items(): 20 | obs_spaces[k] = Box(v.low, v.high, v.shape) 21 | self.observation_space = Dict(obs_spaces) 22 | else: 23 | raise NotImplementedError 24 | 25 | def observation(self, observation: np.ndarray) -> np.ndarray: 26 | if isinstance(observation, np.ndarray): 27 | return observation.astype(np.float32) 28 | elif isinstance(observation, dict): 29 | observation = copy.copy(observation) 30 | for k, v in observation.items(): 31 | observation[k] = v.astype(np.float32) 32 | return observation 33 | -------------------------------------------------------------------------------- /jaxrl/agents/sac/temperature.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax import linen as nn 6 | 7 | from jaxrl.networks.common import InfoDict, TrainState 8 | 9 | 10 | LOG_TEMP_MAX = 5.0 11 | LOG_TEMP_MIN = -10.0 12 | 13 | def temp_activation(temp, temp_min=LOG_TEMP_MIN, temp_max=LOG_TEMP_MAX): 14 | return temp_min + 0.5 * (temp_max - temp_min) * (jnp.tanh(temp) + 1.) 15 | 16 | 17 | class Temperature(nn.Module): 18 | init_log_temp: float = 1.0 19 | 20 | @nn.compact 21 | def __call__(self) -> jnp.ndarray: 22 | log_temp = self.param('log_temp', 23 | init_fn=lambda key: jnp.full( 24 | (), self.init_log_temp)) 25 | return jnp.exp(log_temp) 26 | 27 | 28 | def update(temp: TrainState, entropy: float, 29 | target_entropy: float) -> Tuple[TrainState, InfoDict]: 30 | 31 | def temperature_loss_fn(temp_params): 32 | temperature = temp.apply_fn({'params': temp_params}) 33 | temp_loss = temperature * (entropy - target_entropy).mean() 34 | return temp_loss, {'temperature': temperature, 'temp_loss': temp_loss} 35 | 36 | grads_temp, info = jax.grad(temperature_loss_fn, has_aux=True)(temp.params) 37 | new_temp = temp.apply_gradients(grads=grads_temp) 38 | 39 | return new_temp, info 40 | -------------------------------------------------------------------------------- /jaxrl/datasets/rl_unplugged_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import d4rl 4 | import numpy as np 5 | 6 | from jaxrl.datasets.dataset import Dataset 7 | 8 | 9 | class RLUnpluggedDataset(Dataset): 10 | 11 | def __init__(self, 12 | task_name: str, 13 | clip_to_eps: bool = True, 14 | eps: float = 1e-5): 15 | save_dir = os.path.join(d4rl.offline_env.DATASET_PATH, 'rl_unplugged') 16 | os.makedirs(save_dir, exist_ok=True) 17 | dataset = {} 18 | with open(os.path.join(save_dir, f'{task_name}.npz'), 'rb') as f: 19 | dataset_file = np.load(f) 20 | for k, v in dataset_file.items(): 21 | dataset[k] = v 22 | if clip_to_eps: 23 | lim = 1 - eps 24 | dataset['actions'] = np.clip(dataset['actions'], -lim, lim) 25 | 26 | super().__init__(dataset['observations'].astype(np.float32), 27 | actions=dataset['actions'].astype(np.float32), 28 | rewards=dataset['rewards'].astype(np.float32), 29 | masks=dataset['masks'].astype(np.float32), 30 | dones_float=dataset['done_floats'].astype(np.float32), 31 | next_observations=dataset['next_observations'].astype( 32 | np.float32), 33 | size=len(dataset['observations'])) 34 | -------------------------------------------------------------------------------- /jaxrl/agents/awac/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxrl.agents.awac.value import get_value 7 | from jaxrl.datasets import Batch 8 | from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey 9 | 10 | 11 | def update(key: PRNGKey, actor: Model, critic: Model, batch: Batch, 12 | num_samples: int, beta: float) -> Tuple[Model, InfoDict]: 13 | 14 | v1, v2 = get_value(key, actor, critic, batch, num_samples) 15 | v = jnp.minimum(v1, v2) 16 | 17 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 18 | dist = actor.apply_fn({'params': actor_params}, batch.observations) 19 | lim = 1 - 1e-5 20 | actions = jnp.clip(batch.actions, -lim, lim) 21 | log_probs = dist.log_prob(actions) 22 | 23 | q1, q2 = critic(batch.observations, actions) 24 | q = jnp.minimum(q1, q2) 25 | a = q - v 26 | 27 | # we could have used exp(a / beta) here but 28 | # exp(a / beta) is unbiased but high variance, 29 | # softmax(a / beta) is biased but lower variance. 30 | # sum() instead of mean(), because it should be multiplied by batch size. 31 | actor_loss = -(jax.nn.softmax(a / beta) * log_probs).sum() 32 | 33 | return actor_loss, {'actor_loss': actor_loss} 34 | 35 | new_actor, info = actor.apply_gradient(actor_loss_fn) 36 | 37 | return new_actor, info 38 | -------------------------------------------------------------------------------- /jaxrl/wrappers/episode_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from jaxrl.wrappers.common import TimeStep 7 | 8 | 9 | class EpisodeMonitor(gym.ActionWrapper): 10 | """A class that computes episode returns and lengths.""" 11 | 12 | def __init__(self, env: gym.Env): 13 | super().__init__(env) 14 | self._reset_stats() 15 | self.total_timesteps = 0 16 | 17 | def _reset_stats(self): 18 | self.reward_sum = 0.0 19 | self.episode_length = 0 20 | self.start_time = time.time() 21 | 22 | def step(self, action: np.ndarray) -> TimeStep: 23 | 24 | observation, reward, done, info = self.env.step(action) 25 | 26 | self.reward_sum += reward 27 | self.episode_length += 1 28 | self.total_timesteps += 1 29 | info['total'] = {'timesteps': self.total_timesteps} 30 | 31 | if done: 32 | info['episode'] = {} 33 | info['episode']['return'] = self.reward_sum 34 | info['episode']['length'] = self.episode_length 35 | info['episode']['duration'] = time.time() - self.start_time 36 | 37 | if hasattr(self, 'get_normalized_score'): 38 | info['episode']['return'] = self.get_normalized_score( 39 | info['episode']['return']) * 100.0 40 | 41 | return observation, reward, done, info 42 | 43 | def reset(self, **kwargs) -> np.ndarray: 44 | self._reset_stats() 45 | return self.env.reset() 46 | -------------------------------------------------------------------------------- /jaxrl/agents/bc/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxrl.datasets import Batch 7 | from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey 8 | 9 | 10 | def log_prob_update(actor: Model, batch: Batch, 11 | rng: PRNGKey) -> Tuple[Model, InfoDict]: 12 | rng, key = jax.random.split(rng) 13 | 14 | def loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 15 | dist = actor.apply_fn({'params': actor_params}, 16 | batch.observations, 17 | training=True, 18 | rngs={'dropout': key}) 19 | log_probs = dist.log_prob(batch.actions) 20 | actor_loss = -log_probs.mean() 21 | return actor_loss, {'actor_loss': actor_loss} 22 | 23 | return (rng, *actor.apply_gradient(loss_fn)) 24 | 25 | 26 | def mse_update(actor: Model, batch: Batch, 27 | rng: PRNGKey) -> Tuple[Model, InfoDict]: 28 | rng, key = jax.random.split(rng) 29 | 30 | def loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 31 | actions = actor.apply_fn({'params': actor_params}, 32 | batch.observations, 33 | training=True, 34 | rngs={'dropout': key}) 35 | actor_loss = ((actions - batch.actions)**2).mean() 36 | return actor_loss, {'actor_loss': actor_loss} 37 | 38 | return (rng, *actor.apply_gradient(loss_fn)) 39 | -------------------------------------------------------------------------------- /jaxrl/datasets/d4rl_dataset.py: -------------------------------------------------------------------------------- 1 | import d4rl 2 | import gym 3 | import numpy as np 4 | 5 | from jaxrl.datasets.dataset import Batch, Dataset 6 | 7 | 8 | class D4RLDataset(Dataset): 9 | 10 | def __init__(self, 11 | env: gym.Env, 12 | clip_to_eps: bool = True, 13 | eps: float = 1e-5): 14 | dataset = d4rl.qlearning_dataset(env) 15 | 16 | if clip_to_eps: 17 | lim = 1 - eps 18 | dataset['actions'] = np.clip(dataset['actions'], -lim, lim) 19 | 20 | dones_float = np.zeros_like(dataset['rewards']) 21 | 22 | for i in range(len(dones_float) - 1): 23 | if np.linalg.norm(dataset['observations'][i + 1] - 24 | dataset['next_observations'][i] 25 | ) > 1e-6 or dataset['terminals'][i] == 1.0: 26 | dones_float[i] = 1 27 | else: 28 | dones_float[i] = 0 29 | 30 | dones_float[-1] = 1 31 | 32 | super().__init__(dataset['observations'].astype(np.float32), 33 | actions=dataset['actions'].astype(np.float32), 34 | rewards=dataset['rewards'].astype(np.float32), 35 | masks=1.0 - dataset['terminals'].astype(np.float32), 36 | dones_float=dones_float.astype(np.float32), 37 | next_observations=dataset['next_observations'].astype( 38 | np.float32), 39 | size=len(dataset['observations'])) 40 | -------------------------------------------------------------------------------- /configs/sac_cw.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.optim_configs = ml_collections.ConfigDict() 8 | config.optim_configs.lr = 3e-4 # [3e-4, 1e-3] 9 | config.optim_configs.max_norm = 1.0 # [*1e3*, 1e4, 1e5, 1e6, 1e7] 10 | config.optim_configs.optim_algo = 'adam' # unadjustable 11 | config.optim_configs.clip_method = 'global_clip' # unadjustable 12 | 13 | config.actor_configs = ml_collections.ConfigDict() 14 | config.actor_configs.hidden_dims = (256, 256, 256) # unadjustable 15 | config.actor_configs.name_activation = 'leaky_relu' # unadjustable 16 | config.actor_configs.use_rms_norm = False # unadjustable 17 | config.actor_configs.use_layer_norm = False # unadjustable 18 | config.actor_configs.final_fc_init_scale = 1e-3 # unadjustable 19 | config.actor_configs.clip_mean = 1.0 # unadjustable 20 | config.actor_configs.state_dependent_std = True # unadjustable 21 | 22 | config.critic_configs = ml_collections.ConfigDict() 23 | config.critic_configs.hidden_dims = (256, 256, 256) # unadjustable 24 | config.critic_configs.name_activation = 'leaky_relu' # unadjustable 25 | config.critic_configs.use_layer_norm = False # unadjustable 26 | 27 | config.tau = 0.005 28 | config.init_temperature = 1.0 # unadjustable 29 | config.target_entropy = -4.0 # by default 30 | 31 | return config 32 | -------------------------------------------------------------------------------- /jaxrl/wrappers/absorbing_states.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from gym import Wrapper 4 | 5 | 6 | def make_non_absorbing(observation): 7 | return np.concatenate([observation, [0.0]], -1) 8 | 9 | 10 | class AbsorbingStatesWrapper(Wrapper): 11 | 12 | def __init__(self, env): 13 | super().__init__(env) 14 | low = env.observation_space.low 15 | high = env.observation_space.high 16 | self._absorbing_state = np.concatenate([np.zeros_like(low), [1.0]], 0) 17 | low = np.concatenate([low, [0]], 0) 18 | high = np.concatenate([high, [1]], 0) 19 | 20 | self.observation_space = gym.spaces.Box( 21 | low=low, high=high, dtype=env.observation_space.dtype) 22 | 23 | def reset(self, **kwargs): 24 | self._done = False 25 | self._absorbing = False 26 | self._info = {} 27 | return make_non_absorbing(self.env.reset(**kwargs)) 28 | 29 | def step(self, action): 30 | if not self._done: 31 | observation, reward, done, info = self.env.step(action) 32 | observation = make_non_absorbing(observation) 33 | self._done = done 34 | self._info = info 35 | truncated_done = 'TimeLimit.truncated' in info 36 | return observation, reward, truncated_done, info 37 | else: 38 | if not self._absorbing: 39 | self._absorbing = True 40 | return self._absorbing_state, 0.0, False, self._info 41 | else: 42 | return self._absorbing_state, 0.0, True, self._info 43 | 44 | 45 | if __name__ == '__main__': 46 | env = gym.make('Hopper-v2') 47 | env = AbsorbingStatesWrapper(env) 48 | env.reset() 49 | 50 | done = False 51 | while not done: 52 | action = env.action_space.sample() 53 | obs, reward, done, info = env.step(action) 54 | print(obs, done) 55 | -------------------------------------------------------------------------------- /jaxrl/agents/sac/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import jax 4 | import optax 5 | import jax.numpy as jnp 6 | 7 | from jaxrl.datasets import Batch 8 | from jaxrl.networks.common import InfoDict, TrainState, Params, PRNGKey 9 | 10 | 11 | def target_update(critic: TrainState, target_critic: TrainState, tau: float) -> TrainState: 12 | # new_target_params = jax.tree_multimap( 13 | # lambda p, tp: p * tau + tp * (1 - tau), critic.params, 14 | # target_critic.params) 15 | # use optax's implementation 16 | new_target_params = optax.incremental_update( 17 | critic.params, target_critic.params, tau 18 | ) 19 | return target_critic.replace(params=new_target_params) 20 | 21 | 22 | def update(key: PRNGKey, actor: TrainState, critic: TrainState, target_critic: TrainState, 23 | temp: TrainState, batch: Batch, discount: float) -> Tuple[TrainState, InfoDict]: 24 | 25 | dist = actor(batch.next_observations) 26 | next_actions = dist.sample(seed=key) 27 | next_log_probs = dist.log_prob(next_actions) 28 | next_q1, next_q2 = target_critic(batch.next_observations, next_actions) 29 | next_q = jnp.minimum(next_q1, next_q2) 30 | next_q -= temp() * next_log_probs 31 | target_q = batch.rewards + discount * batch.masks * next_q 32 | 33 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 34 | q1, q2 = critic.apply_fn({'params': critic_params}, batch.observations, 35 | batch.actions) 36 | critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() 37 | return critic_loss, { 38 | 'critic_loss': critic_loss, 39 | 'q1': q1.mean(), 40 | 'q2': q2.mean() 41 | } 42 | 43 | grads_critic, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params) 44 | new_critic = critic.apply_gradients(grads=grads_critic) 45 | 46 | return new_critic, info 47 | -------------------------------------------------------------------------------- /jaxrl/agents/sac_v1/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax.numpy as jnp 4 | 5 | from jaxrl.datasets import Batch 6 | from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey 7 | 8 | 9 | def update_v(key: PRNGKey, actor: Model, critic: Model, value: Model, 10 | temp: Model, batch: Batch, 11 | soft_critic: bool) -> Tuple[Model, InfoDict]: 12 | dist = actor(batch.observations) 13 | actions = dist.sample(seed=key) 14 | log_probs = dist.log_prob(actions) 15 | q1, q2 = critic(batch.observations, actions) 16 | target_v = jnp.minimum(q1, q2) 17 | 18 | if soft_critic: 19 | target_v -= temp() * log_probs 20 | 21 | def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 22 | v = value.apply_fn({'params': value_params}, batch.observations) 23 | value_loss = ((v - target_v)**2).mean() 24 | return value_loss, { 25 | 'value_loss': value_loss, 26 | 'v': v.mean(), 27 | } 28 | 29 | new_value, info = value.apply_gradient(value_loss_fn) 30 | 31 | return new_value, info 32 | 33 | 34 | def update_q(critic: Model, target_value: Model, batch: Batch, 35 | discount: float) -> Tuple[Model, InfoDict]: 36 | next_v = target_value(batch.next_observations) 37 | 38 | target_q = batch.rewards + discount * batch.masks * next_v 39 | 40 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 41 | q1, q2 = critic.apply_fn({'params': critic_params}, batch.observations, 42 | batch.actions) 43 | critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() 44 | return critic_loss, { 45 | 'critic_loss': critic_loss, 46 | 'q1': q1.mean(), 47 | 'q2': q2.mean() 48 | } 49 | 50 | new_critic, info = critic.apply_gradient(critic_loss_fn) 51 | 52 | return new_critic, info 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoTASP 2 | Code for "[**Co**ntinual **T**ask **A**llocation in Meta-Policy Network via **S**parse **P**rompting](https://arxiv.org/abs/2305.18444)", presented in ICML 2023. 3 | 4 | ## Key Dependencies 5 | The some of dependencies are outdated, e.g., the ones listed below. You **might** need to install their latest version to run this project. 6 | ```console 7 | python==3.7.13 8 | - jax==0.3.17 9 | - jaxlib==0.3.15+cuda11.cudnn82 10 | - flax==0.6.4 11 | - optax==0.1.4 12 | - scikit-learn==1.0.2 13 | - tensorflow-probability==0.18.0 14 | - sentence-transformers==2.2.2 15 | ``` 16 | Refer to [this repo](https://github.com/awarelab/continual_world) for the installation of Continual World. 17 | 18 | ## Quick Start 19 | ```python 20 | python train_cotasp.py 21 | ``` 22 | 23 | ## Reproducibility 24 | Tracked experiments on CW20 via [Weights & Biases](https://api.wandb.ai/links/yang-yj/5kbiuz7h). 25 | 26 | ## Citing CoTASP 27 | If you use the code in CoTASP, please kindly cite our paper using following BibTeX entry. 28 | ``` 29 | @InProceedings{pmlr-v202-yang23t, 30 | title = {Continual Task Allocation in Meta-Policy Network via Sparse Prompting}, 31 | author = {Yang, Yijun and Zhou, Tianyi and Jiang, Jing and Long, Guodong and Shi, Yuhui}, 32 | booktitle = {Proceedings of the 40th International Conference on Machine Learning}, 33 | pages = {39623--39638}, 34 | year = {2023}, 35 | volume = {202}, 36 | series = {Proceedings of Machine Learning Research}, 37 | month = {23--29 Jul}, 38 | publisher = {PMLR}, 39 | pdf = {https://proceedings.mlr.press/v202/yang23t/yang23t.pdf}, 40 | url = {https://proceedings.mlr.press/v202/yang23t.html}, 41 | } 42 | ``` 43 | 44 | ## Acknowledgement 45 | We appreciate the open source of the following projects: 46 | 47 | [Continual World](https://github.com/awarelab/continual_world), [Meta World](https://github.com/Farama-Foundation/Metaworld), and [JaxRL](https://github.com/ikostrikov/jaxrl) 48 | -------------------------------------------------------------------------------- /jaxrl/networks/critic_net.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | from typing import Callable, Sequence, Tuple 4 | 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | 8 | from jaxrl.networks.common import MLP, activation_fn 9 | 10 | 11 | class ValueCritic(nn.Module): 12 | hidden_dims: Sequence[int] 13 | 14 | @nn.compact 15 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 16 | critic = MLP((*self.hidden_dims, 1))(observations) 17 | return jnp.squeeze(critic, -1) 18 | 19 | 20 | class Critic(nn.Module): 21 | hidden_dims: Sequence[int] 22 | name_activation: str = 'leaky_relu' 23 | use_layer_norm: bool = True 24 | 25 | @nn.compact 26 | def __call__(self, observations: jnp.ndarray, 27 | actions: jnp.ndarray) -> jnp.ndarray: 28 | inputs = jnp.concatenate([observations, actions], -1) 29 | critic = MLP( 30 | (*self.hidden_dims, 1), 31 | activations=activation_fn(self.name_activation), 32 | use_layer_norm=self.use_layer_norm, 33 | activate_final=False)(inputs) 34 | return jnp.squeeze(critic, -1) 35 | 36 | 37 | class DoubleCritic(nn.Module): 38 | hidden_dims: Sequence[int] 39 | name_activation: str = 'leaky_relu' 40 | use_layer_norm: bool = True 41 | num_qs: int = 2 42 | 43 | @nn.compact 44 | def __call__(self, states, actions): 45 | 46 | VmapCritic = nn.vmap(Critic, 47 | variable_axes={'params': 0}, 48 | split_rngs={'params': True}, 49 | in_axes=None, 50 | out_axes=0, 51 | axis_size=self.num_qs) 52 | qs = VmapCritic(self.hidden_dims, 53 | name_activation=self.name_activation, 54 | use_layer_norm=self.use_layer_norm)(states, actions) 55 | return qs 56 | -------------------------------------------------------------------------------- /jaxrl/agents/redq/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from jaxrl.datasets import Batch 7 | from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey 8 | 9 | 10 | def target_update(critic: Model, target_critic: Model, tau: float) -> Model: 11 | new_target_params = jax.tree_multimap( 12 | lambda p, tp: p * tau + tp * (1 - tau), critic.params, 13 | target_critic.params) 14 | 15 | return target_critic.replace(params=new_target_params) 16 | 17 | 18 | def update(rng: PRNGKey, actor: Model, critic: Model, target_critic: Model, 19 | temp: Model, batch: Batch, discount: float, backup_entropy: bool, 20 | n: int, m: int) -> Tuple[Model, InfoDict]: 21 | dist = actor(batch.next_observations) 22 | rng, key = jax.random.split(rng) 23 | next_actions = dist.sample(seed=key) 24 | next_log_probs = dist.log_prob(next_actions) 25 | 26 | all_indx = jnp.arange(0, n) 27 | rng, key = jax.random.split(rng) 28 | indx = jax.random.choice(key, a=all_indx, shape=(m, ), replace=False) 29 | params = jax.tree_util.tree_map(lambda param: param[indx], 30 | target_critic.params) 31 | next_qs = target_critic.apply_fn({'params': params}, 32 | batch.next_observations, next_actions) 33 | next_q = jnp.min(next_qs, axis=0) 34 | 35 | target_q = batch.rewards + discount * batch.masks * next_q 36 | 37 | if backup_entropy: 38 | target_q -= discount * batch.masks * temp() * next_log_probs 39 | 40 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 41 | qs = critic.apply_fn({'params': critic_params}, batch.observations, 42 | batch.actions) 43 | critic_loss = ((qs - target_q)**2).mean() 44 | return critic_loss, {'critic_loss': critic_loss, 'qs': qs.mean()} 45 | 46 | new_critic, info = critic.apply_gradient(critic_loss_fn) 47 | 48 | return new_critic, info 49 | -------------------------------------------------------------------------------- /jaxrl/wrappers/frame_stack.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import gym 4 | import numpy as np 5 | from gym.spaces import Box 6 | 7 | 8 | # From https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py#L229 9 | # and modified for memory efficiency. 10 | class LazyFrames(object): 11 | 12 | def __init__(self, frames, stack_axis=-1): 13 | self._frames = frames 14 | self._stack_axis = stack_axis 15 | 16 | def __array__(self, dtype=None): 17 | out = np.concatenate(self._frames, axis=self._stack_axis) 18 | if dtype is not None: 19 | out = out.astype(dtype) 20 | return out 21 | 22 | 23 | class FrameStack(gym.Wrapper): 24 | 25 | def __init__(self, env, num_stack: int, stack_axis=-1, lazy=False): 26 | super().__init__(env) 27 | self._num_stack = num_stack 28 | self._stack_axis = stack_axis 29 | self._lazy = lazy 30 | 31 | self._frames = collections.deque([], maxlen=num_stack) 32 | 33 | low = np.repeat(self.observation_space.low, num_stack, axis=stack_axis) 34 | high = np.repeat(self.observation_space.high, 35 | num_stack, 36 | axis=stack_axis) 37 | self.observation_space = Box(low=low, 38 | high=high, 39 | dtype=self.observation_space.dtype) 40 | 41 | def reset(self): 42 | obs = self.env.reset() 43 | for _ in range(self._num_stack): 44 | self._frames.append(obs) 45 | return self._get_obs() 46 | 47 | def step(self, action): 48 | obs, reward, done, info = self.env.step(action) 49 | self._frames.append(obs) 50 | return self._get_obs(), reward, done, info 51 | 52 | def _get_obs(self): 53 | assert len(self._frames) == self._num_stack 54 | if self._lazy: 55 | return LazyFrames(list(self._frames), stack_axis=self._stack_axis) 56 | else: 57 | return np.concatenate(list(self._frames), axis=self._stack_axis) 58 | -------------------------------------------------------------------------------- /configs/sac_cotasp.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import ml_collections 3 | 4 | 5 | def get_config(): 6 | config = ml_collections.ConfigDict() 7 | config.update_dict = True 8 | config.update_coef = True 9 | 10 | config.dict_configs = ml_collections.ConfigDict() 11 | config.dict_configs.c = 1.0 12 | config.dict_configs.alpha = 1e-3 13 | config.dict_configs.method = 'lasso_lars' 14 | config.dict_configs.positive_code = False 15 | config.dict_configs.scale_code = False 16 | 17 | config.optim_configs = ml_collections.ConfigDict() 18 | config.optim_configs.lr = 3e-4 19 | config.optim_configs.max_norm = 1.0 20 | config.optim_configs.optim_algo = 'adam' 21 | config.optim_configs.clip_method = 'global_clip' 22 | 23 | config.actor_configs = ml_collections.ConfigDict() 24 | config.actor_configs.hidden_dims = (1024, 1024, 1024) 25 | config.actor_configs.name_activation = 'gelu' 26 | config.actor_configs.use_rms_norm = False 27 | config.actor_configs.use_layer_norm = False 28 | config.actor_configs.final_fc_init_scale = 1e-3 29 | config.actor_configs.clip_mean = 1.0 30 | config.actor_configs.state_dependent_std = True 31 | 32 | config.critic_configs = ml_collections.ConfigDict() 33 | config.critic_configs.hidden_dims = (256, 256, 256) 34 | config.critic_configs.name_activation = 'gelu' 35 | config.critic_configs.use_layer_norm = False 36 | 37 | config.tau = 0.005 38 | config.init_temperature = 1.0 39 | config.target_entropy = -2.0 # by default 40 | 41 | return config 42 | 43 | 44 | if __name__ == "__main__": 45 | 46 | # kwargs = dict(get_config()) 47 | # print(kwargs) 48 | yaml_dict = get_config().to_dict() 49 | with open('sac_cotasp.yaml', 'w') as file: 50 | yaml.dump(yaml_dict, file) 51 | 52 | # with open('sac_cotasp.yaml', 'r') as file: 53 | # yaml_dict = yaml.unsafe_load(file) 54 | 55 | # yaml_dict = yaml.unsafe_load(get_config().to_yaml()) 56 | # config = ml_collections.ConfigDict(yaml_dict) 57 | # print(dict(yaml_dict)) 58 | -------------------------------------------------------------------------------- /jaxrl/agents/drq/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Tuple 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | from tensorflow_probability.substrates import jax as tfp 7 | 8 | tfd = tfp.distributions 9 | 10 | from jaxrl.networks.common import default_init 11 | from jaxrl.networks.critic_net import DoubleCritic 12 | from jaxrl.networks.policies import NormalTanhPolicy 13 | 14 | 15 | class Encoder(nn.Module): 16 | features: Sequence[int] = (32, 32, 32, 32) 17 | strides: Sequence[int] = (2, 1, 1, 1) 18 | padding: str = 'VALID' 19 | 20 | @nn.compact 21 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 22 | assert len(self.features) == len(self.strides) 23 | 24 | x = observations.astype(jnp.float32) / 255.0 25 | for features, stride in zip(self.features, self.strides): 26 | x = nn.Conv(features, 27 | kernel_size=(3, 3), 28 | strides=(stride, stride), 29 | kernel_init=default_init(), 30 | padding=self.padding)(x) 31 | x = nn.relu(x) 32 | 33 | if len(x.shape) == 4: 34 | x = x.reshape([x.shape[0], -1]) 35 | else: 36 | x = x.reshape([-1]) 37 | return x 38 | 39 | 40 | class DrQDoubleCritic(nn.Module): 41 | hidden_dims: Sequence[int] 42 | cnn_features: Sequence[int] = (32, 32, 32, 32) 43 | cnn_strides: Sequence[int] = (2, 1, 1, 1) 44 | cnn_padding: str = 'VALID' 45 | latent_dim: int = 50 46 | 47 | @nn.compact 48 | def __call__(self, observations: jnp.ndarray, 49 | actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 50 | x = Encoder(self.cnn_features, 51 | self.cnn_strides, 52 | self.cnn_padding, 53 | name='SharedEncoder')(observations) 54 | 55 | x = nn.Dense(self.latent_dim)(x) 56 | x = nn.LayerNorm()(x) 57 | x = nn.tanh(x) 58 | 59 | return DoubleCritic(self.hidden_dims)(x, actions) 60 | 61 | 62 | class DrQPolicy(nn.Module): 63 | hidden_dims: Sequence[int] 64 | action_dim: int 65 | cnn_features: Sequence[int] = (32, 32, 32, 32) 66 | cnn_strides: Sequence[int] = (2, 1, 1, 1) 67 | cnn_padding: str = 'VALID' 68 | latent_dim: int = 50 69 | 70 | @nn.compact 71 | def __call__(self, 72 | observations: jnp.ndarray, 73 | temperature: float = 1.0) -> tfd.Distribution: 74 | x = Encoder(self.cnn_features, 75 | self.cnn_strides, 76 | self.cnn_padding, 77 | name='SharedEncoder')(observations) 78 | 79 | # We do not update conv layers with policy gradients. 80 | x = jax.lax.stop_gradient(x) 81 | 82 | x = nn.Dense(self.latent_dim)(x) 83 | x = nn.LayerNorm()(x) 84 | x = nn.tanh(x) 85 | 86 | return NormalTanhPolicy(self.hidden_dims, self.action_dim)(x, 87 | temperature) 88 | -------------------------------------------------------------------------------- /jaxrl/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import gym 4 | from gym.wrappers import RescaleAction 5 | from gym.wrappers.pixel_observation import PixelObservationWrapper 6 | import pandas as pd 7 | 8 | from jaxrl import wrappers 9 | 10 | 11 | def make_env(env_name: str, 12 | seed: int, 13 | save_folder: Optional[str] = None, 14 | add_episode_monitor: bool = True, 15 | action_repeat: int = 1, 16 | frame_stack: int = 1, 17 | from_pixels: bool = False, 18 | pixels_only: bool = True, 19 | image_size: int = 84, 20 | sticky: bool = False, 21 | gray_scale: bool = False, 22 | flatten: bool = True) -> gym.Env: 23 | # Check if the env is in gym. 24 | all_envs = gym.envs.registry.all() 25 | env_ids = [env_spec.id for env_spec in all_envs] 26 | 27 | if env_name in env_ids: 28 | env = gym.make(env_name) 29 | else: 30 | domain_name, task_name = env_name.split('-') 31 | env = wrappers.DMCEnv(domain_name=domain_name, 32 | task_name=task_name, 33 | task_kwargs={'random': seed}) 34 | 35 | if flatten and isinstance(env.observation_space, gym.spaces.Dict): 36 | env = gym.wrappers.FlattenObservation(env) 37 | 38 | if add_episode_monitor: 39 | env = wrappers.EpisodeMonitor(env) 40 | 41 | if action_repeat > 1: 42 | env = wrappers.RepeatAction(env, action_repeat) 43 | 44 | env = RescaleAction(env, -1.0, 1.0) 45 | 46 | if save_folder is not None: 47 | env = gym.wrappers.RecordVideo(env, save_folder) 48 | 49 | if from_pixels: 50 | if env_name in env_ids: 51 | camera_id = 0 52 | else: 53 | camera_id = 2 if domain_name == 'quadruped' else 0 54 | env = PixelObservationWrapper(env, 55 | pixels_only=pixels_only, 56 | render_kwargs={ 57 | 'pixels': { 58 | 'height': image_size, 59 | 'width': image_size, 60 | 'camera_id': camera_id 61 | } 62 | }) 63 | env = wrappers.TakeKey(env, take_key='pixels') 64 | if gray_scale: 65 | env = wrappers.RGB2Gray(env) 66 | else: 67 | env = wrappers.SinglePrecision(env) 68 | 69 | if frame_stack > 1: 70 | env = wrappers.FrameStack(env, num_stack=frame_stack) 71 | 72 | if sticky: 73 | env = wrappers.StickyActionEnv(env) 74 | 75 | env.seed(seed) 76 | env.action_space.seed(seed) 77 | env.observation_space.seed(seed) 78 | 79 | return env 80 | 81 | 82 | class Logger(object): 83 | def __init__(self, base_dir: str) -> None: 84 | self.base_dir = base_dir 85 | self.data = pd.DataFrame() 86 | 87 | def update(self, point: dict) -> None: 88 | self.data = self.data.append(point, ignore_index=True) 89 | 90 | def save(self) -> None: 91 | self.data.to_csv( 92 | self.base_dir+'/progress.csv', 93 | index=False) 94 | -------------------------------------------------------------------------------- /jaxrl/agents/bc/bc_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | from typing import Sequence 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | 10 | from jaxrl.agents.bc import actor 11 | from jaxrl.datasets import Batch 12 | from jaxrl.networks import autoregressive_policy, policies 13 | from jaxrl.networks.common import InfoDict, Model 14 | 15 | _log_prob_update_jit = jax.jit(actor.log_prob_update) 16 | _mse_update_jit = jax.jit(actor.mse_update) 17 | 18 | 19 | class BCLearner(object): 20 | 21 | def __init__(self, 22 | seed: int, 23 | observations: jnp.ndarray, 24 | actions: jnp.ndarray, 25 | actor_lr: float = 1e-3, 26 | num_steps: int = int(1e6), 27 | hidden_dims: Sequence[int] = (256, 256), 28 | distribution: str = 'det'): 29 | 30 | self.distribution = distribution 31 | 32 | rng = jax.random.PRNGKey(seed) 33 | rng, actor_key = jax.random.split(rng) 34 | 35 | action_dim = actions.shape[-1] 36 | if distribution == 'det': 37 | actor_def = policies.MSEPolicy(hidden_dims, 38 | action_dim, 39 | dropout_rate=0.1) 40 | elif distribution == 'mog': 41 | actor_def = policies.NormalTanhMixturePolicy(hidden_dims, 42 | action_dim, 43 | dropout_rate=0.1) 44 | elif distribution == 'made_mog': 45 | actor_def = autoregressive_policy.MADETanhMixturePolicy( 46 | hidden_dims, action_dim, dropout_rate=0.1) 47 | elif distribution == 'made_d': 48 | actor_def = autoregressive_policy.MADEDiscretizedPolicy( 49 | hidden_dims, action_dim, dropout_rate=0.1) 50 | else: 51 | raise NotImplemented 52 | 53 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, num_steps) 54 | optimiser = optax.chain(optax.scale_by_adam(), 55 | optax.scale_by_schedule(schedule_fn)) 56 | 57 | self.actor = Model.create(actor_def, 58 | inputs=[actor_key, observations], 59 | tx=optimiser) 60 | self.rng = rng 61 | 62 | def sample_actions(self, 63 | observations: np.ndarray, 64 | temperature: float = 1.0) -> jnp.ndarray: 65 | self.rng, actions = policies.sample_actions(self.rng, 66 | self.actor.apply_fn, 67 | self.actor.params, 68 | observations, temperature, 69 | self.distribution) 70 | 71 | actions = np.asarray(actions) 72 | return np.clip(actions, -1, 1) 73 | 74 | def update(self, batch: Batch) -> InfoDict: 75 | if self.distribution == 'det': 76 | self.rng, self.actor, info = _mse_update_jit( 77 | self.actor, batch, self.rng) 78 | else: 79 | self.rng, self.actor, info = _log_prob_update_jit( 80 | self.actor, batch, self.rng) 81 | return info 82 | -------------------------------------------------------------------------------- /jaxrl/datasets/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from jaxrl.datasets.dataset import Dataset 7 | 8 | 9 | class ReplayBuffer(Dataset): 10 | 11 | def __init__(self, observation_space: gym.spaces.Box, 12 | action_space: Union[gym.spaces.Discrete, 13 | gym.spaces.Box], capacity: int): 14 | 15 | observations = np.empty((capacity, *observation_space.shape), 16 | dtype=observation_space.dtype) 17 | actions = np.empty((capacity, *action_space.shape), 18 | dtype=action_space.dtype) 19 | rewards = np.empty((capacity, ), dtype=np.float32) 20 | masks = np.empty((capacity, ), dtype=np.float32) 21 | dones_float = np.empty((capacity, ), dtype=np.float32) 22 | next_observations = np.empty((capacity, *observation_space.shape), 23 | dtype=observation_space.dtype) 24 | super().__init__(observations=observations, 25 | actions=actions, 26 | rewards=rewards, 27 | masks=masks, 28 | dones_float=dones_float, 29 | next_observations=next_observations, 30 | size=0) 31 | 32 | self.size = 0 33 | 34 | self.insert_index = 0 35 | self.capacity = capacity 36 | 37 | def initialize_with_dataset(self, dataset: Dataset, 38 | num_samples: Optional[int]): 39 | assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' 40 | 41 | dataset_size = len(dataset.observations) 42 | 43 | if num_samples is None: 44 | num_samples = dataset_size 45 | else: 46 | num_samples = min(dataset_size, num_samples) 47 | assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' 48 | 49 | if num_samples < dataset_size: 50 | perm = np.random.permutation(dataset_size) 51 | indices = perm[:num_samples] 52 | else: 53 | indices = np.arange(num_samples) 54 | 55 | self.observations[:num_samples] = dataset.observations[indices] 56 | self.actions[:num_samples] = dataset.actions[indices] 57 | self.rewards[:num_samples] = dataset.rewards[indices] 58 | self.masks[:num_samples] = dataset.masks[indices] 59 | self.dones_float[:num_samples] = dataset.dones_float[indices] 60 | self.next_observations[:num_samples] = dataset.next_observations[ 61 | indices] 62 | 63 | self.insert_index = num_samples 64 | self.size = num_samples 65 | 66 | def insert(self, observation: np.ndarray, action: np.ndarray, 67 | reward: float, mask: float, done_float: float, 68 | next_observation: np.ndarray): 69 | self.observations[self.insert_index] = observation 70 | self.actions[self.insert_index] = action 71 | self.rewards[self.insert_index] = reward 72 | self.masks[self.insert_index] = mask 73 | self.dones_float[self.insert_index] = done_float 74 | self.next_observations[self.insert_index] = next_observation 75 | 76 | self.insert_index = (self.insert_index + 1) % self.capacity 77 | self.size = min(self.size + 1, self.capacity) 78 | -------------------------------------------------------------------------------- /jaxrl/wrappers/dmc_env.py: -------------------------------------------------------------------------------- 1 | # Taken from 2 | # https://github.com/denisyarats/dmc2gym 3 | # and modified to exclude duplicated code. 4 | 5 | import copy 6 | from typing import Dict, Optional, OrderedDict 7 | 8 | import dm_env 9 | import numpy as np 10 | from dm_control import suite 11 | from gym import core, spaces 12 | 13 | from jaxrl.wrappers.common import TimeStep 14 | 15 | 16 | def dmc_spec2gym_space(spec): 17 | if isinstance(spec, OrderedDict) or isinstance(spec, dict): 18 | spec = copy.copy(spec) 19 | for k, v in spec.items(): 20 | spec[k] = dmc_spec2gym_space(v) 21 | return spaces.Dict(spec) 22 | elif isinstance(spec, dm_env.specs.BoundedArray): 23 | return spaces.Box(low=spec.minimum, 24 | high=spec.maximum, 25 | shape=spec.shape, 26 | dtype=spec.dtype) 27 | elif isinstance(spec, dm_env.specs.Array): 28 | return spaces.Box(low=-float('inf'), 29 | high=float('inf'), 30 | shape=spec.shape, 31 | dtype=spec.dtype) 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | class DMCEnv(core.Env): 37 | 38 | def __init__(self, 39 | domain_name: Optional[str] = None, 40 | task_name: Optional[str] = None, 41 | env: Optional[dm_env.Environment] = None, 42 | task_kwargs: Optional[Dict] = {}, 43 | environment_kwargs=None): 44 | assert 'random' in task_kwargs, 'Please specify a seed, for deterministic behaviour.' 45 | assert ( 46 | env is not None 47 | or (domain_name is not None and task_name is not None) 48 | ), 'You must provide either an environment or domain and task names.' 49 | 50 | if env is None: 51 | env = suite.load(domain_name=domain_name, 52 | task_name=task_name, 53 | task_kwargs=task_kwargs, 54 | environment_kwargs=environment_kwargs) 55 | 56 | self._env = env 57 | self.action_space = dmc_spec2gym_space(self._env.action_spec()) 58 | 59 | self.observation_space = dmc_spec2gym_space( 60 | self._env.observation_spec()) 61 | 62 | self.seed(seed=task_kwargs['random']) 63 | 64 | def __getattr__(self, name): 65 | return getattr(self._env, name) 66 | 67 | def step(self, action: np.ndarray) -> TimeStep: 68 | assert self.action_space.contains(action) 69 | 70 | time_step = self._env.step(action) 71 | reward = time_step.reward or 0 72 | done = time_step.last() 73 | obs = time_step.observation 74 | 75 | info = {} 76 | if done and time_step.discount == 1.0: 77 | info['TimeLimit.truncated'] = True 78 | 79 | return obs, reward, done, info 80 | 81 | def reset(self): 82 | time_step = self._env.reset() 83 | return time_step.observation 84 | 85 | def render(self, 86 | mode='rgb_array', 87 | height: int = 84, 88 | width: int = 84, 89 | camera_id: int = 0): 90 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 91 | return self._env.physics.render(height=height, 92 | width=width, 93 | camera_id=camera_id) 94 | -------------------------------------------------------------------------------- /jaxrl/datasets/awac_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import d4rl 4 | import gym 5 | import numpy as np 6 | 7 | from jaxrl.datasets.dataset import Batch, Dataset 8 | 9 | # awac_demos corresponds to expert demonstrations. 10 | # awac_off corresponds to additional data collected 11 | # with BC trained on demonstrations. 12 | ENV_NAME_TO_FILE = { 13 | 'HalfCheetah-v2': { 14 | 'awac_off': 'hc_off_policy_15_demos_100.npy', 15 | 'awac_demo': 'hc_action_noise_15.npy' 16 | }, 17 | 'Walker2d-v2': { 18 | 'awac_off': 'walker_off_policy_15_demos_100.npy', 19 | 'awac_demo': 'walker_action_noise_15.npy' 20 | }, 21 | 'Ant-v2': { 22 | 'awac_off': 'ant_off_policy_15_demos_100.npy', 23 | 'awac_demo': 'ant_action_noise_15.npy' 24 | } 25 | } 26 | 27 | 28 | class AWACDataset(Dataset): 29 | 30 | def __init__(self, 31 | env_name: str, 32 | clip_to_eps: bool = True, 33 | eps: float = 1e-5): 34 | 35 | # Reuse d4rl path for now. 36 | dataset_path = os.path.join(d4rl.offline_env.DATASET_PATH, 'avac') 37 | zip_path = os.path.join(dataset_path, 'all.zip') 38 | 39 | url = 'https://drive.google.com/u/0/uc?id=1edcuicVv2d-PqH1aZUVbO5CeRq3lqK89' 40 | gdown.cached_download(url, zip_path, postprocess=gdown.extractall) 41 | 42 | observations = [] 43 | actions = [] 44 | rewards = [] 45 | terminals = [] 46 | dones_float = [] 47 | next_observations = [] 48 | 49 | env = gym.make(env_name) 50 | # Contacentate both datasets for now. 51 | for dataset_name in ['awac_off', 'awac_demo']: 52 | file_name = ENV_NAME_TO_FILE[env_name][dataset_name] 53 | 54 | dataset = np.load(os.path.join(dataset_path, file_name), 55 | allow_pickle=True) 56 | 57 | for trajectory in dataset: 58 | if len(trajectory['observations']) == env._max_episode_steps: 59 | trajectory['terminals'][-1] = False 60 | 61 | observations.append(trajectory['observations']) 62 | actions.append(trajectory['actions']) 63 | rewards.append(trajectory['rewards']) 64 | terminals.append(trajectory['terminals']) 65 | done_float = np.zeros_like(trajectory['rewards']) 66 | done_float[-1] = 1.0 67 | dones_float.append(done_float) 68 | next_observations.append(trajectory['next_observations']) 69 | 70 | observations = np.concatenate(observations, 0) 71 | actions = np.concatenate(actions, 0) 72 | rewards = np.concatenate(rewards, 0) 73 | terminals = np.concatenate(terminals, 0) 74 | dones_float = np.concatenate(dones_float, 0) 75 | next_observations = np.concatenate(next_observations, 0) 76 | 77 | if clip_to_eps: 78 | lim = 1 - eps 79 | actions = np.clip(actions, -lim, lim) 80 | 81 | super().__init__(observations=observations.astype(np.float32), 82 | actions=actions.astype(np.float32), 83 | rewards=rewards.astype(np.float32), 84 | masks=1.0 - terminals.astype(np.float32), 85 | dones_float=dones_float.astype(np.float32), 86 | next_observations=next_observations.astype( 87 | np.float32), 88 | size=len(observations)) 89 | -------------------------------------------------------------------------------- /jaxrl/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import gym 4 | import numpy as np 5 | import jax.numpy as jnp 6 | 7 | 8 | def evaluate(agent, env: gym.Env, num_episodes: int, with_task_embed=False, task_i=None) -> Dict[str, float]: 9 | stats = {'return': [], 'length': []} 10 | successes = None 11 | for _ in range(num_episodes): 12 | observation, done = env.reset(), False 13 | while not done: 14 | if with_task_embed: 15 | action = agent.sample_a_with_task_embed(observation[np.newaxis], temperature=0.0) 16 | else: 17 | if task_i is None: 18 | action = agent.sample_actions(observation[np.newaxis], temperature=0.0) 19 | else: 20 | action = agent.sample_actions(observation[np.newaxis], task_i, temperature=0.0) 21 | observation, _, done, info = env.step(action) 22 | for k in stats.keys(): 23 | stats[k].append(info['episode'][k]) 24 | 25 | if 'success' in info: 26 | if successes is None: 27 | successes = 0.0 28 | successes += info['success'] 29 | 30 | for k, v in stats.items(): 31 | stats[k] = np.mean(v) 32 | 33 | if successes is not None: 34 | stats['success'] = successes / num_episodes 35 | return stats 36 | 37 | def evaluate_cl(agent, envs: List[gym.Env], num_episodes: int, naive_sac=False, tadell=False) -> Dict[str, float]: 38 | stats = {} 39 | sum_return = 0.0 40 | sum_success = 0.0 41 | list_log_keys = ['return'] 42 | 43 | # dummy inputs 44 | # dummy_obs = jnp.ones((128, 12)) 45 | 46 | for task_i, env in enumerate(envs): 47 | for k in list_log_keys: 48 | stats[f'{task_i}-{env.name}/{k}'] = [] 49 | successes = None 50 | 51 | if tadell: 52 | agent.select_actor(task_i) 53 | 54 | for _ in range(num_episodes): 55 | observation, done = env.reset(), False 56 | while not done: 57 | 58 | if naive_sac: 59 | action = agent.sample_actions(observation[np.newaxis], temperature=0) 60 | action = np.asarray(action, dtype=np.float32).flatten() 61 | elif tadell: 62 | action = agent.sample_actions(observation[np.newaxis], temperature=0, eval_mode=True) 63 | action = np.asarray(action, dtype=np.float32).flatten() 64 | else: 65 | action = agent.sample_actions(observation[np.newaxis], task_i, temperature=0) 66 | action = np.asarray(action, dtype=np.float32).flatten() 67 | 68 | observation, _, done, info = env.step(action) 69 | 70 | for k in list_log_keys: 71 | stats[f'{task_i}-{env.name}/{k}'].append(info['episode'][k]) 72 | 73 | if 'success' in info: 74 | if successes is None: 75 | successes = 0.0 76 | successes += info['success'] 77 | 78 | for k in list_log_keys: 79 | stats[f'{task_i}-{env.name}/{k}'] = np.mean(stats[f'{task_i}-{env.name}/{k}']) 80 | 81 | if successes is not None: 82 | stats[f'{task_i}-{env.name}/success'] = successes / num_episodes 83 | sum_success += stats[f'{task_i}-{env.name}/success'] 84 | 85 | sum_return += stats[f'{task_i}-{env.name}/return'] 86 | 87 | # stats[f'{task_i}-{env.name}/check_dummy_action'] = agent.sample_actions(dummy_obs, task_i, temperature=0).mean() 88 | 89 | stats['avg_return'] = sum_return / len(envs) 90 | stats['test/deterministic/average_success'] = sum_success / len(envs) 91 | 92 | return stats 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Ignore all files in the 'logs' subfolder. 10 | logs/ 11 | *.pdf 12 | 13 | # Ignore all files in the 'configs' subfolder. 14 | configs/ 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ -------------------------------------------------------------------------------- /jaxrl/agents/ddpg/ddpg_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | import functools 4 | from typing import Sequence, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | from jaxrl.agents.ddpg.actor import update as update_actor 12 | from jaxrl.agents.ddpg.critic import update as update_critic 13 | from jaxrl.agents.sac.critic import target_update 14 | from jaxrl.datasets import Batch 15 | from jaxrl.networks import critic_net, policies 16 | from jaxrl.networks.common import InfoDict, Model, PRNGKey 17 | 18 | 19 | @functools.partial(jax.jit, static_argnames=('update_target')) 20 | def _update_jit( 21 | actor: Model, critic: Model, target_critic: Model, batch: Batch, 22 | discount: float, tau: float, update_target: bool 23 | ) -> Tuple[PRNGKey, Model, Model, Model, Model, InfoDict]: 24 | 25 | new_critic, critic_info = update_critic(actor, critic, target_critic, 26 | batch, discount) 27 | if update_target: 28 | new_target_critic = target_update(new_critic, target_critic, tau) 29 | else: 30 | new_target_critic = target_critic 31 | 32 | new_actor, actor_info = update_actor(actor, new_critic, batch) 33 | 34 | return new_actor, new_critic, new_target_critic, { 35 | **critic_info, 36 | **actor_info, 37 | } 38 | 39 | 40 | class DDPGLearner(object): 41 | 42 | def __init__(self, 43 | seed: int, 44 | observations: jnp.ndarray, 45 | actions: jnp.ndarray, 46 | actor_lr: float = 3e-4, 47 | critic_lr: float = 3e-4, 48 | hidden_dims: Sequence[int] = (256, 256), 49 | discount: float = 0.99, 50 | tau: float = 0.005, 51 | target_update_period: int = 1, 52 | exploration_noise: float = 0.1): 53 | """ 54 | An implementation of [Deep Deterministic Policy Gradient](https://arxiv.org/abs/1509.02971) 55 | and Clipped Double Q-Learning (https://arxiv.org/abs/1802.09477). 56 | """ 57 | 58 | action_dim = actions.shape[-1] 59 | 60 | self.tau = tau 61 | self.target_update_period = target_update_period 62 | self.discount = discount 63 | self.exploration_noise = exploration_noise 64 | 65 | rng = jax.random.PRNGKey(seed) 66 | rng, actor_key, critic_key = jax.random.split(rng, 3) 67 | 68 | actor_def = policies.MSEPolicy(hidden_dims, action_dim) 69 | actor = Model.create(actor_def, 70 | inputs=[actor_key, observations], 71 | tx=optax.adam(learning_rate=actor_lr)) 72 | 73 | critic_def = critic_net.DoubleCritic(hidden_dims) 74 | critic = Model.create(critic_def, 75 | inputs=[critic_key, observations, actions], 76 | tx=optax.adam(learning_rate=critic_lr)) 77 | target_critic = Model.create( 78 | critic_def, inputs=[critic_key, observations, actions]) 79 | 80 | self.actor = actor 81 | self.critic = critic 82 | self.target_critic = target_critic 83 | self.rng = rng 84 | 85 | self.step = 1 86 | 87 | def sample_actions(self, 88 | observations: np.ndarray, 89 | temperature: float = 1.0) -> jnp.ndarray: 90 | rng, actions = policies.sample_actions(self.rng, 91 | self.actor.apply_fn, 92 | self.actor.params, 93 | observations, 94 | temperature, 95 | distribution='det') 96 | self.rng = rng 97 | 98 | actions = np.asarray(actions) 99 | actions = actions + np.random.normal( 100 | size=actions.shape) * self.exploration_noise * temperature 101 | return np.clip(actions, -1, 1) 102 | 103 | def update(self, batch: Batch) -> InfoDict: 104 | self.step += 1 105 | 106 | new_actor, new_critic, new_target_critic, info = _update_jit( 107 | self.actor, self.critic, self.target_critic, batch, self.discount, 108 | self.tau, self.step % self.target_update_period == 0) 109 | 110 | self.actor = new_actor 111 | self.critic = new_critic 112 | self.target_critic = new_target_critic 113 | 114 | return info 115 | -------------------------------------------------------------------------------- /jaxrl/dict_learning/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def heatmap(data, row_labels, col_labels, ax=None, 7 | cbar_kw=None, font_label=None, x_label="", 8 | y_label="", cbarlabel="", **kwargs): 9 | """ 10 | Create a heatmap from a numpy array and two lists of labels. 11 | 12 | Parameters 13 | ---------- 14 | data 15 | A 2D numpy array of shape (M, N). 16 | row_labels 17 | A list or array of length M with the labels for the rows. 18 | col_labels 19 | A list or array of length N with the labels for the columns. 20 | ax 21 | A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 22 | not provided, use current axes or create a new one. Optional. 23 | cbar_kw 24 | A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 25 | cbarlabel 26 | The label for the colorbar. Optional. 27 | **kwargs 28 | All other arguments are forwarded to `imshow`. 29 | """ 30 | 31 | if ax is None: 32 | ax = plt.gca() 33 | 34 | if cbar_kw is None: 35 | cbar_kw = {} 36 | 37 | # Plot the heatmap 38 | im = ax.imshow(data, **kwargs) 39 | 40 | # Create colorbar 41 | cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 42 | cbar.ax.set_ylabel(cbarlabel, font_label, rotation=-90, va="bottom") 43 | 44 | # set XY labels 45 | ax.set_xlabel(x_label, font_label) 46 | ax.set_ylabel(y_label, font_label) 47 | 48 | # Show all ticks and label them with the respective list entries. 49 | ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) 50 | ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) 51 | 52 | # Let the horizontal axes labeling appear on top. 53 | ax.tick_params(top=True, bottom=False, 54 | labeltop=True, labelbottom=False) 55 | 56 | # Rotate the tick labels and set their alignment. 57 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 58 | rotation_mode="anchor") 59 | 60 | # Turn spines off and create white grid. 61 | ax.spines[:].set_visible(False) 62 | 63 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 64 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 65 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 66 | ax.tick_params(which="minor", bottom=False, left=False) 67 | 68 | return im, cbar 69 | 70 | 71 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 72 | textcolors=("black", "white"), 73 | threshold=None, **textkw): 74 | """ 75 | A function to annotate a heatmap. 76 | 77 | Parameters 78 | ---------- 79 | im 80 | The AxesImage to be labeled. 81 | data 82 | Data used to annotate. If None, the image's data is used. Optional. 83 | valfmt 84 | The format of the annotations inside the heatmap. This should either 85 | use the string format method, e.g. "$ {x:.2f}", or be a 86 | `matplotlib.ticker.Formatter`. Optional. 87 | textcolors 88 | A pair of colors. The first is used for values below a threshold, 89 | the second for those above. Optional. 90 | threshold 91 | Value in data units according to which the colors from textcolors are 92 | applied. If None (the default) uses the middle of the colormap as 93 | separation. Optional. 94 | **kwargs 95 | All other arguments are forwarded to each call to `text` used to create 96 | the text labels. 97 | """ 98 | 99 | if not isinstance(data, (list, np.ndarray)): 100 | data = im.get_array() 101 | 102 | # Normalize the threshold to the images color range. 103 | if threshold is not None: 104 | threshold = im.norm(threshold) 105 | else: 106 | threshold = im.norm(data.max())/2. 107 | 108 | # Set default alignment to center, but allow it to be 109 | # overwritten by textkw. 110 | kw = dict(horizontalalignment="center", 111 | verticalalignment="center") 112 | kw.update(textkw) 113 | 114 | # Get the formatter in case a string is supplied 115 | if isinstance(valfmt, str): 116 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 117 | 118 | # Loop over the data and create a `Text` for each "pixel". 119 | # Change the text's color depending on the data. 120 | texts = [] 121 | for i in range(data.shape[0]): 122 | for j in range(data.shape[1]): 123 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 124 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 125 | texts.append(text) 126 | 127 | return texts -------------------------------------------------------------------------------- /jaxrl/networks/lion_optax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Optax implementation of the Lion optimizer.""" 16 | 17 | from typing import Any, Callable, NamedTuple, Optional, Union 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | import optax 23 | 24 | 25 | def _scale_by_learning_rate( 26 | learning_rate: optax.ScalarOrSchedule, flip_sign=True): 27 | m = -1 if flip_sign else 1 28 | if callable(learning_rate): 29 | return optax.scale_by_schedule(lambda count: m * learning_rate(count)) 30 | return optax.scale(m * learning_rate) 31 | 32 | 33 | def lion( 34 | learning_rate: optax.ScalarOrSchedule, 35 | b1: float = 0.9, 36 | b2: float = 0.99, 37 | mu_dtype: Optional[Any] = None, 38 | weight_decay: float = 0.0, 39 | mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, 40 | ) -> optax.GradientTransformation: 41 | """Lion. 42 | 43 | Args: 44 | learning_rate: A fixed global scaling factor. 45 | b1: Exponential decay rate to combine the gradient and the moment. 46 | b2: Exponential decay rate to track the moment of past gradients. 47 | mu_dtype: Optional `dtype` to be used for the first order accumulator; if 48 | `None` then the `dtype` is inferred from `params` and `updates`. 49 | weight_decay: Strength of the weight decay regularization. Note that this 50 | weight decay is multiplied with the learning rate. This is consistent 51 | with other frameworks such as PyTorch, but different from 52 | (Loshchilov et al, 2019) where the weight decay is only multiplied with 53 | the "schedule multiplier", but not the base learning rate. 54 | mask: A tree with same structure as (or a prefix of) the params PyTree, 55 | or a Callable that returns such a pytree given the params/updates. 56 | The leaves should be booleans, `True` for leaves/subtrees you want to 57 | apply the weight decay to, and `False` for those you want to skip. Note 58 | that the Adam gradient transformations are applied to all parameters. 59 | 60 | Returns: 61 | The corresponding `GradientTransformation`. 62 | """ 63 | return optax.chain( 64 | scale_by_lion( 65 | b1=b1, b2=b2, mu_dtype=mu_dtype), 66 | optax.add_decayed_weights(weight_decay, mask), 67 | _scale_by_learning_rate(learning_rate), 68 | ) 69 | 70 | 71 | def update_moment(updates, moments, decay, order): 72 | """Compute the exponential moving average of the `order`-th moment.""" 73 | return jax.tree_util.tree_map( 74 | lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) 75 | 76 | 77 | class ScaleByLionState(NamedTuple): 78 | """State for the Lion algorithm.""" 79 | count: chex.Array # shape=(), dtype=jnp.int32. 80 | mu: optax.Updates 81 | 82 | 83 | def scale_by_lion( 84 | b1: float = 0.9, 85 | b2: float = 0.99, 86 | mu_dtype: Optional[Any] = None, 87 | ) -> optax.GradientTransformation: 88 | """Rescale updates according to the Lion algorithm. 89 | 90 | Args: 91 | b1: rate for combining moment and the current grad. 92 | b2: decay rate for the exponentially weighted average of grads. 93 | mu_dtype: optional `dtype` to be used for the first order accumulator; if 94 | `None` then the `dtype is inferred from `params` and `updates`. 95 | 96 | Returns: 97 | A `GradientTransformation` object. 98 | """ 99 | 100 | def init_fn(params): 101 | mu = jax.tree_util.tree_map( # moment 102 | lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) 103 | return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu) 104 | 105 | def update_fn(updates, state, params=None): 106 | del params 107 | mu = update_moment(updates, state.mu, b2, 1) 108 | mu = jax.tree_map(lambda x: x.astype(mu_dtype), mu) 109 | count_inc = optax.safe_int32_increment(state.count) 110 | updates = jax.tree_util.tree_map( 111 | lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu) 112 | return updates, ScaleByLionState(count=count_inc, mu=mu) 113 | 114 | return optax.GradientTransformation(init_fn, update_fn) -------------------------------------------------------------------------------- /jaxrl/wrappers/normalization.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | """Set of wrappers for normalizing actions and observations.""" 5 | import numpy as np 6 | 7 | import gym 8 | 9 | 10 | # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py 11 | class RunningMeanStd: 12 | """Tracks the mean, variance and count of values.""" 13 | 14 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 15 | def __init__(self, epsilon=1e-4, shape=()): 16 | """Tracks the mean, variance and count of values.""" 17 | self.mean = np.zeros(shape, "float64") 18 | self.var = np.ones(shape, "float64") 19 | self.count = epsilon 20 | 21 | def update(self, x): 22 | """Updates the mean, var and count from a batch of samples.""" 23 | batch_mean = np.mean(x, axis=0) 24 | batch_var = np.var(x, axis=0) 25 | batch_count = x.shape[0] 26 | self.update_from_moments(batch_mean, batch_var, batch_count) 27 | 28 | def update_from_moments(self, batch_mean, batch_var, batch_count): 29 | """Updates from batch mean, variance and count moments.""" 30 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 31 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 32 | ) 33 | 34 | 35 | def update_mean_var_count_from_moments( 36 | mean, var, count, batch_mean, batch_var, batch_count 37 | ): 38 | """Updates the mean, var and count using the previous mean, var, count and batch values.""" 39 | delta = batch_mean - mean 40 | tot_count = count + batch_count 41 | 42 | new_mean = mean + delta * batch_count / tot_count 43 | m_a = var * count 44 | m_b = batch_var * batch_count 45 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 46 | new_var = M2 / tot_count 47 | new_count = tot_count 48 | 49 | return new_mean, new_var, new_count 50 | 51 | 52 | class NormalizeReward(gym.core.Wrapper): 53 | r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. 54 | The exponential moving average will have variance :math:`(1 - \gamma)^2`. 55 | Note: 56 | The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly 57 | instantiated or the policy was changed recently. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | env: gym.Env, 63 | epsilon: float = 1e-8, 64 | reward_alpha=0.001, 65 | ): 66 | """This wrapper will normalize immediate rewards 67 | Args: 68 | env (env): The environment to apply the wrapper 69 | epsilon (float): A stability parameter 70 | """ 71 | super().__init__(env) 72 | self.epsilon = epsilon 73 | self._reward_alpha = reward_alpha 74 | self._reward_mean = 0. 75 | self._reward_var = 1. 76 | 77 | def step(self, action): 78 | """Steps through the environment, normalizing the rewards returned.""" 79 | obs, rews, dones, infos = self.env.step(action) 80 | rews = self._apply_normalize_reward(rews) 81 | return obs, rews, dones, infos 82 | 83 | def _update_reward_estimate(self, reward): 84 | self._reward_mean = (1 - self._reward_alpha) * \ 85 | self._reward_mean + self._reward_alpha * reward 86 | self._reward_var = ( 87 | 1 - self._reward_alpha 88 | ) * self._reward_var + self._reward_alpha * np.square( 89 | reward - self._reward_mean) 90 | 91 | def _apply_normalize_reward(self, reward): 92 | """Compute normalized reward. 93 | Args: 94 | reward (float): Reward. 95 | Returns: 96 | float: Normalized reward. 97 | """ 98 | self._update_reward_estimate(reward) 99 | return reward / (np.sqrt(self._reward_var) + self.epsilon) 100 | 101 | 102 | class RescaleReward(gym.core.Wrapper): 103 | ''' 104 | This wrapper will rescale immediate rewards based on a constant factor. 105 | ''' 106 | def __init__(self, env: gym.Env, reward_scale: float = 1.0): 107 | super().__init__(env) 108 | self.reward_scale = reward_scale 109 | 110 | def step(self, action): 111 | obs, rews, dones, infos = self.env.step(action) 112 | rews = rews * self.reward_scale 113 | return obs, rews, dones, infos 114 | 115 | 116 | if __name__ == "__main__": 117 | 118 | def print_reward(env: gym.Env): 119 | obs, done = env.reset(), False 120 | i = 0 121 | while not done: 122 | i += 1 123 | next_obs, rew, done, _ = env.step(env.action_space.sample()) 124 | print(i, rew) 125 | 126 | env = gym.make('Hopper-v3') 127 | env_wrapped = NormalizeReward(env) 128 | 129 | print_reward(env) 130 | print_reward(env_wrapped) 131 | -------------------------------------------------------------------------------- /jaxrl/datasets/rl_unplugged/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import d4rl 4 | import dm_control_suite 5 | import gym 6 | import numpy as np 7 | from absl import app, flags 8 | from scipy.spatial import cKDTree 9 | from tqdm import tqdm 10 | 11 | from jaxrl import wrappers 12 | 13 | flags.DEFINE_string('path', '/home/kostrikov/datasets/', 'Path to dataset.') 14 | flags.DEFINE_string('task_name', 'cheetah_run', 'Game.') 15 | flags.DEFINE_enum('task_class', 'control_suite', 16 | ['humanoid', 'rodent', 'control_suite'], 'Task classes.') 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | 21 | def main(_): 22 | if FLAGS.task_class == 'control_suite': 23 | task = dm_control_suite.ControlSuite(task_name=FLAGS.task_name) 24 | elif FLAGS.task_class == 'humanoid': 25 | task = dm_control_suite.CmuThirdParty(task_name=FLAGS.task_name) 26 | elif FLAGS.task_class == 'rodent': 27 | task = dm_control_suite.Rodent(task_name=FLAGS.task_name) 28 | 29 | environment = task.environment 30 | env = wrappers.DMCEnv(env=environment, task_kwargs={'random': 0}) 31 | 32 | ds = dm_control_suite.dataset(root_path=FLAGS.path, 33 | data_path=task.data_path, 34 | shapes=task.shapes, 35 | num_threads=1, 36 | uint8_features=task.uint8_features, 37 | num_shards=100) 38 | observations = [] 39 | actions = [] 40 | rewards = [] 41 | masks = [] 42 | next_observations = [] 43 | 44 | print("Reading the dataset") 45 | for i, sample in tqdm(enumerate(ds)): 46 | obs = gym.spaces.flatten(env.observation_space, 47 | sample.data[0]).astype(np.float32) 48 | action = gym.spaces.flatten(env.action_space, 49 | sample.data[1]).astype(np.float32) 50 | reward = float(sample.data[2].numpy().item()) 51 | mask = float(sample.data[3].numpy().item()) 52 | next_obs = gym.spaces.flatten(env.observation_space, 53 | sample.data[4]).astype(np.float32) 54 | 55 | observations.append(obs) 56 | actions.append(action) 57 | rewards.append(reward) 58 | masks.append(mask) 59 | next_observations.append(next_obs) 60 | 61 | # The datasets are shuffles even in the original files. 62 | # The code below unshuffles them. 63 | kdtree = cKDTree(observations) 64 | _, inds = kdtree.query(next_observations, distance_upper_bound=1e-5) 65 | 66 | kdtree = cKDTree(next_observations) 67 | dists_, _ = kdtree.query(observations, distance_upper_bound=1e-5) 68 | 69 | ordered_observations = [] 70 | ordered_actions = [] 71 | ordered_rewards = [] 72 | ordered_masks = [] 73 | ordered_next_observations = [] 74 | 75 | print("Reordering the dataset") 76 | for i in tqdm(range(len(observations))): 77 | if dists_[i] > 0: 78 | j = i 79 | while j < len(observations): 80 | ordered_observations.append(observations[j]) 81 | ordered_actions.append(actions[j]) 82 | ordered_rewards.append(rewards[j]) 83 | ordered_masks.append(masks[j]) 84 | ordered_next_observations.append(next_observations[j]) 85 | j = inds[j] 86 | 87 | print("Verifying the dataset") 88 | prev_i = -1 89 | 90 | ordered_done_floats = [] 91 | for i in tqdm(range(len(ordered_observations))): 92 | if (i == len(ordered_observations) - 1 93 | or np.linalg.norm(ordered_observations[i + 1] - 94 | ordered_next_observations[i]) > 1e-5): 95 | assert i - prev_i == 1000 96 | prev_i = i 97 | ordered_done_floats.append(1.0) 98 | else: 99 | ordered_done_floats.append(0.0) 100 | 101 | print(f"Dataset size: {len(ordered_observations)}") 102 | 103 | ordered_observations = np.stack(ordered_observations) 104 | ordered_actions = np.stack(ordered_actions) 105 | ordered_rewards = np.stack(ordered_rewards) 106 | ordered_masks = np.stack(ordered_masks) 107 | ordered_done_floats = np.stack(ordered_done_floats) 108 | ordered_next_observations = np.stack(ordered_next_observations) 109 | 110 | save_dir = os.path.join(d4rl.offline_env.DATASET_PATH, 'rl_unplugged') 111 | os.makedirs(save_dir, exist_ok=True) 112 | with open(os.path.join(save_dir, f'{FLAGS.task_name}.npz'), 'wb') as f: 113 | np.savez_compressed(f, 114 | observations=ordered_observations, 115 | actions=ordered_actions, 116 | rewards=ordered_rewards, 117 | masks=ordered_masks, 118 | done_floats=ordered_done_floats, 119 | next_observations=ordered_next_observations) 120 | 121 | 122 | if __name__ == '__main__': 123 | app.run(main) 124 | -------------------------------------------------------------------------------- /jaxrl/agents/awac/awac_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | import functools 4 | from typing import Sequence, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | import jaxrl.agents.awac.actor as awr_actor 12 | import jaxrl.agents.sac.critic as sac_critic 13 | from jaxrl.datasets import Batch 14 | from jaxrl.networks import critic_net, policies 15 | from jaxrl.networks.common import InfoDict, Model, PRNGKey 16 | 17 | 18 | @functools.partial(jax.jit, static_argnames=('update_target', 'num_samples')) 19 | def _update_jit( 20 | rng: PRNGKey, actor: Model, critic: Model, target_critic: Model, 21 | batch: Batch, discount: float, tau: float, num_samples: int, 22 | beta: float, 23 | update_target: bool) -> Tuple[PRNGKey, Model, Model, Model, InfoDict]: 24 | 25 | rng, key = jax.random.split(rng) 26 | new_critic, critic_info = sac_critic.update(key, 27 | actor, 28 | critic, 29 | target_critic, 30 | None, 31 | batch, 32 | discount, 33 | soft_critic=False) 34 | if update_target: 35 | new_target_critic = sac_critic.target_update(new_critic, target_critic, 36 | tau) 37 | else: 38 | new_target_critic = target_critic 39 | 40 | rng, key = jax.random.split(rng) 41 | new_actor, actor_info = awr_actor.update(key, actor, new_critic, batch, 42 | num_samples, beta) 43 | 44 | return rng, new_actor, new_critic, new_target_critic, { 45 | **critic_info, 46 | **actor_info 47 | } 48 | 49 | 50 | class AWACLearner(object): 51 | 52 | def __init__(self, 53 | seed: int, 54 | observations: jnp.ndarray, 55 | actions: jnp.ndarray, 56 | actor_optim_kwargs: dict = { 57 | 'learning_rate': 3e-4, 58 | 'weight_decay': 1e-4 59 | }, 60 | actor_hidden_dims: Sequence[int] = (256, 256, 256, 256), 61 | state_dependent_std: bool = False, 62 | critic_lr: float = 3e-4, 63 | critic_hidden_dims: Sequence[int] = (256, 256), 64 | num_samples: int = 1, 65 | discount: float = 0.99, 66 | tau: float = 0.005, 67 | target_update_period: int = 1, 68 | beta: float = 1.0): 69 | 70 | action_dim = actions.shape[-1] 71 | 72 | self.tau = tau 73 | self.target_update_period = target_update_period 74 | self.discount = discount 75 | self.num_samples = num_samples 76 | self.beta = beta 77 | 78 | rng = jax.random.PRNGKey(seed) 79 | rng, actor_key, critic_key = jax.random.split(rng, 3) 80 | 81 | actor_def = policies.NormalTanhPolicy( 82 | actor_hidden_dims, 83 | action_dim, 84 | state_dependent_std=state_dependent_std, 85 | tanh_squash_distribution=False) 86 | actor = Model.create(actor_def, 87 | inputs=[actor_key, observations], 88 | tx=optax.adamw(**actor_optim_kwargs)) 89 | 90 | critic_def = critic_net.DoubleCritic(critic_hidden_dims) 91 | critic = Model.create(critic_def, 92 | inputs=[critic_key, observations, actions], 93 | tx=optax.adam(learning_rate=critic_lr)) 94 | 95 | target_critic = Model.create( 96 | critic_def, inputs=[critic_key, observations, actions]) 97 | 98 | self.actor = actor 99 | self.critic = critic 100 | self.target_critic = target_critic 101 | self.rng = rng 102 | self.step = 1 103 | 104 | def sample_actions(self, 105 | observations: np.ndarray, 106 | temperature: float = 1.0) -> jnp.ndarray: 107 | rng, actions = policies.sample_actions(self.rng, self.actor.apply_fn, 108 | self.actor.params, observations, 109 | temperature) 110 | 111 | self.rng = rng 112 | 113 | actions = np.asarray(actions) 114 | return np.clip(actions, -1, 1) 115 | 116 | def update(self, batch: Batch) -> InfoDict: 117 | self.step += 1 118 | new_rng, new_actor, new_critic, new_target_network, info = _update_jit( 119 | self.rng, self.actor, self.critic, self.target_critic, batch, 120 | self.discount, self.tau, self.num_samples, self.beta, 121 | self.step % self.target_update_period == 0) 122 | 123 | self.rng = new_rng 124 | self.actor = new_actor 125 | self.critic = new_critic 126 | self.target_critic = new_target_network 127 | 128 | return info 129 | -------------------------------------------------------------------------------- /jaxrl/agents/sac_v1/sac_v1_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | import functools 4 | from typing import Optional, Sequence, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | from jaxrl.agents.sac import temperature 12 | from jaxrl.agents.sac.actor import update as update_actor 13 | from jaxrl.agents.sac.critic import target_update 14 | from jaxrl.agents.sac_v1.critic import update_q, update_v 15 | from jaxrl.datasets import Batch 16 | from jaxrl.networks import critic_net, policies 17 | from jaxrl.networks.common import InfoDict, Model, PRNGKey 18 | 19 | 20 | @functools.partial(jax.jit, static_argnames=('update_target')) 21 | def _update_jit( 22 | rng: PRNGKey, actor: Model, critic: Model, value: Model, 23 | target_value: Model, temp: Model, batch: Batch, discount: float, 24 | tau: float, target_entropy: float, update_target: bool 25 | ) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]: 26 | 27 | new_critic, critic_info = update_q(critic, target_value, batch, discount) 28 | 29 | rng, key = jax.random.split(rng) 30 | new_actor, actor_info = update_actor(key, actor, new_critic, temp, batch) 31 | 32 | rng, key = jax.random.split(rng) 33 | new_value, value_info = update_v(key, new_actor, new_critic, value, temp, 34 | batch, True) 35 | 36 | if update_target: 37 | new_target_value = target_update(new_value, target_value, tau) 38 | else: 39 | new_target_value = target_value 40 | 41 | new_temp, alpha_info = temperature.update(temp, actor_info['entropy'], 42 | target_entropy) 43 | 44 | return rng, new_actor, new_critic, new_value, new_target_value, new_temp, { 45 | **critic_info, 46 | **value_info, 47 | **actor_info, 48 | **alpha_info 49 | } 50 | 51 | 52 | class SACV1Learner(object): 53 | 54 | def __init__(self, 55 | seed: int, 56 | observations: jnp.ndarray, 57 | actions: jnp.ndarray, 58 | actor_lr: float = 3e-4, 59 | value_lr: float = 3e-4, 60 | critic_lr: float = 3e-4, 61 | temp_lr: float = 3e-4, 62 | hidden_dims: Sequence[int] = (256, 256), 63 | discount: float = 0.99, 64 | tau: float = 0.005, 65 | target_update_period: int = 1, 66 | target_entropy: Optional[float] = None, 67 | init_temperature: float = 1.0): 68 | """ 69 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290 70 | """ 71 | 72 | action_dim = actions.shape[-1] 73 | 74 | if target_entropy is None: 75 | self.target_entropy = -action_dim / 2 76 | else: 77 | self.target_entropy = target_entropy 78 | 79 | self.tau = tau 80 | self.target_update_period = target_update_period 81 | self.discount = discount 82 | 83 | rng = jax.random.PRNGKey(seed) 84 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 85 | 86 | actor_def = policies.NormalTanhPolicy(hidden_dims, action_dim) 87 | actor = Model.create(actor_def, 88 | inputs=[actor_key, observations], 89 | tx=optax.adam(learning_rate=actor_lr)) 90 | 91 | critic_def = critic_net.DoubleCritic(hidden_dims) 92 | critic = Model.create(critic_def, 93 | inputs=[critic_key, observations, actions], 94 | tx=optax.adam(learning_rate=critic_lr)) 95 | 96 | value_def = critic_net.ValueCritic(hidden_dims) 97 | value = Model.create(value_def, 98 | inputs=[critic_key, observations], 99 | tx=optax.adam(learning_rate=value_lr)) 100 | 101 | target_value = Model.create(value_def, 102 | inputs=[critic_key, observations]) 103 | 104 | temp = Model.create(temperature.Temperature(init_temperature), 105 | inputs=[temp_key], 106 | tx=optax.adam(learning_rate=temp_lr)) 107 | 108 | self.actor = actor 109 | self.critic = critic 110 | self.value = value 111 | self.target_value = target_value 112 | self.temp = temp 113 | self.rng = rng 114 | 115 | self.step = 1 116 | 117 | def sample_actions(self, 118 | observations: np.ndarray, 119 | temperature: float = 1.0) -> jnp.ndarray: 120 | rng, actions = policies.sample_actions(self.rng, self.actor.apply_fn, 121 | self.actor.params, observations, 122 | temperature) 123 | self.rng = rng 124 | 125 | actions = np.asarray(actions) 126 | return np.clip(actions, -1, 1) 127 | 128 | def update(self, batch: Batch) -> InfoDict: 129 | self.step += 1 130 | 131 | new_rng, new_actor, new_critic, new_value, new_target_value, new_temp, info = _update_jit( 132 | self.rng, self.actor, self.critic, self.value, self.target_value, 133 | self.temp, batch, self.discount, self.tau, self.target_entropy, 134 | self.step % self.target_update_period == 0) 135 | 136 | self.rng = new_rng 137 | self.actor = new_actor 138 | self.critic = new_critic 139 | self.value = new_value 140 | self.target_value = new_target_value 141 | self.temp = new_temp 142 | 143 | return info 144 | -------------------------------------------------------------------------------- /jaxrl/agents/drq/drq_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | 3 | import functools 4 | from typing import Optional, Sequence, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | from jaxrl.agents.drq.augmentations import batched_random_crop 12 | from jaxrl.agents.drq.networks import DrQDoubleCritic, DrQPolicy 13 | from jaxrl.agents.sac import temperature 14 | from jaxrl.agents.sac.actor import update as update_actor 15 | from jaxrl.agents.sac.critic import target_update 16 | from jaxrl.agents.sac.critic import update as update_critic 17 | from jaxrl.datasets import Batch 18 | from jaxrl.networks import policies 19 | from jaxrl.networks.common import InfoDict, Model, PRNGKey 20 | 21 | 22 | @functools.partial(jax.jit, static_argnames=('update_target')) 23 | def _update_jit( 24 | rng: PRNGKey, actor: Model, critic: Model, target_critic: Model, 25 | temp: Model, batch: Batch, discount: float, tau: float, 26 | target_entropy: float, update_target: bool 27 | ) -> Tuple[PRNGKey, Model, Model, Model, Model, InfoDict]: 28 | 29 | rng, key = jax.random.split(rng) 30 | observations = batched_random_crop(key, batch.observations) 31 | rng, key = jax.random.split(rng) 32 | next_observations = batched_random_crop(key, batch.next_observations) 33 | 34 | batch = batch._replace(observations=observations, 35 | next_observations=next_observations) 36 | 37 | rng, key = jax.random.split(rng) 38 | new_critic, critic_info = update_critic(key, 39 | actor, 40 | critic, 41 | target_critic, 42 | temp, 43 | batch, 44 | discount, 45 | backup_entropy=True) 46 | if update_target: 47 | new_target_critic = target_update(new_critic, target_critic, tau) 48 | else: 49 | new_target_critic = target_critic 50 | 51 | # Use critic conv layers in actor: 52 | new_actor_params = actor.params.copy( 53 | add_or_replace={'SharedEncoder': new_critic.params['SharedEncoder']}) 54 | actor = actor.replace(params=new_actor_params) 55 | 56 | rng, key = jax.random.split(rng) 57 | new_actor, actor_info = update_actor(key, actor, new_critic, temp, batch) 58 | new_temp, alpha_info = temperature.update(temp, actor_info['entropy'], 59 | target_entropy) 60 | 61 | return rng, new_actor, new_critic, new_target_critic, new_temp, { 62 | **critic_info, 63 | **actor_info, 64 | **alpha_info 65 | } 66 | 67 | 68 | class DrQLearner(object): 69 | 70 | def __init__(self, 71 | seed: int, 72 | observations: jnp.ndarray, 73 | actions: jnp.ndarray, 74 | actor_lr: float = 3e-4, 75 | critic_lr: float = 3e-4, 76 | temp_lr: float = 3e-4, 77 | hidden_dims: Sequence[int] = (256, 256), 78 | cnn_features: Sequence[int] = (32, 32, 32, 32), 79 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 80 | cnn_padding: str = 'VALID', 81 | latent_dim: int = 50, 82 | discount: float = 0.99, 83 | tau: float = 0.005, 84 | target_update_period: int = 1, 85 | target_entropy: Optional[float] = None, 86 | init_temperature: float = 0.1): 87 | 88 | action_dim = actions.shape[-1] 89 | 90 | if target_entropy is None: 91 | self.target_entropy = -action_dim 92 | else: 93 | self.target_entropy = target_entropy 94 | 95 | self.tau = tau 96 | self.target_update_period = target_update_period 97 | self.discount = discount 98 | 99 | rng = jax.random.PRNGKey(seed) 100 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 101 | 102 | actor_def = DrQPolicy(hidden_dims, action_dim, cnn_features, 103 | cnn_strides, cnn_padding, latent_dim) 104 | actor = Model.create(actor_def, 105 | inputs=[actor_key, observations], 106 | tx=optax.adam(learning_rate=actor_lr)) 107 | 108 | critic_def = DrQDoubleCritic(hidden_dims, cnn_features, cnn_strides, 109 | cnn_padding, latent_dim) 110 | critic = Model.create(critic_def, 111 | inputs=[critic_key, observations, actions], 112 | tx=optax.adam(learning_rate=critic_lr)) 113 | target_critic = Model.create( 114 | critic_def, inputs=[critic_key, observations, actions]) 115 | 116 | temp = Model.create(temperature.Temperature(init_temperature), 117 | inputs=[temp_key], 118 | tx=optax.adam(learning_rate=temp_lr)) 119 | 120 | self.actor = actor 121 | self.critic = critic 122 | self.target_critic = target_critic 123 | self.temp = temp 124 | self.rng = rng 125 | self.step = 0 126 | 127 | def sample_actions(self, 128 | observations: np.ndarray, 129 | temperature: float = 1.0) -> jnp.ndarray: 130 | rng, actions = policies.sample_actions(self.rng, self.actor.apply_fn, 131 | self.actor.params, observations, 132 | temperature) 133 | 134 | self.rng = rng 135 | 136 | actions = np.asarray(actions) 137 | return np.clip(actions, -1, 1) 138 | 139 | def update(self, batch: Batch) -> InfoDict: 140 | self.step += 1 141 | new_rng, new_actor, new_critic, new_target_critic, new_temp, info = _update_jit( 142 | self.rng, self.actor, self.critic, self.target_critic, self.temp, 143 | batch, self.discount, self.tau, self.target_entropy, 144 | self.step % self.target_update_period == 0) 145 | 146 | self.rng = new_rng 147 | self.actor = new_actor 148 | self.critic = new_critic 149 | self.target_critic = new_target_critic 150 | self.temp = new_temp 151 | 152 | return info 153 | -------------------------------------------------------------------------------- /continual_world.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import random 3 | import gym 4 | import metaworld 5 | import numpy as np 6 | from gym.wrappers import TimeLimit 7 | 8 | from jaxrl import wrappers 9 | 10 | from jaxrl.wrappers.normalization import RescaleReward 11 | 12 | def get_mt50() -> metaworld.MT50: 13 | saved_random_state = np.random.get_state() 14 | np.random.seed(999) 15 | random.seed(999) 16 | MT50 = metaworld.MT50() 17 | np.random.set_state(saved_random_state) 18 | return MT50 19 | 20 | TASK_SEQS = { 21 | "cw10": [ 22 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 23 | {'task': "push-wall-v1", 'hint': 'Bypass a wall and push a puck to a goal.'}, 24 | {'task': "faucet-close-v1", 'hint': 'Rotate the faucet clockwise.'}, 25 | {'task': "push-back-v1", 'hint': 'Pull a puck to a goal.'}, 26 | {'task': "stick-pull-v1", 'hint': 'Grasp a stick and pull a box with the stick.'}, 27 | {'task': "handle-press-side-v1", 'hint': 'Press a handle down sideways.'}, 28 | {'task': "push-v1", 'hint': 'Push the puck to a goal.'}, 29 | {'task': "shelf-place-v1", 'hint': 'Pick and place a puck onto a shelf.'}, 30 | {'task': "window-close-v1", 'hint': 'Push and close a window.'}, 31 | {'task': "peg-unplug-side-v1", 'hint': 'Unplug a peg sideways.'}, 32 | ], 33 | "cw1-hammer": [ 34 | "hammer-v1" 35 | ], 36 | "cw1-push-back": [ 37 | "push-back-v1" 38 | ], 39 | "cw1-push": [ 40 | "push-v1" 41 | ], 42 | "cw2-test": [ 43 | {'task': "push-wall-v1", 'hint': 'Bypass a wall and push a puck to a goal.'}, 44 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 45 | ], 46 | "cw2-ab-coffee-button": [ 47 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 48 | {'task': "coffee-button-v1", 'hint': 'Push a button on the coffee machine.'} 49 | ], 50 | "cw2-ab-handle-press": [ 51 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 52 | {'task': "handle-press-v1", 'hint': 'Press a handle down.'} 53 | ], 54 | "cw2-ab-window-open": [ 55 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 56 | {'task': "window-open-v1", 'hint': 'Push and open a window.'} 57 | ], 58 | "cw2-ab-reach": [ 59 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 60 | {'task': "reach-v1", 'hint': 'Reach a goal position.'} 61 | ], 62 | "cw2-ab-button-press": [ 63 | {'task': "hammer-v1", 'hint': 'Hammer a screw on the wall.'}, 64 | {'task': "button-press-v1", 'hint': 'Press a button.'} 65 | ], 66 | "cw3-test": [ 67 | {'task': "stick-pull-v1", 'hint': 'Grasp a stick and pull a box with the stick.'}, 68 | {'task': "push-back-v1", 'hint': 'Pull a puck to a goal.'}, 69 | {'task': "shelf-place-v1", 'hint': 'Pick and place a puck onto a shelf.'}, 70 | ] 71 | } 72 | 73 | TASK_SEQS["cw20"] = TASK_SEQS["cw10"] + TASK_SEQS["cw10"] 74 | META_WORLD_TIME_HORIZON = 200 75 | MT50 = get_mt50() 76 | 77 | class RandomizationWrapper(gym.Wrapper): 78 | """Manages randomization settings in MetaWorld environments.""" 79 | 80 | ALLOWED_KINDS = [ 81 | "deterministic", 82 | "random_init_all", 83 | "random_init_fixed20", 84 | "random_init_small_box", 85 | ] 86 | 87 | def __init__(self, env: gym.Env, subtasks: List[metaworld.Task], kind: str) -> None: 88 | assert kind in RandomizationWrapper.ALLOWED_KINDS 89 | super().__init__(env) 90 | self.subtasks = subtasks 91 | self.kind = kind 92 | 93 | env.set_task(subtasks[0]) 94 | if kind == "random_init_all": 95 | env._freeze_rand_vec = False 96 | env.seeded_rand_vec = True 97 | 98 | if kind == "random_init_fixed20": 99 | assert len(subtasks) >= 20 100 | 101 | if kind == "random_init_small_box": 102 | diff = env._random_reset_space.high - env._random_reset_space.low 103 | self.reset_space_low = env._random_reset_space.low + 0.45 * diff 104 | self.reset_space_high = env._random_reset_space.low + 0.55 * diff 105 | 106 | def reset(self, **kwargs) -> np.ndarray: 107 | if self.kind == "random_init_fixed20": 108 | self.env.set_task(self.subtasks[random.randint(0, 19)]) 109 | elif self.kind == "random_init_small_box": 110 | rand_vec = np.random.uniform( 111 | self.reset_space_low, self.reset_space_high, size=self.reset_space_low.size 112 | ) 113 | self.env._last_rand_vec = rand_vec 114 | 115 | return self.env.reset(**kwargs) 116 | 117 | 118 | def get_subtasks(name: str) -> List[metaworld.Task]: 119 | return [s for s in MT50.train_tasks if s.env_name == name] 120 | 121 | 122 | def get_single_env( 123 | name, seed, 124 | randomization="random_init_all", 125 | add_episode_monitor=True, 126 | normalize_reward=False 127 | ): 128 | if name == "HalfCheetah-v3" or name == "Ant-v3": 129 | env = gym.make(name) 130 | env.seed(seed) 131 | env.action_space.seed(seed) 132 | env.observation_space.seed(seed) 133 | else: 134 | env = MT50.train_classes[name]() 135 | env.seed(seed) 136 | env = RandomizationWrapper(env, get_subtasks(name), randomization) 137 | env.name = name 138 | env = TimeLimit(env, META_WORLD_TIME_HORIZON) 139 | env = gym.wrappers.ClipAction(env) 140 | if normalize_reward: 141 | env = RescaleReward(env, reward_scale=1.0 / META_WORLD_TIME_HORIZON) 142 | if add_episode_monitor: 143 | env = wrappers.EpisodeMonitor(env) 144 | return env 145 | 146 | 147 | if __name__ == "__main__": 148 | import time 149 | 150 | # def print_reward(env: gym.Env): 151 | # obs, done = env.reset(), False 152 | # i = 0 153 | # while not done: 154 | # i += 1 155 | # next_obs, rew, done, _ = env.step(env.action_space.sample()) 156 | # print(i, rew) 157 | 158 | # tasks_list = TASK_SEQS["cw1-push"] 159 | # env = get_single_env(tasks_list[0], 1, "deterministic", normalize_reward=False) 160 | # env_normalized = get_single_env(tasks_list[0], 1, "deterministic", normalize_reward=True) 161 | 162 | # print_reward(env) 163 | # print_reward(env_normalized) 164 | 165 | tasks_list = TASK_SEQS["cw1-push"] 166 | s = time.time() 167 | env = get_single_env(tasks_list[0], 1, "random_init_all") 168 | print(time.time() - s) 169 | s = time.time() 170 | env = get_single_env(tasks_list[0], 1, "random_init_all") 171 | print(time.time() - s) 172 | 173 | o = env.reset() 174 | _, _, _, _ = env.step(np.array([np.nan, 1.0, -1.0, 0.0])) 175 | o_new = env.reset() 176 | print(o) 177 | print(o_new) 178 | -------------------------------------------------------------------------------- /jaxrl/agents/redq/redq_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of RedQ. 2 | https://arxiv.org/abs/2101.05982 3 | """ 4 | 5 | import functools 6 | from typing import Optional, Sequence, Tuple 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import optax 12 | 13 | from jaxrl.agents.redq.actor import update as update_actor 14 | from jaxrl.agents.redq.critic import target_update 15 | from jaxrl.agents.redq.critic import update as update_critic 16 | from jaxrl.agents.sac import temperature 17 | from jaxrl.datasets import Batch 18 | from jaxrl.networks import critic_net, policies 19 | from jaxrl.networks.common import InfoDict, Model, PRNGKey 20 | 21 | 22 | @functools.partial(jax.jit, 23 | static_argnames=('backup_entropy', 'n', 'm', 24 | 'update_target', 'update_policy')) 25 | def _update_jit( 26 | rng: PRNGKey, actor: Model, critic: Model, target_critic: Model, 27 | temp: Model, batch: Batch, discount: float, tau: float, 28 | target_entropy: float, backup_entropy: bool, n: int, m: int, 29 | update_target: bool, update_policy: bool 30 | ) -> Tuple[PRNGKey, Model, Model, Model, Model, InfoDict]: 31 | 32 | rng, key = jax.random.split(rng) 33 | new_critic, critic_info = update_critic(key, 34 | actor, 35 | critic, 36 | target_critic, 37 | temp, 38 | batch, 39 | discount, 40 | backup_entropy=backup_entropy, 41 | n=n, 42 | m=m) 43 | if update_target: 44 | new_target_critic = target_update(new_critic, target_critic, tau) 45 | else: 46 | new_target_critic = target_critic 47 | 48 | if update_policy: 49 | rng, key = jax.random.split(rng) 50 | new_actor, actor_info = update_actor(key, actor, new_critic, temp, 51 | batch) 52 | new_temp, alpha_info = temperature.update(temp, actor_info['entropy'], 53 | target_entropy) 54 | else: 55 | new_actor, actor_info = actor, {} 56 | new_temp, alpha_info = temp, {} 57 | 58 | return rng, new_actor, new_critic, new_target_critic, new_temp, { 59 | **critic_info, 60 | **actor_info, 61 | **alpha_info 62 | } 63 | 64 | 65 | class REDQLearner(object): 66 | 67 | def __init__( 68 | self, 69 | seed: int, 70 | observations: jnp.ndarray, 71 | actions: jnp.ndarray, 72 | actor_lr: float = 3e-4, 73 | critic_lr: float = 3e-4, 74 | temp_lr: float = 3e-4, 75 | n: int = 10, # Number of critics. 76 | m: int = 2, # Nets to use for critic backups. 77 | policy_update_delay: int = 20, # See the original implementation. 78 | hidden_dims: Sequence[int] = (256, 256), 79 | discount: float = 0.99, 80 | tau: float = 0.005, 81 | target_update_period: int = 1, 82 | target_entropy: Optional[float] = None, 83 | backup_entropy: bool = True, 84 | init_temperature: float = 1.0, 85 | init_mean: Optional[np.ndarray] = None, 86 | policy_final_fc_init_scale: float = 1.0): 87 | """ 88 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905 89 | """ 90 | 91 | action_dim = actions.shape[-1] 92 | 93 | if target_entropy is None: 94 | self.target_entropy = -action_dim / 2 95 | else: 96 | self.target_entropy = target_entropy 97 | 98 | self.backup_entropy = backup_entropy 99 | self.n = n 100 | self.m = m 101 | self.policy_update_delay = policy_update_delay 102 | 103 | self.tau = tau 104 | self.target_update_period = target_update_period 105 | self.discount = discount 106 | 107 | rng = jax.random.PRNGKey(seed) 108 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 109 | actor_def = policies.NormalTanhPolicy( 110 | hidden_dims, 111 | action_dim, 112 | init_mean=init_mean, 113 | final_fc_init_scale=policy_final_fc_init_scale) 114 | actor = Model.create(actor_def, 115 | inputs=[actor_key, observations], 116 | tx=optax.adam(learning_rate=actor_lr)) 117 | 118 | critic_def = critic_net.DoubleCritic(hidden_dims, num_qs=n) 119 | critic = Model.create(critic_def, 120 | inputs=[critic_key, observations, actions], 121 | tx=optax.adam(learning_rate=critic_lr)) 122 | target_critic = Model.create( 123 | critic_def, inputs=[critic_key, observations, actions]) 124 | 125 | temp = Model.create(temperature.Temperature(init_temperature), 126 | inputs=[temp_key], 127 | tx=optax.adam(learning_rate=temp_lr)) 128 | 129 | self.actor = actor 130 | self.critic = critic 131 | self.target_critic = target_critic 132 | self.temp = temp 133 | self.rng = rng 134 | 135 | self.step = 0 136 | 137 | def sample_actions(self, 138 | observations: np.ndarray, 139 | temperature: float = 1.0) -> jnp.ndarray: 140 | rng, actions = policies.sample_actions(self.rng, self.actor.apply_fn, 141 | self.actor.params, observations, 142 | temperature) 143 | self.rng = rng 144 | 145 | actions = np.asarray(actions) 146 | return np.clip(actions, -1, 1) 147 | 148 | def update(self, batch: Batch) -> InfoDict: 149 | self.step += 1 150 | 151 | new_rng, new_actor, new_critic, new_target_critic, new_temp, info = _update_jit( 152 | self.rng, 153 | self.actor, 154 | self.critic, 155 | self.target_critic, 156 | self.temp, 157 | batch, 158 | self.discount, 159 | self.tau, 160 | self.target_entropy, 161 | self.backup_entropy, 162 | self.n, 163 | self.m, 164 | update_target=self.step % self.target_update_period == 0, 165 | update_policy=self.step % self.policy_update_delay == 0) 166 | 167 | self.rng = new_rng 168 | self.actor = new_actor 169 | self.critic = new_critic 170 | self.target_critic = new_target_critic 171 | self.temp = new_temp 172 | 173 | return info 174 | -------------------------------------------------------------------------------- /train_tadell.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import numpy as np 5 | import tqdm 6 | import wandb 7 | import yaml 8 | from absl import app, flags 9 | from ml_collections import config_flags, ConfigDict 10 | 11 | from jaxrl.datasets import ReplayBuffer 12 | from jaxrl.evaluation import evaluate_cl 13 | from jaxrl.utils import Logger 14 | 15 | from jaxrl.agents.sac.sac_learner import TaDeLL 16 | from continual_world import TASK_SEQS, get_single_env 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string('env_name', "cw2-test", 'Environment name.') 21 | flags.DEFINE_string('save_dir', '/home/yijunyan/Data/PyCode/CoTASP/logs', 'Tensorboard logging dir.') 22 | flags.DEFINE_integer('seed', 7409, 'Random seed.') 23 | flags.DEFINE_string('base_algo', 'tadell', 'base learning algorithm') 24 | 25 | flags.DEFINE_boolean('save_checkpoint', False, 'Save learned dictionary') 26 | 27 | flags.DEFINE_string('env_type', 'random_init_all', 'The type of env is either deterministic or random_init_all') 28 | flags.DEFINE_boolean('normalize_reward', False, 'Normalize rewards') 29 | flags.DEFINE_integer('eval_episodes', 10, 'Number of episodes used for evaluation.') 30 | flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') 31 | flags.DEFINE_integer('eval_interval', 20000, 'Eval interval.') 32 | flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') 33 | flags.DEFINE_integer('updates_per_step', 1, 'Gradient updates per step.') 34 | flags.DEFINE_integer('max_steps', int(2e4), 'Number of training steps for each task') 35 | flags.DEFINE_integer('start_training', int(1e4), 'Number of training steps to start training.') 36 | 37 | flags.DEFINE_integer('buffer_size', int(1e6), 'Size of replay buffer') 38 | 39 | flags.DEFINE_boolean('tqdm', False, 'Use tqdm progress bar.') 40 | flags.DEFINE_string('wandb_mode', 'disabled', 'Track experiments with Weights and Biases.') 41 | flags.DEFINE_string('wandb_project_name', "jaxrl_tadell", "The wandb's project name.") 42 | flags.DEFINE_string('wandb_entity', None, "the entity (team) of wandb's project") 43 | # YAML file path to tadell's hyperparameter configuration 44 | with open('configs/sac_tadell.yaml', 'r') as file: 45 | yaml_dict = yaml.unsafe_load(file) 46 | config_flags.DEFINE_config_dict( 47 | 'config', 48 | ConfigDict(yaml_dict), 49 | 'Training hyperparameter configuration.', 50 | lock_config=False 51 | ) 52 | 53 | 54 | def main(_): 55 | # config tasks 56 | seq_tasks = TASK_SEQS[FLAGS.env_name] 57 | 58 | kwargs = dict(FLAGS.config) 59 | algo = FLAGS.base_algo 60 | run_name = f"{FLAGS.env_name}__{algo}__{FLAGS.seed}__{int(time.time())}" 61 | 62 | if FLAGS.save_checkpoint: 63 | save_dict_dir = f"logs/saved_dicts/{run_name}.pkl" 64 | else: 65 | save_dict_dir = None 66 | 67 | wandb.init( 68 | project=FLAGS.wandb_project_name, 69 | entity=FLAGS.wandb_entity, 70 | sync_tensorboard=True, 71 | config=FLAGS, 72 | name=run_name, 73 | monitor_gym=False, 74 | save_code=False, 75 | mode=FLAGS.wandb_mode, 76 | dir=FLAGS.save_dir 77 | ) 78 | wandb.config.update({"algo": algo}) 79 | 80 | log = Logger(wandb.run.dir) 81 | 82 | # random numpy seeding 83 | np.random.seed(FLAGS.seed) 84 | random.seed(FLAGS.seed) 85 | 86 | # initialize SAC agent 87 | temp_env = get_single_env( 88 | TASK_SEQS[FLAGS.env_name][0]['task'], FLAGS.seed, 89 | randomization=FLAGS.env_type) 90 | if algo == 'tadell': 91 | agent = TaDeLL( 92 | FLAGS.seed, 93 | temp_env.observation_space.sample()[np.newaxis], 94 | temp_env.action_space.sample()[np.newaxis], 95 | **kwargs) 96 | del temp_env 97 | else: 98 | raise NotImplementedError() 99 | 100 | # continual learning loop 101 | eval_envs = [] 102 | for task_idx, dict_task in enumerate(seq_tasks): 103 | eval_envs.append(get_single_env(dict_task['task'], FLAGS.seed, randomization=FLAGS.env_type)) 104 | 105 | # continual learning loop 106 | total_env_steps = 0 107 | for task_idx, dict_task in enumerate(seq_tasks): 108 | print(f'Learning on task {task_idx+1}: {dict_task["task"]} for {FLAGS.max_steps} steps') 109 | 110 | ''' 111 | Learning subroutine for the current task 112 | ''' 113 | # start the current task 114 | agent.start_task(dict_task["hint"]) 115 | 116 | # set continual world environment 117 | env = get_single_env( 118 | dict_task['task'], FLAGS.seed, randomization=FLAGS.env_type, 119 | normalize_reward=FLAGS.normalize_reward 120 | ) 121 | 122 | # reset replay buffer 123 | replay_buffer = ReplayBuffer( 124 | env.observation_space, env.action_space, FLAGS.buffer_size or FLAGS.max_steps 125 | ) 126 | 127 | observation, done = env.reset(), False 128 | for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), 129 | smoothing=0.1, 130 | disable=not FLAGS.tqdm): 131 | if i < FLAGS.start_training: 132 | action = env.action_space.sample() 133 | else: 134 | action = agent.sample_actions(observation[np.newaxis]) 135 | action = np.asarray(action, dtype=np.float32).flatten() 136 | next_observation, reward, done, info = env.step(action) 137 | # counting total environment step 138 | total_env_steps += 1 139 | 140 | if not done or 'TimeLimit.truncated' in info: 141 | mask = 1.0 142 | else: 143 | mask = 0.0 144 | 145 | # only for meta-world 146 | assert mask == 1.0 147 | 148 | replay_buffer.insert(observation, action, reward, mask, float(done), 149 | next_observation) 150 | observation = next_observation 151 | 152 | if done: 153 | observation, done = env.reset(), False 154 | for k, v in info['episode'].items(): 155 | wandb.log({f'training/{k}': v, 'global_steps': total_env_steps}) 156 | 157 | if (i >= FLAGS.start_training) and (i % FLAGS.updates_per_step == 0): 158 | for _ in range(FLAGS.updates_per_step): 159 | batch = replay_buffer.sample(FLAGS.batch_size) 160 | update_info = agent.update(batch) 161 | 162 | if i % FLAGS.log_interval == 0: 163 | for k, v in update_info.items(): 164 | wandb.log({f'training/{k}': v, 'global_steps': total_env_steps}) 165 | 166 | if i % FLAGS.eval_interval == 0: 167 | eval_stats = evaluate_cl(agent, eval_envs, FLAGS.eval_episodes, tadell=True) 168 | 169 | for k, v in eval_stats.items(): 170 | wandb.log({f'evaluation/average_{k}s': v, 'global_steps': total_env_steps}) 171 | 172 | # Update the log with collected data 173 | eval_stats['cl_method'] = algo 174 | eval_stats['x'] = total_env_steps 175 | eval_stats['steps_per_task'] = FLAGS.max_steps 176 | log.update(eval_stats) 177 | 178 | ''' 179 | Updating miscellaneous things 180 | ''' 181 | print('End the current task') 182 | agent.end_task(save_dict_dir) 183 | 184 | # save log data 185 | log.save() 186 | 187 | if __name__ == '__main__': 188 | app.run(main) -------------------------------------------------------------------------------- /jaxrl/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | Batch = collections.namedtuple( 8 | 'Batch', 9 | ['observations', 'actions', 'rewards', 'masks', 'next_observations']) 10 | 11 | 12 | def split_into_trajectories(observations, actions, rewards, masks, dones_float, 13 | next_observations): 14 | trajs = [[]] 15 | 16 | for i in tqdm(range(len(observations))): 17 | trajs[-1].append((observations[i], actions[i], rewards[i], masks[i], 18 | dones_float[i], next_observations[i])) 19 | if dones_float[i] == 1.0 and i + 1 < len(observations): 20 | trajs.append([]) 21 | 22 | return trajs 23 | 24 | 25 | def merge_trajectories(trajs): 26 | observations = [] 27 | actions = [] 28 | rewards = [] 29 | masks = [] 30 | dones_float = [] 31 | next_observations = [] 32 | 33 | for traj in trajs: 34 | for (obs, act, rew, mask, done, next_obs) in traj: 35 | observations.append(obs) 36 | actions.append(act) 37 | rewards.append(rew) 38 | masks.append(mask) 39 | dones_float.append(done) 40 | next_observations.append(next_obs) 41 | 42 | return np.stack(observations), np.stack(actions), np.stack( 43 | rewards), np.stack(masks), np.stack(dones_float), np.stack( 44 | next_observations) 45 | 46 | 47 | class Dataset(object): 48 | 49 | def __init__(self, observations: np.ndarray, actions: np.ndarray, 50 | rewards: np.ndarray, masks: np.ndarray, 51 | dones_float: np.ndarray, next_observations: np.ndarray, 52 | size: int): 53 | self.observations = observations 54 | self.actions = actions 55 | self.rewards = rewards 56 | self.masks = masks 57 | self.dones_float = dones_float 58 | self.next_observations = next_observations 59 | self.size = size 60 | 61 | def sample(self, batch_size: int) -> Batch: 62 | indx = np.random.randint(self.size, size=batch_size) 63 | return Batch(observations=self.observations[indx], 64 | actions=self.actions[indx], 65 | rewards=self.rewards[indx], 66 | masks=self.masks[indx], 67 | next_observations=self.next_observations[indx]) 68 | 69 | def get_initial_states( 70 | self, 71 | and_action: bool = False 72 | ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 73 | states = [] 74 | if and_action: 75 | actions = [] 76 | trajs = split_into_trajectories(self.observations, self.actions, 77 | self.rewards, self.masks, 78 | self.dones_float, 79 | self.next_observations) 80 | 81 | def compute_returns(traj): 82 | episode_return = 0 83 | for _, _, rew, _, _, _ in traj: 84 | episode_return += rew 85 | 86 | return episode_return 87 | 88 | trajs.sort(key=compute_returns) 89 | 90 | for traj in trajs: 91 | states.append(traj[0][0]) 92 | if and_action: 93 | actions.append(traj[0][1]) 94 | 95 | states = np.stack(states, 0) 96 | if and_action: 97 | actions = np.stack(actions, 0) 98 | return states, actions 99 | else: 100 | return states 101 | 102 | def get_monte_carlo_returns(self, discount) -> np.ndarray: 103 | trajs = split_into_trajectories(self.observations, self.actions, 104 | self.rewards, self.masks, 105 | self.dones_float, 106 | self.next_observations) 107 | mc_returns = [] 108 | for traj in trajs: 109 | mc_return = 0.0 110 | for i, (_, _, reward, _, _, _) in enumerate(traj): 111 | mc_return += reward * (discount**i) 112 | mc_returns.append(mc_return) 113 | 114 | return np.asarray(mc_returns) 115 | 116 | def take_top(self, percentile: float = 100.0): 117 | assert percentile > 0.0 and percentile <= 100.0 118 | 119 | trajs = split_into_trajectories(self.observations, self.actions, 120 | self.rewards, self.masks, 121 | self.dones_float, 122 | self.next_observations) 123 | 124 | def compute_returns(traj): 125 | episode_return = 0 126 | for _, _, rew, _, _, _ in traj: 127 | episode_return += rew 128 | 129 | return episode_return 130 | 131 | trajs.sort(key=compute_returns) 132 | 133 | N = int(len(trajs) * percentile / 100) 134 | N = max(1, N) 135 | 136 | trajs = trajs[-N:] 137 | 138 | (self.observations, self.actions, self.rewards, self.masks, 139 | self.dones_float, self.next_observations) = merge_trajectories(trajs) 140 | 141 | self.size = len(self.observations) 142 | 143 | def take_random(self, percentage: float = 100.0): 144 | assert percentage > 0.0 and percentage <= 100.0 145 | 146 | trajs = split_into_trajectories(self.observations, self.actions, 147 | self.rewards, self.masks, 148 | self.dones_float, 149 | self.next_observations) 150 | np.random.shuffle(trajs) 151 | 152 | N = int(len(trajs) * percentage / 100) 153 | N = max(1, N) 154 | 155 | trajs = trajs[-N:] 156 | 157 | (self.observations, self.actions, self.rewards, self.masks, 158 | self.dones_float, self.next_observations) = merge_trajectories(trajs) 159 | 160 | self.size = len(self.observations) 161 | 162 | def train_validation_split(self, 163 | train_fraction: float = 0.8 164 | ) -> Tuple['Dataset', 'Dataset']: 165 | trajs = split_into_trajectories(self.observations, self.actions, 166 | self.rewards, self.masks, 167 | self.dones_float, 168 | self.next_observations) 169 | train_size = int(train_fraction * len(trajs)) 170 | 171 | np.random.shuffle(trajs) 172 | 173 | (train_observations, train_actions, train_rewards, train_masks, 174 | train_dones_float, 175 | train_next_observations) = merge_trajectories(trajs[:train_size]) 176 | 177 | (valid_observations, valid_actions, valid_rewards, valid_masks, 178 | valid_dones_float, 179 | valid_next_observations) = merge_trajectories(trajs[train_size:]) 180 | 181 | train_dataset = Dataset(train_observations, 182 | train_actions, 183 | train_rewards, 184 | train_masks, 185 | train_dones_float, 186 | train_next_observations, 187 | size=len(train_observations)) 188 | valid_dataset = Dataset(valid_observations, 189 | valid_actions, 190 | valid_rewards, 191 | valid_masks, 192 | valid_dones_float, 193 | valid_next_observations, 194 | size=len(valid_observations)) 195 | 196 | return train_dataset, valid_dataset 197 | -------------------------------------------------------------------------------- /train_cotasp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CONTINUAL TASK ALLOCATION IN META-POLICY NETWORK VIA SPARSE PROMPTING 3 | ''' 4 | 5 | import itertools 6 | import random 7 | import time 8 | 9 | import numpy as np 10 | import wandb 11 | import yaml 12 | from absl import app, flags 13 | from ml_collections import config_flags, ConfigDict 14 | 15 | from jaxrl.datasets import ReplayBuffer 16 | from jaxrl.evaluation import evaluate_cl 17 | from jaxrl.utils import Logger 18 | from jaxrl.agents.sac.sac_learner import CoTASPLearner 19 | from continual_world import TASK_SEQS, get_single_env 20 | 21 | FLAGS = flags.FLAGS 22 | flags.DEFINE_string('env_name', 'cw20', 'Environment name.') 23 | flags.DEFINE_integer('seed', 110, 'Random seed.') 24 | flags.DEFINE_string('base_algo', 'cotasp', 'base learning algorithm') 25 | 26 | flags.DEFINE_string('env_type', 'random_init_all', 'The type of env is either deterministic or random_init_all') 27 | flags.DEFINE_boolean('normalize_reward', True, 'Normalize rewards') 28 | flags.DEFINE_integer('eval_episodes', 10, 'Number of episodes used for evaluation.') 29 | flags.DEFINE_integer('log_interval', 200, 'Logging interval.') 30 | flags.DEFINE_integer('eval_interval', 20000, 'Eval interval.') 31 | flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') 32 | flags.DEFINE_integer('updates_per_step', 1, 'Gradient updating per # environment steps.') 33 | flags.DEFINE_integer('buffer_size', int(1e6), 'Size of replay buffer') 34 | flags.DEFINE_integer('max_step', int(1e6), 'Number of training steps for each task') 35 | flags.DEFINE_integer('start_training', int(1e4), 'Number of training steps to start training.') 36 | flags.DEFINE_integer('theta_step', int(990), 'Number of training steps for theta.') 37 | flags.DEFINE_integer('alpha_step', int(10), 'Number of finetune steps for alpha.') 38 | 39 | flags.DEFINE_boolean('rnd_explore', True, 'random policy distillation') 40 | flags.DEFINE_integer('distill_steps', int(2e4), 'distillation steps') 41 | 42 | flags.DEFINE_boolean('tqdm', False, 'Use tqdm progress bar.') 43 | flags.DEFINE_string('wandb_mode', 'online', 'Track experiments with Weights and Biases.') 44 | flags.DEFINE_string('wandb_project_name', "CoTASP_Testing", "The wandb's project name.") 45 | flags.DEFINE_string('wandb_entity', None, "the entity (team) of wandb's project") 46 | flags.DEFINE_boolean('save_checkpoint', False, 'Save meta-policy network parameters') 47 | flags.DEFINE_string('save_dir', '/home/yijunyan/Data/PyCode/CoTASP/logs', 'Logging dir.') 48 | 49 | # YAML file path to cotasp's hyperparameter configuration 50 | with open('configs/sac_cotasp.yaml', 'r') as file: 51 | yaml_dict = yaml.unsafe_load(file) 52 | config_flags.DEFINE_config_dict( 53 | 'config', 54 | ConfigDict(yaml_dict), 55 | 'Training hyperparameter configuration.', 56 | lock_config=False 57 | ) 58 | 59 | def main(_): 60 | # config tasks 61 | seq_tasks = TASK_SEQS[FLAGS.env_name] 62 | algo_kwargs = dict(FLAGS.config) 63 | algo = FLAGS.base_algo 64 | run_name = f"{FLAGS.env_name}__{algo}__{FLAGS.seed}__{int(time.time())}" 65 | 66 | if FLAGS.save_checkpoint: 67 | save_policy_dir = f"logs/saved_actors/{run_name}.json" 68 | save_dict_dir = f"logs/saved_dicts/{run_name}" 69 | else: 70 | save_policy_dir = None 71 | save_dict_dir = None 72 | 73 | wandb.init( 74 | project=FLAGS.wandb_project_name, 75 | entity=FLAGS.wandb_entity, 76 | sync_tensorboard=True, 77 | config=FLAGS, 78 | name=run_name, 79 | monitor_gym=False, 80 | save_code=False, 81 | mode=FLAGS.wandb_mode, 82 | dir=FLAGS.save_dir 83 | ) 84 | wandb.config.update({"algo": algo}) 85 | 86 | log = Logger(wandb.run.dir) 87 | 88 | # random numpy seeding 89 | np.random.seed(FLAGS.seed) 90 | random.seed(FLAGS.seed) 91 | 92 | # initialize SAC agent 93 | temp_env = get_single_env( 94 | TASK_SEQS[FLAGS.env_name][0]['task'], FLAGS.seed, 95 | randomization=FLAGS.env_type) 96 | if algo == 'cotasp': 97 | agent = CoTASPLearner( 98 | FLAGS.seed, 99 | temp_env.observation_space.sample()[np.newaxis], 100 | temp_env.action_space.sample()[np.newaxis], 101 | len(seq_tasks), 102 | **algo_kwargs) 103 | del temp_env 104 | else: 105 | raise NotImplementedError() 106 | 107 | ''' 108 | continual learning loop 109 | ''' 110 | eval_envs = [] 111 | for idx, dict_task in enumerate(seq_tasks): 112 | eval_envs.append(get_single_env(dict_task['task'], FLAGS.seed, randomization=FLAGS.env_type)) 113 | 114 | total_env_steps = 0 115 | for task_idx, dict_task in enumerate(seq_tasks): 116 | 117 | ''' 118 | Learning subroutine for the current task 119 | ''' 120 | print(f'Learning on task {task_idx+1}: {dict_task["task"]} for {FLAGS.max_step} steps') 121 | # start the current task 122 | agent.start_task(task_idx, dict_task["hint"]) 123 | 124 | if task_idx > 0 and FLAGS.rnd_explore: 125 | ''' 126 | (Optional) Rand policy distillation for better exploration in the initial stage 127 | ''' 128 | for i in range(FLAGS.distill_steps): 129 | batch = replay_buffer.sample(FLAGS.batch_size) 130 | distill_info = agent.rand_net_distill(task_idx, batch) 131 | 132 | if i % (FLAGS.distill_steps // 10) == 0: 133 | print(i, distill_info) 134 | # reset actor's optimizer 135 | agent.reset_actor_optimizer() 136 | 137 | # set continual world environment 138 | env = get_single_env( 139 | dict_task['task'], FLAGS.seed, randomization=FLAGS.env_type, 140 | normalize_reward=FLAGS.normalize_reward 141 | ) 142 | # reset replay buffer 143 | replay_buffer = ReplayBuffer( 144 | env.observation_space, env.action_space, FLAGS.buffer_size or FLAGS.max_step 145 | ) 146 | # reset scheduler 147 | schedule = itertools.cycle([False]*FLAGS.theta_step + [True]*FLAGS.alpha_step) 148 | # reset environment 149 | observation, done = env.reset(), False 150 | for idx in range(FLAGS.max_step): 151 | if idx < FLAGS.start_training: 152 | # initial exploration strategy proposed in ClonEX-SAC 153 | if task_idx == 0: 154 | action = env.action_space.sample() 155 | else: 156 | # uniform-previous strategy 157 | mask_id = np.random.choice(task_idx) 158 | action = agent.sample_actions(observation[np.newaxis], mask_id) 159 | action = np.asarray(action, dtype=np.float32).flatten() 160 | 161 | # default initial exploration strategy 162 | # action = env.action_space.sample() 163 | else: 164 | action = agent.sample_actions(observation[np.newaxis], task_idx) 165 | action = np.asarray(action, dtype=np.float32).flatten() 166 | 167 | next_observation, reward, done, info = env.step(action) 168 | # counting total environment step 169 | total_env_steps += 1 170 | 171 | if not done or 'TimeLimit.truncated' in info: 172 | mask = 1.0 173 | else: 174 | mask = 0.0 175 | # only for meta-world 176 | assert mask == 1.0 177 | 178 | replay_buffer.insert( 179 | observation, action, reward, mask, float(done), next_observation 180 | ) 181 | 182 | # CRUCIAL step easy to overlook 183 | observation = next_observation 184 | 185 | if done: 186 | # EPISODIC ending 187 | observation, done = env.reset(), False 188 | for k, v in info['episode'].items(): 189 | wandb.log({f'training/{k}': v, 'global_steps': total_env_steps}) 190 | 191 | if (idx >= FLAGS.start_training) and (idx % FLAGS.updates_per_step == 0): 192 | for _ in range(FLAGS.updates_per_step): 193 | batch = replay_buffer.sample(FLAGS.batch_size) 194 | update_info = agent.update(task_idx, batch, next(schedule)) 195 | if idx % FLAGS.log_interval == 0: 196 | for k, v in update_info.items(): 197 | wandb.log({f'training/{k}': v, 'global_steps': total_env_steps}) 198 | 199 | if idx % FLAGS.eval_interval == 0: 200 | eval_stats = evaluate_cl(agent, eval_envs, FLAGS.eval_episodes) 201 | 202 | for k, v in eval_stats.items(): 203 | wandb.log({f'evaluation/{k}': v, 'global_steps': total_env_steps}) 204 | 205 | # Update the log with collected data 206 | eval_stats['cl_method'] = algo 207 | eval_stats['x'] = total_env_steps 208 | eval_stats['steps_per_task'] = FLAGS.max_step 209 | log.update(eval_stats) 210 | 211 | ''' 212 | Updating miscellaneous things 213 | ''' 214 | print('End of the current task') 215 | dict_stats = agent.end_task(task_idx, save_policy_dir, save_dict_dir) 216 | 217 | # save log data 218 | log.save() 219 | np.save(f'{wandb.run.dir}/dict_stats.npy', dict_stats) 220 | 221 | if __name__ == '__main__': 222 | app.run(main) -------------------------------------------------------------------------------- /jaxrl/networks/autoregressive_policy.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import typing 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | from tensorflow_probability.python.internal import reparameterization 8 | from tensorflow_probability.substrates import jax as tfp 9 | 10 | tfd = tfp.distributions 11 | tfb = tfp.bijectors 12 | 13 | LOG_STD_MIN = -5.0 14 | LOG_STD_MAX = 2.0 15 | 16 | 17 | class MaskType(enum.Enum): 18 | input = 1 19 | hidden = 2 20 | output = 3 21 | 22 | 23 | @jax.util.cache() 24 | def get_mask(input_dim: int, output_dim: int, randvar_dim: int, 25 | mask_type: MaskType) -> jnp.DeviceArray: 26 | """ 27 | Create a mask for MADE. 28 | 29 | See Figure 1 for a better illustration: 30 | https://arxiv.org/pdf/1502.03509.pdf 31 | 32 | Args: 33 | input_dim: Dimensionality of the inputs. 34 | output_dim: Dimensionality of the outputs. 35 | rand_var_dim: Dimensionality of the random variable. 36 | mask_type: MaskType. 37 | 38 | Returns: 39 | A mask. 40 | """ 41 | if mask_type == MaskType.input: 42 | in_degrees = jnp.arange(input_dim) % randvar_dim 43 | else: 44 | in_degrees = jnp.arange(input_dim) % (randvar_dim - 1) 45 | 46 | if mask_type == MaskType.output: 47 | out_degrees = jnp.arange(output_dim) % randvar_dim - 1 48 | else: 49 | out_degrees = jnp.arange(output_dim) % (randvar_dim - 1) 50 | 51 | in_degrees = jnp.expand_dims(in_degrees, 0) 52 | out_degrees = jnp.expand_dims(out_degrees, -1) 53 | return (out_degrees >= in_degrees).astype(jnp.float32).transpose() 54 | 55 | 56 | class MaskedDense(nn.Dense): 57 | event_size: int = 1 58 | mask_type: MaskType = MaskType.hidden 59 | use_bias: bool = False 60 | 61 | @nn.compact 62 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 63 | inputs = jnp.asarray(inputs, self.dtype) 64 | kernel = self.param('kernel', self.kernel_init, 65 | (inputs.shape[-1], self.features)) 66 | kernel = jnp.asarray(kernel, self.dtype) 67 | 68 | mask = get_mask(*kernel.shape, self.event_size, self.mask_type) 69 | kernel = kernel * mask 70 | 71 | y = jax.lax.dot_general(inputs, 72 | kernel, 73 | (((inputs.ndim - 1, ), (0, )), ((), ())), 74 | precision=self.precision) 75 | if self.use_bias: 76 | bias = self.param('bias', self.bias_init, (self.features, )) 77 | bias = jnp.asarray(bias, self.dtype) 78 | y = y + bias 79 | return y 80 | 81 | 82 | class MaskedMLP(nn.Module): 83 | features: typing.Sequence[int] 84 | activate_final: bool = False 85 | dropout_rate: typing.Optional[float] = 0.1 86 | 87 | @nn.compact 88 | def __call__(self, 89 | inputs: jnp.ndarray, 90 | conds: jnp.ndarray, 91 | training: bool = False) -> jnp.ndarray: 92 | x = inputs 93 | x_conds = conds 94 | for i, feat in enumerate(self.features): 95 | if i == 0: 96 | mask_type = MaskType.input 97 | elif i + 1 < len(self.features): 98 | mask_type = MaskType.hidden 99 | else: 100 | mask_type = MaskType.output 101 | x = MaskedDense(feat, 102 | event_size=inputs.shape[-1], 103 | mask_type=mask_type)(x) 104 | x_conds = nn.Dense(feat)(x_conds) 105 | x = x + x_conds 106 | if i + 1 < len(self.features) or self.activate_final: 107 | x = nn.relu(x) 108 | x_conds = nn.relu(x_conds) 109 | if self.dropout_rate is not None: 110 | if training: 111 | rng = self.make_rng('dropout') 112 | else: 113 | rng = None 114 | x_conds = nn.Dropout(rate=self.dropout_rate)( 115 | x_conds, deterministic=not training, rng=rng) 116 | x = nn.Dropout(rate=self.dropout_rate)( 117 | x, deterministic=not training, rng=rng) 118 | return x 119 | 120 | 121 | class Autoregressive(tfd.Distribution): 122 | 123 | def __init__(self, distr_fn: typing.Callable[[jnp.ndarray], 124 | tfd.Distribution], 125 | batch_shape: typing.Tuple[int], event_dim: int): 126 | super().__init__( 127 | dtype=jnp.float32, 128 | reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, 129 | validate_args=False, 130 | allow_nan_stats=True) 131 | 132 | self._distr_fn = distr_fn 133 | self._event_dim = event_dim 134 | self.__batch_shape = batch_shape 135 | 136 | def _batch_shape(self): 137 | return self.__batch_shape 138 | 139 | def _sample_n(self, n: int, seed: jnp.ndarray) -> jnp.ndarray: 140 | keys = jax.random.split(seed, self._event_dim) 141 | 142 | samples = jnp.zeros((n, *self._batch_shape(), self._event_dim), 143 | jnp.float32) 144 | 145 | # TODO: Consider rewriting it with nn.scan. 146 | for i in range(self._event_dim): 147 | dist = self._distr_fn(samples) 148 | dim_samples = dist.sample(seed=keys[i]) 149 | samples = jax.ops.index_update(samples, jax.ops.index[..., i], 150 | dim_samples[..., i]) 151 | 152 | return samples 153 | 154 | def log_prob(self, values: jnp.ndarray) -> jnp.ndarray: 155 | return self._distr_fn(values).log_prob(values) 156 | 157 | @property 158 | def event_shape(self) -> int: 159 | return self._event_dim 160 | 161 | 162 | class MADETanhMixturePolicy(nn.Module): 163 | features: typing.Sequence[int] 164 | action_dim: int 165 | num_components: int = 10 166 | dropout_rate: typing.Optional[float] = None 167 | 168 | @nn.compact 169 | def __call__(self, 170 | states: jnp.ndarray, 171 | temperature: float = 1.0, 172 | training: bool = False) -> tfd.Distribution: 173 | is_initializing = not self.has_variable('params', 'means') 174 | masked_mlp = MaskedMLP( 175 | (*self.features, 3 * self.num_components * self.action_dim), 176 | dropout_rate=self.dropout_rate) 177 | means_init = self.param('means', nn.initializers.normal(1.0), 178 | (self.num_components * self.action_dim, )) 179 | 180 | if is_initializing: 181 | actions = jnp.zeros((*states.shape[:-1], self.action_dim), 182 | states.dtype) 183 | masked_mlp(actions, states) 184 | 185 | def distr_fn(actions: jnp.ndarray) -> tfd.Distribution: 186 | outputs = masked_mlp(actions, states, training=training) 187 | means, log_scales, logits = jnp.split(outputs, 3, axis=-1) 188 | means = means + means_init 189 | 190 | log_scales = jnp.clip(log_scales, LOG_STD_MIN, LOG_STD_MAX) 191 | 192 | def reshape(x): 193 | new_shape = (*x.shape[:-1], self.num_components, 194 | actions.shape[-1]) 195 | x = jnp.reshape(x, new_shape) 196 | return jnp.swapaxes(x, -1, -2) 197 | 198 | means = reshape(means) 199 | log_scales = reshape(log_scales) 200 | logits = reshape(logits) 201 | 202 | dist = tfd.Normal(loc=means, 203 | scale=jnp.exp(log_scales) * temperature) 204 | 205 | dist = tfd.MixtureSameFamily(tfd.Categorical(logits=logits), dist) 206 | 207 | return tfd.Independent(dist, reinterpreted_batch_ndims=1) 208 | 209 | dist = Autoregressive(distr_fn, states.shape[:-1], self.action_dim) 210 | return tfd.TransformedDistribution(dist, tfb.Tanh()) 211 | 212 | 213 | class MyUniform(tfd.Uniform): 214 | 215 | def _prob(self, inputs): 216 | return super()._prob(inputs) + 1e-8 217 | 218 | 219 | class MADEDiscretizedPolicy(nn.Module): 220 | features: typing.Sequence[int] 221 | action_dim: int 222 | num_components: int = 100 223 | dropout_rate: typing.Optional[float] = None 224 | 225 | @nn.compact 226 | def __call__(self, 227 | states: jnp.ndarray, 228 | temperature: float = 1.0, 229 | training: bool = False) -> tfd.Distribution: 230 | is_initializing = not self.has_variable('params', 'means') 231 | masked_mlp = MaskedMLP( 232 | (*self.features, self.num_components * self.action_dim), 233 | dropout_rate=self.dropout_rate) 234 | 235 | if is_initializing: 236 | actions = jnp.zeros((*states.shape[:-1], self.action_dim), 237 | states.dtype) 238 | masked_mlp(actions, states) 239 | 240 | def distr_fn(actions: jnp.ndarray) -> tfd.Distribution: 241 | logits = masked_mlp(actions, states, training=training) 242 | 243 | def reshape(x): 244 | new_shape = (*x.shape[:-1], self.num_components, 245 | actions.shape[-1]) 246 | x = jnp.reshape(x, new_shape) 247 | return jnp.swapaxes(x, -1, -2) 248 | 249 | xs = jnp.linspace(-1, 1, self.num_components + 1) 250 | low = xs[:-1] 251 | high = xs[1:] 252 | 253 | logits = reshape(logits) 254 | 255 | dist = MyUniform(low=low, high=high) 256 | 257 | dist = tfd.MixtureSameFamily(tfd.Categorical(logits=logits), dist) 258 | 259 | return tfd.Independent(dist, reinterpreted_batch_ndims=1) 260 | 261 | return Autoregressive(distr_fn, states.shape[:-1], self.action_dim) 262 | -------------------------------------------------------------------------------- /jaxrl/networks/policies.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Callable, Optional, Sequence, Tuple 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import custom_jvp 9 | from tensorflow_probability.substrates import jax as tfp 10 | 11 | tfd = tfp.distributions 12 | tfb = tfp.bijectors 13 | 14 | from jaxrl.networks.common import MLP, Params, PRNGKey, \ 15 | default_init, activation_fn, MaskedLayerNorm 16 | 17 | # from common import MLP, Params, PRNGKey, default_init, \ 18 | # activation_fn, RMSNorm, create_mask, zero_grads 19 | 20 | 21 | LOG_STD_MAX = 2 22 | LOG_STD_MIN = -2 23 | 24 | 25 | class MSEPolicy(nn.Module): 26 | hidden_dims: Sequence[int] 27 | action_dim: int 28 | dropout_rate: Optional[float] = None 29 | 30 | @nn.compact 31 | def __call__(self, 32 | observations: jnp.ndarray, 33 | temperature: float = 1.0, 34 | training: bool = False) -> jnp.ndarray: 35 | outputs = MLP(self.hidden_dims, 36 | activate_final=True, 37 | dropout_rate=self.dropout_rate)(observations, 38 | training=training) 39 | 40 | actions = nn.Dense(self.action_dim, 41 | kernel_init=default_init())(outputs) 42 | return nn.tanh(actions) 43 | 44 | 45 | class TanhTransformedDistribution(tfd.TransformedDistribution): 46 | """Distribution followed by tanh.""" 47 | 48 | def __init__(self, distribution, threshold=.999, validate_args=False): 49 | """Initialize the distribution. 50 | Args: 51 | distribution: The distribution to transform. 52 | threshold: Clipping value of the action when computing the logprob. 53 | validate_args: Passed to super class. 54 | """ 55 | super().__init__( 56 | distribution=distribution, 57 | bijector=tfb.Tanh(), 58 | validate_args=validate_args) 59 | # Computes the log of the average probability distribution outside the 60 | # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for 61 | # log_prob_left and [atanh(threshold), inf] for log_prob_right. 62 | self._threshold = threshold 63 | inverse_threshold = self.bijector.inverse(threshold) 64 | # average(pdf) = p/epsilon 65 | # So log(average(pdf)) = log(p) - log(epsilon) 66 | log_epsilon = jnp.log(1. - threshold) 67 | # Those 2 values are differentiable w.r.t. model parameters, such that the 68 | # gradient is defined everywhere. 69 | self._log_prob_left = self.distribution.log_cdf( 70 | -inverse_threshold) - log_epsilon 71 | self._log_prob_right = self.distribution.log_survival_function( 72 | inverse_threshold) - log_epsilon 73 | 74 | def log_prob(self, event): 75 | # Without this clip there would be NaNs in the inner tf.where and that 76 | # causes issues for some reasons. 77 | event = jnp.clip(event, -self._threshold, self._threshold) 78 | # The inverse image of {threshold} is the interval [atanh(threshold), inf] 79 | # which has a probability of "log_prob_right" under the given distribution. 80 | return jnp.where( 81 | event <= -self._threshold, self._log_prob_left, 82 | jnp.where(event >= self._threshold, self._log_prob_right, 83 | super().log_prob(event))) 84 | 85 | def mode(self): 86 | return self.bijector.forward(self.distribution.mode()) 87 | 88 | def entropy(self, seed=None): 89 | # We return an estimation using a single sample of the log_det_jacobian. 90 | # We can still do some backpropagation with this estimate. 91 | return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( 92 | self.distribution.sample(seed=seed), event_ndims=0) 93 | 94 | @classmethod 95 | def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): 96 | td_properties = super()._parameter_properties(dtype, 97 | num_classes=num_classes) 98 | del td_properties['bijector'] 99 | return td_properties 100 | 101 | 102 | class NormalTanhPolicy(nn.Module): 103 | hidden_dims: Sequence[int] 104 | action_dim: int 105 | name_activation: str = 'leaky_relu' 106 | use_layer_norm: bool = False 107 | state_dependent_std: bool = True 108 | final_fc_init_scale: float = 1.0 109 | log_std_min: Optional[float] = None 110 | log_std_max: Optional[float] = None 111 | clip_mean: float = 1.0 112 | tanh_squash: bool = True 113 | 114 | @nn.compact 115 | def __call__(self, 116 | observations: jnp.ndarray, 117 | temperature: float = 1.0) -> tfd.Distribution: 118 | h = MLP(self.hidden_dims, 119 | activations=activation_fn(self.name_activation), 120 | activate_final=True, 121 | use_layer_norm=self.use_layer_norm)(observations) 122 | 123 | means = nn.Dense( 124 | self.action_dim, 125 | kernel_init=default_init(self.final_fc_init_scale))(h) 126 | 127 | if self.state_dependent_std: 128 | log_stds = nn.Dense( 129 | self.action_dim, 130 | kernel_init=default_init(self.final_fc_init_scale))(h) 131 | else: 132 | log_stds = self.param( 133 | 'log_stds', nn.initializers.zeros, (self.action_dim,) 134 | ) 135 | 136 | log_std_min = self.log_std_min or LOG_STD_MIN 137 | log_std_max = self.log_std_max or LOG_STD_MAX 138 | log_stds = jnp.clip(log_stds, log_std_min, log_std_max) 139 | 140 | # Avoid numerical issues by limiting the mean of the Gaussian 141 | # to be in [-clip_mean, clip_mean] 142 | # means = jnp.where( 143 | # means > self.clip_mean, self.clip_mean, 144 | # jnp.where(means < -self.clip_mean, -self.clip_mean, means) 145 | # ) 146 | 147 | # numerically stable method 148 | base_dist = tfd.Normal(loc=means, scale=jnp.exp(log_stds) * temperature) 149 | 150 | if self.tanh_squash: 151 | return tfd.Independent(TanhTransformedDistribution(base_dist), 152 | reinterpreted_batch_ndims=1) 153 | else: 154 | return base_dist, {'means': means, 'stddev': jnp.exp(log_stds)} 155 | 156 | 157 | @custom_jvp 158 | def clip_fn(x): 159 | return jnp.minimum(jnp.maximum(x, 0), 1.0) 160 | 161 | @clip_fn.defjvp 162 | def f_jvp(primals, tangents): 163 | # Custom derivative rule for clip_fn 164 | # x' = 1, when 0 < x < 1; 165 | # x' = 0, otherwise. 166 | x, = primals 167 | x_dot, = tangents 168 | ans = clip_fn(x) 169 | ans_dot = jnp.where(x >= 1.0, 0, jnp.where(x <= 0, 0, 1.0)) * x_dot 170 | return ans, ans_dot 171 | 172 | def ste_step_fn(x): 173 | # Create an exactly-zero expression with Sterbenz lemma that has 174 | # an exactly-one gradient. 175 | # Straight-through estimator of step function 176 | # its derivative is equal to 1 when 0 < x < 1, 0 otherwise. 177 | zero = clip_fn(x) - jax.lax.stop_gradient(clip_fn(x)) 178 | return zero + jax.lax.stop_gradient(jnp.heaviside(x, 0)) 179 | 180 | def sigma_activation(sigma, sigma_min=LOG_STD_MIN, sigma_max=LOG_STD_MAX): 181 | return sigma_min + 0.5 * (sigma_max - sigma_min) * (jnp.tanh(sigma) + 1.) 182 | 183 | def mu_activation(mu): 184 | return jnp.tanh(mu) 185 | 186 | 187 | class MetaPolicy(nn.Module): 188 | hidden_dims: Sequence[int] 189 | action_dim: int 190 | task_num: int 191 | state_dependent_std: bool = True 192 | name_activation: str = 'leaky_relu' 193 | use_layer_norm: bool = False 194 | final_fc_init_scale: float = 1.0 195 | clip_mean: float = 1.0 196 | log_std_min: Optional[float] = None 197 | log_std_max: Optional[float] = None 198 | tanh_squash: bool = True 199 | 200 | def setup(self): 201 | self.backbones = [nn.Dense(hidn, kernel_init=default_init()) \ 202 | for hidn in self.hidden_dims] 203 | self.embeds_bb = [nn.Embed(self.task_num, hidn, embedding_init=default_init()) \ 204 | for hidn in self.hidden_dims] 205 | 206 | self.mean_layer = nn.Dense( 207 | self.action_dim, 208 | kernel_init=default_init(self.final_fc_init_scale), 209 | use_bias=False) 210 | 211 | if self.state_dependent_std: 212 | self.log_std_layer = nn.Dense( 213 | self.action_dim, 214 | kernel_init=default_init(self.final_fc_init_scale), 215 | ) 216 | else: 217 | self.log_std_layer = self.param( 218 | 'log_std_layer', nn.initializers.zeros, 219 | (self.action_dim,) 220 | ) 221 | 222 | self.activation = activation_fn(self.name_activation) 223 | self.tanh = activation_fn('tanh') 224 | if self.use_layer_norm: 225 | self.masked_ln = MaskedLayerNorm(use_bias=False, use_scale=False) 226 | 227 | def __call__(self, 228 | x: jnp.ndarray, 229 | t: jnp.ndarray, 230 | temperature: float = 1.0): 231 | masks = {} 232 | for i, layer in enumerate(self.backbones): 233 | x = layer(x) 234 | # straight-through estimator 235 | phi_l = ste_step_fn(self.embeds_bb[i](t)) 236 | mask_l = jnp.broadcast_to(phi_l, x.shape) 237 | masks[layer.name] = mask_l 238 | # masking outputs 239 | x *= mask_l 240 | if self.use_layer_norm and i == 0: 241 | # layer-normalize output 242 | x = self.masked_ln(x, mask_l) 243 | x = self.tanh(x) 244 | else: 245 | x = self.activation(x) 246 | 247 | means = self.mean_layer(x) 248 | 249 | # Avoid numerical issues by limiting the mean of the Gaussian 250 | # to be in [-clip_mean, clip_mean] 251 | # means = self.hard_tanh(means) 252 | means = mu_activation(means) * self.clip_mean 253 | 254 | if self.state_dependent_std: 255 | log_stds = self.log_std_layer(x) 256 | else: 257 | log_stds = self.log_std_layer 258 | 259 | # squashing log_std 260 | log_std_min = self.log_std_min or LOG_STD_MIN 261 | log_std_max = self.log_std_max or LOG_STD_MAX 262 | log_stds = sigma_activation(log_stds, log_std_min, log_std_max) 263 | 264 | # numerically stable method 265 | base_dist = tfd.Normal(loc=means, scale=jax.nn.softplus(log_stds) * temperature) 266 | 267 | if self.tanh_squash: 268 | return tfd.Independent(TanhTransformedDistribution(base_dist), 269 | reinterpreted_batch_ndims=1), { 270 | 'masks': masks, 271 | 'means': means, 272 | 'stddev': jax.nn.softplus(log_stds) 273 | } 274 | else: 275 | return base_dist, {'masks': masks, 'means': means, 'stddev': jax.nn.softplus(log_stds)} 276 | 277 | def get_grad_masks(self, masks: dict, input_dim: int = 12): 278 | grad_masks = {} 279 | for i, layer in enumerate(self.backbones): 280 | if i == 0: 281 | post_m = masks[layer.name] 282 | grad_masks[(layer.name, 'kernel')] = 1 - jnp.broadcast_to( 283 | post_m, (input_dim, self.hidden_dims[i]) 284 | ) 285 | grad_masks[(layer.name, 'bias')] = 1 - post_m.flatten() 286 | pre_m = masks[layer.name] 287 | else: 288 | post_m = masks[layer.name] 289 | grad_masks[(layer.name, 'kernel')] = 1 - jnp.minimum( 290 | jnp.broadcast_to(pre_m.reshape(-1, 1), (self.hidden_dims[i-1], self.hidden_dims[i])), 291 | jnp.broadcast_to(post_m, (self.hidden_dims[i-1], self.hidden_dims[i])) 292 | ) 293 | grad_masks[(layer.name, 'bias')] = 1 - post_m.flatten() 294 | pre_m = masks[layer.name] 295 | 296 | grad_masks[(self.mean_layer.name, 'kernel')] = 1 - jnp.broadcast_to( 297 | pre_m.reshape(-1, 1), (self.hidden_dims[-1], self.action_dim) 298 | ) 299 | 300 | return grad_masks 301 | 302 | 303 | class NormalTanhMixturePolicy(nn.Module): 304 | hidden_dims: Sequence[int] 305 | action_dim: int 306 | num_components: int = 5 307 | dropout_rate: Optional[float] = None 308 | 309 | @nn.compact 310 | def __call__(self, 311 | observations: jnp.ndarray, 312 | temperature: float = 1.0, 313 | training: bool = False) -> tfd.Distribution: 314 | outputs = MLP(self.hidden_dims, 315 | activate_final=True, 316 | dropout_rate=self.dropout_rate)(observations, 317 | training=training) 318 | 319 | logits = nn.Dense(self.action_dim * self.num_components, 320 | kernel_init=default_init())(outputs) 321 | means = nn.Dense(self.action_dim * self.num_components, 322 | kernel_init=default_init(), 323 | bias_init=nn.initializers.normal(stddev=1.0))(outputs) 324 | log_stds = nn.Dense(self.action_dim * self.num_components, 325 | kernel_init=default_init())(outputs) 326 | 327 | shape = list(observations.shape[:-1]) + [-1, self.num_components] 328 | logits = jnp.reshape(logits, shape) 329 | mu = jnp.reshape(means, shape) 330 | log_stds = jnp.reshape(log_stds, shape) 331 | 332 | log_stds = jnp.clip(log_stds, LOG_STD_MIN, LOG_STD_MAX) 333 | 334 | components_distribution = tfd.Normal(loc=mu, 335 | scale=jnp.exp(log_stds) * 336 | temperature) 337 | 338 | base_dist = tfd.MixtureSameFamily( 339 | mixture_distribution=tfd.Categorical(logits=logits), 340 | components_distribution=components_distribution) 341 | 342 | dist = tfd.TransformedDistribution(distribution=base_dist, 343 | bijector=tfb.Tanh()) 344 | 345 | return tfd.Independent(dist, 1) 346 | 347 | 348 | @functools.partial( 349 | jax.jit, static_argnames=('actor_apply_fn', 'distribution')) 350 | def _sample_actions( 351 | rng: PRNGKey, 352 | actor_apply_fn: Callable[..., Any], 353 | actor_params: Params, 354 | observations: np.ndarray, 355 | temperature: float = 1.0, 356 | distribution: str = 'log_prob') -> Tuple[PRNGKey, jnp.ndarray]: 357 | if distribution == 'det': 358 | return rng, actor_apply_fn({'params': actor_params}, observations, 359 | temperature) 360 | else: 361 | dist = actor_apply_fn( 362 | {'params': actor_params}, observations, temperature) 363 | 364 | rng, key = jax.random.split(rng) 365 | return rng, dist.sample(seed=key) 366 | 367 | 368 | def sample_actions( 369 | rng: PRNGKey, 370 | actor_apply_fn: Callable[..., Any], 371 | actor_params: Params, 372 | observations: jnp.ndarray, 373 | temperature: float = 1.0, 374 | distribution: str = 'log_prob') -> Tuple[PRNGKey, jnp.ndarray]: 375 | return _sample_actions(rng, actor_apply_fn, actor_params, observations, 376 | temperature, distribution) 377 | -------------------------------------------------------------------------------- /jaxrl/dict_learning/task_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from copy import deepcopy 4 | 5 | import jax 6 | import numpy as np 7 | import flax.linen as nn 8 | from scipy import spatial 9 | from sklearn.decomposition import DictionaryLearning, sparse_encode 10 | from sklearn.linear_model import Lasso 11 | from sklearn.utils import check_array, check_random_state 12 | from sklearn.utils.extmath import randomized_svd 13 | from sklearn.metrics.pairwise import cosine_similarity 14 | 15 | 16 | def overlap_1(arr1, arr2): 17 | return np.mean(arr1 == arr2) 18 | 19 | def overlap_2(arr1, arr2): 20 | return np.mean(np.logical_and(arr1, arr2)) 21 | 22 | def overlap_3(arr1, arr2): 23 | n_union = np.sum(np.logical_or(arr1, arr2)) 24 | n_intersected = np.sum(np.logical_and(arr1, arr2)) 25 | return n_intersected / n_union 26 | 27 | 28 | class BasisVectorLearner(object): 29 | 30 | def __init__(self, 31 | n_features: int, 32 | n_components: int, 33 | seed: int=0, 34 | scale: float=np.sqrt(2), 35 | verbose=False): 36 | self.rng = np.random.RandomState(seed) 37 | self._components = nn.initializers.orthogonal(scale)( 38 | jax.random.PRNGKey(seed), shape=(n_components, n_features)) 39 | self._codes = np.zeros((1, n_components)) 40 | self._n_components = n_components 41 | self._n_features = n_features 42 | self._n_samples = 0 43 | self._seed = seed 44 | self._verbose = verbose 45 | 46 | self.arch_samples = None 47 | 48 | def decompose(self, sample): 49 | assert sample.shape[1] == self._n_features 50 | self._n_samples += 1 51 | 52 | if self.arch_samples is None: 53 | self.arch_samples = sample 54 | else: 55 | self.arch_samples = np.vstack((self.arch_samples, sample)) 56 | 57 | dict_learner = DictionaryLearning( 58 | max_iter=5000, 59 | alpha=0.1, 60 | n_components=self._n_components, 61 | fit_algorithm='lars', 62 | code_init=self._codes, 63 | dict_init=self._components, 64 | transform_algorithm='lasso_lars', 65 | transform_alpha=0.1, 66 | transform_max_iter=5000, 67 | random_state=self._seed 68 | ) 69 | 70 | alphas = deepcopy(dict_learner.fit_transform(self.arch_samples)) 71 | next_alphas = np.zeros((1, self._n_components)) 72 | self._codes = np.vstack((alphas, next_alphas)) 73 | self._components = deepcopy(dict_learner.components_) 74 | 75 | if self._verbose: 76 | recon = np.dot(alphas, self._components) 77 | print(f'Number of samples: {self._n_samples}') 78 | print(f'Level of sparsity: {np.mean(alphas == 0):.4f}') 79 | print(f'Recontruction loss: {np.mean((recon - self.arch_samples)**2)}') 80 | print('Samples:\n', self.arch_samples) 81 | print('Reconst:\n', recon) 82 | 83 | def get_components(self): 84 | return deepcopy(self._components) 85 | 86 | def get_next_codes(self): 87 | return deepcopy(self._codes[-1, :].reshape(1, -1)) 88 | 89 | 90 | class OnlineDictLearner(object): 91 | def __init__(self, 92 | n_features: int, 93 | n_components: int, 94 | seed: int=0, 95 | scale: float=1.0, 96 | verbose=False): 97 | 98 | self.N = 1 99 | # d = n_features, k = n_components 100 | self.D = np.eye(n_features) 101 | self.I = np.eye(n_features*n_components) 102 | self.A = np.zeros((n_features*n_components, n_features*n_components)) 103 | self.b = np.zeros((n_features*n_components, 1)) 104 | self.L = jax.nn.initializers.variance_scaling(scale, 'fan_in', 'normal')( 105 | jax.random.PRNGKey(seed), shape=(n_features, n_components)) 106 | # self.L = (jax.random.uniform(jax.random.PRNGKey(seed), shape=(n_features, n_components)) - 0.5) * 2 * np.sqrt(1 / 12) * 1e-2 107 | self.s = None 108 | self.S = None 109 | self.arch_samples = None 110 | self._n_components = n_components 111 | self._verbose = verbose 112 | self.lasso_solver = Lasso(alpha=1e-5, fit_intercept=False, max_iter=5000, random_state=seed) 113 | 114 | def decompose(self, sample): 115 | self.lasso_solver.fit(self.L, sample.T) 116 | s = self.lasso_solver.coef_.reshape(-1,1) 117 | 118 | # collect coefs s: 119 | if self.S is None: 120 | self.S = s 121 | self.arch_samples = sample.T 122 | else: 123 | self.S = np.hstack([self.S, s]) 124 | self.arch_samples = np.hstack([self.arch_samples, sample.T]) 125 | 126 | # update stats 127 | self.A += np.kron(s.dot(s.T), self.D) 128 | self.b += np.kron(s.T, sample.dot(self.D)).T 129 | vals = np.linalg.inv(self.A / self.N + 1e-5 * self.I).dot(self.b / self.N) 130 | self.L = vals.reshape(self.L.shape, order='F') 131 | self.s = s 132 | 133 | if self._verbose: 134 | recon = np.dot(self.L, self.S) 135 | print(f'Number of samples: {self.N}') 136 | print(f'Level of sparsity: {np.mean(self.S == 0):.4f}') 137 | print(f'Recontruction loss: {np.mean((recon - self.arch_samples)**2)}') 138 | print('Samples:\n', self.arch_samples.T) 139 | print('Reconst:\n', recon.T) 140 | 141 | self.N += 1 142 | 143 | def get_components(self): 144 | return deepcopy(self.L.T) 145 | 146 | def get_codes(self): 147 | return deepcopy(self.s.T) 148 | 149 | def get_next_codes(self): 150 | return np.zeros((1, self._n_components)) 151 | 152 | 153 | def clip_by_norm(x, c): 154 | clip_coef = c / (np.linalg.norm(x) + 1e-6) 155 | clip_coef_clipped = min(1.0, clip_coef) 156 | return x * clip_coef_clipped 157 | 158 | def _update_dict( 159 | dictionary, 160 | Y, 161 | code, 162 | A=None, 163 | B=None, 164 | c=1e-1, 165 | verbose=False, 166 | random_state=None, 167 | positive=False, 168 | ): 169 | """Update the dense dictionary factor in place. 170 | 171 | Parameters 172 | ---------- 173 | dictionary : ndarray of shape (n_components, n_features) 174 | Value of the dictionary at the previous iteration. 175 | 176 | Y : ndarray of shape (n_samples, n_features) 177 | Data matrix. 178 | 179 | code : ndarray of shape (n_samples, n_components) 180 | Sparse coding of the data against which to optimize the dictionary. 181 | 182 | A : ndarray of shape (n_components, n_components), default=None 183 | Together with `B`, sufficient stats of the online model to update the 184 | dictionary. 185 | 186 | B : ndarray of shape (n_features, n_components), default=None 187 | Together with `A`, sufficient stats of the online model to update the 188 | dictionary. 189 | 190 | verbose: bool, default=False 191 | Degree of output the procedure will print. 192 | 193 | random_state : int, RandomState instance or None, default=None 194 | Used for randomly initializing the dictionary. Pass an int for 195 | reproducible results across multiple function calls. 196 | See :term:`Glossary `. 197 | 198 | positive : bool, default=False 199 | Whether to enforce positivity when finding the dictionary. 200 | 201 | .. versionadded:: 0.20 202 | """ 203 | n_samples, n_components = code.shape 204 | random_state = check_random_state(random_state) 205 | 206 | if A is None: 207 | A = code.T @ code 208 | if B is None: 209 | B = Y.T @ code 210 | 211 | n_unused = 0 212 | for k in range(n_components): 213 | if A[k, k] > 1e-6: 214 | # 1e-6 is arbitrary but consistent with the spams implementation 215 | # -np.inf means that never resample atoms. 216 | dictionary[k] += (B[:, k] - A[k] @ dictionary) / A[k, k] 217 | else: 218 | # kth atom is almost never used -> sample a new one from the data 219 | newd = Y[random_state.choice(n_samples)] 220 | 221 | # add small noise to avoid making the sparse coding ill conditioned 222 | noise_level = 1.0 * (newd.std() or 1) # avoid 0 std 223 | noise = random_state.normal(0, noise_level, size=len(newd)) 224 | dictionary[k] = newd + noise 225 | code[:, k] = 0 226 | n_unused += 1 227 | 228 | # randomly initialize kth atom 229 | # scale = 1.0 * (newd.std() or 1) # avoid 0 std 230 | # dictionary[k] = random_state.normal(loc=0.0, scale=scale, size=len(newd)) 231 | # code[:, k] = 0 232 | # n_unused += 1 233 | 234 | # pass 235 | 236 | if positive: 237 | np.clip(dictionary[k], 0, None, out=dictionary[k]) 238 | 239 | # Projection on the constraint set ||V_k||_2 <= c 240 | dictionary[k] = clip_by_norm(dictionary[k], c) 241 | 242 | if verbose and n_unused > 0: 243 | print(f"{n_unused} unused atoms resampled.") 244 | 245 | return dictionary 246 | 247 | 248 | class OnlineDictLearnerV2(object): 249 | def __init__(self, 250 | n_features: int, 251 | n_components: int, 252 | seed: int=0, 253 | init_sample: np.ndarray=None, 254 | c: float=1e-2, 255 | scale: float=1.0, 256 | alpha: float=1e-3, 257 | method: str='lasso_lars', # ['lasso_cd', 'lasso_lars', 'threshold'] 258 | positive_code: bool=False, 259 | scale_code: bool=False, 260 | verbose=True): 261 | 262 | self.N = 0 263 | self.rng = np.random.RandomState(seed=seed) 264 | self.A = np.zeros((n_components, n_components)) 265 | self.B = np.zeros((n_features, n_components)) 266 | 267 | if init_sample is None: 268 | dictionary = self.rng.normal(loc=0.0, scale=scale, size=(n_components, n_features)) 269 | # Projection on the constraint set ||V_k||_2 <= c 270 | for j in range(n_components): 271 | dictionary[j] = clip_by_norm(dictionary[j], c) 272 | 273 | else: 274 | _, S, dictionary = randomized_svd(init_sample, n_components, random_state=self.rng) 275 | dictionary = S[:, np.newaxis] * dictionary 276 | r = len(dictionary) 277 | if n_components <= r: 278 | dictionary = dictionary[:n_components, :] 279 | else: 280 | dictionary = np.r_[ 281 | dictionary, 282 | np.zeros((n_components - r, dictionary.shape[1]), dtype=dictionary.dtype), 283 | ] 284 | 285 | dictionary = check_array(dictionary, order="F", copy=False) 286 | self.D = np.require(dictionary, requirements="W") 287 | 288 | self.C = None 289 | self.c = c 290 | self.alpha = alpha 291 | self.method = method 292 | self.archives = None 293 | self._verbose = verbose 294 | self.arch_code = None 295 | self._positive_code = positive_code 296 | self._scale_code = scale_code 297 | self.change_of_dict = [] 298 | 299 | def get_alpha(self, sample: np.ndarray): 300 | code = sparse_encode( 301 | sample, 302 | self.D, 303 | algorithm=self.method, 304 | alpha=self.alpha, 305 | check_input=False, 306 | positive=self._positive_code, 307 | max_iter=10000) 308 | 309 | # recording 310 | if self.arch_code is None: 311 | self.arch_code = code 312 | else: 313 | self.arch_code = np.vstack([self.arch_code, code]) 314 | 315 | if self._scale_code: 316 | scaled_code = self._scale_coeffs(code) 317 | assert np.max(scaled_code) == 1.0 318 | else: 319 | scaled_code = code 320 | 321 | if self._verbose: 322 | recon = np.dot(code, self.D) 323 | print('Spare Coding of Task Embedding') 324 | print(f'Rate of deactivate: {1-np.mean(np.heaviside(scaled_code, 0)):.4f}') 325 | print(f'Rate of activate: {np.mean(np.heaviside(scaled_code, 0)):.4f}') 326 | print(f'Recontruction loss: {np.mean((sample - recon)**2):.4e}') 327 | print('----------------------------------') 328 | 329 | return scaled_code 330 | 331 | def update_dict(self, codes: np.ndarray, sample: np.ndarray): 332 | self.N += 1 333 | if self._scale_code: 334 | codes = self._rescale_coeffs(codes) 335 | # recording 336 | if self.C is None: 337 | self.C = codes 338 | self.archives = sample 339 | else: 340 | self.C = np.vstack([self.C, codes]) 341 | self.archives = np.vstack([self.archives, sample]) 342 | assert self.C.shape[0] == self.N 343 | 344 | # Update the auxiliary variables 345 | # Improvement for mini-batch version 346 | # batch_size = 1 347 | # if self.N < batch_size - 1: 348 | # theta = float((self.N + 1) * batch_size) 349 | # else: 350 | # theta = float(batch_size**2 + self.N + 1 - batch_size) 351 | # beta = (theta + 1 - batch_size) / (theta + 1) 352 | 353 | # self.A *= beta 354 | self.A += np.dot(codes.T, codes) 355 | # self.B *= beta 356 | self.B += np.dot(sample.T, codes) 357 | 358 | # pre-verbose 359 | if self._verbose: 360 | recons = np.dot(self.C, self.D) 361 | print('Dictionary Learning') 362 | print(f'Pre-MSE loss: {np.mean((self.archives - recons)**2):.4e}') 363 | 364 | old_D = deepcopy(self.D) 365 | 366 | # Update dictionary 367 | self.D = _update_dict( 368 | self.D, 369 | sample, 370 | codes, 371 | self.A, 372 | self.B, 373 | self.c, 374 | verbose=self._verbose, 375 | random_state=self.rng, 376 | positive=self._positive_code 377 | ) 378 | 379 | self.change_of_dict.append(np.linalg.norm(old_D - self.D)**2/self.D.size) 380 | 381 | # post-verbose 382 | if self._verbose: 383 | recons = np.dot(self.C, self.D) 384 | print(f'Post-MSE loss: {np.mean((self.archives - recons)**2):.4e}') 385 | print('----------------------------------') 386 | 387 | def _scale_coeffs(self, alpha: np.ndarray): 388 | # constrain the alpha to [0, 1] 389 | assert self._positive_code 390 | self.factor = np.max(alpha) 391 | return alpha / self.factor 392 | 393 | def _rescale_coeffs(self, alpha_scaled: np.ndarray): 394 | assert self._positive_code 395 | return alpha_scaled * self.factor 396 | 397 | def _compute_overlapping(self): 398 | binary_masks = np.heaviside(self.arch_code, 0) 399 | overlap_mat = np.empty((binary_masks.shape[0], binary_masks.shape[0])) 400 | 401 | for i in range(binary_masks.shape[0]): 402 | for j in range(binary_masks.shape[0]): 403 | overlap_mat[i, j] = overlap_3(binary_masks[i], binary_masks[j]) 404 | 405 | return overlap_mat 406 | 407 | def save(self, save_path: str): 408 | 409 | saved_dict = { 410 | "A": self.A, 411 | "B": self.B, 412 | "D": self.D, 413 | } 414 | 415 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 416 | with open(save_path, 'wb') as f: 417 | pickle.dump(saved_dict, f) 418 | 419 | def load(self, load_path: str): 420 | with open(load_path, 'rb') as f: 421 | loaded_dict = pickle.load(f) 422 | 423 | self.A = loaded_dict["A"] 424 | self.B = loaded_dict["B"] 425 | self.D = loaded_dict["D"] 426 | 427 | 428 | if __name__ == "__main__": 429 | x = np.ones(5) * 0.1 430 | print(np.linalg.norm(x)) 431 | print(np.linalg.norm(clip_by_norm(x, 2.0))) 432 | 433 | # import matplotlib.pyplot as plt 434 | # from sentence_transformers import SentenceTransformer 435 | 436 | # model = SentenceTransformer('all-MiniLM-L12-v2') 437 | 438 | # # task hints 439 | # hints = [ 440 | # 'Hammer a screw on the wall.', 441 | # 'Bypass a wall and push a puck to a goal.', 442 | # 'Rotate the faucet clockwise.', 443 | # 'Pull a puck to a goal.', 444 | # 'Grasp a stick and pull a box with the stick.', 445 | # 'Press a handle down sideways.', 446 | # 'Push the puck to a goal.', 447 | # 'Pick and place a puck onto a shelf.', 448 | # 'Push and close a window.', 449 | # 'Unplug a peg sideways.', 450 | # 'Hammer a screw on the wall.', 451 | # 'Bypass a wall and push a puck to a goal.', 452 | # 'Rotate the faucet clockwise.', 453 | # 'Pull a puck to a goal.', 454 | # 'Grasp a stick and pull a box with the stick.', 455 | # 'Press a handle down sideways.', 456 | # 'Push the puck to a goal.', 457 | # 'Pick and place a puck onto a shelf.', 458 | # 'Push and close a window.', 459 | # 'Unplug a peg sideways.' 460 | # ] 461 | # task_idx = [ 462 | # 'task 1', 'task 2', 'task 3', 'task 4', 'task 5', 463 | # 'task 6', 'task 7', 'task 8', 'task 9', 'task 10', 464 | # 'task 11', 'task 12', 'task 13', 'task 14', 'task 15', 465 | # 'task 15', 'task 17', 'task 18', 'task 19', 'task 20' 466 | # ] 467 | 468 | # init_embedding = model.encode('Press a handle down sideways.') 469 | 470 | # dict_learner = OnlineDictLearnerV2( 471 | # n_features=384, 472 | # n_components=1024, 473 | # seed=0, 474 | # init_sample=None, 475 | # c=1.0, 476 | # alpha=1e-3, 477 | # method='lasso_lars', 478 | # positive_code=True, 479 | # scale_code=False, 480 | # verbose=True) 481 | 482 | # # mimic training stage 483 | # for idx, hint_task in enumerate(hints): 484 | # print(idx+1, hint_task) 485 | # task_embedding = model.encode(hint_task) 486 | 487 | # # compute code for current task 488 | # code = dict_learner.get_alpha(task_embedding[np.newaxis, :]) 489 | 490 | # # mimic RL finetuning 491 | # code += np.random.normal(size=code.shape) * 1.0 492 | # code = np.clip(code, 0, 10.0) 493 | 494 | # # online update dictionary via CD 495 | # dict_learner.update_dict(code, task_embedding[np.newaxis, :]) 496 | --------------------------------------------------------------------------------