├── SimpleSAC ├── __init__.py ├── envs.py ├── replay_buffer.py ├── sampler.py ├── mixed_replay_buffer.py ├── model.py ├── sim2real_sac_main.py ├── utils_h2o.py └── sim2real_sac.py ├── viskit ├── __init__.py ├── static │ ├── css │ │ └── dropdowns-enhancement.css │ └── js │ │ ├── dropdowns-enhancement.js │ │ ├── jquery.loadTemplate-1.5.6.js │ │ └── bootstrap.min.js ├── core.py ├── logging.py └── templates │ └── main.html ├── H2O.png ├── requirements.txt ├── Network └── Weight_net.py ├── README.md └── xml_path ├── real_file └── half_cheetah.xml └── sim_file ├── half_cheetah_gravityx1.xml └── half_cheetah_gravityx2.0.xml /SimpleSAC/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viskit/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'dementrock' 2 | -------------------------------------------------------------------------------- /H2O.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/t6-thu/H2O/HEAD/H2O.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | 5 | absl-py==1.0.0 6 | cudatoolkit==10.2.89 7 | d4rl=1.1 8 | gym=0.21.0 9 | h5py=3.6.0 10 | ml-collections=0.1.1 11 | mujoco-py=2.0.2.8 12 | numpy=1.21.5 13 | pandas=1.1.5 14 | python=3.7.0 15 | scipy=1.7.3 16 | seaborn=0.11.2 17 | torch=1.10.1+cu113 18 | torchaudio=0.10.1+cu113 19 | torchvision=0.11.2+cu113 20 | tornado=6.2 21 | tqdm=4.62.3 22 | wandb=0.12.9 23 | -------------------------------------------------------------------------------- /SimpleSAC/envs.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Generate different type of dynamics mismatch. 3 | @python version : 3.6.4 4 | ''' 5 | 6 | import gym 7 | from gym.spaces import Box, Discrete, Tuple 8 | from utils_h2o import update_target_env_gravity, update_target_env_density, update_target_env_friction, update_source_env 9 | 10 | 11 | def get_new_gravity_env(variety, env_name): 12 | update_target_env_gravity(variety, env_name) 13 | env = gym.make(env_name) 14 | 15 | return env 16 | 17 | 18 | def get_source_env(env_name="Walker2d-v2"): 19 | update_source_env(env_name) 20 | env = gym.make(env_name) 21 | 22 | return env 23 | 24 | 25 | def get_new_density_env(variety, env_name): 26 | update_target_env_density(variety, env_name) 27 | env = gym.make(env_name) 28 | 29 | return env 30 | 31 | 32 | def get_new_friction_env(variety, env_name): 33 | update_target_env_friction(variety, env_name) 34 | env = gym.make(env_name) 35 | 36 | return env 37 | 38 | def get_dim(space): 39 | if isinstance(space, Box): 40 | return space.low.size 41 | elif isinstance(space, Discrete): 42 | return space.n 43 | elif isinstance(space, Tuple): 44 | return sum(get_dim(subspace) for subspace in space.spaces) 45 | elif hasattr(space, 'flat_dim'): 46 | return space.flat_dim 47 | else: 48 | raise TypeError("Unknown space: {}".format(space)) 49 | -------------------------------------------------------------------------------- /Network/Weight_net.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | class Discriminator(nn.Module): 8 | def __init__(self, num_input, num_hidden, num_output=2, device="cuda", dropout=False): 9 | super().__init__() 10 | self.device = device 11 | self.fc1 = nn.Linear(num_input, num_hidden) 12 | self.fc2 = nn.Linear(num_hidden, num_hidden) 13 | self.fc3 = nn.Linear(num_hidden, num_output) 14 | self.dropout = dropout 15 | self.dropout_layer = nn.Dropout(p=0.2) 16 | 17 | def forward(self, x): 18 | if isinstance(x, np.ndarray): 19 | x = torch.tensor(x, dtype=torch.float).to(self.device) 20 | if self.dropout: 21 | x = F.relu(self.dropout_layer(self.fc1(x))) 22 | x = F.relu(self.dropout_layer(self.fc2(x))) 23 | else: 24 | x = F.relu(self.fc1(x)) 25 | x = F.relu(self.fc2(x)) 26 | output = 2 * torch.tanh(self.fc3(x)) 27 | return output 28 | 29 | class ConcatDiscriminator(Discriminator): 30 | """ 31 | Concatenate inputs along dimension and then pass through MLP. 32 | """ 33 | def __init__(self, *args, dim=1, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.dim = dim 36 | 37 | def forward(self, *inputs, **kwargs): 38 | flat_inputs = torch.cat(inputs, dim=self.dim) 39 | return super().forward(flat_inputs, **kwargs) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H2O 2 | 3 | H2O ([https://arxiv.org/abs/2206.13464](https://arxiv.org/abs/2206.13464)) is the first Hybrid Offline-and-Online Reinforcement Learning framework, that enables simultaneous policy learning with offline real-world datasets and simulation rollouts, while also addressing the sim-to-real dynamics gaps in imperfect simulation. H2O introduces a dynamics-aware policy evaluation scheme, which adaptively penalizes the Q-values as well as fixes the Bellman error on simulated samples with large dynamics gaps. Through extensive simulation and real-world tasks, as well as theoretical analysis, we demonstrate the superior performance of H2O against other cross-domain online and offline RL algorithms. This repository provides the codebase on which we benchmark H2O and baselines in MuJoCo environments. 4 | 5 | ![pipeline](H2O.png) 6 | 7 | ## Installation and Setups 8 | To install the dependencies, run the command: 9 | ```python 10 | pip install -r requirements.txt 11 | ``` 12 | Add this repo directory to your `PYTHONPATH` environment variable: 13 | ``` 14 | export PYTHONPATH="$PYTHONPATH:$(pwd)" 15 | ``` 16 | 17 | ## Run Benchmark Experiments 18 | We benchmark H2O and its baselines on MuJoCo simulation environment and D4RL datasets. To begin, enter the folder `SimpleSAC`: 19 | ``` 20 | cd SimpleSAC 21 | ``` 22 | Then you can run H2O experiments using the following example commands. 23 | ### Simulated in HalfCheetah-v2 with 2x gravity and Medium Replay dataset 24 | ```python 25 | python sim2real_sac_main.py \ 26 | --env_list HalfCheetah-v2 \ 27 | --data_source medium_replay \ 28 | --unreal_dynamics gravity \ 29 | --variety_list 2.0 30 | ``` 31 | ### Simulated in Walker-v2 with .3x friction and Medium Replay dataset 32 | ```python 33 | python sim2real_sac_main.py \ 34 | --env_list Walker-v2 \ 35 | --data_source medium_replay \ 36 | --unreal_dynamics friction \ 37 | --variety_list 0.3 38 | ``` 39 | ### Simulated in HalfCheetah-v2 with joint noise N(0,1) and Medium dataset 40 | ```python 41 | python sim2real_sac_main.py \ 42 | --env_list HalfCheetah-v2 \ 43 | --data_source medium \ 44 | --variety_list 1.0 \ 45 | --joint_noise_std 1.0 46 | ``` 47 | 48 | ## Visualization of Learning Curves 49 | You can resort to [wandb](https://wandb.ai/site) to login your personal account with your wandb API key. 50 | ``` 51 | export WANDB_API_KEY=YOUR_WANDB_API_KEY 52 | ``` 53 | and run `wandb online` to turn on the online syncronization. 54 | 55 | ## Citation 56 | If you are using H2O framework or code for your project development, please cite the following paper: 57 | ``` 58 | @inproceedings{ 59 | niu2022when, 60 | title={When to Trust Your Simulator: Dynamics-Aware Hybrid Offline-and-Online Reinforcement Learning}, 61 | author={Haoyi Niu and Shubham Sharma and Yiwen Qiu and Ming Li and Guyue Zhou and Jianming HU and Xianyuan Zhan}, 62 | booktitle={Advances in Neural Information Processing Systems}, 63 | year={2022}, 64 | url={https://openreview.net/forum?id=zXE8iFOZKw} 65 | } 66 | ``` 67 | 69 | -------------------------------------------------------------------------------- /SimpleSAC/replay_buffer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Replay buffer for storing samples. 3 | @python version : 3.6.8 4 | ''' 5 | import torch 6 | import d4rl 7 | import numpy as np 8 | from gym.spaces import Box, Discrete, Tuple 9 | 10 | from envs import get_dim 11 | 12 | 13 | class Buffer(object): 14 | 15 | def append(self, *args): 16 | pass 17 | 18 | def sample(self, *args): 19 | pass 20 | 21 | class ReplayBuffer(Buffer): 22 | def __init__(self, state_dim, action_dim, max_size=int(1e6), device='cuda'): 23 | self.max_size = max_size 24 | self.ptr = 0 25 | self.size = 0 26 | 27 | self.state = np.zeros((max_size, state_dim)) 28 | self.action = np.zeros((max_size, action_dim)) 29 | self.next_state = np.zeros((max_size, state_dim)) 30 | self.next_action = np.zeros((max_size, action_dim)) 31 | self.reward = np.zeros((max_size, 1)) 32 | self.done = np.zeros((max_size, 1)) 33 | 34 | self.device = torch.device(device) 35 | 36 | def append(self, state, action, reward, next_state, done): 37 | self.state[self.ptr] = state 38 | self.action[self.ptr] = action 39 | self.next_state[self.ptr] = next_state 40 | self.reward[self.ptr] = reward 41 | self.done[self.ptr] = done 42 | 43 | self.ptr = (self.ptr + 1) % self.max_size 44 | self.size = min(self.size + 1, self.max_size) 45 | 46 | def sample(self, batch_size): 47 | ind = np.random.randint(0, self.size, size=batch_size) 48 | 49 | return { 50 | 'observations': torch.FloatTensor(self.state[ind]).to(self.device), 51 | 'actions': torch.FloatTensor(self.action[ind]).to(self.device), 52 | 'rewards': torch.FloatTensor(self.reward[ind]).to(self.device), 53 | 'next_observations': torch.FloatTensor(self.next_state[ind]).to(self.device), 54 | 'dones': torch.FloatTensor(self.done[ind]).to(self.device) 55 | } 56 | 57 | def batch_to_torch(batch, device): 58 | return { 59 | k: torch.from_numpy(v).to(device=device, non_blocking=True) 60 | for k, v in batch.items() 61 | } 62 | 63 | 64 | def get_d4rl_dataset(env): 65 | dataset = d4rl.qlearning_dataset(env) 66 | return dict( 67 | observations=dataset['observations'], 68 | actions=dataset['actions'], 69 | next_observations=dataset['next_observations'], 70 | rewards=dataset['rewards'], 71 | dones=dataset['terminals'].astype(np.float32), 72 | ) 73 | 74 | 75 | def index_batch(batch, indices): 76 | indexed = {} 77 | for key in batch.keys(): 78 | indexed[key] = batch[key][indices, ...] 79 | return indexed 80 | 81 | 82 | def parition_batch_train_test(batch, train_ratio): 83 | train_indices = np.random.rand(batch['observations'].shape[0]) < train_ratio 84 | train_batch = index_batch(batch, train_indices) 85 | test_batch = index_batch(batch, ~train_indices) 86 | return train_batch, test_batch 87 | 88 | 89 | def subsample_batch(batch, size): 90 | indices = np.random.randint(batch['observations'].shape[0], size=size) 91 | return index_batch(batch, indices) 92 | 93 | 94 | def concatenate_batches(batches): 95 | concatenated = {} 96 | for key in batches[0].keys(): 97 | concatenated[key] = np.concatenate([batch[key] for batch in batches], axis=0).astype(np.float32) 98 | return concatenated 99 | 100 | 101 | def split_batch(batch, batch_size): 102 | batches = [] 103 | length = batch['observations'].shape[0] 104 | keys = batch.keys() 105 | for start in range(0, length, batch_size): 106 | end = min(start + batch_size, length) 107 | batches.append({key: batch[key][start:end, ...] for key in keys}) 108 | return batches 109 | 110 | 111 | def split_data_by_traj(data, max_traj_length): 112 | dones = data['dones'].astype(bool) 113 | start = 0 114 | splits = [] 115 | for i, done in enumerate(dones): 116 | if i - start + 1 >= max_traj_length or done: 117 | splits.append(index_batch(data, slice(start, i + 1))) 118 | start = i + 1 119 | 120 | if start < len(dones): 121 | splits.append(index_batch(data, slice(start, None))) 122 | 123 | return splits 124 | 125 | -------------------------------------------------------------------------------- /SimpleSAC/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | 6 | class StepSampler(object): 7 | 8 | def __init__(self, env, max_traj_length=1000): 9 | self.max_traj_length = max_traj_length 10 | self._env = env 11 | self._traj_steps = 0 12 | self._current_observation = self.env.reset() 13 | 14 | def sample(self, policy, n_steps, deterministic=False, replay_buffer=None, joint_noise_std=0.): 15 | observations = [] 16 | actions = [] 17 | rewards = [] 18 | next_observations = [] 19 | dones = [] 20 | 21 | for _ in range(n_steps): 22 | self._traj_steps += 1 23 | observation = self._current_observation 24 | 25 | #TODO sample actions from current policy 26 | action = policy( 27 | np.expand_dims(observation, 0), deterministic=deterministic 28 | )[0, :] 29 | 30 | if joint_noise_std > 0.: 31 | # normal distribution 32 | next_observation, reward, done, _ = self.env.step(action + np.random.randn(action.shape[0],) * joint_noise_std) 33 | else: 34 | next_observation, reward, done, _ = self.env.step(action) 35 | 36 | observations.append(observation) 37 | actions.append(action) 38 | rewards.append(reward) 39 | dones.append(done) 40 | next_observations.append(next_observation) 41 | 42 | # add samples derived from current policy to replay buffer 43 | if replay_buffer is not None: 44 | replay_buffer.append( 45 | observation, action, reward, next_observation, done 46 | ) 47 | 48 | self._current_observation = next_observation 49 | 50 | if done or self._traj_steps >= self.max_traj_length: 51 | self._traj_steps = 0 52 | self._current_observation = self.env.reset() 53 | 54 | return dict( 55 | observations=np.array(observations, dtype=np.float32), 56 | actions=np.array(actions, dtype=np.float32), 57 | rewards=np.array(rewards, dtype=np.float32), 58 | next_observations=np.array(next_observations, dtype=np.float32), 59 | dones=np.array(dones, dtype=np.float32), 60 | ) 61 | 62 | @property 63 | def env(self): 64 | return self._env 65 | 66 | # with dones as a trajectory end indicator, we can use this sampler to sample trajectories 67 | class TrajSampler(object): 68 | 69 | def __init__(self, env, max_traj_length=1000): 70 | self.max_traj_length = max_traj_length 71 | self._env = env 72 | 73 | def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None): 74 | trajs = [] 75 | for _ in range(n_trajs): 76 | observations = [] 77 | actions = [] 78 | rewards = [] 79 | next_observations = [] 80 | dones = [] 81 | 82 | observation = self.env.reset() 83 | 84 | for _ in range(self.max_traj_length): 85 | action = policy( 86 | np.expand_dims(observation, 0), deterministic=deterministic 87 | )[0, :] 88 | next_observation, reward, done, _ = self.env.step(action) 89 | 90 | observations.append(observation) 91 | actions.append(action) 92 | rewards.append(reward) 93 | dones.append(done) 94 | next_observations.append(next_observation) 95 | 96 | if replay_buffer is not None: 97 | replay_buffer.add_sample( 98 | observation, action, reward, next_observation, done 99 | ) 100 | 101 | observation = next_observation 102 | 103 | if done: 104 | break 105 | 106 | trajs.append(dict( 107 | observations=np.array(observations, dtype=np.float32), 108 | actions=np.array(actions, dtype=np.float32), 109 | rewards=np.array(rewards, dtype=np.float32), 110 | next_observations=np.array(next_observations, dtype=np.float32), 111 | dones=np.array(dones, dtype=np.float32), 112 | )) 113 | 114 | return trajs 115 | 116 | @property 117 | def env(self): 118 | return self._env 119 | -------------------------------------------------------------------------------- /SimpleSAC/mixed_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import random 5 | import numpy as np 6 | from gym.spaces import Box, Discrete, Tuple 7 | 8 | from envs import get_dim 9 | from replay_buffer import ReplayBuffer 10 | 11 | 12 | class MixedReplayBuffer(ReplayBuffer): 13 | def __init__(self, reward_scale, reward_bias, clip_action, state_dim, action_dim, task="halfcheetah", data_source="medium_replay", device="cuda", scale_rewards=True, scale_state=False, buffer_ratio=1, residual_ratio=0.1): 14 | super().__init__(state_dim, action_dim, device=device) 15 | 16 | self.scale_rewards = scale_rewards 17 | self.scale_state = scale_state 18 | self.buffer_ratio = buffer_ratio 19 | self.residual_ratio = residual_ratio 20 | 21 | # load expert dataset into the replay buffer 22 | path = os.path.join("../../d4rl_mujoco_dataset", "{}_{}-v2.hdf5".format(task, data_source)) 23 | with h5py.File(path, "r") as dataset: 24 | total_num = dataset['observations'].shape[0] 25 | # idx = random.sample(range(total_num), int(total_num * self.residual_ratio)) 26 | idx = np.random.choice(range(total_num), int(total_num * self.residual_ratio), replace=False) 27 | s = np.vstack(np.array(dataset['observations'])).astype(np.float32)[idx, :] # An (N, dim_observation)-dimensional numpy array of observations 28 | a = np.vstack(np.array(dataset['actions'])).astype(np.float32)[idx, :] # An (N, dim_action)-dimensional numpy array of actions 29 | r = np.vstack(np.array(dataset['rewards'])).astype(np.float32)[idx, :] # An (N,)-dimensional numpy array of rewards 30 | s_ = np.vstack(np.array(dataset['next_observations'])).astype(np.float32)[idx, :] # An (N, dim_observation)-dimensional numpy array of next observations 31 | done = np.vstack(np.array(dataset['terminals']))[idx, :] # An (N,)-dimensional numpy array of terminal flags 32 | 33 | # whether to bias the reward 34 | r = r * reward_scale + reward_bias 35 | # whether to clip actions 36 | a = np.clip(a, -clip_action, clip_action) 37 | 38 | fixed_dataset_size = r.shape[0] 39 | self.fixed_dataset_size = fixed_dataset_size 40 | self.ptr = fixed_dataset_size 41 | self.size = fixed_dataset_size 42 | self.max_size = (self.buffer_ratio + 1) * fixed_dataset_size 43 | 44 | self.state = np.vstack((s, np.zeros((self.max_size - self.fixed_dataset_size, state_dim)))) 45 | self.action = np.vstack((a, np.zeros((self.max_size - self.fixed_dataset_size, action_dim)))) 46 | self.next_state = np.vstack((s_, np.zeros((self.max_size - self.fixed_dataset_size, state_dim)))) 47 | self.reward = np.vstack((r, np.zeros((self.max_size - self.fixed_dataset_size, 1)))) 48 | self.done = np.vstack((done, np.zeros((self.max_size - self.fixed_dataset_size, 1)))) 49 | self.device = torch.device(device) 50 | 51 | # # State normalization 52 | self.normalize_states() 53 | 54 | 55 | 56 | def normalize_states(self, eps=1e-3): 57 | # STATE: standard normalization 58 | self.state_mean = self.state.mean(0, keepdims=True) 59 | self.state_std = self.state.std(0, keepdims=True) + eps 60 | if self.scale_state: 61 | self.state = (self.state - self.state_mean) / self.state_std 62 | self.next_state = (self.next_state - self.state_mean) / self.state_std 63 | 64 | def append(self, s, a, r, s_, done): 65 | 66 | self.state[self.ptr] = s 67 | self.action[self.ptr] = a 68 | self.next_state[self.ptr] = s_ 69 | self.reward[self.ptr] = r 70 | self.done[self.ptr] = done 71 | 72 | # fix the offline dataset and shuffle the simulated part 73 | self.ptr = (self.ptr + 1 - self.fixed_dataset_size) % (self.max_size - self.fixed_dataset_size) + self.fixed_dataset_size 74 | self.size = min(self.size + 1, self.max_size) 75 | 76 | def append_traj(self, observations, actions, rewards, next_observations, dones): 77 | for o, a, r, no, d in zip(observations, actions, rewards, next_observations, dones): 78 | self.append(o, a, r, no, d) 79 | 80 | def sample(self, batch_size, scope=None, type=None): 81 | if scope == None: 82 | ind = np.random.randint(0, self.size, size=batch_size) 83 | elif scope == "real": 84 | ind = np.random.randint(0, self.fixed_dataset_size, size=batch_size) 85 | elif scope == "sim": 86 | ind = np.random.randint(self.fixed_dataset_size, self.size, size=batch_size) 87 | else: 88 | raise RuntimeError("Misspecified range for replay buffer sampling") 89 | 90 | if type == None: 91 | return { 92 | 'observations': torch.FloatTensor(self.state[ind]).to(self.device), 93 | 'actions': torch.FloatTensor(self.action[ind]).to(self.device), 94 | 'rewards': torch.FloatTensor(self.reward[ind]).to(self.device), 95 | 'next_observations': torch.FloatTensor(self.next_state[ind]).to(self.device), 96 | 'dones': torch.FloatTensor(self.done[ind]).to(self.device) 97 | } 98 | elif type == "sas": 99 | return { 100 | 'observations': torch.FloatTensor(self.state[ind]).to(self.device), 101 | 'actions': torch.FloatTensor(self.action[ind]).to(self.device), 102 | 'next_observations': torch.FloatTensor(self.next_state[ind]).to(self.device) 103 | } 104 | elif type == "sa": 105 | return { 106 | 'observations': torch.FloatTensor(self.state[ind]).to(self.device), 107 | 'actions': torch.FloatTensor(self.action[ind]).to(self.device) 108 | } 109 | else: 110 | raise RuntimeError("Misspecified return data types for replay buffer sampling") 111 | 112 | def get_mean_std(self): 113 | return torch.FloatTensor(self.state_mean).to(self.device), torch.FloatTensor(self.state_std).to(self.device) -------------------------------------------------------------------------------- /xml_path/real_file/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 34 | 35 | 36 | 37 | 38 | 40 | 41 | 42 | 43 | 112 | -------------------------------------------------------------------------------- /xml_path/sim_file/half_cheetah_gravityx1.xml: -------------------------------------------------------------------------------- 1 | 34 | 35 | 36 | 37 | 38 | 40 | 41 | 42 | 43 | 112 | -------------------------------------------------------------------------------- /xml_path/sim_file/half_cheetah_gravityx2.0.xml: -------------------------------------------------------------------------------- 1 | 34 | 35 | 36 | 37 | 38 | 40 | 41 | 42 | 43 | 112 | -------------------------------------------------------------------------------- /SimpleSAC/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal 6 | from torch.distributions.transformed_distribution import TransformedDistribution 7 | from torch.distributions.transforms import TanhTransform 8 | 9 | 10 | def extend_and_repeat(tensor, dim, repeat): 11 | # Extend and repeast the tensor along dim axie and repeat it 12 | ones_shape = [1 for _ in range(tensor.ndim + 1)] 13 | ones_shape[dim] = repeat 14 | return torch.unsqueeze(tensor, dim) * tensor.new_ones(ones_shape) 15 | 16 | 17 | def soft_target_update(network, target_network, soft_target_update_rate): 18 | target_network_params = {k: v for k, v in target_network.named_parameters()} 19 | for k, v in network.named_parameters(): 20 | target_network_params[k].data = ( 21 | (1 - soft_target_update_rate) * target_network_params[k].data 22 | + soft_target_update_rate * v.data 23 | ) 24 | 25 | 26 | def multiple_action_q_function(forward): 27 | # Forward the q function with multiple actions on each state, to be used as a decorator 28 | def wrapped(self, observations, actions, **kwargs): 29 | multiple_actions = False 30 | batch_size = observations.shape[0] 31 | if actions.ndim == 3 and observations.ndim == 2: 32 | multiple_actions = True 33 | observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1]) 34 | actions = actions.reshape(-1, actions.shape[-1]) 35 | q_values = forward(self, observations, actions, **kwargs) 36 | if multiple_actions: 37 | q_values = q_values.reshape(batch_size, -1) 38 | return q_values 39 | return wrapped 40 | 41 | 42 | class FullyConnectedNetwork(nn.Module): 43 | 44 | def __init__(self, input_dim, output_dim, arch='256-256', orthogonal_init=False): 45 | super().__init__() 46 | self.input_dim = input_dim 47 | self.output_dim = output_dim 48 | self.arch = arch 49 | self.orthogonal_init = orthogonal_init 50 | 51 | d = input_dim 52 | modules = [] 53 | hidden_sizes = [int(h) for h in arch.split('-')] 54 | 55 | for hidden_size in hidden_sizes: 56 | fc = nn.Linear(d, hidden_size) 57 | if orthogonal_init: 58 | nn.init.orthogonal_(fc.weight, gain=np.sqrt(2)) 59 | nn.init.constant_(fc.bias, 0.0) 60 | modules.append(fc) 61 | modules.append(nn.ReLU()) 62 | d = hidden_size 63 | 64 | last_fc = nn.Linear(d, output_dim) 65 | if orthogonal_init: 66 | nn.init.orthogonal_(last_fc.weight, gain=1e-2) 67 | else: 68 | nn.init.xavier_uniform_(last_fc.weight, gain=1e-2) 69 | 70 | nn.init.constant_(last_fc.bias, 0.0) 71 | modules.append(last_fc) 72 | 73 | self.network = nn.Sequential(*modules) 74 | 75 | def forward(self, input_tensor): 76 | return self.network(input_tensor) 77 | 78 | 79 | class ReparameterizedTanhGaussian(nn.Module): 80 | 81 | def __init__(self, log_std_min=-20.0, log_std_max=2.0, no_tanh=False): 82 | super().__init__() 83 | self.log_std_min = log_std_min 84 | self.log_std_max = log_std_max 85 | self.no_tanh = no_tanh 86 | 87 | def log_prob(self, mean, log_std, sample): 88 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 89 | std = torch.exp(log_std) 90 | if self.no_tanh: 91 | action_distribution = Normal(mean, std) 92 | else: 93 | action_distribution = TransformedDistribution( 94 | Normal(mean, std), TanhTransform(cache_size=1) 95 | ) 96 | return torch.sum(action_distribution.log_prob(sample), dim=-1) 97 | 98 | def forward(self, mean, log_std, deterministic=False): 99 | log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) 100 | std = torch.exp(log_std) 101 | 102 | 103 | if self.no_tanh: 104 | action_distribution = Normal(mean, std) 105 | else: 106 | action_distribution = TransformedDistribution( 107 | Normal(mean, std), TanhTransform(cache_size=1) 108 | ) 109 | 110 | if deterministic: 111 | action_sample = torch.tanh(mean) 112 | else: 113 | action_sample = action_distribution.rsample() 114 | 115 | log_prob = torch.sum( 116 | action_distribution.log_prob(action_sample), dim=-1 117 | ) 118 | 119 | return action_sample, log_prob 120 | 121 | 122 | class TanhGaussianPolicy(nn.Module): 123 | 124 | def __init__(self, observation_dim, action_dim, arch='256-256', 125 | log_std_multiplier=1.0, log_std_offset=-1.0, 126 | orthogonal_init=False, no_tanh=False): 127 | super().__init__() 128 | self.observation_dim = observation_dim 129 | self.action_dim = action_dim 130 | self.arch = arch 131 | self.orthogonal_init = orthogonal_init 132 | self.no_tanh = no_tanh 133 | 134 | self.base_network = FullyConnectedNetwork( 135 | observation_dim, 2 * action_dim, arch, orthogonal_init 136 | ) 137 | self.log_std_multiplier = Scalar(log_std_multiplier) 138 | self.log_std_offset = Scalar(log_std_offset) 139 | self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh) 140 | 141 | def log_prob(self, observations, actions): 142 | if actions.ndim == 3: 143 | observations = extend_and_repeat(observations, 1, actions.shape[1]) 144 | base_network_output = self.base_network(observations) 145 | mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) 146 | log_std = self.log_std_multiplier() * log_std + self.log_std_offset() 147 | return self.tanh_gaussian.log_prob(mean, log_std, actions) 148 | 149 | def forward(self, observations, deterministic=False, repeat=None): 150 | if repeat is not None: 151 | observations = extend_and_repeat(observations, 1, repeat) 152 | assert torch.isnan(observations).sum() == 0, print(observations) 153 | base_network_output = self.base_network(observations) 154 | mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) 155 | log_std = self.log_std_multiplier() * log_std + self.log_std_offset() 156 | assert torch.isnan(mean).sum() == 0, print(mean) 157 | assert torch.isnan(log_std).sum() == 0, print(log_std) 158 | return self.tanh_gaussian(mean, log_std, deterministic) 159 | 160 | 161 | class SamplerPolicy(object): 162 | 163 | def __init__(self, policy, device): 164 | self.policy = policy 165 | self.device = device 166 | 167 | def __call__(self, observations, deterministic=False): 168 | with torch.no_grad(): 169 | observations = torch.tensor( 170 | observations, dtype=torch.float32, device=self.device 171 | ) 172 | actions, _ = self.policy(observations, deterministic) 173 | actions = actions.cpu().numpy() 174 | return actions 175 | 176 | 177 | class FullyConnectedQFunction(nn.Module): 178 | 179 | def __init__(self, observation_dim, action_dim, arch='256-256', orthogonal_init=False): 180 | super().__init__() 181 | self.observation_dim = observation_dim 182 | self.action_dim = action_dim 183 | self.arch = arch 184 | self.orthogonal_init = orthogonal_init 185 | self.network = FullyConnectedNetwork( 186 | observation_dim + action_dim, 1, arch, orthogonal_init 187 | ) 188 | 189 | @multiple_action_q_function 190 | def forward(self, observations, actions): 191 | input_tensor = torch.cat([observations, actions], dim=-1) 192 | return torch.squeeze(self.network(input_tensor), dim=-1) 193 | 194 | 195 | class Scalar(nn.Module): 196 | def __init__(self, init_value): 197 | super().__init__() 198 | self.constant = nn.Parameter( 199 | torch.tensor(init_value, dtype=torch.float32) 200 | ) 201 | 202 | def forward(self): 203 | return self.constant 204 | -------------------------------------------------------------------------------- /viskit/static/css/dropdowns-enhancement.css: -------------------------------------------------------------------------------- 1 | .dropdown-menu > li > label { 2 | display: block; 3 | padding: 3px 20px; 4 | clear: both; 5 | font-weight: normal; 6 | line-height: 1.42857143; 7 | color: #333333; 8 | white-space: nowrap; 9 | } 10 | .dropdown-menu > li > label:hover, 11 | .dropdown-menu > li > label:focus { 12 | text-decoration: none; 13 | color: #262626; 14 | background-color: #f5f5f5; 15 | } 16 | .dropdown-menu > li > input:checked ~ label, 17 | .dropdown-menu > li > input:checked ~ label:hover, 18 | .dropdown-menu > li > input:checked ~ label:focus, 19 | .dropdown-menu > .active > label, 20 | .dropdown-menu > .active > label:hover, 21 | .dropdown-menu > .active > label:focus { 22 | color: #ffffff; 23 | text-decoration: none; 24 | outline: 0; 25 | background-color: #428bca; 26 | } 27 | .dropdown-menu > li > input[disabled] ~ label, 28 | .dropdown-menu > li > input[disabled] ~ label:hover, 29 | .dropdown-menu > li > input[disabled] ~ label:focus, 30 | .dropdown-menu > .disabled > label, 31 | .dropdown-menu > .disabled > label:hover, 32 | .dropdown-menu > .disabled > label:focus { 33 | color: #999999; 34 | } 35 | .dropdown-menu > li > input[disabled] ~ label:hover, 36 | .dropdown-menu > li > input[disabled] ~ label:focus, 37 | .dropdown-menu > .disabled > label:hover, 38 | .dropdown-menu > .disabled > label:focus { 39 | text-decoration: none; 40 | background-color: transparent; 41 | background-image: none; 42 | filter: progid:DXImageTransform.Microsoft.gradient(enabled = false); 43 | cursor: not-allowed; 44 | } 45 | .dropdown-menu > li > label { 46 | margin-bottom: 0; 47 | cursor: pointer; 48 | } 49 | .dropdown-menu > li > input[type="radio"], 50 | .dropdown-menu > li > input[type="checkbox"] { 51 | display: none; 52 | position: absolute; 53 | top: -9999em; 54 | left: -9999em; 55 | } 56 | .dropdown-menu > li > label:focus, 57 | .dropdown-menu > li > input:focus ~ label { 58 | outline: thin dotted; 59 | outline: 5px auto -webkit-focus-ring-color; 60 | outline-offset: -2px; 61 | } 62 | .dropdown-menu.pull-right { 63 | right: 0; 64 | left: auto; 65 | } 66 | .dropdown-menu.pull-top { 67 | bottom: 100%; 68 | top: auto; 69 | margin: 0 0 2px; 70 | -webkit-box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175); 71 | box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175); 72 | } 73 | .dropdown-menu.pull-center { 74 | right: 50%; 75 | left: auto; 76 | } 77 | .dropdown-menu.pull-middle { 78 | right: 100%; 79 | margin: 0 2px 0 0; 80 | box-shadow: -5px 0 10px rgba(0, 0, 0, 0.2); 81 | left: auto; 82 | } 83 | .dropdown-menu.pull-middle.pull-right { 84 | right: auto; 85 | left: 100%; 86 | margin: 0 0 0 2px; 87 | box-shadow: 5px 0 10px rgba(0, 0, 0, 0.2); 88 | } 89 | .dropdown-menu.pull-middle.pull-center { 90 | right: 50%; 91 | margin: 0; 92 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); 93 | } 94 | .dropdown-menu.bullet { 95 | margin-top: 8px; 96 | } 97 | .dropdown-menu.bullet:before { 98 | width: 0; 99 | height: 0; 100 | content: ''; 101 | display: inline-block; 102 | position: absolute; 103 | border-color: transparent; 104 | border-style: solid; 105 | -webkit-transform: rotate(360deg); 106 | border-width: 0 7px 7px; 107 | border-bottom-color: #cccccc; 108 | border-bottom-color: rgba(0, 0, 0, 0.15); 109 | top: -7px; 110 | left: 9px; 111 | } 112 | .dropdown-menu.bullet:after { 113 | width: 0; 114 | height: 0; 115 | content: ''; 116 | display: inline-block; 117 | position: absolute; 118 | border-color: transparent; 119 | border-style: solid; 120 | -webkit-transform: rotate(360deg); 121 | border-width: 0 6px 6px; 122 | border-bottom-color: #ffffff; 123 | top: -6px; 124 | left: 10px; 125 | } 126 | .dropdown-menu.bullet.pull-right:before { 127 | left: auto; 128 | right: 9px; 129 | } 130 | .dropdown-menu.bullet.pull-right:after { 131 | left: auto; 132 | right: 10px; 133 | } 134 | .dropdown-menu.bullet.pull-top { 135 | margin-top: 0; 136 | margin-bottom: 8px; 137 | } 138 | .dropdown-menu.bullet.pull-top:before { 139 | top: auto; 140 | bottom: -7px; 141 | border-bottom-width: 0; 142 | border-top-width: 7px; 143 | border-top-color: #cccccc; 144 | border-top-color: rgba(0, 0, 0, 0.15); 145 | } 146 | .dropdown-menu.bullet.pull-top:after { 147 | top: auto; 148 | bottom: -6px; 149 | border-bottom: none; 150 | border-top-width: 6px; 151 | border-top-color: #ffffff; 152 | } 153 | .dropdown-menu.bullet.pull-center:before { 154 | left: auto; 155 | right: 50%; 156 | margin-right: -7px; 157 | } 158 | .dropdown-menu.bullet.pull-center:after { 159 | left: auto; 160 | right: 50%; 161 | margin-right: -6px; 162 | } 163 | .dropdown-menu.bullet.pull-middle { 164 | margin-right: 8px; 165 | } 166 | .dropdown-menu.bullet.pull-middle:before { 167 | top: 50%; 168 | left: 100%; 169 | right: auto; 170 | margin-top: -7px; 171 | border-right-width: 0; 172 | border-bottom-color: transparent; 173 | border-top-width: 7px; 174 | border-left-color: #cccccc; 175 | border-left-color: rgba(0, 0, 0, 0.15); 176 | } 177 | .dropdown-menu.bullet.pull-middle:after { 178 | top: 50%; 179 | left: 100%; 180 | right: auto; 181 | margin-top: -6px; 182 | border-right-width: 0; 183 | border-bottom-color: transparent; 184 | border-top-width: 6px; 185 | border-left-color: #ffffff; 186 | } 187 | .dropdown-menu.bullet.pull-middle.pull-right { 188 | margin-right: 0; 189 | margin-left: 8px; 190 | } 191 | .dropdown-menu.bullet.pull-middle.pull-right:before { 192 | left: -7px; 193 | border-left-width: 0; 194 | border-right-width: 7px; 195 | border-right-color: #cccccc; 196 | border-right-color: rgba(0, 0, 0, 0.15); 197 | } 198 | .dropdown-menu.bullet.pull-middle.pull-right:after { 199 | left: -6px; 200 | border-left-width: 0; 201 | border-right-width: 6px; 202 | border-right-color: #ffffff; 203 | } 204 | .dropdown-menu.bullet.pull-middle.pull-center { 205 | margin-left: 0; 206 | margin-right: 0; 207 | } 208 | .dropdown-menu.bullet.pull-middle.pull-center:before { 209 | border: none; 210 | display: none; 211 | } 212 | .dropdown-menu.bullet.pull-middle.pull-center:after { 213 | border: none; 214 | display: none; 215 | } 216 | .dropdown-submenu { 217 | position: relative; 218 | } 219 | .dropdown-submenu > .dropdown-menu { 220 | top: 0; 221 | left: 100%; 222 | margin-top: -6px; 223 | margin-left: -1px; 224 | border-top-left-radius: 0; 225 | } 226 | .dropdown-submenu > a:before { 227 | display: block; 228 | float: right; 229 | width: 0; 230 | height: 0; 231 | content: ""; 232 | margin-top: 6px; 233 | margin-right: -8px; 234 | border-width: 4px 0 4px 4px; 235 | border-style: solid; 236 | border-left-style: dashed; 237 | border-top-color: transparent; 238 | border-bottom-color: transparent; 239 | } 240 | @media (max-width: 767px) { 241 | .navbar-nav .dropdown-submenu > a:before { 242 | margin-top: 8px; 243 | border-color: inherit; 244 | border-style: solid; 245 | border-width: 4px 4px 0; 246 | border-left-color: transparent; 247 | border-right-color: transparent; 248 | } 249 | .navbar-nav .dropdown-submenu > a { 250 | padding-left: 40px; 251 | } 252 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > a, 253 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > label { 254 | padding-left: 35px; 255 | } 256 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > a, 257 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > label { 258 | padding-left: 45px; 259 | } 260 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, 261 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { 262 | padding-left: 55px; 263 | } 264 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, 265 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { 266 | padding-left: 65px; 267 | } 268 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, 269 | .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { 270 | padding-left: 75px; 271 | } 272 | } 273 | .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a, 274 | .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover, 275 | .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus { 276 | background-color: #e7e7e7; 277 | color: #555555; 278 | } 279 | @media (max-width: 767px) { 280 | .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before { 281 | border-top-color: #555555; 282 | } 283 | } 284 | .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a, 285 | .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover, 286 | .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus { 287 | background-color: #080808; 288 | color: #ffffff; 289 | } 290 | @media (max-width: 767px) { 291 | .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before { 292 | border-top-color: #ffffff; 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /viskit/core.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import os 4 | import numpy as np 5 | import json 6 | import itertools 7 | 8 | 9 | class AttrDict(dict): 10 | def __init__(self, *args, **kwargs): 11 | super(AttrDict, self).__init__(*args, **kwargs) 12 | self.__dict__ = self 13 | 14 | 15 | 16 | def unique(l): 17 | return list(set(l)) 18 | 19 | 20 | def flatten(l): 21 | return [item for sublist in l for item in sublist] 22 | 23 | 24 | def load_progress(progress_csv_path): 25 | print("Reading %s" % progress_csv_path) 26 | entries = dict() 27 | if progress_csv_path.split('.')[-1] == "csv": 28 | delimiter = ',' 29 | else: 30 | delimiter = '\t' 31 | with open(progress_csv_path, 'r') as csvfile: 32 | reader = csv.DictReader(csvfile, delimiter=delimiter) 33 | for row in reader: 34 | for k, v in row.items(): 35 | if k not in entries: 36 | entries[k] = [] 37 | try: 38 | entries[k].append(float(v)) 39 | except: 40 | entries[k].append(0.) 41 | entries = dict([(k, np.array(v)) for k, v in entries.items()]) 42 | return entries 43 | 44 | 45 | def to_json(stub_object): 46 | from rllab.misc.instrument import StubObject 47 | from rllab.misc.instrument import StubAttr 48 | if isinstance(stub_object, StubObject): 49 | assert len(stub_object.args) == 0 50 | data = dict() 51 | for k, v in stub_object.kwargs.items(): 52 | data[k] = to_json(v) 53 | data["_name"] = stub_object.proxy_class.__module__ + \ 54 | "." + stub_object.proxy_class.__name__ 55 | return data 56 | elif isinstance(stub_object, StubAttr): 57 | return dict( 58 | obj=to_json(stub_object.obj), 59 | attr=to_json(stub_object.attr_name) 60 | ) 61 | return stub_object 62 | 63 | 64 | def flatten_dict(d): 65 | flat_params = dict() 66 | for k, v in d.items(): 67 | if isinstance(v, dict): 68 | v = flatten_dict(v) 69 | for subk, subv in flatten_dict(v).items(): 70 | flat_params[k + "." + subk] = subv 71 | else: 72 | flat_params[k] = v 73 | return flat_params 74 | 75 | 76 | def load_params(params_json_path): 77 | with open(params_json_path, 'r') as f: 78 | data = json.loads(f.read()) 79 | if "args_data" in data: 80 | del data["args_data"] 81 | if "exp_name" not in data: 82 | data["exp_name"] = params_json_path.split("/")[-2] 83 | return data 84 | 85 | 86 | def lookup(d, keys): 87 | if not isinstance(keys, list): 88 | keys = keys.split(".") 89 | for k in keys: 90 | if hasattr(d, "__getitem__"): 91 | if k in d: 92 | d = d[k] 93 | else: 94 | return None 95 | else: 96 | return None 97 | return d 98 | 99 | 100 | def load_exps_data( 101 | exp_folder_paths, 102 | data_filename='progress.csv', 103 | params_filename='params.json', 104 | disable_variant=False, 105 | ): 106 | exps = [] 107 | for exp_folder_path in exp_folder_paths: 108 | exps += [x[0] for x in os.walk(exp_folder_path)] 109 | exps_data = [] 110 | for exp in exps: 111 | try: 112 | exp_path = exp 113 | params_json_path = os.path.join(exp_path, params_filename) 114 | variant_json_path = os.path.join(exp_path, "variant.json") 115 | progress_csv_path = os.path.join(exp_path, data_filename) 116 | if os.stat(progress_csv_path).st_size == 0: 117 | progress_csv_path = os.path.join(exp_path, "log.txt") 118 | progress = load_progress(progress_csv_path) 119 | if disable_variant: 120 | params = load_params(params_json_path) 121 | else: 122 | try: 123 | params = load_params(variant_json_path) 124 | except IOError: 125 | params = load_params(params_json_path) 126 | exps_data.append(AttrDict( 127 | progress=progress, 128 | params=params, 129 | flat_params=flatten_dict(params))) 130 | except IOError as e: 131 | print(e) 132 | return exps_data 133 | 134 | 135 | def smart_repr(x): 136 | if isinstance(x, tuple): 137 | if len(x) == 0: 138 | return "tuple()" 139 | elif len(x) == 1: 140 | return "(%s,)" % smart_repr(x[0]) 141 | else: 142 | return "(" + ",".join(map(smart_repr, x)) + ")" 143 | elif isinstance(x, list): 144 | if len(x) == 0: 145 | return "[]" 146 | elif len(x) == 1: 147 | return "[%s,]" % smart_repr(x[0]) 148 | else: 149 | return "[" + ",".join(map(smart_repr, x)) + "]" 150 | else: 151 | if hasattr(x, "__call__"): 152 | return "__import__('pydoc').locate('%s')" % (x.__module__ + "." + x.__name__) 153 | elif isinstance(x, float) and math.isnan(x): 154 | return 'float("nan")' 155 | else: 156 | return repr(x) 157 | 158 | 159 | def smart_eval(string): 160 | string = string.replace(',inf)', ',"inf")') 161 | return eval(string) 162 | 163 | 164 | 165 | def extract_distinct_params(exps_data, excluded_params=('seed', 'log_dir'), l=1): 166 | # all_pairs = unique(flatten([d.flat_params.items() for d in exps_data])) 167 | # if logger: 168 | # logger("(Excluding {excluded})".format(excluded=', '.join(excluded_params))) 169 | # def cmp(x,y): 170 | # if x < y: 171 | # return -1 172 | # elif x > y: 173 | # return 1 174 | # else: 175 | # return 0 176 | 177 | try: 178 | params_as_evalable_strings = [ 179 | list( 180 | map( 181 | smart_repr, 182 | list(d.flat_params.items()) 183 | ) 184 | ) 185 | for d in exps_data 186 | ] 187 | unique_params = unique( 188 | flatten( 189 | params_as_evalable_strings 190 | ) 191 | ) 192 | stringified_pairs = sorted( 193 | map( 194 | smart_eval, 195 | unique_params 196 | ), 197 | key=lambda x: ( 198 | tuple(smart_repr(i) for i in x) 199 | # tuple(0. if it is None else it for it in x), 200 | ) 201 | ) 202 | except Exception as e: 203 | print(e) 204 | import ipdb; ipdb.set_trace() 205 | proposals = [(k, [x[1] for x in v]) 206 | for k, v in itertools.groupby(stringified_pairs, lambda x: x[0])] 207 | filtered = [ 208 | (k, v) for (k, v) in proposals 209 | if k == 'version' or ( 210 | len(v) > l and all( 211 | [k.find(excluded_param) != 0 212 | for excluded_param in excluded_params] 213 | ) 214 | ) 215 | ] 216 | return filtered 217 | 218 | def exp_has_key_value(exp, k, v): 219 | return ( 220 | str(exp.flat_params.get(k, None)) == str(v) 221 | # TODO: include this? 222 | or (k not in exp.flat_params) 223 | ) 224 | 225 | 226 | class Selector(object): 227 | def __init__(self, exps_data, filters=None, custom_filters=None): 228 | self._exps_data = exps_data 229 | if filters is None: 230 | self._filters = tuple() 231 | else: 232 | self._filters = tuple(filters) 233 | if custom_filters is None: 234 | self._custom_filters = [] 235 | else: 236 | self._custom_filters = custom_filters 237 | 238 | def where(self, k, v): 239 | return Selector( 240 | self._exps_data, 241 | self._filters + ((k, v),), 242 | self._custom_filters, 243 | ) 244 | 245 | def where_not(self, k, v): 246 | return Selector( 247 | self._exps_data, 248 | self._filters, 249 | self._custom_filters + [ 250 | lambda exp: not exp_has_key_value(exp, k, v) 251 | ], 252 | ) 253 | 254 | def custom_filter(self, filter): 255 | return Selector(self._exps_data, self._filters, self._custom_filters + [filter]) 256 | 257 | def _check_exp(self, exp): 258 | # or exp.flat_params.get(k, None) is None 259 | return all( 260 | ( 261 | exp_has_key_value(exp, k, v) 262 | for k, v in self._filters 263 | ) 264 | ) and all(custom_filter(exp) for custom_filter in self._custom_filters) 265 | 266 | def extract(self): 267 | return list(filter(self._check_exp, self._exps_data)) 268 | 269 | def iextract(self): 270 | return filter(self._check_exp, self._exps_data) 271 | 272 | 273 | # Taken from plot.ly 274 | color_defaults = [ 275 | '#1f77b4', # muted blue 276 | '#ff7f0e', # safety orange 277 | '#2ca02c', # cooked asparagus green 278 | '#d62728', # brick red 279 | '#9467bd', # muted purple 280 | '#8c564b', # chestnut brown 281 | '#e377c2', # raspberry yogurt pink 282 | '#7f7f7f', # middle gray 283 | '#bcbd22', # curry yellow-green 284 | '#17becf' # blue-teal 285 | ] 286 | 287 | 288 | def hex_to_rgb(hex, opacity=1.0): 289 | if hex[0] == '#': 290 | hex = hex[1:] 291 | assert (len(hex) == 6) 292 | return "rgba({0},{1},{2},{3})".format(int(hex[:2], 16), int(hex[2:4], 16), int(hex[4:6], 16), opacity) 293 | -------------------------------------------------------------------------------- /SimpleSAC/sim2real_sac_main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pprint 4 | import re 5 | import sys 6 | import time 7 | import uuid 8 | from copy import deepcopy 9 | from sre_parse import FLAGS 10 | 11 | import absl.app 12 | import absl.flags 13 | import d4rl 14 | import gym 15 | import numpy as np 16 | import torch 17 | import wandb 18 | from tqdm import trange 19 | 20 | from envs import get_new_density_env, get_new_friction_env, get_new_gravity_env 21 | from mixed_replay_buffer import MixedReplayBuffer 22 | from model import FullyConnectedQFunction, SamplerPolicy, TanhGaussianPolicy 23 | from sampler import StepSampler, TrajSampler 24 | from sim2real_sac import Sim2realSAC 25 | from utils_h2o import (Timer, WandBLogger, define_flags_with_default, 26 | get_user_flags, prefix_metrics, print_flags, 27 | set_random_seed) 28 | 29 | sys.path.append("..") 30 | 31 | from Network.Weight_net import ConcatDiscriminator 32 | from viskit.logging import logger, setup_logger 33 | 34 | nowTime = datetime.datetime.now().strftime('%y-%m-%d-%H-%M-%S') 35 | 36 | FLAGS_DEF = define_flags_with_default( 37 | current_time=nowTime, 38 | name_str='', 39 | env_list='HalfCheetah-v2', 40 | data_source='medium_replay', 41 | unreal_dynamics="gravity", 42 | variety_list="2.0", 43 | batch_ratio=0.5, 44 | replaybuffer_ratio=10, 45 | real_residual_ratio=1.0, 46 | dis_dropout=False, 47 | max_traj_length=1000, 48 | seed=42, 49 | device='cuda', 50 | save_model=False, 51 | batch_size=256, 52 | 53 | reward_scale=1.0, 54 | reward_bias=0.0, 55 | clip_action=1.0, 56 | joint_noise_std=0.0, 57 | 58 | policy_arch='256-256', 59 | qf_arch='256-256', 60 | orthogonal_init=False, 61 | policy_log_std_multiplier=1.0, 62 | policy_log_std_offset=-1.0, 63 | 64 | # train and evaluate policy 65 | n_epochs=1000, 66 | bc_epochs=0, 67 | n_rollout_steps_per_epoch=1000, 68 | n_train_step_per_epoch=1000, 69 | eval_period=10, 70 | eval_n_trajs=5, 71 | 72 | cql=Sim2realSAC.get_default_config(), 73 | logging=WandBLogger.get_default_config() 74 | ) 75 | 76 | 77 | def main(argv): 78 | FLAGS = absl.flags.FLAGS 79 | 80 | # define logged variables for wandb 81 | variant = get_user_flags(FLAGS, FLAGS_DEF) 82 | wandb_logger = WandBLogger(config=FLAGS.logging, variant=variant) 83 | wandb.run.name = f"{FLAGS.name_str}_{FLAGS.env_list}_{FLAGS.data_source}_{FLAGS.unreal_dynamics}x{FLAGS.variety_list}_{FLAGS.current_time}" 84 | 85 | setup_logger( 86 | variant=variant, 87 | exp_id=wandb_logger.experiment_id, 88 | seed=FLAGS.seed, 89 | base_log_dir=FLAGS.logging.output_dir, 90 | include_exp_prefix_sub_dir=False 91 | ) 92 | 93 | set_random_seed(FLAGS.seed) 94 | 95 | # different unreal dynamics properties 96 | for unreal_dynamics in FLAGS.unreal_dynamics.split(";"): 97 | # different environment 98 | for env_name in FLAGS.env_list.split(";"): 99 | # different varieties 100 | for variety_degree in FLAGS.variety_list.split(";"): 101 | variety_degree = float(variety_degree) 102 | 103 | off_env_name = "{}-{}-v2".format(env_name.split("-")[0].lower(), FLAGS.data_source).replace('_',"-") 104 | if unreal_dynamics == "gravity": 105 | real_env = get_new_gravity_env(1, off_env_name) 106 | sim_env = get_new_gravity_env(variety_degree, off_env_name) 107 | elif unreal_dynamics == "density": 108 | real_env = get_new_density_env(1, off_env_name) 109 | sim_env = get_new_density_env(variety_degree, off_env_name) 110 | elif unreal_dynamics == "friction": 111 | real_env = get_new_friction_env(1, off_env_name) 112 | sim_env = get_new_friction_env(variety_degree, off_env_name) 113 | else: 114 | raise RuntimeError("Got erroneous unreal dynamics %s" % unreal_dynamics) 115 | 116 | print("\n-------------Env name: {}, variety: {}, unreal_dynamics: {}-------------".format(env_name, variety_degree, unreal_dynamics)) 117 | 118 | # a step sampler for "simulated" training 119 | train_sampler = StepSampler(sim_env.unwrapped, FLAGS.max_traj_length) 120 | # a trajectory sampler for "real-world" evaluation 121 | eval_sampler = TrajSampler(real_env.unwrapped, FLAGS.max_traj_length) 122 | 123 | # replay buffer 124 | num_state = real_env.observation_space.shape[0] 125 | num_action = real_env.action_space.shape[0] 126 | replay_buffer = MixedReplayBuffer(FLAGS.reward_scale, FLAGS.reward_bias, FLAGS.clip_action, num_state, num_action, task=env_name.split("-")[0].lower(), data_source=FLAGS.data_source, device=FLAGS.device, buffer_ratio=FLAGS.replaybuffer_ratio, residual_ratio=FLAGS.real_residual_ratio) 127 | 128 | # discirminators 129 | d_sa = ConcatDiscriminator(num_state + num_action, 256, 2, FLAGS.device, dropout=FLAGS.dis_dropout).float().to(FLAGS.device) 130 | d_sas = ConcatDiscriminator(2* num_state + num_action, 256, 2, FLAGS.device, dropout=FLAGS.dis_dropout).float().to(FLAGS.device) 131 | 132 | # agent 133 | policy = TanhGaussianPolicy( 134 | eval_sampler.env.observation_space.shape[0], 135 | eval_sampler.env.action_space.shape[0], 136 | arch=FLAGS.policy_arch, 137 | log_std_multiplier=FLAGS.policy_log_std_multiplier, 138 | log_std_offset=FLAGS.policy_log_std_offset, 139 | orthogonal_init=FLAGS.orthogonal_init, 140 | ) 141 | 142 | qf1 = FullyConnectedQFunction( 143 | eval_sampler.env.observation_space.shape[0], 144 | eval_sampler.env.action_space.shape[0], 145 | arch=FLAGS.qf_arch, 146 | orthogonal_init=FLAGS.orthogonal_init, 147 | ) 148 | target_qf1 = deepcopy(qf1) 149 | 150 | qf2 = FullyConnectedQFunction( 151 | eval_sampler.env.observation_space.shape[0], 152 | eval_sampler.env.action_space.shape[0], 153 | arch=FLAGS.qf_arch, 154 | orthogonal_init=FLAGS.orthogonal_init, 155 | ) 156 | target_qf2 = deepcopy(qf2) 157 | 158 | if FLAGS.cql.target_entropy >= 0.0: 159 | FLAGS.cql.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item() 160 | 161 | sac = Sim2realSAC(FLAGS.cql, policy, qf1, qf2, target_qf1, target_qf2, d_sa, d_sas, replay_buffer, dynamics_model=None) 162 | sac.torch_to_device(FLAGS.device) 163 | 164 | # sampling policy is always the current policy: \pi 165 | sampler_policy = SamplerPolicy(policy, FLAGS.device) 166 | 167 | viskit_metrics = {} 168 | 169 | # train and evaluate for n_epochs 170 | for epoch in trange(FLAGS.n_epochs): 171 | metrics = {} 172 | 173 | # TODO rollout from the simulator 174 | with Timer() as rollout_timer: 175 | # rollout and append simulated trajectories to the replay buffer 176 | train_sampler.sample( 177 | sampler_policy, FLAGS.n_rollout_steps_per_epoch, 178 | deterministic=False, replay_buffer=replay_buffer, joint_noise_std=FLAGS.joint_noise_std 179 | ) 180 | metrics['epoch'] = epoch 181 | 182 | # TODO Train from the mixed data 183 | with Timer() as train_timer: 184 | for batch_idx in trange(FLAGS.n_train_step_per_epoch): 185 | real_batch_size = int(FLAGS.batch_size * (1 - FLAGS.batch_ratio)) 186 | sim_batch_size = int(FLAGS.batch_size * FLAGS.batch_ratio) 187 | if batch_idx + 1 == FLAGS.n_train_step_per_epoch: 188 | metrics.update( 189 | prefix_metrics(sac.train(real_batch_size, sim_batch_size), 'sac') 190 | ) 191 | else: 192 | sac.train(real_batch_size, sim_batch_size) 193 | 194 | # TODO Evaluate in the real world 195 | with Timer() as eval_timer: 196 | if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: 197 | trajs = eval_sampler.sample( 198 | sampler_policy, FLAGS.eval_n_trajs, deterministic=True 199 | ) 200 | 201 | eval_dsa_loss, eval_dsas_loss = sac.discriminator_evaluate() 202 | metrics['eval_dsa_loss'] = eval_dsa_loss 203 | metrics['eval_dsas_loss'] = eval_dsas_loss 204 | metrics['average_return'] = np.mean([np.sum(t['rewards']) for t in trajs]) 205 | metrics['average_traj_length'] = np.mean([len(t['rewards']) for t in trajs]) 206 | metrics['average_normalizd_return'] = np.mean( 207 | [eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs] 208 | ) 209 | 210 | if FLAGS.save_model: 211 | save_data = {'sac': sac, 'variant': variant, 'epoch': epoch} 212 | wandb_logger.save_pickle(save_data, 'model.pkl') 213 | 214 | metrics['rollout_time'] = rollout_timer() 215 | metrics['train_time'] = train_timer() 216 | metrics['eval_time'] = eval_timer() 217 | metrics['epoch_time'] = rollout_timer() + train_timer() + eval_timer() 218 | wandb_logger.log(metrics) 219 | viskit_metrics.update(metrics) 220 | logger.record_dict(viskit_metrics) 221 | logger.dump_tabular(with_prefix=False, with_timestamp=False) 222 | 223 | if FLAGS.save_model: 224 | save_data = {'sac': sac, 'variant': variant, 'epoch': epoch} 225 | wandb_logger.save_pickle(save_data, 'model.pkl') 226 | 227 | if __name__ == '__main__': 228 | absl.app.run(main) 229 | -------------------------------------------------------------------------------- /viskit/static/js/dropdowns-enhancement.js: -------------------------------------------------------------------------------- 1 | /* ======================================================================== 2 | * Bootstrap Dropdowns Enhancement: dropdowns-enhancement.js v3.1.1 (Beta 1) 3 | * http://behigh.github.io/bootstrap_dropdowns_enhancement/ 4 | * ======================================================================== 5 | * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE) 6 | * ======================================================================== */ 7 | 8 | (function($) { 9 | "use strict"; 10 | 11 | var toggle = '[data-toggle="dropdown"]', 12 | disabled = '.disabled, :disabled', 13 | backdrop = '.dropdown-backdrop', 14 | menuClass = 'dropdown-menu', 15 | subMenuClass = 'dropdown-submenu', 16 | namespace = '.bs.dropdown.data-api', 17 | eventNamespace = '.bs.dropdown', 18 | openClass = 'open', 19 | touchSupport = 'ontouchstart' in document.documentElement, 20 | opened; 21 | 22 | 23 | function Dropdown(element) { 24 | $(element).on('click' + eventNamespace, this.toggle) 25 | } 26 | 27 | var proto = Dropdown.prototype; 28 | 29 | proto.toggle = function(event) { 30 | var $element = $(this); 31 | 32 | if ($element.is(disabled)) return; 33 | 34 | var $parent = getParent($element); 35 | var isActive = $parent.hasClass(openClass); 36 | var isSubMenu = $parent.hasClass(subMenuClass); 37 | var menuTree = isSubMenu ? getSubMenuParents($parent) : null; 38 | 39 | closeOpened(event, menuTree); 40 | 41 | if (!isActive) { 42 | if (!menuTree) 43 | menuTree = [$parent]; 44 | 45 | if (touchSupport && !$parent.closest('.navbar-nav').length && !menuTree[0].find(backdrop).length) { 46 | // if mobile we use a backdrop because click events don't delegate 47 | $('
').appendTo(menuTree[0]).on('click', closeOpened) 48 | } 49 | 50 | for (var i = 0, s = menuTree.length; i < s; i++) { 51 | if (!menuTree[i].hasClass(openClass)) { 52 | menuTree[i].addClass(openClass); 53 | positioning(menuTree[i].children('.' + menuClass), menuTree[i]); 54 | } 55 | } 56 | opened = menuTree[0]; 57 | } 58 | 59 | return false; 60 | }; 61 | 62 | proto.keydown = function (e) { 63 | if (!/(38|40|27)/.test(e.keyCode)) return; 64 | 65 | var $this = $(this); 66 | 67 | e.preventDefault(); 68 | e.stopPropagation(); 69 | 70 | if ($this.is('.disabled, :disabled')) return; 71 | 72 | var $parent = getParent($this); 73 | var isActive = $parent.hasClass('open'); 74 | 75 | if (!isActive || (isActive && e.keyCode == 27)) { 76 | if (e.which == 27) $parent.find(toggle).trigger('focus'); 77 | return $this.trigger('click') 78 | } 79 | 80 | var desc = ' li:not(.divider):visible a'; 81 | var desc1 = 'li:not(.divider):visible > input:not(disabled) ~ label'; 82 | var $items = $parent.find(desc1 + ', ' + '[role="menu"]' + desc + ', [role="listbox"]' + desc); 83 | 84 | if (!$items.length) return; 85 | 86 | var index = $items.index($items.filter(':focus')); 87 | 88 | if (e.keyCode == 38 && index > 0) index--; // up 89 | if (e.keyCode == 40 && index < $items.length - 1) index++; // down 90 | if (!~index) index = 0; 91 | 92 | $items.eq(index).trigger('focus') 93 | }; 94 | 95 | proto.change = function (e) { 96 | 97 | var 98 | $parent, 99 | $menu, 100 | $toggle, 101 | selector, 102 | text = '', 103 | $items; 104 | 105 | $menu = $(this).closest('.' + menuClass); 106 | 107 | $toggle = $menu.parent().find('[data-label-placement]'); 108 | 109 | if (!$toggle || !$toggle.length) { 110 | $toggle = $menu.parent().find(toggle); 111 | } 112 | 113 | if (!$toggle || !$toggle.length || $toggle.data('placeholder') === false) 114 | return; // do nothing, no control 115 | 116 | ($toggle.data('placeholder') == undefined && $toggle.data('placeholder', $.trim($toggle.text()))); 117 | text = $.data($toggle[0], 'placeholder'); 118 | 119 | $items = $menu.find('li > input:checked'); 120 | 121 | if ($items.length) { 122 | text = []; 123 | $items.each(function () { 124 | var str = $(this).parent().find('label').eq(0), 125 | label = str.find('.data-label'); 126 | 127 | if (label.length) { 128 | var p = $('

'); 129 | p.append(label.clone()); 130 | str = p.html(); 131 | } 132 | else { 133 | str = str.html(); 134 | } 135 | 136 | 137 | str && text.push($.trim(str)); 138 | }); 139 | 140 | text = text.length < 4 ? text.join(', ') : text.length + ' selected'; 141 | } 142 | 143 | var caret = $toggle.find('.caret'); 144 | 145 | $toggle.html(text || ' '); 146 | if (caret.length) 147 | $toggle.append(' ') && caret.appendTo($toggle); 148 | 149 | }; 150 | 151 | function positioning($menu, $control) { 152 | if ($menu.hasClass('pull-center')) { 153 | $menu.css('margin-right', $menu.outerWidth() / -2); 154 | } 155 | 156 | if ($menu.hasClass('pull-middle')) { 157 | $menu.css('margin-top', ($menu.outerHeight() / -2) - ($control.outerHeight() / 2)); 158 | } 159 | } 160 | 161 | function closeOpened(event, menuTree) { 162 | if (opened) { 163 | 164 | if (!menuTree) { 165 | menuTree = [opened]; 166 | } 167 | 168 | var parent; 169 | 170 | if (opened[0] !== menuTree[0][0]) { 171 | parent = opened; 172 | } else { 173 | parent = menuTree[menuTree.length - 1]; 174 | if (parent.parent().hasClass(menuClass)) { 175 | parent = parent.parent(); 176 | } 177 | } 178 | 179 | parent.find('.' + openClass).removeClass(openClass); 180 | 181 | if (parent.hasClass(openClass)) 182 | parent.removeClass(openClass); 183 | 184 | if (parent === opened) { 185 | opened = null; 186 | $(backdrop).remove(); 187 | } 188 | } 189 | } 190 | 191 | function getSubMenuParents($submenu) { 192 | var result = [$submenu]; 193 | var $parent; 194 | while (!$parent || $parent.hasClass(subMenuClass)) { 195 | $parent = ($parent || $submenu).parent(); 196 | if ($parent.hasClass(menuClass)) { 197 | $parent = $parent.parent(); 198 | } 199 | if ($parent.children(toggle)) { 200 | result.unshift($parent); 201 | } 202 | } 203 | return result; 204 | } 205 | 206 | function getParent($this) { 207 | var selector = $this.attr('data-target'); 208 | 209 | if (!selector) { 210 | selector = $this.attr('href'); 211 | selector = selector && /#[A-Za-z]/.test(selector) && selector.replace(/.*(?=#[^\s]*$)/, ''); //strip for ie7 212 | } 213 | 214 | var $parent = selector && $(selector); 215 | 216 | return $parent && $parent.length ? $parent : $this.parent() 217 | } 218 | 219 | // DROPDOWN PLUGIN DEFINITION 220 | // ========================== 221 | 222 | var old = $.fn.dropdown; 223 | 224 | $.fn.dropdown = function (option) { 225 | return this.each(function () { 226 | var $this = $(this); 227 | var data = $this.data('bs.dropdown'); 228 | 229 | if (!data) $this.data('bs.dropdown', (data = new Dropdown(this))); 230 | if (typeof option == 'string') data[option].call($this); 231 | }) 232 | }; 233 | 234 | $.fn.dropdown.Constructor = Dropdown; 235 | 236 | $.fn.dropdown.clearMenus = function(e) { 237 | $(backdrop).remove(); 238 | $('.' + openClass + ' ' + toggle).each(function () { 239 | var $parent = getParent($(this)); 240 | var relatedTarget = { relatedTarget: this }; 241 | if (!$parent.hasClass('open')) return; 242 | $parent.trigger(e = $.Event('hide' + eventNamespace, relatedTarget)); 243 | if (e.isDefaultPrevented()) return; 244 | $parent.removeClass('open').trigger('hidden' + eventNamespace, relatedTarget); 245 | }); 246 | return this; 247 | }; 248 | 249 | 250 | // DROPDOWN NO CONFLICT 251 | // ==================== 252 | 253 | $.fn.dropdown.noConflict = function () { 254 | $.fn.dropdown = old; 255 | return this 256 | }; 257 | 258 | 259 | $(document).off(namespace) 260 | .on('click' + namespace, closeOpened) 261 | .on('click' + namespace, toggle, proto.toggle) 262 | .on('click' + namespace, '.dropdown-menu > li > input[type="checkbox"] ~ label, .dropdown-menu > li > input[type="checkbox"], .dropdown-menu.noclose > li', function (e) { 263 | e.stopPropagation() 264 | }) 265 | .on('change' + namespace, '.dropdown-menu > li > input[type="checkbox"], .dropdown-menu > li > input[type="radio"]', proto.change) 266 | .on('keydown' + namespace, toggle + ', [role="menu"], [role="listbox"]', proto.keydown) 267 | }(jQuery)); -------------------------------------------------------------------------------- /SimpleSAC/utils_h2o.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pprint 3 | import time 4 | import uuid 5 | import tempfile 6 | import os 7 | import re 8 | from copy import copy 9 | from socket import gethostname 10 | import pickle 11 | 12 | import numpy as np 13 | 14 | import absl.flags 15 | from absl import logging 16 | from ml_collections import ConfigDict 17 | from ml_collections.config_flags import config_flags 18 | from ml_collections.config_dict import config_dict 19 | 20 | import wandb 21 | 22 | import torch 23 | 24 | 25 | class Timer(object): 26 | 27 | def __init__(self): 28 | self._time = None 29 | 30 | def __enter__(self): 31 | self._start_time = time.time() 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_value, exc_tb): 35 | self._time = time.time() - self._start_time 36 | 37 | def __call__(self): 38 | return self._time 39 | 40 | 41 | class WandBLogger(object): 42 | 43 | @staticmethod 44 | def get_default_config(updates=None): 45 | config = ConfigDict() 46 | config.online = True 47 | config.prefix = '' 48 | config.project = 'when-to-trust-your-simulator' 49 | config.entity = 't6-thu' 50 | config.output_dir = './experiment_output' 51 | config.random_delay = 0.0 52 | config.experiment_id = config_dict.placeholder(str) 53 | config.anonymous = config_dict.placeholder(str) 54 | config.notes = config_dict.placeholder(str) 55 | 56 | if updates is not None: 57 | config.update(ConfigDict(updates).copy_and_resolve_references()) 58 | return config 59 | 60 | def __init__(self, config, variant): 61 | self.config = self.get_default_config(config) 62 | if self.config.experiment_id is None: 63 | self.config.experiment_id = uuid.uuid4().hex 64 | 65 | if self.config.prefix != '': 66 | self.config.project = '{}--{}'.format(self.config.prefix, self.config.project) 67 | 68 | if self.config.output_dir == '': 69 | self.config.output_dir = tempfile.mkdtemp() 70 | else: 71 | self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id) 72 | os.makedirs(self.config.output_dir, exist_ok=True) 73 | 74 | self._variant = copy(variant) 75 | 76 | if 'hostname' not in self._variant: 77 | self._variant['hostname'] = gethostname() 78 | 79 | if self.config.random_delay > 0: 80 | time.sleep(np.random.uniform(0, self.config.random_delay)) 81 | 82 | self.run = wandb.init( 83 | reinit=True, 84 | config=self._variant, 85 | project=self.config.project, 86 | entity=self.config.entity, 87 | dir=self.config.output_dir, 88 | id=self.config.experiment_id, 89 | anonymous=self.config.anonymous, 90 | notes=self.config.notes, 91 | settings=wandb.Settings( 92 | start_method="thread", 93 | _disable_stats=True, 94 | ), 95 | mode='online' if self.config.online else 'offline', 96 | ) 97 | 98 | def log(self, *args, **kwargs): 99 | self.run.log(*args, **kwargs) 100 | 101 | def save_pickle(self, obj, filename): 102 | with open(os.path.join(self.config.output_dir, filename), 'wb') as fout: 103 | pickle.dump(obj, fout) 104 | 105 | @property 106 | def experiment_id(self): 107 | return self.config.experiment_id 108 | 109 | @property 110 | def variant(self): 111 | return self.config.variant 112 | 113 | @property 114 | def output_dir(self): 115 | return self.config.output_dir 116 | 117 | 118 | def define_flags_with_default(**kwargs): 119 | for key, val in kwargs.items(): 120 | if isinstance(val, ConfigDict): 121 | config_flags.DEFINE_config_dict(key, val) 122 | elif isinstance(val, bool): 123 | # Note that True and False are instances of int. 124 | absl.flags.DEFINE_bool(key, val, 'automatically defined flag') 125 | elif isinstance(val, int): 126 | absl.flags.DEFINE_integer(key, val, 'automatically defined flag') 127 | elif isinstance(val, float): 128 | absl.flags.DEFINE_float(key, val, 'automatically defined flag') 129 | elif isinstance(val, str): 130 | absl.flags.DEFINE_string(key, val, 'automatically defined flag') 131 | else: 132 | raise ValueError('Incorrect value type') 133 | return kwargs 134 | 135 | 136 | def set_random_seed(seed): 137 | np.random.seed(seed) 138 | torch.cuda.manual_seed_all(seed) 139 | torch.manual_seed(seed) 140 | random.seed(seed) 141 | 142 | 143 | def print_flags(flags, flags_def): 144 | logging.info( 145 | 'Running training with hyperparameters: \n{}'.format( 146 | pprint.pformat( 147 | ['{}: {}'.format(key, val) for key, val in get_user_flags(flags, flags_def).items()] 148 | ) 149 | ) 150 | ) 151 | 152 | # update user flags with flags_def 153 | def get_user_flags(flags, flags_def): 154 | output = {} 155 | for key in flags_def: 156 | val = getattr(flags, key) 157 | if isinstance(val, ConfigDict): 158 | output.update(flatten_config_dict(val, prefix=key)) 159 | else: 160 | output[key] = val 161 | 162 | return output 163 | 164 | 165 | def flatten_config_dict(config, prefix=None): 166 | output = {} 167 | for key, val in config.items(): 168 | if prefix is not None: 169 | next_prefix = '{}.{}'.format(prefix, key) 170 | else: 171 | next_prefix = key 172 | if isinstance(val, ConfigDict): 173 | output.update(flatten_config_dict(val, prefix=next_prefix)) 174 | else: 175 | output[next_prefix] = val 176 | return output 177 | 178 | 179 | 180 | def prefix_metrics(metrics, prefix): 181 | return { 182 | '{}/{}'.format(prefix, key): value for key, value in metrics.items() 183 | } 184 | 185 | # generate xml assets path: gym_xml_path 186 | def generate_xml_path(): 187 | import gym, os 188 | xml_path = os.path.join(gym.__file__[:-11], 'envs/mujoco/assets') 189 | 190 | assert os.path.exists(xml_path) 191 | print("gym_xml_path: ",xml_path) 192 | 193 | return xml_path 194 | 195 | 196 | gym_xml_path = generate_xml_path() 197 | 198 | 199 | def record_data(file, content): 200 | with open(file, 'a+') as f: 201 | f.write('{}\n'.format(content)) 202 | 203 | 204 | def check_path(path): 205 | try: 206 | if not os.path.exists(path): 207 | os.mkdir(path) 208 | except FileExistsError: 209 | pass 210 | 211 | return path 212 | 213 | 214 | def update_xml(index, env_name): 215 | xml_name = parse_xml_name(env_name) 216 | os.system('cp ./xml_path/{0}/{1} {2}/{1}}'.format(index, xml_name, gym_xml_path)) 217 | 218 | time.sleep(0.2) 219 | 220 | 221 | def parse_xml_name(env_name): 222 | if 'walker' in env_name.lower(): 223 | xml_name = "walker2d.xml" 224 | elif 'hopper' in env_name.lower(): 225 | xml_name = "hopper.xml" 226 | elif 'halfcheetah' in env_name.lower(): 227 | xml_name = "half_cheetah.xml" 228 | elif "ant" in env_name.lower(): 229 | xml_name = "ant.xml" 230 | else: 231 | raise RuntimeError("No available environment named \'%s\'" % env_name) 232 | 233 | return xml_name 234 | 235 | 236 | def update_source_env(env_name): 237 | xml_name = parse_xml_name(env_name) 238 | 239 | os.system( 240 | 'cp ./xml_path/real_file/{0} {1}/{0}'.format(xml_name, gym_xml_path)) 241 | 242 | time.sleep(0.2) 243 | 244 | #TODO: gravity 245 | def update_target_env_gravity(variety_degree, env_name): 246 | old_xml_name = parse_xml_name(env_name) 247 | # create new xml 248 | xml_name = "{}_gravityx{}.xml".format(old_xml_name.split(".")[0], variety_degree) 249 | 250 | with open('../xml_path/real_file/{}'.format(old_xml_name), "r+") as f: 251 | 252 | new_f = open('../xml_path/sim_file/{}'.format(xml_name), "w+") 253 | for line in f.readlines(): 254 | if "gravity" in line: 255 | pattern = re.compile(r"gravity=\"(.*?)\"") 256 | a = pattern.findall(line) 257 | gravity_list = a[0].split(" ") 258 | new_gravity_list = [] 259 | for num in gravity_list: 260 | new_gravity_list.append(variety_degree * float(num)) 261 | 262 | replace_num = " ".join(str(i) for i in new_gravity_list) 263 | replace_num = "gravity=\"" + replace_num + "\"" 264 | sub_result = re.sub(pattern, str(replace_num), line) 265 | 266 | new_f.write(sub_result) 267 | else: 268 | new_f.write(line) 269 | 270 | new_f.close() 271 | 272 | # replace the default gym env with newly-revised env 273 | os.system( 274 | 'cp ../xml_path/sim_file/{0} {1}/{2}'.format(xml_name, gym_xml_path, old_xml_name)) 275 | 276 | time.sleep(0.2) 277 | 278 | #TODO: density 279 | def update_target_env_density(variety_degree, env_name): 280 | old_xml_name = parse_xml_name(env_name) 281 | # create new xml 282 | xml_name = "{}_densityx{}.xml".format(old_xml_name.split(".")[0], variety_degree) 283 | 284 | with open('../xml_path/real_file/{}'.format(old_xml_name), "r+") as f: 285 | 286 | new_f = open('../xml_path/sim_file/{}'.format(xml_name), "w") 287 | for line in f.readlines(): 288 | if "density" in line: 289 | pattern = re.compile(r'(?<=density=")\d+\.?\d*') 290 | a = pattern.findall(line) 291 | current_num = float(a[0]) 292 | replace_num = current_num * variety_degree 293 | sub_result = re.sub(pattern, str(replace_num), line) 294 | 295 | new_f.write(sub_result) 296 | else: 297 | new_f.write(line) 298 | 299 | new_f.close() 300 | 301 | # replace the default gym env with newly-revised env 302 | os.system( 303 | 'cp ../xml_path/sim_file/{0} {1}/{2}'.format(xml_name, gym_xml_path, old_xml_name)) 304 | 305 | time.sleep(0.2) 306 | 307 | #TODO: friction 308 | def update_target_env_friction(variety_degree, env_name): 309 | old_xml_name = parse_xml_name(env_name) 310 | # create new xml 311 | xml_name = "{}_frictionx{}.xml".format(old_xml_name.split(".")[0], variety_degree) 312 | 313 | with open('../xml_path/real_file/{}'.format(old_xml_name), "r+") as f: 314 | 315 | new_f = open('../xml_path/sim_file/{}'.format(xml_name), "w") 316 | for line in f.readlines(): 317 | if "friction" in line: 318 | pattern = re.compile(r"friction=\"(.*?)\"") 319 | a = pattern.findall(line) 320 | friction_list = a[0].split(" ") 321 | new_friction_list = [] 322 | for num in friction_list: 323 | new_friction_list.append(variety_degree * float(num)) 324 | 325 | replace_num = " ".join(str(i) for i in new_friction_list) 326 | replace_num = "friction=\"" + replace_num + "\"" 327 | sub_result = re.sub(pattern, str(replace_num), line) 328 | 329 | new_f.write(sub_result) 330 | else: 331 | new_f.write(line) 332 | 333 | new_f.close() 334 | 335 | # replace the default gym env with newly-revised env 336 | os.system( 337 | 'cp ../xml_path/sim_file/{0} {1}/{2}'.format(xml_name, gym_xml_path, old_xml_name)) 338 | 339 | time.sleep(0.2) 340 | 341 | 342 | # def generate_log(extra=None): 343 | # print(extra) 344 | # record_data('../documents/log_{}.txt'.format(args.log_index), "{}".format(extra)) -------------------------------------------------------------------------------- /viskit/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | File taken from RLKit (https://github.com/vitchyr/rlkit). 3 | Based on rllab's logger. 4 | 5 | https://github.com/rll/rllab 6 | """ 7 | from enum import Enum 8 | from contextlib import contextmanager 9 | import numpy as np 10 | import os 11 | import os.path as osp 12 | import sys 13 | import datetime 14 | import dateutil.tz 15 | import csv 16 | import json 17 | import pickle 18 | import errno 19 | import time 20 | import torch 21 | 22 | import tempfile 23 | 24 | from viskit.tabulate import tabulate 25 | 26 | 27 | class TerminalTablePrinter(object): 28 | def __init__(self): 29 | self.headers = None 30 | self.tabulars = [] 31 | 32 | def print_tabular(self, new_tabular): 33 | if self.headers is None: 34 | self.headers = [x[0] for x in new_tabular] 35 | else: 36 | assert len(self.headers) == len(new_tabular) 37 | self.tabulars.append([x[1] for x in new_tabular]) 38 | self.refresh() 39 | 40 | def refresh(self): 41 | import os 42 | rows, columns = os.popen('stty size', 'r').read().split() 43 | tabulars = self.tabulars[-(int(rows) - 3):] 44 | sys.stdout.write("\x1b[2J\x1b[H") 45 | sys.stdout.write(tabulate(tabulars, self.headers)) 46 | sys.stdout.write("\n") 47 | 48 | 49 | class MyEncoder(json.JSONEncoder): 50 | def default(self, o): 51 | if isinstance(o, type): 52 | return {'$class': o.__module__ + "." + o.__name__} 53 | elif isinstance(o, Enum): 54 | return { 55 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 56 | } 57 | elif callable(o): 58 | return { 59 | '$function': o.__module__ + "." + o.__name__ 60 | } 61 | return json.JSONEncoder.default(self, o) 62 | 63 | 64 | def mkdir_p(path): 65 | try: 66 | os.makedirs(path) 67 | except OSError as exc: # Python >2.5 68 | if exc.errno == errno.EEXIST and os.path.isdir(path): 69 | pass 70 | else: 71 | raise 72 | 73 | 74 | class Logger(object): 75 | def __init__(self): 76 | self._prefixes = [] 77 | self._prefix_str = '' 78 | 79 | self._tabular_prefixes = [] 80 | self._tabular_prefix_str = '' 81 | 82 | self._tabular = [] 83 | 84 | self._text_outputs = [] 85 | self._tabular_outputs = [] 86 | 87 | self._text_fds = {} 88 | self._tabular_fds = {} 89 | self._tabular_header_written = set() 90 | 91 | self._snapshot_dir = None 92 | self._snapshot_mode = 'all' 93 | self._snapshot_gap = 1 94 | 95 | self._log_tabular_only = False 96 | self._header_printed = False 97 | self.table_printer = TerminalTablePrinter() 98 | 99 | def reset(self): 100 | self.__init__() 101 | 102 | def _add_output(self, file_name, arr, fds, mode='a'): 103 | if file_name not in arr: 104 | mkdir_p(os.path.dirname(file_name)) 105 | arr.append(file_name) 106 | fds[file_name] = open(file_name, mode) 107 | 108 | def _remove_output(self, file_name, arr, fds): 109 | if file_name in arr: 110 | fds[file_name].close() 111 | del fds[file_name] 112 | arr.remove(file_name) 113 | 114 | def push_prefix(self, prefix): 115 | self._prefixes.append(prefix) 116 | self._prefix_str = ''.join(self._prefixes) 117 | 118 | def add_text_output(self, file_name): 119 | self._add_output(file_name, self._text_outputs, self._text_fds, 120 | mode='a') 121 | 122 | def remove_text_output(self, file_name): 123 | self._remove_output(file_name, self._text_outputs, self._text_fds) 124 | 125 | def add_tabular_output(self, file_name, relative_to_snapshot_dir=False): 126 | if relative_to_snapshot_dir: 127 | file_name = osp.join(self._snapshot_dir, file_name) 128 | self._add_output(file_name, self._tabular_outputs, self._tabular_fds, 129 | mode='w') 130 | 131 | def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False): 132 | if relative_to_snapshot_dir: 133 | file_name = osp.join(self._snapshot_dir, file_name) 134 | if self._tabular_fds[file_name] in self._tabular_header_written: 135 | self._tabular_header_written.remove(self._tabular_fds[file_name]) 136 | self._remove_output(file_name, self._tabular_outputs, self._tabular_fds) 137 | 138 | def set_snapshot_dir(self, dir_name): 139 | self._snapshot_dir = dir_name 140 | 141 | def get_snapshot_dir(self, ): 142 | return self._snapshot_dir 143 | 144 | def get_snapshot_mode(self, ): 145 | return self._snapshot_mode 146 | 147 | def set_snapshot_mode(self, mode): 148 | self._snapshot_mode = mode 149 | 150 | def get_snapshot_gap(self, ): 151 | return self._snapshot_gap 152 | 153 | def set_snapshot_gap(self, gap): 154 | self._snapshot_gap = gap 155 | 156 | def set_log_tabular_only(self, log_tabular_only): 157 | self._log_tabular_only = log_tabular_only 158 | 159 | def get_log_tabular_only(self, ): 160 | return self._log_tabular_only 161 | 162 | def log(self, s, with_prefix=True, with_timestamp=True): 163 | out = s 164 | if with_prefix: 165 | out = self._prefix_str + out 166 | if with_timestamp: 167 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 168 | timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z') 169 | out = "%s | %s" % (timestamp, out) 170 | if not self._log_tabular_only: 171 | # Also log to stdout 172 | print(out) 173 | for fd in list(self._text_fds.values()): 174 | fd.write(out + '\n') 175 | fd.flush() 176 | sys.stdout.flush() 177 | 178 | def record_tabular(self, key, val): 179 | self._tabular.append((self._tabular_prefix_str + str(key), str(val))) 180 | 181 | def record_dict(self, d, prefix=None): 182 | if prefix is not None: 183 | self.push_tabular_prefix(prefix) 184 | for k, v in d.items(): 185 | self.record_tabular(k, v) 186 | if prefix is not None: 187 | self.pop_tabular_prefix() 188 | 189 | def push_tabular_prefix(self, key): 190 | self._tabular_prefixes.append(key) 191 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 192 | 193 | def pop_tabular_prefix(self, ): 194 | del self._tabular_prefixes[-1] 195 | self._tabular_prefix_str = ''.join(self._tabular_prefixes) 196 | 197 | def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'): 198 | """ 199 | Data saved here will always override the last entry 200 | 201 | :param data: Something pickle'able. 202 | """ 203 | file_name = osp.join(self._snapshot_dir, file_name) 204 | if mode == 'joblib': 205 | import joblib 206 | joblib.dump(data, file_name, compress=3) 207 | elif mode == 'pickle': 208 | pickle.dump(data, open(file_name, "wb")) 209 | else: 210 | raise ValueError("Invalid mode: {}".format(mode)) 211 | return file_name 212 | 213 | def get_table_dict(self, ): 214 | return dict(self._tabular) 215 | 216 | def get_table_key_set(self, ): 217 | return set(key for key, value in self._tabular) 218 | 219 | @contextmanager 220 | def prefix(self, key): 221 | self.push_prefix(key) 222 | try: 223 | yield 224 | finally: 225 | self.pop_prefix() 226 | 227 | @contextmanager 228 | def tabular_prefix(self, key): 229 | self.push_tabular_prefix(key) 230 | yield 231 | self.pop_tabular_prefix() 232 | 233 | def log_variant(self, log_file, variant_data): 234 | mkdir_p(os.path.dirname(log_file)) 235 | with open(log_file, "w") as f: 236 | json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder) 237 | 238 | def record_tabular_misc_stat(self, key, values, placement='back'): 239 | if placement == 'front': 240 | prefix = "" 241 | suffix = key 242 | else: 243 | prefix = key 244 | suffix = "" 245 | if len(values) > 0: 246 | self.record_tabular(prefix + "Average" + suffix, np.average(values)) 247 | self.record_tabular(prefix + "Std" + suffix, np.std(values)) 248 | self.record_tabular(prefix + "Median" + suffix, np.median(values)) 249 | self.record_tabular(prefix + "Min" + suffix, np.min(values)) 250 | self.record_tabular(prefix + "Max" + suffix, np.max(values)) 251 | else: 252 | self.record_tabular(prefix + "Average" + suffix, np.nan) 253 | self.record_tabular(prefix + "Std" + suffix, np.nan) 254 | self.record_tabular(prefix + "Median" + suffix, np.nan) 255 | self.record_tabular(prefix + "Min" + suffix, np.nan) 256 | self.record_tabular(prefix + "Max" + suffix, np.nan) 257 | 258 | def dump_tabular(self, *args, **kwargs): 259 | wh = kwargs.pop("write_header", None) 260 | if len(self._tabular) > 0: 261 | if self._log_tabular_only: 262 | self.table_printer.print_tabular(self._tabular) 263 | else: 264 | for line in tabulate(self._tabular).split('\n'): 265 | self.log(line, *args, **kwargs) 266 | tabular_dict = dict(self._tabular) 267 | # Also write to the csv files 268 | # This assumes that the keys in each iteration won't change! 269 | for tabular_fd in list(self._tabular_fds.values()): 270 | writer = csv.DictWriter(tabular_fd, 271 | fieldnames=list(tabular_dict.keys())) 272 | if wh or ( 273 | wh is None and tabular_fd not in self._tabular_header_written): 274 | writer.writeheader() 275 | self._tabular_header_written.add(tabular_fd) 276 | writer.writerow(tabular_dict) 277 | tabular_fd.flush() 278 | del self._tabular[:] 279 | 280 | def pop_prefix(self, ): 281 | del self._prefixes[-1] 282 | self._prefix_str = ''.join(self._prefixes) 283 | 284 | def save_itr_params(self, itr, params): 285 | if self._snapshot_dir: 286 | if self._snapshot_mode == 'all': 287 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 288 | torch.save(params, file_name) 289 | elif self._snapshot_mode == 'last': 290 | # override previous params 291 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 292 | torch.save(params, file_name) 293 | elif self._snapshot_mode == "gap": 294 | if itr % self._snapshot_gap == 0: 295 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 296 | torch.save(params, file_name) 297 | elif self._snapshot_mode == "gap_and_last": 298 | if itr % self._snapshot_gap == 0: 299 | file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) 300 | torch.save(params, file_name) 301 | file_name = osp.join(self._snapshot_dir, 'params.pkl') 302 | torch.save(params, file_name) 303 | elif self._snapshot_mode == 'none': 304 | pass 305 | else: 306 | raise NotImplementedError 307 | 308 | 309 | def safe_json(data): 310 | if data is None: 311 | return True 312 | elif isinstance(data, (bool, int, float)): 313 | return True 314 | elif isinstance(data, (tuple, list)): 315 | return all(safe_json(x) for x in data) 316 | elif isinstance(data, dict): 317 | return all(isinstance(k, str) and safe_json(v) for k, v in data.items()) 318 | return False 319 | 320 | 321 | def dict_to_safe_json(d): 322 | """ 323 | Convert each value in the dictionary into a JSON'able primitive. 324 | :param d: 325 | :return: 326 | """ 327 | new_d = {} 328 | for key, item in d.items(): 329 | if safe_json(item): 330 | new_d[key] = item 331 | else: 332 | if isinstance(item, dict): 333 | new_d[key] = dict_to_safe_json(item) 334 | else: 335 | new_d[key] = str(item) 336 | return new_d 337 | 338 | 339 | def create_exp_name(exp_prefix, exp_id=0, seed=0): 340 | """ 341 | Create a semi-unique experiment name that has a timestamp 342 | :param exp_prefix: 343 | :param exp_id: 344 | :return: 345 | """ 346 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 347 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 348 | return "%s_%s-s-%d--%s" % (exp_prefix, timestamp, seed, str(exp_id)) 349 | 350 | 351 | def create_log_dir( 352 | exp_prefix, 353 | exp_id=0, 354 | seed=0, 355 | base_log_dir=None, 356 | include_exp_prefix_sub_dir=True, 357 | ): 358 | """ 359 | Creates and returns a unique log directory. 360 | 361 | :param exp_prefix: All experiments with this prefix will have log 362 | directories be under this directory. 363 | :param exp_id: The number of the specific experiment run within this 364 | experiment. 365 | :param base_log_dir: The directory where all log should be saved. 366 | :return: 367 | """ 368 | exp_name = create_exp_name(exp_prefix, exp_id=exp_id, 369 | seed=seed) 370 | if base_log_dir is None: 371 | base_log_dir = conf.LOCAL_LOG_DIR 372 | if include_exp_prefix_sub_dir: 373 | log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name) 374 | else: 375 | log_dir = osp.join(base_log_dir, exp_name) 376 | if osp.exists(log_dir): 377 | print("WARNING: Log directory already exists {}".format(log_dir)) 378 | os.makedirs(log_dir, exist_ok=True) 379 | return log_dir 380 | 381 | 382 | def setup_logger( 383 | exp_prefix="default", 384 | variant=None, 385 | text_log_file="debug.log", 386 | variant_log_file="variant.json", 387 | tabular_log_file="progress.csv", 388 | snapshot_mode="last", 389 | snapshot_gap=1, 390 | log_tabular_only=False, 391 | base_log_dir=None, 392 | **create_log_dir_kwargs 393 | ): 394 | """ 395 | Set up logger to have some reasonable default settings. 396 | 397 | Will save log output to 398 | 399 | based_log_dir/exp_prefix/exp_name. 400 | 401 | exp_name will be auto-generated to be unique. 402 | 403 | If log_dir is specified, then that directory is used as the output dir. 404 | 405 | :param exp_prefix: The sub-directory for this specific experiment. 406 | :param variant: 407 | :param text_log_file: 408 | :param variant_log_file: 409 | :param tabular_log_file: 410 | :param snapshot_mode: 411 | :param log_tabular_only: 412 | :param snapshot_gap: 413 | :param log_dir: 414 | :return: 415 | """ 416 | log_dir = create_log_dir( 417 | exp_prefix, base_log_dir=base_log_dir, **create_log_dir_kwargs 418 | ) 419 | 420 | if variant is not None: 421 | logger.log("Variant:") 422 | logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) 423 | variant_log_path = osp.join(log_dir, variant_log_file) 424 | logger.log_variant(variant_log_path, variant) 425 | 426 | tabular_log_path = osp.join(log_dir, tabular_log_file) 427 | text_log_path = osp.join(log_dir, text_log_file) 428 | 429 | logger.add_text_output(text_log_path) 430 | logger.add_tabular_output(tabular_log_path) 431 | logger.set_snapshot_dir(log_dir) 432 | logger.set_snapshot_mode(snapshot_mode) 433 | logger.set_snapshot_gap(snapshot_gap) 434 | logger.set_log_tabular_only(log_tabular_only) 435 | exp_name = log_dir.split("/")[-1] 436 | logger.push_prefix("[%s] " % exp_name) 437 | 438 | return log_dir 439 | 440 | 441 | logger = Logger() 442 | -------------------------------------------------------------------------------- /viskit/static/js/jquery.loadTemplate-1.5.6.js: -------------------------------------------------------------------------------- 1 | (function ($) { 2 | "use strict"; 3 | var templates = {}, 4 | queue = {}, 5 | formatters = {}, 6 | isArray; 7 | 8 | function loadTemplate(template, data, options) { 9 | var $that = this, 10 | $template, 11 | isFile, 12 | settings; 13 | 14 | data = data || {}; 15 | 16 | settings = $.extend(true, { 17 | // These are the defaults. 18 | async: true, 19 | overwriteCache: false, 20 | complete: null, 21 | success: null, 22 | error: function () { 23 | $(this).each(function () { 24 | $(this).html(settings.errorMessage); 25 | }); 26 | }, 27 | errorMessage: "There was an error loading the template.", 28 | paged: false, 29 | pageNo: 1, 30 | elemPerPage: 10, 31 | append: false, 32 | prepend: false, 33 | beforeInsert: null, 34 | afterInsert: null, 35 | bindingOptions: { 36 | ignoreUndefined: false, 37 | ignoreNull: false, 38 | ignoreEmptyString: false 39 | } 40 | }, options); 41 | 42 | if ($.type(data) === "array") { 43 | isArray = true; 44 | return processArray.call(this, template, data, settings); 45 | } 46 | 47 | if (!containsSlashes(template)) { 48 | $template = $(template); 49 | if (typeof template === 'string' && template.indexOf('#') === 0) { 50 | settings.isFile = false; 51 | } 52 | } 53 | 54 | isFile = settings.isFile || (typeof settings.isFile === "undefined" && (typeof $template === "undefined" || $template.length === 0)); 55 | 56 | if (isFile && !settings.overwriteCache && templates[template]) { 57 | prepareTemplateFromCache(template, $that, data, settings); 58 | } else if (isFile && !settings.overwriteCache && templates.hasOwnProperty(template)) { 59 | addToQueue(template, $that, data, settings); 60 | } else if (isFile) { 61 | loadAndPrepareTemplate(template, $that, data, settings); 62 | } else { 63 | loadTemplateFromDocument($template, $that, data, settings); 64 | } 65 | return this; 66 | } 67 | 68 | function addTemplateFormatter(key, formatter) { 69 | if (formatter) { 70 | formatters[key] = formatter; 71 | } else { 72 | formatters = $.extend(formatters, key); 73 | } 74 | } 75 | 76 | function containsSlashes(str) { 77 | return typeof str === "string" && str.indexOf("/") > -1; 78 | } 79 | 80 | function processArray(template, data, settings) { 81 | settings = settings || {}; 82 | var $that = this, 83 | todo = data.length, 84 | doPrepend = settings.prepend && !settings.append, 85 | done = 0, 86 | success = 0, 87 | errored = false, 88 | errorObjects = [], 89 | newOptions; 90 | 91 | if (settings.paged) { 92 | var startNo = (settings.pageNo - 1) * settings.elemPerPage; 93 | data = data.slice(startNo, startNo + settings.elemPerPage); 94 | todo = data.length; 95 | } 96 | 97 | newOptions = $.extend( 98 | {}, 99 | settings, 100 | { 101 | async: false, 102 | complete: function (data) { 103 | if (this.html) { 104 | var insertedElement; 105 | if (doPrepend) { 106 | insertedElement = $(this.html()).prependTo($that); 107 | } else { 108 | insertedElement = $(this.html()).appendTo($that); 109 | } 110 | if (settings.afterInsert && data) { 111 | settings.afterInsert(insertedElement, data); 112 | } 113 | } 114 | done++; 115 | if (done === todo || errored) { 116 | if (errored && settings && typeof settings.error === "function") { 117 | settings.error.call($that, errorObjects); 118 | } 119 | if (settings && typeof settings.complete === "function") { 120 | settings.complete(); 121 | } 122 | } 123 | }, 124 | success: function () { 125 | success++; 126 | if (success === todo) { 127 | if (settings && typeof settings.success === "function") { 128 | settings.success(); 129 | } 130 | } 131 | }, 132 | error: function (e) { 133 | errored = true; 134 | errorObjects.push(e); 135 | } 136 | } 137 | ); 138 | 139 | if (!settings.append && !settings.prepend) { 140 | $that.html(""); 141 | } 142 | 143 | if (doPrepend) data.reverse(); 144 | $(data).each(function () { 145 | var $div = $("
"); 146 | loadTemplate.call($div, template, this, newOptions); 147 | if (errored) { 148 | return false; 149 | } 150 | }); 151 | 152 | return this; 153 | } 154 | 155 | function addToQueue(template, selection, data, settings) { 156 | if (queue[template]) { 157 | queue[template].push({ data: data, selection: selection, settings: settings }); 158 | } else { 159 | queue[template] = [{ data: data, selection: selection, settings: settings}]; 160 | } 161 | } 162 | 163 | function prepareTemplateFromCache(template, selection, data, settings) { 164 | var $templateContainer = templates[template].clone(); 165 | 166 | prepareTemplate.call(selection, $templateContainer, data, settings); 167 | if (typeof settings.success === "function") { 168 | settings.success(); 169 | } 170 | } 171 | 172 | function uniqueId() { 173 | return new Date().getTime(); 174 | } 175 | 176 | function urlAvoidCache(url) { 177 | if (url.indexOf('?') !== -1) { 178 | return url + "&_=" + uniqueId(); 179 | } 180 | else { 181 | return url + "?_=" + uniqueId(); 182 | } 183 | } 184 | 185 | function loadAndPrepareTemplate(template, selection, data, settings) { 186 | var $templateContainer = $("
"); 187 | 188 | templates[template] = null; 189 | var templateUrl = template; 190 | if (settings.overwriteCache) { 191 | templateUrl = urlAvoidCache(templateUrl); 192 | } 193 | $.ajax({ 194 | url: templateUrl, 195 | async: settings.async, 196 | success: function (templateContent) { 197 | $templateContainer.html(templateContent); 198 | handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings); 199 | }, 200 | error: function (e) { 201 | handleTemplateLoadingError(template, selection, data, settings, e); 202 | } 203 | }); 204 | } 205 | 206 | function loadTemplateFromDocument($template, selection, data, settings) { 207 | var $templateContainer = $("
"); 208 | 209 | if ($template.is("script") || $template.is("template")) { 210 | $template = $.parseHTML($.trim($template.html())); 211 | } 212 | 213 | $templateContainer.html($template); 214 | prepareTemplate.call(selection, $templateContainer, data, settings); 215 | 216 | if (typeof settings.success === "function") { 217 | settings.success(); 218 | } 219 | } 220 | 221 | function prepareTemplate(template, data, settings) { 222 | bindData(template, data, settings); 223 | 224 | $(this).each(function () { 225 | var $templateHtml = $(template.html()); 226 | if (settings.beforeInsert) { 227 | settings.beforeInsert($templateHtml, data); 228 | } 229 | if (settings.append) { 230 | 231 | $(this).append($templateHtml); 232 | } else if (settings.prepend) { 233 | $(this).prepend($templateHtml); 234 | } else { 235 | $(this).html($templateHtml); 236 | } 237 | if (settings.afterInsert && !isArray) { 238 | settings.afterInsert($templateHtml, data); 239 | } 240 | }); 241 | 242 | if (typeof settings.complete === "function") { 243 | settings.complete.call($(this), data); 244 | } 245 | } 246 | 247 | function handleTemplateLoadingError(template, selection, data, settings, error) { 248 | var value; 249 | 250 | if (typeof settings.error === "function") { 251 | settings.error.call(selection, error); 252 | } 253 | 254 | $(queue[template]).each(function (key, value) { 255 | if (typeof value.settings.error === "function") { 256 | value.settings.error.call(value.selection, error); 257 | } 258 | }); 259 | 260 | if (typeof settings.complete === "function") { 261 | settings.complete.call(selection); 262 | } 263 | 264 | while (queue[template] && (value = queue[template].shift())) { 265 | if (typeof value.settings.complete === "function") { 266 | value.settings.complete.call(value.selection); 267 | } 268 | } 269 | 270 | if (typeof queue[template] !== 'undefined' && queue[template].length > 0) { 271 | queue[template] = []; 272 | } 273 | } 274 | 275 | function handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings) { 276 | var value; 277 | 278 | templates[template] = $templateContainer.clone(); 279 | prepareTemplate.call(selection, $templateContainer, data, settings); 280 | 281 | if (typeof settings.success === "function") { 282 | settings.success.call(selection); 283 | } 284 | 285 | while (queue[template] && (value = queue[template].shift())) { 286 | prepareTemplate.call(value.selection, templates[template].clone(), value.data, value.settings); 287 | if (typeof value.settings.success === "function") { 288 | value.settings.success.call(value.selection); 289 | } 290 | } 291 | } 292 | 293 | function bindData(template, data, settings) { 294 | data = data || {}; 295 | 296 | processElements("data-content", template, data, settings, function ($elem, value) { 297 | $elem.html(applyFormatters($elem, value, "content", settings)); 298 | }); 299 | 300 | processElements("data-content-append", template, data, settings, function ($elem, value) { 301 | $elem.append(applyFormatters($elem, value, "content", settings)); 302 | }); 303 | 304 | processElements("data-content-prepend", template, data, settings, function ($elem, value) { 305 | $elem.prepend(applyFormatters($elem, value, "content", settings)); 306 | }); 307 | 308 | processElements("data-content-text", template, data, settings, function ($elem, value) { 309 | $elem.text(applyFormatters($elem, value, "content", settings)); 310 | }); 311 | 312 | processElements("data-innerHTML", template, data, settings, function ($elem, value) { 313 | $elem.html(applyFormatters($elem, value, "content", settings)); 314 | }); 315 | 316 | processElements("data-src", template, data, settings, function ($elem, value) { 317 | $elem.attr("src", applyFormatters($elem, value, "src", settings)); 318 | }, function ($elem) { 319 | $elem.remove(); 320 | }); 321 | 322 | processElements("data-href", template, data, settings, function ($elem, value) { 323 | $elem.attr("href", applyFormatters($elem, value, "href", settings)); 324 | }, function ($elem) { 325 | $elem.remove(); 326 | }); 327 | 328 | processElements("data-alt", template, data, settings, function ($elem, value) { 329 | $elem.attr("alt", applyFormatters($elem, value, "alt", settings)); 330 | }); 331 | 332 | processElements("data-id", template, data, settings, function ($elem, value) { 333 | $elem.attr("id", applyFormatters($elem, value, "id", settings)); 334 | }); 335 | 336 | processElements("data-value", template, data, settings, function ($elem, value) { 337 | $elem.attr("value", applyFormatters($elem, value, "value", settings)); 338 | }); 339 | 340 | processElements("data-class", template, data, settings, function ($elem, value) { 341 | $elem.addClass(applyFormatters($elem, value, "class", settings)); 342 | }); 343 | 344 | processElements("data-link", template, data, settings, function ($elem, value) { 345 | var $linkElem = $(""); 346 | $linkElem.attr("href", applyFormatters($elem, value, "link", settings)); 347 | $linkElem.html($elem.html()); 348 | $elem.html($linkElem); 349 | }); 350 | 351 | processElements("data-link-wrap", template, data, settings, function ($elem, value) { 352 | var $linkElem = $(""); 353 | $linkElem.attr("href", applyFormatters($elem, value, "link-wrap", settings)); 354 | $elem.wrap($linkElem); 355 | }); 356 | 357 | processElements("data-options", template, data, settings, function ($elem, value) { 358 | $(value).each(function () { 359 | var $option = $("