├── common ├── __init__.py ├── type_aliases.py ├── distributions.py ├── buffers.py ├── off_policy_algorithm.py └── policies.py ├── models ├── __init__.py ├── utils.py ├── actor_critic_evaluation_callback.py └── critic.py ├── diffusion ├── __init__.py ├── common │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── pisgrad_net.py │ ├── diffusion_models.py │ ├── scheduler.py │ ├── learning_rate_scheduler.py │ ├── utils.py │ └── init_diffusion_model.py ├── od │ ├── __init__.py │ ├── od_sampling.py │ ├── od_integrators.py │ └── dis.py ├── diffusion_policy.py └── dime.py ├── .gitignore ├── paper_results ├── gym_ant_v3 │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_acrobot │ ├── env_interacts.npy │ └── mean_return.npy ├── dmc_dog_run │ ├── env_interacts.npy │ └── mean_return.npy ├── dmc_dog_stand │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_dog_trot │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_dog_walk │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_fish_swim │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_cheetah_run │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_finger_turn │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_hopper_hop │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_walker_run │ ├── mean_return.npy │ └── env_interacts.npy ├── gym_humanoid_v3 │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_humanoid_run │ ├── env_interacts.npy │ └── mean_return.npy ├── dmc_humanoid_stand │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_humanoid_walk │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_quadruped_run │ ├── mean_return.npy │ └── env_interacts.npy ├── dmc_pendulum_swingup │ ├── mean_return.npy │ └── env_interacts.npy ├── myo_hand_reach_rndm │ ├── mean_return.npy │ └── env_interacts.npy ├── myo_hand_key_turn_rndm │ ├── mean_return.npy │ └── env_interacts.npy ├── myo_hand_obj_hold_rndm │ ├── mean_return.npy │ └── env_interacts.npy └── myo_hand_pen_twirl_rndm │ ├── mean_return.npy │ └── env_interacts.npy ├── configs ├── setup.yaml ├── sampler │ ├── dt_schedule │ │ └── cosine.yaml │ ├── score_model │ │ └── cond_pisgradnet.yaml │ └── dis.yaml ├── base.yaml └── alg │ └── dime.yaml ├── requirements.txt ├── LICENSE ├── setup.sh ├── README.md └── run_dime.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/od/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/common/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | **.log 4 | logs/ 5 | wandb/ 6 | eval_logs/ 7 | outputs/ -------------------------------------------------------------------------------- /paper_results/gym_ant_v3/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/gym_ant_v3/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_acrobot/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_acrobot/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_acrobot/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_acrobot/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_run/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_run/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_run/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_run/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_stand/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_stand/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_trot/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_trot/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_walk/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_walk/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_fish_swim/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_fish_swim/mean_return.npy -------------------------------------------------------------------------------- /paper_results/gym_ant_v3/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/gym_ant_v3/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_cheetah_run/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_cheetah_run/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_stand/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_stand/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_trot/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_trot/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_dog_walk/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_dog_walk/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_finger_turn/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_finger_turn/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_fish_swim/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_fish_swim/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_hopper_hop/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_hopper_hop/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_walker_run/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_walker_run/mean_return.npy -------------------------------------------------------------------------------- /paper_results/gym_humanoid_v3/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/gym_humanoid_v3/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_cheetah_run/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_cheetah_run/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_finger_turn/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_finger_turn/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_hopper_hop/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_hopper_hop/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_run/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_run/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_run/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_run/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_stand/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_stand/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_walk/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_walk/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_quadruped_run/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_quadruped_run/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_walker_run/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_walker_run/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/gym_humanoid_v3/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/gym_humanoid_v3/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_stand/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_stand/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_humanoid_walk/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_humanoid_walk/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/dmc_pendulum_swingup/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_pendulum_swingup/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_quadruped_run/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_quadruped_run/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_reach_rndm/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_reach_rndm/mean_return.npy -------------------------------------------------------------------------------- /paper_results/dmc_pendulum_swingup/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/dmc_pendulum_swingup/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_key_turn_rndm/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_key_turn_rndm/mean_return.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_obj_hold_rndm/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_obj_hold_rndm/mean_return.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_pen_twirl_rndm/mean_return.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_pen_twirl_rndm/mean_return.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_reach_rndm/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_reach_rndm/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_key_turn_rndm/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_key_turn_rndm/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_obj_hold_rndm/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_obj_hold_rndm/env_interacts.npy -------------------------------------------------------------------------------- /paper_results/myo_hand_pen_twirl_rndm/env_interacts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/DIME/HEAD/paper_results/myo_hand_pen_twirl_rndm/env_interacts.npy -------------------------------------------------------------------------------- /configs/setup.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | entity: null 3 | group: DIME_${env_name} 4 | project: DIME 5 | job_type: lr${alg.optimizer.lr_actor} 6 | activate: True 7 | -------------------------------------------------------------------------------- /configs/sampler/dt_schedule/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffusion.common.scheduler.get_cosine_schedule 2 | total_steps: ${alg.actor.diff_steps} 3 | min: 0.001 4 | s: 0.008 5 | pow: 2 6 | -------------------------------------------------------------------------------- /configs/sampler/score_model/cond_pisgradnet.yaml: -------------------------------------------------------------------------------- 1 | use_target_score: ${use_target_score} 2 | num_layers: 3 3 | num_hid: 256 4 | outer_clip: 1e4 5 | inner_clip: 1e2 6 | 7 | weight_init: 1e-8 8 | bias_init: 0. 9 | layer_norm: false 10 | time_coder_out: ${sampler.score_model.num_hid} -------------------------------------------------------------------------------- /configs/sampler/dis.yaml: -------------------------------------------------------------------------------- 1 | # Time-Reversed Diffusion Sampler (DIS) 2 | name: dis 3 | underdamped: False 4 | integrator: EM 5 | init_std: 2.5 6 | friction: 1.0 7 | learn_prior: False 8 | learn_betas: False 9 | learn_friction: True 10 | learn_mass_matrix: False 11 | 12 | defaults: 13 | - score_model: cond_pisgradnet 14 | - dt_schedule: cosine 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | stable_baselines3==2.1.0 2 | gymnasium==0.29.1 3 | imageio==2.31.3 4 | mujoco==2.3.7 5 | optax==0.1.7 6 | tqdm==4.66.1 7 | rich==13.5.2 8 | rlax==0.1.6 9 | tensorboard==2.14.0 10 | tensorflow-probability==0.21.0 11 | wandb==0.18.5 12 | scipy==1.11.4 13 | shimmy==1.3.0 14 | hydra-core==1.3.2 15 | numpyro==0.15.3 16 | dm-control==1.0.14 -------------------------------------------------------------------------------- /diffusion/common/diffusion_models.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Callable 2 | 3 | 4 | class DiffusionModel(NamedTuple): 5 | num_steps: int 6 | forward_model: Callable 7 | backward_model: Callable 8 | drift_fn: Callable 9 | delta_t_fn: Callable 10 | friction_fn: Callable 11 | mass_fn: Callable 12 | prior_sampler: Callable 13 | prior_log_prob: Callable 14 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | defaults: 3 | - alg: dime 4 | - sampler: dis 5 | - setup 6 | - _self_ 7 | 8 | seed: 0 9 | use_jit: true 10 | tot_time_steps: 1e6 11 | log_freq: 100 12 | step_size: ${alg.optimizer.lr_actor} 13 | step_size_betas: ${alg.optimizer.lr_actor} 14 | use_path_gradient: False 15 | use_target_score: False 16 | dt: 0.1 17 | learn_dt: True 18 | per_step_dt: False 19 | per_dim_friction: True 20 | # Related to the learning rate scheduler (not used in DIME) 21 | use_step_size_scheduler: False 22 | warmup: const 23 | iters: ${tot_time_steps} 24 | warmup_iters: 60_000 25 | 26 | env_name: dm_control/dog-run 27 | -------------------------------------------------------------------------------- /common/type_aliases.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Any 2 | 3 | import flax 4 | import numpy as np 5 | from flax.training.train_state import TrainState 6 | 7 | 8 | class ActorTrainState(TrainState): 9 | batch_stats: flax.core.FrozenDict 10 | 11 | 12 | class RLTrainState(TrainState): # type: ignore[misc] 13 | target_params: flax.core.FrozenDict # type: ignore[misc] 14 | batch_stats: flax.core.FrozenDict 15 | target_batch_stats: flax.core.FrozenDict 16 | 17 | 18 | class ReplayBufferSamplesNp(NamedTuple): 19 | observations: np.ndarray 20 | actions: np.ndarray 21 | next_observations: np.ndarray 22 | dones: np.ndarray 23 | rewards: np.ndarray 24 | -------------------------------------------------------------------------------- /diffusion/common/scheduler.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | 4 | def get_linear_schedule(total_steps, min=0.01): 5 | def linear_noise_schedule(step): 6 | t = (total_steps - step) / total_steps 7 | return (1. - t) * min + t 8 | 9 | return linear_noise_schedule 10 | 11 | 12 | def get_cosine_schedule(total_steps, min=0.01, s=0.008, pow=2): 13 | def cosine_schedule(step): 14 | t = (total_steps - step) / total_steps 15 | offset = 1 + s 16 | return (1. - min) * jnp.cos(0.5 * jnp.pi * (offset - t) / offset) ** pow + min 17 | 18 | return cosine_schedule 19 | 20 | 21 | def get_constant_schedule(): 22 | def constant_schedule(step): 23 | return jnp.array(1.) 24 | 25 | return constant_schedule 26 | -------------------------------------------------------------------------------- /configs/alg/dime.yaml: -------------------------------------------------------------------------------- 1 | tau: 1.0 2 | policy_tau: 1.0 3 | utd: 2 4 | gamma: 0.99 5 | policy_delay: 3 6 | batch_size: 256 7 | buffer_size: 1000000 8 | learning_starts: 5000 9 | reset_models: False 10 | 11 | ent_coef: 12 | type: "auto" # const or auto 13 | init: 1.0 14 | target: 6.0 15 | 16 | critic: 17 | activation: 'relu' 18 | n_critics: 2 19 | hs: [2048, 2048] 20 | dropout_rate: null # None in yaml 21 | use_layer_norm: False 22 | n_atoms: 101 23 | v_min: -200 24 | v_max: 200 25 | entr_coeff: 0.005 26 | 27 | optimizer: 28 | bn: True 29 | bn_momentum: 0.99 30 | bn_mode: brn_actor 31 | bn_warmup: 100000 32 | lr_critic: 3.0e-4 33 | lr_actor: 3.0e-4 34 | b1: 0.5 35 | do_actor_grad_clip: True 36 | actor_grad_clip: 1.0 37 | 38 | 39 | actor: 40 | diff_steps: 16 41 | 42 | 43 | -------------------------------------------------------------------------------- /diffusion/common/learning_rate_scheduler.py: -------------------------------------------------------------------------------- 1 | import optax 2 | 3 | 4 | def get_learning_rate_scheduler(cfg, step_size): 5 | """Creates learning rate schedule.""" 6 | if cfg['warmup'] == 'linear': 7 | warmup_fn = optax.linear_schedule( 8 | init_value=0., end_value=step_size, 9 | transition_steps=cfg['warmup_iters']) 10 | """Creates learning rate schedule.""" 11 | elif cfg['warmup'] == 'const': 12 | warmup_fn = optax.constant_schedule(step_size) 13 | else: 14 | raise ValueError(f"No warmup scheme called {cfg['warmup']}") 15 | 16 | cosine_epochs = max(cfg['iters'] - cfg['warmup_iters'], 1) 17 | cosine_fn = optax.cosine_decay_schedule( 18 | init_value=step_size, 19 | decay_steps=cosine_epochs) 20 | 21 | schedule_fn = optax.join_schedules( 22 | schedules=[warmup_fn, cosine_fn], 23 | boundaries=[cfg['warmup_iters']]) 24 | return schedule_fn 25 | -------------------------------------------------------------------------------- /common/distributions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import jax.numpy as jnp 4 | import tensorflow_probability 5 | 6 | tfp = tensorflow_probability.substrates.jax 7 | tfd = tfp.distributions 8 | 9 | 10 | class TanhTransformedDistribution(tfd.TransformedDistribution): # type: ignore[name-defined] 11 | """ 12 | From https://github.com/ikostrikov/walk_in_the_park 13 | otherwise mode is not defined for Squashed Gaussian 14 | """ 15 | 16 | def __init__(self, distribution: tfd.Distribution, validate_args: bool = False): # type: ignore[name-defined] 17 | super().__init__(distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args) 18 | 19 | def mode(self) -> jnp.ndarray: 20 | return self.bijector.forward(self.distribution.mode()) 21 | 22 | @classmethod 23 | def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): 24 | td_properties = super()._parameter_properties(dtype, num_classes=num_classes) 25 | del td_properties["bijector"] 26 | return td_properties 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DIME authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jax.numpy as jnp 4 | import flax.linen as nn 5 | 6 | 7 | def is_slurm_job(): 8 | """Checks whether the script is run within slurm""" 9 | return bool(len({k: v for k, v in os.environ.items() if 'SLURM' in k})) 10 | 11 | 12 | class ReLU(nn.Module): 13 | def __call__(self, x): 14 | return nn.relu(x) 15 | 16 | 17 | class ReLU6(nn.Module): 18 | def __call__(self, x): 19 | return nn.relu6(x) 20 | 21 | 22 | class Tanh(nn.Module): 23 | def __call__(self, x): 24 | return nn.tanh(x) 25 | 26 | 27 | class Sin(nn.Module): 28 | def __call__(self, x): 29 | return jnp.sin(x) 30 | 31 | 32 | class Elu(nn.Module): 33 | def __call__(self, x): 34 | return nn.elu(x) 35 | 36 | 37 | class GLU(nn.Module): 38 | def __call__(self, x): 39 | return nn.glu(x) 40 | 41 | 42 | class LayerNormedReLU(nn.Module): 43 | @nn.compact 44 | def __call__(self, x): 45 | return nn.LayerNorm()(nn.relu(x)) 46 | 47 | 48 | class ReLUOverMax(nn.Module): 49 | def __call__(self, x): 50 | act = nn.relu(x) 51 | return act / (jnp.max(act) + 1e-6) 52 | 53 | 54 | activation_fn = { 55 | # unbounded 56 | "relu": ReLU, 57 | "elu": Elu, 58 | "glu": GLU, 59 | # bounded 60 | "tanh": Tanh, 61 | "sin": Sin, 62 | "relu6": ReLU6, 63 | # unbounded with normalizer 64 | "layernormed_relu": LayerNormedReLU, # with normalizer 65 | "relu_over_max": ReLUOverMax, # with normalizer 66 | } -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Exit immediately if a command exits with a non-zero status. 3 | set -e 4 | 5 | # Function to print messages with a timestamp 6 | log() { 7 | echo "[$(date +"%Y-%m-%d %H:%M:%S")] $1" 8 | } 9 | 10 | # Create a new conda environment named 'dime' with Python 3.11. 11 | log "Creating the conda environment 'dime' with Python 3.11..." 12 | conda create -n dime python=3.11 -y 13 | 14 | # Activate the new environment. 15 | # Ensure that the conda base environment is initialized for this shell. 16 | if [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then 17 | source "$(conda info --base)/etc/profile.d/conda.sh" 18 | else 19 | log "Could not find conda.sh; make sure conda is installed and initialized." 20 | exit 1 21 | fi 22 | 23 | # Explicitly activate the 'dime' environment. 24 | log "Activating the 'dime' environment..." 25 | conda activate dime 26 | 27 | # Install the Python project requirements. 28 | if [ -f requirements.txt ]; then 29 | log "Installing project requirements from requirements.txt..." 30 | pip install -r requirements.txt 31 | else 32 | log "requirements.txt not found. Skipping pip install -r requirements.txt." 33 | fi 34 | 35 | # Install JAX with CUDA support. 36 | log "Installing jax with CUDA (cuda12_pip) support..." 37 | pip install "jax[cuda12_pip]==0.4.33" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 38 | 39 | # Install flax version 0.9.0. 40 | log "Installing flax==0.9.0..." 41 | pip install flax==0.9.0 42 | 43 | pip install jax==0.4.33 44 | 45 | # Install the specified PyTorch version. 46 | log "Installing torch==2.4.1..." 47 | pip install torch==2.4.1 48 | 49 | pip install orbax-checkpoint==0.6.4 50 | 51 | 52 | log "Setup complete! The 'dime' environment is ready for use." 53 | -------------------------------------------------------------------------------- /diffusion/od/od_sampling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import distrax 3 | import jax.numpy as jnp 4 | 5 | 6 | def single_sample(seed, model_state, params, obs, integrator, diffusion_model, stop_grad=False): 7 | key, key_gen = jax.random.split(seed) 8 | 9 | init_x = diffusion_model.prior_sampler(params, key, 1) 10 | key, key_gen = jax.random.split(key_gen) 11 | init_x = jnp.squeeze(init_x, 0) 12 | if stop_grad: 13 | init_x = jax.lax.stop_gradient(init_x) 14 | key, key_gen = jax.random.split(key_gen) 15 | aux = (init_x, jnp.zeros(1), key) 16 | integrate = integrator(model_state, params, obs, stop_grad) 17 | aux, per_step_output = jax.lax.scan(integrate, aux, jnp.arange(0, diffusion_model.num_steps)) 18 | final_x, log_ratio, _ = aux 19 | 20 | terminal_costs = diffusion_model.prior_log_prob(init_x, params) 21 | running_cost = -(log_ratio + distrax.Tanh().forward_log_det_jacobian(final_x).sum()) 22 | # running_cost = -log_ratio 23 | 24 | final_x = distrax.Tanh().forward(final_x) 25 | x_t = per_step_output 26 | x_t = jnp.concatenate([jnp.expand_dims(init_x, 0), x_t]) 27 | x_t = x_t.at[-1].set(final_x) 28 | stochastic_costs = jnp.zeros_like(running_cost) 29 | return final_x, running_cost, stochastic_costs, terminal_costs.reshape(running_cost.shape), x_t, None 30 | 31 | 32 | def sample(key, model_state, params, obs, integrator, diffusion_model, stop_grad=False): 33 | keys = jax.random.split(key, num=obs.shape[0]) 34 | in_tuple = (keys, model_state, params, obs, integrator, diffusion_model, stop_grad) 35 | in_axes = (0, None, None, 0, None, None, None) 36 | rnd_result = jax.vmap(single_sample, in_axes=in_axes)(*in_tuple) 37 | x_0, running_costs, stochastic_costs, terminal_costs, x_t, _ = rnd_result 38 | 39 | return x_0, running_costs, stochastic_costs, terminal_costs, x_t, None 40 | 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The Code for the DIME paper submission at ICML2025. 2 | ## DIME: Diffusion-Based Maximum Entropy Reinforcement Learning 3 | 4 | This repository accompanies the paper "[DIME: Diffusion-Based Maximum Entropy Reinforcement Learning](https://arxiv.org/pdf/2502.02316)" published at ICML 2025. 5 | 6 | ### Learning Curves are available in *paper_results* 7 | **Update:** We have uploaded all the learning curve data to our repository. You can find the data in the *paper_results* folder. 8 | Additionally, we have run DIME on all remaining DMC enviroinments and added the results to the same folder. 9 | 10 | ### Installation 11 | The file setup.sh provides a convenient way to set up the conda environment and install the required packages automatically via 12 | ```bash 13 | chmod +x setup.sh 14 | ./setup.sh 15 | ``` 16 | 17 | After installation is finished, the conda environment can be activated, and the code can be run using 18 | 19 | ```python 20 | python run_dime.py 21 | ``` 22 | 23 | ### Running DIME 24 | 25 | Specific parameters can be set in the terminal such as the learning environment using hydra's multirun function 26 | 27 | ```python 28 | python run_dime.py --multirun env_name=dm_control/humanoid-run 29 | ``` 30 | 31 | Detailed hyperparameter specifications are available in the config directory. 32 | The current config file is adapted to the hyperparameters used for DMC. If you want to run DIME on the gym environemtns, 33 | you only need to change the v_min and v_max parameters of the critic as specified in the appendix of the paper. We used the same values for both gym environments. 34 | For example if you would like to run DIME on gym's Humanoid-v3 environment, you can do so by running 35 | 36 | ```python 37 | python run_dime.py env_name=Humanoid-v3 alg.critic.v_min=-1600 alg.critic.v_max=1600 38 | ``` 39 | 40 | 41 | ## Acknowledgements 42 | Portions of the project are adapted from other repositories: 43 | - https://github.com/DenisBless/UnderdampedDiffusionBridges is licensed under MIT, 44 | - https://github.com/adityab/CrossQ is licensed under MIT and is built upon code from "[Stable Baselines Jax](https://github.com/araffin/sbx/)" -------------------------------------------------------------------------------- /diffusion/od/od_integrators.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from diffusion.common.utils import sample_kernel, log_prob_kernel, check_stop_grad 5 | 6 | 7 | def get_integrator(cfg, diffusion_model): 8 | def integrator(model_state, params, obs, stop_grad=False): 9 | 10 | def integrate_EM(state, per_step_input): 11 | x, log_w, key_gen = state 12 | step = per_step_input 13 | 14 | step = step.astype(jnp.float32) 15 | 16 | # Compute SDE components 17 | dt = diffusion_model.delta_t_fn(step, params) 18 | sigma_square = 1. / diffusion_model.friction_fn(step, params) 19 | eta = dt * sigma_square 20 | scale = jnp.sqrt(2 * eta) 21 | 22 | # Forward kernel 23 | drift, aux = diffusion_model.drift_fn(step, x, params) 24 | fwd_mean = x + eta * (drift + diffusion_model.forward_model(step, x, obs, model_state, params, aux)) 25 | key, key_gen = jax.random.split(key_gen) 26 | x_new = sample_kernel(key, check_stop_grad(fwd_mean, stop_grad) if stop_grad else fwd_mean, scale) 27 | 28 | # Backward kernel 29 | drift_new, aux_new = diffusion_model.drift_fn(step + 1, x_new, params) 30 | bwd_mean = x_new + eta * ( 31 | drift_new + diffusion_model.backward_model(step + 1, x_new, obs, model_state, params, aux_new)) 32 | 33 | # Evaluate kernels 34 | fwd_log_prob = log_prob_kernel(x_new, fwd_mean, scale) 35 | bwd_log_prob = log_prob_kernel(x, bwd_mean, scale) 36 | 37 | # Update weight and return 38 | log_w += bwd_log_prob - fwd_log_prob 39 | 40 | key, key_gen = jax.random.split(key_gen) 41 | next_state = (x_new, log_w, key_gen) 42 | per_step_output = x_new 43 | return next_state, per_step_output 44 | 45 | if cfg.sampler.integrator == 'EM': 46 | integrate = integrate_EM 47 | else: 48 | raise ValueError(f'No integrator named {cfg.sampler.integrator}.') 49 | 50 | return integrate 51 | 52 | return integrator 53 | -------------------------------------------------------------------------------- /common/buffers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | from stable_baselines3.common.buffers import DictReplayBuffer 5 | from stable_baselines3.common.type_aliases import DictReplayBufferSamples 6 | from stable_baselines3.common.vec_env import VecNormalize 7 | 8 | ####### This class overwrites the DictReplayBuffer from stable baselines. It throws an exception when running the DMC's 9 | # humanoid tasks because of the head height observation. Either shimmy or dmc returns it as a 1-dim 10 | # observation, which is not aligned with the general framework. This class only takes care of dimensonality issues of 11 | # observations during sampling from the buffer and doesn't change anything else 12 | 13 | 14 | class DMCCompatibleDictReplayBuffer(DictReplayBuffer): 15 | def _get_samples( 16 | self, 17 | batch_inds: np.ndarray, 18 | env: Optional[VecNormalize] = None, 19 | ) -> DictReplayBufferSamples: 20 | # type: ignore[signature-mismatch] 21 | # Sample randomly the env idx 22 | env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) 23 | 24 | # Normalize if needed and remove extra dimension (we are using only one env for now) 25 | obs_ = self._normalize_obs({key: np.atleast_3d(obs)[batch_inds, env_indices, :] for key, obs in self.observations.items()}, 26 | env) 27 | next_obs_ = self._normalize_obs( 28 | {key: np.atleast_3d(obs)[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env 29 | ) 30 | 31 | # Convert to torch tensor 32 | observations = {key: self.to_torch(obs) for key, obs in obs_.items()} 33 | next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} 34 | 35 | return DictReplayBufferSamples( 36 | observations=observations, 37 | actions=self.to_torch(self.actions[batch_inds, env_indices]), 38 | next_observations=next_observations, 39 | # Only use dones that are not due to timeouts 40 | # deactivated by default (timeouts is initialized as an array of False) 41 | dones=self.to_torch( 42 | self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape( 43 | -1, 1 44 | ), 45 | rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)), 46 | ) -------------------------------------------------------------------------------- /diffusion/od/dis.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from diffusion.common.utils import init_dt 5 | from diffusion.common.utils import inverse_softplus 6 | from diffusion.common.diffusion_models import DiffusionModel 7 | from diffusion.common.init_diffusion_model import init_od, init_langevin, init_model 8 | 9 | 10 | def init_dis(key, cfg, dim, obs_dim, target=None): 11 | 12 | params = {'params': {'betas': jnp.ones((cfg.alg.actor.diff_steps,)), 13 | 'prior_mean': jnp.zeros((dim,)), 14 | 'prior_std': jnp.ones((dim,)) * inverse_softplus(cfg.sampler.init_std), 15 | 'mass_std': jnp.ones(1) * inverse_softplus(1.), 16 | 'dt': init_dt(cfg), 17 | 'friction': jnp.ones(dim) * inverse_softplus(cfg.sampler.friction) if cfg.per_dim_friction else jnp.ones(1) * inverse_softplus(cfg.sampler.friction), 18 | }} 19 | 20 | prior_log_prob, prior_sampler, delta_t_fn, friction_fn, mass_fn = init_od(cfg, dim) 21 | if target is not None: 22 | langevin_fn = init_langevin(cfg, prior_log_prob, target.log_prob) 23 | 24 | def forward_model(step, x, obs, model_state, params, aux): 25 | langevin_vals = aux 26 | return model_state.apply_fn[0](params['params']['fwd_params'], x, obs, step, 27 | jax.lax.stop_gradient(langevin_vals)) 28 | 29 | def backward_model(step, x, obs, model_state, params, aux): 30 | return jnp.zeros_like(x) 31 | 32 | def drift_fn(step, x, params): 33 | if target is not None: 34 | if cfg.sampler.use_target_score: 35 | _, aux = langevin_fn(step, x, params) 36 | else: 37 | aux = None 38 | else: 39 | aux = None 40 | 41 | return jax.grad(prior_log_prob)(x, params), aux 42 | 43 | key, key_gen = jax.random.split(key) 44 | model_state = init_model(key, params, cfg, dim, obs_dim, learn_forward=True, learn_backward=False) 45 | 46 | return DiffusionModel(num_steps=cfg.alg.actor.diff_steps, 47 | forward_model=forward_model, 48 | backward_model=backward_model, 49 | drift_fn=drift_fn, 50 | delta_t_fn=delta_t_fn, 51 | friction_fn=friction_fn, 52 | mass_fn=mass_fn, 53 | prior_sampler=prior_sampler, 54 | prior_log_prob=prior_log_prob, 55 | ), model_state 56 | -------------------------------------------------------------------------------- /diffusion/common/utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | import numpyro.distributions as npdist 5 | import wandb 6 | import matplotlib.pyplot as plt 7 | from flax import traverse_util 8 | 9 | 10 | def inverse_softplus(x): 11 | # Numerically stable implementation of inverse softplus 12 | # Threshold above which the approximation log(e^x - 1) ≈ x is used 13 | threshold = 20.0 14 | return jnp.where(x > threshold, x, jnp.log(jnp.expm1(x))) 15 | 16 | 17 | def check_stop_grad(expression, stop_grad): 18 | return jax.lax.stop_gradient(expression) if stop_grad else expression 19 | 20 | 21 | def sample_kernel(rng_key, mean, scale): 22 | eps = jax.random.normal(rng_key, shape=(mean.shape[0],)) 23 | return mean + scale * eps 24 | 25 | 26 | def log_prob_kernel(x, mean, scale): 27 | dist = npdist.Independent(npdist.Normal(loc=mean, scale=scale), 1) 28 | return dist.log_prob(x) 29 | 30 | 31 | def avg_list_entries(list, num): 32 | assert len(list) >= num 33 | print(range(0, len(list) - num)) 34 | return [sum(list[i:i + num]) / float(num) for i in range(0, len(list) - num + 1)] 35 | 36 | 37 | def reverse_transition_params(transition_params): 38 | flattened_params, tree = jax.tree_util.tree_flatten(transition_params, is_leaf=None) 39 | reversed_flattened_params = list(map(lambda w: jnp.flip(w, axis=0), flattened_params)) 40 | return jax.tree_util.tree_unflatten(tree, reversed_flattened_params) 41 | 42 | 43 | def interpolate_values(values, X): 44 | # Compute the interpolated values 45 | interpolated_values = [X] + [X + (X / 2 - X) * t for t in values[1:-1]] + [X / 2] 46 | return interpolated_values 47 | 48 | 49 | def flattened_traversal(fn): 50 | def mask(data): 51 | flat = traverse_util.flatten_dict(data) 52 | return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) 53 | 54 | return mask 55 | 56 | 57 | def plot_annealing(model_state, cfg): 58 | if cfg.use_wandb: 59 | fig, ax = plt.subplots() 60 | b = jax.nn.softplus(model_state.params['params']['betas']) 61 | b = jnp.cumsum(b) / jnp.sum(b) 62 | 63 | ax.plot(b) 64 | return {"figures/annealing": [wandb.Image(fig)]} 65 | else: 66 | return {} 67 | 68 | 69 | def plot_timesteps(diffusion_model, model_state, cfg): 70 | if cfg.use_wandb: 71 | dt_fn = lambda step: diffusion_model.delta_t_fn(step, model_state.params) 72 | dts = jax.vmap(dt_fn)(jnp.arange(cfg.algorithm.num_steps)) 73 | fig, ax = plt.subplots() 74 | ax.plot(dts) 75 | return {"figures/timesteps": [wandb.Image(fig)]} 76 | else: 77 | return {} 78 | 79 | 80 | def init_dt(cfg): 81 | if cfg.per_step_dt: 82 | dt_schedule = cfg.sampler.dt_schedule 83 | return inverse_softplus(jnp.ones(cfg.alg.actor.diff_steps) * cfg.dt * dt_schedule(jnp.arange(cfg.alg.actor.diff_steps))) 84 | else: 85 | return jnp.ones(1) * inverse_softplus(cfg.dt) 86 | 87 | 88 | def get_sampler_init(alg_name): 89 | 90 | if alg_name == 'dis': 91 | from diffusion.od.dis import init_dis 92 | return init_dis 93 | 94 | else: 95 | raise ValueError(f'No sampler named {alg_name}.') 96 | -------------------------------------------------------------------------------- /diffusion/common/models/pisgrad_net.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from flax import linen as nn 3 | 4 | 5 | class PISGRADNet(nn.Module): 6 | dim: int 7 | use_target_score: True 8 | layer_norm: bool = False 9 | time_coder_out: int = 64 10 | 11 | num_layers: int = 2 12 | num_hid: int = 64 13 | outer_clip: float = 1e4 14 | inner_clip: float = 1e2 15 | 16 | weight_init: float = 1e-8 17 | bias_init: float = 0. 18 | 19 | def setup(self): 20 | self.timestep_phase = self.param('timestep_phase', nn.initializers.zeros_init(), (1, self.num_hid)) 21 | self.timestep_coeff = jnp.linspace(start=0.1, stop=100, num=self.num_hid)[None] 22 | 23 | self.time_coder_state = nn.Sequential([ 24 | nn.Dense(self.num_hid), 25 | nn.gelu, 26 | nn.Dense(self.time_coder_out), 27 | ]) 28 | 29 | self.time_coder_grad = nn.Sequential([nn.Dense(self.num_hid)] + [nn.Sequential( 30 | [nn.gelu, nn.Dense(self.num_hid)]) for _ in range(self.num_layers)] + [ 31 | nn.Dense(self.dim, kernel_init=nn.initializers.constant(self.weight_init), 32 | bias_init=nn.initializers.constant(self.bias_init))]) 33 | 34 | if self.layer_norm: 35 | self.state_time_net = nn.Sequential([nn.Sequential( 36 | [nn.Dense(self.num_hid), nn.LayerNorm(), nn.gelu]) for _ in range(self.num_layers)] + [ 37 | nn.Dense(self.dim, kernel_init=nn.initializers.constant(1e-8), 38 | bias_init=nn.initializers.zeros_init())]) 39 | else: 40 | self.state_time_net = nn.Sequential([nn.Sequential( 41 | [nn.Dense(self.num_hid), nn.gelu]) for _ in range(self.num_layers)] + [ 42 | nn.Dense(self.dim, kernel_init=nn.initializers.constant(1e-8), 43 | bias_init=nn.initializers.zeros_init())]) 44 | 45 | def get_fourier_features(self, timesteps): 46 | sin_embed_cond = jnp.sin( 47 | (self.timestep_coeff * timesteps) + self.timestep_phase 48 | ) 49 | cos_embed_cond = jnp.cos( 50 | (self.timestep_coeff * timesteps) + self.timestep_phase 51 | ) 52 | return jnp.concatenate([sin_embed_cond, cos_embed_cond], axis=-1) 53 | 54 | def __call__(self, input_array, obs_array, time_array, target_score=None): 55 | time_array_emb = self.get_fourier_features(time_array) 56 | if len(input_array.shape) == 1: 57 | time_array_emb = time_array_emb[0] 58 | 59 | t_net1 = self.time_coder_state(time_array_emb) 60 | 61 | extended_input = jnp.concatenate((input_array, obs_array, t_net1), axis=-1) 62 | out_state = self.state_time_net(extended_input) 63 | out_state = jnp.clip(out_state, -self.outer_clip, self.outer_clip) 64 | if self.use_target_score: 65 | t_net2 = self.time_coder_grad(time_array_emb) 66 | target_score = jnp.clip(target_score, -self.inner_clip, self.inner_clip) 67 | return out_state + t_net2 * target_score 68 | else: 69 | return out_state 70 | -------------------------------------------------------------------------------- /run_dime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import time 4 | import hydra 5 | import wandb 6 | import omegaconf 7 | import traceback 8 | 9 | from common.buffers import DMCCompatibleDictReplayBuffer 10 | from diffusion.dime import DIME 11 | from omegaconf import DictConfig 12 | from models.utils import is_slurm_job 13 | from wandb.integration.sb3 import WandbCallback 14 | from stable_baselines3.common.env_util import make_vec_env 15 | from stable_baselines3.common.callbacks import CallbackList 16 | from models.actor_critic_evaluation_callback import EvalCallback 17 | 18 | 19 | 20 | def _create_alg(cfg: DictConfig): 21 | import gymnasium as gym 22 | try: 23 | import myosuite 24 | except ImportError: 25 | print("myosuite not installed") 26 | pass 27 | 28 | training_env = gym.make(cfg.env_name) 29 | eval_env = make_vec_env(cfg.env_name, n_envs=1, seed=cfg.seed) 30 | env_name_split = cfg.env_name.split('/') 31 | rb_class = None 32 | if env_name_split[0] == 'dm_control': 33 | rb_class = DMCCompatibleDictReplayBuffer if env_name_split[1].split('-')[0] in ['humanoid', 'fish', 'walker', 'quadruped','finger'] else None 34 | 35 | tensorboard_log_dir = f"./logs/{cfg.wandb['group']}/{cfg.wandb['job_type']}/seed= + {str(cfg.seed)}/" 36 | eval_log_dir = f"./eval_logs/{cfg.wandb['group']}/{cfg.wandb['job_type']}/seed= + {str(cfg.seed)}/eval/" 37 | 38 | 39 | model = DIME( 40 | "MultiInputPolicy" if isinstance(training_env.observation_space, gym.spaces.Dict) else "MlpPolicy", 41 | env=training_env, 42 | model_save_path=None, 43 | save_every_n_steps=int(cfg.tot_time_steps / 100000), 44 | cfg=cfg, 45 | tensorboard_log=tensorboard_log_dir, 46 | replay_buffer_class=rb_class 47 | ) 48 | 49 | # Create log dir where evaluation results will be saved 50 | os.makedirs(eval_log_dir, exist_ok=True) 51 | # Create callback that evaluates agent 52 | 53 | eval_callback = EvalCallback( 54 | eval_env, 55 | jax_random_key_for_seeds=cfg.seed, 56 | best_model_save_path=None, 57 | log_path=eval_log_dir, 58 | eval_freq=max(300000 // cfg.log_freq, 1), 59 | n_eval_episodes=5, deterministic=True, render=False 60 | ) 61 | if cfg.wandb["activate"]: 62 | callback_list = CallbackList([eval_callback, WandbCallback(verbose=0, )]) 63 | else: 64 | callback_list = CallbackList([eval_callback]) 65 | return model, callback_list 66 | 67 | 68 | def initialize_and_run(cfg): 69 | cfg = hydra.utils.instantiate(cfg) 70 | seed = cfg.seed 71 | if cfg.wandb["activate"]: 72 | name = f"seed_{seed}" 73 | wandb_config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 74 | wandb.init( 75 | settings=wandb.Settings(_service_wait=300), 76 | project=cfg.wandb["project"], 77 | group=cfg.wandb["group"], 78 | job_type=cfg.wandb["job_type"], 79 | name=name, 80 | config=wandb_config, 81 | entity=cfg.wandb["entity"], 82 | sync_tensorboard=True, 83 | ) 84 | if is_slurm_job(): 85 | print(f"SLURM_JOB_ID: {os.environ.get('SLURM_JOB_ID')}") 86 | wandb.summary['SLURM_JOB_ID'] = os.environ.get('SLURM_JOB_ID') 87 | model, callback_list = _create_alg(cfg) 88 | model.learn(total_timesteps=cfg.tot_time_steps, progress_bar=True, callback=callback_list) 89 | 90 | 91 | @hydra.main(version_base=None, config_path="configs", config_name="base") 92 | def main(cfg: DictConfig) -> None: 93 | try: 94 | starting_time = time.time() 95 | if cfg.use_jit: 96 | initialize_and_run(cfg) 97 | else: 98 | with jax.disable_jit(): 99 | initialize_and_run(cfg) 100 | end_time = time.time() 101 | print(f"Training took: {(end_time - starting_time)/3600} hours") 102 | if cfg.wandb["activate"]: 103 | wandb.finish() 104 | except Exception as ex: 105 | print("-- exception occured. traceback :") 106 | traceback.print_tb(ex.__traceback__) 107 | print(ex, flush=True) 108 | print("--------------------------------\n") 109 | traceback.print_exception(ex) 110 | if cfg.wandb["activate"]: 111 | wandb.finish() 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /common/off_policy_algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Union, Optional, Tuple, Dict, Any, List 2 | 3 | import jax 4 | import numpy as np 5 | from stable_baselines3.common.buffers import ReplayBuffer, DictReplayBuffer 6 | from stable_baselines3.common.noise import ActionNoise 7 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm 8 | from stable_baselines3.common.policies import BasePolicy 9 | from stable_baselines3.common.type_aliases import GymEnv, Schedule 10 | from stable_baselines3 import HerReplayBuffer 11 | from gymnasium import spaces 12 | 13 | 14 | class OffPolicyAlgorithmJax(OffPolicyAlgorithm): 15 | def __init__( 16 | self, 17 | policy: Type[BasePolicy], 18 | env: Union[GymEnv, str], 19 | learning_rate: Union[float, Schedule], 20 | qf_learning_rate: Optional[float] = None, 21 | buffer_size: int = 1_000_000, # 1e6 22 | learning_starts: int = 100, 23 | batch_size: int = 256, 24 | tau: float = 0.005, 25 | gamma: float = 0.99, 26 | train_freq: Union[int, Tuple[int, str]] = (1, "step"), 27 | gradient_steps: int = 1, 28 | action_noise: Optional[ActionNoise] = None, 29 | replay_buffer_class: Optional[Type[ReplayBuffer]] = None, 30 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 31 | optimize_memory_usage: bool = False, 32 | policy_kwargs: Optional[Dict[str, Any]] = None, 33 | tensorboard_log: Optional[str] = None, 34 | verbose: int = 0, 35 | device: str = "auto", 36 | support_multi_env: bool = False, 37 | monitor_wrapper: bool = True, 38 | seed: Optional[int] = None, 39 | use_sde: bool = False, 40 | sde_sample_freq: int = -1, 41 | use_sde_at_warmup: bool = False, 42 | sde_support: bool = True, 43 | supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, 44 | stats_window_size: int = 100, 45 | ): 46 | super().__init__( 47 | policy=policy, 48 | env=env, 49 | learning_rate=learning_rate, 50 | buffer_size=buffer_size, 51 | learning_starts=learning_starts, 52 | batch_size=batch_size, 53 | tau=tau, 54 | gamma=gamma, 55 | train_freq=train_freq, 56 | gradient_steps=gradient_steps, 57 | replay_buffer_class=replay_buffer_class, 58 | replay_buffer_kwargs=replay_buffer_kwargs, 59 | action_noise=action_noise, 60 | use_sde=use_sde, 61 | sde_sample_freq=sde_sample_freq, 62 | use_sde_at_warmup=use_sde_at_warmup, 63 | policy_kwargs=policy_kwargs, 64 | tensorboard_log=tensorboard_log, 65 | verbose=verbose, 66 | seed=seed, 67 | sde_support=sde_support, 68 | supported_action_spaces=supported_action_spaces, 69 | support_multi_env=support_multi_env, 70 | stats_window_size=stats_window_size, 71 | ) 72 | # Will be updated later 73 | self.key = jax.random.key(0) 74 | # Note: we do not allow schedule for it 75 | self.qf_learning_rate = qf_learning_rate 76 | 77 | def _get_torch_save_params(self): 78 | return [], [] 79 | 80 | def _excluded_save_params(self) -> List[str]: 81 | excluded = super()._excluded_save_params() 82 | excluded.remove("policy") 83 | return excluded 84 | 85 | def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] 86 | super().set_random_seed(seed) 87 | if seed is None: 88 | # Sample random seed 89 | seed = np.random.randint(2**14) 90 | self.key = jax.random.key(seed) 91 | 92 | def _setup_model(self) -> None: 93 | if self.replay_buffer_class is None: # type: ignore[has-type] 94 | if isinstance(self.observation_space, spaces.Dict): 95 | self.replay_buffer_class = DictReplayBuffer 96 | else: 97 | self.replay_buffer_class = ReplayBuffer 98 | 99 | self._setup_lr_schedule() 100 | # By default qf_learning_rate = pi_learning_rate 101 | self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1) 102 | self.set_random_seed(self.seed) 103 | # Make a local copy as we should not pickle 104 | # the environment when using HerReplayBuffer 105 | replay_buffer_kwargs = self.replay_buffer_kwargs.copy() 106 | if issubclass(self.replay_buffer_class, HerReplayBuffer): # type: ignore[arg-type] 107 | assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" 108 | replay_buffer_kwargs["env"] = self.env 109 | 110 | self.replay_buffer = self.replay_buffer_class( # type: ignore[misc] 111 | self.buffer_size, 112 | self.observation_space, 113 | self.action_space, 114 | device="cpu", # force cpu device to easy torch -> numpy conversion 115 | n_envs=self.n_envs, 116 | optimize_memory_usage=self.optimize_memory_usage, 117 | **replay_buffer_kwargs, 118 | ) 119 | # Convert train freq parameter to TrainFreq object 120 | self._convert_train_freq() 121 | -------------------------------------------------------------------------------- /diffusion/common/init_diffusion_model.py: -------------------------------------------------------------------------------- 1 | import distrax 2 | import jax.numpy as jnp 3 | import jax 4 | import optax 5 | from flax.training import train_state 6 | from jax._src.nn.functions import softplus 7 | 8 | from diffusion.common.learning_rate_scheduler import get_learning_rate_scheduler 9 | from diffusion.common.models.pisgrad_net import PISGRADNet 10 | from diffusion.common.utils import flattened_traversal 11 | 12 | 13 | def init_od(cfg, dim): 14 | alg_cfg = cfg.sampler 15 | 16 | def prior_sampler(params, key, n_samples): 17 | samples = distrax.MultivariateNormalDiag(params['params']['prior_mean'], 18 | jnp.ones(dim) * jax.nn.softplus(params['params']['prior_std'])).sample( 19 | seed=key, 20 | sample_shape=( 21 | n_samples,)) 22 | return samples if alg_cfg.learn_prior else jax.lax.stop_gradient(samples) 23 | 24 | if alg_cfg.learn_prior: 25 | def prior_log_prob(x, params): 26 | log_probs = distrax.MultivariateNormalDiag(params['params']['prior_mean'], 27 | jnp.ones(dim) * jax.nn.softplus( 28 | params['params']['prior_std'])).log_prob(x) 29 | return log_probs 30 | else: 31 | def prior_log_prob(x, params): 32 | log_probs = distrax.MultivariateNormalDiag(jnp.zeros(dim), jnp.ones(dim) * alg_cfg.init_std).log_prob(x) 33 | return log_probs 34 | 35 | dt_schedule = alg_cfg.dt_schedule 36 | 37 | def delta_t_fn(step, params): 38 | if cfg.per_step_dt: 39 | dt = params['params']['dt'][step.astype(int)] if cfg.learn_dt else jax.lax.stop_gradient(params['params']['dt'][step.astype(int)]) 40 | return softplus(dt) 41 | else: 42 | dt = params['params']['dt'] if cfg.learn_dt else jax.lax.stop_gradient(params['params']['dt']) 43 | return softplus(dt) * dt_schedule(step) 44 | 45 | def friction_fn(step, params): 46 | friction = jax.nn.softplus(params['params']['friction']) 47 | return friction if alg_cfg.learn_friction else jax.lax.stop_gradient(friction) 48 | 49 | def mass_fn(params): 50 | mass_std = jax.nn.softplus(params['params']['mass_std']) 51 | return mass_std if alg_cfg.learn_mass_matrix else jax.lax.stop_gradient(mass_std) 52 | 53 | return prior_log_prob, prior_sampler, delta_t_fn, friction_fn, mass_fn 54 | 55 | 56 | def init_langevin(cfg, prior_log_prob, target_log_prob): 57 | alg_cfg = cfg.algorithm 58 | dim = cfg.target.dim 59 | target_score_max_norm = alg_cfg.target_score_max_norm 60 | 61 | def get_betas(params): 62 | b = jax.nn.softplus(params['params']['betas']) 63 | b = jnp.cumsum(b) / jnp.sum(b) 64 | b = b if alg_cfg.learn_betas else jax.lax.stop_gradient(b) 65 | 66 | # Freeze first and last beta 67 | b = b.at[0].set(jax.lax.stop_gradient(b[0])) 68 | b = b.at[-1].set(jax.lax.stop_gradient(b[-1])) 69 | 70 | def get_beta(step): 71 | return b[jnp.array(step, int)] 72 | 73 | return get_beta 74 | 75 | def clip_target_score(target_score): 76 | target_score_norm = jnp.linalg.norm(target_score) 77 | target_score_clipped = jnp.where(target_score_norm > target_score_max_norm * jnp.sqrt(dim), 78 | (target_score_max_norm * jnp.sqrt(dim) * target_score) / target_score_norm, 79 | target_score) 80 | return target_score_clipped 81 | 82 | def langevin_fn(step, x, params): 83 | beta = get_betas(params)(step) 84 | target_score = jax.grad(lambda x: jnp.squeeze(target_log_prob(x)))(x) 85 | prior_score = jax.grad(lambda x: jnp.squeeze(prior_log_prob(x, params)))(x) 86 | if target_score_max_norm is None: 87 | return beta * target_score + (1 - beta) * prior_score, target_score 88 | 89 | else: 90 | target_score_clipped = clip_target_score(target_score) 91 | return beta * target_score_clipped + (1 - beta) * prior_score, target_score_clipped 92 | 93 | return langevin_fn 94 | 95 | 96 | def init_model(key, params, cfg, dim, obs_dim, learn_forward=True, learn_backward=True): 97 | # Define the model 98 | 99 | in_dim = 2 * dim if cfg.sampler.underdamped else dim 100 | 101 | key, key_gen = jax.random.split(key) 102 | if learn_forward: 103 | fwd_model = PISGRADNet(dim=dim, **cfg.sampler.score_model) 104 | fwd_params = fwd_model.init(key, jnp.ones([cfg.alg.batch_size, in_dim]), 105 | jnp.ones(([cfg.alg.batch_size, obs_dim])), 106 | jnp.ones([cfg.alg.batch_size, 1]), 107 | jnp.ones([cfg.alg.batch_size, dim])) 108 | params['params']['fwd_params'] = fwd_params 109 | fwd_apply_fn = fwd_model.apply 110 | else: 111 | fwd_apply_fn = None 112 | 113 | key, key_gen = jax.random.split(key_gen) 114 | if learn_backward: 115 | bwd_model = PISGRADNet(dim=dim, **cfg.sampler.score_model) 116 | bwd_params = bwd_model.init(key, jnp.ones([cfg.alg.batch_size, in_dim]), 117 | jnp.ones(([cfg.alg.batch_size, obs_dim])), 118 | jnp.ones([cfg.alg.batch_size, 1]), 119 | jnp.ones([cfg.alg.batch_size, dim])) 120 | params['params']['bwd_params'] = bwd_params 121 | bwd_apply_fn = bwd_model.apply 122 | else: 123 | bwd_apply_fn = None 124 | 125 | if cfg.use_step_size_scheduler: 126 | model_opt = optax.masked(optax.adam(get_learning_rate_scheduler(cfg, cfg.step_size), 127 | b1=cfg.alg.optimizer.b1), 128 | mask=flattened_traversal( 129 | lambda path, _: ('fwd_params' in path) or ('bwd_params' in path))) 130 | betas_opt = optax.masked(optax.adam(get_learning_rate_scheduler(cfg, cfg.step_size_betas), 131 | b1=cfg.alg.optimizer.b1), 132 | mask=flattened_traversal( 133 | lambda path, _: ('fwd_params' not in path) and ('bwd_params' not in path))) 134 | else: 135 | model_opt = optax.masked(optax.adam(cfg.step_size, b1=cfg.alg.optimizer.b1), 136 | mask=flattened_traversal( 137 | lambda path, _: ('fwd_params' in path) or ('bwd_params' in path))) 138 | betas_opt = optax.masked(optax.adam(cfg.step_size_betas, b1=cfg.alg.optimizer.b1), 139 | mask=flattened_traversal( 140 | lambda path, _: ('fwd_params' not in path) and ('bwd_params' not in path))) 141 | 142 | if cfg.alg.optimizer.do_actor_grad_clip: 143 | optimizer = optax.chain(optax.zero_nans(), 144 | optax.clip(cfg.alg.optimizer.actor_grad_clip), 145 | model_opt, betas_opt) 146 | else: 147 | optimizer = optax.chain(optax.zero_nans(), 148 | model_opt, betas_opt) 149 | 150 | model_state = train_state.TrainState.create(apply_fn=(fwd_apply_fn, bwd_apply_fn), params=params, tx=optimizer) 151 | 152 | return model_state 153 | -------------------------------------------------------------------------------- /diffusion/diffusion_policy.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | 5 | from functools import partial 6 | 7 | import optax 8 | from gymnasium import spaces 9 | from models.critic import VectorCritic 10 | from diffusion.common.utils import get_sampler_init 11 | from diffusion.od.od_integrators import get_integrator as get_integrator_od 12 | from diffusion.od.od_sampling import sample as sample_od 13 | from common.policies import BaseJaxPolicy 14 | from common.type_aliases import RLTrainState 15 | from stable_baselines3.common.type_aliases import Schedule 16 | 17 | from models.utils import activation_fn 18 | 19 | 20 | class DiffPol(BaseJaxPolicy): 21 | def __init__(self, 22 | observation_space: spaces.Space, 23 | action_space: spaces.Box, 24 | cfg, 25 | squash_output: bool = True, 26 | **kwargs, 27 | ): 28 | super().__init__(observation_space, 29 | action_space, 30 | features_extractor=None, 31 | features_extractor_kwargs=None, 32 | squash_output=squash_output) 33 | self.cfg = cfg 34 | self.use_sde = False 35 | 36 | def build(self, key, lr_schedule: Schedule, qf_learning_rate: float): 37 | key, score_key, stat_distr_key, qf_key, dropout_key, stat_distr_bn_key, bn_key = jax.random.split(key, 7) 38 | # Keep a key for the actor 39 | key, self.key = jax.random.split(key, 2) 40 | # Initialize noise 41 | self.reset_noise() 42 | 43 | if isinstance(self.observation_space, spaces.Dict): 44 | obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) 45 | else: 46 | obs = jnp.array([self.observation_space.sample()]) 47 | action = jnp.array([self.action_space.sample()]) 48 | 49 | a_dim = self.action_space.shape[0] 50 | obs_dim = obs.shape[1] 51 | 52 | # initialize Q-function 53 | self.qf = VectorCritic( 54 | dropout_rate=self.cfg.alg.critic.dropout_rate, 55 | use_layer_norm=self.cfg.alg.critic.use_layer_norm, 56 | use_batch_norm=self.cfg.alg.optimizer.bn, 57 | bn_warmup=self.cfg.alg.optimizer.bn_warmup, 58 | batch_norm_momentum=self.cfg.alg.optimizer.bn_momentum, 59 | batch_norm_mode=self.cfg.alg.optimizer.bn_mode, 60 | net_arch=self.cfg.alg.critic.hs, 61 | activation_fn=activation_fn[self.cfg.alg.critic.activation], 62 | n_critics=self.cfg.alg.critic.n_critics, 63 | n_atoms=self.cfg.alg.critic.n_atoms, 64 | ) 65 | 66 | qf_init_variables = self.qf.init( 67 | {"params": qf_key, "dropout": dropout_key, "batch_stats": bn_key}, 68 | obs, 69 | action, 70 | train=False, 71 | ) 72 | target_qf_init_variables = self.qf.init( 73 | {"params": qf_key, "dropout": dropout_key, "batch_stats": bn_key}, 74 | obs, 75 | action, 76 | train=False, 77 | ) 78 | 79 | self.qf_state = RLTrainState.create( 80 | apply_fn=self.qf.apply, 81 | params=qf_init_variables["params"], 82 | batch_stats=qf_init_variables["batch_stats"], 83 | target_params=target_qf_init_variables["params"], 84 | target_batch_stats=target_qf_init_variables["batch_stats"], 85 | tx=optax.adam( 86 | learning_rate=qf_learning_rate, # type: ignore[call-arg] 87 | **dict({ 88 | 'b1': self.cfg.alg.optimizer.b1, 89 | 'b2': 0.999 # default 90 | }), 91 | ), 92 | ) 93 | 94 | self.qf.apply = jax.jit( # type: ignore[method-assign] 95 | self.qf.apply, 96 | static_argnames=("dropout_rate", "use_layer_norm", 97 | "use_batch_norm", "batch_norm_momentum", "bn_mode"), 98 | ) 99 | 100 | # Initialize actor 101 | key, diff_key = jax.random.split(key, 2) 102 | self.actor_model, self.actor_state = get_sampler_init(self.cfg.sampler.name)(diff_key, self.cfg, a_dim, obs_dim) 103 | target_model_state = get_sampler_init(self.cfg.sampler.name)(diff_key, self.cfg, a_dim, obs_dim) 104 | self.actor_target_model, self.target_actor_state = target_model_state 105 | self.integrator = get_integrator_od(self.cfg, self.actor_model) 106 | self.target_integrator = get_integrator_od(self.cfg, self.actor_target_model) 107 | sampler = sample_od 108 | self.sampler = partial(sampler, integrator=self.integrator, diffusion_model=self.actor_model) 109 | self.target_sampler = partial(sampler, integrator=self.target_integrator, 110 | diffusion_model=self.actor_target_model) 111 | return key 112 | 113 | @staticmethod 114 | @partial(jax.jit, static_argnames=["sampler", "return_logprob"]) 115 | def sample_action(actor_state, actor_params, observations, key, sampler, return_logprob=False): 116 | out = sampler(key, actor_state, actor_params, observations, stop_grad=False) 117 | # terminal costs = prior log prob loss for od and prior log prob loss - momentum loss for ud 118 | final_action, running_costs, stochastic_costs, terminal_costs, a_t, v_t = out 119 | return final_action, running_costs, stochastic_costs, terminal_costs, a_t, v_t 120 | 121 | def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: 122 | # Trick to use gSDE: repeat sampled noise by using the same noise key 123 | if not self.use_sde: 124 | self.reset_noise() 125 | actions, *_ = DiffPol.sample_action(self.actor_state, self.actor_state.params, observation, self.noise_key, 126 | self.sampler) 127 | return actions[0] 128 | 129 | def _predict2(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: 130 | # Trick to use gSDE: repeat sampled noise by using the same noise key 131 | if not self.use_sde: 132 | self.reset_noise() 133 | actions, _, _, _, la, _ = DiffPol.sample_action(self.actor_state, self.actor_state.params, observation, 134 | self.noise_key, self.sampler) 135 | actions = (actions, la) 136 | return actions 137 | 138 | def reset_noise(self, batch_size: int = 1) -> None: 139 | """ 140 | Sample new weights for the exploration matrix, when using gSDE. 141 | """ 142 | self.key, self.noise_key = jax.random.split(self.key, 2) 143 | 144 | def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: 145 | return self._predict(obs, deterministic=deterministic) 146 | 147 | def predict_critic(self, observation: np.ndarray, action: np.ndarray) -> np.ndarray: 148 | 149 | if not self.use_sde: 150 | self.reset_noise() 151 | 152 | def Q(params, batch_stats, o, a, dropout_key): 153 | return self.qf_state.apply_fn( 154 | {"params": params, "batch_stats": batch_stats}, 155 | o, a, 156 | rngs={"dropout": dropout_key}, 157 | train=False 158 | ) 159 | 160 | return jax.jit(Q)( 161 | self.qf_state.params, 162 | self.qf_state.batch_stats, 163 | observation, 164 | action, 165 | self.noise_key, 166 | ) -------------------------------------------------------------------------------- /common/policies.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.policies import BasePolicy 2 | from functools import partial 3 | import jax 4 | from typing import Dict, Optional, Tuple, Union, no_type_check 5 | 6 | import numpy as np 7 | from gymnasium import spaces 8 | from stable_baselines3.common.preprocessing import maybe_transpose, is_image_space 9 | from stable_baselines3.common.utils import is_vectorized_observation 10 | 11 | 12 | class BaseJaxPolicy(BasePolicy): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__( 15 | *args, 16 | **kwargs, 17 | ) 18 | 19 | @staticmethod 20 | @partial(jax.jit, static_argnames=["return_logprob"]) 21 | def sample_action(actor_state, obervations, key, return_logprob=False): 22 | if hasattr(actor_state, "batch_stats"): 23 | dist = actor_state.apply_fn({"params": actor_state.params, "batch_stats": actor_state.batch_stats}, 24 | obervations, train=False) 25 | else: 26 | dist = actor_state.apply_fn(actor_state.params, obervations) 27 | action = dist.sample(seed=key) 28 | if not return_logprob: 29 | return action 30 | else: 31 | return action, dist.log_prob(action) 32 | 33 | @staticmethod 34 | @partial(jax.jit, static_argnames=["return_logprob"]) 35 | def select_action(actor_state, obervations, return_logprob=False): 36 | if hasattr(actor_state, "batch_stats"): 37 | dist = actor_state.apply_fn({"params": actor_state.params, "batch_stats": actor_state.batch_stats}, 38 | obervations, train=False) 39 | else: 40 | dist = actor_state.apply_fn(actor_state.params, obervations) 41 | action = dist.mode() 42 | 43 | if not return_logprob: 44 | return action 45 | else: 46 | return action, dist.log_prob(action) 47 | 48 | @no_type_check 49 | def predict( 50 | self, 51 | observation: Union[np.ndarray, Dict[str, np.ndarray]], 52 | state: Optional[Tuple[np.ndarray, ...]] = None, 53 | episode_start: Optional[np.ndarray] = None, 54 | deterministic: bool = False, 55 | ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: 56 | # self.set_training_mode(False) 57 | 58 | observation, vectorized_env = self.prepare_obs(observation) 59 | 60 | actions = self._predict(observation, deterministic=deterministic) 61 | 62 | # Convert to numpy, and reshape to the original action shape 63 | actions = np.array(actions).reshape((-1, *self.action_space.shape)) 64 | 65 | if isinstance(self.action_space, spaces.Box): 66 | if self.squash_output: 67 | # Clip due to numerical instability 68 | actions = np.clip(actions, -1, 1) 69 | # Rescale to proper domain when using squashing 70 | actions = self.unscale_action(actions) 71 | else: 72 | # Actions could be on arbitrary scale, so clip the actions to avoid 73 | # out of bound error (e.g. if sampling from a Gaussian distribution) 74 | actions = np.clip(actions, self.action_space.low, self.action_space.high) 75 | 76 | # Remove batch dimension if needed 77 | if not vectorized_env: 78 | actions = actions.squeeze(axis=0) # type: ignore[call-overload] 79 | 80 | return actions, state 81 | 82 | def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, bool]: 83 | vectorized_env = False 84 | if isinstance(observation, dict): 85 | assert isinstance(self.observation_space, spaces.Dict) 86 | # Minimal dict support: flatten 87 | keys = list(self.observation_space.keys()) 88 | vectorized_env = is_vectorized_observation(observation[keys[0]], self.observation_space[keys[0]]) 89 | 90 | # Add batch dim and concatenate 91 | observation = np.concatenate( 92 | [np.atleast_2d(observation[key].reshape(-1, *self.observation_space[key].shape)) for key in keys], 93 | axis=1, 94 | ) 95 | # need to copy the dict as the dict in VecFrameStack will become a torch tensor 96 | # observation = copy.deepcopy(observation) 97 | # for key, obs in observation.items(): 98 | # obs_space = self.observation_space.spaces[key] 99 | # if is_image_space(obs_space): 100 | # obs_ = maybe_transpose(obs, obs_space) 101 | # else: 102 | # obs_ = np.array(obs) 103 | # vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) 104 | # # Add batch dimension if needed 105 | # observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) 106 | 107 | elif is_image_space(self.observation_space): 108 | # Handle the different cases for images 109 | # as PyTorch use channel first format 110 | observation = maybe_transpose(observation, self.observation_space) 111 | 112 | else: 113 | observation = np.array(observation) 114 | 115 | if not isinstance(self.observation_space, spaces.Dict): 116 | assert isinstance(observation, np.ndarray) 117 | vectorized_env = is_vectorized_observation(observation, self.observation_space) 118 | # Add batch dimension if needed 119 | observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc] 120 | 121 | assert isinstance(observation, np.ndarray) 122 | return observation, vectorized_env 123 | 124 | def set_training_mode(self, mode: bool) -> None: 125 | # self.actor.set_training_mode(mode) 126 | # self.critic.set_training_mode(mode) 127 | self.training = mode 128 | 129 | @no_type_check 130 | def predict2( 131 | self, 132 | observation: Union[np.ndarray, Dict[str, np.ndarray]], 133 | state: Optional[Tuple[np.ndarray, ...]] = None, 134 | episode_start: Optional[np.ndarray] = None, 135 | deterministic: bool = False, 136 | ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: 137 | # self.set_training_mode(False) 138 | 139 | observation, vectorized_env = self.prepare_obs(observation) 140 | 141 | all_actions = self._predict2(observation, deterministic=deterministic) 142 | latent_actions = all_actions[1] 143 | actions = all_actions[0] 144 | # Convert to numpy, and reshape to the original action shape 145 | actions = np.array(actions).reshape((-1, *self.action_space.shape)) 146 | 147 | if isinstance(self.action_space, spaces.Box): 148 | if self.squash_output: 149 | # Clip due to numerical instability 150 | actions = np.clip(actions, -1, 1) 151 | # Rescale to proper domain when using squashing 152 | actions = self.unscale_action(actions) 153 | else: 154 | # Actions could be on arbitrary scale, so clip the actions to avoid 155 | # out of bound error (e.g. if sampling from a Gaussian distribution) 156 | actions = np.clip(actions, self.action_space.low, self.action_space.high) 157 | 158 | # Remove batch dimension if needed 159 | if not vectorized_env: 160 | actions = actions.squeeze(axis=0) # type: ignore[call-overload] 161 | 162 | return actions, latent_actions, state 163 | -------------------------------------------------------------------------------- /models/actor_critic_evaluation_callback.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from stable_baselines3.common.callbacks import EventCallback, BaseCallback 4 | from typing import Any, Dict, Optional, Union 5 | import gymnasium as gym 6 | from stable_baselines3.common.evaluation import evaluate_policy 7 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization 8 | import numpy as np 9 | import os 10 | import jax 11 | 12 | 13 | class EvalCallback(EventCallback): 14 | """ 15 | Callback for evaluating an agent. 16 | 17 | .. warning:: 18 | 19 | When using multiple environments, each call to ``env.step()`` 20 | will effectively correspond to ``n_envs`` steps. 21 | To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` 22 | 23 | :param eval_env: The environment used for initialization 24 | :param callback_on_new_best: Callback to trigger 25 | when there is a new best model according to the ``mean_reward`` 26 | :param callback_after_eval: Callback to trigger after every evaluation 27 | :param n_eval_episodes: The number of episodes to test the agent 28 | :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. 29 | :param log_path: Path to a folder where the evaluations (``evaluations.npz``) 30 | will be saved. It will be updated at each evaluation. 31 | :param best_model_save_path: Path to a folder where the best model 32 | according to performance on the eval env will be saved. 33 | :param deterministic: Whether the evaluation should 34 | use a stochastic or deterministic actions. 35 | :param render: Whether to render or not the environment during evaluation 36 | :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results 37 | :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been 38 | wrapped with a Monitor wrapper) 39 | """ 40 | 41 | def __init__( 42 | self, 43 | eval_env: Union[gym.Env, VecEnv], 44 | jax_random_key_for_seeds: int, 45 | callback_on_new_best: Optional[BaseCallback] = None, 46 | callback_after_eval: Optional[BaseCallback] = None, 47 | n_eval_episodes: int = 5, 48 | eval_freq: int = 10000, 49 | log_path: Optional[str] = None, 50 | best_model_save_path: Optional[str] = None, 51 | deterministic: bool = True, 52 | render: bool = False, 53 | verbose: int = 1, 54 | warn: bool = True, 55 | ): 56 | super().__init__(callback_after_eval, verbose=verbose) 57 | 58 | self.callback_on_new_best = callback_on_new_best 59 | if self.callback_on_new_best is not None: 60 | # Give access to the parent 61 | self.callback_on_new_best.parent = self 62 | 63 | self.n_eval_episodes = n_eval_episodes 64 | self.eval_freq = eval_freq 65 | self.best_mean_reward = -np.inf 66 | self.last_mean_reward = -np.inf 67 | self.deterministic = deterministic 68 | self.render = render 69 | self.warn = warn 70 | 71 | # Convert to VecEnv for consistency 72 | if not isinstance(eval_env, VecEnv): 73 | eval_env = DummyVecEnv([lambda: eval_env]) 74 | 75 | self.eval_env = eval_env 76 | self.best_model_save_path = best_model_save_path 77 | # Logs will be written in ``evaluations.npz`` 78 | if log_path is not None: 79 | log_path = os.path.join(log_path, "evaluations") 80 | self.log_path = log_path 81 | self.evaluations_results = [] 82 | self.evaluations_timesteps = [] 83 | self.evaluations_length = [] 84 | # For computing success rate 85 | self._is_success_buffer = [] 86 | self._per_time_is_success_buffer = [] 87 | self.evaluations_successes = [] 88 | 89 | # generate a list of 1M random integers, using a jax random key supplied in the args 90 | seed_list = jax.random.randint(jax.random.key(jax_random_key_for_seeds), (10000000,), 0, 2 ** 30 - 1) 91 | # cast to numpy 92 | self.seed_list = np.array(seed_list) 93 | 94 | 95 | def _init_callback(self) -> None: 96 | # Does not work in some corner cases, where the wrapper is not the same 97 | if not isinstance(self.training_env, type(self.eval_env)): 98 | warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") 99 | 100 | # Create folders if needed 101 | if self.best_model_save_path is not None: 102 | os.makedirs(self.best_model_save_path, exist_ok=True) 103 | if self.log_path is not None: 104 | os.makedirs(os.path.dirname(self.log_path), exist_ok=True) 105 | 106 | # Init callback called on new best model 107 | if self.callback_on_new_best is not None: 108 | self.callback_on_new_best.init_callback(self.model) 109 | 110 | def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: 111 | """ 112 | Callback passed to the ``evaluate_policy`` function 113 | in order to log the success rate (when applicable), 114 | for instance when using HER. 115 | 116 | :param locals_: 117 | :param globals_: 118 | """ 119 | info = locals_["info"] 120 | maybe_is_success = info.get("is_success") or info.get("solved") # solved for myo suite 121 | local_per_time_step_buffer = [] 122 | if maybe_is_success is not None: 123 | local_per_time_step_buffer.append(maybe_is_success) 124 | if locals_["done"]: 125 | maybe_is_success = info.get("is_success") or info.get("solved") # solved for myo suite 126 | if maybe_is_success is not None: 127 | self._is_success_buffer.append(maybe_is_success) 128 | if len(local_per_time_step_buffer) > 0: 129 | local_per_time_step_buffer = np.array(local_per_time_step_buffer) 130 | local_per_time_step_buffer = np.mean(local_per_time_step_buffer, axis=0) 131 | self._per_time_is_success_buffer.append(1 if local_per_time_step_buffer > 0 else 0) 132 | 133 | def _on_step(self) -> bool: 134 | continue_training = True 135 | 136 | if self.eval_freq > 0 and (self.n_calls % self.eval_freq == 0 or self.n_calls == 1): 137 | # reset the env with a new seed at the current timestep (reproducibilty) 138 | self.eval_env.seed(int(self.seed_list[self.n_calls])) 139 | 140 | # Sync training and eval env if there is VecNormalize 141 | if self.model.get_vec_normalize_env() is not None: 142 | try: 143 | sync_envs_normalization(self.training_env, self.eval_env) 144 | except AttributeError as e: 145 | raise AssertionError( 146 | "Training and eval env are not wrapped the same way, " 147 | "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " 148 | "and warning above." 149 | ) from e 150 | 151 | # Reset success rate buffer 152 | self._is_success_buffer = [] 153 | self._per_time_is_success_buffer = [] 154 | 155 | episode_rewards, episode_lengths = evaluate_policy( 156 | self.model, 157 | self.eval_env, 158 | n_eval_episodes=self.n_eval_episodes, 159 | render=self.render, 160 | deterministic=self.deterministic, 161 | return_episode_rewards=True, 162 | warn=self.warn, 163 | callback=self._log_success_callback, 164 | ) 165 | 166 | if self.log_path is not None: 167 | self.evaluations_timesteps.append(self.num_timesteps) 168 | self.evaluations_results.append(episode_rewards) 169 | self.evaluations_length.append(episode_lengths) 170 | 171 | kwargs = {} 172 | # Save success log if present 173 | if len(self._is_success_buffer) > 0: 174 | self.evaluations_successes.append(self._is_success_buffer) 175 | kwargs = dict(successes=self.evaluations_successes) 176 | 177 | mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) 178 | mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) 179 | self.last_mean_reward = mean_reward 180 | 181 | if self.verbose >= 1: 182 | print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") 183 | print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") 184 | # Add to current Logger 185 | self.logger.record("eval/mean_reward", float(mean_reward)) 186 | self.logger.record("eval/mean_ep_length", mean_ep_length) 187 | 188 | if len(self._is_success_buffer) > 0: 189 | success_rate = np.mean(self._is_success_buffer) 190 | per_time_success_rate = np.mean(self._per_time_is_success_buffer) 191 | if self.verbose >= 1: 192 | print(f"Success rate: {100 * success_rate:.2f}%") 193 | self.logger.record("eval/success_rate", success_rate) 194 | self.logger.record("eval/per_time_success_rate", per_time_success_rate) 195 | 196 | # Dump log so the evaluation results are printed with the correct timestep 197 | self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") 198 | self.logger.dump(self.num_timesteps) 199 | 200 | if mean_reward > self.best_mean_reward: 201 | if self.verbose >= 1: 202 | print("New best mean reward!") 203 | if self.best_model_save_path is not None: 204 | self.model.save(os.path.join(self.best_model_save_path, "best_model")) 205 | self.best_mean_reward = mean_reward 206 | # Trigger callback on new best model, if needed 207 | if self.callback_on_new_best is not None: 208 | continue_training = self.callback_on_new_best.on_step() 209 | 210 | # Trigger callback after every evaluation, if needed 211 | if self.callback is not None: 212 | continue_training = continue_training and self._on_event() 213 | 214 | return continue_training 215 | 216 | def update_child_locals(self, locals_: Dict[str, Any]) -> None: 217 | """ 218 | Update the references to the local variables. 219 | 220 | :param locals_: the local variables during rollout collection 221 | """ 222 | if self.callback: 223 | self.callback.update_locals(locals_) -------------------------------------------------------------------------------- /models/critic.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, Type, Callable, Any, Union, List, Dict, Tuple 2 | import jax 3 | import numpy as np 4 | 5 | import flax.linen as nn 6 | import jax.numpy as jnp 7 | import optax 8 | from flax.linen import initializers 9 | from flax.linen.module import Module, compact, merge_param 10 | from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize 11 | from gymnasium import spaces 12 | from stable_baselines3.common.type_aliases import Schedule 13 | 14 | from common.distributions import TanhTransformedDistribution 15 | import tensorflow_probability 16 | 17 | from common.policies import BaseJaxPolicy 18 | from common.type_aliases import ActorTrainState, RLTrainState 19 | 20 | tfp = tensorflow_probability.substrates.jax 21 | tfd = tfp.distributions 22 | 23 | PRNGKey = Any 24 | Array = Any 25 | Shape = Tuple[int, ...] 26 | Dtype = Any # this could be a real type? 27 | Axes = Union[int, Sequence[int]] 28 | 29 | 30 | class BatchRenorm(Module): 31 | """BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275). 32 | and adapted from Flax's BatchNorm implementation: 33 | https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228 34 | 35 | 36 | Attributes: 37 | use_running_average: if True, the statistics stored in batch_stats will be 38 | used instead of computing the batch statistics on the input. 39 | axis: the feature or non-batch axis of the input. 40 | momentum: decay rate for the exponential moving average of the batch 41 | statistics. 42 | epsilon: a small float added to variance to avoid dividing by zero. 43 | dtype: the dtype of the result (default: infer from input and params). 44 | param_dtype: the dtype passed to parameter initializers (default: float32). 45 | use_bias: if True, bias (beta) is added. 46 | use_scale: if True, multiply by scale (gamma). When the next layer is linear 47 | (also e.g. nn.relu), this can be disabled since the scaling will be done 48 | by the next layer. 49 | bias_init: initializer for bias, by default, zero. 50 | scale_init: initializer for scale, by default, one. 51 | axis_name: the axis name used to combine batch statistics from multiple 52 | devices. See `jax.pmap` for a description of axis names (default: None). 53 | axis_index_groups: groups of axis indices within that named axis 54 | representing subsets of devices to reduce over (default: None). For 55 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over the 56 | examples on the first two and last two devices. See `jax.lax.psum` for 57 | more details. 58 | use_fast_variance: If true, use a faster, but less numerically stable, 59 | calculation for the variance. 60 | """ 61 | 62 | use_running_average: Optional[bool] = None 63 | axis: int = -1 64 | momentum: float = 0.999 65 | bn_warmup: int = 100_000 66 | epsilon: float = 0.001 67 | dtype: Optional[Dtype] = None 68 | param_dtype: Dtype = jnp.float32 69 | use_bias: bool = True 70 | use_scale: bool = True 71 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 72 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 73 | axis_name: Optional[str] = None 74 | axis_index_groups: Any = None 75 | use_fast_variance: bool = True 76 | 77 | @compact 78 | def __call__(self, x, use_running_average: Optional[bool] = None): 79 | """ 80 | Args: 81 | x: the input to be normalized. 82 | use_running_average: if true, the statistics stored in batch_stats will be 83 | used instead of computing the batch statistics on the input. 84 | 85 | Returns: 86 | Normalized inputs (the same shape as inputs). 87 | """ 88 | 89 | use_running_average = merge_param( 90 | 'use_running_average', self.use_running_average, use_running_average 91 | ) 92 | feature_axes = _canonicalize_axes(x.ndim, self.axis) 93 | reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) 94 | feature_shape = [x.shape[ax] for ax in feature_axes] 95 | 96 | ra_mean = self.variable( 97 | 'batch_stats', 98 | 'mean', 99 | lambda s: jnp.zeros(s, jnp.float32), 100 | feature_shape, 101 | ) 102 | ra_var = self.variable( 103 | 'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape 104 | ) 105 | 106 | r_max = self.variable( 107 | 'batch_stats', 108 | 'r_max', 109 | lambda s: s, 110 | 3, 111 | ) 112 | d_max = self.variable( 113 | 'batch_stats', 114 | 'd_max', 115 | lambda s: s, 116 | 5, 117 | ) 118 | steps = self.variable( 119 | 'batch_stats', 120 | 'steps', 121 | lambda s: s, 122 | 0, 123 | ) 124 | 125 | if use_running_average: 126 | mean, var = ra_mean.value, ra_var.value 127 | custom_mean = mean 128 | custom_var = var 129 | else: 130 | mean, var = _compute_stats( 131 | x, 132 | reduction_axes, 133 | dtype=self.dtype, 134 | axis_name=self.axis_name if not self.is_initializing() else None, 135 | axis_index_groups=self.axis_index_groups, 136 | use_fast_variance=self.use_fast_variance, 137 | ) 138 | custom_mean = mean 139 | custom_var = var 140 | if not self.is_initializing(): 141 | # The code below is implemented following the Batch Renormalization paper 142 | std = jnp.sqrt(var + self.epsilon) 143 | ra_std = jnp.sqrt(ra_var.value + self.epsilon) 144 | r = jax.lax.stop_gradient(std / ra_std) 145 | r = jnp.clip(r, 1 / r_max.value, r_max.value) 146 | d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std) 147 | d = jnp.clip(d, -d_max.value, d_max.value) 148 | tmp_var = var / (r ** 2) 149 | tmp_mean = mean - d * jnp.sqrt(custom_var) / r 150 | 151 | # Warm up batch renorm for 100_000 steps to build up proper running statistics 152 | # warmed_up = jnp.greater_equal(steps.value, 100_000).astype(jnp.float32) 153 | warmed_up = jnp.greater_equal(steps.value, self.bn_warmup).astype(jnp.float32) 154 | custom_var = warmed_up * tmp_var + (1. - warmed_up) * custom_var 155 | custom_mean = warmed_up * tmp_mean + (1. - warmed_up) * custom_mean 156 | 157 | ra_mean.value = ( 158 | self.momentum * ra_mean.value + (1 - self.momentum) * mean 159 | ) 160 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 161 | steps.value += 1 162 | 163 | return _normalize( 164 | self, 165 | x, 166 | custom_mean, 167 | custom_var, 168 | reduction_axes, 169 | feature_axes, 170 | self.dtype, 171 | self.param_dtype, 172 | self.epsilon, 173 | self.use_bias, 174 | self.use_scale, 175 | self.bias_init, 176 | self.scale_init, 177 | ) 178 | 179 | 180 | class Critic(nn.Module): 181 | net_arch: Sequence[int] 182 | activation_fn: Type[nn.Module] 183 | batch_norm_momentum: float 184 | bn_warmup: int = 100_000 185 | use_layer_norm: bool = False 186 | dropout_rate: Optional[float] = None 187 | use_batch_norm: bool = False 188 | bn_mode: str = "bn" 189 | n_atoms: int = 101 190 | 191 | @nn.compact 192 | def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train) -> jnp.ndarray: 193 | if 'bn' in self.bn_mode: 194 | BN = nn.BatchNorm 195 | elif 'brn' in self.bn_mode: 196 | BN = BatchRenorm 197 | else: 198 | raise NotImplementedError 199 | 200 | x = jnp.concatenate([x, action], -1) 201 | 202 | if self.use_batch_norm: 203 | x = BN(bn_warmup=self.bn_warmup, use_running_average=not train, momentum=self.batch_norm_momentum)(x) 204 | else: 205 | # Hack to make flax return state_updates. Is only necessary such that the downstream 206 | # functions have the same function signature. 207 | x_dummy = BN(bn_warmup=self.bn_warmup, use_running_average=not train, momentum=self.batch_norm_momentum)(x) 208 | 209 | for n_units in self.net_arch: 210 | x = nn.Dense(n_units)(x) 211 | 212 | if self.dropout_rate is not None and self.dropout_rate > 0: 213 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) 214 | 215 | if self.use_layer_norm: 216 | x = nn.LayerNorm()(x) 217 | 218 | x = self.activation_fn()(x) 219 | 220 | if self.use_batch_norm: 221 | x = BN(bn_warmup=self.bn_warmup,use_running_average=not train, momentum=self.batch_norm_momentum)(x) 222 | else: 223 | x_dummy = BN(bn_warmup=self.bn_warmup, use_running_average=not train, momentum=self.batch_norm_momentum)(x) 224 | x = nn.Dense(self.n_atoms)(x) 225 | # x = nn.Dense(1, kernel_init=nn.initializers.constant(1e-6), 226 | # bias_init=nn.initializers.constant(0.0))(x) 227 | if self.n_atoms > 1: 228 | x = jax.nn.softmax(x, axis=-1) 229 | return x 230 | 231 | 232 | class VectorCritic(nn.Module): 233 | net_arch: Sequence[int] 234 | activation_fn: Type[nn.Module] 235 | batch_norm_momentum: float 236 | bn_warmup: int = 100_000 237 | use_batch_norm: bool = False 238 | batch_norm_mode: str = "bn" 239 | use_layer_norm: bool = False 240 | dropout_rate: Optional[float] = None 241 | n_critics: int = 2 242 | n_atoms: int = 101 243 | 244 | @nn.compact 245 | def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = True): 246 | # Idea taken from https://github.com/perrin-isir/xpag 247 | # Similar to https://github.com/tinkoff-ai/CORL for PyTorch 248 | vmap_critic = nn.vmap( 249 | Critic, 250 | variable_axes={"params": 0, "batch_stats": 0}, 251 | split_rngs={"params": True, "dropout": True, "batch_stats": True}, 252 | in_axes=None, 253 | out_axes=0, 254 | axis_size=self.n_critics, 255 | ) 256 | q_values = vmap_critic( 257 | use_layer_norm=self.use_layer_norm, 258 | use_batch_norm=self.use_batch_norm, 259 | batch_norm_momentum=self.batch_norm_momentum, 260 | bn_warmup=self.bn_warmup, 261 | bn_mode=self.batch_norm_mode, 262 | dropout_rate=self.dropout_rate, 263 | net_arch=self.net_arch, 264 | activation_fn=self.activation_fn, 265 | n_atoms=self.n_atoms 266 | )(obs, action, train) 267 | return q_values 268 | 269 | 270 | -------------------------------------------------------------------------------- /diffusion/dime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import flax 4 | import optax 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | 9 | from gymnasium import spaces 10 | from functools import partial 11 | 12 | from diffusion.diffusion_policy import DiffPol 13 | from flax.training.train_state import TrainState 14 | from stable_baselines3.common.noise import ActionNoise 15 | from stable_baselines3.common.buffers import ReplayBuffer 16 | from common.off_policy_algorithm import OffPolicyAlgorithmJax 17 | from common.type_aliases import ReplayBufferSamplesNp, RLTrainState 18 | from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union 19 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback 20 | 21 | 22 | class EntropyCoef(nn.Module): 23 | ent_coef_init: float = 1.0 24 | 25 | @nn.compact 26 | def __call__(self, step) -> jnp.ndarray: 27 | log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) 28 | return jnp.exp(log_ent_coef) 29 | 30 | 31 | class ConstantEntropyCoef(nn.Module): 32 | ent_coef_init: float = 1.0 33 | 34 | @nn.compact 35 | def __call__(self, step) -> float: 36 | # Hack to not optimize the entropy coefficient while not having to use if/else for the jit 37 | self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) 38 | return jax.lax.stop_gradient(self.ent_coef_init) 39 | 40 | 41 | class DIME(OffPolicyAlgorithmJax): 42 | policy_aliases: ClassVar[Dict[str, Type[DiffPol]]] = { # type: ignore[assignment] 43 | "MlpPolicy": DiffPol, 44 | # Minimal dict support using flatten() 45 | "MultiInputPolicy": DiffPol, 46 | } 47 | 48 | policy: DiffPol 49 | action_space: spaces.Box # type: ignore[assignment] 50 | 51 | def __init__(self, 52 | policy, 53 | env: Union[GymEnv, str], 54 | model_save_path: str, 55 | save_every_n_steps: int, 56 | cfg, 57 | train_freq: Union[int, Tuple[int, str]] = 1, 58 | action_noise: Optional[ActionNoise] = None, 59 | replay_buffer_class: Optional[Type[ReplayBuffer]] = None, 60 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 61 | use_sde: bool = False, 62 | sde_sample_freq: int = -1, 63 | use_sde_at_warmup: bool = False, 64 | tensorboard_log: Optional[str] = None, 65 | verbose: int = 0, 66 | _init_setup_model: bool = True, 67 | stats_window_size: int = 100, 68 | ) -> None: 69 | super().__init__( 70 | policy=policy, 71 | env=env, 72 | learning_rate=cfg.alg.optimizer.lr_actor, 73 | qf_learning_rate=cfg.alg.optimizer.lr_critic, 74 | buffer_size=cfg.alg.buffer_size, 75 | learning_starts=cfg.alg.learning_starts, 76 | batch_size=cfg.alg.batch_size, 77 | tau=cfg.alg.tau, 78 | gamma=cfg.alg.gamma, 79 | train_freq=train_freq, 80 | gradient_steps=cfg.alg.utd, 81 | action_noise=action_noise, 82 | replay_buffer_class=replay_buffer_class, 83 | replay_buffer_kwargs=replay_buffer_kwargs, 84 | use_sde=use_sde, 85 | sde_sample_freq=sde_sample_freq, 86 | use_sde_at_warmup=use_sde_at_warmup, 87 | policy_kwargs=None, 88 | tensorboard_log=tensorboard_log, 89 | verbose=verbose, 90 | seed=cfg.seed, 91 | supported_action_spaces=(spaces.Box,), 92 | support_multi_env=True, 93 | stats_window_size=stats_window_size, 94 | ) 95 | self.cfg = cfg 96 | self.policy_delay = self.cfg.alg.policy_delay 97 | self.ent_coef_params = self.cfg.alg.ent_coef 98 | self.crossq_style = True 99 | self.use_bnstats_from_live_net = False 100 | self.policy_q_reduce_fn = jax.numpy.mean 101 | self.save_every_n_steps = save_every_n_steps 102 | self.model_save_path = model_save_path 103 | self.policy_tau = self.cfg.alg.policy_tau 104 | if _init_setup_model: 105 | self._setup_model() 106 | 107 | def _setup_model(self, reset=False) -> None: 108 | if not reset: 109 | super()._setup_model() 110 | 111 | if not hasattr(self, "policy") or self.policy is None or reset: 112 | # pytype: disable=not-instantiable 113 | self.policy = self.policy_class( # type: ignore[assignment] 114 | self.observation_space, 115 | self.action_space, 116 | self.cfg 117 | ) 118 | # pytype: enable=not-instantiable 119 | 120 | assert isinstance(self.qf_learning_rate, float) 121 | 122 | self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) 123 | 124 | self.key, ent_key = jax.random.split(self.key, 2) 125 | 126 | self.qf = self.policy.qf # type: ignore[assignment] 127 | 128 | # The entropy coefficient or entropy can be learned automatically 129 | # see Automating Entropy Adjustment for Maximum Entropy RL section 130 | # of https://arxiv.org/abs/1812.05905 131 | if self.ent_coef_params["type"] == "auto": 132 | ent_coef_init = self.ent_coef_params['init'] 133 | # Note: we optimize the log of the entropy coeff which is slightly different from the paper 134 | # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 135 | self.ent_coef = EntropyCoef(ent_coef_init) 136 | elif self.ent_coef_params["type"] == "const": 137 | # This will throw an error if a malformed string (different from 'auto') is passed 138 | assert isinstance( 139 | self.ent_coef_params["init"], float 140 | ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_params['init']}" 141 | self.ent_coef = ConstantEntropyCoef(self.ent_coef_params["init"]) # type: ignore[assignment] 142 | else: 143 | raise NotImplementedError(f"Entropy coefficient type {self.ent_coef_params['type']} not supported") 144 | 145 | self.ent_coef_state = TrainState.create( 146 | apply_fn=self.ent_coef.apply, 147 | params=self.ent_coef.init({"params": ent_key}, 0.0)["params"], 148 | tx=optax.adam( 149 | # learning_rate=self.learning_rate, 150 | learning_rate=1.0e-3, 151 | ), 152 | ) 153 | 154 | # automatically set target entropy if needed 155 | self.target_entropy = self.action_space.shape[0] * 4.0 156 | 157 | def learn( 158 | self, 159 | total_timesteps: int, 160 | callback: MaybeCallback = None, 161 | log_interval: int = 1, 162 | tb_log_name: str = "SAC", 163 | reset_num_timesteps: bool = True, 164 | progress_bar: bool = False, 165 | ): 166 | return super().learn( 167 | total_timesteps=total_timesteps, 168 | callback=callback, 169 | log_interval=log_interval, 170 | tb_log_name=tb_log_name, 171 | reset_num_timesteps=reset_num_timesteps, 172 | progress_bar=progress_bar, 173 | ) 174 | 175 | def train(self, batch_size, gradient_steps): 176 | # Sample all at once for efficiency (so we can jit the for loop) 177 | data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) 178 | # Pre-compute the indices where we need to update the actor 179 | # This is a hack in order to jit the train loop 180 | # It will compile once per value of policy_delay_indices 181 | policy_delay_indices = {i: True for i in range(gradient_steps) if 182 | ((self._n_updates + i + 1) % self.policy_delay) == 0} 183 | policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) 184 | 185 | if isinstance(data.observations, dict): 186 | keys = list(self.observation_space.keys()) 187 | obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) 188 | next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) 189 | else: 190 | obs = data.observations.numpy() 191 | next_obs = data.next_observations.numpy() 192 | 193 | # Convert to numpy 194 | data = ReplayBufferSamplesNp( 195 | obs, 196 | data.actions.numpy(), 197 | next_obs, 198 | data.dones.numpy().flatten(), 199 | data.rewards.numpy().flatten(), 200 | ) 201 | 202 | ( 203 | self.policy.qf_state, 204 | self.policy.actor_state, 205 | self.policy.target_actor_state, 206 | self.ent_coef_state, 207 | self.key, 208 | log_metrics, 209 | ) = self._train( 210 | self.crossq_style, 211 | self.use_bnstats_from_live_net, 212 | self.gamma, 213 | self.tau, 214 | self.policy_tau, 215 | self.target_entropy, 216 | gradient_steps, 217 | data, 218 | policy_delay_indices, 219 | self.policy.qf_state, 220 | self.policy.actor_state, 221 | self.policy.target_actor_state, 222 | self.ent_coef_state, 223 | self.key, 224 | self.num_timesteps, 225 | self.policy_q_reduce_fn, 226 | self.policy.sampler, 227 | self.policy.target_sampler, 228 | self.cfg.alg.critic.v_min, 229 | self.cfg.alg.critic.v_max, 230 | self.cfg.alg.critic.entr_coeff, 231 | self.cfg.alg.critic.n_atoms 232 | ) 233 | self._n_updates += gradient_steps 234 | 235 | if self.model_save_path is not None: 236 | if (self.num_timesteps % self.save_every_n_steps == 0) or (self.num_timesteps == (self.learning_starts+1)): 237 | self._save_model() 238 | 239 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 240 | for k, v in log_metrics.items(): 241 | try: 242 | log_val = v.item() 243 | except: 244 | log_val = v 245 | self.logger.record(f"train/{k}", log_val) 246 | 247 | @staticmethod 248 | @partial(jax.jit, static_argnames=["crossq_style", "use_bnstats_from_live_net", "sampler", "num_atoms", "z_atoms", 249 | "v_min", "v_max", "entr_coeff"]) 250 | def update_critic( 251 | crossq_style: bool, 252 | use_bnstats_from_live_net: bool, 253 | gamma: float, 254 | actor_state: TrainState, 255 | qf_state: RLTrainState, 256 | ent_coef_state: TrainState, 257 | observations: np.ndarray, 258 | actions: np.ndarray, 259 | next_observations: np.ndarray, 260 | rewards: np.ndarray, 261 | dones: np.ndarray, 262 | n_env_interacts: int, 263 | num_atoms: int, 264 | z_atoms: jnp.ndarray, 265 | v_min: int, 266 | v_max: int, 267 | entr_coeff: float, 268 | key, 269 | sampler 270 | ): 271 | key, noise_key, dropout_key_target, dropout_key_current, redq_key = jax.random.split(key, 5) 272 | # sample action from the actor 273 | 274 | out = DiffPol.sample_action(actor_state, actor_state.params, next_observations, noise_key, sampler) 275 | all_actions, next_run_costs, next_sto_costs, next_terminal_costs, latents, v_t = out 276 | next_state_actions = jax.lax.stop_gradient(all_actions) 277 | next_run_costs = jax.lax.stop_gradient(next_run_costs) 278 | next_sto_costs = jax.lax.stop_gradient(next_sto_costs) 279 | next_terminal_costs = jax.lax.stop_gradient(next_terminal_costs) 280 | 281 | ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}, n_env_interacts) 282 | 283 | def ce_loss(params, batch_stats, dropout_key): 284 | if not crossq_style: 285 | next_q_values = qf_state.apply_fn( 286 | { 287 | "params": qf_state.target_params, 288 | "batch_stats": qf_state.target_batch_stats if not use_bnstats_from_live_net else batch_stats 289 | }, 290 | next_observations, next_state_actions, 291 | rngs={"dropout": dropout_key_target}, 292 | train=False 293 | ) 294 | 295 | # shape is (n_critics, batch_size, 1) 296 | current_q_values, state_updates = qf_state.apply_fn( 297 | {"params": params, "batch_stats": batch_stats}, 298 | observations, actions, 299 | rngs={"dropout": dropout_key}, 300 | mutable=["batch_stats"], 301 | train=True, 302 | ) 303 | 304 | else: 305 | # ----- CrossQ's One Weird Trick™ ----- 306 | # concatenate current and next observations to double the batch size 307 | # new shape of input is (n_critics, 2*batch_size, obs_dim + act_dim) 308 | # apply critic to this bigger batch 309 | catted_q_values, state_updates = qf_state.apply_fn( 310 | {"params": params, "batch_stats": batch_stats}, 311 | jnp.concatenate([observations, next_observations], axis=0), 312 | jnp.concatenate([actions, next_state_actions], axis=0), 313 | rngs={"dropout": dropout_key}, 314 | mutable=["batch_stats"], 315 | train=True, 316 | ) 317 | current_q_values, next_q_values = jnp.split(catted_q_values, 2, axis=1) 318 | 319 | if next_q_values.shape[0] > 2: # only for REDQ 320 | # REDQ style subsampling of critics. 321 | m_critics = 2 322 | next_q_values = jax.random.choice(redq_key, next_q_values, (m_critics,), replace=False, axis=0) 323 | 324 | next_q_values_q1 = next_q_values[0] 325 | next_q_values_q2 = next_q_values[1] 326 | 327 | current_q1 = current_q_values[0] 328 | current_q2 = current_q_values[1] 329 | 330 | def projection(next_dist, rewards, dones, gamma, v_min, v_max, num_atoms, support): 331 | delta_z = (v_max - v_min) / (num_atoms - 1) 332 | batch_size = rewards.shape[0] 333 | 334 | entr_bon = - (1 - dones[:, None]) * gamma * ent_coef_value * (next_run_costs + next_sto_costs + next_terminal_costs) 335 | 336 | # Compute target_z 337 | target_z = jnp.clip(rewards[:,None] + entr_bon + (1 - dones[:, None]) * gamma * support, a_min=v_min, a_max=v_max) 338 | b = (target_z - v_min) / delta_z 339 | l = jnp.floor(b).astype(jnp.int32) 340 | u = jnp.ceil(b).astype(jnp.int32) 341 | 342 | # Adjust l and u to ensure they remain within valid bounds 343 | l = jnp.where((u > 0) & (l == u), l - 1, l) 344 | u = jnp.where((l < (num_atoms - 1)) & (l == u), u + 1, u) 345 | 346 | # Create the projected distribution 347 | proj_dist = jnp.zeros_like(next_dist) 348 | 349 | # Offset calculation for batch indexing 350 | offset = jnp.arange(batch_size)[:, None] * num_atoms 351 | # offset = jnp.tile(offset, (1, num_atoms)) # Repeat along the second axis 352 | 353 | # Index updates for proj_dist 354 | l_idx = (l + offset).ravel() 355 | u_idx = (u + offset).ravel() 356 | 357 | # Flattened updates 358 | l_update = (next_dist * (u.astype(jnp.float32) - b)).ravel() 359 | u_update = (next_dist * (b - l.astype(jnp.float32))).ravel() 360 | 361 | # Flatten proj_dist for updates 362 | proj_dist_flat = proj_dist.ravel() 363 | 364 | # Add values to proj_dist 365 | proj_dist_flat = proj_dist_flat.at[l_idx].add(l_update) 366 | proj_dist_flat = proj_dist_flat.at[u_idx].add(u_update) 367 | 368 | # Reshape back to [batch_size, num_atoms] 369 | proj_dist = proj_dist_flat.reshape(batch_size, num_atoms) 370 | 371 | return proj_dist 372 | 373 | target_q1_projected = projection(next_dist=next_q_values_q1, rewards=rewards, dones=dones, gamma=gamma, 374 | v_min=v_min, v_max=v_max, num_atoms=num_atoms, support=z_atoms) 375 | target_q2_projected = projection(next_dist=next_q_values_q2, rewards=rewards, dones=dones, gamma=gamma, 376 | v_min=v_min, v_max=v_max, num_atoms=num_atoms, support=z_atoms) 377 | 378 | next_q_values = jax.lax.stop_gradient(jnp.mean( 379 | jnp.stack([target_q1_projected, target_q2_projected], axis=0), axis=0)) 380 | 381 | def binary_cross_entropy(pred, target): 382 | return -jnp.mean(jnp.sum(target * jnp.log(pred + 1e-15), axis=-1)) + entr_coeff*jnp.mean(jnp.sum(pred*jnp.log(pred + 1e-15), axis=-1)) #+ (1 - target) * jnp.log(1 - pred + 1e-15)) 383 | 384 | loss = binary_cross_entropy(current_q1, next_q_values) + binary_cross_entropy(current_q2, next_q_values) 385 | qf_pi1 = jnp.sum(current_q1 * z_atoms, axis=-1) 386 | qf_pi2 = jnp.sum(current_q2 * z_atoms, axis=-1) 387 | entr_1 = -jnp.mean(jnp.sum(current_q1 * jnp.log(current_q1 + 1e-15), axis=-1)) 388 | entr_2 = -jnp.mean(jnp.sum(current_q2 * jnp.log(current_q2 + 1e-15), axis=-1)) 389 | min_qf_pi = jax.lax.stop_gradient(jnp.min(jnp.stack([qf_pi1, qf_pi2], axis=0), axis=0).squeeze()) 390 | return loss, (state_updates, min_qf_pi, next_q_values, entr_1, entr_2) 391 | 392 | (qf_loss_value, (state_updates, current_q_values, next_q_values, entr_1, entr_2)), grads = \ 393 | jax.value_and_grad(ce_loss, has_aux=True)(qf_state.params, qf_state.batch_stats, dropout_key_current) 394 | 395 | qf_state = qf_state.apply_gradients(grads=grads) 396 | qf_state = qf_state.replace(batch_stats=state_updates["batch_stats"]) 397 | 398 | metrics = { 399 | 'critic_loss': qf_loss_value, 400 | 'ent_coef': ent_coef_value, 401 | 'current_q_values': current_q_values.mean(), 402 | 'next_q_values': next_q_values.mean(), 403 | 'entrQ_1': entr_1, 404 | 'entrQ_2': entr_2, 405 | } 406 | return qf_state, metrics, key 407 | 408 | @staticmethod 409 | @partial(jax.jit, static_argnames=["q_reduce_fn", "sampler"]) 410 | def update_actor( 411 | actor_state: TrainState, 412 | qf_state: RLTrainState, 413 | ent_coef_state: TrainState, 414 | observations: np.ndarray, 415 | n_env_interacts: int, 416 | key, 417 | z_atoms: jnp.ndarray, 418 | sampler, 419 | q_reduce_fn, 420 | ): 421 | key, dropout_key, noise_key = jax.random.split(key, 3) 422 | 423 | def actor_loss(actor_state_in, actor_params): 424 | out = DiffPol.sample_action(actor_state_in, actor_params, observations, noise_key, sampler) 425 | actions, run_costs, sto_costs, terminal_costs, latents, v_t = out 426 | qf_pi = qf_state.apply_fn( 427 | { 428 | "params": qf_state.params, 429 | "batch_stats": qf_state.batch_stats 430 | }, 431 | observations, 432 | actions, 433 | rngs={"dropout": dropout_key}, train=False 434 | ) 435 | 436 | qf_pi1 = jnp.sum(qf_pi[0] * z_atoms, axis=-1) 437 | qf_pi2 = jnp.sum(qf_pi[1] * z_atoms, axis=-1) 438 | min_qf_pi = q_reduce_fn(jnp.stack([qf_pi1, qf_pi2], axis=0), axis=0).squeeze() 439 | ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}, n_env_interacts) 440 | actor_loss = (- min_qf_pi + ent_coef_value * (run_costs.squeeze() + sto_costs.squeeze() + terminal_costs.squeeze())).mean() 441 | 442 | max_actions = jnp.max(jnp.max(latents, axis=0), axis=1) 443 | min_actions = jnp.min(jnp.min(latents, axis=0), axis=1) 444 | mean_actions = jnp.mean(jnp.mean(latents, axis=0), axis=1) 445 | 446 | latent_acts = {'max_la': max_actions, 'min_la': min_actions, 'mean_la': mean_actions} 447 | 448 | return actor_loss, (run_costs.mean(), sto_costs.mean(), terminal_costs.mean(), latent_acts) 449 | 450 | outs = jax.value_and_grad(actor_loss, has_aux=True, argnums=1)(actor_state, actor_state.params) 451 | (act_loss_value, (run_costs_mean, sto_costs, terminal_costs, latent_acts)), grads = outs 452 | actor_state = actor_state.apply_gradients(grads=grads) 453 | metrics = {"entropy": 0.0, 454 | "run_costs": run_costs_mean, 455 | "sto_costs": sto_costs, 456 | "terminal_costs": terminal_costs, 457 | } 458 | return actor_state, qf_state, act_loss_value, key, [metrics, latent_acts] 459 | 460 | @staticmethod 461 | @jax.jit 462 | def soft_update(tau: float, qf_state: RLTrainState): 463 | qf_state = qf_state.replace( 464 | target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) 465 | qf_state = qf_state.replace( 466 | target_batch_stats=optax.incremental_update(qf_state.batch_stats, qf_state.target_batch_stats, tau)) 467 | return qf_state 468 | 469 | @staticmethod 470 | @jax.jit 471 | def soft_update_target_actor(tau: float, actor_state: TrainState, target_actor_state: TrainState): 472 | target_actor_state = target_actor_state.replace( 473 | params=optax.incremental_update(actor_state.params, target_actor_state.params, tau)) 474 | return target_actor_state 475 | 476 | @staticmethod 477 | @jax.jit 478 | def update_temperature(target_entropy: np.ndarray, ent_coef_state: TrainState, entropy: float): 479 | def temperature_loss(temp_params): 480 | ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}, 0) 481 | ent_coef_loss = -ent_coef_value * (entropy - target_entropy).mean() 482 | return ent_coef_loss 483 | 484 | ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) 485 | ent_coef_state = ent_coef_state.apply_gradients(grads=grads) 486 | 487 | return ent_coef_state, ent_coef_loss 488 | 489 | @classmethod 490 | @partial(jax.jit, 491 | static_argnames=["cls", "crossq_style", "use_bnstats_from_live_net", "gradient_steps", "q_reduce_fn", 492 | "sampler", "target_sampler", "v_min", "v_max", "num_atoms", "entr_coeff"]) 493 | def _train( 494 | cls, 495 | crossq_style: bool, 496 | use_bnstats_from_live_net: bool, 497 | gamma: float, 498 | tau: float, 499 | policy_tau: float, 500 | target_entropy: np.ndarray, 501 | gradient_steps: int, 502 | data: ReplayBufferSamplesNp, 503 | policy_delay_indices: flax.core.FrozenDict, 504 | qf_state: RLTrainState, 505 | actor_state: TrainState, 506 | target_actor_state: TrainState, 507 | ent_coef_state: TrainState, 508 | key, 509 | n_env_interacts, 510 | q_reduce_fn, 511 | sampler, 512 | target_sampler, 513 | v_min, 514 | v_max, 515 | entr_coeff, 516 | num_atoms 517 | ): 518 | actor_loss_value = jnp.array(0) 519 | actor_metrics = [{}] 520 | for i in range(gradient_steps): 521 | 522 | def slice(x, step=i): 523 | assert x.shape[0] % gradient_steps == 0 524 | batch_size = x.shape[0] // gradient_steps 525 | return x[batch_size * step: batch_size * (step + 1)] 526 | 527 | z_atoms = jnp.linspace(v_min, v_max, num_atoms) 528 | 529 | ( 530 | qf_state, 531 | log_metrics_critic, 532 | key, 533 | ) = cls.update_critic( 534 | crossq_style, 535 | use_bnstats_from_live_net, 536 | gamma, 537 | target_actor_state, 538 | qf_state, 539 | ent_coef_state, 540 | slice(data.observations), 541 | slice(data.actions), 542 | slice(data.next_observations), 543 | slice(data.rewards), 544 | slice(data.dones), 545 | n_env_interacts, 546 | num_atoms, 547 | z_atoms, 548 | v_min, 549 | v_max, 550 | entr_coeff, 551 | key, 552 | target_sampler 553 | ) 554 | qf_state = DIME.soft_update(tau, qf_state) 555 | target_actor_state = target_actor_state 556 | # hack to be able to jit (n_updates % policy_delay == 0) 557 | # a = False 558 | if i in policy_delay_indices: # and a: 559 | (actor_state, qf_state, actor_loss_value, key, actor_metrics) = cls.update_actor( 560 | actor_state, 561 | qf_state, 562 | ent_coef_state, 563 | slice(data.observations), 564 | n_env_interacts, 565 | key, 566 | z_atoms, 567 | sampler, 568 | q_reduce_fn, 569 | ) 570 | ent_coef_state, _ = DIME.update_temperature(target_entropy, ent_coef_state, 571 | actor_metrics[0]['run_costs']) 572 | 573 | target_actor_state = DIME.soft_update_target_actor(policy_tau, actor_state, target_actor_state) 574 | log_metrics = {'actor_loss': actor_loss_value, **actor_metrics[0], **log_metrics_critic} 575 | return qf_state, actor_state, target_actor_state, ent_coef_state, key, log_metrics 576 | 577 | def predict_critic(self, observation, action): 578 | return self.policy.predict_critic(observation, action) 579 | 580 | def current_entropy_coeff(self): 581 | return self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) 582 | 583 | def _save_model(self): 584 | save_model_state(self.policy.actor_state, self.model_save_path, "actor_state", self.num_timesteps) 585 | save_model_state(self.policy.qf_state, self.model_save_path, "critic_state", self.num_timesteps) 586 | 587 | def load_model(self, path, n_steps_actor, n_steps_critic): 588 | self.policy.actor_state = load_state(path, "actor_state", n_steps_actor, train_state=self.policy.actor_state) 589 | self.policy.qf_state = load_state(path, "critic_state", n_steps_critic, train_state=self.policy.qf_state) 590 | 591 | 592 | # Save and load model 593 | def save_model_state(train_state, path, name, n_steps): 594 | # Serialize the model parameters 595 | serialized_state = flax.serialization.to_bytes(train_state) 596 | os.makedirs(path, exist_ok=True) 597 | extended_path = os.path.join(path, f'{name}_{n_steps}.msgpack') 598 | # Save the serialized parameters to a file 599 | with open(extended_path, 'wb') as f: 600 | f.write(serialized_state) 601 | 602 | 603 | def load_state(path, name, n_steps, train_state=None): 604 | extended_path = os.path.join(path, f'{name}_{n_steps}.msgpack') 605 | # Load the serialized parameters from a file 606 | with open(extended_path, 'rb') as f: 607 | train_state_loaded = f.read() 608 | return flax.serialization.from_bytes(train_state, train_state_loaded) --------------------------------------------------------------------------------