├── utils ├── __init__.py ├── pytorch_util.py ├── data_sampler.py ├── utils.py └── logger.py ├── agents ├── __init__.py ├── bc_diffusion.py ├── helpers.py ├── diffusion.py ├── sdes.py ├── eql_diffusion.py └── model.py ├── .gitignore ├── README.md ├── main.py └── tabulate.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | */__pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Entropy-regularized Diffusion Policy with Q-Ensembles for Offline Reinforcement Learning 2 | 3 | 4 | ## Dependenices 5 | 6 | * OS: Ubuntu 20.04 7 | * nvidia : 8 | - cuda: 11.7 9 | - cudnn: 8.5.0 10 | * python3 11 | * pytorch >= 1.13.0 12 | 13 | ## How to run the code 14 | 15 | ```.bash 16 | python main.py --env_name walker2d-medium-expert-v2 --device 0 --lr_decay 17 | ``` 18 | 19 | All the hyperparameters are fixed in the `main.py`. -------------------------------------------------------------------------------- /utils/pytorch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def soft_update_from_to(source, target, tau): 6 | for target_param, param in zip(target.parameters(), source.parameters()): 7 | target_param.data.copy_( 8 | target_param.data * (1.0 - tau) + param.data * tau 9 | ) 10 | 11 | 12 | def copy_model_params_from_to(source, target): 13 | for target_param, param in zip(target.parameters(), source.parameters()): 14 | target_param.data.copy_(param.data) 15 | 16 | 17 | def fanin_init(tensor, scale=1): 18 | size = tensor.size() 19 | if len(size) == 2: 20 | fan_in = size[0] 21 | elif len(size) > 2: 22 | fan_in = np.prod(size[1:]) 23 | else: 24 | raise Exception("Shape must be have dimension at least 2.") 25 | bound = scale / np.sqrt(fan_in) 26 | return tensor.data.uniform_(-bound, bound) 27 | 28 | 29 | def orthogonal_init(tensor, gain=0.01): 30 | torch.nn.init.orthogonal_(tensor, gain=gain) 31 | 32 | 33 | def fanin_init_weights_like(tensor): 34 | size = tensor.size() 35 | if len(size) == 2: 36 | fan_in = size[0] 37 | elif len(size) > 2: 38 | fan_in = np.prod(size[1:]) 39 | else: 40 | raise Exception("Shape must be have dimension at least 2.") 41 | bound = 1. / np.sqrt(fan_in) 42 | new_tensor = torch.FloatTensor(tensor.size()) 43 | new_tensor.uniform_(-bound, bound) 44 | return new_tensor 45 | 46 | -------------------------------------------------------------------------------- /utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class Data_Sampler(object): 8 | def __init__(self, data, device, reward_tune='no'): 9 | 10 | self.device = device 11 | self.state = torch.from_numpy(data['observations']).float().to(self.device) 12 | self.action = torch.from_numpy(data['actions']).float().to(self.device) 13 | self.next_state = torch.from_numpy(data['next_observations']).float().to(self.device) 14 | reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device) 15 | self.not_done = 1. - torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device) 16 | 17 | self.size = self.state.shape[0] 18 | self.state_dim = self.state.shape[1] 19 | self.action_dim = self.action.shape[1] 20 | 21 | if reward_tune == 'normalize': 22 | reward = (reward - reward.mean()) / reward.std() 23 | elif reward_tune == 'iql_antmaze': 24 | reward = reward - 1.0 25 | elif reward_tune == 'iql_locomotion': 26 | reward = iql_normalize(reward, self.not_done) 27 | elif reward_tune == 'cql_antmaze': 28 | reward = (reward - 0.5) * 4.0 29 | elif reward_tune == 'antmaze': 30 | reward = (reward - 0.25) * 2.0 31 | self.reward = reward 32 | 33 | def sample(self, batch_size): 34 | ind = torch.randint(0, self.size, size=(batch_size,)) 35 | return ( 36 | self.state[ind], self.action[ind], self.next_state[ind], self.reward[ind], self.not_done[ind] 37 | ) 38 | 39 | 40 | def iql_normalize(reward, not_done): 41 | trajs_rt = [] 42 | episode_return = 0.0 43 | for i in range(len(reward)): 44 | episode_return += reward[i] 45 | if not not_done[i]: 46 | trajs_rt.append(episode_return) 47 | episode_return = 0.0 48 | rt_max, rt_min = torch.max(torch.tensor(trajs_rt)), torch.min(torch.tensor(trajs_rt)) 49 | reward /= (rt_max - rt_min) 50 | reward *= 1000. 51 | return reward 52 | -------------------------------------------------------------------------------- /agents/bc_diffusion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils.logger import logger 7 | 8 | from agents.sdes import TractableSDE 9 | from agents.diffusion import DiffusionPolicy 10 | from agents.model import Critic 11 | 12 | from tqdm import tqdm 13 | 14 | 15 | class Diffusion_BC(object): 16 | def __init__(self, 17 | state_dim, 18 | action_dim, 19 | max_action, 20 | device, 21 | discount, 22 | tau, 23 | schedule='cosine', 24 | n_timesteps=100, 25 | lr=2e-4, 26 | loss_type='MLL', 27 | action_clip=False 28 | ): 29 | 30 | self.sde = TractableSDE(n_timesteps, schedule, action_clip, device=device) 31 | self.actor = DiffusionPolicy(state_dim=state_dim, action_dim=action_dim, max_action=max_action, 32 | sde=self.sde, n_timesteps=n_timesteps, loss_type=loss_type).to(device) 33 | 34 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) 35 | 36 | self.max_action = max_action 37 | self.action_dim = action_dim 38 | self.discount = discount 39 | self.tau = tau 40 | self.device = device 41 | 42 | def train(self, replay_buffer, iterations, batch_size=100, log_writer=None): 43 | 44 | metric = {'bc_loss': [], 'ql_loss': [], 'actor_loss': [], 'critic_loss': []} 45 | for _ in tqdm(range(iterations)): 46 | # Sample replay buffer / batch 47 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 48 | 49 | loss, _ = self.actor.loss(state, action) 50 | 51 | self.actor_optimizer.zero_grad() 52 | loss.backward() 53 | self.actor_optimizer.step() 54 | 55 | metric['actor_loss'].append(0.) 56 | metric['bc_loss'].append(loss.item()) 57 | metric['ql_loss'].append(0.) 58 | metric['critic_loss'].append(0.) 59 | 60 | return metric 61 | 62 | def sample_action(self, state): 63 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 64 | with torch.no_grad(): 65 | action = self.actor(state) 66 | return action.cpu().data.numpy().flatten() 67 | 68 | def save_model(self, dir, id=None): 69 | if id is not None: 70 | torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth') 71 | else: 72 | torch.save(self.actor.state_dict(), f'{dir}/actor.pth') 73 | 74 | def load_model(self, dir, id=None): 75 | if id is not None: 76 | self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth')) 77 | else: 78 | self.actor.load_state_dict(torch.load(f'{dir}/actor.pth')) 79 | 80 | -------------------------------------------------------------------------------- /agents/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SinusoidalPosEmb(nn.Module): 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.dim = dim 13 | 14 | def forward(self, x): 15 | device = x.device 16 | half_dim = self.dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 19 | emb = x[:, None] * emb[None, :] 20 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 21 | return emb 22 | 23 | #-----------------------------------------------------------------------------# 24 | #------------------------------ theta schedules ------------------------------# 25 | #-----------------------------------------------------------------------------# 26 | 27 | 28 | def constant_theta_schedule(timesteps, v=1.): 29 | """ 30 | constant schedule 31 | """ 32 | return torch.ones(timesteps, dtype=torch.float32) 33 | 34 | def linear_theta_schedule(timesteps): 35 | """ 36 | linear schedule 37 | """ 38 | scale = 1000 / timesteps 39 | beta_start = scale * 0.0001 40 | beta_end = scale * 0.02 41 | return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32) 42 | 43 | def cosine_theta_schedule(timesteps, s=0.008): 44 | """ 45 | cosine schedule 46 | """ 47 | timesteps = timesteps + 1 # for truncating from 1 to -1 48 | steps = timesteps + 1 49 | x = torch.linspace(0, timesteps, steps, dtype=torch.float32) 50 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 51 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 52 | betas = 1 - alphas_cumprod[1:-1] 53 | return betas 54 | 55 | def vp_beta_schedule(timesteps, dtype=torch.float32): 56 | t = np.arange(1, timesteps + 1) 57 | T = timesteps 58 | b_max = 10. 59 | b_min = 0.1 60 | alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2) 61 | betas = 1 - alpha 62 | return torch.tensor(betas, dtype=dtype) 63 | 64 | #-----------------------------------------------------------------------------# 65 | #---------------------------------- losses -----------------------------------# 66 | #-----------------------------------------------------------------------------# 67 | 68 | class WeightedLoss(nn.Module): 69 | 70 | def __init__(self): 71 | super().__init__() 72 | 73 | def forward(self, pred, targ, weights=1.0): 74 | ''' 75 | pred, targ : tensor [ batch_size x action_dim ] 76 | ''' 77 | loss = self._loss(pred, targ) 78 | weighted_loss = (loss * weights).mean() 79 | return weighted_loss 80 | 81 | class WeightedL1(WeightedLoss): 82 | 83 | def _loss(self, pred, targ): 84 | return torch.abs(pred - targ) 85 | 86 | class WeightedL2(WeightedLoss): 87 | 88 | def _loss(self, pred, targ): 89 | return F.mse_loss(pred, targ, reduction='none') 90 | 91 | 92 | Losses = { 93 | 'l1': WeightedL1, 94 | 'l2': WeightedL2, 95 | } 96 | 97 | 98 | class EMA(): 99 | ''' 100 | empirical moving average 101 | ''' 102 | def __init__(self, beta): 103 | super().__init__() 104 | self.beta = beta 105 | 106 | def update_model_average(self, ma_model, current_model): 107 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 108 | old_weight, up_weight = ma_params.data, current_params.data 109 | ma_params.data = self.update_average(old_weight, up_weight) 110 | 111 | def update_average(self, old, new): 112 | if old is None: 113 | return new 114 | return old * self.beta + (1 - self.beta) * new 115 | -------------------------------------------------------------------------------- /agents/diffusion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from agents.sdes import TractableSDE 8 | from agents.model import DiffusionMLP 9 | 10 | # loss type is MLL(maximum likelihood loss) or NML(noise-matching loss) 11 | class DiffusionPolicy(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action, sde=None, 13 | n_timesteps=100, loss_type='MLL', predict_epsilon=True): 14 | super().__init__() 15 | 16 | self.state_dim = state_dim 17 | self.action_dim = action_dim 18 | self.max_action = max_action 19 | self.n_timesteps = n_timesteps 20 | 21 | self.predict_epsilon = predict_epsilon 22 | self.loss_type = loss_type 23 | 24 | self.model = DiffusionMLP(state_dim=state_dim, action_dim=action_dim) 25 | self.sde = sde 26 | 27 | # ------------------------------------------ sampling ------------------------------------------# 28 | 29 | def clip_action(self, action): 30 | return action.clamp(-self.max_action, self.max_action) 31 | 32 | def predict_start_from_score(self, a_t, t, score): 33 | action = self.sde.predict_start_from_score(a_t, t, score) 34 | return action 35 | 36 | def previous_diffusion_action(self, a_t, t=0): 37 | # action = self.sde.forward_step(a_t, t) 38 | action = a_t + self.sde.drift(a_t, t) * self.sde.dt 39 | return action 40 | 41 | def sample_noise(self, tensor): 42 | batch_size = tensor.shape[0] 43 | return torch.randn(batch_size, self.action_dim).to(tensor.device) 44 | 45 | def sample_action(self, state, mode): 46 | score_fn = self.model if mode == 'posterior' else self.score_fn 47 | noise = self.sample_noise(state) 48 | action = self.sde.reverse(noise, score_fn, mode=mode, clip_value=self.max_action, state=state) 49 | return self.clip_action(action) 50 | 51 | def score_fn(self, a_t, t, state): 52 | noise = self.model(a_t, t, state=state) 53 | return self.sde.compute_score_from_noise(noise, t) 54 | 55 | def entropy(self, a_0): 56 | a_1 = self.previous_diffusion_action(a_0, t=0) 57 | log_p = self.sde.log_reverse_transition(a_0, a_1, self.sample_noise(a_0), 1, self.n_timesteps-1) 58 | return -log_p.sum(dim=1, keepdim=True) 59 | 60 | # ------------------------------------------ training ------------------------------------------# 61 | 62 | def random_action_states(self, a_0): 63 | a_t, t, noise = self.sde.generate_random_states(a_0) 64 | return a_t, t, noise 65 | 66 | # noise matching loss 67 | def NML_loss(self, state, a_0): 68 | a_t, t, noise = self.random_action_states(a_0) 69 | noise_pred = self.model(a_t, t.squeeze(1), state) 70 | start_pred = self.sde.predict_start_from_noise(a_t, t, noise_pred) 71 | 72 | return F.mse_loss(noise_pred, noise), start_pred 73 | 74 | # maximum likelihood loss 75 | def MLL_loss(self, state, a_0): 76 | a_t, t, noise = self.random_action_states(a_0) 77 | noise_pred = self.model(a_t, t.squeeze(1), state) 78 | score_pred = self.sde.compute_score_from_noise(noise_pred, t) 79 | start_pred = self.sde.predict_start_from_noise(a_t, t, noise_pred) 80 | 81 | a_pre_pred = a_t - self.sde.sde_reverse_drift(a_t, t, score_pred) * self.sde.dt 82 | a_pre_target = self.sde.reverse_optimum_step(a_t, a_0, t) 83 | return F.mse_loss(a_pre_pred, a_pre_target), start_pred 84 | 85 | def loss(self, state, a_0): 86 | if self.loss_type == 'MLL': 87 | return self.MLL_loss(state, a_0) 88 | elif self.loss_type == 'NML': 89 | return self.NML_loss(state, a_0) 90 | else: 91 | print('Now only support MLL and NML loss') 92 | 93 | def forward(self, state, mode='posterior'): # posterior 94 | return self.sample_action(state, mode=mode) 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def print_banner(s, separator="-", num_star=60): 8 | print(separator * num_star, flush=True) 9 | print(s, flush=True) 10 | print(separator * num_star, flush=True) 11 | 12 | 13 | class Progress: 14 | 15 | def __init__(self, total, name='Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 16 | self.total = total 17 | self.name = name 18 | self.ncol = ncol 19 | self.max_length = max_length 20 | self.indent = indent 21 | self.line_width = line_width 22 | self._speed_update_freq = speed_update_freq 23 | 24 | self._step = 0 25 | self._prev_line = '\033[F' 26 | self._clear_line = ' ' * self.line_width 27 | 28 | self._pbar_size = self.ncol * self.max_length 29 | self._complete_pbar = '#' * self._pbar_size 30 | self._incomplete_pbar = ' ' * self._pbar_size 31 | 32 | self.lines = [''] 33 | self.fraction = '{} / {}'.format(0, self.total) 34 | 35 | self.resume() 36 | 37 | def update(self, description, n=1): 38 | self._step += n 39 | if self._step % self._speed_update_freq == 0: 40 | self._time0 = time.time() 41 | self._step0 = self._step 42 | self.set_description(description) 43 | 44 | def resume(self): 45 | self._skip_lines = 1 46 | print('\n', end='') 47 | self._time0 = time.time() 48 | self._step0 = self._step 49 | 50 | def pause(self): 51 | self._clear() 52 | self._skip_lines = 1 53 | 54 | def set_description(self, params=[]): 55 | 56 | if type(params) == dict: 57 | params = sorted([ 58 | (key, val) 59 | for key, val in params.items() 60 | ]) 61 | 62 | ############ 63 | # Position # 64 | ############ 65 | self._clear() 66 | 67 | ########### 68 | # Percent # 69 | ########### 70 | percent, fraction = self._format_percent(self._step, self.total) 71 | self.fraction = fraction 72 | 73 | ######### 74 | # Speed # 75 | ######### 76 | speed = self._format_speed(self._step) 77 | 78 | ########## 79 | # Params # 80 | ########## 81 | num_params = len(params) 82 | nrow = math.ceil(num_params / self.ncol) 83 | params_split = self._chunk(params, self.ncol) 84 | params_string, lines = self._format(params_split) 85 | self.lines = lines 86 | 87 | description = '{} | {}{}'.format(percent, speed, params_string) 88 | print(description) 89 | self._skip_lines = nrow + 1 90 | 91 | def append_description(self, descr): 92 | self.lines.append(descr) 93 | 94 | def _clear(self): 95 | position = self._prev_line * self._skip_lines 96 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 97 | print(position, end='') 98 | print(empty) 99 | print(position, end='') 100 | 101 | def _format_percent(self, n, total): 102 | if total: 103 | percent = n / float(total) 104 | 105 | complete_entries = int(percent * self._pbar_size) 106 | incomplete_entries = self._pbar_size - complete_entries 107 | 108 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 109 | fraction = '{} / {}'.format(n, total) 110 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent * 100)) 111 | else: 112 | fraction = '{}'.format(n) 113 | string = '{} iterations'.format(n) 114 | return string, fraction 115 | 116 | def _format_speed(self, n): 117 | num_steps = n - self._step0 118 | t = time.time() - self._time0 119 | speed = num_steps / t 120 | string = '{:.1f} Hz'.format(speed) 121 | if num_steps > 0: 122 | self._speed = string 123 | return string 124 | 125 | def _chunk(self, l, n): 126 | return [l[i:i + n] for i in range(0, len(l), n)] 127 | 128 | def _format(self, chunks): 129 | lines = [self._format_chunk(chunk) for chunk in chunks] 130 | lines.insert(0, '') 131 | padding = '\n' + ' ' * self.indent 132 | string = padding.join(lines) 133 | return string, lines 134 | 135 | def _format_chunk(self, chunk): 136 | line = ' | '.join([self._format_param(param) for param in chunk]) 137 | return line 138 | 139 | def _format_param(self, param): 140 | k, v = param 141 | return '{} : {}'.format(k, v)[:self.max_length] 142 | 143 | def stamp(self): 144 | if self.lines != ['']: 145 | params = ' | '.join(self.lines) 146 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 147 | self._clear() 148 | print(string, end='\n') 149 | self._skip_lines = 1 150 | else: 151 | self._clear() 152 | self._skip_lines = 0 153 | 154 | def close(self): 155 | self.pause() 156 | 157 | 158 | class Silent: 159 | 160 | def __init__(self, *args, **kwargs): 161 | pass 162 | 163 | def __getattr__(self, attr): 164 | return lambda *args: None 165 | 166 | 167 | class EarlyStopping(object): 168 | def __init__(self, tolerance=5, min_delta=0): 169 | self.tolerance = tolerance 170 | self.min_delta = min_delta 171 | self.counter = 0 172 | self.early_stop = False 173 | 174 | def __call__(self, train_loss, validation_loss): 175 | if (validation_loss - train_loss) > self.min_delta: 176 | self.counter += 1 177 | if self.counter >= self.tolerance: 178 | return True 179 | else: 180 | self.counter = 0 181 | return False 182 | -------------------------------------------------------------------------------- /agents/sdes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import abc 3 | import math 4 | 5 | from agents.helpers import (constant_theta_schedule, 6 | linear_theta_schedule, 7 | cosine_theta_schedule, 8 | vp_beta_schedule, 9 | ) 10 | 11 | class SDEBase(abc.ABC): 12 | def __init__(self, T, device): 13 | self.T = T 14 | self.dt = 1 / T 15 | self.device = device 16 | 17 | @abc.abstractmethod 18 | def drift(self, x_t, t): 19 | pass 20 | 21 | @abc.abstractmethod 22 | def dispersion(self, x_t, t): 23 | pass 24 | 25 | ################################################################################ 26 | 27 | def sde_reverse_drift(self, x_t, t, score): 28 | return self.drift(x_t, t) - self.dispersion(x_t, t)**2 * score 29 | 30 | def ode_reverse_drift(self, x_t, t, score): 31 | return self.drift(x_t, t) - 0.5 * self.dispersion(x_t, t)**2 * score 32 | 33 | def dw(self, x_t): # Wiener 34 | return torch.randn_like(x_t) * math.sqrt(self.dt) 35 | 36 | def forward_step(self, x_t, t): 37 | dx = self.drift(x_t, t) * self.dt + self.dispersion(x_t, t) * self.dw(x_t) 38 | return x_t + dx 39 | 40 | def reverse_sde_step(self, x_t, t, score): 41 | dx = self.sde_reverse_drift(x_t, t, score) * self.dt \ 42 | + self.dispersion(x_t, t) * self.dw(x_t) * (t > 0) 43 | return x_t - dx 44 | 45 | def reverse_ode_step(self, x_t, t, score): 46 | return x_t - self.ode_reverse_drift(x_t, t, score) * self.dt 47 | 48 | def forward(self, x_0): 49 | x_t = x_0 50 | for t in range(self.T): 51 | x_t = self.forward_step(x_t, t) 52 | return x_t 53 | 54 | def reverse(self, x_t, score_fn, mode='sde', **kwargs): 55 | for t in reversed(range(self.T)): 56 | score = score_fn(x_t, t, **kwargs) 57 | if mode == 'sde': 58 | x_t = self.reverse_sde_step(x_t, t, score) 59 | elif mode == 'ode': 60 | x_t = self.reverse_ode_step(x_t, t, score) 61 | else: 62 | print('the mode should be sde or ode') 63 | break 64 | return x_t 65 | 66 | 67 | #-----------------------------------------------------------------------------# 68 | #------------------------------- Tractable SDE -------------------------------# 69 | #-----------------------------------------------------------------------------# 70 | 71 | # mean-reverting SDE with 'mu=0' 72 | class TractableSDE(SDEBase): 73 | def __init__(self, T=100, schedule='cosine', action_clip=False, device=None): 74 | super().__init__(T=T, device=device) 75 | self.action_clip = action_clip 76 | 77 | # beta and sigma for the SDE 78 | if schedule == 'cosine': 79 | self.thetas = cosine_theta_schedule(T).to(device) 80 | elif schedule == 'linear': 81 | self.thetas = linear_theta_schedule(T).to(device) 82 | elif schedule == 'constant': 83 | self.thetas = constant_theta_schedule(T).to(device) 84 | elif schedule == 'vp': 85 | self.thetas = vp_beta_schedule(T).to(device) 86 | else: 87 | print('Not implemented such schedule yet!!!') 88 | 89 | self.sigmas = torch.sqrt(2 * self.thetas) 90 | 91 | # recompute dt to make sure the SDE converges to a Gaussian(0, 1) 92 | thetas_cumsum = torch.cumsum(self.thetas, dim=0) 93 | self.dt = -math.log(1e-3) / thetas_cumsum[-1] 94 | self.thetas_cumsum = torch.cat([torch.zeros(1).to(device), thetas_cumsum]) 95 | 96 | # compute theta_bar and SDE's variance/standard 97 | self.thetas_bar = thetas_cumsum * self.dt 98 | self.vars = 1 - torch.exp(-2 * self.thetas_bar) 99 | self.stds = torch.sqrt(self.vars) 100 | 101 | self.posterior_vars = (1 - torch.exp(-2 * self.thetas * self.dt)) \ 102 | * (1 - torch.exp(-2 * self.thetas_cumsum[:-1] * self.dt)) \ 103 | / (1 - torch.exp(-2 * self.thetas_cumsum[1:] * self.dt)) 104 | 105 | self.log_posterior_vars = torch.log(torch.clamp(self.posterior_vars, min=1e-20)) 106 | 107 | 108 | def mean(self, x_0, t): 109 | return x_0 * torch.exp(-self.thetas_bar[t]) 110 | 111 | def variance(self, t): 112 | return self.vars[t] 113 | 114 | def drift(self, x_t, t): 115 | return -self.thetas[t] * x_t 116 | 117 | def dispersion(self, x_t, t): 118 | return self.sigmas[t] 119 | 120 | def forward_state(self, x_0, t, noise): 121 | return self.mean(x_0, t) + self.stds[t] * noise 122 | 123 | def ground_truth_score(self, x_t, t, x_0): 124 | return -(x_t - self.mean(x_0, t)) / self.variance(t) 125 | 126 | def compute_score_from_noise(self, noise, t): 127 | return -noise / self.stds[t] 128 | 129 | def predict_start_from_score(self, x_t, t, score): 130 | return (x_t + self.variance(t) * score) * torch.exp(self.thetas_bar[t]) 131 | 132 | def predict_start_from_noise(self, x_t, t, noise): 133 | return (x_t - self.stds[t] * noise) * self.thetas_bar[t].exp() 134 | 135 | def log_forward_transition(self, x1, x2, t1, t2): # t1 < t2 136 | tb = self.thetas_bar[t2] - self.thetas_bar[t1] 137 | log_p = torch.log(2*math.pi * (1 - torch.exp(-2 * tb))) \ 138 | + (x1 - x2 * torch.exp(-tb))**2 / (1 - torch.exp(-2 * tb)) 139 | return -0.5 * log_p 140 | 141 | def log_reverse_transition(self, x_0, x1, x2, t1, t2): # t1 < t2 142 | return self.log_forward_transition(x1, x2, t1, t2) \ 143 | + self.log_forward_transition(x_0, x1, 0, t1) \ 144 | - self.log_forward_transition(x_0, x2, 0, t2) 145 | 146 | # sample states for training 147 | def generate_random_states(self, x_0): 148 | noise = torch.randn_like(x_0) 149 | t = torch.randint(0, self.T, (x_0.shape[0], 1), device=x_0.device).long() 150 | xt = self.forward_state(x_0, t, noise) 151 | return xt, t, noise 152 | 153 | # optimum x_{t-1} 154 | def reverse_optimum_step(self, x_t, x_0, t): 155 | A = torch.exp(-self.thetas[t] * self.dt) 156 | # self.thetas_cumsum has length T+1 157 | B = torch.exp(-self.thetas_cumsum[t+1] * self.dt) 158 | C = torch.exp(-self.thetas_cumsum[t] * self.dt) 159 | 160 | term1 = A * (1 - C**2) / (1 - B**2) 161 | term2 = C * (1 - A**2) / (1 - B**2) 162 | return term1 * x_t + term2 * x_0 163 | 164 | def optimal_reverse(self, x_t, x_0): 165 | x = x_t.clone() 166 | for t in reversed(range(self.T)): 167 | x = self.reverse_optimum_step(x, x_0, t) 168 | 169 | return x 170 | 171 | def posterior_step(self, x_t, t, noise, clip_value=1.0): 172 | x_0 = self.predict_start_from_noise(x_t, t, noise) 173 | if self.action_clip and clip_value > 0: 174 | x_0.clamp_(-clip_value, clip_value) 175 | mean_t = self.reverse_optimum_step(x_t, x_0, t) 176 | std_t = (0.5 * self.log_posterior_vars[t]).exp() 177 | noise = torch.randn_like(x_t) 178 | return mean_t + std_t * noise * (t > 0) 179 | 180 | def reverse(self, x_t, score_fn, mode='posterior', clip_value=1.0, **kwargs): 181 | for t in reversed(range(self.T)): 182 | score = score_fn(x_t, t, **kwargs) 183 | if mode == 'sde': 184 | x_t = self.reverse_sde_step(x_t, t, score) 185 | elif mode == 'ode': 186 | x_t = self.reverse_ode_step(x_t, t, score) 187 | elif mode == 'posterior': 188 | x_t = self.posterior_step(x_t, t, score, clip_value) 189 | else: 190 | print('the mode should be sde or ode') 191 | break 192 | return x_t 193 | 194 | -------------------------------------------------------------------------------- /agents/eql_diffusion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from utils.logger import logger 8 | 9 | from agents.sdes import TractableSDE 10 | from agents.diffusion import DiffusionPolicy 11 | from agents.model import EnsembleCritic 12 | 13 | from agents.helpers import EMA 14 | from tqdm import tqdm 15 | 16 | 17 | class Diffusion_EQL(object): 18 | def __init__(self, 19 | state_dim, 20 | action_dim, 21 | max_action, 22 | device, 23 | discount, 24 | tau, 25 | max_q_backup=False, 26 | eta=1.0, 27 | schedule='cosine', 28 | n_timesteps=100, 29 | ema_decay=0.995, 30 | step_start_ema=1000, 31 | update_ema_every=5, 32 | lr=3e-4, 33 | lr_decay=False, 34 | lr_maxt=1000, 35 | grad_norm=1.0, 36 | ent_coef=0.2, 37 | num_critics=4, 38 | pess_method='lcb', # ['min', 'lcb'] 39 | lcb_coef=4., # [4, 8] 40 | loss_type='MLL', 41 | action_clip=False 42 | ): 43 | 44 | self.sde = TractableSDE(n_timesteps, schedule, action_clip, device=device) 45 | self.actor = DiffusionPolicy(state_dim=state_dim, action_dim=action_dim, max_action=max_action, 46 | sde=self.sde, n_timesteps=n_timesteps, loss_type=loss_type).to(device) 47 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) 48 | 49 | self.lr_decay = lr_decay 50 | self.grad_norm = grad_norm 51 | 52 | self.pess_method = pess_method 53 | self.lcb_coef = lcb_coef 54 | 55 | 56 | self.step = 0 57 | self.step_start_ema = step_start_ema 58 | self.ema = EMA(ema_decay) 59 | self.ema_model = copy.deepcopy(self.actor) 60 | self.update_ema_every = update_ema_every 61 | self.num_critics = num_critics 62 | 63 | self.critic = EnsembleCritic(state_dim, action_dim, num_critics=num_critics).to(device) 64 | self.critic_target = copy.deepcopy(self.critic) 65 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 66 | 67 | if lr_decay: 68 | self.actor_lr_scheduler = CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt, eta_min=0.) 69 | self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.) 70 | 71 | self.state_dim = state_dim 72 | self.max_action = max_action 73 | self.action_dim = action_dim 74 | self.discount = discount 75 | self.tau = tau 76 | self.eta = eta # q_learning weight 77 | self.device = device 78 | self.max_q_backup = max_q_backup 79 | self.ent_coef = torch.tensor(ent_coef).to(self.device) 80 | 81 | 82 | def step_ema(self): 83 | if self.step < self.step_start_ema: 84 | return 85 | self.ema.update_model_average(self.ema_model, self.actor) 86 | 87 | def train(self, replay_buffer, iterations, batch_size=100, log_writer=None): 88 | 89 | metric = {'bc_loss': [], 'ql_loss': [], 'actor_loss': [], 'critic_loss': [], 'entropy': []} 90 | 91 | for _ in tqdm(range(iterations)): 92 | # Sample replay buffer / batch 93 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 94 | 95 | ################## 96 | """ Q Training """ 97 | ################## 98 | current_q_values = self.critic(state, action) 99 | 100 | if self.max_q_backup: 101 | next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0) 102 | next_action_rpt = self.ema_model(next_state_rpt) 103 | q_next_rpt = self.critic_target(next_state_rpt, next_action_rpt) 104 | q_next = q_next_rpt.view(batch_size, 10, -1).max(dim=1)[0] 105 | else: 106 | next_action = self.ema_model(next_state) 107 | q_next = self.critic_target(next_state, next_action) # shape: batch_siz, num_critic 108 | 109 | target_q = (reward + not_done * self.discount * q_next).detach() 110 | critic_loss = F.mse_loss(current_q_values, target_q) 111 | 112 | self.critic_optimizer.zero_grad() 113 | critic_loss.backward() 114 | if self.grad_norm > 0: 115 | critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm, norm_type=2) 116 | self.critic_optimizer.step() 117 | 118 | ####################### 119 | """ Policy Training """ 120 | ####################### 121 | bc_loss, new_action = self.actor.loss(state, action) 122 | entropy = self.actor.entropy(new_action) 123 | 124 | q_values_new_action_ensembles = self.critic(state, self.actor.clip_action(new_action)) 125 | if self.pess_method == 'min': 126 | q_values_new_action = q_values_new_action_ensembles.min(dim=1, keepdim=True)[0] 127 | elif self.pess_method == 'lcb': 128 | mu = q_values_new_action_ensembles.mean(dim=1, keepdim=True) 129 | std = q_values_new_action_ensembles.std(dim=1, keepdim=True) 130 | q_values_new_action = mu - self.lcb_coef * std 131 | 132 | q_loss = -q_values_new_action.mean() / q_values_new_action_ensembles.abs().mean().detach() 133 | entropy_loss = -entropy.mean() / entropy.abs().mean().detach() 134 | 135 | actor_loss = bc_loss + self.eta * q_loss + self.ent_coef * entropy_loss 136 | 137 | self.actor_optimizer.zero_grad() 138 | actor_loss.backward() 139 | if self.grad_norm > 0: 140 | actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2) 141 | self.actor_optimizer.step() 142 | 143 | """ Step Target network """ 144 | if self.step % self.update_ema_every == 0: 145 | self.step_ema() 146 | 147 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 148 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 149 | 150 | self.step += 1 151 | 152 | """ Log """ 153 | if log_writer is not None: 154 | if self.grad_norm > 0: 155 | log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step) 156 | log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step) 157 | log_writer.add_scalar('BC Loss', bc_loss.item(), self.step) 158 | log_writer.add_scalar('QL Loss', q_loss.item(), self.step) 159 | log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step) 160 | log_writer.add_scalar('Target_Q Mean', target_q.mean().item(), self.step) 161 | 162 | metric['actor_loss'].append(actor_loss.item()) 163 | metric['bc_loss'].append(bc_loss.item()) 164 | metric['ql_loss'].append(q_loss.item()) 165 | metric['critic_loss'].append(critic_loss.item()) 166 | metric['entropy'].append(entropy.mean().item()) 167 | 168 | if self.lr_decay: 169 | self.actor_lr_scheduler.step() 170 | self.critic_lr_scheduler.step() 171 | 172 | return metric 173 | 174 | 175 | def _sample_action(self, state): 176 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 177 | with torch.no_grad(): 178 | action = self.actor(state).squeeze() 179 | return action.cpu().numpy() 180 | 181 | def sample_action(self, state): 182 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 183 | state_rpt = torch.repeat_interleave(state, repeats=50, dim=0) 184 | with torch.no_grad(): 185 | action = self.actor(state_rpt) 186 | q_value = self.critic_target(state_rpt, action) 187 | q_mean = q_value.mean(dim=1,keepdim=True).flatten() 188 | q_std = q_value.std(dim=1,keepdim=True).flatten() 189 | q_lcb = q_mean - self.lcb_coef * q_std 190 | idx = torch.multinomial(F.softmax(q_lcb, dim=0), 1) 191 | return action[idx].cpu().data.numpy().flatten() 192 | 193 | def save_model(self, dir, id=None): 194 | if id is not None: 195 | torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth') 196 | torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth') 197 | else: 198 | torch.save(self.actor.state_dict(), f'{dir}/actor.pth') 199 | torch.save(self.critic.state_dict(), f'{dir}/critic.pth') 200 | 201 | def load_model(self, dir, id=None): 202 | if id is not None: 203 | self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth')) 204 | self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth')) 205 | else: 206 | self.actor.load_state_dict(torch.load(f'{dir}/actor.pth')) 207 | self.critic.load_state_dict(torch.load(f'{dir}/critic.pth')) 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /agents/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from agents.helpers import SinusoidalPosEmb 7 | 8 | from typing import Tuple 9 | import math 10 | 11 | 12 | class VectorizedLinear(nn.Module): 13 | def __init__(self, in_features: int, out_features: int, ensemble_size: int): 14 | super().__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.ensemble_size = ensemble_size 18 | 19 | self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) 20 | self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) 21 | 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | for layer in range(self.ensemble_size): 26 | nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5)) 27 | 28 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) 29 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 30 | nn.init.uniform_(self.bias, -bound, bound) 31 | 32 | def forward(self, input: torch.Tensor) -> torch.Tensor: 33 | assert len(input.shape) == 3, "shape should be [num_models, batch_size, in_features]" 34 | return torch.bmm(input, self.weight) + self.bias 35 | 36 | def extra_repr(self) -> str: 37 | return f'in_features={self.in_features}, out_features={self.out_features}, ensemble_size={self.ensemble_size}' 38 | 39 | class EnsembleCritic(nn.Module): 40 | def __init__( 41 | self, 42 | state_dim: int, 43 | action_dim: int, 44 | hidden_dim: int = 256, 45 | num_critics: int = 100, 46 | layernorm: bool = False, 47 | edac_init: bool = True 48 | ): 49 | super().__init__() 50 | self.ensemble = nn.Sequential( 51 | VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics), 52 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 53 | nn.Mish(), 54 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 55 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 56 | nn.Mish(), 57 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 58 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 59 | nn.Mish(), 60 | VectorizedLinear(hidden_dim, 1, num_critics) 61 | ) 62 | if edac_init: 63 | for layer in self.ensemble[::3]: 64 | torch.nn.init.constant_(layer.bias, 0.1) 65 | 66 | torch.nn.init.uniform_(self.ensemble[-1].weight, -3e-3, 3e-3) 67 | torch.nn.init.uniform_(self.ensemble[-1].bias, -3e-3, 3e-3) 68 | 69 | self.num_critics = num_critics 70 | 71 | def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 72 | # [batch_size, state_dim + action_dim] 73 | state_action = torch.cat([state, action], dim=-1) 74 | if state_action.dim() != 3: 75 | assert state_action.dim() == 2 76 | # [num_critics, batch_size, state_dim + action_dim] 77 | state_action = state_action.unsqueeze(0).repeat_interleave(self.num_critics, dim=0) 78 | assert state_action.dim() == 3 79 | assert state_action.shape[0] == self.num_critics 80 | # [num_critics, batch_size] 81 | out = self.ensemble(state_action).squeeze(-1) 82 | return out.permute(1,0) 83 | 84 | def q_min(self, state, action): 85 | # we dont have two to avoid over-estimation 86 | out = self.forward(state, action) 87 | return out 88 | 89 | 90 | class EnsembleDoubleCritic(nn.Module): 91 | def __init__( 92 | self, 93 | state_dim: int, 94 | action_dim: int, 95 | hidden_dim: int = 256, 96 | num_critics: int = 100, 97 | layernorm: bool = False, 98 | edac_init: bool = True 99 | ): 100 | super().__init__() 101 | self.ensemble1 = nn.Sequential( 102 | VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics), 103 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 104 | nn.Mish(), 105 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 106 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 107 | nn.Mish(), 108 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 109 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 110 | nn.Mish(), 111 | VectorizedLinear(hidden_dim, 1, num_critics) 112 | ) 113 | self.ensemble2 = nn.Sequential( 114 | VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics), 115 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 116 | nn.Mish(), 117 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 118 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 119 | nn.Mish(), 120 | VectorizedLinear(hidden_dim, hidden_dim, num_critics), 121 | nn.LayerNorm(hidden_dim) if layernorm else nn.Identity(), 122 | nn.Mish(), 123 | VectorizedLinear(hidden_dim, 1, num_critics) 124 | ) 125 | 126 | if edac_init: 127 | for layer in self.ensemble1[::3]: 128 | torch.nn.init.constant_(layer.bias, 0.1) 129 | 130 | torch.nn.init.uniform_(self.ensemble1[-1].weight, -3e-3, 3e-3) 131 | torch.nn.init.uniform_(self.ensemble1[-1].bias, -3e-3, 3e-3) 132 | 133 | for layer in self.ensemble2[::3]: 134 | torch.nn.init.constant_(layer.bias, 0.1) 135 | 136 | torch.nn.init.uniform_(self.ensemble2[-1].weight, -3e-3, 3e-3) 137 | torch.nn.init.uniform_(self.ensemble2[-1].bias, -3e-3, 3e-3) 138 | 139 | self.num_critics = num_critics 140 | 141 | def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 142 | # [batch_size, state_dim + action_dim] 143 | state_action = torch.cat([state, action], dim=-1) 144 | if state_action.dim() != 3: 145 | assert state_action.dim() == 2 146 | # [num_critics, batch_size, state_dim + action_dim] 147 | state_action = state_action.unsqueeze(0).repeat_interleave(self.num_critics, dim=0) 148 | assert state_action.dim() == 3 149 | assert state_action.shape[0] == self.num_critics 150 | # [num_critics, batch_size] 151 | out1 = self.ensemble1(state_action).squeeze(-1) 152 | out2 = self.ensemble2(state_action).squeeze(-1) 153 | return out1.permute(1,0), out2.permute(1,0) # [batch_size, num_critic] 154 | 155 | def q_min(self, state, action): 156 | out1, out2 = self.forward(state, action) 157 | return torch.min(out1, out2) 158 | 159 | 160 | class Critic(nn.Module): 161 | def __init__(self, state_dim, action_dim, hidden_dim=256): 162 | super(Critic, self).__init__() 163 | self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), 164 | nn.Mish(), 165 | nn.Linear(hidden_dim, hidden_dim), 166 | nn.Mish(), 167 | nn.Linear(hidden_dim, hidden_dim), 168 | nn.Mish(), 169 | nn.Linear(hidden_dim, 1)) 170 | 171 | self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), 172 | nn.Mish(), 173 | nn.Linear(hidden_dim, hidden_dim), 174 | nn.Mish(), 175 | nn.Linear(hidden_dim, hidden_dim), 176 | nn.Mish(), 177 | nn.Linear(hidden_dim, 1)) 178 | 179 | def forward(self, state, action): 180 | x = torch.cat([state, action], dim=-1) 181 | return self.q1_model(x), self.q2_model(x) 182 | 183 | def q1(self, state, action): 184 | x = torch.cat([state, action], dim=-1) 185 | return self.q1_model(x) 186 | 187 | def q_min(self, state, action): 188 | q1, q2 = self.forward(state, action) 189 | return torch.min(q1, q2) 190 | 191 | 192 | class DiffusionMLP(nn.Module): 193 | """ 194 | MLP Model 195 | """ 196 | def __init__(self, state_dim, action_dim, t_dim=16): 197 | super().__init__() 198 | 199 | self.time_mlp = nn.Sequential( 200 | SinusoidalPosEmb(t_dim), 201 | nn.Linear(t_dim, t_dim * 2), 202 | nn.Mish(), 203 | nn.Linear(t_dim * 2, t_dim), 204 | ) 205 | 206 | input_dim = state_dim + action_dim + t_dim 207 | self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256), 208 | nn.Mish(), 209 | nn.Linear(256, 256), 210 | nn.Mish(), 211 | nn.Linear(256, 256), 212 | nn.Mish()) 213 | 214 | self.final_layer = nn.Linear(256, action_dim) 215 | 216 | def forward(self, x, time, state): 217 | if isinstance(time, int): 218 | batch_size = x.shape[0] 219 | time = torch.full((batch_size,), time, device=x.device, dtype=torch.long) 220 | 221 | t = self.time_mlp(time) 222 | x = torch.cat([x, t, state], dim=1) 223 | x = self.mid_layer(x) 224 | 225 | return self.final_layer(x) 226 | 227 | 228 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | import os 5 | import torch 6 | import json 7 | 8 | import d4rl 9 | from utils import utils 10 | from utils.data_sampler import Data_Sampler 11 | from utils.logger import logger, setup_logger 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | hyperparameters = { 15 | 'halfcheetah-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 16 | 'hopper-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 17 | 'walker2d-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 18 | 'halfcheetah-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 19 | 'hopper-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 20 | 'walker2d-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 21 | 'halfcheetah-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 22 | 'hopper-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 23 | 'walker2d-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 2000, 'gn': 4.0, 'loss_type': 'MLL', 'action_clip': False}, 24 | 25 | 'antmaze-umaze-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': False, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 26 | 'antmaze-umaze-diverse-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 27 | 'antmaze-medium-play-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 28 | 'antmaze-medium-diverse-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 29 | 'antmaze-large-play-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 30 | 'antmaze-large-diverse-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze', 'num_epochs': 1000, 'gn': 4.0, 'loss_type': 'NML', 'action_clip': True}, 31 | 32 | 'pen-human-v1': {'lr': 3e-5, 'eta': 0.1, 'max_q_backup': False, 'reward_tune': 'normalize', 'num_epochs': 1000, 'gn': 8.0, 'loss_type': 'NML', 'action_clip': True}, 33 | 'pen-cloned-v1': {'lr': 3e-5, 'eta': 0.1, 'max_q_backup': False, 'reward_tune': 'normalize', 'num_epochs': 1000, 'gn': 8.0, 'loss_type': 'NML', 'action_clip': True}, 34 | 35 | 'kitchen-complete-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 1000, 'gn': 10., 'loss_type': 'MLL', 'action_clip': False}, 36 | 'kitchen-partial-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 1000, 'gn': 10., 'loss_type': 'MLL', 'action_clip': False}, 37 | 'kitchen-mixed-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'num_epochs': 1000, 'gn': 10., 'loss_type': 'MLL', 'action_clip': False}, 38 | } 39 | 40 | def train_agent(env, state_dim, action_dim, max_action, output_dir, args): 41 | # Load buffer 42 | dataset = d4rl.qlearning_dataset(env) 43 | data_sampler = Data_Sampler(dataset, args.device, args.reward_tune) 44 | utils.print_banner('Loaded buffer') 45 | 46 | if args.algo == 'eql': 47 | from agents.eql_diffusion import Diffusion_EQL as Agent 48 | agent = Agent(state_dim=state_dim, 49 | action_dim=action_dim, 50 | max_action=max_action, 51 | device=args.device, 52 | discount=args.discount, 53 | tau=args.tau, 54 | max_q_backup=args.max_q_backup, 55 | schedule=args.schedule, 56 | n_timesteps=args.T, 57 | eta=args.eta, 58 | lr=args.lr, 59 | lr_decay=args.lr_decay, 60 | lr_maxt=args.num_epochs, 61 | grad_norm=args.gn, 62 | num_critics=args.num_critics, 63 | pess_method=args.pess, 64 | ent_coef=args.ent_coef, 65 | lcb_coef=args.lcb_coef, 66 | loss_type=args.loss_type, 67 | action_clip=args.action_clip) 68 | elif args.algo == 'bc': 69 | from agents.bc_diffusion import Diffusion_BC as Agent 70 | agent = Agent(state_dim=state_dim, 71 | action_dim=action_dim, 72 | max_action=max_action, 73 | device=args.device, 74 | discount=args.discount, 75 | tau=args.tau, 76 | schedule=args.schedule, 77 | n_timesteps=args.T, 78 | lr=args.lr, 79 | loss_type=args.loss_type, 80 | action_clip=args.action_clip) 81 | 82 | early_stop = False 83 | stop_check = utils.EarlyStopping(tolerance=1, min_delta=0.) 84 | writer = None # SummaryWriter(output_dir) 85 | 86 | evaluations = [] 87 | training_iters = 0 88 | max_timesteps = args.num_epochs * args.num_steps_per_epoch 89 | metric = 100. 90 | utils.print_banner(f"Training Start", separator="*", num_star=90) 91 | while (training_iters < max_timesteps) and (not early_stop): 92 | iterations = int(args.eval_freq * args.num_steps_per_epoch) 93 | loss_metric = agent.train(data_sampler, 94 | iterations=iterations, 95 | batch_size=args.batch_size, 96 | log_writer=writer) 97 | training_iters += iterations 98 | curr_epoch = int(training_iters // int(args.num_steps_per_epoch)) 99 | 100 | # Logging 101 | utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90) 102 | logger.record_tabular('Trained Epochs', curr_epoch) 103 | logger.record_tabular('BC Loss', np.mean(loss_metric['bc_loss'])) 104 | logger.record_tabular('QL Loss', np.mean(loss_metric['ql_loss'])) 105 | logger.record_tabular('Actor Loss', np.mean(loss_metric['actor_loss'])) 106 | logger.record_tabular('Critic Loss', np.mean(loss_metric['critic_loss'])) 107 | logger.record_tabular('Entropy', np.mean(loss_metric['entropy'])) 108 | if 'ent_coef' in loss_metric.keys(): 109 | logger.record_tabular('Entropy Coef', np.mean(loss_metric['ent_coef'])) 110 | if 'ent_coef_loss' in loss_metric.keys(): 111 | logger.record_tabular('Entropy Coef Loss', np.mean(loss_metric['ent_coef_loss'])) 112 | logger.dump_tabular() 113 | 114 | 115 | # Evaluation 116 | eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(agent, args.env_name, args.seed, 117 | eval_episodes=args.eval_episodes) 118 | evaluations.append([eval_res, eval_res_std, eval_norm_res, eval_norm_res_std, 119 | np.mean(loss_metric['bc_loss']), np.mean(loss_metric['ql_loss']), 120 | np.mean(loss_metric['actor_loss']), np.mean(loss_metric['critic_loss']), 121 | curr_epoch]) 122 | np.save(os.path.join(output_dir, "eval"), evaluations) 123 | logger.record_tabular('Average Episodic Reward', eval_res) 124 | logger.record_tabular('Average Episodic N-Reward', eval_norm_res) 125 | logger.dump_tabular() 126 | 127 | bc_loss = np.mean(loss_metric['bc_loss']) 128 | if args.early_stop: 129 | early_stop = stop_check(metric, bc_loss) 130 | 131 | metric = bc_loss 132 | 133 | if args.save_best_model: 134 | agent.save_model(output_dir, curr_epoch) 135 | 136 | # Model Selection: online or offline 137 | scores = np.array(evaluations) 138 | 139 | best_id = np.argmax(scores[:, 2]) 140 | best_res = {'model selection': args.algo, 'epoch': scores[best_id, -1], 141 | 'best normalized score avg': scores[best_id, 2], 142 | 'best normalized score std': scores[best_id, 3], 143 | 'best raw score avg': scores[best_id, 0], 144 | 'best raw score std': scores[best_id, 1]} 145 | 146 | with open(os.path.join(output_dir, f"best_score_{args.algo}.txt"), 'w') as f: 147 | f.write(json.dumps(best_res)) 148 | 149 | # writer.close() 150 | 151 | 152 | # Runs policy for X episodes and returns average reward 153 | # A fixed seed is used for the eval environment 154 | def eval_policy(policy, env_name, seed, eval_episodes=10): 155 | eval_env = gym.make(env_name) 156 | eval_env.seed(seed + 100) 157 | 158 | scores = [] 159 | for _ in range(eval_episodes): 160 | traj_return = 0. 161 | state, done = eval_env.reset(), False 162 | while not done: 163 | action = policy.sample_action(np.array(state)) 164 | state, reward, done, _ = eval_env.step(action) 165 | traj_return += reward 166 | scores.append(traj_return) 167 | 168 | avg_reward = np.mean(scores) 169 | std_reward = np.std(scores) 170 | 171 | normalized_scores = [eval_env.get_normalized_score(s) for s in scores] 172 | avg_norm_score = eval_env.get_normalized_score(avg_reward) 173 | std_norm_score = np.std(normalized_scores) 174 | 175 | utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}") 176 | return avg_reward, std_reward, avg_norm_score, std_norm_score 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | ### Experimental Setups ### 182 | parser.add_argument("--exp", default='exp_1', type=str) # Experiment ID 183 | parser.add_argument('--device', default=0, type=int) # device, {"cpu", "cuda", "cuda:0", "cuda:1"}, etc 184 | parser.add_argument("--env_name", default="walker2d-medium-expert-v2", type=str) # OpenAI gym environment name 185 | parser.add_argument("--dir", default="results", type=str) # Logging directory 186 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 187 | parser.add_argument("--num_steps_per_epoch", default=1000, type=int) 188 | parser.add_argument("--eval_freq", default=50, type=int) 189 | 190 | ### Optimization Setups ### 191 | parser.add_argument("--batch_size", default=256, type=int) 192 | parser.add_argument("--lr_decay", action='store_true') 193 | parser.add_argument('--early_stop', action='store_true') 194 | parser.add_argument('--save_best_model', action='store_true') 195 | 196 | ### RL Parameters ### 197 | parser.add_argument("--discount", default=0.99, type=float) 198 | parser.add_argument("--tau", default=0.005, type=float) 199 | 200 | ### Diffusion Setting ### 201 | parser.add_argument("--T", default=5, type=int) 202 | parser.add_argument("--schedule", default='cosine', type=str) 203 | ### Algo Choice ### 204 | parser.add_argument("--algo", default="eql", type=str) # ['bc', 'eql'] 205 | 206 | parser.add_argument("--pess", default='lcb', type=str, help="['min', 'lcb']") 207 | parser.add_argument("--num_critics", default=64, type=int) 208 | parser.add_argument("--ent_coef", default=0.01, type=float) 209 | parser.add_argument("--lcb_coef", default=4, type=float) 210 | 211 | # parser.add_argument("--lr", default=3e-4, type=float) 212 | # parser.add_argument("--eta", default=1.0, type=float) 213 | # parser.add_argument("--reward_tune", default='no', type=str) 214 | # parser.add_argument("--gn", default=-1.0, type=float) 215 | 216 | args = parser.parse_args() 217 | args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 218 | args.output_dir = f'{args.dir}/{args.algo}' 219 | 220 | args.num_epochs = hyperparameters[args.env_name]['num_epochs'] 221 | args.eval_episodes = 10 if 'v2' in args.env_name else 100 222 | 223 | args.lr = hyperparameters[args.env_name]['lr'] 224 | args.eta = hyperparameters[args.env_name]['eta'] 225 | args.max_q_backup = hyperparameters[args.env_name]['max_q_backup'] 226 | args.reward_tune = hyperparameters[args.env_name]['reward_tune'] 227 | args.gn = hyperparameters[args.env_name]['gn'] 228 | args.loss_type = hyperparameters[args.env_name]['loss_type'] 229 | args.action_clip = hyperparameters[args.env_name]['action_clip'] 230 | 231 | # Setup Logging 232 | file_name_prex = f"{args.env_name}|{args.exp}|" 233 | file_name = f"diffusion-{args.algo}|T-{args.T}" 234 | if args.lr_decay: file_name += '|lr_decay' 235 | file_name += f'|{args.seed}' 236 | file_name += f'|ent_coef-{args.ent_coef}' 237 | 238 | if 'eql' in args.algo: 239 | file_name += f'|pess-{args.pess}-{args.lcb_coef}' 240 | file_name += f'|num_critics-{args.num_critics}' 241 | 242 | results_dir = os.path.join(args.output_dir, file_name_prex+file_name) 243 | if not os.path.exists(results_dir): 244 | os.makedirs(results_dir) 245 | utils.print_banner(f"Saving location: {results_dir}") 246 | # if os.path.exists(os.path.join(results_dir, 'variant.json')): 247 | # raise AssertionError("Experiment under this setting has been done!") 248 | variant = vars(args) 249 | variant.update(version=f"EnsembleQ-Diffusion-Policies-RL") 250 | 251 | env = gym.make(args.env_name) 252 | 253 | env.seed(args.seed) 254 | torch.manual_seed(args.seed) 255 | np.random.seed(args.seed) 256 | 257 | state_dim = env.observation_space.shape[0] 258 | action_dim = env.action_space.shape[0] 259 | max_action = float(env.action_space.high[0]) 260 | 261 | variant.update(state_dim=state_dim) 262 | variant.update(action_dim=action_dim) 263 | variant.update(max_action=max_action) 264 | setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir) 265 | utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}") 266 | 267 | 268 | train_agent(env, state_dim, action_dim, max_action, results_dir, args) 269 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's logger. 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | from enum import Enum 7 | from contextlib import contextmanager 8 | import numpy as np 9 | import os 10 | import os.path as osp 11 | import sys 12 | import datetime 13 | import dateutil.tz 14 | import csv 15 | import json 16 | import pickle 17 | import errno 18 | from collections import OrderedDict 19 | from numbers import Number 20 | import os 21 | 22 | from tabulate import tabulate 23 | import os.path as osp 24 | 25 | def dict_to_safe_json(d): 26 | """ 27 | Convert each value in the dictionary into a JSON'able primitive. 28 | :param d: 29 | :return: 30 | """ 31 | new_d = {} 32 | for key, item in d.items(): 33 | if safe_json(item): 34 | new_d[key] = item 35 | else: 36 | if isinstance(item, dict): 37 | new_d[key] = dict_to_safe_json(item) 38 | else: 39 | new_d[key] = str(item) 40 | return new_d 41 | 42 | 43 | def safe_json(data): 44 | if data is None: 45 | return True 46 | elif isinstance(data, (bool, int, float)): 47 | return True 48 | elif isinstance(data, (tuple, list)): 49 | return all(safe_json(x) for x in data) 50 | elif isinstance(data, dict): 51 | return all(isinstance(k, str) and safe_json(v) for k, v in data.items()) 52 | return False 53 | 54 | def create_exp_name(exp_prefix, exp_id=0, seed=0): 55 | """ 56 | Create a semi-unique experiment name that has a timestamp 57 | :param exp_prefix: 58 | :param exp_id: 59 | :return: 60 | """ 61 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 62 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 63 | return "%s_%s_%04d--s-%d" % (exp_prefix, timestamp, exp_id, seed) 64 | 65 | def create_log_dir( 66 | exp_prefix, 67 | exp_id=0, 68 | seed=0, 69 | base_log_dir=None, 70 | include_exp_prefix_sub_dir=True, 71 | ): 72 | """ 73 | Creates and returns a unique log directory. 74 | :param exp_prefix: All experiments with this prefix will have log 75 | directories be under this directory. 76 | :param exp_id: The number of the specific experiment run within this 77 | experiment. 78 | :param base_log_dir: The directory where all log should be saved. 79 | :return: 80 | """ 81 | exp_name = create_exp_name(exp_prefix, exp_id=exp_id, 82 | seed=seed) 83 | if base_log_dir is None: 84 | base_log_dir = './data' 85 | if include_exp_prefix_sub_dir: 86 | log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name) 87 | else: 88 | log_dir = osp.join(base_log_dir, exp_name) 89 | if osp.exists(log_dir): 90 | print("WARNING: Log directory already exists {}".format(log_dir), flush=True) 91 | os.makedirs(log_dir, exist_ok=True) 92 | return log_dir 93 | 94 | 95 | def setup_logger( 96 | exp_prefix="default", 97 | variant=None, 98 | text_log_file="debug.log", 99 | variant_log_file="variant.json", 100 | tabular_log_file="progress.csv", 101 | snapshot_mode="last", 102 | snapshot_gap=1, 103 | log_tabular_only=False, 104 | log_dir=None, 105 | git_infos=None, 106 | script_name=None, 107 | **create_log_dir_kwargs 108 | ): 109 | """ 110 | Set up logger to have some reasonable default settings. 111 | Will save log output to 112 | based_log_dir/exp_prefix/exp_name. 113 | exp_name will be auto-generated to be unique. 114 | If log_dir is specified, then that directory is used as the output dir. 115 | :param exp_prefix: The sub-directory for this specific experiment. 116 | :param variant: 117 | :param text_log_file: 118 | :param variant_log_file: 119 | :param tabular_log_file: 120 | :param snapshot_mode: 121 | :param log_tabular_only: 122 | :param snapshot_gap: 123 | :param log_dir: 124 | :param git_infos: 125 | :param script_name: If set, save the script name to this. 126 | :return: 127 | """ 128 | first_time = log_dir is None 129 | if first_time: 130 | log_dir = create_log_dir(exp_prefix, **create_log_dir_kwargs) 131 | 132 | if variant is not None: 133 | logger.log("Variant:") 134 | logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) 135 | variant_log_path = osp.join(log_dir, variant_log_file) 136 | logger.log_variant(variant_log_path, variant) 137 | 138 | tabular_log_path = osp.join(log_dir, tabular_log_file) 139 | text_log_path = osp.join(log_dir, text_log_file) 140 | 141 | logger.add_text_output(text_log_path) 142 | if first_time: 143 | logger.add_tabular_output(tabular_log_path) 144 | else: 145 | logger._add_output(tabular_log_path, logger._tabular_outputs, 146 | logger._tabular_fds, mode='a') 147 | for tabular_fd in logger._tabular_fds: 148 | logger._tabular_header_written.add(tabular_fd) 149 | logger.set_snapshot_dir(log_dir) 150 | logger.set_snapshot_mode(snapshot_mode) 151 | logger.set_snapshot_gap(snapshot_gap) 152 | logger.set_log_tabular_only(log_tabular_only) 153 | exp_name = log_dir.split("/")[-1] 154 | logger.push_prefix("[%s] " % exp_name) 155 | 156 | if script_name is not None: 157 | with open(osp.join(log_dir, "script_name.txt"), "w") as f: 158 | f.write(script_name) 159 | return log_dir 160 | 161 | 162 | def create_stats_ordered_dict( 163 | name, 164 | data, 165 | stat_prefix=None, 166 | always_show_all_stats=True, 167 | exclude_max_min=False, 168 | ): 169 | if stat_prefix is not None: 170 | name = "{}{}".format(stat_prefix, name) 171 | if isinstance(data, Number): 172 | return OrderedDict({name: data}) 173 | 174 | if len(data) == 0: 175 | return OrderedDict() 176 | 177 | if isinstance(data, tuple): 178 | ordered_dict = OrderedDict() 179 | for number, d in enumerate(data): 180 | sub_dict = create_stats_ordered_dict( 181 | "{0}_{1}".format(name, number), 182 | d, 183 | ) 184 | ordered_dict.update(sub_dict) 185 | return ordered_dict 186 | 187 | if isinstance(data, list): 188 | try: 189 | iter(data[0]) 190 | except TypeError: 191 | pass 192 | else: 193 | data = np.concatenate(data) 194 | 195 | if (isinstance(data, np.ndarray) and data.size == 1 196 | and not always_show_all_stats): 197 | return OrderedDict({name: float(data)}) 198 | 199 | stats = OrderedDict([ 200 | (name + ' Mean', np.mean(data)), 201 | (name + ' Std', np.std(data)), 202 | ]) 203 | if not exclude_max_min: 204 | stats[name + ' Max'] = np.max(data) 205 | stats[name + ' Min'] = np.min(data) 206 | return stats 207 | 208 | 209 | class TerminalTablePrinter(object): 210 | def __init__(self): 211 | self.headers = None 212 | self.tabulars = [] 213 | 214 | def print_tabular(self, new_tabular): 215 | if self.headers is None: 216 | self.headers = [x[0] for x in new_tabular] 217 | else: 218 | assert len(self.headers) == len(new_tabular) 219 | self.tabulars.append([x[1] for x in new_tabular]) 220 | self.refresh() 221 | 222 | def refresh(self): 223 | import os 224 | rows, columns = os.popen('stty size', 'r').read().split() 225 | tabulars = self.tabulars[-(int(rows) - 3):] 226 | sys.stdout.write("\x1b[2J\x1b[H") 227 | sys.stdout.write(tabulate(tabulars, self.headers)) 228 | sys.stdout.write("\n") 229 | 230 | 231 | class MyEncoder(json.JSONEncoder): 232 | def default(self, o): 233 | if isinstance(o, type): 234 | return {'$class': o.__module__ + "." + o.__name__} 235 | elif isinstance(o, Enum): 236 | return { 237 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 238 | } 239 | elif callable(o): 240 | return { 241 | '$function': o.__module__ + "." + o.__name__ 242 | } 243 | return json.JSONEncoder.default(self, o) 244 | 245 | 246 | def mkdir_p(path): 247 | try: 248 | os.makedirs(path) 249 | except OSError as exc: # Python >2.5 250 | if exc.errno == errno.EEXIST and os.path.isdir(path): 251 | pass 252 | else: 253 | raise 254 | 255 | 256 | class Logger(object): 257 | def __init__(self): 258 | self._prefixes = [] 259 | self._prefix_str = '' 260 | 261 | self._tabular_prefixes = [] 262 | self._tabular_prefix_str = '' 263 | 264 | self._tabular = [] 265 | 266 | self._text_outputs = [] 267 | self._tabular_outputs = [] 268 | 269 | self._text_fds = {} 270 | self._tabular_fds = {} 271 | self._tabular_header_written = set() 272 | 273 | self._snapshot_dir = None 274 | self._snapshot_mode = 'all' 275 | self._snapshot_gap = 1 276 | 277 | self._log_tabular_only = False 278 | self._header_printed = False 279 | self.table_printer = TerminalTablePrinter() 280 | 281 | def reset(self): 282 | self.__init__() 283 | 284 | def _add_output(self, file_name, arr, fds, mode='a'): 285 | if file_name not in arr: 286 | mkdir_p(os.path.dirname(file_name)) 287 | arr.append(file_name) 288 | fds[file_name] = open(file_name, mode) 289 | 290 | def _remove_output(self, file_name, arr, fds): 291 | if file_name in arr: 292 | fds[file_name].close() 293 | del fds[file_name] 294 | arr.remove(file_name) 295 | 296 | def push_prefix(self, prefix): 297 | self._prefixes.append(prefix) 298 | self._prefix_str = ''.join(self._prefixes) 299 | 300 | def add_text_output(self, file_name): 301 | self._add_output(file_name, self._text_outputs, self._text_fds, 302 | mode='a') 303 | 304 | def remove_text_output(self, file_name): 305 | self._remove_output(file_name, self._text_outputs, self._text_fds) 306 | 307 | def add_tabular_output(self, file_name, relative_to_snapshot_dir=False): 308 | if relative_to_snapshot_dir: 309 | file_name = osp.join(self._snapshot_dir, file_name) 310 | self._add_output(file_name, self._tabular_outputs, self._tabular_fds, 311 | mode='w') 312 | 313 | def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False): 314 | if relative_to_snapshot_dir: 315 | file_name = osp.join(self._snapshot_dir, file_name) 316 | if self._tabular_fds[file_name] in self._tabular_header_written: 317 | self._tabular_header_written.remove(self._tabular_fds[file_name]) 318 | self._remove_output(file_name, self._tabular_outputs, self._tabular_fds) 319 | 320 | def set_snapshot_dir(self, dir_name): 321 | self._snapshot_dir = dir_name 322 | 323 | def get_snapshot_dir(self, ): 324 | return self._snapshot_dir 325 | 326 | def get_snapshot_mode(self, ): 327 | return self._snapshot_mode 328 | 329 | def set_snapshot_mode(self, mode): 330 | self._snapshot_mode = mode 331 | 332 | def get_snapshot_gap(self, ): 333 | return self._snapshot_gap 334 | 335 | def set_snapshot_gap(self, gap): 336 | self._snapshot_gap = gap 337 | 338 | def set_log_tabular_only(self, log_tabular_only): 339 | self._log_tabular_only = log_tabular_only 340 | 341 | def get_log_tabular_only(self, ): 342 | return self._log_tabular_only 343 | 344 | def log(self, s, with_prefix=True, with_timestamp=True): 345 | out = s 346 | if with_prefix: 347 | out = self._prefix_str + out 348 | if with_timestamp: 349 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 350 | timestamp = now.strftime('%y-%m-%d.%H:%M') # :%S 351 | out = "%s|%s" % (timestamp, out) 352 | if not self._log_tabular_only: 353 | # Also log to stdout 354 | print(out, flush=True) 355 | for fd in list(self._text_fds.values()): 356 | fd.write(out + '\n') 357 | fd.flush() 358 | sys.stdout.flush() 359 | 360 | def record_tabular(self, key, val): 361 | self._tabular.append((self._tabular_prefix_str + str(key), str(val))) 362 | 363 | def record_dict(self, d, prefix=None): 364 | if prefix is not None: 365 | self.push_tabular_prefix(prefix) 366 | for k, v in d.items(): 367 | self.record_tabular(k, v) 368 | if prefix is not None: 369 | self.pop_tabular_prefix() 370 | 371 | def push_tabular_prefix(self, key): 372 | self._tabular_prefixes.append(key) 373 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 374 | 375 | def pop_tabular_prefix(self, ): 376 | del self._tabular_prefixes[-1] 377 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 378 | 379 | def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'): 380 | """ 381 | Data saved here will always override the last entry 382 | 383 | :param data: Something pickle'able. 384 | """ 385 | file_name = osp.join(self._snapshot_dir, file_name) 386 | if mode == 'joblib': 387 | import joblib 388 | joblib.dump(data, file_name, compress=3) 389 | elif mode == 'pickle': 390 | pickle.dump(data, open(file_name, "wb")) 391 | else: 392 | raise ValueError("Invalid mode: {}".format(mode)) 393 | return file_name 394 | 395 | def get_table_dict(self, ): 396 | return dict(self._tabular) 397 | 398 | def get_table_key_set(self, ): 399 | return set(key for key, value in self._tabular) 400 | 401 | @contextmanager 402 | def prefix(self, key): 403 | self.push_prefix(key) 404 | try: 405 | yield 406 | finally: 407 | self.pop_prefix() 408 | 409 | @contextmanager 410 | def tabular_prefix(self, key): 411 | self.push_tabular_prefix(key) 412 | yield 413 | self.pop_tabular_prefix() 414 | 415 | def log_variant(self, log_file, variant_data): 416 | mkdir_p(os.path.dirname(log_file)) 417 | with open(log_file, "w") as f: 418 | json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder) 419 | 420 | def record_tabular_misc_stat(self, key, values, placement='back'): 421 | if placement == 'front': 422 | prefix = "" 423 | suffix = key 424 | else: 425 | prefix = key 426 | suffix = "" 427 | if len(values) > 0: 428 | self.record_tabular(prefix + "Average" + suffix, np.average(values)) 429 | self.record_tabular(prefix + "Std" + suffix, np.std(values)) 430 | self.record_tabular(prefix + "Median" + suffix, np.median(values)) 431 | self.record_tabular(prefix + "Min" + suffix, np.min(values)) 432 | self.record_tabular(prefix + "Max" + suffix, np.max(values)) 433 | else: 434 | self.record_tabular(prefix + "Average" + suffix, np.nan) 435 | self.record_tabular(prefix + "Std" + suffix, np.nan) 436 | self.record_tabular(prefix + "Median" + suffix, np.nan) 437 | self.record_tabular(prefix + "Min" + suffix, np.nan) 438 | self.record_tabular(prefix + "Max" + suffix, np.nan) 439 | 440 | def dump_tabular(self, *args, **kwargs): 441 | wh = kwargs.pop("write_header", None) 442 | if len(self._tabular) > 0: 443 | if self._log_tabular_only: 444 | self.table_printer.print_tabular(self._tabular) 445 | else: 446 | for line in tabulate(self._tabular).split('\n'): 447 | self.log(line, *args, **kwargs) 448 | tabular_dict = dict(self._tabular) 449 | # Also write to the csv files 450 | # This assumes that the keys in each iteration won't change! 451 | for tabular_fd in list(self._tabular_fds.values()): 452 | writer = csv.DictWriter(tabular_fd, 453 | fieldnames=list(tabular_dict.keys())) 454 | if wh or ( 455 | wh is None and tabular_fd not in self._tabular_header_written): 456 | writer.writeheader() 457 | self._tabular_header_written.add(tabular_fd) 458 | writer.writerow(tabular_dict) 459 | tabular_fd.flush() 460 | del self._tabular[:] 461 | 462 | def pop_prefix(self, ): 463 | del self._prefixes[-1] 464 | self._prefix_str = ''.join(self._prefixes) 465 | 466 | def save_itr_params(self, itr, params): 467 | if self._snapshot_dir: 468 | if self._snapshot_mode == 'all': 469 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 470 | pickle.dump(params, open(file_name, "wb")) 471 | elif self._snapshot_mode == 'last': 472 | # override previous params 473 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 474 | pickle.dump(params, open(file_name, "wb")) 475 | elif self._snapshot_mode == "gap": 476 | if itr % self._snapshot_gap == 0: 477 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 478 | pickle.dump(params, open(file_name, "wb")) 479 | elif self._snapshot_mode == "gap_and_last": 480 | if itr % self._snapshot_gap == 0: 481 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 482 | pickle.dump(params, open(file_name, "wb")) 483 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 484 | pickle.dump(params, open(file_name, "wb")) 485 | elif self._snapshot_mode == 'none': 486 | pass 487 | else: 488 | raise NotImplementedError 489 | 490 | 491 | logger = Logger() 492 | 493 | -------------------------------------------------------------------------------- /tabulate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Code taken from https://github.com/gregbanks/python-tabulate 3 | 4 | """Pretty-print tabular data.""" 5 | 6 | 7 | 8 | from collections import namedtuple 9 | from platform import python_version_tuple 10 | import re 11 | 12 | 13 | if python_version_tuple()[0] < "3": 14 | from itertools import izip_longest 15 | from functools import partial 16 | _none_type = type(None) 17 | _int_type = int 18 | _float_type = float 19 | _text_type = str 20 | _binary_type = str 21 | else: 22 | from itertools import zip_longest as izip_longest 23 | from functools import reduce, partial 24 | _none_type = type(None) 25 | _int_type = int 26 | _float_type = float 27 | _text_type = str 28 | _binary_type = bytes 29 | 30 | 31 | __all__ = ["tabulate", "tabulate_formats", "simple_separated_format"] 32 | __version__ = "0.7.2" 33 | 34 | 35 | Line = namedtuple("Line", ["begin", "hline", "sep", "end"]) 36 | 37 | 38 | DataRow = namedtuple("DataRow", ["begin", "sep", "end"]) 39 | 40 | 41 | # A table structure is suppposed to be: 42 | # 43 | # --- lineabove --------- 44 | # headerrow 45 | # --- linebelowheader --- 46 | # datarow 47 | # --- linebewteenrows --- 48 | # ... (more datarows) ... 49 | # --- linebewteenrows --- 50 | # last datarow 51 | # --- linebelow --------- 52 | # 53 | # TableFormat's line* elements can be 54 | # 55 | # - either None, if the element is not used, 56 | # - or a Line tuple, 57 | # - or a function: [col_widths], [col_alignments] -> string. 58 | # 59 | # TableFormat's *row elements can be 60 | # 61 | # - either None, if the element is not used, 62 | # - or a DataRow tuple, 63 | # - or a function: [cell_values], [col_widths], [col_alignments] -> string. 64 | # 65 | # padding (an integer) is the amount of white space around data values. 66 | # 67 | # with_header_hide: 68 | # 69 | # - either None, to display all table elements unconditionally, 70 | # - or a list of elements not to be displayed if the table has column headers. 71 | # 72 | TableFormat = namedtuple("TableFormat", ["lineabove", "linebelowheader", 73 | "linebetweenrows", "linebelow", 74 | "headerrow", "datarow", 75 | "padding", "with_header_hide"]) 76 | 77 | 78 | def _pipe_segment_with_colons(align, colwidth): 79 | """Return a segment of a horizontal line with optional colons which 80 | indicate column's alignment (as in `pipe` output format).""" 81 | w = colwidth 82 | if align in ["right", "decimal"]: 83 | return ('-' * (w - 1)) + ":" 84 | elif align == "center": 85 | return ":" + ('-' * (w - 2)) + ":" 86 | elif align == "left": 87 | return ":" + ('-' * (w - 1)) 88 | else: 89 | return '-' * w 90 | 91 | 92 | def _pipe_line_with_colons(colwidths, colaligns): 93 | """Return a horizontal line with optional colons to indicate column's 94 | alignment (as in `pipe` output format).""" 95 | segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] 96 | return "|" + "|".join(segments) + "|" 97 | 98 | 99 | def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns): 100 | alignment = { "left": '', 101 | "right": 'align="right"| ', 102 | "center": 'align="center"| ', 103 | "decimal": 'align="right"| ' } 104 | # hard-coded padding _around_ align attribute and value together 105 | # rather than padding parameter which affects only the value 106 | values_with_attrs = [' ' + alignment.get(a, '') + c + ' ' 107 | for c, a in zip(cell_values, colaligns)] 108 | colsep = separator*2 109 | return (separator + colsep.join(values_with_attrs)).rstrip() 110 | 111 | 112 | def _latex_line_begin_tabular(colwidths, colaligns): 113 | alignment = { "left": "l", "right": "r", "center": "c", "decimal": "r" } 114 | tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns]) 115 | return "\\begin{tabular}{" + tabular_columns_fmt + "}\n\hline" 116 | 117 | 118 | _table_formats = {"simple": 119 | TableFormat(lineabove=Line("", "-", " ", ""), 120 | linebelowheader=Line("", "-", " ", ""), 121 | linebetweenrows=None, 122 | linebelow=Line("", "-", " ", ""), 123 | headerrow=DataRow("", " ", ""), 124 | datarow=DataRow("", " ", ""), 125 | padding=0, 126 | with_header_hide=["lineabove", "linebelow"]), 127 | "plain": 128 | TableFormat(lineabove=None, linebelowheader=None, 129 | linebetweenrows=None, linebelow=None, 130 | headerrow=DataRow("", " ", ""), 131 | datarow=DataRow("", " ", ""), 132 | padding=0, with_header_hide=None), 133 | "grid": 134 | TableFormat(lineabove=Line("+", "-", "+", "+"), 135 | linebelowheader=Line("+", "=", "+", "+"), 136 | linebetweenrows=Line("+", "-", "+", "+"), 137 | linebelow=Line("+", "-", "+", "+"), 138 | headerrow=DataRow("|", "|", "|"), 139 | datarow=DataRow("|", "|", "|"), 140 | padding=1, with_header_hide=None), 141 | "pipe": 142 | TableFormat(lineabove=_pipe_line_with_colons, 143 | linebelowheader=_pipe_line_with_colons, 144 | linebetweenrows=None, 145 | linebelow=None, 146 | headerrow=DataRow("|", "|", "|"), 147 | datarow=DataRow("|", "|", "|"), 148 | padding=1, 149 | with_header_hide=["lineabove"]), 150 | "orgtbl": 151 | TableFormat(lineabove=None, 152 | linebelowheader=Line("|", "-", "+", "|"), 153 | linebetweenrows=None, 154 | linebelow=None, 155 | headerrow=DataRow("|", "|", "|"), 156 | datarow=DataRow("|", "|", "|"), 157 | padding=1, with_header_hide=None), 158 | "rst": 159 | TableFormat(lineabove=Line("", "=", " ", ""), 160 | linebelowheader=Line("", "=", " ", ""), 161 | linebetweenrows=None, 162 | linebelow=Line("", "=", " ", ""), 163 | headerrow=DataRow("", " ", ""), 164 | datarow=DataRow("", " ", ""), 165 | padding=0, with_header_hide=None), 166 | "mediawiki": 167 | TableFormat(lineabove=Line("{| class=\"wikitable\" style=\"text-align: left;\"", 168 | "", "", "\n|+ \n|-"), 169 | linebelowheader=Line("|-", "", "", ""), 170 | linebetweenrows=Line("|-", "", "", ""), 171 | linebelow=Line("|}", "", "", ""), 172 | headerrow=partial(_mediawiki_row_with_attrs, "!"), 173 | datarow=partial(_mediawiki_row_with_attrs, "|"), 174 | padding=0, with_header_hide=None), 175 | "latex": 176 | TableFormat(lineabove=_latex_line_begin_tabular, 177 | linebelowheader=Line("\\hline", "", "", ""), 178 | linebetweenrows=None, 179 | linebelow=Line("\\hline\n\\end{tabular}", "", "", ""), 180 | headerrow=DataRow("", "&", "\\\\"), 181 | datarow=DataRow("", "&", "\\\\"), 182 | padding=1, with_header_hide=None), 183 | "tsv": 184 | TableFormat(lineabove=None, linebelowheader=None, 185 | linebetweenrows=None, linebelow=None, 186 | headerrow=DataRow("", "\t", ""), 187 | datarow=DataRow("", "\t", ""), 188 | padding=0, with_header_hide=None)} 189 | 190 | 191 | tabulate_formats = list(sorted(_table_formats.keys())) 192 | 193 | 194 | _invisible_codes = re.compile("\x1b\[\d*m") # ANSI color codes 195 | _invisible_codes_bytes = re.compile(b"\x1b\[\d*m") # ANSI color codes 196 | 197 | 198 | def simple_separated_format(separator): 199 | """Construct a simple TableFormat with columns separated by a separator. 200 | 201 | >>> tsv = simple_separated_format("\\t") ; \ 202 | tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \\t 1\\nspam\\t23' 203 | True 204 | 205 | """ 206 | return TableFormat(None, None, None, None, 207 | headerrow=DataRow('', separator, ''), 208 | datarow=DataRow('', separator, ''), 209 | padding=0, with_header_hide=None) 210 | 211 | 212 | def _isconvertible(conv, string): 213 | try: 214 | n = conv(string) 215 | return True 216 | except ValueError: 217 | return False 218 | 219 | 220 | def _isnumber(string): 221 | """ 222 | >>> _isnumber("123.45") 223 | True 224 | >>> _isnumber("123") 225 | True 226 | >>> _isnumber("spam") 227 | False 228 | """ 229 | return _isconvertible(float, string) 230 | 231 | 232 | def _isint(string): 233 | """ 234 | >>> _isint("123") 235 | True 236 | >>> _isint("123.45") 237 | False 238 | """ 239 | return type(string) is int or \ 240 | (isinstance(string, _binary_type) or isinstance(string, _text_type)) and \ 241 | _isconvertible(int, string) 242 | 243 | 244 | def _type(string, has_invisible=True): 245 | """The least generic type (type(None), int, float, str, unicode). 246 | 247 | >>> _type(None) is type(None) 248 | True 249 | >>> _type("foo") is type("") 250 | True 251 | >>> _type("1") is type(1) 252 | True 253 | >>> _type('\x1b[31m42\x1b[0m') is type(42) 254 | True 255 | >>> _type('\x1b[31m42\x1b[0m') is type(42) 256 | True 257 | 258 | """ 259 | 260 | if has_invisible and \ 261 | (isinstance(string, _text_type) or isinstance(string, _binary_type)): 262 | string = _strip_invisible(string) 263 | 264 | if string is None: 265 | return _none_type 266 | elif hasattr(string, "isoformat"): # datetime.datetime, date, and time 267 | return _text_type 268 | elif _isint(string): 269 | return int 270 | elif _isnumber(string): 271 | return float 272 | elif isinstance(string, _binary_type): 273 | return _binary_type 274 | else: 275 | return _text_type 276 | 277 | 278 | def _afterpoint(string): 279 | """Symbols after a decimal point, -1 if the string lacks the decimal point. 280 | 281 | >>> _afterpoint("123.45") 282 | 2 283 | >>> _afterpoint("1001") 284 | -1 285 | >>> _afterpoint("eggs") 286 | -1 287 | >>> _afterpoint("123e45") 288 | 2 289 | 290 | """ 291 | if _isnumber(string): 292 | if _isint(string): 293 | return -1 294 | else: 295 | pos = string.rfind(".") 296 | pos = string.lower().rfind("e") if pos < 0 else pos 297 | if pos >= 0: 298 | return len(string) - pos - 1 299 | else: 300 | return -1 # no point 301 | else: 302 | return -1 # not a number 303 | 304 | 305 | def _padleft(width, s, has_invisible=True): 306 | """Flush right. 307 | 308 | >>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430' 309 | True 310 | 311 | """ 312 | iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width 313 | fmt = "{0:>%ds}" % iwidth 314 | return fmt.format(s) 315 | 316 | 317 | def _padright(width, s, has_invisible=True): 318 | """Flush left. 319 | 320 | >>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 ' 321 | True 322 | 323 | """ 324 | iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width 325 | fmt = "{0:<%ds}" % iwidth 326 | return fmt.format(s) 327 | 328 | 329 | def _padboth(width, s, has_invisible=True): 330 | """Center string. 331 | 332 | >>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 ' 333 | True 334 | 335 | """ 336 | iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width 337 | fmt = "{0:^%ds}" % iwidth 338 | return fmt.format(s) 339 | 340 | 341 | def _strip_invisible(s): 342 | "Remove invisible ANSI color codes." 343 | if isinstance(s, _text_type): 344 | return re.sub(_invisible_codes, "", s) 345 | else: # a bytestring 346 | return re.sub(_invisible_codes_bytes, "", s) 347 | 348 | 349 | def _visible_width(s): 350 | """Visible width of a printed string. ANSI color codes are removed. 351 | 352 | >>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world") 353 | (5, 5) 354 | 355 | """ 356 | if isinstance(s, _text_type) or isinstance(s, _binary_type): 357 | return len(_strip_invisible(s)) 358 | else: 359 | return len(_text_type(s)) 360 | 361 | 362 | def _align_column(strings, alignment, minwidth=0, has_invisible=True): 363 | """[string] -> [padded_string] 364 | 365 | >>> list(map(str,_align_column(["12.345", "-1234.5", "1.23", "1234.5", "1e+234", "1.0e234"], "decimal"))) 366 | [' 12.345 ', '-1234.5 ', ' 1.23 ', ' 1234.5 ', ' 1e+234 ', ' 1.0e234'] 367 | 368 | >>> list(map(str,_align_column(['123.4', '56.7890'], None))) 369 | ['123.4', '56.7890'] 370 | 371 | """ 372 | if alignment == "right": 373 | strings = [s.strip() for s in strings] 374 | padfn = _padleft 375 | elif alignment == "center": 376 | strings = [s.strip() for s in strings] 377 | padfn = _padboth 378 | elif alignment == "decimal": 379 | decimals = [_afterpoint(s) for s in strings] 380 | maxdecimals = max(decimals) 381 | strings = [s + (maxdecimals - decs) * " " 382 | for s, decs in zip(strings, decimals)] 383 | padfn = _padleft 384 | elif not alignment: 385 | return strings 386 | else: 387 | strings = [s.strip() for s in strings] 388 | padfn = _padright 389 | 390 | if has_invisible: 391 | width_fn = _visible_width 392 | else: 393 | width_fn = len 394 | 395 | maxwidth = max(max(list(map(width_fn, strings))), minwidth) 396 | padded_strings = [padfn(maxwidth, s, has_invisible) for s in strings] 397 | return padded_strings 398 | 399 | 400 | def _more_generic(type1, type2): 401 | types = { _none_type: 0, int: 1, float: 2, _binary_type: 3, _text_type: 4 } 402 | invtypes = { 4: _text_type, 3: _binary_type, 2: float, 1: int, 0: _none_type } 403 | moregeneric = max(types.get(type1, 4), types.get(type2, 4)) 404 | return invtypes[moregeneric] 405 | 406 | 407 | def _column_type(strings, has_invisible=True): 408 | """The least generic type all column values are convertible to. 409 | 410 | >>> _column_type(["1", "2"]) is _int_type 411 | True 412 | >>> _column_type(["1", "2.3"]) is _float_type 413 | True 414 | >>> _column_type(["1", "2.3", "four"]) is _text_type 415 | True 416 | >>> _column_type(["four", '\u043f\u044f\u0442\u044c']) is _text_type 417 | True 418 | >>> _column_type([None, "brux"]) is _text_type 419 | True 420 | >>> _column_type([1, 2, None]) is _int_type 421 | True 422 | >>> import datetime as dt 423 | >>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is _text_type 424 | True 425 | 426 | """ 427 | types = [_type(s, has_invisible) for s in strings ] 428 | return reduce(_more_generic, types, int) 429 | 430 | 431 | def _format(val, valtype, floatfmt, missingval=""): 432 | """Format a value accoding to its type. 433 | 434 | Unicode is supported: 435 | 436 | >>> hrow = ['\u0431\u0443\u043a\u0432\u0430', '\u0446\u0438\u0444\u0440\u0430'] ; \ 437 | tbl = [['\u0430\u0437', 2], ['\u0431\u0443\u043a\u0438', 4]] ; \ 438 | good_result = '\\u0431\\u0443\\u043a\\u0432\\u0430 \\u0446\\u0438\\u0444\\u0440\\u0430\\n------- -------\\n\\u0430\\u0437 2\\n\\u0431\\u0443\\u043a\\u0438 4' ; \ 439 | tabulate(tbl, headers=hrow) == good_result 440 | True 441 | 442 | """ 443 | if val is None: 444 | return missingval 445 | 446 | if valtype in [int, _text_type]: 447 | return "{0}".format(val) 448 | elif valtype is _binary_type: 449 | return _text_type(val, "ascii") 450 | elif valtype is float: 451 | return format(float(val), floatfmt) 452 | else: 453 | return "{0}".format(val) 454 | 455 | 456 | def _align_header(header, alignment, width): 457 | if alignment == "left": 458 | return _padright(width, header) 459 | elif alignment == "center": 460 | return _padboth(width, header) 461 | elif not alignment: 462 | return "{0}".format(header) 463 | else: 464 | return _padleft(width, header) 465 | 466 | 467 | def _normalize_tabular_data(tabular_data, headers): 468 | """Transform a supported data type to a list of lists, and a list of headers. 469 | 470 | Supported tabular data types: 471 | 472 | * list-of-lists or another iterable of iterables 473 | 474 | * list of named tuples (usually used with headers="keys") 475 | 476 | * 2D NumPy arrays 477 | 478 | * NumPy record arrays (usually used with headers="keys") 479 | 480 | * dict of iterables (usually used with headers="keys") 481 | 482 | * pandas.DataFrame (usually used with headers="keys") 483 | 484 | The first row can be used as headers if headers="firstrow", 485 | column indices can be used as headers if headers="keys". 486 | 487 | """ 488 | 489 | if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"): 490 | # dict-like and pandas.DataFrame? 491 | if hasattr(tabular_data.values, "__call__"): 492 | # likely a conventional dict 493 | keys = list(tabular_data.keys()) 494 | rows = list(zip_longest(*list(tabular_data.values()))) # columns have to be transposed 495 | elif hasattr(tabular_data, "index"): 496 | # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) 497 | keys = list(tabular_data.keys()) 498 | vals = tabular_data.values # values matrix doesn't need to be transposed 499 | names = tabular_data.index 500 | rows = [[v]+list(row) for v,row in zip(names, vals)] 501 | else: 502 | raise ValueError("tabular data doesn't appear to be a dict or a DataFrame") 503 | 504 | if headers == "keys": 505 | headers = list(map(_text_type,keys)) # headers should be strings 506 | 507 | else: # it's a usual an iterable of iterables, or a NumPy array 508 | rows = list(tabular_data) 509 | 510 | if (headers == "keys" and 511 | hasattr(tabular_data, "dtype") and 512 | getattr(tabular_data.dtype, "names")): 513 | # numpy record array 514 | headers = tabular_data.dtype.names 515 | elif (headers == "keys" 516 | and len(rows) > 0 517 | and isinstance(rows[0], tuple) 518 | and hasattr(rows[0], "_fields")): # namedtuple 519 | headers = list(map(_text_type, rows[0]._fields)) 520 | elif headers == "keys" and len(rows) > 0: # keys are column indices 521 | headers = list(map(_text_type, list(range(len(rows[0]))))) 522 | 523 | # take headers from the first row if necessary 524 | if headers == "firstrow" and len(rows) > 0: 525 | headers = list(map(_text_type, rows[0])) # headers should be strings 526 | rows = rows[1:] 527 | 528 | headers = list(headers) 529 | rows = list(map(list,rows)) 530 | 531 | # pad with empty headers for initial columns if necessary 532 | if headers and len(rows) > 0: 533 | nhs = len(headers) 534 | ncols = len(rows[0]) 535 | if nhs < ncols: 536 | headers = [""]*(ncols - nhs) + headers 537 | 538 | return rows, headers 539 | 540 | 541 | def tabulate(tabular_data, headers=[], tablefmt="simple", 542 | floatfmt="g", numalign="decimal", stralign="left", 543 | missingval=""): 544 | """Format a fixed width table for pretty printing. 545 | 546 | >>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]])) 547 | --- --------- 548 | 1 2.34 549 | -56 8.999 550 | 2 10001 551 | --- --------- 552 | 553 | The first required argument (`tabular_data`) can be a 554 | list-of-lists (or another iterable of iterables), a list of named 555 | tuples, a dictionary of iterables, a two-dimensional NumPy array, 556 | NumPy record array, or a Pandas' dataframe. 557 | 558 | 559 | Table headers 560 | ------------- 561 | 562 | To print nice column headers, supply the second argument (`headers`): 563 | 564 | - `headers` can be an explicit list of column headers 565 | - if `headers="firstrow"`, then the first row of data is used 566 | - if `headers="keys"`, then dictionary keys or column indices are used 567 | 568 | Otherwise a headerless table is produced. 569 | 570 | If the number of headers is less than the number of columns, they 571 | are supposed to be names of the last columns. This is consistent 572 | with the plain-text format of R and Pandas' dataframes. 573 | 574 | >>> print(tabulate([["sex","age"],["Alice","F",24],["Bob","M",19]], 575 | ... headers="firstrow")) 576 | sex age 577 | ----- ----- ----- 578 | Alice F 24 579 | Bob M 19 580 | 581 | 582 | Column alignment 583 | ---------------- 584 | 585 | `tabulate` tries to detect column types automatically, and aligns 586 | the values properly. By default it aligns decimal points of the 587 | numbers (or flushes integer numbers to the right), and flushes 588 | everything else to the left. Possible column alignments 589 | (`numalign`, `stralign`) are: "right", "center", "left", "decimal" 590 | (only for `numalign`), and None (to disable alignment). 591 | 592 | 593 | Table formats 594 | ------------- 595 | 596 | `floatfmt` is a format specification used for columns which 597 | contain numeric data with a decimal point. 598 | 599 | `None` values are replaced with a `missingval` string: 600 | 601 | >>> print(tabulate([["spam", 1, None], 602 | ... ["eggs", 42, 3.14], 603 | ... ["other", None, 2.7]], missingval="?")) 604 | ----- -- ---- 605 | spam 1 ? 606 | eggs 42 3.14 607 | other ? 2.7 608 | ----- -- ---- 609 | 610 | Various plain-text table formats (`tablefmt`) are supported: 611 | 'plain', 'simple', 'grid', 'pipe', 'orgtbl', 'rst', 'mediawiki', 612 | and 'latex'. Variable `tabulate_formats` contains the list of 613 | currently supported formats. 614 | 615 | "plain" format doesn't use any pseudographics to draw tables, 616 | it separates columns with a double space: 617 | 618 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 619 | ... ["strings", "numbers"], "plain")) 620 | strings numbers 621 | spam 41.9999 622 | eggs 451 623 | 624 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="plain")) 625 | spam 41.9999 626 | eggs 451 627 | 628 | "simple" format is like Pandoc simple_tables: 629 | 630 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 631 | ... ["strings", "numbers"], "simple")) 632 | strings numbers 633 | --------- --------- 634 | spam 41.9999 635 | eggs 451 636 | 637 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="simple")) 638 | ---- -------- 639 | spam 41.9999 640 | eggs 451 641 | ---- -------- 642 | 643 | "grid" is similar to tables produced by Emacs table.el package or 644 | Pandoc grid_tables: 645 | 646 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 647 | ... ["strings", "numbers"], "grid")) 648 | +-----------+-----------+ 649 | | strings | numbers | 650 | +===========+===========+ 651 | | spam | 41.9999 | 652 | +-----------+-----------+ 653 | | eggs | 451 | 654 | +-----------+-----------+ 655 | 656 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="grid")) 657 | +------+----------+ 658 | | spam | 41.9999 | 659 | +------+----------+ 660 | | eggs | 451 | 661 | +------+----------+ 662 | 663 | "pipe" is like tables in PHP Markdown Extra extension or Pandoc 664 | pipe_tables: 665 | 666 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 667 | ... ["strings", "numbers"], "pipe")) 668 | | strings | numbers | 669 | |:----------|----------:| 670 | | spam | 41.9999 | 671 | | eggs | 451 | 672 | 673 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="pipe")) 674 | |:-----|---------:| 675 | | spam | 41.9999 | 676 | | eggs | 451 | 677 | 678 | "orgtbl" is like tables in Emacs org-mode and orgtbl-mode. They 679 | are slightly different from "pipe" format by not using colons to 680 | define column alignment, and using a "+" sign to indicate line 681 | intersections: 682 | 683 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 684 | ... ["strings", "numbers"], "orgtbl")) 685 | | strings | numbers | 686 | |-----------+-----------| 687 | | spam | 41.9999 | 688 | | eggs | 451 | 689 | 690 | 691 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="orgtbl")) 692 | | spam | 41.9999 | 693 | | eggs | 451 | 694 | 695 | "rst" is like a simple table format from reStructuredText; please 696 | note that reStructuredText accepts also "grid" tables: 697 | 698 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], 699 | ... ["strings", "numbers"], "rst")) 700 | ========= ========= 701 | strings numbers 702 | ========= ========= 703 | spam 41.9999 704 | eggs 451 705 | ========= ========= 706 | 707 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="rst")) 708 | ==== ======== 709 | spam 41.9999 710 | eggs 451 711 | ==== ======== 712 | 713 | "mediawiki" produces a table markup used in Wikipedia and on other 714 | MediaWiki-based sites: 715 | 716 | >>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], ["eggs", "451.0"]], 717 | ... headers="firstrow", tablefmt="mediawiki")) 718 | {| class="wikitable" style="text-align: left;" 719 | |+ 720 | |- 721 | ! strings !! align="right"| numbers 722 | |- 723 | | spam || align="right"| 41.9999 724 | |- 725 | | eggs || align="right"| 451 726 | |} 727 | 728 | "latex" produces a tabular environment of LaTeX document markup: 729 | 730 | >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex")) 731 | \\begin{tabular}{lr} 732 | \\hline 733 | spam & 41.9999 \\\\ 734 | eggs & 451 \\\\ 735 | \\hline 736 | \\end{tabular} 737 | 738 | """ 739 | 740 | list_of_lists, headers = _normalize_tabular_data(tabular_data, headers) 741 | 742 | # optimization: look for ANSI control codes once, 743 | # enable smart width functions only if a control code is found 744 | plain_text = '\n'.join(['\t'.join(map(_text_type, headers))] + \ 745 | ['\t'.join(map(_text_type, row)) for row in list_of_lists]) 746 | has_invisible = re.search(_invisible_codes, plain_text) 747 | if has_invisible: 748 | width_fn = _visible_width 749 | else: 750 | width_fn = len 751 | 752 | # format rows and columns, convert numeric values to strings 753 | cols = list(zip(*list_of_lists)) 754 | coltypes = list(map(_column_type, cols)) 755 | cols = [[_format(v, ct, floatfmt, missingval) for v in c] 756 | for c,ct in zip(cols, coltypes)] 757 | 758 | # align columns 759 | aligns = [numalign if ct in [int,float] else stralign for ct in coltypes] 760 | minwidths = [width_fn(h)+2 for h in headers] if headers else [0]*len(cols) 761 | cols = [_align_column(c, a, minw, has_invisible) 762 | for c, a, minw in zip(cols, aligns, minwidths)] 763 | 764 | if headers: 765 | # align headers and add headers 766 | minwidths = [max(minw, width_fn(c[0])) for minw, c in zip(minwidths, cols)] 767 | headers = [_align_header(h, a, minw) 768 | for h, a, minw in zip(headers, aligns, minwidths)] 769 | rows = list(zip(*cols)) 770 | else: 771 | minwidths = [width_fn(c[0]) for c in cols] 772 | rows = list(zip(*cols)) 773 | 774 | if not isinstance(tablefmt, TableFormat): 775 | tablefmt = _table_formats.get(tablefmt, _table_formats["simple"]) 776 | 777 | return _format_table(tablefmt, headers, rows, minwidths, aligns) 778 | 779 | 780 | def _build_simple_row(padded_cells, rowfmt): 781 | "Format row according to DataRow format without padding." 782 | begin, sep, end = rowfmt 783 | return (begin + sep.join(padded_cells) + end).rstrip() 784 | 785 | 786 | def _build_row(padded_cells, colwidths, colaligns, rowfmt): 787 | "Return a string which represents a row of data cells." 788 | if not rowfmt: 789 | return None 790 | if hasattr(rowfmt, "__call__"): 791 | return rowfmt(padded_cells, colwidths, colaligns) 792 | else: 793 | return _build_simple_row(padded_cells, rowfmt) 794 | 795 | 796 | def _build_line(colwidths, colaligns, linefmt): 797 | "Return a string which represents a horizontal line." 798 | if not linefmt: 799 | return None 800 | if hasattr(linefmt, "__call__"): 801 | return linefmt(colwidths, colaligns) 802 | else: 803 | begin, fill, sep, end = linefmt 804 | cells = [fill*w for w in colwidths] 805 | return _build_simple_row(cells, (begin, sep, end)) 806 | 807 | 808 | def _pad_row(cells, padding): 809 | if cells: 810 | pad = " "*padding 811 | padded_cells = [pad + cell + pad for cell in cells] 812 | return padded_cells 813 | else: 814 | return cells 815 | 816 | 817 | def _format_table(fmt, headers, rows, colwidths, colaligns): 818 | """Produce a plain-text representation of the table.""" 819 | lines = [] 820 | hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] 821 | pad = fmt.padding 822 | headerrow = fmt.headerrow 823 | 824 | padded_widths = [(w + 2*pad) for w in colwidths] 825 | padded_headers = _pad_row(headers, pad) 826 | padded_rows = [_pad_row(row, pad) for row in rows] 827 | 828 | if fmt.lineabove and "lineabove" not in hidden: 829 | lines.append(_build_line(padded_widths, colaligns, fmt.lineabove)) 830 | 831 | if padded_headers: 832 | lines.append(_build_row(padded_headers, padded_widths, colaligns, headerrow)) 833 | if fmt.linebelowheader and "linebelowheader" not in hidden: 834 | lines.append(_build_line(padded_widths, colaligns, fmt.linebelowheader)) 835 | 836 | if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: 837 | # initial rows with a line below 838 | for row in padded_rows[:-1]: 839 | lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow)) 840 | lines.append(_build_line(padded_widths, colaligns, fmt.linebetweenrows)) 841 | # the last row without a line below 842 | lines.append(_build_row(padded_rows[-1], padded_widths, colaligns, fmt.datarow)) 843 | else: 844 | for row in padded_rows: 845 | lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow)) 846 | 847 | if fmt.linebelow and "linebelow" not in hidden: 848 | lines.append(_build_line(padded_widths, colaligns, fmt.linebelow)) 849 | 850 | return "\n".join(lines) 851 | --------------------------------------------------------------------------------