├── agents
├── __init__.py
├── helpers.py
├── model.py
├── dtql.py
└── dql_kl.py
├── utils
├── __init__.py
├── data_sampler.py
├── utils.py
└── logger.py
├── diffusion
├── __init__.py
├── utils.py
├── mlps.py
└── karras.py
├── assets
├── DTQL_toy.pdf
└── DTQL_toy.png
├── install_env.sh
├── requirements.txt
├── toy_tasks
├── run_toy.sh
├── data_generator.py
└── toy_main.py
├── README.md
└── main.py
/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/DTQL_toy.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TianyuCodings/Diffusion_Trusted_Q_Learning/HEAD/assets/DTQL_toy.pdf
--------------------------------------------------------------------------------
/assets/DTQL_toy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TianyuCodings/Diffusion_Trusted_Q_Learning/HEAD/assets/DTQL_toy.png
--------------------------------------------------------------------------------
/install_env.sh:
--------------------------------------------------------------------------------
1 | conda create -n rl python=3.9
2 | conda activate rl
3 |
4 | # make sure the cuda verion fits the local machine
5 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
6 | pip install mujoco
7 | pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
8 | pip install "cython<3"
9 | pip install python-dateutil
10 |
11 |
12 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
13 |
14 | # do this in the home directory, it will extract mujoco210 into /home/username/.mujoco/mujoco210...
15 | mkdir -p .mujoco/
16 | tar -xvf mujoco210-linux-x86_64.tar.gz -C .mujoco/
17 |
18 | # add this to .bashrc
19 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tychen/.mujoco/mujoco210/bin
20 | source .bashrc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.1.0
2 | certifi==2022.5.18.1
3 | cffi==1.15.0
4 | charset-normalizer==2.0.12
5 | click==8.1.3
6 | cloudpickle==2.1.0
7 | cycler==0.11.0
8 | Cython==0.29.30
9 | d4rl @ git+https://github.com/rail-berkeley/d4rl@d842aa194b416e564e54b0730d9f934e3e32f854
10 | dm-control @ git+https://github.com/deepmind/dm_control@41d0c7383153f9ca6c12f8e865ef5e73a98759bd
11 | dm-env==1.5
12 | dm-tree==0.1.7
13 | fasteners==0.17.3
14 | fonttools==4.33.3
15 | glfw==2.5.3
16 | gym==0.24.1
17 | gym-notices==0.0.7
18 | h5py==3.7.0
19 | idna==3.3
20 | imageio==2.19.3
21 | importlib-metadata==4.11.4
22 | kiwisolver==1.4.2
23 | labmaze==1.0.5
24 | lxml==4.9.0
25 | matplotlib==3.5.2
26 | mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc
27 | mujoco==2.2.0
28 | mujoco-py==2.0.2.13
29 | numpy==1.22.4
30 | packaging==21.3
31 | pandas==1.4.2
32 | Pillow==9.1.1
33 | protobuf==4.21.1
34 | pybullet==3.2.5
35 | pycparser==2.21
36 | PyOpenGL==3.1.6
37 | pyparsing==2.4.7
38 | python-dateutil==2.8.2
39 | pytz==2022.1
40 | requests==2.28.0
41 | scipy==1.8.1
42 | six==1.16.0
43 | termcolor==1.1.0
44 | torch==1.10.1+cu111
45 | torchaudio==0.10.1+rocm4.1
46 | torchvision==0.11.2+cu111
47 | tqdm==4.64.0
48 | typing_extensions==4.2.0
49 | urllib3==1.26.9
50 | zipp==3.8.0
--------------------------------------------------------------------------------
/toy_tasks/run_toy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | source activate torch-rl
3 | python toy_main.py --device=5 --env_name=25_gaussian --distill_loss=diffusion --reward_type=near --actor=sac &
4 | python toy_main.py --device=1 --env_name=25_gaussian --distill_loss=diffusion --reward_type=far --actor=sac &
5 | python toy_main.py --device=2 --env_name=25_gaussian --distill_loss=diffusion --reward_type=hard --actor=sac --seed=2&
6 |
7 |
8 | python toy_main.py --device=2 --env_name=25_gaussian --distill_loss=diffusion --reward_type=near --actor=sac --gamma=0.005&
9 | python toy_main.py --device=3 --env_name=25_gaussian --distill_loss=diffusion --reward_type=far --actor=sac --gamma=0.005&
10 | python toy_main.py --device=6 --env_name=25_gaussian --distill_loss=diffusion --reward_type=hard --actor=sac --gamma=0.005 --seed=2&
11 |
12 | python toy_main.py --device=3 --env_name=25_gaussian --distill_loss=dmd --reward_type=near --actor=implicit &
13 | python toy_main.py --device=7 --env_name=25_gaussian --distill_loss=dmd --reward_type=far --actor=implicit --train_epochs=2000&
14 | python toy_main.py --device=1 --env_name=25_gaussian --distill_loss=dmd --reward_type=hard --actor=implicit --train_epochs=2500&
15 |
16 | python toy_main.py --device=2 --env_name=swiss_roll_2D --distill_loss=diffusion --reward_type=near --actor=sac --seed=2 &
17 | python toy_main.py --device=3 --env_name=swiss_roll_2D --distill_loss=diffusion --reward_type=far --actor=sac --eta=5 --seed=2 &
18 |
19 | python toy_main.py --device=4 --env_name=swiss_roll_2D --distill_loss=dmd --reward_type=near --actor=implicit --train_epochs=1500&
20 | python toy_main.py --device=1 --env_name=swiss_roll_2D --distill_loss=dmd --reward_type=far --actor=implicit &
21 |
--------------------------------------------------------------------------------
/utils/data_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Data_Sampler(object):
5 | def __init__(self, data, device, reward_tune='no'):
6 |
7 | self.device = device
8 |
9 | self.state = torch.from_numpy(data['observations']).float().to(device)
10 | self.action = torch.from_numpy(data['actions']).float().to(device)
11 | self.next_state = torch.from_numpy(data['next_observations']).float().to(device)
12 | reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(device)
13 | self.not_done = 1. - torch.from_numpy(data['terminals']).view(-1, 1).float().to(device)
14 |
15 | self.size = self.state.shape[0]
16 | self.state_dim = self.state.shape[1]
17 | self.action_dim = self.action.shape[1]
18 |
19 | if reward_tune == 'normalize':
20 | reward = (reward - reward.mean()) / reward.std()
21 | elif reward_tune == 'iql_antmaze':
22 | reward = reward - 1.0
23 | elif reward_tune == 'iql_locomotion':
24 | reward = iql_normalize(reward, self.not_done)
25 | elif reward_tune == 'cql_antmaze':
26 | reward = (reward - 0.5) * 4.0
27 | elif reward_tune == 'antmaze':
28 | reward = (reward - 0.25) * 2.0
29 | self.reward = reward
30 |
31 | def sample(self, batch_size):
32 | ind = torch.randint(0, self.size, size=(batch_size,))
33 |
34 | return (
35 | self.state[ind],
36 | self.action[ind],
37 | self.next_state[ind],
38 | self.reward[ind],
39 | self.not_done[ind],
40 | )
41 |
42 | def iql_normalize(reward, not_done):
43 | trajs_rt = []
44 | episode_return = 0.0
45 | for i in range(len(reward)):
46 | episode_return += reward[i]
47 | if not not_done[i]:
48 | trajs_rt.append(episode_return)
49 | episode_return = 0.0
50 | rt_max, rt_min = torch.max(torch.tensor(trajs_rt)), torch.min(torch.tensor(trajs_rt))
51 | reward /= (rt_max - rt_min)
52 | reward *= 1000.
53 | return reward
54 |
--------------------------------------------------------------------------------
/agents/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | class EMA():
4 | '''
5 | empirical moving average
6 | '''
7 | def __init__(self, beta):
8 | super().__init__()
9 | self.beta = beta
10 |
11 | def update_model_average(self, ma_model, current_model):
12 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
13 | old_weight, up_weight = ma_params.data, current_params.data
14 | ma_params.data = self.update_average(old_weight, up_weight)
15 |
16 | def update_average(self, old, new):
17 | if old is None:
18 | return new
19 | return old * self.beta + (1 - self.beta) * new
20 |
21 |
22 | def get_dmd_loss(diffusion_model, true_score_model, fake_score_model, fake_action_data, state_data):
23 | noise = torch.randn_like(fake_action_data)
24 | fake_score_model.eval()
25 | true_score_model.eval()
26 | with torch.no_grad():
27 | pred_real_action, _, t_chosen = diffusion_model.diffusion_train_step(model=true_score_model,
28 | x=fake_action_data, cond=state_data,
29 | noise=noise, t_chosen=None,
30 | return_denoised=True)
31 |
32 | pred_fake_action, _, t_chosen = diffusion_model.diffusion_train_step(model=fake_score_model,
33 | x=fake_action_data, cond=state_data,
34 | noise=noise, t_chosen=t_chosen,
35 | return_denoised=True)
36 | weighting_factor = (fake_action_data - pred_real_action).abs().mean(axis=1).reshape(-1, 1)
37 | grad = (pred_fake_action - pred_real_action) / weighting_factor
38 | distill_loss = 0.5 * F.mse_loss(fake_action_data, (fake_action_data - grad).detach())
39 | return distill_loss
--------------------------------------------------------------------------------
/diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 |
6 | def append_dims(x, target_dims):
7 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def append_zero(action):
15 | return torch.cat([action, action.new_zeros([1])])
16 |
17 |
18 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
19 | """Draws samples from a lognormal distribution."""
20 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
21 |
22 |
23 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
24 | """Draws samples from an optionally truncated log-logistic distribution."""
25 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
26 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
27 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
28 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
29 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
30 | return u.logit().mul(scale).add(loc).exp().to(dtype)
31 |
32 |
33 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
34 | """Draws samples from an log-uniform distribution."""
35 | min_value = math.log(min_value)
36 | max_value = math.log(max_value)
37 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
38 |
39 |
40 | def rand_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
41 | """Draws samples from an uniform distribution."""
42 | return torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value
43 |
44 |
45 | def rand_discrete(shape, values, device='cpu', dtype=torch.float32):
46 | probs = [1 / len(values)] * len(values) # set equal probability for all values
47 | return torch.tensor(np.random.choice(values, size=shape, p=probs), device=device, dtype=dtype)
48 |
49 |
50 | def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
51 | """Draws samples from a truncated v-diffusion training timestep distribution."""
52 | min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
53 | max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
54 | u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
55 | return torch.tan(u * math.pi / 2) * sigma_data
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Diffusion Trusted Q-Learning for Offline RL — Official PyTorch Implementation
2 |
3 | **Diffusion Policies creating a Trust Region for Offline Reinforcement Learning**
4 | [Tianyu Chen](https://scholar.google.com/citations?user=8Aum3V8AAAAJ&hl=en), [Zhendong Wang](https://zhendong-wang.github.io/), and [Mingyuan Zhou](https://mingyuanzhou.github.io/)
https://arxiv.org/abs/2405.19690
5 |
6 |
7 | Abstract: *Offline reinforcement learning (RL) leverages pre-collected
8 | datasets to train optimal policies. Diffusion Q-Learning (DQL), introducing
9 | diffusion models as a powerful and expressive policy class, significantly
10 | boosts the performance of offline RL. However, its reliance on iterative
11 | denoising sampling to generate actions slows down both training and
12 | inference. While several recent attempts have tried to accelerate
13 | diffusion-QL, the improvement in training and/or inference speed often
14 | results in degraded performance. In this paper, we introduce a dual
15 | policy approach, Diffusion Trusted Q-Learning (DTQL), which comprises a
16 | diffusion policy for pure behavior cloning and a practical one-step policy.
17 | We bridge the two polices by a newly introduced diffusion trust region
18 | loss. The diffusion policy maintains expressiveness, while the trust
19 | region loss directs the one-step policy to explore freely and seek
20 | modes within the region defined by the diffusion policy. DTQL
21 | eliminates the need for iterative denoising sampling during both training
22 | and inference, making it remarkably computationally efficient. We
23 | evaluate its effectiveness and algorithmic characteristics against
24 | popular Kullback-Leibler (KL) based distillation methods in 2D bandit
25 | scenarios and gym tasks. We then show that DTQL could not only outperform
26 | other methods on the majority of the D4RL benchmark tasks but also
27 | demonstrate efficiency in training and inference speeds.*
28 |
29 | ## Introduction
30 |
31 | We introduce a dual-policy approach, Diffusion Trusted Q-Learning (DTQL): a diffusion policy for pure behavior cloning and a one-step policy for actual depolyment.
32 | We bridge the two policies through our newly introduced diffusion trust region loss.
33 | The loss ensures that the generated action lies within the in-sample datasets' action manifold.
34 | With the gradient of the Q-function, it allows actions to freely move within the in-sample data manifold and gravitate towards high-reward regions.
35 |
36 | We compare our behaviour regularization loss (diffusion trusted region loss) with Kullback–Leibler based behaviour regularization loss. We tested their differential impact on behavior regularization, using a trained Q-function for policy improvement. Red points represent actions generated from the one-step policy.
37 |
38 | 
39 |
40 |
41 | ## Experiments
42 |
43 | ### Requirements
44 | Installations of [PyTorch](https://pytorch.org/), [MuJoCo](https://github.com/deepmind/mujoco), and [D4RL](https://github.com/Farama-Foundation/D4RL) are needed. Please see the ``requirements.txt`` and ``install_env.sh`` for environment set up details.
45 |
46 | ### Running
47 | Running experiments based our code could be quite easy, so below we use `halfcheetah-medium-v2` dataset as an example.
48 |
49 | ```.bash
50 | python main.py --device=0 --env_name=halfcheetah-medium-v2 --seed=1 --dir=results
51 | ```
52 |
53 | Hyperparameters for Diffusion-QL have been hard coded in `main.py` for easily reproducing our reported results.
54 | Definitely, there could exist better hyperparameter settings. Feel free to have your own modifications.
55 |
56 | ## Citation
57 |
58 | If you find this open source release useful, please cite in your paper:
59 | ```
60 | @misc{chen2024diffusion,
61 | title={Diffusion Policies creating a Trust Region for Offline Reinforcement Learning},
62 | author={Tianyu Chen and Zhendong Wang and Mingyuan Zhou},
63 | year={2024},
64 | eprint={2405.19690},
65 | archivePrefix={arXiv},
66 | primaryClass={cs.LG}
67 | }
68 | ```
69 |
70 | ## Acknowledgement
71 | This repo is heavily built upon [DQL](https://github.com/Zhendong-Wang/Diffusion-Policies-for-Offline-RL). We thank the authors for their excellent work.
72 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import math
4 |
5 |
6 | def print_banner(s, separator="-", num_star=60):
7 | print(separator * num_star, flush=True)
8 | print(s, flush=True)
9 | print(separator * num_star, flush=True)
10 |
11 |
12 | class Progress:
13 |
14 | def __init__(self, total, name='Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100):
15 | self.total = total
16 | self.name = name
17 | self.ncol = ncol
18 | self.max_length = max_length
19 | self.indent = indent
20 | self.line_width = line_width
21 | self._speed_update_freq = speed_update_freq
22 |
23 | self._step = 0
24 | self._prev_line = '\033[F'
25 | self._clear_line = ' ' * self.line_width
26 |
27 | self._pbar_size = self.ncol * self.max_length
28 | self._complete_pbar = '#' * self._pbar_size
29 | self._incomplete_pbar = ' ' * self._pbar_size
30 |
31 | self.lines = ['']
32 | self.fraction = '{} / {}'.format(0, self.total)
33 |
34 | self.resume()
35 |
36 | def update(self, description, n=1):
37 | self._step += n
38 | if self._step % self._speed_update_freq == 0:
39 | self._time0 = time.time()
40 | self._step0 = self._step
41 | self.set_description(description)
42 |
43 | def resume(self):
44 | self._skip_lines = 1
45 | print('\n', end='')
46 | self._time0 = time.time()
47 | self._step0 = self._step
48 |
49 | def pause(self):
50 | self._clear()
51 | self._skip_lines = 1
52 |
53 | def set_description(self, params=[]):
54 |
55 | if type(params) == dict:
56 | params = sorted([
57 | (key, val)
58 | for key, val in params.items()
59 | ])
60 |
61 | ############
62 | # Position #
63 | ############
64 | self._clear()
65 |
66 | ###########
67 | # Percent #
68 | ###########
69 | percent, fraction = self._format_percent(self._step, self.total)
70 | self.fraction = fraction
71 |
72 | #########
73 | # Speed #
74 | #########
75 | speed = self._format_speed(self._step)
76 |
77 | ##########
78 | # Params #
79 | ##########
80 | num_params = len(params)
81 | nrow = math.ceil(num_params / self.ncol)
82 | params_split = self._chunk(params, self.ncol)
83 | params_string, lines = self._format(params_split)
84 | self.lines = lines
85 |
86 | description = '{} | {}{}'.format(percent, speed, params_string)
87 | print(description)
88 | self._skip_lines = nrow + 1
89 |
90 | def append_description(self, descr):
91 | self.lines.append(descr)
92 |
93 | def _clear(self):
94 | position = self._prev_line * self._skip_lines
95 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)])
96 | print(position, end='')
97 | print(empty)
98 | print(position, end='')
99 |
100 | def _format_percent(self, n, total):
101 | if total:
102 | percent = n / float(total)
103 |
104 | complete_entries = int(percent * self._pbar_size)
105 | incomplete_entries = self._pbar_size - complete_entries
106 |
107 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries]
108 | fraction = '{} / {}'.format(n, total)
109 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent * 100))
110 | else:
111 | fraction = '{}'.format(n)
112 | string = '{} iterations'.format(n)
113 | return string, fraction
114 |
115 | def _format_speed(self, n):
116 | num_steps = n - self._step0
117 | t = time.time() - self._time0
118 | speed = num_steps / t
119 | string = '{:.1f} Hz'.format(speed)
120 | if num_steps > 0:
121 | self._speed = string
122 | return string
123 |
124 | def _chunk(self, l, n):
125 | return [l[i:i + n] for i in range(0, len(l), n)]
126 |
127 | def _format(self, chunks):
128 | lines = [self._format_chunk(chunk) for chunk in chunks]
129 | lines.insert(0, '')
130 | padding = '\n' + ' ' * self.indent
131 | string = padding.join(lines)
132 | return string, lines
133 |
134 | def _format_chunk(self, chunk):
135 | line = ' | '.join([self._format_param(param) for param in chunk])
136 | return line
137 |
138 | def _format_param(self, param):
139 | k, v = param
140 | return '{} : {}'.format(k, v)[:self.max_length]
141 |
142 | def stamp(self):
143 | if self.lines != ['']:
144 | params = ' | '.join(self.lines)
145 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed)
146 | self._clear()
147 | print(string, end='\n')
148 | self._skip_lines = 1
149 | else:
150 | self._clear()
151 | self._skip_lines = 0
152 |
153 | def close(self):
154 | self.pause()
155 |
156 |
157 | class Silent:
158 |
159 | def __init__(self, *args, **kwargs):
160 | pass
161 |
162 | def __getattr__(self, attr):
163 | return lambda *args: None
164 |
165 |
166 | class EarlyStopping(object):
167 | def __init__(self, tolerance=5, min_delta=0):
168 | self.tolerance = tolerance
169 | self.min_delta = min_delta
170 | self.counter = 0
171 | self.early_stop = False
172 |
173 | def __call__(self, train_loss, validation_loss):
174 | if (validation_loss - train_loss) > self.min_delta:
175 | self.counter += 1
176 | if self.counter >= self.tolerance:
177 | return True
178 | else:
179 | self.counter = 0
180 | return False
181 |
--------------------------------------------------------------------------------
/agents/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ########################################### Critic net ################################
5 | class Critic(nn.Module):
6 | def __init__(self, state_dim, action_dim, hidden_dim=256):
7 | super(Critic, self).__init__()
8 | self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
9 | nn.LayerNorm(hidden_dim),
10 | nn.Mish(),
11 | nn.Linear(hidden_dim, hidden_dim),
12 | nn.LayerNorm(hidden_dim),
13 | nn.Mish(),
14 | nn.Linear(hidden_dim, hidden_dim),
15 | nn.LayerNorm(hidden_dim),
16 | nn.Mish(),
17 | nn.Linear(hidden_dim, 1))
18 |
19 | self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
20 | nn.LayerNorm(hidden_dim),
21 | nn.Mish(),
22 | nn.Linear(hidden_dim, hidden_dim),
23 | nn.LayerNorm(hidden_dim),
24 | nn.Mish(),
25 | nn.Linear(hidden_dim, hidden_dim),
26 | nn.LayerNorm(hidden_dim),
27 | nn.Mish(),
28 | nn.Linear(hidden_dim, 1))
29 |
30 | self.v_model = nn.Sequential(nn.Linear(state_dim, hidden_dim),
31 | nn.Mish(),
32 | nn.Linear(hidden_dim, hidden_dim),
33 | nn.Mish(),
34 | nn.Linear(hidden_dim, hidden_dim),
35 | nn.Mish(),
36 | nn.Linear(hidden_dim, 1))
37 |
38 | def forward(self, state, action):
39 | x = torch.cat([state, action], dim=-1)
40 | return self.q1_model(x), self.q2_model(x)
41 |
42 | def q1(self, state, action):
43 | x = torch.cat([state, action], dim=-1)
44 | return self.q1_model(x)
45 |
46 | def q_min(self, state, action):
47 | q1, q2 = self.forward(state, action)
48 | return torch.min(q1, q2)
49 |
50 | def v(self, state):
51 | return self.v_model(state)
52 |
53 |
54 | ############################################# SAC #################################################
55 | def weight_init(m):
56 | """Custom weight init for Conv2D and Linear layers."""
57 | if isinstance(m, nn.Linear):
58 | nn.init.orthogonal_(m.weight.data)
59 | if hasattr(m.bias, 'data'):
60 | m.bias.data.fill_(0.0)
61 |
62 |
63 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
64 | if hidden_depth == 0:
65 | mods = [nn.Linear(input_dim, output_dim)]
66 | else:
67 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
68 | for i in range(hidden_depth - 1):
69 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
70 | mods.append(nn.Linear(hidden_dim, output_dim))
71 | if output_mod is not None:
72 | mods.append(output_mod)
73 | trunk = nn.Sequential(*mods)
74 | return trunk
75 |
76 |
77 | class DiagGaussianActorTanhAction(nn.Module):
78 | """torch.distributions implementation of an diagonal Gaussian policy."""
79 | def __init__(self, state_dim, action_dim, max_action,
80 | hidden_dim=256, hidden_depth=3,
81 | log_std_bounds=[-5, 2]):
82 | super().__init__()
83 |
84 | self.log_std_bounds = log_std_bounds
85 | self.net = mlp(state_dim, hidden_dim, 2 * action_dim, hidden_depth)
86 | self.apply(weight_init)
87 | self.action_scale = max_action
88 | self.action_dim = action_dim
89 |
90 | def forward(self, state):
91 | mu, log_std = self.net(state).chunk(2, dim=-1)
92 | log_std_min, log_std_max = self.log_std_bounds
93 | log_std = log_std.clamp(log_std_min, log_std_max)
94 |
95 | std = log_std.exp()
96 | actor_dist = torch.distributions.Normal(mu, std)
97 | return actor_dist
98 |
99 | def sample(self, state):
100 | actor_dist = self(state)
101 | z = actor_dist.rsample()
102 | action = torch.tanh(z)
103 |
104 | action = action * self.action_scale
105 | action = action.clamp(-self.action_scale, self.action_scale)
106 | return action
107 |
108 | def log_prob(self, state, action):
109 | actor_dist = self(state)
110 | pre_tanh_value = torch.arctanh(action / (self.action_scale + 1e-3))
111 | log_prob = actor_dist.log_prob(pre_tanh_value)
112 | return log_prob.sum(-1, keepdim=True)
113 |
114 | def get_entropy(self,state):
115 | with torch.no_grad():
116 | mu, log_std = self.net(state).chunk(2, dim=-1)
117 | return log_std.sum(-1).mean()
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/diffusion/mlps.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import math
3 | import torch
4 | from torch import nn
5 | import numpy as np
6 |
7 | class SinusoidalPosEmb(nn.Module):
8 | def __init__(self, dim):
9 | super().__init__()
10 | self.dim = dim
11 |
12 | def forward(self, x):
13 | device = x.device
14 | half_dim = self.dim // 2
15 | emb = math.log(10000) / (half_dim - 1)
16 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
17 | emb = x[:, None] * emb[None, :]
18 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
19 | return emb
20 |
21 |
22 | class MLPNetwork(nn.Module):
23 | """
24 | Simple multi layer perceptron network which can be generated with different
25 | activation functions with and without spectral normalization of the weights
26 | """
27 |
28 | def __init__(
29 | self,
30 | input_dim: int,
31 | hidden_dim: int = 100,
32 | num_hidden_layers: int = 1,
33 | output_dim=1,
34 | device: str = 'cuda'
35 | ):
36 | super(MLPNetwork, self).__init__()
37 | self.network_type = "mlp"
38 | # define number of variables in an input sequence
39 | self.input_dim = input_dim
40 | # the dimension of neurons in the hidden layer
41 | self.hidden_dim = hidden_dim
42 | self.num_hidden_layers = num_hidden_layers
43 | # number of samples per batch
44 | self.output_dim = output_dim
45 | # set up the network
46 | self.layers = nn.ModuleList([nn.Linear(self.input_dim, self.hidden_dim)])
47 | for i in range(1, self.num_hidden_layers):
48 | self.layers.extend([
49 | nn.Linear(self.hidden_dim, self.hidden_dim),
50 | nn.Mish()
51 | ])
52 | self.layers.append(nn.Linear(self.hidden_dim, self.output_dim))
53 |
54 | self._device = device
55 | self.layers.to(self._device)
56 |
57 | def forward(self, x):
58 | for layer in self.layers:
59 | x = layer(x)
60 | return x
61 |
62 | def get_device(self, device: torch.device):
63 | self._device = device
64 | self.layers.to(device)
65 |
66 | def get_params(self):
67 | return self.layers.parameters()
68 |
69 |
70 | class ScoreNetwork(nn.Module):
71 | def __init__(
72 | self,
73 | action_dim: int,
74 | hidden_dim: int,
75 | time_embed_dim: int,
76 | cond_dim: int,
77 | cond_mask_prob: float,
78 | num_hidden_layers: int = 1,
79 | output_dim=1,
80 | device: str = 'cuda',
81 | cond_conditional: bool = True
82 | ):
83 | super(ScoreNetwork, self).__init__()
84 | # Gaussian random feature embedding layer for time
85 | # self.embed = GaussianFourierProjection(time_embed_dim).to(device)
86 | self.embed = nn.Sequential(
87 | SinusoidalPosEmb(time_embed_dim),
88 | nn.Linear(time_embed_dim, time_embed_dim * 2),
89 | nn.Mish(),
90 | nn.Linear(time_embed_dim * 2, time_embed_dim),
91 | ).to(device)
92 | self.time_embed_dim = time_embed_dim
93 | self.cond_mask_prob = cond_mask_prob
94 | self.cond_conditional = cond_conditional
95 | if self.cond_conditional:
96 | input_dim = self.time_embed_dim + action_dim + cond_dim
97 | else:
98 | input_dim = self.time_embed_dim + action_dim
99 | # set up the network
100 | self.layers = MLPNetwork(
101 | input_dim,
102 | hidden_dim,
103 | num_hidden_layers,
104 | output_dim,
105 | device
106 | ).to(device)
107 |
108 | # build the activation layer
109 | self.act = nn.Mish()
110 | self.device = device
111 | self.sigma = None
112 | self.training = True
113 |
114 | def forward(self, x, cond, sigma, uncond=False):
115 | # Obtain the feature embedding for t
116 | if len(sigma.shape) == 0:
117 | sigma = einops.rearrange(sigma, ' -> 1')
118 | sigma = sigma.unsqueeze(1)
119 | elif len(sigma.shape) == 1:
120 | sigma = sigma.unsqueeze(1)
121 | embed = self.embed(sigma)
122 | embed.squeeze_(1)
123 | if embed.shape[0] != x.shape[0]:
124 | embed = einops.repeat(embed, '1 d -> (1 b) d', b=x.shape[0])
125 | # during training randomly mask out the cond
126 | # to train the conditional model with classifier-free guidance wen need
127 | # to 0 out some of the conditional during training with a desrired probability
128 | # it is usually in the range of 0,1 to 0.2
129 | if self.training and cond is not None:
130 | cond = self.mask_cond(cond)
131 | # we want to use unconditional sampling during classifier free guidance
132 | if uncond:
133 | cond = torch.zeros_like(cond) # cond
134 | if self.cond_conditional:
135 | x = torch.cat([x, cond, embed], dim=-1)
136 | else:
137 | x = torch.cat([x, embed], dim=-1)
138 | x = self.layers(x)
139 | return x # / marginal_prob_std(t, self.sigma, self.device)[:, None]
140 |
141 | def mask_cond(self, cond, force_mask=False):
142 | bs, d = cond.shape
143 | if force_mask:
144 | return torch.zeros_like(cond)
145 | elif self.training and self.cond_mask_prob > 0.:
146 | mask = torch.bernoulli(torch.ones((bs, d),
147 | device=cond.device) * self.cond_mask_prob) # .view(bs, 1) # 1-> use null_cond, 0-> use real cond
148 | return cond * (1. - mask)
149 | else:
150 | return cond
151 |
152 | def get_params(self):
153 | return self.parameters()
--------------------------------------------------------------------------------
/diffusion/karras.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch
3 | from .utils import *
4 | from torch import nn
5 | class DiffusionModel(nn.Module):
6 | def __init__(
7 | self,
8 | sigma_data: float,
9 | sigma_min: float,
10 | sigma_max: float,
11 | device: str,
12 | sigma_sample_density_type: str = 'loglogistic',
13 | clip_denoised=False,
14 | max_action=1.0,
15 | ) -> None:
16 | super().__init__()
17 |
18 | self.device = device
19 | # use the score wrapper
20 | self.sigma_data = sigma_data
21 | self.sigma_min = sigma_min
22 | self.sigma_max = sigma_max
23 | self.sigma_sample_density_type = sigma_sample_density_type
24 | self.epochs = 0
25 | self.clip_denoised = clip_denoised
26 | self.max_action = max_action
27 |
28 | def get_diffusion_scalings(self, sigma):
29 | """
30 | Computes the scaling factors for diffusion training at a given time step sigma.
31 |
32 | Args:
33 | - self: the object instance of the model
34 | - sigma (float or torch.Tensor): the time step at which to compute the scaling factors
35 |
36 | , where self.sigma_data: the data noise level of the diffusion process, set during initialization of the model
37 |
38 | Returns:
39 | - c_skip (torch.Tensor): the scaling factor for skipping the diffusion model for the given time step sigma
40 | - c_out (torch.Tensor): the scaling factor for the output of the diffusion model for the given time step sigma
41 | - c_in (torch.Tensor): the scaling factor for the input of the diffusion model for the given time step sigma
42 |
43 | """
44 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
45 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
46 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
47 | return c_skip, c_out, c_in
48 |
49 | def diffusion_train_step(self, model, x, cond, noise=None, t_chosen=None, return_denoised=False):
50 | """
51 | Computes the training loss and performs a single update step for the score-based model.
52 |
53 | Args:
54 | - self: the object instance of the model
55 | - x (torch.Tensor): the input tensor of shape (batch_size, dim)
56 | - cond (torch.Tensor): the conditional input tensor of shape (batch_size, cond_dim)
57 |
58 | Returns:
59 | - loss.item() (float): the scalar value of the training loss for this batch
60 |
61 | """
62 | model.train()
63 | x = x.to(self.device)
64 | cond = cond.to(self.device)
65 | if t_chosen is None:
66 | t_chosen = self.make_sample_density()(shape=(len(x),), device=self.device)
67 |
68 | if return_denoised:
69 | denoised_x, loss = self.diffusion_loss(model,x, cond, t_chosen, noise, return_denoised)
70 | return denoised_x, loss, t_chosen
71 | else:
72 | loss = self.diffusion_loss(model, x, cond, t_chosen, noise, return_denoised)
73 | return loss
74 |
75 | def diffusion_loss(self, model, x, cond, t, noise, return_denoised):
76 | """
77 | Computes the diffusion training loss for the given model, input, condition, and time.
78 |
79 | Args:
80 | - self: the object instance of the model
81 | - x (torch.Tensor): the input tensor of shape (batch_size, channels, height, width)
82 | - cond (torch.Tensor): the conditional input tensor of shape (batch_size, cond_dim)
83 | - t (torch.Tensor): the time step tensor of shape (batch_size,)
84 |
85 | Returns:
86 | - loss (torch.Tensor): the diffusion training loss tensor of shape ()
87 |
88 | The diffusion training loss is computed based on the following equation from Karras et al. 2022:
89 | loss = (model_output - target)^2.mean()
90 | where,
91 | - noise: a tensor of the same shape as x, containing randomly sampled noise
92 | - x_1: a tensor of the same shape as x, obtained by adding the noise tensor to x
93 | - c_skip, c_out, c_in: scaling tensors obtained from the diffusion scalings for the given time step
94 | - t: a tensor of the same shape as t, obtained by taking the natural logarithm of t and dividing it by 4
95 | - model_output: the output tensor of the model for the input x_1, condition cond, and time t
96 | - target: the target tensor for the given input x, scaling tensors c_skip, c_out, c_in, and time t
97 | """
98 | if noise is None:
99 | noise = torch.randn_like(x)
100 | x_1 = x + noise * append_dims(t, x.ndim)
101 | c_skip, c_out, c_in = [append_dims(x, 2) for x in self.get_diffusion_scalings(t)]
102 | t = torch.log(t) / 4
103 | model_output = model(x_1 * c_in, cond, t)
104 |
105 | if self.clip_denoised:
106 | denoised_x = c_out * model_output + c_skip * x_1
107 | denoised_x = denoised_x.clamp(-self.max_action,self.max_action)
108 | loss = ((denoised_x - x)/c_out).pow(2).mean()
109 | else:
110 | denoised_x = c_out * model_output + c_skip * x_1
111 | target = (x - c_skip * x_1) / c_out
112 | loss = (model_output - target).pow(2).mean()
113 |
114 | if return_denoised:
115 | return denoised_x, loss
116 | else:
117 | return loss
118 |
119 | def make_sample_density(self):
120 | """
121 | Returns a function that generates random timesteps based on the chosen sample density.
122 |
123 | Args:
124 | - self: the object instance of the model
125 |
126 | Returns:
127 | - sample_density_fn (callable): a function that generates random timesteps
128 |
129 | The method returns a callable function that generates random timesteps based on the chosen sample density.
130 | The available sample densities are:
131 | - 'lognormal': generates random timesteps from a log-normal distribution with mean and standard deviation set
132 | during initialization of the model also used in Karras et al. (2022)
133 | - 'loglogistic': generates random timesteps from a log-logistic distribution with location parameter set to the
134 | natural logarithm of the sigma_data parameter and scale and range parameters set during initialization
135 | of the model
136 | - 'loguniform': generates random timesteps from a log-uniform distribution with range parameters set during
137 | initialization of the model
138 | - 'uniform': generates random timesteps from a uniform distribution with range parameters set during initialization
139 | of the model
140 | - 'v-diffusion': generates random timesteps using the Variational Diffusion sampler with range parameters set during
141 | initialization of the model
142 | - 'discrete': generates random timesteps from the noise schedule using the exponential density
143 | - 'split-lognormal': generates random timesteps from a split log-normal distribution with mean and standard deviation
144 | set during initialization of the model
145 | """
146 | sd_config = []
147 |
148 | if self.sigma_sample_density_type == 'lognormal':
149 | loc = self.sigma_sample_density_mean # if 'mean' in sd_config else sd_config['loc']
150 | scale = self.sigma_sample_density_std # if 'std' in sd_config else sd_config['scale']
151 | return partial(rand_log_normal, loc=loc, scale=scale)
152 |
153 | if self.sigma_sample_density_type == 'loglogistic':
154 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(self.sigma_data)
155 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5
156 | min_value = sd_config['min_value'] if 'min_value' in sd_config else self.sigma_min
157 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max
158 | return partial(rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
159 |
160 | if self.sigma_sample_density_type == 'loguniform':
161 | min_value = sd_config['min_value'] if 'min_value' in sd_config else self.sigma_min
162 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max
163 | return partial(rand_log_uniform, min_value=min_value, max_value=max_value)
164 | if self.sigma_sample_density_type == 'uniform':
165 | return partial(rand_uniform, min_value=self.sigma_min, max_value=self.sigma_max)
166 |
167 | if self.sigma_sample_density_type == 'v-diffusion':
168 | min_value = self.min_value if 'min_value' in sd_config else self.sigma_min
169 | max_value = sd_config['max_value'] if 'max_value' in sd_config else self.sigma_max
170 | return partial(rand_v_diffusion, sigma_data=self.sigma_data, min_value=min_value, max_value=max_value)
171 | if self.sigma_sample_density_type == 'discrete':
172 | sigmas = self.get_noise_schedule(self.n_sampling_steps, 'exponential')
173 | return partial(rand_discrete, values=sigmas)
174 | else:
175 | raise ValueError('Unknown sample density type')
--------------------------------------------------------------------------------
/agents/dtql.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import time
3 |
4 | from torch.optim.lr_scheduler import CosineAnnealingLR
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from pathlib import Path
9 | from diffusion.karras import DiffusionModel
10 | from diffusion.mlps import ScoreNetwork
11 | from agents.model import Critic, DiagGaussianActorTanhAction
12 | from agents.helpers import EMA
13 |
14 | class DTQL(object):
15 | def __init__(self,
16 | device,
17 | state_dim,
18 | action_dim,
19 | action_space=None,
20 | discount=0.99,
21 | alpha=1.0,
22 | ema_decay=0.995,
23 | step_start_ema=1000,
24 | update_ema_every=5,
25 | lr=3e-4,
26 | lr_decay=False,
27 | lr_maxt=1000,
28 | sigma_max=80.,
29 | sigma_min=0.002,
30 | sigma_data=0.5,
31 | expectile=0.7,
32 | tau = 0.005,
33 | gamma=1,
34 | repeats=1024
35 | ):
36 | """Init critic networks"""
37 | self.critic = Critic(state_dim, action_dim).to(device)
38 | self.critic_target = copy.deepcopy(self.critic)
39 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
40 |
41 | """Init bc actor"""
42 | self.bc_actor = ScoreNetwork(
43 | action_dim=action_dim,
44 | hidden_dim=256,
45 | time_embed_dim=16,
46 | cond_dim=state_dim,
47 | cond_mask_prob=0.0,
48 | num_hidden_layers=4,
49 | output_dim=action_dim,
50 | device=device,
51 | cond_conditional=True
52 | ).to(device)
53 | self.bc_actor_target = copy.deepcopy(self.bc_actor)
54 | self.bc_actor_optimizer = torch.optim.Adam(self.bc_actor.parameters(), lr=lr)
55 |
56 | """Init diffusion training schedule"""
57 | self.diffusion = DiffusionModel(
58 | sigma_data=sigma_data,
59 | sigma_min=sigma_min,
60 | sigma_max=sigma_max,
61 | device=device,
62 | clip_denoised=True,
63 | max_action=float(action_space.high[0]))
64 |
65 | """Init sac"""
66 | self.distill_actor = DiagGaussianActorTanhAction(state_dim=state_dim, action_dim=action_dim,
67 | max_action=float(action_space.high[0])).to(device)
68 | self.distill_actor_target = copy.deepcopy(self.distill_actor)
69 | self.distill_actor_optimizer = torch.optim.Adam(self.distill_actor.parameters(), lr=lr)
70 |
71 |
72 | """Back up training parameters"""
73 | self.tau = tau
74 | self.lr_decay = lr_decay
75 | self.gamma = gamma
76 | self.repeats = repeats
77 |
78 | self.step = 0
79 | self.step_start_ema = step_start_ema
80 | self.ema = EMA(ema_decay)
81 | self.update_ema_every = update_ema_every
82 |
83 | if lr_decay:
84 | self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.)
85 | self.bc_actor_lr_scheduler = CosineAnnealingLR(self.bc_actor_optimizer, T_max=lr_maxt, eta_min=0.)
86 | self.distill_actor_lr_scheduler = CosineAnnealingLR(self.distill_actor_optimizer, T_max=lr_maxt, eta_min=0.)
87 |
88 | self.state_dim = state_dim
89 | self.action_dim = action_dim
90 | self.discount = discount
91 | self.alpha = alpha # bc weight
92 | self.expectile = expectile
93 | self.device = device
94 |
95 |
96 | def step_ema(self):
97 | if self.step < self.step_start_ema:
98 | return
99 | self.ema.update_model_average(self.distill_actor_target, self.distill_actor)
100 |
101 | def pretrain(self,replay_buffer, batch_size=256,pretrain_steps=50000):
102 | self.bc_actor.train()
103 | for _ in range(pretrain_steps):
104 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
105 | loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state)
106 | self.bc_actor_optimizer.zero_grad()
107 | loss.backward()
108 | self.bc_actor_optimizer.step()
109 | self.bc_loss = loss
110 |
111 | critic_loss = self.q_v_critic_loss(state,action, next_state, reward, not_done)
112 | self.critic_optimizer.zero_grad()
113 | critic_loss.backward()
114 | self.critic_optimizer.step()
115 |
116 |
117 | def train(self, replay_buffer, batch_size=256):
118 | # initialize
119 | self.bc_loss = torch.tensor([0.]).to(self.device)
120 | self.critic_loss = torch.tensor([0.]).to(self.device)
121 | metric = {'bc_loss': [], 'distill_loss':[], 'ql_loss': [], 'actor_loss': [], 'critic_loss': [], 'gamma_loss': []}
122 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
123 |
124 | """ Q Training """
125 | critic_loss = self.q_v_critic_loss(state, action, next_state, reward, not_done)
126 |
127 | self.critic_loss = critic_loss
128 | self.critic_optimizer.zero_grad()
129 | critic_loss.backward()
130 | self.critic_optimizer.step()
131 |
132 |
133 | """ Diffusion Policy Training """
134 | bc_loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state)
135 | self.bc_actor_optimizer.zero_grad()
136 | bc_loss.backward()
137 | self.bc_actor_optimizer.step()
138 | self.bc_loss = bc_loss
139 |
140 |
141 | """Distill Policy Training"""
142 | new_action = self.distill_actor.sample(state)
143 | distill_loss = self.diffusion.diffusion_train_step(self.bc_actor, new_action, state)
144 | q_loss = -self.critic.q_min(state, new_action).mean()
145 |
146 | if self.gamma == 0.:
147 | gamma_loss = torch.tensor([0.]).to(self.device)
148 | else:
149 | gamma_loss = -self.distill_actor.log_prob(state, action).mean()
150 | actor_loss = (self.alpha * distill_loss + q_loss +
151 | self.gamma * gamma_loss)
152 | self.distill_actor_optimizer.zero_grad()
153 | actor_loss.backward()
154 | self.distill_actor_optimizer.step()
155 |
156 | """ Step Target network """
157 | if self.step % self.update_ema_every == 0:
158 | self.step_ema()
159 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
160 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
161 |
162 | """ Record loss """
163 | self.step += 1
164 |
165 | metric['actor_loss'].append(actor_loss.item())
166 | metric['bc_loss'].append(self.bc_loss.item())
167 | metric['ql_loss'].append(q_loss.item())
168 | metric['critic_loss'].append(self.critic_loss.item())
169 | metric['distill_loss'].append(distill_loss.item())
170 | metric['gamma_loss'].append(gamma_loss.item())
171 |
172 | """ Lr decay"""
173 | if self.lr_decay:
174 | self.bc_actor_lr_scheduler.step()
175 | self.distill_actor_lr_scheduler.step()
176 | self.critic_lr_scheduler.step()
177 | return metric
178 |
179 | def sample_action(self, state):
180 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
181 | state_rpt = torch.repeat_interleave(state, repeats=self.repeats, dim=0)
182 | with torch.no_grad():
183 | action = self.distill_actor.sample(state_rpt)
184 | q_value = self.critic_target.q_min(state_rpt, action).flatten()
185 | idx = torch.multinomial(F.softmax(q_value), 1)
186 | action = action[idx].cpu().data.numpy().flatten()
187 | return action
188 |
189 | def save_model(self, dir, id=None):
190 | if id is not None:
191 | torch.save(self.bc_actor.state_dict(), f'{dir}/bc_actor_{id}.pth')
192 | torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
193 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor_{id}.pth')
194 | else:
195 | torch.save(self.bc_actor.state_dict(), f'{dir}/actor.pth')
196 | torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
197 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor.pth')
198 |
199 | def load_model(self, dir, id=None):
200 | if id is not None:
201 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor_{id}.pth'))
202 | self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth'))
203 | print(f"Successfully load critic from {dir}/critic_{id}.pth")
204 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor_{id}.pth'))
205 | print(f"Successfully load distill actor from {dir}/distill_actor_{id}.pth")
206 | else:
207 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor.pth'))
208 | self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))
209 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor.pth'))
210 | print(f"Models loaded successfully from {dir}")
211 |
212 | def load_or_pretrain_models(self, dir, replay_buffer, batch_size, pretrain_steps,num_steps_per_epoch):
213 | # Paths for the models
214 | actor_path = Path(dir) / f'diffusion_pretrained_{pretrain_steps // num_steps_per_epoch}.pth'
215 | critic_path = Path(dir) / f'critic_pretrained_{pretrain_steps // num_steps_per_epoch}.pth'
216 |
217 | # Check if both models exist
218 | if actor_path.exists() and critic_path.exists():
219 | try:
220 | # Load the models
221 | self.bc_actor.load_state_dict(torch.load(actor_path, map_location=self.device))
222 | self.critic.load_state_dict(torch.load(critic_path, map_location=self.device))
223 | except Exception as e:
224 | print(f"Failed to load models: {e}")
225 | else:
226 | # Begin pretraining if the models do not exist
227 | print("Models not found, starting pretraining...")
228 | self.pretrain(replay_buffer, batch_size, pretrain_steps)
229 | torch.save(self.bc_actor.state_dict(), actor_path)
230 | torch.save(self.critic.state_dict(), critic_path)
231 | print(f"Saved successfully to {dir}")
232 |
233 |
234 | def q_v_critic_loss(self,state,action, next_state, reward, not_done):
235 | def expectile_loss(diff, expectile=0.8):
236 | weight = torch.where(diff > 0, expectile, (1 - expectile))
237 | return weight * (diff ** 2)
238 |
239 | with torch.no_grad():
240 | q = self.critic.q_min(state, action)
241 | v = self.critic.v(state)
242 | value_loss = expectile_loss(q - v, self.expectile).mean()
243 |
244 | current_q1, current_q2 = self.critic(state, action)
245 | with torch.no_grad():
246 | next_v = self.critic.v(next_state)
247 | target_q = (reward + not_done * self.discount * next_v).detach()
248 |
249 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
250 | return critic_loss
251 |
--------------------------------------------------------------------------------
/toy_tasks/data_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.distributions import Normal
4 | import torch
5 | from torch.distributions.normal import Normal
6 |
7 |
8 | class DataGenerator:
9 | def __init__(self, dist_type: str):
10 | self.dist_type = dist_type
11 | self.func_mapping = {
12 | "two_gmm_1D": (self.two_gmm_1D, self.two_gmm_1D_log_prob),
13 | "uneven_two_gmm_1D": (self.uneven_two_gmm_1D, self.uneven_two_gmm_1D_log_prob),
14 | "three_gmm_1D": (self.three_gmm_1D, self.three_gmm_1D_log_prob),
15 | "single_gaussian_1D": (self.single_gaussian_1D, self.single_gaussian_1D_log_prob),
16 | "swiss_roll_2D": (self.sample_swiss_roll, None),
17 | "25_gaussian": (self.sample_25_gaussian, None)
18 | }
19 | if self.dist_type not in self.func_mapping:
20 | raise ValueError("Invalid distribution type")
21 | self.sample_func, self.log_prob_func = self.func_mapping[self.dist_type]
22 |
23 | def generate_samples(self, num_samples: int):
24 | """
25 | Generate `num_samples` samples and labels using the `sample_func`.
26 |
27 | Args:
28 | num_samples (int): Number of samples to generate.
29 |
30 | Returns:
31 | Tuple[np.ndarray, np.ndarray]: A tuple of two numpy arrays containing the generated samples and labels.
32 | """
33 | samples, labels = self.sample_func(num_samples)
34 | return samples, labels
35 |
36 | def compute_log_prob(self, samples, exp: bool = False):
37 | """
38 | Compute the logarithm of probability density function (pdf) of the given `samples`
39 | using the `log_prob_func`. If `exp` is True, return exponentiated log probability.
40 |
41 | Args:
42 | samples (np.ndarray): Samples for which pdf is to be computed.
43 | exp (bool, optional): If True, return exponentiated log probability.
44 | Default is False.
45 |
46 | Returns:
47 | np.ndarray: Logarithm of probability density function (pdf) of the given `samples`.
48 | If `exp` is True, exponentiated log probability is returned.
49 | """
50 | return self.log_prob_func(samples, exp=exp)
51 |
52 | @staticmethod
53 | def two_gmm_1D(num_samples, ):
54 | """
55 | Generates `num_samples` samples from a 1D mixture of two Gaussians with equal weights.
56 |
57 | Args:
58 | num_samples (int): Number of samples to generate.
59 |
60 | Returns:
61 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated
62 | samples and binary labels indicating which Gaussian component the sample is from.
63 | """
64 | g1 = Normal(loc=-1.5, scale=0.3)
65 | g2 = Normal(loc=1.5, scale=0.3)
66 | mixture_probs = torch.ones(num_samples) * 0.5
67 | is_from_g1 = torch.bernoulli(mixture_probs).bool()
68 | samples = torch.where(is_from_g1, g1.sample((num_samples,)), g2.sample((num_samples,)))
69 | return samples, is_from_g1.int()
70 |
71 | @staticmethod
72 | def uneven_two_gmm_1D(num_samples, w1=0.7):
73 | """
74 | Generates `num_samples` samples from a 1D mixture of two Gaussians with weights `w1` and `w2`.
75 |
76 | Args:
77 | num_samples (int): Number of samples to generate.
78 | w1 (float, optional): Weight of first Gaussian component. Default is 0.7.
79 |
80 | Returns:
81 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated
82 | samples and binary labels indicating which Gaussian component the sample is from.
83 | """
84 | g1 = Normal(loc=-1.5, scale=0.3)
85 | g2 = Normal(loc=1.5, scale=0.2)
86 | mixture_probs = torch.tensor([w1, 1 - w1])
87 | is_from_g1 = torch.bernoulli(mixture_probs.repeat(num_samples, 1)).view(num_samples, -1).bool().squeeze()
88 |
89 | samples_g1 = g1.sample((num_samples, 1))
90 | samples_g2 = g2.sample((num_samples, 1))
91 | samples = torch.where(is_from_g1, samples_g1, samples_g2).squeeze()
92 |
93 | return samples, is_from_g1.int()
94 |
95 | @staticmethod
96 | def single_gaussian_1D(num_samples):
97 | """
98 | Generates `num_samples` samples from a 1D Gaussian distribution.
99 |
100 | Args:
101 | num_samples (int): Number of samples to generate.
102 |
103 | Returns:
104 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated
105 | samples and binary labels indicating which Gaussian component the sample is from.
106 | Since there is only one Gaussian component, all labels will be zero.
107 | """
108 | g1 = Normal(loc=1, scale=0.2)
109 | samples = g1.sample((num_samples, 1))
110 | return samples, torch.zeros(num_samples).int()
111 |
112 | @staticmethod
113 | def three_gmm_1D(num_samples):
114 | """
115 | Generates `num_samples` samples from a 1D mixture of three Gaussians with equal weights.
116 |
117 | Args:
118 | num_samples (int): Number of samples to generate.
119 | exp (bool, optional): If True, return exponentiated log probability. Default is False.
120 |
121 | Returns:
122 | Tuple[torch.Tensor, torch.Tensor]: A tuple of two torch tensors containing the generated
123 | samples and integer labels indicating which Gaussian component the sample is from.
124 | """
125 | g1 = Normal(loc=-1.5, scale=0.2)
126 | g2 = Normal(loc=0, scale=0.2)
127 | g3 = Normal(loc=1.5, scale=0.2)
128 | mixture_probs = torch.ones(3) / 3
129 | component_assignments = torch.multinomial(mixture_probs, num_samples, replacement=True)
130 | samples = torch.zeros(num_samples, 1)
131 |
132 | g1_mask = (component_assignments == 0)
133 | g2_mask = (component_assignments == 1)
134 | g3_mask = (component_assignments == 2)
135 |
136 | samples[g1_mask] = g1.sample((g1_mask.sum(),)).view(-1, 1)
137 | samples[g2_mask] = g2.sample((g2_mask.sum(),)).view(-1, 1)
138 | samples[g3_mask] = g3.sample((g3_mask.sum(),)).view(-1, 1)
139 |
140 | return samples, component_assignments.int()
141 |
142 | @staticmethod
143 | def two_gmm_1D_log_prob(z, exp=False):
144 | """
145 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of two Gaussians
146 | with equal weights at the given points `z`.
147 |
148 | Args:
149 | z (torch.Tensor): Points at which to compute the pdf.
150 | exp (bool, optional): If True, return exponentiated log probability. Default is False.
151 |
152 | Returns:
153 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of two Gaussians
154 | with equal weights at the given points `z`. If `exp` is True, exponentiated log probability
155 | is returned.
156 | """
157 | g1 = Normal(loc=-1.5, scale=0.3)
158 | g2 = Normal(loc=1.5, scale=0.3)
159 | f = torch.log(0.5 * (g1.log_prob(z).exp() + g2.log_prob(z).exp()))
160 | if exp:
161 | return torch.exp(f)
162 | else:
163 | return f
164 |
165 | @staticmethod
166 | def uneven_two_gmm_1D_log_prob(z, w1=0.7, exp=False):
167 | """
168 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of two Gaussians
169 | with weights `w1` and `w2` at the given points `z`.
170 |
171 | Args:
172 | z (torch.Tensor): Points at which to compute the pdf.
173 | w1 (float, optional): Weight of first Gaussian component. Default is 0.7.
174 | exp (bool, optional): If True, return exponentiated log probability. Default is False.
175 |
176 | Returns:
177 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of two Gaussians
178 | with weights `w1` and `w2` at the given points `z`. If `exp` is True, exponentiated log probability
179 | is returned.
180 | """
181 | g1 = Normal(loc=-1.5, scale=0.3)
182 | g2 = Normal(loc=1.5, scale=0.2)
183 | f = torch.log(w1 * g1.log_prob(z).exp() + (1 - w1) * g2.log_prob(z).exp())
184 | if exp:
185 | return torch.exp(f)
186 | else:
187 | return f
188 |
189 | @staticmethod
190 | def three_gmm_1D_log_prob(z, exp=False):
191 | """
192 | Computes the logarithm of the probability density function (pdf) of a 1D mixture of three Gaussians
193 | with equal weights at the given points `z`.
194 |
195 | Args:
196 | z (torch.Tensor): Points at which to compute the pdf.
197 | exp (bool, optional): If True, return exponentiated log probability. Default is False.
198 |
199 | Returns:
200 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D mixture of three Gaussians
201 | with equal weights at the given points `z`. If `exp` is True, exponentiated log probability
202 | is returned.
203 | """
204 | g1 = Normal(loc=-1.5, scale=0.2)
205 | g2 = Normal(loc=0, scale=0.2)
206 | g3 = Normal(loc=1.5, scale=0.2)
207 | f = torch.log(1 / 3 * (g1.log_prob(z).exp() + g2.log_prob(z).exp() + g3.log_prob(z).exp()))
208 | if exp:
209 | return torch.exp(f)
210 | else:
211 | return f
212 |
213 | @staticmethod
214 | def single_gaussian_1D_log_prob(z, exp=False):
215 | """
216 | Computes the logarithm of the probability density function (pdf) of a 1D Gaussian
217 | distribution at the given points `z`.
218 |
219 | Args:
220 | z (torch.Tensor): Points at which to compute the pdf.
221 | exp (bool, optional): If True, return exponentiated log probability. Default is False.
222 |
223 | Returns:
224 | torch.Tensor: Logarithm of probability density function (pdf) of a 1D Gaussian
225 | distribution at the given points `z`. If `exp` is True, exponentiated log probability
226 | is returned.
227 | """
228 | g = Normal(loc=1, scale=0.2)
229 | f = g.log_prob(z)
230 | if exp:
231 | return torch.exp(f)
232 | else:
233 | return f
234 |
235 | @staticmethod
236 | def sample_swiss_roll(num_samples):
237 | from sklearn.datasets import make_swiss_roll
238 | samples = make_swiss_roll(num_samples, noise=0.1)
239 | samples = torch.tensor(samples[0][:, [0, 2]], dtype=torch.float32)
240 | # None is the placeholder for label
241 | return samples, torch.zeros(num_samples).int()
242 |
243 | @staticmethod
244 | def sample_25_gaussian(num_samples):
245 | num_modes = 25 # Number of Gaussian modes
246 | grid_size = int(np.sqrt(num_modes)) # Determining the grid size (5x5 for 25 modes)
247 |
248 | # Creating a grid of means
249 | x_means = np.linspace(-10, 10, grid_size)
250 | y_means = np.linspace(-10, 10, grid_size)
251 | means = np.array(np.meshgrid(x_means, y_means)).T.reshape(-1, 2)
252 |
253 | # Standard deviation for each mode (can be adjusted as needed)
254 | # Standard deviation for each mode (can be adjusted as needed)
255 | std_dev = 0.3
256 | covariance_matrix = np.array([[std_dev ** 2, 0], [0, std_dev ** 2]]) # Diagonal covariance matrix
257 |
258 | # Generating one sample from each mode
259 | samples = np.array(
260 | [np.random.multivariate_normal(mean, covariance_matrix, num_samples // num_modes) for mean in means])
261 | samples = samples.reshape(-1, samples.shape[-1])
262 | samples = torch.from_numpy(samples).type(torch.float32)
263 | return samples, torch.zeros(num_samples).int()
--------------------------------------------------------------------------------
/toy_tasks/toy_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from tqdm import tqdm
5 | import sys
6 | sys.path.append(("../"))
7 |
8 | from diffusion.karras import DiffusionModel
9 | from diffusion.mlps import ScoreNetwork
10 | from data_generator import DataGenerator
11 | from agents.model import DiagGaussianActorTanhAction, Critic
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import os
15 | import copy
16 | import torch.nn.functional as F
17 | import random
18 |
19 |
20 | def q_v_critic_loss(critic, state, action, next_state, reward, not_done, expectile, discount):
21 | def expectile_loss(diff, expectile=0.8):
22 | weight = torch.where(diff > 0, expectile, (1 - expectile))
23 | return weight * (diff ** 2)
24 |
25 | with torch.no_grad():
26 | q = critic.q_min(state, action)
27 | v = critic.v(state)
28 | value_loss = expectile_loss(q - v, expectile).mean()
29 |
30 | current_q1, current_q2 = critic(state, action)
31 | with torch.no_grad():
32 | next_v = critic.v(next_state)
33 | target_q = (reward + not_done * discount * next_v).detach()
34 |
35 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
36 | return critic_loss
37 |
38 | def get_dmd_loss(diffusion_model, true_score_model, fake_score_model, fake_action_data, state_data):
39 | noise = torch.randn_like(fake_action_data)
40 | with torch.no_grad():
41 | pred_real_action, _, t_chosen = diffusion_model.diffusion_train_step(model=true_score_model, x=fake_action_data, cond=state_data,
42 | noise=noise, t_chosen=None, return_denoised=True)
43 |
44 | pred_fake_action, _, t_chosen = diffusion_model.diffusion_train_step(model=fake_score_model, x=fake_action_data, cond=state_data,
45 | noise=noise, t_chosen=t_chosen,
46 | return_denoised=True)
47 | weighting_factor = (fake_action_data - pred_real_action).abs().mean(axis=1).reshape(-1, 1)
48 | grad = (pred_fake_action - pred_real_action) / weighting_factor
49 | distill_loss = 0.5 * F.mse_loss(fake_action_data, (fake_action_data - grad).detach())
50 | return distill_loss
51 |
52 | def train_toy_task(args,file_name):
53 | data_manager = DataGenerator(args.env_name)
54 | action, state = data_manager.generate_samples(10000)
55 | device = args.device
56 |
57 | """Data Prepare"""
58 | state = state.to(torch.float32).to(device)
59 | state = state.reshape(-1, 1)
60 | next_state = state.reshape(-1, 1).to(device)
61 | reward = np.linalg.norm(action, axis=1, keepdims=True)
62 |
63 | if args.reward_type == "far":
64 | # farer tp (0,0), higher reward
65 | reward = torch.from_numpy(reward).to(torch.float32).to(device)
66 | elif args.reward_type == "near":
67 | # closer to (0,0), higher reward
68 | reward = torch.from_numpy(np.max(reward) - reward).to(torch.float32).to(device)
69 | elif args.reward_type == "hard": # reward type is hard
70 | reward[(action[:, 0] < -7.5) & (action[:, 1] > 7.5)] = 2 * reward[(action[:, 0] < -7.5) & (action[:, 1] > 7.5)]
71 | reward = torch.from_numpy(reward).to(torch.float32).to(device)
72 | elif args.reward_type == "same":
73 | reward = torch.from_numpy(np.zeros_like(reward)).to(torch.float32).to(device)
74 |
75 | action = action.to(torch.float32).to(device)
76 | not_done = torch.ones_like(reward).to(torch.float32)
77 |
78 | """Init bc actor"""
79 | bc_actor = ScoreNetwork(
80 | action_dim=2,
81 | hidden_dim=128,
82 | time_embed_dim=4,
83 | cond_dim=1,
84 | cond_mask_prob=0.0,
85 | num_hidden_layers=4,
86 | output_dim=2,
87 | device=device,
88 | cond_conditional=True
89 | ).to(device)
90 | bc_actor_optimizer = torch.optim.Adam(bc_actor.parameters(), lr=3e-3)
91 |
92 | diffusion = DiffusionModel(
93 | sigma_data=args.sigma_data,
94 | sigma_min=args.sigma_min,
95 | sigma_max=args.sigma_max,
96 | device=device,
97 | )
98 |
99 | if args.actor == "sac":
100 | distill_actor = DiagGaussianActorTanhAction(state_dim=1, action_dim=2,
101 | max_action=action.abs().max()).to(device)
102 | elif args.actor == "implicit":
103 | distill_actor = ScoreNetwork(
104 | action_dim=2,
105 | hidden_dim=128,
106 | time_embed_dim=4,
107 | cond_dim=1,
108 | cond_mask_prob=0.0,
109 | num_hidden_layers=4,
110 | output_dim=2,
111 | device=device,
112 | cond_conditional=True
113 | ).to(device)
114 |
115 | if args.pretrain_diffusion:
116 | for _ in tqdm(range(args.pretrain_epochs)):
117 | loss = diffusion.diffusion_train_step(bc_actor, action, state)
118 | bc_actor_optimizer.zero_grad()
119 | loss.backward()
120 | bc_actor_optimizer.step()
121 | bc_actor_state_dict = bc_actor.state_dict()
122 | distill_actor.load_state_dict(bc_actor_state_dict)
123 |
124 | else:
125 | raise ValueError("Actor type can only be sac for implicit")
126 | distill_actor_optimizer = torch.optim.Adam(distill_actor.parameters(), lr=3e-3)
127 |
128 | critic = Critic(state_dim=1, action_dim=2).to(device)
129 | critic_target = copy.deepcopy(critic)
130 | critic_optimizer = torch.optim.Adam(critic.parameters(), lr=3e-3)
131 |
132 | if args.distill_loss == "dmd":
133 | distill_score = ScoreNetwork(
134 | action_dim=2,
135 | hidden_dim=128,
136 | time_embed_dim=4,
137 | cond_dim=1,
138 | cond_mask_prob=0.0,
139 | num_hidden_layers=4,
140 | output_dim=2,
141 | device=device,
142 | cond_conditional=True
143 | ).to(device)
144 | distill_score_optimizer = torch.optim.Adam(distill_score.parameters(), lr=3e-3)
145 |
146 | def get_action(given_state, action_dim, generation_sigma=2.5):
147 | if args.actor == "sac":
148 | action = distill_actor.sample(state=given_state)
149 | return action
150 | elif args.actor == "implicit":
151 | noise = torch.randn((given_state.shape[0], action_dim)) * generation_sigma
152 | noise = noise.to(given_state.device)
153 | action = distill_actor(noise, given_state, torch.tensor([generation_sigma]).to(given_state.device))
154 | return action
155 | else:
156 | raise ValueError("Actor not correct.")
157 |
158 | pbar = tqdm(range(args.train_epochs))
159 | for i in range(args.train_epochs):
160 | """Q policy"""
161 | critic_loss = q_v_critic_loss(critic,state,action, next_state, reward, not_done, args.expectile, args.discount)
162 | critic_optimizer.zero_grad()
163 | critic_loss.backward()
164 | critic_optimizer.step()
165 |
166 | """BC policy"""
167 | loss = diffusion.diffusion_train_step(bc_actor, action, state)
168 | bc_actor_optimizer.zero_grad()
169 | loss.backward()
170 | bc_actor_optimizer.step()
171 |
172 | """Distill policy"""
173 | new_action = get_action(given_state=state, action_dim=2,
174 | generation_sigma=args.generation_sigma)
175 | q_loss = -critic.q_min(state, new_action).mean()
176 |
177 | if args.distill_loss == "diffusion":
178 | distill_loss = diffusion.diffusion_train_step(bc_actor, new_action, state)
179 | elif args.distill_loss == "dmd":
180 | distill_loss = get_dmd_loss(diffusion, bc_actor, distill_score, new_action, state)
181 | else:
182 | distill_loss = 0
183 |
184 | if args.gamma == 0.:
185 | gamma_loss = 0
186 | else:
187 | gamma_loss = -distill_actor.log_prob(state,action).mean()
188 |
189 | actor_loss = distill_loss + args.eta * q_loss + args.gamma * gamma_loss
190 | distill_actor_optimizer.zero_grad()
191 | actor_loss.backward()
192 | distill_actor_optimizer.step()
193 |
194 | """Train fake score"""
195 | if args.distill_loss == "dmd":
196 | fake_loss = diffusion.diffusion_train_step(distill_score, new_action.detach(), state)
197 | distill_score_optimizer.zero_grad()
198 | fake_loss.backward()
199 | distill_score_optimizer.step()
200 |
201 | """Update critic target"""
202 | for param, target_param in zip(critic.parameters(), critic_target.parameters()):
203 | tau = 0.005
204 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
205 |
206 | pbar.set_description(f"Step {i}, BC Loss: {loss:.4f}, Critic Loss: {critic_loss:.4f}, Q loss:{q_loss:.4f}")
207 | pbar.update(1)
208 |
209 |
210 | state = torch.zeros(1000).int().reshape(-1, 1).to(device).to(torch.float32)
211 | with torch.no_grad():
212 | plt.figure(figsize=(8, 6))
213 | scatter = plt.scatter(action.cpu().numpy()[:, 0], action.cpu().numpy()[:, 1], c=reward.cpu().numpy())
214 | plt.colorbar(scatter, label='Reward values')
215 | plt.title(args.env_name)
216 | plt.xlabel('Action Dimension 1')
217 | plt.ylabel('Action Dimension 2')
218 | plt.grid(True)
219 |
220 | new_action = get_action(given_state=state,action_dim=2,
221 | generation_sigma=args.generation_sigma)
222 | new_action = new_action.cpu().numpy()
223 | plt.scatter(new_action[:, 0], new_action[:, 1], c='red',alpha=0.5)
224 | plot_path = file_name + ".png"
225 | plt.savefig(plot_path)
226 |
227 | if __name__ == "__main__":
228 | parser = argparse.ArgumentParser()
229 | ### Experimental Setups ###
230 | parser.add_argument('--device', default=1, type=int)
231 | parser.add_argument('--actor', default="sac", type=str,help="sac or implicit")
232 | parser.add_argument('--distill_loss', default="dmd", type=str, help="diffusion or dmd")
233 | parser.add_argument("--env_name", default="25_gaussian", type=str, help="swiss_roll_2D or 25_gaussian")
234 | parser.add_argument("--pretrain_diffusion", action="store_true")
235 | parser.add_argument("--train_epochs", default=3500, type=int)
236 | parser.add_argument("--pretrain_epochs", default=500, type=int)
237 | parser.add_argument("--reward_type", default="hard", type=str, help="far, near, or hard")
238 | parser.add_argument("--seed", default=1, type=int)
239 | parser.add_argument("--discount", default=0.99, type=float)
240 | parser.add_argument("--eta", default=1, type=float)
241 | parser.add_argument("--generation_sigma", default=2.5, type=float)
242 | parser.add_argument("--gamma", default=0.,type=float,help="weight of sac entropy")
243 | parser.add_argument("--expectile", default=0.95, type=float)
244 |
245 | parser.add_argument("--sigma_max", default=80, type=float)
246 | parser.add_argument("--sigma_min", default=0.002, type=float)
247 | parser.add_argument("--sigma_data", default=0.5, type=float)
248 |
249 | args = parser.parse_args()
250 | output_dir = f"results/{args.env_name}"
251 | if not os.path.exists(output_dir):
252 | os.makedirs(output_dir)
253 | file_name = f"actor={args.actor}|distill={args.distill_loss}|seed={args.seed}|reward={args.reward_type}"
254 | file_name = os.path.join(output_dir,file_name)
255 | if args.actor == "implicit":
256 | file_name += f"|pretrain={args.pretrain_diffusion}"
257 | file_name += f"|eta={args.eta}"
258 | file_name += f"|gamma={args.gamma}"
259 |
260 | def set_seed(seed):
261 | torch.manual_seed(seed)
262 | torch.cuda.manual_seed_all(seed) # For multi-GPU setups
263 | np.random.seed(seed)
264 | random.seed(seed)
265 | torch.backends.cudnn.deterministic = True
266 | torch.backends.cudnn.benchmark = False # May impact performance
267 |
268 | set_seed(args.seed)
269 | train_toy_task(args, file_name)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gym
3 | import numpy as np
4 | import os
5 | import torch
6 | from pathlib import Path
7 |
8 | import d4rl
9 | from utils import utils
10 | from utils.data_sampler import Data_Sampler
11 | from utils.logger import logger, setup_logger
12 | from agents.dtql import DTQL as Agent
13 |
14 | """If you are using DQL-KL, you can specific generation_sigma when init agent"""
15 | #from agents.dql_kl import DQL_KL as Agent
16 | import random
17 |
18 | offline_hyperparameters = {
19 | 'halfcheetah-medium-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
20 | 'halfcheetah-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
21 | 'halfcheetah-medium-expert-v2': {'lr': 3e-4, 'alpha': 50.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
22 | 'hopper-medium-v2': {'lr': 1e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
23 | 'hopper-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
24 | 'hopper-medium-expert-v2': {'lr': 3e-4, 'alpha': 20.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
25 | 'walker2d-medium-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
26 | 'walker2d-medium-replay-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
27 | 'walker2d-medium-expert-v2': {'lr': 3e-4, 'alpha': 5.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
28 | 'antmaze-umaze-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9},
29 | 'antmaze-umaze-diverse-v0': {'lr': 3e-5, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9},
30 | 'antmaze-medium-play-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9},
31 | 'antmaze-medium-diverse-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9},
32 | 'antmaze-large-play-v0': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 350, 'batch_size': 2048, 'expectile': 0.9},
33 | 'antmaze-large-diverse-v0': {'lr': 3e-4, 'alpha': 0.5, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 300, 'batch_size': 2048, 'expectile': 0.9},
34 | 'antmaze-umaze-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9},
35 | 'antmaze-umaze-diverse-v2': {'lr': 3e-5, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 2048, 'expectile': 0.9},
36 | 'antmaze-medium-play-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9},
37 | 'antmaze-medium-diverse-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 400, 'batch_size': 2048, 'expectile': 0.9},
38 | 'antmaze-large-play-v2': {'lr': 3e-4, 'alpha': 1.0, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 350, 'batch_size': 2048, 'expectile': 0.9},
39 | 'antmaze-large-diverse-v2': {'lr': 3e-4, 'alpha': 0.5, 'gamma': 1.0, 'lr_decay': False, 'num_epochs': 300, 'batch_size': 2048, 'expectile': 0.9},
40 | 'pen-human-v1': {'lr': 3e-5, 'alpha': 1500.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 300, 'batch_size': 256, 'expectile': 0.9},
41 | 'pen-cloned-v1': {'lr': 1e-5, 'alpha': 1500.0, 'gamma': 0.0, 'lr_decay': False, 'num_epochs': 200, 'batch_size': 256, 'expectile': 0.7},
42 | 'kitchen-complete-v0': {'lr': 1e-4, 'alpha': 200.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 256, 'expectile': 0.7},
43 | 'kitchen-partial-v0': {'lr': 1e-4, 'alpha': 100.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 1000, 'batch_size': 256, 'expectile': 0.7},
44 | 'kitchen-mixed-v0': {'lr': 3e-4, 'alpha': 200.0, 'gamma': 0.0, 'lr_decay': True, 'num_epochs': 500, 'batch_size': 256, 'expectile': 0.7},
45 | }
46 |
47 | def train_agent(env, state_dim, action_dim, device, output_dir, args):
48 | dataset = d4rl.qlearning_dataset(env)
49 | data_sampler = Data_Sampler(dataset, device, args.reward_tune)
50 | utils.print_banner('Loaded buffer')
51 |
52 | agent = Agent(state_dim=state_dim,
53 | action_dim=action_dim,
54 | action_space=env.action_space,
55 | device=device,
56 | discount=args.discount,
57 | lr=args.lr,
58 | alpha=args.alpha,
59 | lr_decay=args.lr_decay,
60 | lr_maxt=args.num_epochs*args.num_steps_per_epoch,
61 | expectile=args.expectile,
62 | sigma_data=args.sigma_data,
63 | sigma_max=args.sigma_max,
64 | sigma_min=args.sigma_min,
65 | tau=args.tau,
66 | gamma=args.gamma,
67 | repeats=args.repeats)
68 | if args.pretrain_epochs is not None:
69 | agent.load_or_pretrain_models(
70 | dir=str(Path(output_dir)),
71 | replay_buffer=data_sampler,
72 | batch_size=args.batch_size,
73 | pretrain_steps=args.pretrain_epochs*args.num_steps_per_epoch,
74 | num_steps_per_epoch=args.num_steps_per_epoch)
75 |
76 | training_iters = 0
77 | max_timesteps = args.num_epochs * args.num_steps_per_epoch
78 | log_interval = int(args.eval_freq * args.num_steps_per_epoch)
79 |
80 | utils.print_banner(f"Training Start", separator="*", num_star=90)
81 | while (training_iters < max_timesteps + 1):
82 | curr_epoch = int(training_iters // int(args.num_steps_per_epoch))
83 | env.reset()
84 | loss_metric = agent.train(replay_buffer=data_sampler,
85 | batch_size=args.batch_size)
86 | training_iters += 1
87 | # Logging
88 | if training_iters % log_interval == 0:
89 | if loss_metric is not None:
90 | utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90)
91 | logger.record_tabular('Trained Epochs', curr_epoch)
92 | logger.record_tabular('BC Loss', np.mean(loss_metric['bc_loss']))
93 | logger.record_tabular('QL Loss', np.mean(loss_metric['ql_loss']))
94 | logger.record_tabular('Distill Loss', np.mean(loss_metric['distill_loss']))
95 | logger.record_tabular('Actor Loss', np.mean(loss_metric['actor_loss']))
96 | logger.record_tabular('Critic Loss', np.mean(loss_metric['critic_loss']))
97 | logger.record_tabular('Gamma Loss', np.mean(loss_metric['gamma_loss']))
98 |
99 | # Evaluating
100 | eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(agent,
101 | args.env_name,
102 | args.seed,
103 | eval_episodes=args.eval_episodes)
104 | logger.record_tabular('Average Episodic Reward', eval_res)
105 | logger.record_tabular('Average Episodic N-Reward', eval_norm_res)
106 | logger.record_tabular('Average Episodic N-Reward Std', eval_norm_res_std)
107 | logger.dump_tabular()
108 |
109 | if args.save_checkpoints:
110 | agent.save_model(output_dir, curr_epoch)
111 | agent.save_model(output_dir, curr_epoch)
112 |
113 |
114 | # Runs policy for [eval_episodes] episodes and returns average reward
115 | # A fixed seed is used for the eval environment
116 | def eval_policy(policy, env_name, seed, eval_episodes=10):
117 | eval_env = gym.make(env_name)
118 | eval_env.seed(seed + 100)
119 |
120 | scores = []
121 | for _ in range(eval_episodes):
122 | traj_return = 0.
123 | state, done = eval_env.reset(), False
124 | while not done:
125 | action = policy.sample_action(np.array(state))
126 | state, reward, done, _ = eval_env.step(action)
127 | traj_return += reward
128 | scores.append(traj_return)
129 |
130 | avg_reward = np.mean(scores)
131 | std_reward = np.std(scores)
132 |
133 | normalized_scores = [eval_env.get_normalized_score(s) for s in scores]
134 | avg_norm_score = eval_env.get_normalized_score(avg_reward)
135 | std_norm_score = np.std(normalized_scores)
136 |
137 | utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}")
138 | return avg_reward, std_reward, avg_norm_score, std_norm_score
139 |
140 |
141 | if __name__ == "__main__":
142 | parser = argparse.ArgumentParser()
143 | ### Experimental Setups ###
144 | parser.add_argument('--device', default=1, type=int)
145 | parser.add_argument("--env_name", default="antmaze-large-diverse-v0", type=str, help='Mujoco Gym environment')
146 | parser.add_argument("--seed", default=1, type=int, help='random seed (default: 0)')
147 | parser.add_argument("--eval_freq", default=50, type=int)
148 | parser.add_argument("--dir", default="results", type=str)
149 |
150 | parser.add_argument("--pretrain_epochs", default=50, type=int)
151 | parser.add_argument("--repeats", default=1024, type=int)
152 | parser.add_argument("--tau", default=0.005, type=float)
153 |
154 | parser.add_argument("--sigma_max", default=80, type=int)
155 | parser.add_argument("--sigma_min", default=0.002, type=int)
156 | parser.add_argument("--sigma_data", default=0.5, type=int)
157 | parser.add_argument('--save_checkpoints', action='store_true')
158 |
159 | parser.add_argument("--num_steps_per_epoch", default=1000, type=int)
160 | parser.add_argument("--discount", default=0.99, type=float, help='discount factor for reward (default: 0.99)')
161 |
162 | args = parser.parse_args()
163 | args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
164 | args.output_dir = f'{args.dir}'
165 |
166 | if 'antmaze' in args.env_name:
167 | args.reward_tune = 'iql_antmaze'
168 | args.eval_episodes = 100
169 | else:
170 | args.reward_tune = 'no'
171 | args.eval_episodes = 10 if 'v2' in args.env_name else 100
172 |
173 | args.num_epochs = offline_hyperparameters[args.env_name]["num_epochs"]
174 | args.lr = offline_hyperparameters[args.env_name]["lr"]
175 | args.lr_decay = offline_hyperparameters[args.env_name]["lr_decay"]
176 | args.batch_size = offline_hyperparameters[args.env_name]["batch_size"]
177 | args.alpha = offline_hyperparameters[args.env_name]["alpha"]
178 | args.gamma = offline_hyperparameters[args.env_name]["gamma"]
179 | args.expectile = offline_hyperparameters[args.env_name]["expectile"]
180 |
181 | file_name = f'|expect-{args.expectile}'
182 | file_name += f"|alpha-{args.alpha}|gamma-{args.gamma}"
183 | file_name += f'|seed={args.seed}'
184 | file_name += f'|lr={args.lr}'
185 | if args.lr_decay:
186 | file_name += f'|lr_decay'
187 | if args.pretrain_epochs is not None:
188 | file_name += f'|pretrain={args.pretrain_epochs}'
189 | #file_name += f'|{args.env_name}'
190 |
191 | results_dir = os.path.join(args.output_dir, args.env_name, file_name)
192 | if not os.path.exists(results_dir):
193 | os.makedirs(results_dir)
194 | utils.print_banner(f"Saving location: {results_dir}")
195 |
196 | variant = vars(args)
197 | variant.update(version=f"DTQL")
198 |
199 | env = gym.make(args.env_name)
200 |
201 | env.seed(args.seed)
202 | torch.manual_seed(args.seed)
203 | np.random.seed(args.seed)
204 | random.seed(args.seed)
205 | torch.cuda.manual_seed_all(args.seed)
206 | torch.backends.cudnn.deterministic = True
207 | torch.backends.cudnn.benchmark = False
208 |
209 |
210 | state_dim = env.observation_space.shape[0]
211 | action_dim = env.action_space.shape[0]
212 |
213 | variant.update(state_dim=state_dim)
214 | variant.update(action_dim=action_dim)
215 | setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir)
216 | utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}")
217 |
218 | train_agent(env,
219 | state_dim,
220 | action_dim,
221 | args.device,
222 | results_dir,
223 | args)
224 |
--------------------------------------------------------------------------------
/agents/dql_kl.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch.optim.lr_scheduler import CosineAnnealingLR
4 | import torch.nn.functional as F
5 | from pathlib import Path
6 | from diffusion.karras import DiffusionModel
7 | from diffusion.mlps import ScoreNetwork
8 | from agents.model import Critic
9 | from agents.helpers import EMA, get_dmd_loss
10 | import numpy as np
11 | class DQL_KL(object):
12 | def __init__(self,
13 | device,
14 | state_dim,
15 | action_dim,
16 | action_space=None,
17 | discount=0.99,
18 | alpha=1.0,
19 | ema_decay=0.995,
20 | step_start_ema=1000,
21 | update_ema_every=5,
22 | lr=3e-4,
23 | lr_decay=False,
24 | lr_maxt=1000,
25 | sigma_max=80.,
26 | sigma_min=0.002,
27 | sigma_data=0.5,
28 | generation_sigma=2.5,
29 | expectile=0.7,
30 | tau = 0.005,
31 | gamma=0,
32 | repeats=1024
33 | ):
34 | """Init critic networks"""
35 | self.critic = Critic(state_dim, action_dim).to(device)
36 | self.critic_target = copy.deepcopy(self.critic)
37 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
38 |
39 | """"Init behaviour cloning network"""
40 | self.bc_actor = ScoreNetwork(
41 | action_dim=action_dim,
42 | hidden_dim=256,
43 | time_embed_dim=16,
44 | cond_dim=state_dim,
45 | cond_mask_prob=0.0,
46 | num_hidden_layers=4,
47 | output_dim=action_dim,
48 | device=device,
49 | cond_conditional=True
50 | ).to(device)
51 | self.bc_actor_target = copy.deepcopy(self.bc_actor)
52 | self.bc_actor_optimizer = torch.optim.Adam(self.bc_actor.parameters(), lr=lr)
53 |
54 | """Init diffusion schedule"""
55 | self.diffusion = DiffusionModel(
56 | sigma_data=sigma_data,
57 | sigma_min=sigma_min,
58 | sigma_max=sigma_max,
59 | device=device,
60 | clip_denoised=True,
61 | max_action=float(action_space.high[0]))
62 |
63 | """Init one-step policy"""
64 | self.distill_actor = ScoreNetwork(
65 | action_dim=action_dim,
66 | hidden_dim=256,
67 | time_embed_dim=16,
68 | cond_dim=state_dim,
69 | cond_mask_prob=0.0,
70 | num_hidden_layers=4,
71 | output_dim=action_dim,
72 | device=device,
73 | cond_conditional=True
74 | ).to(device)
75 | self.distill_actor_target = copy.deepcopy(self.distill_actor)
76 | self.distill_actor_optimizer = torch.optim.Adam(self.distill_actor.parameters(), lr=lr)
77 |
78 | """Init fake score network"""
79 | self.fake_score = ScoreNetwork(
80 | action_dim=action_dim,
81 | hidden_dim=256,
82 | time_embed_dim=16,
83 | cond_dim=state_dim,
84 | cond_mask_prob=0.0,
85 | num_hidden_layers=4,
86 | output_dim=action_dim,
87 | device=device,
88 | cond_conditional=True
89 | ).to(device)
90 | self.fake_score_optimizer = torch.optim.Adam(self.fake_score.parameters(), lr=lr)
91 |
92 | """Back up training parameters"""
93 | self.generation_sigma = generation_sigma
94 | self.tau = tau
95 | self.lr_decay = lr_decay
96 | self.gamma = gamma
97 | self.repeats = repeats
98 |
99 | self.step = 0
100 | self.step_start_ema = step_start_ema
101 | self.ema = EMA(ema_decay)
102 | self.update_ema_every = update_ema_every
103 |
104 | if lr_decay:
105 | self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.)
106 | self.bc_actor_lr_scheduler = CosineAnnealingLR(self.bc_actor_optimizer, T_max=lr_maxt, eta_min=0.)
107 | self.distill_actor_lr_scheduler = CosineAnnealingLR(self.distill_actor_optimizer, T_max=lr_maxt, eta_min=0.)
108 | self.fake_score_lr_scheduler = CosineAnnealingLR(self.fake_score_optimizer, T_max=lr_maxt, eta_min=0.)
109 |
110 | self.state_dim = state_dim
111 | self.action_dim = action_dim
112 | self.discount = discount
113 | self.alpha = alpha # bc weight
114 | self.expectile = expectile
115 | self.device = device
116 |
117 | def step_ema(self):
118 | if self.step < self.step_start_ema:
119 | return
120 | self.ema.update_model_average(self.distill_actor_target, self.distill_actor)
121 |
122 | def pretrain(self,replay_buffer, batch_size=256,pretrain_steps=50000):
123 | for _ in range(pretrain_steps):
124 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
125 | loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state)
126 | self.bc_actor_optimizer.zero_grad()
127 | loss.backward()
128 | self.bc_actor_optimizer.step()
129 | self.bc_loss = loss
130 |
131 | critic_loss = self.q_v_critic_loss(state,action, next_state, reward, not_done)
132 | self.critic_optimizer.zero_grad()
133 | critic_loss.backward()
134 | self.critic_optimizer.step()
135 |
136 | bc_actor_state_dict = self.bc_actor.state_dict()
137 | self.distill_actor.load_state_dict(bc_actor_state_dict)
138 | self.fake_score.load_state_dict(bc_actor_state_dict)
139 | def load_or_pretrain_models(self, dir, replay_buffer, batch_size, pretrain_steps,num_steps_per_epoch):
140 | # Paths for the models
141 | actor_path = Path(dir) / f'diffusion_pretrained_{pretrain_steps // num_steps_per_epoch}.pth'
142 | critic_path = Path(dir) / f'critic_pretrained_{pretrain_steps // num_steps_per_epoch}.pth'
143 |
144 | # Check if both models exist
145 | if actor_path.exists() and critic_path.exists():
146 | try:
147 | # Load the models
148 | self.bc_actor.load_state_dict(torch.load(actor_path, map_location=self.device))
149 | self.critic.load_state_dict(torch.load(critic_path, map_location=self.device))
150 | bc_actor_state_dict = self.bc_actor.state_dict()
151 | self.distill_actor.load_state_dict(bc_actor_state_dict)
152 | self.fake_score.load_state_dict(torch.load(actor_path))
153 | except Exception as e:
154 | print(f"Failed to load models: {e}")
155 | else:
156 | # Begin pretraining if the models do not exist
157 | print("Models not found, starting pretraining...")
158 | self.pretrain(replay_buffer, batch_size, pretrain_steps)
159 | torch.save(self.bc_actor.state_dict(), actor_path)
160 | torch.save(self.critic.state_dict(), critic_path)
161 | print(f"Saved successfully to {dir}")
162 |
163 | def train(self, replay_buffer, batch_size=256):
164 | # initialize
165 | self.bc_loss = torch.tensor([0.]).to(self.device)
166 | self.critic_loss = torch.tensor([0.]).to(self.device)
167 | metric = {'bc_loss': [], 'distill_loss':[], 'ql_loss': [], 'actor_loss': [], 'critic_loss': [], 'gamma_loss': []}
168 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
169 |
170 | """ Q Training """
171 | critic_loss = self.q_v_critic_loss(state, action, next_state, reward, not_done)
172 |
173 | self.critic_loss = critic_loss
174 | self.critic_optimizer.zero_grad()
175 | critic_loss.backward()
176 | self.critic_optimizer.step()
177 |
178 | """ Diffusion Policy Training """
179 | bc_loss = self.diffusion.diffusion_train_step(self.bc_actor, action, state)
180 | self.bc_actor_optimizer.zero_grad()
181 | bc_loss.backward()
182 | self.bc_actor_optimizer.step()
183 | self.bc_loss = bc_loss
184 |
185 | """Distill Policy Training"""
186 | new_action = self.get_agent_sample(self.distill_actor, given_state=state,
187 | generation_sigma=self.generation_sigma)
188 | distill_loss = get_dmd_loss(self.diffusion,self.bc_actor,self.fake_score,new_action,state)
189 | q_loss = -self.critic.q_min(state, new_action).mean()
190 |
191 | actor_loss = self.alpha * distill_loss + q_loss
192 | self.distill_actor_optimizer.zero_grad()
193 | actor_loss.backward()
194 | self.distill_actor_optimizer.step()
195 |
196 | """Training fake score"""
197 | fake_score_loss = self.diffusion.diffusion_train_step(self.fake_score, new_action.detach(), state)
198 | self.fake_score_optimizer.zero_grad()
199 | fake_score_loss.backward()
200 | self.fake_score_optimizer.step()
201 |
202 | """ Step Target network """
203 | if self.step % self.update_ema_every == 0:
204 | self.step_ema()
205 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
206 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
207 |
208 | self.step += 1
209 |
210 | metric['actor_loss'].append(actor_loss.item())
211 | metric['bc_loss'].append(self.bc_loss.item())
212 | metric['ql_loss'].append(q_loss.item())
213 | metric['critic_loss'].append(self.critic_loss.item())
214 | metric['distill_loss'].append(distill_loss.item())
215 | metric['gamma_loss'].append(np.nan)
216 |
217 | if self.lr_decay:
218 | self.bc_actor_lr_scheduler.step()
219 | self.distill_actor_lr_scheduler.step()
220 | self.critic_lr_scheduler.step()
221 | self.fake_score_lr_scheduler.step()
222 | return metric
223 |
224 | def sample_action(self, state):
225 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
226 | state_rpt = torch.repeat_interleave(state, repeats=self.repeats, dim=0)
227 | with torch.no_grad():
228 | action = self.get_agent_sample(self.distill_actor, given_state=state_rpt,
229 | generation_sigma=self.generation_sigma
230 | )
231 | q_value = self.critic_target.q_min(state_rpt, action).flatten()
232 | idx = torch.multinomial(F.softmax(q_value), 1)
233 | action = action[idx].cpu().data.numpy().flatten()
234 | return action
235 |
236 | def save_model(self, dir, id=None):
237 | if id is not None:
238 | torch.save(self.bc_actor.state_dict(), f'{dir}/bc_actor_{id}.pth')
239 | torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
240 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor_{id}.pth')
241 | torch.save(self.fake_score.state_dict(), f'{dir}/fake_score_{id}.pth')
242 | else:
243 | torch.save(self.bc_actor.state_dict(), f'{dir}/actor.pth')
244 | torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
245 | torch.save(self.distill_actor.state_dict(), f'{dir}/distill_actor.pth')
246 | torch.save(self.fake_score.state_dict(), f'{dir}/fake_score.pth')
247 |
248 | def load_model(self, dir, id=None):
249 | if id is not None:
250 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor_{id}.pth'))
251 | self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth'))
252 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor_{id}.pth'))
253 | self.fake_score.load_state_dict(torch.load(f'{dir}/fake_score_{id}.pth'))
254 | else:
255 | self.bc_actor.load_state_dict(torch.load(f'{dir}/bc_actor.pth'))
256 | self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))
257 | self.distill_actor.load_state_dict(torch.load(f'{dir}/distill_actor.pth'))
258 | self.fake_score.load_state_dict(torch.load(f'{dir}/fake_score.pth'))
259 | print(f"Models loaded successfully from {dir}")
260 |
261 |
262 | def get_agent_sample(self,model,given_state,generation_sigma):
263 | noise = torch.randn((given_state.shape[0],self.action_dim)).to(given_state.device) * generation_sigma
264 | action = model(noise, given_state, torch.tensor([generation_sigma]).to(given_state.device))
265 | return action
266 |
267 | def q_v_critic_loss(self,state,action, next_state, reward, not_done):
268 | def expectile_loss(diff, expectile=0.8):
269 | weight = torch.where(diff > 0, expectile, (1 - expectile))
270 | return weight * (diff ** 2)
271 |
272 | with torch.no_grad():
273 | q = self.critic.q_min(state, action)
274 | v = self.critic.v(state)
275 | value_loss = expectile_loss(q - v, self.expectile).mean()
276 |
277 | current_q1, current_q2 = self.critic(state, action)
278 | with torch.no_grad():
279 | next_v = self.critic.v(next_state)
280 | target_q = (reward + not_done * self.discount * next_v).detach()
281 |
282 | critic_loss = value_loss + F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
283 | return critic_loss
284 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from contextlib import contextmanager
3 | import numpy as np
4 | import os.path as osp
5 | import sys
6 | import datetime
7 | import dateutil.tz
8 | import csv
9 | import json
10 | import pickle
11 | import errno
12 | from collections import OrderedDict
13 | from numbers import Number
14 | import os
15 |
16 | from tabulate import tabulate
17 |
18 | def dict_to_safe_json(d):
19 | """
20 | Convert each value in the dictionary into a JSON'able primitive.
21 | :param d:
22 | :return:
23 | """
24 | new_d = {}
25 | for key, item in d.items():
26 | if safe_json(item):
27 | new_d[key] = item
28 | else:
29 | if isinstance(item, dict):
30 | new_d[key] = dict_to_safe_json(item)
31 | else:
32 | new_d[key] = str(item)
33 | return new_d
34 |
35 |
36 | def safe_json(data):
37 | if data is None:
38 | return True
39 | elif isinstance(data, (bool, int, float)):
40 | return True
41 | elif isinstance(data, (tuple, list)):
42 | return all(safe_json(x) for x in data)
43 | elif isinstance(data, dict):
44 | return all(isinstance(k, str) and safe_json(v) for k, v in data.items())
45 | return False
46 |
47 | def create_exp_name(exp_prefix, exp_id=0, seed=0):
48 | """
49 | Create a semi-unique experiment name that has a timestamp
50 | :param exp_prefix:
51 | :param exp_id:
52 | :return:
53 | """
54 | now = datetime.datetime.now(dateutil.tz.tzlocal())
55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
56 | return "%s_%s_%04d--s-%d" % (exp_prefix, timestamp, exp_id, seed)
57 |
58 | def create_log_dir(
59 | exp_prefix,
60 | exp_id=0,
61 | seed=0,
62 | base_log_dir=None,
63 | include_exp_prefix_sub_dir=True,
64 | ):
65 | """
66 | Creates and returns a unique log directory.
67 | :param exp_prefix: All experiments with this prefix will have log
68 | directories be under this directory.
69 | :param exp_id: The number of the specific experiment run within this
70 | experiment.
71 | :param base_log_dir: The directory where all log should be saved.
72 | :return:
73 | """
74 | exp_name = create_exp_name(exp_prefix, exp_id=exp_id,
75 | seed=seed)
76 | if base_log_dir is None:
77 | base_log_dir = './data'
78 | if include_exp_prefix_sub_dir:
79 | log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name)
80 | else:
81 | log_dir = osp.join(base_log_dir, exp_name)
82 | if osp.exists(log_dir):
83 | print("WARNING: Log directory already exists {}".format(log_dir), flush=True)
84 | os.makedirs(log_dir, exist_ok=True)
85 | return log_dir
86 |
87 |
88 | def setup_logger(
89 | exp_prefix="default",
90 | variant=None,
91 | text_log_file="debug.log",
92 | variant_log_file="variant.json",
93 | tabular_log_file="progress.csv",
94 | snapshot_mode="last",
95 | snapshot_gap=1,
96 | log_tabular_only=False,
97 | log_dir=None,
98 | git_infos=None,
99 | script_name=None,
100 | **create_log_dir_kwargs
101 | ):
102 | """
103 | Set up logger to have some reasonable default settings.
104 | Will save log output to
105 | based_log_dir/exp_prefix/exp_name.
106 | exp_name will be auto-generated to be unique.
107 | If log_dir is specified, then that directory is used as the output dir.
108 | :param exp_prefix: The sub-directory for this specific experiment.
109 | :param variant:
110 | :param text_log_file:
111 | :param variant_log_file:
112 | :param tabular_log_file:
113 | :param snapshot_mode:
114 | :param log_tabular_only:
115 | :param snapshot_gap:
116 | :param log_dir:
117 | :param git_infos:
118 | :param script_name: If set, save the script name to this.
119 | :return:
120 | """
121 | first_time = log_dir is None
122 | if first_time:
123 | log_dir = create_log_dir(exp_prefix, **create_log_dir_kwargs)
124 |
125 | if variant is not None:
126 | logger.log("Variant:")
127 | logger.log(json.dumps(dict_to_safe_json(variant), indent=2))
128 | variant_log_path = osp.join(log_dir, variant_log_file)
129 | logger.log_variant(variant_log_path, variant)
130 |
131 | tabular_log_path = osp.join(log_dir, tabular_log_file)
132 | text_log_path = osp.join(log_dir, text_log_file)
133 |
134 | logger.add_text_output(text_log_path)
135 | if first_time:
136 | logger.add_tabular_output(tabular_log_path)
137 | else:
138 | logger._add_output(tabular_log_path, logger._tabular_outputs,
139 | logger._tabular_fds, mode='a')
140 | for tabular_fd in logger._tabular_fds:
141 | logger._tabular_header_written.add(tabular_fd)
142 | logger.set_snapshot_dir(log_dir)
143 | logger.set_snapshot_mode(snapshot_mode)
144 | logger.set_snapshot_gap(snapshot_gap)
145 | logger.set_log_tabular_only(log_tabular_only)
146 | exp_name = log_dir.split("/")[-1]
147 | logger.push_prefix("[%s] " % exp_name)
148 |
149 | if script_name is not None:
150 | with open(osp.join(log_dir, "script_name.txt"), "w") as f:
151 | f.write(script_name)
152 | return log_dir
153 |
154 |
155 | def create_stats_ordered_dict(
156 | name,
157 | data,
158 | stat_prefix=None,
159 | always_show_all_stats=True,
160 | exclude_max_min=False,
161 | ):
162 | if stat_prefix is not None:
163 | name = "{}{}".format(stat_prefix, name)
164 | if isinstance(data, Number):
165 | return OrderedDict({name: data})
166 |
167 | if len(data) == 0:
168 | return OrderedDict()
169 |
170 | if isinstance(data, tuple):
171 | ordered_dict = OrderedDict()
172 | for number, d in enumerate(data):
173 | sub_dict = create_stats_ordered_dict(
174 | "{0}_{1}".format(name, number),
175 | d,
176 | )
177 | ordered_dict.update(sub_dict)
178 | return ordered_dict
179 |
180 | if isinstance(data, list):
181 | try:
182 | iter(data[0])
183 | except TypeError:
184 | pass
185 | else:
186 | data = np.concatenate(data)
187 |
188 | if (isinstance(data, np.ndarray) and data.size == 1
189 | and not always_show_all_stats):
190 | return OrderedDict({name: float(data)})
191 |
192 | stats = OrderedDict([
193 | (name + ' Mean', np.mean(data)),
194 | (name + ' Std', np.std(data)),
195 | ])
196 | if not exclude_max_min:
197 | stats[name + ' Max'] = np.max(data)
198 | stats[name + ' Min'] = np.min(data)
199 | return stats
200 |
201 |
202 | class TerminalTablePrinter(object):
203 | def __init__(self):
204 | self.headers = None
205 | self.tabulars = []
206 |
207 | def print_tabular(self, new_tabular):
208 | if self.headers is None:
209 | self.headers = [x[0] for x in new_tabular]
210 | else:
211 | assert len(self.headers) == len(new_tabular)
212 | self.tabulars.append([x[1] for x in new_tabular])
213 | self.refresh()
214 |
215 | def refresh(self):
216 | import os
217 | rows, columns = os.popen('stty size', 'r').read().split()
218 | tabulars = self.tabulars[-(int(rows) - 3):]
219 | sys.stdout.write("\x1b[2J\x1b[H")
220 | sys.stdout.write(tabulate(tabulars, self.headers))
221 | sys.stdout.write("\n")
222 |
223 |
224 | class MyEncoder(json.JSONEncoder):
225 | def default(self, o):
226 | if isinstance(o, type):
227 | return {'$class': o.__module__ + "." + o.__name__}
228 | elif isinstance(o, Enum):
229 | return {
230 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
231 | }
232 | elif callable(o):
233 | return {
234 | '$function': o.__module__ + "." + o.__name__
235 | }
236 | return json.JSONEncoder.default(self, o)
237 |
238 |
239 | def mkdir_p(path):
240 | try:
241 | os.makedirs(path)
242 | except OSError as exc: # Python >2.5
243 | if exc.errno == errno.EEXIST and os.path.isdir(path):
244 | pass
245 | else:
246 | raise
247 |
248 |
249 | class Logger(object):
250 | def __init__(self):
251 | self._prefixes = []
252 | self._prefix_str = ''
253 |
254 | self._tabular_prefixes = []
255 | self._tabular_prefix_str = ''
256 |
257 | self._tabular = []
258 |
259 | self._text_outputs = []
260 | self._tabular_outputs = []
261 |
262 | self._text_fds = {}
263 | self._tabular_fds = {}
264 | self._tabular_header_written = set()
265 |
266 | self._snapshot_dir = None
267 | self._snapshot_mode = 'all'
268 | self._snapshot_gap = 1
269 |
270 | self._log_tabular_only = False
271 | self._header_printed = False
272 | self.table_printer = TerminalTablePrinter()
273 |
274 | def reset(self):
275 | self.__init__()
276 |
277 | def _add_output(self, file_name, arr, fds, mode='a'):
278 | if file_name not in arr:
279 | mkdir_p(os.path.dirname(file_name))
280 | arr.append(file_name)
281 | fds[file_name] = open(file_name, mode)
282 |
283 | def _remove_output(self, file_name, arr, fds):
284 | if file_name in arr:
285 | fds[file_name].close()
286 | del fds[file_name]
287 | arr.remove(file_name)
288 |
289 | def push_prefix(self, prefix):
290 | self._prefixes.append(prefix)
291 | self._prefix_str = ''.join(self._prefixes)
292 |
293 | def add_text_output(self, file_name):
294 | self._add_output(file_name, self._text_outputs, self._text_fds,
295 | mode='a')
296 |
297 | def remove_text_output(self, file_name):
298 | self._remove_output(file_name, self._text_outputs, self._text_fds)
299 |
300 | def add_tabular_output(self, file_name, relative_to_snapshot_dir=False):
301 | if relative_to_snapshot_dir:
302 | file_name = osp.join(self._snapshot_dir, file_name)
303 | self._add_output(file_name, self._tabular_outputs, self._tabular_fds,
304 | mode='w')
305 |
306 | def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False):
307 | if relative_to_snapshot_dir:
308 | file_name = osp.join(self._snapshot_dir, file_name)
309 | if self._tabular_fds[file_name] in self._tabular_header_written:
310 | self._tabular_header_written.remove(self._tabular_fds[file_name])
311 | self._remove_output(file_name, self._tabular_outputs, self._tabular_fds)
312 |
313 | def set_snapshot_dir(self, dir_name):
314 | self._snapshot_dir = dir_name
315 |
316 | def get_snapshot_dir(self, ):
317 | return self._snapshot_dir
318 |
319 | def get_snapshot_mode(self, ):
320 | return self._snapshot_mode
321 |
322 | def set_snapshot_mode(self, mode):
323 | self._snapshot_mode = mode
324 |
325 | def get_snapshot_gap(self, ):
326 | return self._snapshot_gap
327 |
328 | def set_snapshot_gap(self, gap):
329 | self._snapshot_gap = gap
330 |
331 | def set_log_tabular_only(self, log_tabular_only):
332 | self._log_tabular_only = log_tabular_only
333 |
334 | def get_log_tabular_only(self, ):
335 | return self._log_tabular_only
336 |
337 | def log(self, s, with_prefix=True, with_timestamp=True):
338 | out = s
339 | if with_prefix:
340 | out = self._prefix_str + out
341 | if with_timestamp:
342 | now = datetime.datetime.now(dateutil.tz.tzlocal())
343 | timestamp = now.strftime('%y-%m-%d.%H:%M') # :%S
344 | out = "%s|%s" % (timestamp, out)
345 | if not self._log_tabular_only:
346 | # Also log to stdout
347 | print(out, flush=True)
348 | for fd in list(self._text_fds.values()):
349 | fd.write(out + '\n')
350 | fd.flush()
351 | sys.stdout.flush()
352 |
353 | def record_tabular(self, key, val):
354 | self._tabular.append((self._tabular_prefix_str + str(key), str(val)))
355 |
356 | def record_dict(self, d, prefix=None):
357 | if prefix is not None:
358 | self.push_tabular_prefix(prefix)
359 | for k, v in d.items():
360 | self.record_tabular(k, v)
361 | if prefix is not None:
362 | self.pop_tabular_prefix()
363 |
364 | def push_tabular_prefix(self, key):
365 | self._tabular_prefixes.append(key)
366 | self._tabular_prefix_str = ''.join(self._tabular_prefixes)
367 |
368 | def pop_tabular_prefix(self, ):
369 | del self._tabular_prefixes[-1]
370 | self._tabular_prefix_str = ''.join(self._tabular_prefixes)
371 |
372 | def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'):
373 | """
374 | Data saved here will always override the last entry
375 |
376 | :param data: Something pickle'able.
377 | """
378 | file_name = osp.join(self._snapshot_dir, file_name)
379 | if mode == 'joblib':
380 | import joblib
381 | joblib.dump(data, file_name, compress=3)
382 | elif mode == 'pickle':
383 | pickle.dump(data, open(file_name, "wb"))
384 | else:
385 | raise ValueError("Invalid mode: {}".format(mode))
386 | return file_name
387 |
388 | def get_table_dict(self, ):
389 | return dict(self._tabular)
390 |
391 | def get_table_key_set(self, ):
392 | return set(key for key, value in self._tabular)
393 |
394 | @contextmanager
395 | def prefix(self, key):
396 | self.push_prefix(key)
397 | try:
398 | yield
399 | finally:
400 | self.pop_prefix()
401 |
402 | @contextmanager
403 | def tabular_prefix(self, key):
404 | self.push_tabular_prefix(key)
405 | yield
406 | self.pop_tabular_prefix()
407 |
408 | def log_variant(self, log_file, variant_data):
409 | mkdir_p(os.path.dirname(log_file))
410 | with open(log_file, "w") as f:
411 | json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder)
412 |
413 | def record_tabular_misc_stat(self, key, values, placement='back'):
414 | if placement == 'front':
415 | prefix = ""
416 | suffix = key
417 | else:
418 | prefix = key
419 | suffix = ""
420 | if len(values) > 0:
421 | self.record_tabular(prefix + "Average" + suffix, np.average(values))
422 | self.record_tabular(prefix + "Std" + suffix, np.std(values))
423 | self.record_tabular(prefix + "Median" + suffix, np.median(values))
424 | self.record_tabular(prefix + "Min" + suffix, np.min(values))
425 | self.record_tabular(prefix + "Max" + suffix, np.max(values))
426 | else:
427 | self.record_tabular(prefix + "Average" + suffix, np.nan)
428 | self.record_tabular(prefix + "Std" + suffix, np.nan)
429 | self.record_tabular(prefix + "Median" + suffix, np.nan)
430 | self.record_tabular(prefix + "Min" + suffix, np.nan)
431 | self.record_tabular(prefix + "Max" + suffix, np.nan)
432 |
433 | def dump_tabular(self, *args, **kwargs):
434 | wh = kwargs.pop("write_header", None)
435 | if len(self._tabular) > 0:
436 | if self._log_tabular_only:
437 | self.table_printer.print_tabular(self._tabular)
438 | else:
439 | for line in tabulate(self._tabular).split('\n'):
440 | self.log(line, *args, **kwargs)
441 | tabular_dict = dict(self._tabular)
442 | # Also write to the csv files
443 | # This assumes that the keys in each iteration won't change!
444 | for tabular_fd in list(self._tabular_fds.values()):
445 | writer = csv.DictWriter(tabular_fd,
446 | fieldnames=list(tabular_dict.keys()))
447 | if wh or (
448 | wh is None and tabular_fd not in self._tabular_header_written):
449 | writer.writeheader()
450 | self._tabular_header_written.add(tabular_fd)
451 | writer.writerow(tabular_dict)
452 | tabular_fd.flush()
453 | del self._tabular[:]
454 |
455 | def pop_prefix(self, ):
456 | del self._prefixes[-1]
457 | self._prefix_str = ''.join(self._prefixes)
458 |
459 | def save_itr_params(self, itr, params):
460 | if self._snapshot_dir:
461 | if self._snapshot_mode == 'all':
462 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
463 | pickle.dump(params, open(file_name, "wb"))
464 | elif self._snapshot_mode == 'last':
465 | # override previous params
466 | file_name = osp.join(self._snapshot_dir, 'params.pkl')
467 | pickle.dump(params, open(file_name, "wb"))
468 | elif self._snapshot_mode == "gap":
469 | if itr % self._snapshot_gap == 0:
470 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
471 | pickle.dump(params, open(file_name, "wb"))
472 | elif self._snapshot_mode == "gap_and_last":
473 | if itr % self._snapshot_gap == 0:
474 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
475 | pickle.dump(params, open(file_name, "wb"))
476 | file_name = osp.join(self._snapshot_dir, 'params.pkl')
477 | pickle.dump(params, open(file_name, "wb"))
478 | elif self._snapshot_mode == 'none':
479 | pass
480 | else:
481 | raise NotImplementedError
482 |
483 |
484 | logger = Logger()
485 |
486 |
--------------------------------------------------------------------------------