├── agents ├── __init__.py ├── helpers.py ├── model.py ├── dtql.py └── dql_kl.py ├── utils ├── __init__.py ├── data_sampler.py ├── utils.py └── logger.py ├── diffusion ├── __init__.py ├── utils.py ├── mlps.py └── karras.py ├── assets ├── DTQL_toy.pdf └── DTQL_toy.png ├── install_env.sh ├── requirements.txt ├── toy_tasks ├── run_toy.sh ├── data_generator.py └── toy_main.py ├── README.md └── main.py /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/DTQL_toy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianyuCodings/Diffusion_Trusted_Q_Learning/HEAD/assets/DTQL_toy.pdf -------------------------------------------------------------------------------- /assets/DTQL_toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianyuCodings/Diffusion_Trusted_Q_Learning/HEAD/assets/DTQL_toy.png -------------------------------------------------------------------------------- /install_env.sh: -------------------------------------------------------------------------------- 1 | conda create -n rl python=3.9 2 | conda activate rl 3 | 4 | # make sure the cuda verion fits the local machine 5 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 6 | pip install mujoco 7 | pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 8 | pip install "cython<3" 9 | pip install python-dateutil 10 | 11 | 12 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 13 | 14 | # do this in the home directory, it will extract mujoco210 into /home/username/.mujoco/mujoco210... 15 | mkdir -p .mujoco/ 16 | tar -xvf mujoco210-linux-x86_64.tar.gz -C .mujoco/ 17 | 18 | # add this to .bashrc 19 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tychen/.mujoco/mujoco210/bin 20 | source .bashrc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | certifi==2022.5.18.1 3 | cffi==1.15.0 4 | charset-normalizer==2.0.12 5 | click==8.1.3 6 | cloudpickle==2.1.0 7 | cycler==0.11.0 8 | Cython==0.29.30 9 | d4rl @ git+https://github.com/rail-berkeley/d4rl@d842aa194b416e564e54b0730d9f934e3e32f854 10 | dm-control @ git+https://github.com/deepmind/dm_control@41d0c7383153f9ca6c12f8e865ef5e73a98759bd 11 | dm-env==1.5 12 | dm-tree==0.1.7 13 | fasteners==0.17.3 14 | fonttools==4.33.3 15 | glfw==2.5.3 16 | gym==0.24.1 17 | gym-notices==0.0.7 18 | h5py==3.7.0 19 | idna==3.3 20 | imageio==2.19.3 21 | importlib-metadata==4.11.4 22 | kiwisolver==1.4.2 23 | labmaze==1.0.5 24 | lxml==4.9.0 25 | matplotlib==3.5.2 26 | mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc 27 | mujoco==2.2.0 28 | mujoco-py==2.0.2.13 29 | numpy==1.22.4 30 | packaging==21.3 31 | pandas==1.4.2 32 | Pillow==9.1.1 33 | protobuf==4.21.1 34 | pybullet==3.2.5 35 | pycparser==2.21 36 | PyOpenGL==3.1.6 37 | pyparsing==2.4.7 38 | python-dateutil==2.8.2 39 | pytz==2022.1 40 | requests==2.28.0 41 | scipy==1.8.1 42 | six==1.16.0 43 | termcolor==1.1.0 44 | torch==1.10.1+cu111 45 | torchaudio==0.10.1+rocm4.1 46 | torchvision==0.11.2+cu111 47 | tqdm==4.64.0 48 | typing_extensions==4.2.0 49 | urllib3==1.26.9 50 | zipp==3.8.0 -------------------------------------------------------------------------------- /toy_tasks/run_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | source activate torch-rl 3 | python toy_main.py --device=5 --env_name=25_gaussian --distill_loss=diffusion --reward_type=near --actor=sac & 4 | python toy_main.py --device=1 --env_name=25_gaussian --distill_loss=diffusion --reward_type=far --actor=sac & 5 | python toy_main.py --device=2 --env_name=25_gaussian --distill_loss=diffusion --reward_type=hard --actor=sac --seed=2& 6 | 7 | 8 | python toy_main.py --device=2 --env_name=25_gaussian --distill_loss=diffusion --reward_type=near --actor=sac --gamma=0.005& 9 | python toy_main.py --device=3 --env_name=25_gaussian --distill_loss=diffusion --reward_type=far --actor=sac --gamma=0.005& 10 | python toy_main.py --device=6 --env_name=25_gaussian --distill_loss=diffusion --reward_type=hard --actor=sac --gamma=0.005 --seed=2& 11 | 12 | python toy_main.py --device=3 --env_name=25_gaussian --distill_loss=dmd --reward_type=near --actor=implicit & 13 | python toy_main.py --device=7 --env_name=25_gaussian --distill_loss=dmd --reward_type=far --actor=implicit --train_epochs=2000& 14 | python toy_main.py --device=1 --env_name=25_gaussian --distill_loss=dmd --reward_type=hard --actor=implicit --train_epochs=2500& 15 | 16 | python toy_main.py --device=2 --env_name=swiss_roll_2D --distill_loss=diffusion --reward_type=near --actor=sac --seed=2 & 17 | python toy_main.py --device=3 --env_name=swiss_roll_2D --distill_loss=diffusion --reward_type=far --actor=sac --eta=5 --seed=2 & 18 | 19 | python toy_main.py --device=4 --env_name=swiss_roll_2D --distill_loss=dmd --reward_type=near --actor=implicit --train_epochs=1500& 20 | python toy_main.py --device=1 --env_name=swiss_roll_2D --distill_loss=dmd --reward_type=far --actor=implicit & 21 | -------------------------------------------------------------------------------- /utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Data_Sampler(object): 5 | def __init__(self, data, device, reward_tune='no'): 6 | 7 | self.device = device 8 | 9 | self.state = torch.from_numpy(data['observations']).float().to(device) 10 | self.action = torch.from_numpy(data['actions']).float().to(device) 11 | self.next_state = torch.from_numpy(data['next_observations']).float().to(device) 12 | reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(device) 13 | self.not_done = 1. - torch.from_numpy(data['terminals']).view(-1, 1).float().to(device) 14 | 15 | self.size = self.state.shape[0] 16 | self.state_dim = self.state.shape[1] 17 | self.action_dim = self.action.shape[1] 18 | 19 | if reward_tune == 'normalize': 20 | reward = (reward - reward.mean()) / reward.std() 21 | elif reward_tune == 'iql_antmaze': 22 | reward = reward - 1.0 23 | elif reward_tune == 'iql_locomotion': 24 | reward = iql_normalize(reward, self.not_done) 25 | elif reward_tune == 'cql_antmaze': 26 | reward = (reward - 0.5) * 4.0 27 | elif reward_tune == 'antmaze': 28 | reward = (reward - 0.25) * 2.0 29 | self.reward = reward 30 | 31 | def sample(self, batch_size): 32 | ind = torch.randint(0, self.size, size=(batch_size,)) 33 | 34 | return ( 35 | self.state[ind], 36 | self.action[ind], 37 | self.next_state[ind], 38 | self.reward[ind], 39 | self.not_done[ind], 40 | ) 41 | 42 | def iql_normalize(reward, not_done): 43 | trajs_rt = [] 44 | episode_return = 0.0 45 | for i in range(len(reward)): 46 | episode_return += reward[i] 47 | if not not_done[i]: 48 | trajs_rt.append(episode_return) 49 | episode_return = 0.0 50 | rt_max, rt_min = torch.max(torch.tensor(trajs_rt)), torch.min(torch.tensor(trajs_rt)) 51 | reward /= (rt_max - rt_min) 52 | reward *= 1000. 53 | return reward 54 | -------------------------------------------------------------------------------- /agents/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | class EMA(): 4 | ''' 5 | empirical moving average 6 | ''' 7 | def __init__(self, beta): 8 | super().__init__() 9 | self.beta = beta 10 | 11 | def update_model_average(self, ma_model, current_model): 12 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 13 | old_weight, up_weight = ma_params.data, current_params.data 14 | ma_params.data = self.update_average(old_weight, up_weight) 15 | 16 | def update_average(self, old, new): 17 | if old is None: 18 | return new 19 | return old * self.beta + (1 - self.beta) * new 20 | 21 | 22 | def get_dmd_loss(diffusion_model, true_score_model, fake_score_model, fake_action_data, state_data): 23 | noise = torch.randn_like(fake_action_data) 24 | fake_score_model.eval() 25 | true_score_model.eval() 26 | with torch.no_grad(): 27 | pred_real_action, _, t_chosen = diffusion_model.diffusion_train_step(model=true_score_model, 28 | x=fake_action_data, cond=state_data, 29 | noise=noise, t_chosen=None, 30 | return_denoised=True) 31 | 32 | pred_fake_action, _, t_chosen = diffusion_model.diffusion_train_step(model=fake_score_model, 33 | x=fake_action_data, cond=state_data, 34 | noise=noise, t_chosen=t_chosen, 35 | return_denoised=True) 36 | weighting_factor = (fake_action_data - pred_real_action).abs().mean(axis=1).reshape(-1, 1) 37 | grad = (pred_fake_action - pred_real_action) / weighting_factor 38 | distill_loss = 0.5 * F.mse_loss(fake_action_data, (fake_action_data - grad).detach()) 39 | return distill_loss -------------------------------------------------------------------------------- /diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | 6 | def append_dims(x, target_dims): 7 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def append_zero(action): 15 | return torch.cat([action, action.new_zeros([1])]) 16 | 17 | 18 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 19 | """Draws samples from a lognormal distribution.""" 20 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() 21 | 22 | 23 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 24 | """Draws samples from an optionally truncated log-logistic distribution.""" 25 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 26 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 27 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 28 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 29 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 30 | return u.logit().mul(scale).add(loc).exp().to(dtype) 31 | 32 | 33 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 34 | """Draws samples from an log-uniform distribution.""" 35 | min_value = math.log(min_value) 36 | max_value = math.log(max_value) 37 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() 38 | 39 | 40 | def rand_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 41 | """Draws samples from an uniform distribution.""" 42 | return torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value 43 | 44 | 45 | def rand_discrete(shape, values, device='cpu', dtype=torch.float32): 46 | probs = [1 / len(values)] * len(values) # set equal probability for all values 47 | return torch.tensor(np.random.choice(values, size=shape, p=probs), device=device, dtype=dtype) 48 | 49 | 50 | def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 51 | """Draws samples from a truncated v-diffusion training timestep distribution.""" 52 | min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi 53 | max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi 54 | u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf 55 | return torch.tan(u * math.pi / 2) * sigma_data 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Diffusion Trusted Q-Learning for Offline RL — Official PyTorch Implementation 2 | 3 | **Diffusion Policies creating a Trust Region for Offline Reinforcement Learning**
4 | [Tianyu Chen](https://scholar.google.com/citations?user=8Aum3V8AAAAJ&hl=en), [Zhendong Wang](https://zhendong-wang.github.io/), and [Mingyuan Zhou](https://mingyuanzhou.github.io/)
https://arxiv.org/abs/2405.19690 5 |
6 | 7 | Abstract: *Offline reinforcement learning (RL) leverages pre-collected 8 | datasets to train optimal policies. Diffusion Q-Learning (DQL), introducing 9 | diffusion models as a powerful and expressive policy class, significantly 10 | boosts the performance of offline RL. However, its reliance on iterative 11 | denoising sampling to generate actions slows down both training and 12 | inference. While several recent attempts have tried to accelerate 13 | diffusion-QL, the improvement in training and/or inference speed often 14 | results in degraded performance. In this paper, we introduce a dual 15 | policy approach, Diffusion Trusted Q-Learning (DTQL), which comprises a 16 | diffusion policy for pure behavior cloning and a practical one-step policy. 17 | We bridge the two polices by a newly introduced diffusion trust region 18 | loss. The diffusion policy maintains expressiveness, while the trust 19 | region loss directs the one-step policy to explore freely and seek 20 | modes within the region defined by the diffusion policy. DTQL 21 | eliminates the need for iterative denoising sampling during both training 22 | and inference, making it remarkably computationally efficient. We 23 | evaluate its effectiveness and algorithmic characteristics against 24 | popular Kullback-Leibler (KL) based distillation methods in 2D bandit 25 | scenarios and gym tasks. We then show that DTQL could not only outperform 26 | other methods on the majority of the D4RL benchmark tasks but also 27 | demonstrate efficiency in training and inference speeds.* 28 | 29 | ## Introduction 30 | 31 | We introduce a dual-policy approach, Diffusion Trusted Q-Learning (DTQL): a diffusion policy for pure behavior cloning and a one-step policy for actual depolyment. 32 | We bridge the two policies through our newly introduced diffusion trust region loss. 33 | The loss ensures that the generated action lies within the in-sample datasets' action manifold. 34 | With the gradient of the Q-function, it allows actions to freely move within the in-sample data manifold and gravitate towards high-reward regions. 35 | 36 | We compare our behaviour regularization loss (diffusion trusted region loss) with Kullback–Leibler based behaviour regularization loss. We tested their differential impact on behavior regularization, using a trained Q-function for policy improvement. Red points represent actions generated from the one-step policy. 37 | 38 | ![DTQL](./assets/DTQL_toy.png) 39 | 40 | 41 | ## Experiments 42 | 43 | ### Requirements 44 | Installations of [PyTorch](https://pytorch.org/), [MuJoCo](https://github.com/deepmind/mujoco), and [D4RL](https://github.com/Farama-Foundation/D4RL) are needed. Please see the ``requirements.txt`` and ``install_env.sh`` for environment set up details. 45 | 46 | ### Running 47 | Running experiments based our code could be quite easy, so below we use `halfcheetah-medium-v2` dataset as an example. 48 | 49 | ```.bash 50 | python main.py --device=0 --env_name=halfcheetah-medium-v2 --seed=1 --dir=results 51 | ``` 52 | 53 | Hyperparameters for Diffusion-QL have been hard coded in `main.py` for easily reproducing our reported results. 54 | Definitely, there could exist better hyperparameter settings. Feel free to have your own modifications. 55 | 56 | ## Citation 57 | 58 | If you find this open source release useful, please cite in your paper: 59 | ``` 60 | @misc{chen2024diffusion, 61 | title={Diffusion Policies creating a Trust Region for Offline Reinforcement Learning}, 62 | author={Tianyu Chen and Zhendong Wang and Mingyuan Zhou}, 63 | year={2024}, 64 | eprint={2405.19690}, 65 | archivePrefix={arXiv}, 66 | primaryClass={cs.LG} 67 | } 68 | ``` 69 | 70 | ## Acknowledgement 71 | This repo is heavily built upon [DQL](https://github.com/Zhendong-Wang/Diffusion-Policies-for-Offline-RL). We thank the authors for their excellent work. 72 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import math 4 | 5 | 6 | def print_banner(s, separator="-", num_star=60): 7 | print(separator * num_star, flush=True) 8 | print(s, flush=True) 9 | print(separator * num_star, flush=True) 10 | 11 | 12 | class Progress: 13 | 14 | def __init__(self, total, name='Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 15 | self.total = total 16 | self.name = name 17 | self.ncol = ncol 18 | self.max_length = max_length 19 | self.indent = indent 20 | self.line_width = line_width 21 | self._speed_update_freq = speed_update_freq 22 | 23 | self._step = 0 24 | self._prev_line = '\033[F' 25 | self._clear_line = ' ' * self.line_width 26 | 27 | self._pbar_size = self.ncol * self.max_length 28 | self._complete_pbar = '#' * self._pbar_size 29 | self._incomplete_pbar = ' ' * self._pbar_size 30 | 31 | self.lines = [''] 32 | self.fraction = '{} / {}'.format(0, self.total) 33 | 34 | self.resume() 35 | 36 | def update(self, description, n=1): 37 | self._step += n 38 | if self._step % self._speed_update_freq == 0: 39 | self._time0 = time.time() 40 | self._step0 = self._step 41 | self.set_description(description) 42 | 43 | def resume(self): 44 | self._skip_lines = 1 45 | print('\n', end='') 46 | self._time0 = time.time() 47 | self._step0 = self._step 48 | 49 | def pause(self): 50 | self._clear() 51 | self._skip_lines = 1 52 | 53 | def set_description(self, params=[]): 54 | 55 | if type(params) == dict: 56 | params = sorted([ 57 | (key, val) 58 | for key, val in params.items() 59 | ]) 60 | 61 | ############ 62 | # Position # 63 | ############ 64 | self._clear() 65 | 66 | ########### 67 | # Percent # 68 | ########### 69 | percent, fraction = self._format_percent(self._step, self.total) 70 | self.fraction = fraction 71 | 72 | ######### 73 | # Speed # 74 | ######### 75 | speed = self._format_speed(self._step) 76 | 77 | ########## 78 | # Params # 79 | ########## 80 | num_params = len(params) 81 | nrow = math.ceil(num_params / self.ncol) 82 | params_split = self._chunk(params, self.ncol) 83 | params_string, lines = self._format(params_split) 84 | self.lines = lines 85 | 86 | description = '{} | {}{}'.format(percent, speed, params_string) 87 | print(description) 88 | self._skip_lines = nrow + 1 89 | 90 | def append_description(self, descr): 91 | self.lines.append(descr) 92 | 93 | def _clear(self): 94 | position = self._prev_line * self._skip_lines 95 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 96 | print(position, end='') 97 | print(empty) 98 | print(position, end='') 99 | 100 | def _format_percent(self, n, total): 101 | if total: 102 | percent = n / float(total) 103 | 104 | complete_entries = int(percent * self._pbar_size) 105 | incomplete_entries = self._pbar_size - complete_entries 106 | 107 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 108 | fraction = '{} / {}'.format(n, total) 109 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent * 100)) 110 | else: 111 | fraction = '{}'.format(n) 112 | string = '{} iterations'.format(n) 113 | return string, fraction 114 | 115 | def _format_speed(self, n): 116 | num_steps = n - self._step0 117 | t = time.time() - self._time0 118 | speed = num_steps / t 119 | string = '{:.1f} Hz'.format(speed) 120 | if num_steps > 0: 121 | self._speed = string 122 | return string 123 | 124 | def _chunk(self, l, n): 125 | return [l[i:i + n] for i in range(0, len(l), n)] 126 | 127 | def _format(self, chunks): 128 | lines = [self._format_chunk(chunk) for chunk in chunks] 129 | lines.insert(0, '') 130 | padding = '\n' + ' ' * self.indent 131 | string = padding.join(lines) 132 | return string, lines 133 | 134 | def _format_chunk(self, chunk): 135 | line = ' | '.join([self._format_param(param) for param in chunk]) 136 | return line 137 | 138 | def _format_param(self, param): 139 | k, v = param 140 | return '{} : {}'.format(k, v)[:self.max_length] 141 | 142 | def stamp(self): 143 | if self.lines != ['']: 144 | params = ' | '.join(self.lines) 145 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 146 | self._clear() 147 | print(string, end='\n') 148 | self._skip_lines = 1 149 | else: 150 | self._clear() 151 | self._skip_lines = 0 152 | 153 | def close(self): 154 | self.pause() 155 | 156 | 157 | class Silent: 158 | 159 | def __init__(self, *args, **kwargs): 160 | pass 161 | 162 | def __getattr__(self, attr): 163 | return lambda *args: None 164 | 165 | 166 | class EarlyStopping(object): 167 | def __init__(self, tolerance=5, min_delta=0): 168 | self.tolerance = tolerance 169 | self.min_delta = min_delta 170 | self.counter = 0 171 | self.early_stop = False 172 | 173 | def __call__(self, train_loss, validation_loss): 174 | if (validation_loss - train_loss) > self.min_delta: 175 | self.counter += 1 176 | if self.counter >= self.tolerance: 177 | return True 178 | else: 179 | self.counter = 0 180 | return False 181 | -------------------------------------------------------------------------------- /agents/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ########################################### Critic net ################################ 5 | class Critic(nn.Module): 6 | def __init__(self, state_dim, action_dim, hidden_dim=256): 7 | super(Critic, self).__init__() 8 | self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), 9 | nn.LayerNorm(hidden_dim), 10 | nn.Mish(), 11 | nn.Linear(hidden_dim, hidden_dim), 12 | nn.LayerNorm(hidden_dim), 13 | nn.Mish(), 14 | nn.Linear(hidden_dim, hidden_dim), 15 | nn.LayerNorm(hidden_dim), 16 | nn.Mish(), 17 | nn.Linear(hidden_dim, 1)) 18 | 19 | self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), 20 | nn.LayerNorm(hidden_dim), 21 | nn.Mish(), 22 | nn.Linear(hidden_dim, hidden_dim), 23 | nn.LayerNorm(hidden_dim), 24 | nn.Mish(), 25 | nn.Linear(hidden_dim, hidden_dim), 26 | nn.LayerNorm(hidden_dim), 27 | nn.Mish(), 28 | nn.Linear(hidden_dim, 1)) 29 | 30 | self.v_model = nn.Sequential(nn.Linear(state_dim, hidden_dim), 31 | nn.Mish(), 32 | nn.Linear(hidden_dim, hidden_dim), 33 | nn.Mish(), 34 | nn.Linear(hidden_dim, hidden_dim), 35 | nn.Mish(), 36 | nn.Linear(hidden_dim, 1)) 37 | 38 | def forward(self, state, action): 39 | x = torch.cat([state, action], dim=-1) 40 | return self.q1_model(x), self.q2_model(x) 41 | 42 | def q1(self, state, action): 43 | x = torch.cat([state, action], dim=-1) 44 | return self.q1_model(x) 45 | 46 | def q_min(self, state, action): 47 | q1, q2 = self.forward(state, action) 48 | return torch.min(q1, q2) 49 | 50 | def v(self, state): 51 | return self.v_model(state) 52 | 53 | 54 | ############################################# SAC ################################################# 55 | def weight_init(m): 56 | """Custom weight init for Conv2D and Linear layers.""" 57 | if isinstance(m, nn.Linear): 58 | nn.init.orthogonal_(m.weight.data) 59 | if hasattr(m.bias, 'data'): 60 | m.bias.data.fill_(0.0) 61 | 62 | 63 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 64 | if hidden_depth == 0: 65 | mods = [nn.Linear(input_dim, output_dim)] 66 | else: 67 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 68 | for i in range(hidden_depth - 1): 69 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 70 | mods.append(nn.Linear(hidden_dim, output_dim)) 71 | if output_mod is not None: 72 | mods.append(output_mod) 73 | trunk = nn.Sequential(*mods) 74 | return trunk 75 | 76 | 77 | class DiagGaussianActorTanhAction(nn.Module): 78 | """torch.distributions implementation of an diagonal Gaussian policy.""" 79 | def __init__(self, state_dim, action_dim, max_action, 80 | hidden_dim=256, hidden_depth=3, 81 | log_std_bounds=[-5, 2]): 82 | super().__init__() 83 | 84 | self.log_std_bounds = log_std_bounds 85 | self.net = mlp(state_dim, hidden_dim, 2 * action_dim, hidden_depth) 86 | self.apply(weight_init) 87 | self.action_scale = max_action 88 | self.action_dim = action_dim 89 | 90 | def forward(self, state): 91 | mu, log_std = self.net(state).chunk(2, dim=-1) 92 | log_std_min, log_std_max = self.log_std_bounds 93 | log_std = log_std.clamp(log_std_min, log_std_max) 94 | 95 | std = log_std.exp() 96 | actor_dist = torch.distributions.Normal(mu, std) 97 | return actor_dist 98 | 99 | def sample(self, state): 100 | actor_dist = self(state) 101 | z = actor_dist.rsample() 102 | action = torch.tanh(z) 103 | 104 | action = action * self.action_scale 105 | action = action.clamp(-self.action_scale, self.action_scale) 106 | return action 107 | 108 | def log_prob(self, state, action): 109 | actor_dist = self(state) 110 | pre_tanh_value = torch.arctanh(action / (self.action_scale + 1e-3)) 111 | log_prob = actor_dist.log_prob(pre_tanh_value) 112 | return log_prob.sum(-1, keepdim=True) 113 | 114 | def get_entropy(self,state): 115 | with torch.no_grad(): 116 | mu, log_std = self.net(state).chunk(2, dim=-1) 117 | return log_std.sum(-1).mean() 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /diffusion/mlps.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import math 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | class SinusoidalPosEmb(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.dim = dim 11 | 12 | def forward(self, x): 13 | device = x.device 14 | half_dim = self.dim // 2 15 | emb = math.log(10000) / (half_dim - 1) 16 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 17 | emb = x[:, None] * emb[None, :] 18 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 19 | return emb 20 | 21 | 22 | class MLPNetwork(nn.Module): 23 | """ 24 | Simple multi layer perceptron network which can be generated with different 25 | activation functions with and without spectral normalization of the weights 26 | """ 27 | 28 | def __init__( 29 | self, 30 | input_dim: int, 31 | hidden_dim: int = 100, 32 | num_hidden_layers: int = 1, 33 | output_dim=1, 34 | device: str = 'cuda' 35 | ): 36 | super(MLPNetwork, self).__init__() 37 | self.network_type = "mlp" 38 | # define number of variables in an input sequence 39 | self.input_dim = input_dim 40 | # the dimension of neurons in the hidden layer 41 | self.hidden_dim = hidden_dim 42 | self.num_hidden_layers = num_hidden_layers 43 | # number of samples per batch 44 | self.output_dim = output_dim 45 | # set up the network 46 | self.layers = nn.ModuleList([nn.Linear(self.input_dim, self.hidden_dim)]) 47 | for i in range(1, self.num_hidden_layers): 48 | self.layers.extend([ 49 | nn.Linear(self.hidden_dim, self.hidden_dim), 50 | nn.Mish() 51 | ]) 52 | self.layers.append(nn.Linear(self.hidden_dim, self.output_dim)) 53 | 54 | self._device = device 55 | self.layers.to(self._device) 56 | 57 | def forward(self, x): 58 | for layer in self.layers: 59 | x = layer(x) 60 | return x 61 | 62 | def get_device(self, device: torch.device): 63 | self._device = device 64 | self.layers.to(device) 65 | 66 | def get_params(self): 67 | return self.layers.parameters() 68 | 69 | 70 | class ScoreNetwork(nn.Module): 71 | def __init__( 72 | self, 73 | action_dim: int, 74 | hidden_dim: int, 75 | time_embed_dim: int, 76 | cond_dim: int, 77 | cond_mask_prob: float, 78 | num_hidden_layers: int = 1, 79 | output_dim=1, 80 | device: str = 'cuda', 81 | cond_conditional: bool = True 82 | ): 83 | super(ScoreNetwork, self).__init__() 84 | # Gaussian random feature embedding layer for time 85 | # self.embed = GaussianFourierProjection(time_embed_dim).to(device) 86 | self.embed = nn.Sequential( 87 | SinusoidalPosEmb(time_embed_dim), 88 | nn.Linear(time_embed_dim, time_embed_dim * 2), 89 | nn.Mish(), 90 | nn.Linear(time_embed_dim * 2, time_embed_dim), 91 | ).to(device) 92 | self.time_embed_dim = time_embed_dim 93 | self.cond_mask_prob = cond_mask_prob 94 | self.cond_conditional = cond_conditional 95 | if self.cond_conditional: 96 | input_dim = self.time_embed_dim + action_dim + cond_dim 97 | else: 98 | input_dim = self.time_embed_dim + action_dim 99 | # set up the network 100 | self.layers = MLPNetwork( 101 | input_dim, 102 | hidden_dim, 103 | num_hidden_layers, 104 | output_dim, 105 | device 106 | ).to(device) 107 | 108 | # build the activation layer 109 | self.act = nn.Mish() 110 | self.device = device 111 | self.sigma = None 112 | self.training = True 113 | 114 | def forward(self, x, cond, sigma, uncond=False): 115 | # Obtain the feature embedding for t 116 | if len(sigma.shape) == 0: 117 | sigma = einops.rearrange(sigma, ' -> 1') 118 | sigma = sigma.unsqueeze(1) 119 | elif len(sigma.shape) == 1: 120 | sigma = sigma.unsqueeze(1) 121 | embed = self.embed(sigma) 122 | embed.squeeze_(1) 123 | if embed.shape[0] != x.shape[0]: 124 | embed = einops.repeat(embed, '1 d -> (1 b) d', b=x.shape[0]) 125 | # during training randomly mask out the cond 126 | # to train the conditional model with classifier-free guidance wen need 127 | # to 0 out some of the conditional during training with a desrired probability 128 | # it is usually in the range of 0,1 to 0.2 129 | if self.training and cond is not None: 130 | cond = self.mask_cond(cond) 131 | # we want to use unconditional sampling during classifier free guidance 132 | if uncond: 133 | cond = torch.zeros_like(cond) # cond 134 | if self.cond_conditional: 135 | x = torch.cat([x, cond, embed], dim=-1) 136 | else: 137 | x = torch.cat([x, embed], dim=-1) 138 | x = self.layers(x) 139 | return x # / marginal_prob_std(t, self.sigma, self.device)[:, None] 140 | 141 | def mask_cond(self, cond, force_mask=False): 142 | bs, d = cond.shape 143 | if force_mask: 144 | return torch.zeros_like(cond) 145 | elif self.training and self.cond_mask_prob > 0.: 146 | mask = torch.bernoulli(torch.ones((bs, d), 147 | device=cond.device) * self.cond_mask_prob) # .view(bs, 1) # 1-> use null_cond, 0-> use real cond 148 | return cond * (1. - mask) 149 | else: 150 | return cond 151 | 152 | def get_params(self): 153 | return self.parameters() -------------------------------------------------------------------------------- /diffusion/karras.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from .utils import * 4 | from torch import nn 5 | class DiffusionModel(nn.Module): 6 | def __init__( 7 | self, 8 | sigma_data: float, 9 | sigma_min: float, 10 | sigma_max: float, 11 | device: str, 12 | sigma_sample_density_type: str = 'loglogistic', 13 | clip_denoised=False, 14 | max_action=1.0, 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.device = device 19 | # use the score wrapper 20 | self.sigma_data = sigma_data 21 | self.sigma_min = sigma_min 22 | self.sigma_max = sigma_max 23 | self.sigma_sample_density_type = sigma_sample_density_type 24 | self.epochs = 0 25 | self.clip_denoised = clip_denoised 26 | self.max_action = max_action 27 | 28 | def get_diffusion_scalings(self, sigma): 29 | """ 30 | Computes the scaling factors for diffusion training at a given time step sigma. 31 | 32 | Args: 33 | - self: the object instance of the model 34 | - sigma (float or torch.Tensor): the time step at which to compute the scaling factors 35 | 36 | , where self.sigma_data: the data noise level of the diffusion process, set during initialization of the model 37 | 38 | Returns: 39 | - c_skip (torch.Tensor): the scaling factor for skipping the diffusion model for the given time step sigma 40 | - c_out (torch.Tensor): the scaling factor for the output of the diffusion model for the given time step sigma 41 | - c_in (torch.Tensor): the scaling factor for the input of the diffusion model for the given time step sigma 42 | 43 | """ 44 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 45 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 46 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 47 | return c_skip, c_out, c_in 48 | 49 | def diffusion_train_step(self, model, x, cond, noise=None, t_chosen=None, return_denoised=False): 50 | """ 51 | Computes the training loss and performs a single update step for the score-based model. 52 | 53 | Args: 54 | - self: the object instance of the model 55 | - x (torch.Tensor): the input tensor of shape (batch_size, dim) 56 | - cond (torch.Tensor): the conditional input tensor of shape (batch_size, cond_dim) 57 | 58 | Returns: 59 | - loss.item() (float): the scalar value of the training loss for this batch 60 | 61 | """ 62 | model.train() 63 | x = x.to(self.device) 64 | cond = cond.to(self.device) 65 | if t_chosen is None: 66 | t_chosen = self.make_sample_density()(shape=(len(x),), device=self.device) 67 | 68 | if return_denoised: 69 | denoised_x, loss = self.diffusion_loss(model,x, cond, t_chosen, noise, return_denoised) 70 | return denoised_x, loss, t_chosen 71 | else: 72 | loss = self.diffusion_loss(model, x, cond, t_chosen, noise, return_denoised) 73 | return loss 74 | 75 | def diffusion_loss(self, model, x, cond, t, noise, return_denoised): 76 | """ 77 | Computes the diffusion training loss for the given model, input, condition, and time. 78 | 79 | Args: 80 | - self: the object instance of the model 81 | - x (torch.Tensor): the input tensor of shape (batch_size, channels, height, width) 82 | - cond (torch.Tensor): the conditional input tensor of shape (batch_size, cond_dim) 83 | - t (torch.Tensor): the time step tensor of shape (batch_size,) 84 | 85 | Returns: 86 | - loss (torch.Tensor): the diffusion training loss tensor of shape () 87 | 88 | The diffusion training loss is computed based on the following equation from Karras et al. 2022: 89 | loss = (model_output - target)^2.mean() 90 | where, 91 | - noise: a tensor of the same shape as x, containing randomly sampled noise 92 | - x_1: a tensor of the same shape as x, obtained by adding the noise tensor to x 93 | - c_skip, c_out, c_in: scaling tensors obtained from the diffusion scalings for the given time step 94 | - t: a tensor of the same shape as t, obtained by taking the natural logarithm of t and dividing it by 4 95 | - model_output: the output tensor of the model for the input x_1, condition cond, and time t 96 | - target: the target tensor for the given input x, scaling tensors c_skip, c_out, c_in, and time t 97 | """ 98 | if noise is None: 99 | noise = torch.randn_like(x) 100 | x_1 = x + noise * append_dims(t, x.ndim) 101 | c_skip, c_out, c_in = [append_dims(x, 2) for x in self.get_diffusion_scalings(t)] 102 | t = torch.log(t) / 4 103 | model_output = model(x_1 * c_in, cond, t) 104 | 105 | if self.clip_denoised: 106 | denoised_x = c_out * model_output + c_skip * x_1 107 | denoised_x = denoised_x.clamp(-self.max_action,self.max_action) 108 | loss = ((denoised_x - x)/c_out).pow(2).mean() 109 | else: 110 | denoised_x = c_out * model_output + c_skip * x_1 111 | target = (x - c_skip * x_1) / c_out 112 | loss = (model_output - target).pow(2).mean() 113 | 114 | if return_denoised: 115 | return denoised_x, loss 116 | else: 117 | return loss 118 | 119 | def make_sample_density(self): 120 | """ 121 | Returns a function that generates random timesteps based on the chosen sample density. 122 | 123 | Args: 124 | - self: the object instance of the model 125 | 126 | Returns: 127 | - sample_density_fn (callable): a function that generates random timesteps 128 | 129 | The method returns a callable function that generates random timesteps based on the chosen sample density. 130 | The available sample densities are: 131 | - 'lognormal': generates random timesteps from a log-normal distribution with mean and standard deviation set 132 | during initialization of the model also used in Karras et al. (2022) 133 | - 'loglogistic': generates random timesteps from a log-logistic distribution with location parameter set to the 134 | natural logarithm of the sigma_data parameter and scale and range parameters set during initialization 135 | of the model 136 | - 'loguniform': generates random timesteps from a log-uniform distribution with range parameters set during 137 | initialization of the model 138 | - 'uniform': generates random timesteps from a uniform distribution with range parameters set during initialization 139 | of the model 140 | - 'v-diffusion': generates random timesteps using the Variational Diffusion sampler with range parameters set during 141 | initialization of the model 142 | - 'discrete': generates random timesteps from the noise schedule using the exponential density 143 | - 'split-lognormal': generates random timesteps from a split log-normal distribution with mean and standard deviation 144 | set during initialization of the model 145 | """ 146 | sd_config = [] 147 | 148 | if self.sigma_sample_density_type == 'lognormal': 149 | loc = self.sigma_sample_density_mean # if 'mean' in sd_config else sd_config['loc'] 150 | scale = self.sigma_sample_density_std # if 'std' in sd_config else sd_config['scale'] 151 | return partial(rand_log_normal, loc=loc, scale=scale) 152 | 153 | if self.sigma_sample_density_type == 'loglogistic': 154 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(self.sigma_data) 155 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5 156 | min_value = sd_config['min_value'] if 'min_value' in sd_config else self.sigma_min 157 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max 158 | return partial(rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 159 | 160 | if self.sigma_sample_density_type == 'loguniform': 161 | min_value = sd_config['min_value'] if 'min_value' in sd_config else self.sigma_min 162 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max 163 | return partial(rand_log_uniform, min_value=min_value, max_value=max_value) 164 | if self.sigma_sample_density_type == 'uniform': 165 | return partial(rand_uniform, min_value=self.sigma_min, max_value=self.sigma_max) 166 | 167 | if self.sigma_sample_density_type == 'v-diffusion': 168 | min_value = self.min_value if 'min_value' in sd_config else self.sigma_min 169 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max 170 | return partial(rand_v_diffusion, sigma_data=self.sigma_data, min_value=min_value, max_value=max_value) 171 | if self.sigma_sample_density_type == 'discrete': 172 | sigmas = self.get_noise_schedule(self.n_sampling_steps, 'exponential') 173 | return partial(rand_discrete, values=sigmas) 174 | else: 175 | raise ValueError('Unknown sample density type') -------------------------------------------------------------------------------- /agents/dtql.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | 4 | from torch.optim.lr_scheduler import CosineAnnealingLR 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from pathlib import Path 9 | from diffusion.karras import DiffusionModel 10 | from diffusion.mlps import ScoreNetwork 11 | from agents.model import Critic, DiagGaussianActorTanhAction 12 | from agents.helpers import EMA 13 | 14 | class DTQL(object): 15 | def __init__(self, 16 | device, 17 | state_dim, 18 | action_dim, 19 | action_space=None, 20 | discount=0.99, 21 | alpha=1.0, 22 | ema_decay=0.995, 23 | step_start_ema=1000, 24 | update_ema_every=5, 25 | lr=3e-4, 26 | lr_decay=False, 27 | lr_maxt=1000, 28 | sigma_max=80., 29 | sigma_min=0.002, 30 | sigma_data=0.5, 31 | expectile=0.7, 32 | tau = 0.005, 33 | gamma=1, 34 | repeats=1024 35 | ): 36 | """Init critic networks""" 37 | self.critic = Critic(state_dim, action_dim).to(device) 38 | self.critic_target = copy.deepcopy(self.critic) 39 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 40 | 41 | """Init bc actor""" 42 | self.bc_actor = ScoreNetwork( 43 | action_dim=action_dim, 44 | hidden_dim=256, 45 | time_embed_dim=16, 46 | cond_dim=state_dim, 47 | cond_mask_prob=0.0, 48 | num_hidden_layers=4, 49 | output_dim=action_dim, 50 | device=device, 51 | cond_conditional=True 52 | ).to(device) 53 | self.bc_actor_target = copy.deepcopy(self.bc_actor) 54 | self.bc_actor_optimizer = torch.optim.Adam(self.bc_actor.parameters(), lr=lr) 55 | 56 | """Init diffusion training schedule""" 57 | self.diffusion = DiffusionModel( 58 | sigma_data=sigma_data, 59 | sigma_min=sigma_min, 60 | sigma_max=sigma_max, 61 | device=device, 62 | clip_denoised=True, 63 | max_action=float(action_space.high[0])) 64 | 65 | """Init sac""" 66 | self.distill_actor = DiagGaussianActorTanhAction(state_dim=state_dim, action_dim=action_dim, 67 | max_action=float(action_space.high[0])).to(device) 68 | self.distill_actor_target = copy.deepcopy(self.distill_actor) 69 | self.distill_actor_optimizer = torch.optim.Adam(self.distill_actor.parameters(), lr=lr) 70 | 71 | 72 | """Back up training parameters""" 73 | self.tau = tau 74 | self.lr_decay = lr_decay 75 | self.gamma = gamma 76 | self.repeats = repeats 77 | 78 | self.step = 0 79 | self.step_start_ema = step_start_ema 80 | self.ema = EMA(ema_decay) 81 | self.update_ema_every = update_ema_every 82 | 83 | if lr_decay: 84 | self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.) 85 | self.bc_actor_lr_scheduler = CosineAnnealingLR(self.bc_actor_optimizer, T_max=lr_maxt, eta_min=0.) 86 | self.distill_actor_lr_scheduler = CosineAnnealingLR(self.distill_actor_optimizer, T_max=lr_maxt, eta_min=0.) 87 | 88 | self.state_dim = state_dim 89 | self.action_dim = action_dim 90 | self.discount = discount 91 | self.alpha = alpha # bc weight 92 | self.expectile = expectile 93 | self.device = device 94 | 95 | 96 | def step_ema(self): 97 | if self.step < self.step_start_ema: 98 | return 99 | self.ema.update_model_average(self.distill_actor_target, self.distill_actor) 100 | 101 | def pretrain(self,replay_buffer, batch_size=256,pretrain_steps=50000): 102 | self.bc_actor.train() 103 | for _ in range(pretrain_steps): 104 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 105 | loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state) 106 | self.bc_actor_optimizer.zero_grad() 107 | loss.backward() 108 | self.bc_actor_optimizer.step() 109 | self.bc_loss = loss 110 | 111 | critic_loss = self.q_v_critic_loss(state,action, next_state, reward, not_done) 112 | self.critic_optimizer.zero_grad() 113 | critic_loss.backward() 114 | self.critic_optimizer.step() 115 | 116 | 117 | def train(self, replay_buffer, batch_size=256): 118 | # initialize 119 | self.bc_loss = torch.tensor([0.]).to(self.device) 120 | self.critic_loss = torch.tensor([0.]).to(self.device) 121 | metric = {'bc_loss': [], 'distill_loss':[], 'ql_loss': [], 'actor_loss': [], 'critic_loss': [], 'gamma_loss': []} 122 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 123 | 124 | """ Q Training """ 125 | critic_loss = self.q_v_critic_loss(state, action, next_state, reward, not_done) 126 | 127 | self.critic_loss = critic_loss 128 | self.critic_optimizer.zero_grad() 129 | critic_loss.backward() 130 | self.critic_optimizer.step() 131 | 132 | 133 | """ Diffusion Policy Training """ 134 | bc_loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state) 135 | self.bc_actor_optimizer.zero_grad() 136 | bc_loss.backward() 137 | self.bc_actor_optimizer.step() 138 | self.bc_loss = bc_loss 139 | 140 | 141 | """Distill Policy Training""" 142 | new_action = self.distill_actor.sample(state) 143 | distill_loss = self.diffusion.diffusion_train_step(self.bc_actor, new_action, state) 144 | q_loss = -self.critic.q_min(state, new_action).mean() 145 | 146 | if self.gamma == 0.: 147 | gamma_loss = torch.tensor([0.]).to(self.device) 148 | else: 149 | gamma_loss = -self.distill_actor.log_prob(state, action).mean() 150 | actor_loss = (self.alpha * distill_loss + q_loss + 151 | self.gamma * gamma_loss) 152 | self.distill_actor_optimizer.zero_grad() 153 | actor_loss.backward() 154 | self.distill_actor_optimizer.step() 155 | 156 | """ Step Target network """ 157 | if self.step % self.update_ema_every == 0: 158 | self.step_ema() 159 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 160 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 161 | 162 | """ Record loss """ 163 | self.step += 1 164 | 165 | metric['actor_loss'].append(actor_loss.item()) 166 | metric['bc_loss'].append(self.bc_loss.item()) 167 | metric['ql_loss'].append(q_loss.item()) 168 | metric['critic_loss'].append(self.critic_loss.item()) 169 | metric['distill_loss'].append(distill_loss.item()) 170 | metric['gamma_loss'].append(gamma_loss.item()) 171 | 172 | """ Lr decay""" 173 | if self.lr_decay: 174 | self.bc_actor_lr_scheduler.step() 175 | self.distill_actor_lr_scheduler.step() 176 | self.critic_lr_scheduler.step() 177 | return metric 178 | 179 | def sample_action(self, state): 180 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 181 | state_rpt = torch.repeat_interleave(state, repeats=self.repeats, dim=0) 182 | with torch.no_grad(): 183 | action = self.distill_actor.sample(state_rpt) 184 | q_value = self.critic_target.q_min(state_rpt, action).flatten() 185 | idx = torch.multinomial(F.softmax(q_value), 1) 186 | action = action[idx].cpu().data.numpy().flatten() 187 | return action 188 | 189 | def save_model(self, dir, id=None): 190 | if id is not None: 191 | torch.save(self.bc_actor.state_dict(), f'{dir}/bc_actor_{id}.pth') 192 | torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth') 193 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor_{id}.pth') 194 | else: 195 | torch.save(self.bc_actor.state_dict(), f'{dir}/actor.pth') 196 | torch.save(self.critic.state_dict(), f'{dir}/critic.pth') 197 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor.pth') 198 | 199 | def load_model(self, dir, id=None): 200 | if id is not None: 201 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor_{id}.pth')) 202 | self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth')) 203 | print(f"Successfully load critic from {dir}/critic_{id}.pth") 204 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor_{id}.pth')) 205 | print(f"Successfully load distill actor from {dir}/distill_actor_{id}.pth") 206 | else: 207 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor.pth')) 208 | self.critic.load_state_dict(torch.load(f'{dir}/critic.pth')) 209 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor.pth')) 210 | print(f"Models loaded successfully from {dir}") 211 | 212 | def load_or_pretrain_models(self, dir, replay_buffer, batch_size, pretrain_steps,num_steps_per_epoch): 213 | # Paths for the models 214 | actor_path = Path(dir) / f'diffusion_pretrained_{pretrain_steps // num_steps_per_epoch}.pth' 215 | critic_path = Path(dir) / f'critic_pretrained_{pretrain_steps // num_steps_per_epoch}.pth' 216 | 217 | # Check if both models exist 218 | if actor_path.exists() and critic_path.exists(): 219 | try: 220 | # Load the models 221 | self.bc_actor.load_state_dict(torch.load(actor_path, map_location=self.device)) 222 | self.critic.load_state_dict(torch.load(critic_path, map_location=self.device)) 223 | except Exception as e: 224 | print(f"Failed to load models: {e}") 225 | else: 226 | # Begin pretraining if the models do not exist 227 | print("Models not found, starting pretraining...") 228 | self.pretrain(replay_buffer, batch_size, pretrain_steps) 229 | torch.save(self.bc_actor.state_dict(), actor_path) 230 | torch.save(self.critic.state_dict(), critic_path) 231 | print(f"Saved successfully to {dir}") 232 | 233 | 234 | def q_v_critic_loss(self,state,action, next_state, reward, not_done): 235 | def expectile_loss(diff, expectile=0.8): 236 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 237 | return weight * (diff ** 2) 238 | 239 | with torch.no_grad(): 240 | q = self.critic.q_min(state, action) 241 | v = self.critic.v(state) 242 | value_loss = expectile_loss(q - v, self.expectile).mean() 243 | 244 | current_q1, current_q2 = self.critic(state, action) 245 | with torch.no_grad(): 246 | next_v = self.critic.v(next_state) 247 | target_q = (reward + not_done * self.discount * next_v).detach() 248 | 249 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) 250 | return critic_loss 251 | -------------------------------------------------------------------------------- /toy_tasks/data_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.distributions import Normal 4 | import torch 5 | from torch.distributions.normal import Normal 6 | 7 | 8 | class DataGenerator: 9 | def __init__(self, dist_type: str): 10 | self.dist_type = dist_type 11 | self.func_mapping = { 12 | "two_gmm_1D": (self.two_gmm_1D, self.two_gmm_1D_log_prob), 13 | "uneven_two_gmm_1D": (self.uneven_two_gmm_1D, self.uneven_two_gmm_1D_log_prob), 14 | "three_gmm_1D": (self.three_gmm_1D, self.three_gmm_1D_log_prob), 15 | "single_gaussian_1D": (self.single_gaussian_1D, self.single_gaussian_1D_log_prob), 16 | "swiss_roll_2D": (self.sample_swiss_roll, None), 17 | "25_gaussian": (self.sample_25_gaussian, None) 18 | } 19 | if self.dist_type not in self.func_mapping: 20 | raise ValueError("Invalid distribution type") 21 | self.sample_func, self.log_prob_func = self.func_mapping[self.dist_type] 22 | 23 | def generate_samples(self, num_samples: int): 24 | """ 25 | Generate `num_samples` samples and labels using the `sample_func`. 26 | 27 | Args: 28 | num_samples (int): Number of samples to generate. 29 | 30 | Returns: 31 | Tuple[np.ndarray, np.ndarray]: A tuple of two numpy arrays containing the generated samples and labels. 32 | """ 33 | samples, labels = self.sample_func(num_samples) 34 | return samples, labels 35 | 36 | def compute_log_prob(self, samples, exp: bool = False): 37 | """ 38 | Compute the logarithm of probability density function (pdf) of the given `samples` 39 | using the `log_prob_func`. If `exp` is True, return exponentiated log probability. 40 | 41 | Args: 42 | samples (np.ndarray): Samples for which pdf is to be computed. 43 | exp (bool, optional): If True, return exponentiated log probability. 44 | Default is False. 45 | 46 | Returns: 47 | np.ndarray: Logarithm of probability density function (pdf) of the given `samples`. 48 | If `exp` is True, exponentiated log probability is returned. 49 | """ 50 | return self.log_prob_func(samples, exp=exp) 51 | 52 | @staticmethod 53 | def two_gmm_1D(num_samples, ): 54 | """ 55 | Generates `num_samples` samples from a 1D mixture of two Gaussians with equal weights. 56 | 57 | Args: 58 | num_samples (int): Number of samples to generate. 59 | 60 | Returns: 61 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated 62 | samples and binary labels indicating which Gaussian component the sample is from. 63 | """ 64 | g1 = Normal(loc=-1.5, scale=0.3) 65 | g2 = Normal(loc=1.5, scale=0.3) 66 | mixture_probs = torch.ones(num_samples) * 0.5 67 | is_from_g1 = torch.bernoulli(mixture_probs).bool() 68 | samples = torch.where(is_from_g1, g1.sample((num_samples,)), g2.sample((num_samples,))) 69 | return samples, is_from_g1.int() 70 | 71 | @staticmethod 72 | def uneven_two_gmm_1D(num_samples, w1=0.7): 73 | """ 74 | Generates `num_samples` samples from a 1D mixture of two Gaussians with weights `w1` and `w2`. 75 | 76 | Args: 77 | num_samples (int): Number of samples to generate. 78 | w1 (float, optional): Weight of first Gaussian component. Default is 0.7. 79 | 80 | Returns: 81 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated 82 | samples and binary labels indicating which Gaussian component the sample is from. 83 | """ 84 | g1 = Normal(loc=-1.5, scale=0.3) 85 | g2 = Normal(loc=1.5, scale=0.2) 86 | mixture_probs = torch.tensor([w1, 1 - w1]) 87 | is_from_g1 = torch.bernoulli(mixture_probs.repeat(num_samples, 1)).view(num_samples, -1).bool().squeeze() 88 | 89 | samples_g1 = g1.sample((num_samples, 1)) 90 | samples_g2 = g2.sample((num_samples, 1)) 91 | samples = torch.where(is_from_g1, samples_g1, samples_g2).squeeze() 92 | 93 | return samples, is_from_g1.int() 94 | 95 | @staticmethod 96 | def single_gaussian_1D(num_samples): 97 | """ 98 | Generates `num_samples` samples from a 1D Gaussian distribution. 99 | 100 | Args: 101 | num_samples (int): Number of samples to generate. 102 | 103 | Returns: 104 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated 105 | samples and binary labels indicating which Gaussian component the sample is from. 106 | Since there is only one Gaussian component, all labels will be zero. 107 | """ 108 | g1 = Normal(loc=1, scale=0.2) 109 | samples = g1.sample((num_samples, 1)) 110 | return samples, torch.zeros(num_samples).int() 111 | 112 | @staticmethod 113 | def three_gmm_1D(num_samples): 114 | """ 115 | Generates `num_samples` samples from a 1D mixture of three Gaussians with equal weights. 116 | 117 | Args: 118 | num_samples (int): Number of samples to generate. 119 | exp (bool, optional): If True, return exponentiated log probability. Default is False. 120 | 121 | Returns: 122 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated 123 | samples and integer labels indicating which Gaussian component the sample is from. 124 | """ 125 | g1 = Normal(loc=-1.5, scale=0.2) 126 | g2 = Normal(loc=0, scale=0.2) 127 | g3 = Normal(loc=1.5, scale=0.2) 128 | mixture_probs = torch.ones(3) / 3 129 | component_assignments = torch.multinomial(mixture_probs, num_samples, replacement=True) 130 | samples = torch.zeros(num_samples, 1) 131 | 132 | g1_mask = (component_assignments == 0) 133 | g2_mask = (component_assignments == 1) 134 | g3_mask = (component_assignments == 2) 135 | 136 | samples[g1_mask] = g1.sample((g1_mask.sum(),)).view(-1, 1) 137 | samples[g2_mask] = g2.sample((g2_mask.sum(),)).view(-1, 1) 138 | samples[g3_mask] = g3.sample((g3_mask.sum(),)).view(-1, 1) 139 | 140 | return samples, component_assignments.int() 141 | 142 | @staticmethod 143 | def two_gmm_1D_log_prob(z, exp=False): 144 | """ 145 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of two Gaussians 146 | with equal weights at the given points `z`. 147 | 148 | Args: 149 | z (torch.Tensor): Points at which to compute the pdf. 150 | exp (bool, optional): If True, return exponentiated log probability. Default is False. 151 | 152 | Returns: 153 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of two Gaussians 154 | with equal weights at the given points `z`. If `exp` is True, exponentiated log probability 155 | is returned. 156 | """ 157 | g1 = Normal(loc=-1.5, scale=0.3) 158 | g2 = Normal(loc=1.5, scale=0.3) 159 | f = torch.log(0.5 * (g1.log_prob(z).exp() + g2.log_prob(z).exp())) 160 | if exp: 161 | return torch.exp(f) 162 | else: 163 | return f 164 | 165 | @staticmethod 166 | def uneven_two_gmm_1D_log_prob(z, w1=0.7, exp=False): 167 | """ 168 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of two Gaussians 169 | with weights `w1` and `w2` at the given points `z`. 170 | 171 | Args: 172 | z (torch.Tensor): Points at which to compute the pdf. 173 | w1 (float, optional): Weight of first Gaussian component. Default is 0.7. 174 | exp (bool, optional): If True, return exponentiated log probability. Default is False. 175 | 176 | Returns: 177 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of two Gaussians 178 | with weights `w1` and `w2` at the given points `z`. If `exp` is True, exponentiated log probability 179 | is returned. 180 | """ 181 | g1 = Normal(loc=-1.5, scale=0.3) 182 | g2 = Normal(loc=1.5, scale=0.2) 183 | f = torch.log(w1 * g1.log_prob(z).exp() + (1 - w1) * g2.log_prob(z).exp()) 184 | if exp: 185 | return torch.exp(f) 186 | else: 187 | return f 188 | 189 | @staticmethod 190 | def three_gmm_1D_log_prob(z, exp=False): 191 | """ 192 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of three Gaussians 193 | with equal weights at the given points `z`. 194 | 195 | Args: 196 | z (torch.Tensor): Points at which to compute the pdf. 197 | exp (bool, optional): If True, return exponentiated log probability. Default is False. 198 | 199 | Returns: 200 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of three Gaussians 201 | with equal weights at the given points `z`. If `exp` is True, exponentiated log probability 202 | is returned. 203 | """ 204 | g1 = Normal(loc=-1.5, scale=0.2) 205 | g2 = Normal(loc=0, scale=0.2) 206 | g3 = Normal(loc=1.5, scale=0.2) 207 | f = torch.log(1 / 3 * (g1.log_prob(z).exp() + g2.log_prob(z).exp() + g3.log_prob(z).exp())) 208 | if exp: 209 | return torch.exp(f) 210 | else: 211 | return f 212 | 213 | @staticmethod 214 | def single_gaussian_1D_log_prob(z, exp=False): 215 | """ 216 | Computes the logarithm of the probability density function (pdf) of a 1D Gaussian 217 | distribution at the given points `z`. 218 | 219 | Args: 220 | z (torch.Tensor): Points at which to compute the pdf. 221 | exp (bool, optional): If True, return exponentiated log probability. Default is False. 222 | 223 | Returns: 224 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D Gaussian 225 | distribution at the given points `z`. If `exp` is True, exponentiated log probability 226 | is returned. 227 | """ 228 | g = Normal(loc=1, scale=0.2) 229 | f = g.log_prob(z) 230 | if exp: 231 | return torch.exp(f) 232 | else: 233 | return f 234 | 235 | @staticmethod 236 | def sample_swiss_roll(num_samples): 237 | from sklearn.datasets import make_swiss_roll 238 | samples = make_swiss_roll(num_samples, noise=0.1) 239 | samples = torch.tensor(samples[0][:, [0, 2]], dtype=torch.float32) 240 | # None is the placeholder for label 241 | return samples, torch.zeros(num_samples).int() 242 | 243 | @staticmethod 244 | def sample_25_gaussian(num_samples): 245 | num_modes = 25 # Number of Gaussian modes 246 | grid_size = int(np.sqrt(num_modes)) # Determining the grid size (5x5 for 25 modes) 247 | 248 | # Creating a grid of means 249 | x_means = np.linspace(-10, 10, grid_size) 250 | y_means = np.linspace(-10, 10, grid_size) 251 | means = np.array(np.meshgrid(x_means, y_means)).T.reshape(-1, 2) 252 | 253 | # Standard deviation for each mode (can be adjusted as needed) 254 | # Standard deviation for each mode (can be adjusted as needed) 255 | std_dev = 0.3 256 | covariance_matrix = np.array([[std_dev ** 2, 0], [0, std_dev ** 2]]) # Diagonal covariance matrix 257 | 258 | # Generating one sample from each mode 259 | samples = np.array( 260 | [np.random.multivariate_normal(mean, covariance_matrix, num_samples // num_modes) for mean in means]) 261 | samples = samples.reshape(-1, samples.shape[-1]) 262 | samples = torch.from_numpy(samples).type(torch.float32) 263 | return samples, torch.zeros(num_samples).int() -------------------------------------------------------------------------------- /toy_tasks/toy_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | import sys 6 | sys.path.append(("../")) 7 | 8 | from diffusion.karras import DiffusionModel 9 | from diffusion.mlps import ScoreNetwork 10 | from data_generator import DataGenerator 11 | from agents.model import DiagGaussianActorTanhAction, Critic 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import os 15 | import copy 16 | import torch.nn.functional as F 17 | import random 18 | 19 | 20 | def q_v_critic_loss(critic, state, action, next_state, reward, not_done, expectile, discount): 21 | def expectile_loss(diff, expectile=0.8): 22 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 23 | return weight * (diff ** 2) 24 | 25 | with torch.no_grad(): 26 | q = critic.q_min(state, action) 27 | v = critic.v(state) 28 | value_loss = expectile_loss(q - v, expectile).mean() 29 | 30 | current_q1, current_q2 = critic(state, action) 31 | with torch.no_grad(): 32 | next_v = critic.v(next_state) 33 | target_q = (reward + not_done * discount * next_v).detach() 34 | 35 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) 36 | return critic_loss 37 | 38 | def get_dmd_loss(diffusion_model, true_score_model, fake_score_model, fake_action_data, state_data): 39 | noise = torch.randn_like(fake_action_data) 40 | with torch.no_grad(): 41 | pred_real_action, _, t_chosen = diffusion_model.diffusion_train_step(model=true_score_model, x=fake_action_data, cond=state_data, 42 | noise=noise, t_chosen=None, return_denoised=True) 43 | 44 | pred_fake_action, _, t_chosen = diffusion_model.diffusion_train_step(model=fake_score_model, x=fake_action_data, cond=state_data, 45 | noise=noise, t_chosen=t_chosen, 46 | return_denoised=True) 47 | weighting_factor = (fake_action_data - pred_real_action).abs().mean(axis=1).reshape(-1, 1) 48 | grad = (pred_fake_action - pred_real_action) / weighting_factor 49 | distill_loss = 0.5 * F.mse_loss(fake_action_data, (fake_action_data - grad).detach()) 50 | return distill_loss 51 | 52 | def train_toy_task(args,file_name): 53 | data_manager = DataGenerator(args.env_name) 54 | action, state = data_manager.generate_samples(10000) 55 | device = args.device 56 | 57 | """Data Prepare""" 58 | state = state.to(torch.float32).to(device) 59 | state = state.reshape(-1, 1) 60 | next_state = state.reshape(-1, 1).to(device) 61 | reward = np.linalg.norm(action, axis=1, keepdims=True) 62 | 63 | if args.reward_type == "far": 64 | # farer tp (0,0), higher reward 65 | reward = torch.from_numpy(reward).to(torch.float32).to(device) 66 | elif args.reward_type == "near": 67 | # closer to (0,0), higher reward 68 | reward = torch.from_numpy(np.max(reward) - reward).to(torch.float32).to(device) 69 | elif args.reward_type == "hard": # reward type is hard 70 | reward[(action[:, 0] < -7.5) & (action[:, 1] > 7.5)] = 2 * reward[(action[:, 0] < -7.5) & (action[:, 1] > 7.5)] 71 | reward = torch.from_numpy(reward).to(torch.float32).to(device) 72 | elif args.reward_type == "same": 73 | reward = torch.from_numpy(np.zeros_like(reward)).to(torch.float32).to(device) 74 | 75 | action = action.to(torch.float32).to(device) 76 | not_done = torch.ones_like(reward).to(torch.float32) 77 | 78 | """Init bc actor""" 79 | bc_actor = ScoreNetwork( 80 | action_dim=2, 81 | hidden_dim=128, 82 | time_embed_dim=4, 83 | cond_dim=1, 84 | cond_mask_prob=0.0, 85 | num_hidden_layers=4, 86 | output_dim=2, 87 | device=device, 88 | cond_conditional=True 89 | ).to(device) 90 | bc_actor_optimizer = torch.optim.Adam(bc_actor.parameters(), lr=3e-3) 91 | 92 | diffusion = DiffusionModel( 93 | sigma_data=args.sigma_data, 94 | sigma_min=args.sigma_min, 95 | sigma_max=args.sigma_max, 96 | device=device, 97 | ) 98 | 99 | if args.actor == "sac": 100 | distill_actor = DiagGaussianActorTanhAction(state_dim=1, action_dim=2, 101 | max_action=action.abs().max()).to(device) 102 | elif args.actor == "implicit": 103 | distill_actor = ScoreNetwork( 104 | action_dim=2, 105 | hidden_dim=128, 106 | time_embed_dim=4, 107 | cond_dim=1, 108 | cond_mask_prob=0.0, 109 | num_hidden_layers=4, 110 | output_dim=2, 111 | device=device, 112 | cond_conditional=True 113 | ).to(device) 114 | 115 | if args.pretrain_diffusion: 116 | for _ in tqdm(range(args.pretrain_epochs)): 117 | loss = diffusion.diffusion_train_step(bc_actor, action, state) 118 | bc_actor_optimizer.zero_grad() 119 | loss.backward() 120 | bc_actor_optimizer.step() 121 | bc_actor_state_dict = bc_actor.state_dict() 122 | distill_actor.load_state_dict(bc_actor_state_dict) 123 | 124 | else: 125 | raise ValueError("Actor type can only be sac for implicit") 126 | distill_actor_optimizer = torch.optim.Adam(distill_actor.parameters(), lr=3e-3) 127 | 128 | critic = Critic(state_dim=1, action_dim=2).to(device) 129 | critic_target = copy.deepcopy(critic) 130 | critic_optimizer = torch.optim.Adam(critic.parameters(), lr=3e-3) 131 | 132 | if args.distill_loss == "dmd": 133 | distill_score = ScoreNetwork( 134 | action_dim=2, 135 | hidden_dim=128, 136 | time_embed_dim=4, 137 | cond_dim=1, 138 | cond_mask_prob=0.0, 139 | num_hidden_layers=4, 140 | output_dim=2, 141 | device=device, 142 | cond_conditional=True 143 | ).to(device) 144 | distill_score_optimizer = torch.optim.Adam(distill_score.parameters(), lr=3e-3) 145 | 146 | def get_action(given_state, action_dim, generation_sigma=2.5): 147 | if args.actor == "sac": 148 | action = distill_actor.sample(state=given_state) 149 | return action 150 | elif args.actor == "implicit": 151 | noise = torch.randn((given_state.shape[0], action_dim)) * generation_sigma 152 | noise = noise.to(given_state.device) 153 | action = distill_actor(noise, given_state, torch.tensor([generation_sigma]).to(given_state.device)) 154 | return action 155 | else: 156 | raise ValueError("Actor not correct.") 157 | 158 | pbar = tqdm(range(args.train_epochs)) 159 | for i in range(args.train_epochs): 160 | """Q policy""" 161 | critic_loss = q_v_critic_loss(critic,state,action, next_state, reward, not_done, args.expectile, args.discount) 162 | critic_optimizer.zero_grad() 163 | critic_loss.backward() 164 | critic_optimizer.step() 165 | 166 | """BC policy""" 167 | loss = diffusion.diffusion_train_step(bc_actor, action, state) 168 | bc_actor_optimizer.zero_grad() 169 | loss.backward() 170 | bc_actor_optimizer.step() 171 | 172 | """Distill policy""" 173 | new_action = get_action(given_state=state, action_dim=2, 174 | generation_sigma=args.generation_sigma) 175 | q_loss = -critic.q_min(state, new_action).mean() 176 | 177 | if args.distill_loss == "diffusion": 178 | distill_loss = diffusion.diffusion_train_step(bc_actor, new_action, state) 179 | elif args.distill_loss == "dmd": 180 | distill_loss = get_dmd_loss(diffusion, bc_actor, distill_score, new_action, state) 181 | else: 182 | distill_loss = 0 183 | 184 | if args.gamma == 0.: 185 | gamma_loss = 0 186 | else: 187 | gamma_loss = -distill_actor.log_prob(state,action).mean() 188 | 189 | actor_loss = distill_loss + args.eta * q_loss + args.gamma * gamma_loss 190 | distill_actor_optimizer.zero_grad() 191 | actor_loss.backward() 192 | distill_actor_optimizer.step() 193 | 194 | """Train fake score""" 195 | if args.distill_loss == "dmd": 196 | fake_loss = diffusion.diffusion_train_step(distill_score, new_action.detach(), state) 197 | distill_score_optimizer.zero_grad() 198 | fake_loss.backward() 199 | distill_score_optimizer.step() 200 | 201 | """Update critic target""" 202 | for param, target_param in zip(critic.parameters(), critic_target.parameters()): 203 | tau = 0.005 204 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 205 | 206 | pbar.set_description(f"Step {i}, BC Loss: {loss:.4f}, Critic Loss: {critic_loss:.4f}, Q loss:{q_loss:.4f}") 207 | pbar.update(1) 208 | 209 | 210 | state = torch.zeros(1000).int().reshape(-1, 1).to(device).to(torch.float32) 211 | with torch.no_grad(): 212 | plt.figure(figsize=(8, 6)) 213 | scatter = plt.scatter(action.cpu().numpy()[:, 0], action.cpu().numpy()[:, 1], c=reward.cpu().numpy()) 214 | plt.colorbar(scatter, label='Reward values') 215 | plt.title(args.env_name) 216 | plt.xlabel('Action Dimension 1') 217 | plt.ylabel('Action Dimension 2') 218 | plt.grid(True) 219 | 220 | new_action = get_action(given_state=state,action_dim=2, 221 | generation_sigma=args.generation_sigma) 222 | new_action = new_action.cpu().numpy() 223 | plt.scatter(new_action[:, 0], new_action[:, 1], c='red',alpha=0.5) 224 | plot_path = file_name + ".png" 225 | plt.savefig(plot_path) 226 | 227 | if __name__ == "__main__": 228 | parser = argparse.ArgumentParser() 229 | ### Experimental Setups ### 230 | parser.add_argument('--device', default=1, type=int) 231 | parser.add_argument('--actor', default="sac", type=str,help="sac or implicit") 232 | parser.add_argument('--distill_loss', default="dmd", type=str, help="diffusion or dmd") 233 | parser.add_argument("--env_name", default="25_gaussian", type=str, help="swiss_roll_2D or 25_gaussian") 234 | parser.add_argument("--pretrain_diffusion", action="store_true") 235 | parser.add_argument("--train_epochs", default=3500, type=int) 236 | parser.add_argument("--pretrain_epochs", default=500, type=int) 237 | parser.add_argument("--reward_type", default="hard", type=str, help="far, near, or hard") 238 | parser.add_argument("--seed", default=1, type=int) 239 | parser.add_argument("--discount", default=0.99, type=float) 240 | parser.add_argument("--eta", default=1, type=float) 241 | parser.add_argument("--generation_sigma", default=2.5, type=float) 242 | parser.add_argument("--gamma", default=0.,type=float,help="weight of sac entropy") 243 | parser.add_argument("--expectile", default=0.95, type=float) 244 | 245 | parser.add_argument("--sigma_max", default=80, type=float) 246 | parser.add_argument("--sigma_min", default=0.002, type=float) 247 | parser.add_argument("--sigma_data", default=0.5, type=float) 248 | 249 | args = parser.parse_args() 250 | output_dir = f"results/{args.env_name}" 251 | if not os.path.exists(output_dir): 252 | os.makedirs(output_dir) 253 | file_name = f"actor={args.actor}|distill={args.distill_loss}|seed={args.seed}|reward={args.reward_type}" 254 | file_name = os.path.join(output_dir,file_name) 255 | if args.actor == "implicit": 256 | file_name += f"|pretrain={args.pretrain_diffusion}" 257 | file_name += f"|eta={args.eta}" 258 | file_name += f"|gamma={args.gamma}" 259 | 260 | def set_seed(seed): 261 | torch.manual_seed(seed) 262 | torch.cuda.manual_seed_all(seed) # For multi-GPU setups 263 | np.random.seed(seed) 264 | random.seed(seed) 265 | torch.backends.cudnn.deterministic = True 266 | torch.backends.cudnn.benchmark = False # May impact performance 267 | 268 | set_seed(args.seed) 269 | train_toy_task(args, file_name) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | import os 5 | import torch 6 | from pathlib import Path 7 | 8 | import d4rl 9 | from utils import utils 10 | from utils.data_sampler import Data_Sampler 11 | from utils.logger import logger, setup_logger 12 | from agents.dtql import DTQL as Agent 13 | 14 | """If you are using DQL-KL, you can specific generation_sigma when init agent""" 15 | #from agents.dql_kl import DQL_KL as Agent 16 | import random 17 | 18 | offline_hyperparameters = { 19 | 'halfcheetah-medium-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 20 | 'halfcheetah-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 21 | 'halfcheetah-medium-expert-v2': {'lr': 3e-4, 'alpha': 50.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 22 | 'hopper-medium-v2': {'lr': 1e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 23 | 'hopper-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 24 | 'hopper-medium-expert-v2': {'lr': 3e-4, 'alpha': 20.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 25 | 'walker2d-medium-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 26 | 'walker2d-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 27 | 'walker2d-medium-expert-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 28 | 'antmaze-umaze-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9}, 29 | 'antmaze-umaze-diverse-v0': {'lr': 3e-5, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9}, 30 | 'antmaze-medium-play-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9}, 31 | 'antmaze-medium-diverse-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9}, 32 | 'antmaze-large-play-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 350, 'batch_size': 2048, 'expectile': 0.9}, 33 | 'antmaze-large-diverse-v0': {'lr': 3e-4, 'alpha': 0.5, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 300, 'batch_size': 2048, 'expectile': 0.9}, 34 | 'antmaze-umaze-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9}, 35 | 'antmaze-umaze-diverse-v2': {'lr': 3e-5, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9}, 36 | 'antmaze-medium-play-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9}, 37 | 'antmaze-medium-diverse-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9}, 38 | 'antmaze-large-play-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 350, 'batch_size': 2048, 'expectile': 0.9}, 39 | 'antmaze-large-diverse-v2': {'lr': 3e-4, 'alpha': 0.5, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 300, 'batch_size': 2048, 'expectile': 0.9}, 40 | 'pen-human-v1': {'lr': 3e-5, 'alpha': 1500.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 300, 'batch_size': 256, 'expectile': 0.9}, 41 | 'pen-cloned-v1': {'lr': 1e-5, 'alpha': 1500.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 200, 'batch_size': 256, 'expectile': 0.7}, 42 | 'kitchen-complete-v0': {'lr': 1e-4, 'alpha': 200.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 256, 'expectile': 0.7}, 43 | 'kitchen-partial-v0': {'lr': 1e-4, 'alpha': 100.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7}, 44 | 'kitchen-mixed-v0': {'lr': 3e-4, 'alpha': 200.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 256, 'expectile': 0.7}, 45 | } 46 | 47 | def train_agent(env, state_dim, action_dim, device, output_dir, args): 48 | dataset = d4rl.qlearning_dataset(env) 49 | data_sampler = Data_Sampler(dataset, device, args.reward_tune) 50 | utils.print_banner('Loaded buffer') 51 | 52 | agent = Agent(state_dim=state_dim, 53 | action_dim=action_dim, 54 | action_space=env.action_space, 55 | device=device, 56 | discount=args.discount, 57 | lr=args.lr, 58 | alpha=args.alpha, 59 | lr_decay=args.lr_decay, 60 | lr_maxt=args.num_epochs*args.num_steps_per_epoch, 61 | expectile=args.expectile, 62 | sigma_data=args.sigma_data, 63 | sigma_max=args.sigma_max, 64 | sigma_min=args.sigma_min, 65 | tau=args.tau, 66 | gamma=args.gamma, 67 | repeats=args.repeats) 68 | if args.pretrain_epochs is not None: 69 | agent.load_or_pretrain_models( 70 | dir=str(Path(output_dir)), 71 | replay_buffer=data_sampler, 72 | batch_size=args.batch_size, 73 | pretrain_steps=args.pretrain_epochs*args.num_steps_per_epoch, 74 | num_steps_per_epoch=args.num_steps_per_epoch) 75 | 76 | training_iters = 0 77 | max_timesteps = args.num_epochs * args.num_steps_per_epoch 78 | log_interval = int(args.eval_freq * args.num_steps_per_epoch) 79 | 80 | utils.print_banner(f"Training Start", separator="*", num_star=90) 81 | while (training_iters < max_timesteps + 1): 82 | curr_epoch = int(training_iters // int(args.num_steps_per_epoch)) 83 | env.reset() 84 | loss_metric = agent.train(replay_buffer=data_sampler, 85 | batch_size=args.batch_size) 86 | training_iters += 1 87 | # Logging 88 | if training_iters % log_interval == 0: 89 | if loss_metric is not None: 90 | utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90) 91 | logger.record_tabular('Trained Epochs', curr_epoch) 92 | logger.record_tabular('BC Loss', np.mean(loss_metric['bc_loss'])) 93 | logger.record_tabular('QL Loss', np.mean(loss_metric['ql_loss'])) 94 | logger.record_tabular('Distill Loss', np.mean(loss_metric['distill_loss'])) 95 | logger.record_tabular('Actor Loss', np.mean(loss_metric['actor_loss'])) 96 | logger.record_tabular('Critic Loss', np.mean(loss_metric['critic_loss'])) 97 | logger.record_tabular('Gamma Loss', np.mean(loss_metric['gamma_loss'])) 98 | 99 | # Evaluating 100 | eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(agent, 101 | args.env_name, 102 | args.seed, 103 | eval_episodes=args.eval_episodes) 104 | logger.record_tabular('Average Episodic Reward', eval_res) 105 | logger.record_tabular('Average Episodic N-Reward', eval_norm_res) 106 | logger.record_tabular('Average Episodic N-Reward Std', eval_norm_res_std) 107 | logger.dump_tabular() 108 | 109 | if args.save_checkpoints: 110 | agent.save_model(output_dir, curr_epoch) 111 | agent.save_model(output_dir, curr_epoch) 112 | 113 | 114 | # Runs policy for [eval_episodes] episodes and returns average reward 115 | # A fixed seed is used for the eval environment 116 | def eval_policy(policy, env_name, seed, eval_episodes=10): 117 | eval_env = gym.make(env_name) 118 | eval_env.seed(seed + 100) 119 | 120 | scores = [] 121 | for _ in range(eval_episodes): 122 | traj_return = 0. 123 | state, done = eval_env.reset(), False 124 | while not done: 125 | action = policy.sample_action(np.array(state)) 126 | state, reward, done, _ = eval_env.step(action) 127 | traj_return += reward 128 | scores.append(traj_return) 129 | 130 | avg_reward = np.mean(scores) 131 | std_reward = np.std(scores) 132 | 133 | normalized_scores = [eval_env.get_normalized_score(s) for s in scores] 134 | avg_norm_score = eval_env.get_normalized_score(avg_reward) 135 | std_norm_score = np.std(normalized_scores) 136 | 137 | utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}") 138 | return avg_reward, std_reward, avg_norm_score, std_norm_score 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser() 143 | ### Experimental Setups ### 144 | parser.add_argument('--device', default=1, type=int) 145 | parser.add_argument("--env_name", default="antmaze-large-diverse-v0", type=str, help='Mujoco Gym environment') 146 | parser.add_argument("--seed", default=1, type=int, help='random seed (default: 0)') 147 | parser.add_argument("--eval_freq", default=50, type=int) 148 | parser.add_argument("--dir", default="results", type=str) 149 | 150 | parser.add_argument("--pretrain_epochs", default=50, type=int) 151 | parser.add_argument("--repeats", default=1024, type=int) 152 | parser.add_argument("--tau", default=0.005, type=float) 153 | 154 | parser.add_argument("--sigma_max", default=80, type=int) 155 | parser.add_argument("--sigma_min", default=0.002, type=int) 156 | parser.add_argument("--sigma_data", default=0.5, type=int) 157 | parser.add_argument('--save_checkpoints', action='store_true') 158 | 159 | parser.add_argument("--num_steps_per_epoch", default=1000, type=int) 160 | parser.add_argument("--discount", default=0.99, type=float, help='discount factor for reward (default: 0.99)') 161 | 162 | args = parser.parse_args() 163 | args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 164 | args.output_dir = f'{args.dir}' 165 | 166 | if 'antmaze' in args.env_name: 167 | args.reward_tune = 'iql_antmaze' 168 | args.eval_episodes = 100 169 | else: 170 | args.reward_tune = 'no' 171 | args.eval_episodes = 10 if 'v2' in args.env_name else 100 172 | 173 | args.num_epochs = offline_hyperparameters[args.env_name]["num_epochs"] 174 | args.lr = offline_hyperparameters[args.env_name]["lr"] 175 | args.lr_decay = offline_hyperparameters[args.env_name]["lr_decay"] 176 | args.batch_size = offline_hyperparameters[args.env_name]["batch_size"] 177 | args.alpha = offline_hyperparameters[args.env_name]["alpha"] 178 | args.gamma = offline_hyperparameters[args.env_name]["gamma"] 179 | args.expectile = offline_hyperparameters[args.env_name]["expectile"] 180 | 181 | file_name = f'|expect-{args.expectile}' 182 | file_name += f"|alpha-{args.alpha}|gamma-{args.gamma}" 183 | file_name += f'|seed={args.seed}' 184 | file_name += f'|lr={args.lr}' 185 | if args.lr_decay: 186 | file_name += f'|lr_decay' 187 | if args.pretrain_epochs is not None: 188 | file_name += f'|pretrain={args.pretrain_epochs}' 189 | #file_name += f'|{args.env_name}' 190 | 191 | results_dir = os.path.join(args.output_dir, args.env_name, file_name) 192 | if not os.path.exists(results_dir): 193 | os.makedirs(results_dir) 194 | utils.print_banner(f"Saving location: {results_dir}") 195 | 196 | variant = vars(args) 197 | variant.update(version=f"DTQL") 198 | 199 | env = gym.make(args.env_name) 200 | 201 | env.seed(args.seed) 202 | torch.manual_seed(args.seed) 203 | np.random.seed(args.seed) 204 | random.seed(args.seed) 205 | torch.cuda.manual_seed_all(args.seed) 206 | torch.backends.cudnn.deterministic = True 207 | torch.backends.cudnn.benchmark = False 208 | 209 | 210 | state_dim = env.observation_space.shape[0] 211 | action_dim = env.action_space.shape[0] 212 | 213 | variant.update(state_dim=state_dim) 214 | variant.update(action_dim=action_dim) 215 | setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir) 216 | utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}") 217 | 218 | train_agent(env, 219 | state_dim, 220 | action_dim, 221 | args.device, 222 | results_dir, 223 | args) 224 | -------------------------------------------------------------------------------- /agents/dql_kl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.optim.lr_scheduler import CosineAnnealingLR 4 | import torch.nn.functional as F 5 | from pathlib import Path 6 | from diffusion.karras import DiffusionModel 7 | from diffusion.mlps import ScoreNetwork 8 | from agents.model import Critic 9 | from agents.helpers import EMA, get_dmd_loss 10 | import numpy as np 11 | class DQL_KL(object): 12 | def __init__(self, 13 | device, 14 | state_dim, 15 | action_dim, 16 | action_space=None, 17 | discount=0.99, 18 | alpha=1.0, 19 | ema_decay=0.995, 20 | step_start_ema=1000, 21 | update_ema_every=5, 22 | lr=3e-4, 23 | lr_decay=False, 24 | lr_maxt=1000, 25 | sigma_max=80., 26 | sigma_min=0.002, 27 | sigma_data=0.5, 28 | generation_sigma=2.5, 29 | expectile=0.7, 30 | tau = 0.005, 31 | gamma=0, 32 | repeats=1024 33 | ): 34 | """Init critic networks""" 35 | self.critic = Critic(state_dim, action_dim).to(device) 36 | self.critic_target = copy.deepcopy(self.critic) 37 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 38 | 39 | """"Init behaviour cloning network""" 40 | self.bc_actor = ScoreNetwork( 41 | action_dim=action_dim, 42 | hidden_dim=256, 43 | time_embed_dim=16, 44 | cond_dim=state_dim, 45 | cond_mask_prob=0.0, 46 | num_hidden_layers=4, 47 | output_dim=action_dim, 48 | device=device, 49 | cond_conditional=True 50 | ).to(device) 51 | self.bc_actor_target = copy.deepcopy(self.bc_actor) 52 | self.bc_actor_optimizer = torch.optim.Adam(self.bc_actor.parameters(), lr=lr) 53 | 54 | """Init diffusion schedule""" 55 | self.diffusion = DiffusionModel( 56 | sigma_data=sigma_data, 57 | sigma_min=sigma_min, 58 | sigma_max=sigma_max, 59 | device=device, 60 | clip_denoised=True, 61 | max_action=float(action_space.high[0])) 62 | 63 | """Init one-step policy""" 64 | self.distill_actor = ScoreNetwork( 65 | action_dim=action_dim, 66 | hidden_dim=256, 67 | time_embed_dim=16, 68 | cond_dim=state_dim, 69 | cond_mask_prob=0.0, 70 | num_hidden_layers=4, 71 | output_dim=action_dim, 72 | device=device, 73 | cond_conditional=True 74 | ).to(device) 75 | self.distill_actor_target = copy.deepcopy(self.distill_actor) 76 | self.distill_actor_optimizer = torch.optim.Adam(self.distill_actor.parameters(), lr=lr) 77 | 78 | """Init fake score network""" 79 | self.fake_score = ScoreNetwork( 80 | action_dim=action_dim, 81 | hidden_dim=256, 82 | time_embed_dim=16, 83 | cond_dim=state_dim, 84 | cond_mask_prob=0.0, 85 | num_hidden_layers=4, 86 | output_dim=action_dim, 87 | device=device, 88 | cond_conditional=True 89 | ).to(device) 90 | self.fake_score_optimizer = torch.optim.Adam(self.fake_score.parameters(), lr=lr) 91 | 92 | """Back up training parameters""" 93 | self.generation_sigma = generation_sigma 94 | self.tau = tau 95 | self.lr_decay = lr_decay 96 | self.gamma = gamma 97 | self.repeats = repeats 98 | 99 | self.step = 0 100 | self.step_start_ema = step_start_ema 101 | self.ema = EMA(ema_decay) 102 | self.update_ema_every = update_ema_every 103 | 104 | if lr_decay: 105 | self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.) 106 | self.bc_actor_lr_scheduler = CosineAnnealingLR(self.bc_actor_optimizer, T_max=lr_maxt, eta_min=0.) 107 | self.distill_actor_lr_scheduler = CosineAnnealingLR(self.distill_actor_optimizer, T_max=lr_maxt, eta_min=0.) 108 | self.fake_score_lr_scheduler = CosineAnnealingLR(self.fake_score_optimizer, T_max=lr_maxt, eta_min=0.) 109 | 110 | self.state_dim = state_dim 111 | self.action_dim = action_dim 112 | self.discount = discount 113 | self.alpha = alpha # bc weight 114 | self.expectile = expectile 115 | self.device = device 116 | 117 | def step_ema(self): 118 | if self.step < self.step_start_ema: 119 | return 120 | self.ema.update_model_average(self.distill_actor_target, self.distill_actor) 121 | 122 | def pretrain(self,replay_buffer, batch_size=256,pretrain_steps=50000): 123 | for _ in range(pretrain_steps): 124 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 125 | loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state) 126 | self.bc_actor_optimizer.zero_grad() 127 | loss.backward() 128 | self.bc_actor_optimizer.step() 129 | self.bc_loss = loss 130 | 131 | critic_loss = self.q_v_critic_loss(state,action, next_state, reward, not_done) 132 | self.critic_optimizer.zero_grad() 133 | critic_loss.backward() 134 | self.critic_optimizer.step() 135 | 136 | bc_actor_state_dict = self.bc_actor.state_dict() 137 | self.distill_actor.load_state_dict(bc_actor_state_dict) 138 | self.fake_score.load_state_dict(bc_actor_state_dict) 139 | def load_or_pretrain_models(self, dir, replay_buffer, batch_size, pretrain_steps,num_steps_per_epoch): 140 | # Paths for the models 141 | actor_path = Path(dir) / f'diffusion_pretrained_{pretrain_steps // num_steps_per_epoch}.pth' 142 | critic_path = Path(dir) / f'critic_pretrained_{pretrain_steps // num_steps_per_epoch}.pth' 143 | 144 | # Check if both models exist 145 | if actor_path.exists() and critic_path.exists(): 146 | try: 147 | # Load the models 148 | self.bc_actor.load_state_dict(torch.load(actor_path, map_location=self.device)) 149 | self.critic.load_state_dict(torch.load(critic_path, map_location=self.device)) 150 | bc_actor_state_dict = self.bc_actor.state_dict() 151 | self.distill_actor.load_state_dict(bc_actor_state_dict) 152 | self.fake_score.load_state_dict(torch.load(actor_path)) 153 | except Exception as e: 154 | print(f"Failed to load models: {e}") 155 | else: 156 | # Begin pretraining if the models do not exist 157 | print("Models not found, starting pretraining...") 158 | self.pretrain(replay_buffer, batch_size, pretrain_steps) 159 | torch.save(self.bc_actor.state_dict(), actor_path) 160 | torch.save(self.critic.state_dict(), critic_path) 161 | print(f"Saved successfully to {dir}") 162 | 163 | def train(self, replay_buffer, batch_size=256): 164 | # initialize 165 | self.bc_loss = torch.tensor([0.]).to(self.device) 166 | self.critic_loss = torch.tensor([0.]).to(self.device) 167 | metric = {'bc_loss': [], 'distill_loss':[], 'ql_loss': [], 'actor_loss': [], 'critic_loss': [], 'gamma_loss': []} 168 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 169 | 170 | """ Q Training """ 171 | critic_loss = self.q_v_critic_loss(state, action, next_state, reward, not_done) 172 | 173 | self.critic_loss = critic_loss 174 | self.critic_optimizer.zero_grad() 175 | critic_loss.backward() 176 | self.critic_optimizer.step() 177 | 178 | """ Diffusion Policy Training """ 179 | bc_loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state) 180 | self.bc_actor_optimizer.zero_grad() 181 | bc_loss.backward() 182 | self.bc_actor_optimizer.step() 183 | self.bc_loss = bc_loss 184 | 185 | """Distill Policy Training""" 186 | new_action = self.get_agent_sample(self.distill_actor, given_state=state, 187 | generation_sigma=self.generation_sigma) 188 | distill_loss = get_dmd_loss(self.diffusion,self.bc_actor,self.fake_score,new_action,state) 189 | q_loss = -self.critic.q_min(state, new_action).mean() 190 | 191 | actor_loss = self.alpha * distill_loss + q_loss 192 | self.distill_actor_optimizer.zero_grad() 193 | actor_loss.backward() 194 | self.distill_actor_optimizer.step() 195 | 196 | """Training fake score""" 197 | fake_score_loss = self.diffusion.diffusion_train_step(self.fake_score, new_action.detach(), state) 198 | self.fake_score_optimizer.zero_grad() 199 | fake_score_loss.backward() 200 | self.fake_score_optimizer.step() 201 | 202 | """ Step Target network """ 203 | if self.step % self.update_ema_every == 0: 204 | self.step_ema() 205 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 206 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 207 | 208 | self.step += 1 209 | 210 | metric['actor_loss'].append(actor_loss.item()) 211 | metric['bc_loss'].append(self.bc_loss.item()) 212 | metric['ql_loss'].append(q_loss.item()) 213 | metric['critic_loss'].append(self.critic_loss.item()) 214 | metric['distill_loss'].append(distill_loss.item()) 215 | metric['gamma_loss'].append(np.nan) 216 | 217 | if self.lr_decay: 218 | self.bc_actor_lr_scheduler.step() 219 | self.distill_actor_lr_scheduler.step() 220 | self.critic_lr_scheduler.step() 221 | self.fake_score_lr_scheduler.step() 222 | return metric 223 | 224 | def sample_action(self, state): 225 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 226 | state_rpt = torch.repeat_interleave(state, repeats=self.repeats, dim=0) 227 | with torch.no_grad(): 228 | action = self.get_agent_sample(self.distill_actor, given_state=state_rpt, 229 | generation_sigma=self.generation_sigma 230 | ) 231 | q_value = self.critic_target.q_min(state_rpt, action).flatten() 232 | idx = torch.multinomial(F.softmax(q_value), 1) 233 | action = action[idx].cpu().data.numpy().flatten() 234 | return action 235 | 236 | def save_model(self, dir, id=None): 237 | if id is not None: 238 | torch.save(self.bc_actor.state_dict(), f'{dir}/bc_actor_{id}.pth') 239 | torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth') 240 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor_{id}.pth') 241 | torch.save(self.fake_score.state_dict(), f'{dir}/fake_score_{id}.pth') 242 | else: 243 | torch.save(self.bc_actor.state_dict(), f'{dir}/actor.pth') 244 | torch.save(self.critic.state_dict(), f'{dir}/critic.pth') 245 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor.pth') 246 | torch.save(self.fake_score.state_dict(), f'{dir}/fake_score.pth') 247 | 248 | def load_model(self, dir, id=None): 249 | if id is not None: 250 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor_{id}.pth')) 251 | self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth')) 252 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor_{id}.pth')) 253 | self.fake_score.load_state_dict(torch.load(f'{dir}/fake_score_{id}.pth')) 254 | else: 255 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor.pth')) 256 | self.critic.load_state_dict(torch.load(f'{dir}/critic.pth')) 257 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor.pth')) 258 | self.fake_score.load_state_dict(torch.load(f'{dir}/fake_score.pth')) 259 | print(f"Models loaded successfully from {dir}") 260 | 261 | 262 | def get_agent_sample(self,model,given_state,generation_sigma): 263 | noise = torch.randn((given_state.shape[0],self.action_dim)).to(given_state.device) * generation_sigma 264 | action = model(noise, given_state, torch.tensor([generation_sigma]).to(given_state.device)) 265 | return action 266 | 267 | def q_v_critic_loss(self,state,action, next_state, reward, not_done): 268 | def expectile_loss(diff, expectile=0.8): 269 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 270 | return weight * (diff ** 2) 271 | 272 | with torch.no_grad(): 273 | q = self.critic.q_min(state, action) 274 | v = self.critic.v(state) 275 | value_loss = expectile_loss(q - v, self.expectile).mean() 276 | 277 | current_q1, current_q2 = self.critic(state, action) 278 | with torch.no_grad(): 279 | next_v = self.critic.v(next_state) 280 | target_q = (reward + not_done * self.discount * next_v).detach() 281 | 282 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) 283 | return critic_loss 284 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from contextlib import contextmanager 3 | import numpy as np 4 | import os.path as osp 5 | import sys 6 | import datetime 7 | import dateutil.tz 8 | import csv 9 | import json 10 | import pickle 11 | import errno 12 | from collections import OrderedDict 13 | from numbers import Number 14 | import os 15 | 16 | from tabulate import tabulate 17 | 18 | def dict_to_safe_json(d): 19 | """ 20 | Convert each value in the dictionary into a JSON'able primitive. 21 | :param d: 22 | :return: 23 | """ 24 | new_d = {} 25 | for key, item in d.items(): 26 | if safe_json(item): 27 | new_d[key] = item 28 | else: 29 | if isinstance(item, dict): 30 | new_d[key] = dict_to_safe_json(item) 31 | else: 32 | new_d[key] = str(item) 33 | return new_d 34 | 35 | 36 | def safe_json(data): 37 | if data is None: 38 | return True 39 | elif isinstance(data, (bool, int, float)): 40 | return True 41 | elif isinstance(data, (tuple, list)): 42 | return all(safe_json(x) for x in data) 43 | elif isinstance(data, dict): 44 | return all(isinstance(k, str) and safe_json(v) for k, v in data.items()) 45 | return False 46 | 47 | def create_exp_name(exp_prefix, exp_id=0, seed=0): 48 | """ 49 | Create a semi-unique experiment name that has a timestamp 50 | :param exp_prefix: 51 | :param exp_id: 52 | :return: 53 | """ 54 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 56 | return "%s_%s_%04d--s-%d" % (exp_prefix, timestamp, exp_id, seed) 57 | 58 | def create_log_dir( 59 | exp_prefix, 60 | exp_id=0, 61 | seed=0, 62 | base_log_dir=None, 63 | include_exp_prefix_sub_dir=True, 64 | ): 65 | """ 66 | Creates and returns a unique log directory. 67 | :param exp_prefix: All experiments with this prefix will have log 68 | directories be under this directory. 69 | :param exp_id: The number of the specific experiment run within this 70 | experiment. 71 | :param base_log_dir: The directory where all log should be saved. 72 | :return: 73 | """ 74 | exp_name = create_exp_name(exp_prefix, exp_id=exp_id, 75 | seed=seed) 76 | if base_log_dir is None: 77 | base_log_dir = './data' 78 | if include_exp_prefix_sub_dir: 79 | log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name) 80 | else: 81 | log_dir = osp.join(base_log_dir, exp_name) 82 | if osp.exists(log_dir): 83 | print("WARNING: Log directory already exists {}".format(log_dir), flush=True) 84 | os.makedirs(log_dir, exist_ok=True) 85 | return log_dir 86 | 87 | 88 | def setup_logger( 89 | exp_prefix="default", 90 | variant=None, 91 | text_log_file="debug.log", 92 | variant_log_file="variant.json", 93 | tabular_log_file="progress.csv", 94 | snapshot_mode="last", 95 | snapshot_gap=1, 96 | log_tabular_only=False, 97 | log_dir=None, 98 | git_infos=None, 99 | script_name=None, 100 | **create_log_dir_kwargs 101 | ): 102 | """ 103 | Set up logger to have some reasonable default settings. 104 | Will save log output to 105 | based_log_dir/exp_prefix/exp_name. 106 | exp_name will be auto-generated to be unique. 107 | If log_dir is specified, then that directory is used as the output dir. 108 | :param exp_prefix: The sub-directory for this specific experiment. 109 | :param variant: 110 | :param text_log_file: 111 | :param variant_log_file: 112 | :param tabular_log_file: 113 | :param snapshot_mode: 114 | :param log_tabular_only: 115 | :param snapshot_gap: 116 | :param log_dir: 117 | :param git_infos: 118 | :param script_name: If set, save the script name to this. 119 | :return: 120 | """ 121 | first_time = log_dir is None 122 | if first_time: 123 | log_dir = create_log_dir(exp_prefix, **create_log_dir_kwargs) 124 | 125 | if variant is not None: 126 | logger.log("Variant:") 127 | logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) 128 | variant_log_path = osp.join(log_dir, variant_log_file) 129 | logger.log_variant(variant_log_path, variant) 130 | 131 | tabular_log_path = osp.join(log_dir, tabular_log_file) 132 | text_log_path = osp.join(log_dir, text_log_file) 133 | 134 | logger.add_text_output(text_log_path) 135 | if first_time: 136 | logger.add_tabular_output(tabular_log_path) 137 | else: 138 | logger._add_output(tabular_log_path, logger._tabular_outputs, 139 | logger._tabular_fds, mode='a') 140 | for tabular_fd in logger._tabular_fds: 141 | logger._tabular_header_written.add(tabular_fd) 142 | logger.set_snapshot_dir(log_dir) 143 | logger.set_snapshot_mode(snapshot_mode) 144 | logger.set_snapshot_gap(snapshot_gap) 145 | logger.set_log_tabular_only(log_tabular_only) 146 | exp_name = log_dir.split("/")[-1] 147 | logger.push_prefix("[%s] " % exp_name) 148 | 149 | if script_name is not None: 150 | with open(osp.join(log_dir, "script_name.txt"), "w") as f: 151 | f.write(script_name) 152 | return log_dir 153 | 154 | 155 | def create_stats_ordered_dict( 156 | name, 157 | data, 158 | stat_prefix=None, 159 | always_show_all_stats=True, 160 | exclude_max_min=False, 161 | ): 162 | if stat_prefix is not None: 163 | name = "{}{}".format(stat_prefix, name) 164 | if isinstance(data, Number): 165 | return OrderedDict({name: data}) 166 | 167 | if len(data) == 0: 168 | return OrderedDict() 169 | 170 | if isinstance(data, tuple): 171 | ordered_dict = OrderedDict() 172 | for number, d in enumerate(data): 173 | sub_dict = create_stats_ordered_dict( 174 | "{0}_{1}".format(name, number), 175 | d, 176 | ) 177 | ordered_dict.update(sub_dict) 178 | return ordered_dict 179 | 180 | if isinstance(data, list): 181 | try: 182 | iter(data[0]) 183 | except TypeError: 184 | pass 185 | else: 186 | data = np.concatenate(data) 187 | 188 | if (isinstance(data, np.ndarray) and data.size == 1 189 | and not always_show_all_stats): 190 | return OrderedDict({name: float(data)}) 191 | 192 | stats = OrderedDict([ 193 | (name + ' Mean', np.mean(data)), 194 | (name + ' Std', np.std(data)), 195 | ]) 196 | if not exclude_max_min: 197 | stats[name + ' Max'] = np.max(data) 198 | stats[name + ' Min'] = np.min(data) 199 | return stats 200 | 201 | 202 | class TerminalTablePrinter(object): 203 | def __init__(self): 204 | self.headers = None 205 | self.tabulars = [] 206 | 207 | def print_tabular(self, new_tabular): 208 | if self.headers is None: 209 | self.headers = [x[0] for x in new_tabular] 210 | else: 211 | assert len(self.headers) == len(new_tabular) 212 | self.tabulars.append([x[1] for x in new_tabular]) 213 | self.refresh() 214 | 215 | def refresh(self): 216 | import os 217 | rows, columns = os.popen('stty size', 'r').read().split() 218 | tabulars = self.tabulars[-(int(rows) - 3):] 219 | sys.stdout.write("\x1b[2J\x1b[H") 220 | sys.stdout.write(tabulate(tabulars, self.headers)) 221 | sys.stdout.write("\n") 222 | 223 | 224 | class MyEncoder(json.JSONEncoder): 225 | def default(self, o): 226 | if isinstance(o, type): 227 | return {'$class': o.__module__ + "." + o.__name__} 228 | elif isinstance(o, Enum): 229 | return { 230 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 231 | } 232 | elif callable(o): 233 | return { 234 | '$function': o.__module__ + "." + o.__name__ 235 | } 236 | return json.JSONEncoder.default(self, o) 237 | 238 | 239 | def mkdir_p(path): 240 | try: 241 | os.makedirs(path) 242 | except OSError as exc: # Python >2.5 243 | if exc.errno == errno.EEXIST and os.path.isdir(path): 244 | pass 245 | else: 246 | raise 247 | 248 | 249 | class Logger(object): 250 | def __init__(self): 251 | self._prefixes = [] 252 | self._prefix_str = '' 253 | 254 | self._tabular_prefixes = [] 255 | self._tabular_prefix_str = '' 256 | 257 | self._tabular = [] 258 | 259 | self._text_outputs = [] 260 | self._tabular_outputs = [] 261 | 262 | self._text_fds = {} 263 | self._tabular_fds = {} 264 | self._tabular_header_written = set() 265 | 266 | self._snapshot_dir = None 267 | self._snapshot_mode = 'all' 268 | self._snapshot_gap = 1 269 | 270 | self._log_tabular_only = False 271 | self._header_printed = False 272 | self.table_printer = TerminalTablePrinter() 273 | 274 | def reset(self): 275 | self.__init__() 276 | 277 | def _add_output(self, file_name, arr, fds, mode='a'): 278 | if file_name not in arr: 279 | mkdir_p(os.path.dirname(file_name)) 280 | arr.append(file_name) 281 | fds[file_name] = open(file_name, mode) 282 | 283 | def _remove_output(self, file_name, arr, fds): 284 | if file_name in arr: 285 | fds[file_name].close() 286 | del fds[file_name] 287 | arr.remove(file_name) 288 | 289 | def push_prefix(self, prefix): 290 | self._prefixes.append(prefix) 291 | self._prefix_str = ''.join(self._prefixes) 292 | 293 | def add_text_output(self, file_name): 294 | self._add_output(file_name, self._text_outputs, self._text_fds, 295 | mode='a') 296 | 297 | def remove_text_output(self, file_name): 298 | self._remove_output(file_name, self._text_outputs, self._text_fds) 299 | 300 | def add_tabular_output(self, file_name, relative_to_snapshot_dir=False): 301 | if relative_to_snapshot_dir: 302 | file_name = osp.join(self._snapshot_dir, file_name) 303 | self._add_output(file_name, self._tabular_outputs, self._tabular_fds, 304 | mode='w') 305 | 306 | def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False): 307 | if relative_to_snapshot_dir: 308 | file_name = osp.join(self._snapshot_dir, file_name) 309 | if self._tabular_fds[file_name] in self._tabular_header_written: 310 | self._tabular_header_written.remove(self._tabular_fds[file_name]) 311 | self._remove_output(file_name, self._tabular_outputs, self._tabular_fds) 312 | 313 | def set_snapshot_dir(self, dir_name): 314 | self._snapshot_dir = dir_name 315 | 316 | def get_snapshot_dir(self, ): 317 | return self._snapshot_dir 318 | 319 | def get_snapshot_mode(self, ): 320 | return self._snapshot_mode 321 | 322 | def set_snapshot_mode(self, mode): 323 | self._snapshot_mode = mode 324 | 325 | def get_snapshot_gap(self, ): 326 | return self._snapshot_gap 327 | 328 | def set_snapshot_gap(self, gap): 329 | self._snapshot_gap = gap 330 | 331 | def set_log_tabular_only(self, log_tabular_only): 332 | self._log_tabular_only = log_tabular_only 333 | 334 | def get_log_tabular_only(self, ): 335 | return self._log_tabular_only 336 | 337 | def log(self, s, with_prefix=True, with_timestamp=True): 338 | out = s 339 | if with_prefix: 340 | out = self._prefix_str + out 341 | if with_timestamp: 342 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 343 | timestamp = now.strftime('%y-%m-%d.%H:%M') # :%S 344 | out = "%s|%s" % (timestamp, out) 345 | if not self._log_tabular_only: 346 | # Also log to stdout 347 | print(out, flush=True) 348 | for fd in list(self._text_fds.values()): 349 | fd.write(out + '\n') 350 | fd.flush() 351 | sys.stdout.flush() 352 | 353 | def record_tabular(self, key, val): 354 | self._tabular.append((self._tabular_prefix_str + str(key), str(val))) 355 | 356 | def record_dict(self, d, prefix=None): 357 | if prefix is not None: 358 | self.push_tabular_prefix(prefix) 359 | for k, v in d.items(): 360 | self.record_tabular(k, v) 361 | if prefix is not None: 362 | self.pop_tabular_prefix() 363 | 364 | def push_tabular_prefix(self, key): 365 | self._tabular_prefixes.append(key) 366 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 367 | 368 | def pop_tabular_prefix(self, ): 369 | del self._tabular_prefixes[-1] 370 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 371 | 372 | def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'): 373 | """ 374 | Data saved here will always override the last entry 375 | 376 | :param data: Something pickle'able. 377 | """ 378 | file_name = osp.join(self._snapshot_dir, file_name) 379 | if mode == 'joblib': 380 | import joblib 381 | joblib.dump(data, file_name, compress=3) 382 | elif mode == 'pickle': 383 | pickle.dump(data, open(file_name, "wb")) 384 | else: 385 | raise ValueError("Invalid mode: {}".format(mode)) 386 | return file_name 387 | 388 | def get_table_dict(self, ): 389 | return dict(self._tabular) 390 | 391 | def get_table_key_set(self, ): 392 | return set(key for key, value in self._tabular) 393 | 394 | @contextmanager 395 | def prefix(self, key): 396 | self.push_prefix(key) 397 | try: 398 | yield 399 | finally: 400 | self.pop_prefix() 401 | 402 | @contextmanager 403 | def tabular_prefix(self, key): 404 | self.push_tabular_prefix(key) 405 | yield 406 | self.pop_tabular_prefix() 407 | 408 | def log_variant(self, log_file, variant_data): 409 | mkdir_p(os.path.dirname(log_file)) 410 | with open(log_file, "w") as f: 411 | json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder) 412 | 413 | def record_tabular_misc_stat(self, key, values, placement='back'): 414 | if placement == 'front': 415 | prefix = "" 416 | suffix = key 417 | else: 418 | prefix = key 419 | suffix = "" 420 | if len(values) > 0: 421 | self.record_tabular(prefix + "Average" + suffix, np.average(values)) 422 | self.record_tabular(prefix + "Std" + suffix, np.std(values)) 423 | self.record_tabular(prefix + "Median" + suffix, np.median(values)) 424 | self.record_tabular(prefix + "Min" + suffix, np.min(values)) 425 | self.record_tabular(prefix + "Max" + suffix, np.max(values)) 426 | else: 427 | self.record_tabular(prefix + "Average" + suffix, np.nan) 428 | self.record_tabular(prefix + "Std" + suffix, np.nan) 429 | self.record_tabular(prefix + "Median" + suffix, np.nan) 430 | self.record_tabular(prefix + "Min" + suffix, np.nan) 431 | self.record_tabular(prefix + "Max" + suffix, np.nan) 432 | 433 | def dump_tabular(self, *args, **kwargs): 434 | wh = kwargs.pop("write_header", None) 435 | if len(self._tabular) > 0: 436 | if self._log_tabular_only: 437 | self.table_printer.print_tabular(self._tabular) 438 | else: 439 | for line in tabulate(self._tabular).split('\n'): 440 | self.log(line, *args, **kwargs) 441 | tabular_dict = dict(self._tabular) 442 | # Also write to the csv files 443 | # This assumes that the keys in each iteration won't change! 444 | for tabular_fd in list(self._tabular_fds.values()): 445 | writer = csv.DictWriter(tabular_fd, 446 | fieldnames=list(tabular_dict.keys())) 447 | if wh or ( 448 | wh is None and tabular_fd not in self._tabular_header_written): 449 | writer.writeheader() 450 | self._tabular_header_written.add(tabular_fd) 451 | writer.writerow(tabular_dict) 452 | tabular_fd.flush() 453 | del self._tabular[:] 454 | 455 | def pop_prefix(self, ): 456 | del self._prefixes[-1] 457 | self._prefix_str = ''.join(self._prefixes) 458 | 459 | def save_itr_params(self, itr, params): 460 | if self._snapshot_dir: 461 | if self._snapshot_mode == 'all': 462 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 463 | pickle.dump(params, open(file_name, "wb")) 464 | elif self._snapshot_mode == 'last': 465 | # override previous params 466 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 467 | pickle.dump(params, open(file_name, "wb")) 468 | elif self._snapshot_mode == "gap": 469 | if itr % self._snapshot_gap == 0: 470 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 471 | pickle.dump(params, open(file_name, "wb")) 472 | elif self._snapshot_mode == "gap_and_last": 473 | if itr % self._snapshot_gap == 0: 474 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 475 | pickle.dump(params, open(file_name, "wb")) 476 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 477 | pickle.dump(params, open(file_name, "wb")) 478 | elif self._snapshot_mode == 'none': 479 | pass 480 | else: 481 | raise NotImplementedError 482 | 483 | 484 | logger = Logger() 485 | 486 | --------------------------------------------------------------------------------