├── assets └── poster.jpg ├── requirement.txt ├── .gitignore ├── main.py ├── stop_docker.sh ├── utils ├── math.py ├── rp_array.py ├── history_construct.py ├── zfilter.py ├── timer.py ├── torch_utils.py └── visualize_repre.py ├── models ├── mlp_base.py ├── transition.py ├── value.py ├── rnn_base.py └── policy.py ├── parameter ├── private_config.py └── Parameter.py ├── LICENSE ├── README.md ├── generate_tmuxp.py ├── algorithms ├── contrastive.py └── RMDM.py ├── log_util ├── logger_base.py └── logger.py ├── envs ├── grid_world.py ├── grid_world_general.py └── nonstationary_env.py └── agent └── Agent.py /assets/poster.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/ESCP/HEAD/assets/poster.jpg -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch 2 | mujoco_py 3 | tensorboard 4 | matplotlib 5 | numpy 6 | ray 7 | gym 8 | inspect 9 | sklearn -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | log 5 | run_all.yaml 6 | data 7 | figs 8 | log_file 9 | baselines.tar.gz 10 | 11 | 12 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from algorithms.sac import SAC 3 | 4 | 5 | if __name__ == '__main__': 6 | ray.init(log_to_driver=True) 7 | sac = SAC() 8 | sac.logger.log(sac.logger.parameter) 9 | sac.run() 10 | -------------------------------------------------------------------------------- /stop_docker.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | docker container stop $(docker container ps -aq --filter "ancestor=sanluosizhou/selfdl:latest") 3 | docker container stop $(docker container ps -aq --filter "ancestor=sanluosizhou/selfdl:ml") 4 | docker container prune -f -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def normal_entropy(std): 6 | var = std.pow(2) 7 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 8 | return entropy.sum(-1, keepdim=True) 9 | 10 | 11 | def normal_log_density(x, mean, log_std, std): 12 | var = std.pow(2) 13 | log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std 14 | return log_density.sum(-1, keepdim=True) 15 | -------------------------------------------------------------------------------- /models/mlp_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | from models.rnn_base import RNNBase 5 | 6 | class MLPBase(RNNBase): 7 | def __init__(self, input_size, output_size, hidden_size_list, activation): 8 | super().__init__(input_size, output_size, hidden_size_list, activation, ['fc'] * len(activation)) 9 | 10 | def meta_forward(self, x, h=None, require_full_hidden=False): 11 | return super(MLPBase, self).meta_forward(x, [], False) 12 | 13 | if __name__ == '__main__': 14 | hidden_layers = [256, 128, 64] 15 | hidden_activates = ['leaky_relu'] * len(hidden_layers) 16 | hidden_activates.append('tanh') 17 | nn = MLPBase(64, 4, hidden_layers, hidden_activates) 18 | for _ in range(5): 19 | x = torch.randn((3, 64)) 20 | y, _ = nn.meta_forward(x) 21 | print(y) 22 | -------------------------------------------------------------------------------- /parameter/private_config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | 5 | def get_base_path(): 6 | return osp.dirname(osp.dirname(osp.abspath(__file__))) 7 | 8 | 9 | def system(cmd, print_func=None): 10 | if print_func is None: 11 | print(cmd) 12 | else: 13 | print_func(cmd) 14 | os.system(cmd) 15 | 16 | EXPERIMENT_TARGET = "RELEASE" 17 | MAIN_MACHINE_IP = "114.212.22.189" 18 | SKIP_MAX_LEN_DONE = True 19 | FC_MODE = False 20 | ENV_DEFAULT_CHANGE = 3.0 21 | USE_TQDM = False 22 | NON_STATIONARY_PERIOD = 100 23 | NON_STATIONARY_INTERVAL = 10 24 | SHORT_NAME_SUFFIX = 'N' 25 | 26 | 27 | def get_global_configs(things): 28 | res = dict() 29 | for k, v in things: 30 | if not k.startswith('__') and not hasattr(v, '__call__') and 'module' not in str(type(v)): 31 | res[k] = v 32 | return res 33 | 34 | def global_configs(things=[*locals().items()]): 35 | return get_global_configs(things) 36 | 37 | # ALL_CONFIGS = get_global_configs([*locals().items()]) 38 | 39 | -------------------------------------------------------------------------------- /utils/rp_array.py: -------------------------------------------------------------------------------- 1 | from utils.replay_memory import MemoryArray 2 | 3 | 4 | class RPArray: 5 | def __init__(self, env_num, rnn_slice_length=32, max_trajectory_num=1000, max_traj_step=1050, fix_length=0): 6 | self.env_num = env_num 7 | max_trajectory_num_per_env = max_trajectory_num // env_num 8 | self.max_trajectory_num_per_env = max_trajectory_num_per_env 9 | self.replay_buffer_array = [] 10 | for i in range(self.env_num): 11 | self.replay_buffer_array.append(MemoryArray(rnn_slice_length, max_trajectory_num_per_env, max_traj_step, fix_length)) 12 | 13 | @property 14 | def size(self): 15 | sizes = [item.size for item in self.replay_buffer_array] 16 | return sum(sizes) 17 | 18 | def sample_transitions(self, batch_size=None): 19 | pass 20 | 21 | def sample_fix_length_sub_trajs(self, batch_size, fix_length): 22 | pass 23 | 24 | def sample_trajs(self, batch_size, max_sample_size=None): 25 | pass 26 | 27 | def mem_push(self, mem): 28 | pass -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Fan-Ming Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/transition.py: -------------------------------------------------------------------------------- 1 | from models.rnn_base import RNNBase 2 | from models.value import Value 3 | import torch 4 | 5 | 6 | class Trainsition(Value): 7 | def __init__(self, obs_dim, act_dim, up_hidden_size, up_activations, up_layer_type, 8 | ep_hidden_size, ep_activation, ep_layer_type, ep_dim, use_gt_env_feature, 9 | logger=None, freeze_ep=False, enhance_ep=False, stop_pg_for_ep=False): 10 | super(Value, self).__init__() 11 | self.obs_dim = obs_dim 12 | self.act_dim = act_dim 13 | self.use_gt_env_feature = use_gt_env_feature 14 | # aux dim: we add ep to every layer inputs. 15 | aux_dim = ep_dim if enhance_ep else 0 16 | self.up = RNNBase(obs_dim + ep_dim + act_dim, obs_dim, up_hidden_size, up_activations, up_layer_type, logger, aux_dim) 17 | self.ep = RNNBase(obs_dim + act_dim, ep_dim, ep_hidden_size, ep_activation, ep_layer_type, logger) 18 | self.ep_rnn_count = self.ep.rnn_num 19 | self.up_rnn_count = self.up.rnn_num 20 | # ep first, up second 21 | self.module_list = torch.nn.ModuleList(self.up.total_module_list + self.ep.total_module_list) 22 | self.min_log_std = -7.0 23 | self.max_log_std = 2.0 24 | self.sample_hidden_state = None 25 | self.ep_tensor = None 26 | self.freeze_ep = freeze_ep 27 | self.enhance_ep = enhance_ep 28 | self.stop_pg_for_ep = stop_pg_for_ep 29 | 30 | 31 | def forward(self, x, lst_a, a, h, ep_out=None): 32 | next_x_delta, h_out = self.meta_forward(x, lst_a, a, h, ep_out=ep_out) 33 | next_x = next_x_delta + x 34 | return next_x, h_out -------------------------------------------------------------------------------- /utils/history_construct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import copy 4 | 5 | class HistoryConstructor: 6 | """ 7 | state0, action0, state1, action1, ..., state_{max_len}, action_{max_len}, state_{current} 8 | """ 9 | def __init__(self, max_len, state_dim, action_dim, need_lst_action=False): 10 | self.max_len = max_len 11 | self.buffer = deque(maxlen=max_len * 2 + 1) 12 | self.state_dim = state_dim 13 | self.action_dim = action_dim 14 | self.lst_action = np.zeros((self.action_dim,)) 15 | self.need_lst_action = need_lst_action 16 | self.reset() 17 | 18 | def __call__(self, current_state): 19 | if self.max_len == 0: 20 | if self.need_lst_action: 21 | obs_dim = len(np.shape(current_state)) 22 | if obs_dim == 1: 23 | return np.hstack((self.lst_action.reshape((-1)), current_state)) 24 | else: 25 | return np.hstack((self.lst_action.reshape((1, -1)), current_state)) 26 | else: 27 | return current_state 28 | 29 | self.buffer.append(np.squeeze(current_state)) 30 | return np.hstack(self.buffer) 31 | 32 | def reset(self): 33 | self.lst_action = np.zeros((self.action_dim, )) 34 | for i in range(self.max_len): 35 | self.buffer.append(np.zeros((self.state_dim,))) 36 | self.buffer.append(np.zeros((self.action_dim,))) 37 | 38 | def update_action(self, action): 39 | self.lst_action = copy.deepcopy(action) 40 | if self.max_len == 0: 41 | return 42 | self.buffer.append(np.squeeze(action)) 43 | 44 | 45 | if __name__ == '__main__': 46 | import gym 47 | env = gym.make('Hopper-v2') 48 | constructor = HistoryConstructor(4, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0]) 49 | state = constructor(env.reset()) 50 | for _ in range(100): 51 | np.set_printoptions(suppress=True, threshold=int(1e5), linewidth=150, precision=2) 52 | print('unified state: ', state) 53 | action = env.action_space.sample() 54 | next_state_tmp, reward, done, _ = env.step(action) 55 | print('state: ', next_state_tmp, ', action: ', action) 56 | constructor.update_action(action) 57 | next_state = constructor(next_state_tmp) 58 | state = next_state 59 | -------------------------------------------------------------------------------- /utils/zfilter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from https://github.com/joschu/modular_rl 4 | # http://www.johndcook.com/blog/standard_deviation/ 5 | import torch 6 | 7 | class RunningStat(object): 8 | def __init__(self, shape): 9 | self._n = 0 10 | self._M = np.zeros(shape) 11 | self._S = np.zeros(shape) 12 | 13 | def push(self, x): 14 | x = np.asarray(x) 15 | assert x.shape == self._M.shape 16 | self._n += 1 17 | if self._n == 1: 18 | self._M[...] = x 19 | else: 20 | oldM = self._M.copy() 21 | self._M[...] = oldM + (x - oldM) / self._n 22 | self._S[...] = self._S + (x - oldM) * (x - self._M) 23 | 24 | @property 25 | def n(self): 26 | return self._n 27 | 28 | @property 29 | def mean(self): 30 | return self._M 31 | 32 | @property 33 | def var(self): 34 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 35 | 36 | @property 37 | def std(self): 38 | return np.sqrt(self.var) 39 | 40 | @property 41 | def shape(self): 42 | return self._M.shape 43 | 44 | 45 | class ZFilter: 46 | """ 47 | y = (x-mean)/std 48 | using running estimates of mean,std 49 | """ 50 | 51 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 52 | self.demean = demean 53 | self.destd = destd 54 | self.clip = clip 55 | 56 | self.rs = RunningStat(shape) 57 | self.fix = False 58 | 59 | def __call__(self, x, update=True): 60 | # return x 61 | if update and not self.fix: 62 | self.rs.push(x) 63 | if self.demean: 64 | x = x - self.rs.mean 65 | if self.destd: 66 | x = x / (self.rs.std + 1e-8) 67 | if self.clip: 68 | x = np.clip(x, -self.clip, self.clip) 69 | return x 70 | 71 | def transform_tensor(self, x, mean, std): 72 | if self.demean: 73 | x = x - mean 74 | if self.destd: 75 | x = x / (std + 1e-8) 76 | if self.clip: 77 | x = torch.clamp(x, -self.clip, self.clip) 78 | # x = np.clip(x, -self.clip, self.clip) 79 | return x 80 | 81 | if __name__ == '__main__': 82 | import copy 83 | mean_filter = ZFilter((10, )) 84 | mean_filter.rs._M = 10 85 | mean_filter.rs._S = 20 86 | c = copy.deepcopy(mean_filter) 87 | mean_filter.rs._M = 1 88 | 89 | 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESCP 2 | Code for [Adapt to Environment Sudden Changes by Learning a Context Sensitive Policy](https://www.aaai.org/AAAI22Papers/AAAI-6573.LuoF.pdf). 3 | ![image](assets/poster.jpg) 4 | ## Installation 5 | ### Install with pip 6 | Install the required python packages in `requirement.txt` by 7 | ```bash 8 | pip install -r ./requirement.txt 9 | ``` 10 | Note: You can follow the instructions at [here](https://github.com/openai/mujoco-py) to properly install `mujoco-py`. 11 | 12 | ### Use a docker image 13 | we have built a docker image, with which we ran all the experiments in the paper. The docker image can be pulled from [DockerHub](https://hub.docker.com/repository/docker/sanluosizhou/selfdl). 14 | ```bash 15 | docker pull sanluosizhou/selfdl:ml 16 | ``` 17 | ## Run 18 | 19 | You can conduct the experiment in `HalfCheetah-v2` with the following command. 20 | ```bash 21 | python main.py --env_name HalfCheetah-v2 --rnn_fix_length 16 --seed 5 --task_num 40 --max_iter_num 2000 --varying_params dof_damping_1_dim --test_task_num 40 --ep_dim 2 --name_suffix RMDM --rbf_radius 3000 --use_rmdm --stop_pg_for_ep --bottle_neck 22 | ``` 23 | 24 | We also provide the command for running in the docker 25 | ```bash 26 | docker run --rm -it --shm-size 50gb --gpus all -v $PWD:/root/policy_adaptation sanluosizhou/selfdl:ml -c "cd /root/policy_adaptation && python main.py --env_name HalfCheetah-v2 --rnn_fix_length 16 --seed 5 --task_num 40 --max_iter_num 2000 --varying_params dof_damping_1_dim --test_task_num 40 --ep_dim 2 --name_suffix RMDM --rbf_radius 3000 --use_rmdm --stop_pg_for_ep --bottle_neck" 27 | ``` 28 | 29 | There are several key parameters: 30 | - `--env_name`: configures the environment you are going to conduct experiment on. The possible environments: `GridWorldPlat-v2,Hopper-v2,HalfCheetah-v2,Walker2d-v,Ant-v2,Humanoid-v2`. 31 | - `--rnn_fix_length`: configures the memory length (H in the paper). 32 | - `--seed`: configures the random seeds. 33 | - `--task_num`: configures how many environments are used for policy training (it should be set to *12* in `GridWorldPlat-v2`). 34 | - `--test_task_num`: configures how many environments are used for policy testing (it should be set to `12` in `GridWorldPlat-v2`). 35 | - `--varying_params`: configures what kinds of environment changes are used, refer to [code](envs/nonstationary_env.py) for all kinds of supported environment changes. 36 | 37 | You can conduct the experiment in `HalfCheetah-v2` with both `gravity` and `dof_damping` changed. 38 | 39 | ```bash 40 | python main.py --env_name HalfCheetah-v2 --rnn_fix_length 16 --seed 5 --task_num 40 --max_iter_num 2000 --varying_params dof_damping_1_dim gravity --test_task_num 40 --ep_dim 2 --name_suffix RMDM_more_change --kernel_type rbf --rbf_radius 80 --diversity_loss_weight 1.0 --use_rmdm --stop_pg_for_ep --bottle_neck 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import inspect 3 | import numpy as np 4 | 5 | class Timer: 6 | def __init__(self): 7 | self.check_points = {} 8 | self.points_time = {} 9 | self.need_summary = {} 10 | self.init_time = time.time() 11 | 12 | def reset(self): 13 | self.check_points = {} 14 | self.points_time = {} 15 | self.need_summary = {} 16 | 17 | @staticmethod 18 | def file_func_line(stack=1): 19 | frame = inspect.stack()[stack][0] 20 | info = inspect.getframeinfo(frame) 21 | return info.filename, info.function, info.lineno 22 | 23 | @staticmethod 24 | def line(stack=2, short=False): 25 | file, func, lineo = Timer.file_func_line(stack) 26 | if short: 27 | return f"line_{lineo}_func_{func}" 28 | return f"line: {lineo}, func: {func}, file: {file}" 29 | 30 | def register_point(self, tag=None, stack=3, short=True, need_summary=True, level=0): 31 | if tag is None: 32 | tag = self.line(stack, short) 33 | if False and not tag.startswith('__'): 34 | print(f'arrive {tag}, time: {time.time() - self.init_time}, level: {level}') 35 | if level not in self.check_points: 36 | self.check_points[level] = [] 37 | self.points_time[level] = [] 38 | self.need_summary[level] = set() 39 | self.check_points[level].append(tag) 40 | self.points_time[level].append(time.time()) 41 | if need_summary: 42 | self.need_summary[level].add(tag) 43 | 44 | def register_end(self, stack=4, level=0): 45 | self.register_point('__timer_end_unique', stack, need_summary=False, level=level) 46 | 47 | def summary(self): 48 | if len(self.check_points) == 0: 49 | return dict() 50 | res = {} 51 | for level in self.check_points: 52 | self.register_point('__timer_finale_unique', level=level) 53 | res_tmp = {} 54 | for ind, item in enumerate(self.check_points[level][:-1]): 55 | time_now = self.points_time[level][ind] 56 | time_next = self.points_time[level][ind + 1] 57 | if item in res_tmp: 58 | res_tmp[item].append(time_next - time_now) 59 | else: 60 | res_tmp[item] = [time_next - time_now] 61 | for k, v in res_tmp.items(): 62 | if k in self.need_summary[level]: 63 | res['period_' + k] = np.mean(v) 64 | self.reset() 65 | return res 66 | 67 | 68 | def test_timer(): 69 | timer = Timer() 70 | for i in range(4): 71 | timer.register_point() 72 | time.sleep(1) 73 | for k, v in timer.summary().items(): 74 | print(f'{k}, {v}') 75 | 76 | if __name__ == '__main__': 77 | test_timer() -------------------------------------------------------------------------------- /generate_tmuxp.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from parameter.private_config import * 3 | 4 | config = {"session_name": "run-all-8359y", "windows": []} 5 | base_path = get_base_path() 6 | docker_path = "/root/policy_adaptation" 7 | path = docker_path 8 | tb_port = 6006 9 | docker_template = f'docker run --rm -it --shm-size 50gb --gpus all -v {base_path}:{docker_path} sanluosizhou/selfdl:ml -c ' 10 | docker_template = None 11 | if docker_template is None or len(docker_template) == 0: 12 | path = base_path 13 | template = 'export CUDA_VISIBLE_DEVICES={0} && cd {1} ' \ 14 | '&& python main.py --env_name {3} --rnn_fix_length {4}' \ 15 | ' --seed {5} --task_num {6} --max_iter_num {7} --varying_params {8} --test_task_num {9} --ep_dim {10}' \ 16 | ' --name_suffix {11} --rbf_radius {12} --kernel_type {13} --diversity_loss_weight {14} ' 17 | seeds = [8] 18 | GPUS = [0, 1] 19 | envs = ['HalfCheetah-v2', 'Hopper-v2', 'Walker2d-v2', 'Ant-v2'] 20 | count_it = 0 21 | algs = ['sac'] 22 | task_common_num = 40 23 | test_task_num = 40 24 | max_iter_num = 2000 25 | rnn_fix_length = 16 26 | ep_dim = 2 27 | sac_mini_batch_size = 256 28 | test_dof = False 29 | diversity_loss_weight = 0.004 30 | if test_dof: 31 | varying_params = [ 32 | # ' gravity ', 33 | ' gravity dof_damping_1_dim ', 34 | ] 35 | kernel_type = "rbf" 36 | name_suffix = 'RMDM_REPLAY_DOF' 37 | rbf_kernel_size = 80 38 | else: 39 | varying_params = [ 40 | ' gravity ', 41 | ] 42 | kernel_type = "rbf_element_wise" 43 | name_suffix = 'RMDM_REPLAY' 44 | rbf_kernel_size = 3000 45 | use_rmdm = True 46 | stop_pg_for_ep = True 47 | bottle_neck = True 48 | imitate_update_interval = 50 49 | for seed in seeds: 50 | for env in envs: 51 | panes_list = [] 52 | for varying_param in varying_params: 53 | for alg in algs: 54 | script_it = template.format(GPUS[count_it % len(GPUS)], path, alg, env, rnn_fix_length, 55 | seed, task_common_num, max_iter_num, varying_param, test_task_num, 56 | ep_dim, name_suffix, rbf_kernel_size, kernel_type, diversity_loss_weight) 57 | if use_rmdm: 58 | script_it += ' --use_rmdm ' 59 | if stop_pg_for_ep: 60 | script_it += ' --stop_pg_for_ep ' 61 | if bottle_neck: 62 | script_it += ' --bottle_neck ' 63 | if docker_template is not None and len(docker_template) > 0: 64 | script_it = docker_template + '"{}"'.format(script_it) 65 | 66 | print('{}-{}'.format(env, seed), ': ', script_it) 67 | 68 | panes_list.append(script_it) 69 | count_it = count_it + 1 70 | 71 | config["windows"].append({ 72 | "window_name": "{}-{}".format(env, seed), 73 | "panes": panes_list, 74 | "layout": "tiled" 75 | }) 76 | 77 | yaml.dump(config, open("run_all.yaml", "w"), default_flow_style=False) 78 | -------------------------------------------------------------------------------- /algorithms/contrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import time 4 | import numpy as np 5 | 6 | 7 | class ContrastiveLoss: 8 | def __init__(self, c_dim, max_env_len=40, ep=None): 9 | self.max_env_len = max_env_len 10 | self.query_ep = ep 11 | self.c_dim = c_dim 12 | self.W = torch.rand((c_dim, c_dim), requires_grad=True) 13 | self.w_optim = torch.optim.Adam([self.W], lr=1e-1) 14 | self.loss_func = torch.nn.CrossEntropyLoss() 15 | 16 | def to(self, *arg, **kwargs): 17 | self.query_ep.to(*arg, **kwargs) 18 | self.W = self.W.to(*arg, **kwargs) 19 | 20 | def get_query_tensor(self, state, last_action): 21 | hidden = self.query_ep.make_init_state(state.shape[0], device=state.device) 22 | ep, h, full_hidden = self.query_ep.meta_forward(torch.cat((state, last_action), dim=-1), 23 | hidden, require_full_hidden=True) 24 | return ep 25 | 26 | def get_loss_meta(self, y_key, y_query, need_w_grad=False): 27 | if need_w_grad: 28 | proj_k = self.W.matmul(y_key.t()) 29 | else: 30 | proj_k = self.W.detach().matmul(y_key.t()) 31 | # proj_k = 30 * torch.eye((self.W + self.W.t()).shape[0]).detach().matmul(y.t()) 32 | # proj_k = y.t() 33 | logits = y_query.matmul(proj_k) 34 | # print(logits.max(dim=1, keepdim=True).values) 35 | logits = logits - logits.max(dim=1, keepdim=True).values 36 | # logits = get_rbf_matrix(y, y, alpha=1) 37 | labels = torch.arange(logits.shape[0]).to(device=y_key.device) 38 | # print(logits) 39 | loss = self.loss_func(logits, labels) 40 | return loss 41 | 42 | def update_ep(self, origin_ep): 43 | self.query_ep.copy_weight_from(origin_ep, tau=0.99) 44 | 45 | def contrastive_loss(self, predicted_env_vector, predicted_env_vector_query, tasks): 46 | tasks = tasks[..., -1, 0] # torch.max(tasks[..., 0, 0], ) 47 | tasks_sorted, indices = torch.sort(tasks) 48 | tasks_sorted_np = tasks_sorted.detach().cpu().numpy().reshape((-1)) 49 | task_ind_map = {} 50 | tasks_sorted_np_idx = np.where(np.diff(tasks_sorted_np))[0] + 1 51 | last_ind = 0 52 | for i, item in enumerate(tasks_sorted_np_idx): 53 | task_ind_map[tasks_sorted_np[item - 1]] = [last_ind, item] 54 | last_ind = item 55 | if i == len(tasks_sorted_np_idx) - 1: 56 | task_ind_map[tasks_sorted_np[-1]] = [last_ind, len(tasks_sorted_np)] 57 | predicted_env_vector = predicted_env_vector[indices] 58 | predicted_env_vector_query = predicted_env_vector_query[indices] 59 | if 0 in task_ind_map: 60 | predicted_env_vector = predicted_env_vector[task_ind_map[0][1]:] 61 | predicted_env_vector_query = predicted_env_vector_query[task_ind_map[0][1]:] 62 | start_ind = task_ind_map[0][1] 63 | task_ind_map.pop(0) 64 | for k in task_ind_map: 65 | task_ind_map[k][0] -= start_ind 66 | task_ind_map[k][1] -= start_ind 67 | all_queries_ind = [] 68 | all_key_ind = [] 69 | all_tasks = sorted(list(task_ind_map.keys())) 70 | for ind, item in enumerate(all_tasks): 71 | all_queries_ind.append(task_ind_map[item][0]) 72 | all_key_ind.append(task_ind_map[item][1]-1) 73 | queries = predicted_env_vector_query[all_queries_ind, 0] 74 | key = predicted_env_vector[all_key_ind, 0] 75 | loss_w = self.get_loss_meta(y_key=key.detach(), y_query=queries.detach(), need_w_grad=True) 76 | self.w_optim.zero_grad() 77 | loss_w.backward() 78 | self.w_optim.step() 79 | return self.get_loss_meta(y_key=key, y_query=queries.detach(), need_w_grad=False) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as pyd 3 | import numpy as np 4 | import math 5 | from torch.distributions import constraints 6 | import torch.nn.functional as F 7 | 8 | tensor = torch.tensor 9 | DoubleTensor = torch.DoubleTensor 10 | FloatTensor = torch.FloatTensor 11 | LongTensor = torch.LongTensor 12 | ByteTensor = torch.ByteTensor 13 | ones = torch.ones 14 | zeros = torch.zeros 15 | 16 | 17 | def to_device(device, *args): 18 | return [x.to(device) for x in args] 19 | 20 | 21 | def get_flat_params_from(model): 22 | params = [] 23 | for param in model.parameters(): 24 | params.append(param.view(-1)) 25 | 26 | flat_params = torch.cat(params) 27 | return flat_params 28 | 29 | 30 | def set_flat_params_to(model, flat_params): 31 | prev_ind = 0 32 | for param in model.parameters(): 33 | flat_size = int(np.prod(list(param.size()))) 34 | param.data.copy_( 35 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 36 | prev_ind += flat_size 37 | 38 | 39 | def get_flat_grad_from(inputs, grad_grad=False): 40 | grads = [] 41 | for param in inputs: 42 | if grad_grad: 43 | grads.append(param.grad.grad.view(-1)) 44 | else: 45 | if param.grad is None: 46 | grads.append(zeros(param.view(-1).shape)) 47 | else: 48 | grads.append(param.grad.view(-1)) 49 | 50 | flat_grad = torch.cat(grads) 51 | return flat_grad 52 | 53 | 54 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 55 | if create_graph: 56 | retain_graph = True 57 | 58 | inputs = list(inputs) 59 | params = [] 60 | for i, param in enumerate(inputs): 61 | if i not in filter_input_ids: 62 | params.append(param) 63 | 64 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 65 | 66 | j = 0 67 | out_grads = [] 68 | for i, param in enumerate(inputs): 69 | if i in filter_input_ids: 70 | out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype)) 71 | else: 72 | out_grads.append(grads[j].view(-1)) 73 | j += 1 74 | grads = torch.cat(out_grads) 75 | 76 | for param in params: 77 | param.grad = None 78 | return grads 79 | 80 | class TanhTransform(pyd.transforms.Transform): 81 | domain = pyd.constraints.real 82 | codomain = pyd.constraints.interval(-1.0, 1.0) 83 | bijective = True 84 | sign = +1 85 | 86 | def __init__(self, cache_size=1): 87 | super().__init__(cache_size=cache_size) 88 | 89 | @staticmethod 90 | def atanh(x): 91 | return 0.5 * (x.log1p() - (-x).log1p()) 92 | 93 | def __eq__(self, other): 94 | return isinstance(other, TanhTransform) 95 | 96 | def _call(self, x): 97 | return x.tanh() 98 | 99 | def _inverse(self, y): 100 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 101 | # one should use `cache_size=1` instead 102 | return self.atanh(y) 103 | 104 | def log_abs_det_jacobian(self, x, y): 105 | # We use a formula that is more numerically stable, see details in the following link 106 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 107 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 108 | 109 | 110 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 111 | arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} 112 | support = constraints.interval(-1., 1.) 113 | has_rsample = True 114 | 115 | def __init__(self, loc, scale): 116 | self.base_dist = pyd.Normal(loc, scale) 117 | super(SquashedNormal, self).__init__(self.base_dist, TanhTransform()) 118 | # super().__init__(self.base_dist, transforms) 119 | 120 | def expand(self, batch_shape, _instance=None): 121 | new = self._get_checked_instance(SquashedNormal, _instance) 122 | return super(SquashedNormal, self).expand(batch_shape, _instance=new) 123 | 124 | @property 125 | def loc(self): 126 | return self.base_dist.loc 127 | 128 | @property 129 | def scale(self): 130 | return self.base_dist.scale 131 | 132 | @property 133 | def mean(self): 134 | mu = self.loc 135 | return torch.tanh(mu) 136 | 137 | 138 | -------------------------------------------------------------------------------- /log_util/logger_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import os.path as osp, time, atexit, os, copy 3 | from parameter.private_config import * 4 | color2num = dict( 5 | gray=30, 6 | red=31, 7 | green=32, 8 | yellow=33, 9 | blue=34, 10 | magenta=35, 11 | cyan=36, 12 | white=37, 13 | crimson=38 14 | ) 15 | 16 | def colorize(string, color, bold=False, highlight=False): 17 | """ 18 | Colorize a string. 19 | 20 | This function was originally written by John Schulman. 21 | """ 22 | if color is None: 23 | return string 24 | attr = [] 25 | num = color2num[color] 26 | if highlight: num += 10 27 | attr.append(str(num)) 28 | if bold: attr.append('1') 29 | 30 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 31 | 32 | 33 | class LoggerBase: 34 | def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None, log_file=None): 35 | self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time()) 36 | self.base_log_file = log_file 37 | self.tb = None 38 | self.output_file = open(osp.join(self.output_dir, output_fname), 'w') 39 | atexit.register(self.output_file.close) 40 | self.log("Logging data to %s" % self.output_file.name, color='green') 41 | self.first_row = True 42 | self.log_headers = [] 43 | self.log_current_row = {} 44 | self.log_last_row = None 45 | self.exp_name = exp_name 46 | self.tb_x_label = None 47 | self.tb_x = None 48 | self.step = 0 49 | 50 | def init_tb(self): 51 | if osp.exists(self.output_dir): 52 | self.log("Warning: Log dir %s already exists! Storing info there anyway." % self.output_dir, color='blue') 53 | cmd = f'rm -rf {osp.join(self.output_dir, "events.out.tfevents*")}' 54 | system(cmd, lambda x: self.log(x)) 55 | else: 56 | os.makedirs(self.output_dir) 57 | self.tb = SummaryWriter(self.output_dir) 58 | 59 | def set_tb_x_label(self, label): 60 | self.tb_x_label = label 61 | 62 | def register_keys(self, keys): 63 | """ 64 | this keys will may not be added to tabular data at the first step, thus, 65 | they should be manually logged at the first step. 66 | :param keys: 67 | :return: 68 | """ 69 | for key in keys: 70 | self.log_tabular(key, 0, no_tb=True) 71 | 72 | def set_log_file(self, log_file): 73 | self.base_log_file = log_file 74 | 75 | def log(self, *args, color=None, bold=True): 76 | s = '' 77 | for item in args[:-1]: 78 | s += str(item) + ' ' 79 | s += str(args[-1]) 80 | print(colorize(s, color, bold=bold)) 81 | if self.base_log_file is not None: 82 | print(s, file=self.base_log_file) 83 | self.base_log_file.flush() 84 | 85 | def log_tabular(self, key, val, tb_prefix=None, no_tb=False): 86 | if self.tb_x_label is not None and key == self.tb_x_label: 87 | self.tb_x = val 88 | if self.first_row: 89 | self.log_headers.append(key) 90 | self.log_headers = sorted(self.log_headers) 91 | else: 92 | assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key 93 | assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key 94 | self.log_current_row[key] = val 95 | if tb_prefix is None: 96 | tb_prefix = 'tb' 97 | if not no_tb: 98 | if self.tb_x_label is None: 99 | self.tb.add_scalar(f'{tb_prefix}/{key}', val, self.step) 100 | else: 101 | self.tb.add_scalar(f'{tb_prefix}/{key}', val, self.tb_x) 102 | 103 | 104 | def dump_tabular(self): 105 | if self.first_row: 106 | self.log_last_row = self.log_current_row 107 | for k, v in self.log_last_row.items(): 108 | if k not in self.log_current_row: 109 | self.log_current_row[k] = v 110 | vals = [] 111 | key_lens = [len(key) for key in self.log_headers] 112 | if len(key_lens) > 0: 113 | max_key_len = max(15, max(key_lens)) 114 | else: 115 | max_key_len = 15 116 | keystr = '%' + '%d' % max_key_len 117 | fmt = "| " + keystr + "s | %15s |" 118 | n_slashes = 22 + max_key_len 119 | 120 | head_indication = f" iter {self.step} " 121 | # print("-" * n_slashes) 122 | bar_num = n_slashes - len(head_indication) 123 | self.log("-" * (bar_num // 2) + head_indication + "-" * (bar_num // 2)) 124 | for key in self.log_headers: 125 | val = self.log_current_row.get(key, "") 126 | valstr = "%8.3g" % val if hasattr(val, "__float__") else val 127 | self.log(fmt % (key, valstr)) 128 | vals.append(val) 129 | self.log("-" * n_slashes+'\n') 130 | if self.output_file is not None: 131 | if self.first_row: 132 | self.output_file.write("\t".join(self.log_headers) + "\n") 133 | self.output_file.write("\t".join(map(str, vals)) + "\n") 134 | self.output_file.flush() 135 | self.log_last_row = copy.deepcopy(self.log_current_row) 136 | self.log_current_row.clear() 137 | self.first_row = False 138 | self.step += 1 139 | 140 | 141 | if __name__ == '__main__': 142 | logger = LoggerBase(output_dir=os.path.join(get_base_path(), 'log_file', 'log_file', 'log_file')) 143 | logger.log('123123', '2232', color="green") 144 | for i in range(5): 145 | logger.log_tabular('a', i) 146 | logger.dump_tabular() -------------------------------------------------------------------------------- /log_util/logger.py: -------------------------------------------------------------------------------- 1 | from parameter.Parameter import Parameter 2 | from log_util.logger_base import LoggerBase 3 | from parameter.private_config import * 4 | import os 5 | import numpy as np 6 | import time 7 | import copy 8 | 9 | 10 | class Logger(LoggerBase): 11 | def __init__(self, log_to_file=True, parameter=None, force_backup=False): 12 | if parameter: 13 | self.parameter = parameter 14 | else: 15 | self.parameter = Parameter() 16 | self.output_dir = os.path.join(get_base_path(), 'log_file', self.parameter.short_name) 17 | if log_to_file: 18 | if not os.path.exists(self.output_dir): 19 | os.makedirs(self.output_dir) 20 | if os.path.exists(os.path.join(self.output_dir, 'log.txt')): 21 | system(f'mv {os.path.join(self.output_dir, "log.txt")} {os.path.join(self.output_dir, "log_back.txt")}') 22 | self.log_file = open(os.path.join(self.output_dir, 'log.txt'), 'w') 23 | else: 24 | self.log_file = None 25 | # super(Logger, self).set_log_file(self.log_file) 26 | super(Logger, self).__init__(self.output_dir, log_file=self.log_file) 27 | # self.parameter.set_log_func(lambda x: self.log(x)) 28 | self.current_data = {} 29 | self.logged_data = set() 30 | 31 | self.model_output_dir = self.get_model_output_path(self.parameter) 32 | self.log(f"my output path is {self.output_dir}") 33 | 34 | self.parameter.set_config_path(self.output_dir) 35 | if not os.path.exists(self.output_dir): 36 | self.log(f'directory {self.output_dir} does not exist, create it...') 37 | else: 38 | self.log(f'directory {self.output_dir} exists, checking identity...') 39 | if (self.parameter.check_identity(need_decription=True) or self.parameter.differences is None) \ 40 | and (not force_backup): 41 | self.log(f'config is completely same, file will be overwrited anyway...') 42 | else: 43 | self.log(f'config is not same, file will backup first...') 44 | diffs = self.parameter.differences 45 | self.log(f'difference appears in {diffs}') 46 | backup_dir = os.path.join(get_base_path(), "log_file", f"backup_{self.parameter.exec_time}") 47 | if not os.path.exists(backup_dir): 48 | os.makedirs(backup_dir) 49 | system(f"cp -r {self.output_dir} {backup_dir}", lambda x: self.log(x)) 50 | self.parameter.save_config() 51 | self.init_tb() 52 | self.backup_code() 53 | self.tb_header_dict = {} 54 | # self.output_dir = os.path.join(get_base_path(), "log_file") 55 | 56 | @staticmethod 57 | def get_model_output_path(parameter): 58 | output_dir = os.path.join(get_base_path(), 'log_file', parameter.short_name) 59 | return os.path.join(output_dir, 'model') 60 | 61 | @staticmethod 62 | def get_replay_buffer_path(parameter): 63 | output_dir = os.path.join(get_base_path(), 'log_file', parameter.short_name) 64 | return os.path.join(output_dir, 'replay_buffer.pkl') 65 | 66 | def backup_code(self): 67 | base_path = get_base_path() 68 | things = [] 69 | for item in os.listdir(base_path): 70 | p = os.path.join(base_path, item) 71 | if not item.startswith('.') and not item.startswith('__') and not item == 'log_file' and not item == 'baselines': 72 | things.append(p) 73 | code_path = os.path.join(self.output_dir, 'codes') 74 | if not os.path.exists(code_path): 75 | os.makedirs(code_path) 76 | for item in things: 77 | system(f'cp -r {item} {code_path}', lambda x: self.log(f'backing up: {x}')) 78 | 79 | def log(self, *args, color=None, bold=True): 80 | super(Logger, self).log(*args, color=color, bold=bold) 81 | 82 | def log_dict(self, color=None, bold=False, **kwargs): 83 | for k, v in kwargs.items(): 84 | super(Logger, self).log('{}: {}'.format(k, v), color=color, bold=bold) 85 | 86 | def log_dict_single(self, data, color=None, bold=False): 87 | for k, v in data.items(): 88 | super(Logger, self).log('{}: {}'.format(k, v), color=color, bold=bold) 89 | 90 | def __call__(self, *args, **kwargs): 91 | self.log(*args, **kwargs) 92 | 93 | def save_config(self): 94 | self.parameter.save_config() 95 | 96 | def log_tabular(self, key, val=None, tb_prefix=None, with_min_and_max=False, average_only=False, no_tb=False): 97 | if val is not None: 98 | super(Logger, self).log_tabular(key, val, tb_prefix, no_tb=no_tb) 99 | else: 100 | if key in self.current_data: 101 | self.logged_data.add(key) 102 | super(Logger, self).log_tabular(key if average_only else "Average"+key, np.mean(self.current_data[key]), tb_prefix, no_tb=no_tb) 103 | if not average_only: 104 | super(Logger, self).log_tabular("Std" + key, 105 | np.std(self.current_data[key]), tb_prefix, no_tb=no_tb) 106 | if with_min_and_max: 107 | super(Logger, self).log_tabular("Min" + key, np.min(self.current_data[key]), tb_prefix, no_tb=no_tb) 108 | super(Logger, self).log_tabular('Max' + key, np.max(self.current_data[key]), tb_prefix, no_tb=no_tb) 109 | 110 | def add_tabular_data(self, tb_prefix=None, **kwargs): 111 | for k, v in kwargs.items(): 112 | if tb_prefix is not None and k not in self.tb_header_dict: 113 | self.tb_header_dict[k] = tb_prefix 114 | if k not in self.current_data: 115 | self.current_data[k] = [] 116 | if not isinstance(v, list): 117 | self.current_data[k].append(v) 118 | else: 119 | self.current_data[k] += v 120 | 121 | def update_tb_header_dict(self, tb_header_dict): 122 | self.tb_header_dict.update(tb_header_dict) 123 | 124 | def dump_tabular(self): 125 | for k in self.current_data: 126 | if k not in self.logged_data: 127 | if k in self.tb_header_dict: 128 | self.log_tabular(k, tb_prefix=self.tb_header_dict[k], average_only=True) 129 | else: 130 | self.log_tabular(k, average_only=True) 131 | self.logged_data.clear() 132 | self.current_data.clear() 133 | super(Logger, self).dump_tabular() 134 | 135 | 136 | if __name__ == '__main__': 137 | logger = Logger() 138 | logger.log(122, '22', color='red', bold=False) 139 | data = {'a': 10, 'b': 11, 'c': 13} 140 | for i in range(100): 141 | for _ in range(10): 142 | for k in data: 143 | data[k] += 1 144 | logger.add_tabular_data(**data) 145 | logger.log_tabular('a') 146 | logger.dump_tabular() 147 | time.sleep(1) 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /envs/grid_world.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class GridWorld(gym.Env): 8 | def __init__(self, env_flag=2, append_context=False, continuous_action=True): 9 | super(gym.Env).__init__() 10 | self.deterministic = True 11 | # A, B, C, s_0, D 12 | # ------------------------ 13 | # | A, B, C | None | 14 | # ------------------------ 15 | # | s_0 | D | 16 | # ------------------------ 17 | # 0 stay 18 | # 1 up 19 | # 2 right 20 | # 3 left 21 | # 4 down 22 | self.continuous_action = continuous_action 23 | if self.continuous_action: 24 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) 25 | else: 26 | self.action_space = gym.spaces.Discrete(5) 27 | 28 | self.observation_space = None 29 | self._grid_escape_time = 0 30 | self._grid_max_time = 1000 31 | self._current_position = 0 32 | self.env_flag = env_flag 33 | self.append_context = append_context 34 | self.middle_state = [2, 3, 4] 35 | assert self.env_flag in self.middle_state, '{} is accepted.'.format(self.middle_state) 36 | self._ind_to_name = { 37 | 0: 's0', 38 | 1: 'D', 39 | 2: 'A', 40 | 3: 'B', 41 | 4: 'C', 42 | 5: 'None' 43 | } 44 | self.reward_setting = { 45 | 0: 0, 46 | 1: 1, 47 | 2: 10, 48 | 3: -10, 49 | 4: 0, 50 | 5: 0 51 | } 52 | for k in self.reward_setting: 53 | self.reward_setting[k] *= 0.1 54 | 55 | self.state_space = len(self.reward_setting) 56 | self._raw_state_length = self.state_space 57 | if self.append_context: 58 | self.state_space += len(self.middle_state) 59 | self.diy_env = True 60 | self.observation_space = gym.spaces.Box(0, 1, (self.state_space, )) 61 | 62 | @property 63 | def middle_state_embedding(self): 64 | v = [0] * len(self.middle_state) 65 | v[self.env_flag - 2] = 1 66 | return v 67 | 68 | def make_one_hot(self, state): 69 | vec = [0] * self._raw_state_length 70 | vec[state] = 1 71 | return vec 72 | 73 | def get_next_position_toy(self, action): 74 | if self._current_position == 0: 75 | if action == 0: 76 | # to D 77 | next_state = 1 78 | else: 79 | # to unknown position 80 | next_state = self.env_flag 81 | # elif self._current_position == 1: 82 | # # keep at D 83 | # next_state = 1 84 | elif self._current_position in self.middle_state + [1]: 85 | # to s0 86 | next_state = 0 87 | else: 88 | raise NotImplementedError('current position exceeds range!!!') 89 | return next_state 90 | 91 | def get_next_position(self, action): 92 | # ------------------------ 93 | # | A, B, C | None | 94 | # ------------------------ 95 | # | s_0 | D | 96 | # ------------------------ 97 | # action: 0 stay 98 | # action: 1 up 99 | # action: 2 right 100 | # action: 3 left 101 | # action: 4 down 102 | # self._ind_to_name = { 103 | # 0: 's0', 104 | # 1: 'D', 105 | # 2: 'A', 106 | # 3: 'B', 107 | # 4: 'C', 108 | # 5: 'None' 109 | # } 110 | if not self.deterministic: 111 | if random.random() > 0.5: 112 | action = action 113 | else: 114 | action = random.randint(0, 4) 115 | if action == 0: 116 | if self._current_position in [2, 3, 4]: 117 | return self.env_flag 118 | return self._current_position 119 | left_up_map = { 120 | 4: 0, 121 | # 2: 5 122 | } 123 | action_transition_mapping = \ 124 | { 125 | 0: {1: self.env_flag, 2: 1}, 126 | 1: {1: 5, 3: 0}, 127 | 5: {3: self.env_flag, 4:1}, 128 | 2: left_up_map, 129 | 3: left_up_map, 130 | 4: left_up_map 131 | } 132 | action_to_state = action_transition_mapping[self._current_position] 133 | if action in action_to_state: 134 | return action_to_state[action] 135 | if self._current_position in [2, 3, 4]: 136 | return self.env_flag 137 | return self._current_position 138 | 139 | def step(self, action): 140 | self._grid_escape_time += 1 141 | info = {} 142 | if self.continuous_action: 143 | action_tmp = (action[0] + 1) / 2 144 | action_tmp = int(action_tmp * 5) 145 | if action_tmp >= 5: 146 | action_tmp = 4 147 | next_state = self.get_next_position(action_tmp) 148 | else: 149 | assert isinstance(action, int), 'action should be int type rather than {}'.format(type(action)) 150 | next_state = self.get_next_position(action) 151 | done = False # next_state == 1 152 | if self._grid_escape_time >= self._grid_max_time: 153 | done = True 154 | reward = self.reward_setting[next_state] 155 | info['current_position'] = self._ind_to_name[next_state] 156 | next_state_vector = self.make_one_hot(next_state) 157 | self._current_position = next_state 158 | if self.append_context: 159 | next_state_vector += self.middle_state_embedding 160 | return next_state_vector, reward, done, info 161 | 162 | def reset(self): 163 | self._grid_escape_time = 0 164 | self._current_position = 0 165 | state = self.make_one_hot(self._current_position) 166 | if self.append_context: 167 | state += self.middle_state_embedding 168 | return state 169 | 170 | def seed(self, seed=None): 171 | self.action_space.seed(seed) 172 | 173 | 174 | class RandomGridWorld(GridWorld): 175 | def __init__(self, append_context=False): 176 | self.possible_choice = [2, 3, 4] 177 | self.renv_flag = random.choice(self.possible_choice) 178 | self.fix_env = None 179 | super(RandomGridWorld, self).__init__(self.renv_flag, append_context) 180 | 181 | def reset(self): 182 | if self.fix_env is None: 183 | self.renv_flag = random.choice(self.possible_choice) 184 | self.env_flag = self.renv_flag 185 | else: 186 | self.renv_flag = self.env_flag = self.fix_env 187 | return super(RandomGridWorld, self).reset() 188 | 189 | def set_fix_env(self, fix_env): 190 | self.renv_flag = self.env_flag = self.fix_env = fix_env 191 | 192 | def set_task(self, task): 193 | self.set_fix_env(task) 194 | 195 | def sample_tasks(self, n_tasks): 196 | if n_tasks < len(self.possible_choice): 197 | tasks = [random.choice(self.possible_choice) for _ in range(n_tasks)] 198 | else: 199 | tasks = [] 200 | for i in range(n_tasks): 201 | tasks.append(self.possible_choice[i % len(self.possible_choice)]) 202 | 203 | return tasks 204 | 205 | @property 206 | def env_parameter_vector_(self): 207 | return np.array([(self.renv_flag - np.min(self.possible_choice)) / (np.max(self.possible_choice) 208 | - np.min(self.possible_choice))]) 209 | 210 | @property 211 | def env_parameter_length(self): 212 | return 1 213 | 214 | from gym.envs.registration import register 215 | 216 | register( 217 | id='GridWorldNS-v2', entry_point=RandomGridWorld 218 | ) 219 | 220 | if __name__ == '__main__': 221 | import gym 222 | env = gym.make('GridWorldNS-v2') 223 | print('observation space: ', env.observation_space) 224 | print('action space: ', env.action_space) 225 | print(hasattr(env, 'rmdm_env_flag')) -------------------------------------------------------------------------------- /models/value.py: -------------------------------------------------------------------------------- 1 | from models.rnn_base import RNNBase 2 | import torch 3 | from torch.distributions import Normal 4 | import numpy as np 5 | import os 6 | 7 | 8 | class Value(torch.nn.Module): 9 | def __init__(self, obs_dim, act_dim, up_hidden_size, up_activations, up_layer_type, 10 | ep_hidden_size, ep_activation, ep_layer_type, ep_dim, use_gt_env_feature, 11 | logger=None, freeze_ep=False, enhance_ep=False, stop_pg_for_ep=False): 12 | super(Value, self).__init__() 13 | self.obs_dim = obs_dim 14 | self.act_dim = act_dim 15 | self.use_gt_env_feature = use_gt_env_feature 16 | # aux dim: we add ep to every layer inputs. 17 | aux_dim = ep_dim if enhance_ep else 0 18 | self.up = RNNBase(obs_dim + ep_dim + act_dim, 1, up_hidden_size, up_activations, up_layer_type, logger, aux_dim) 19 | self.ep = RNNBase(obs_dim + act_dim, ep_dim, ep_hidden_size, ep_activation, ep_layer_type, logger) 20 | self.ep_rnn_count = self.ep.rnn_num 21 | self.up_rnn_count = self.up.rnn_num 22 | # ep first, up second 23 | self.module_list = torch.nn.ModuleList(self.up.total_module_list + self.ep.total_module_list) 24 | self.min_log_std = -7.0 25 | self.max_log_std = 2.0 26 | self.sample_hidden_state = None 27 | self.ep_tensor = None 28 | self.freeze_ep = freeze_ep 29 | self.enhance_ep = enhance_ep 30 | self.stop_pg_for_ep = stop_pg_for_ep 31 | 32 | def get_ep(self, x, h, require_full_output=False): 33 | if require_full_output: 34 | ep, h, full_hidden = self.ep.meta_forward(x, h, require_full_output) 35 | if self.freeze_ep: 36 | ep = ep.detach() 37 | self.ep_tensor = ep 38 | return ep, h, full_hidden 39 | ep, h = self.ep.meta_forward(x, h) 40 | if self.freeze_ep: 41 | ep = ep.detach() 42 | self.ep_tensor = ep 43 | return ep, h 44 | 45 | def ep_h(self, h): 46 | return h[:self.ep_rnn_count] 47 | 48 | def up_h(self, h): 49 | return h[self.ep_rnn_count:] 50 | 51 | def make_init_state(self, batch_size, device): 52 | ep_h = self.ep.make_init_state(batch_size, device) 53 | up_h = self.up.make_init_state(batch_size, device) 54 | h = ep_h + up_h 55 | return h 56 | 57 | def meta_forward(self, x, lst_a, a, h, require_full_output=False, ep_out=None): 58 | ep_h = h[:self.ep_rnn_count] 59 | up_h = h[self.ep_rnn_count:] 60 | if not require_full_output: 61 | if self.use_gt_env_feature: 62 | up, up_h_out = self.up.meta_forward(torch.cat((x, a), -1), up_h) 63 | ep_h_out = [] 64 | else: 65 | if ep_out is None: 66 | ep, ep_h_out = self.get_ep(torch.cat((x, lst_a), -1), ep_h) 67 | else: 68 | ep = ep_out 69 | ep_h_out = [] 70 | if self.stop_pg_for_ep: 71 | ep = ep.detach() 72 | aux_input = ep if self.enhance_ep else None 73 | up, up_h_out = self.up.meta_forward(torch.cat((x, ep, a), -1), up_h, aux_state=aux_input) 74 | h_out = ep_h_out + up_h_out 75 | return up, h_out 76 | else: 77 | if self.use_gt_env_feature: 78 | up, up_h_out, up_full_hidden = self.up.meta_forward(torch.cat((x, a), -1), up_h, 79 | require_full_output) 80 | ep_h_out, ep_full_hidden = [], [] 81 | else: 82 | if ep_out is None: 83 | ep, ep_h_out, ep_full_hidden = self.get_ep(torch.cat((x, lst_a), -1), ep_h, require_full_output) 84 | else: 85 | ep, ep_h_out = ep_out, [] 86 | if self.stop_pg_for_ep: 87 | ep = ep.detach() 88 | aux_input = ep if self.enhance_ep else None 89 | up, up_h_out, up_full_hidden = self.up.meta_forward(torch.cat((x, ep, a), -1), up_h, 90 | require_full_output, aux_state=aux_input) 91 | h_out = ep_h_out + up_h_out 92 | full_hidden = ep_full_hidden + up_full_hidden 93 | return up, h_out, full_hidden 94 | 95 | def forward(self, x, lst_a, a, h, ep_out=None): 96 | value_out, h_out = self.meta_forward(x, lst_a, a, h, ep_out=ep_out) 97 | return value_out, h_out 98 | 99 | def save(self, path, index=0): 100 | self.up.save(os.path.join(path, f'value_universe_policy{index}.pt')) 101 | self.ep.save(os.path.join(path, f'value_environment_probe{index}.pt')) 102 | 103 | def load(self, path, index=0, **kwargs): 104 | self.up.load(os.path.join(path, f'value_universe_policy{index}.pt'), **kwargs) 105 | self.ep.load(os.path.join(path, f'value_environment_probe{index}.pt'), **kwargs) 106 | 107 | @staticmethod 108 | def make_config_from_param(parameter): 109 | return dict( 110 | up_hidden_size=parameter.value_hidden_size, 111 | up_activations=parameter.value_activations, 112 | up_layer_type=parameter.value_layer_type, 113 | ep_hidden_size=parameter.ep_hidden_size, 114 | ep_activation=parameter.ep_activations, 115 | ep_layer_type=parameter.ep_layer_type, 116 | ep_dim=parameter.ep_dim, 117 | use_gt_env_feature=parameter.use_true_parameter, 118 | enhance_ep=parameter.enhance_ep, 119 | stop_pg_for_ep=parameter.stop_pg_for_ep 120 | ) 121 | 122 | def copy_weight_from(self, src, tau): 123 | """ 124 | I am target net, tau ~~ 1 125 | if tau = 0, self <--- src_net 126 | if tau = 1, self <--- self 127 | """ 128 | self.up.copy_weight_from(src.up, tau) 129 | self.ep.copy_weight_from(src.ep, tau) 130 | 131 | def generate_hidden_state_with_slice(self, sliced_state: torch.Tensor, sliced_lst_action: torch.Tensor, sliced_action: torch.Tensor): 132 | """ 133 | :param sliced_state: 0-dim: mini-trajectory index, 1-dim: batch_size, 2-dim: time step, 3-dim: feature index 134 | :param sliced_lst_action: 135 | :param slice_num: 136 | :return: 137 | """ 138 | with torch.no_grad(): 139 | mini_traj_num = sliced_state.shape[0] 140 | batch_size = sliced_state.shape[1] 141 | device = sliced_state.device 142 | hidden_states = [] 143 | hidden_state_now = self.make_init_state(batch_size, device) 144 | for i in range(mini_traj_num): 145 | hidden_states.append(hidden_state_now) 146 | _, hidden_state_now = self.meta_forward(sliced_state[i], sliced_lst_action[i], 147 | sliced_action[i], hidden_state_now) 148 | return hidden_states 149 | 150 | 151 | def generate_hidden_state(self, state: torch.Tensor, lst_action: torch.Tensor, action: torch.Tensor, slice_num): 152 | """ 153 | :param sliced_state: 0-dim: mini-trajectory index, 1-dim: batch_size, 2-dim: time step, 3-dim: feature index 154 | :param sliced_lst_action: 155 | :param slice_num: 156 | :return: 157 | """ 158 | with torch.no_grad(): 159 | batch_size = state.shape[0] 160 | device = state.device 161 | hidden_states = [] 162 | hidden_state_now = self.make_init_state(batch_size, device) 163 | _, _, full_hidden = self.meta_forward(state, lst_action, action, hidden_state_now, require_full_output=True) 164 | for i in range(len(full_hidden)): 165 | full_hidden[i] = torch.cat((hidden_state_now[i].squeeze(0).unsqueeze(1), full_hidden[i]), dim=1) 166 | idx = [i * slice_num for i in range(state.shape[1] // slice_num)] 167 | hidden_states = [item[:, idx].transpose(0, 1) for item in full_hidden] 168 | hidden_states_res = [] 169 | for item in hidden_states: 170 | it_shape = item.shape 171 | hidden_states_res.append(item.reshape((1, it_shape[0] * it_shape[1], it_shape[2]))) 172 | return hidden_states_res -------------------------------------------------------------------------------- /utils/visualize_repre.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import matplotlib.cm as cm 5 | 6 | def merge_data(data, valid): 7 | res = [] 8 | for ind, item in enumerate(data): 9 | sp = item.shape 10 | # print(ind, valid[ind].shape, sp, item[0][valid[ind][0].squeeze()==1], item[0][valid[ind][0].squeeze(-1)==1]) 11 | if sp[1] == 1: 12 | res.append(item.squeeze(1).detach().cpu().numpy()) 13 | else: 14 | data_list = [] 15 | for i in range(sp[0]): 16 | if len(valid) > 0: 17 | valid_it = valid[ind][i] 18 | data_it = item[i][valid_it.squeeze(-1)==1] 19 | else: 20 | data_it = item[i] 21 | data_list.append(data_it) 22 | res.append(torch.cat(data_list, dim=0).detach().cpu().numpy()) 23 | # assert False 24 | return res 25 | 26 | def to_scalar(param_vector): 27 | if isinstance(param_vector, float) or isinstance(param_vector, int): 28 | return param_vector 29 | else: 30 | return param_vector[-1] 31 | pass 32 | 33 | 34 | def to_2dim_vector(param_vector): 35 | 36 | return [param_vector[0], param_vector[-1]] 37 | 38 | 39 | def visualize_repre(data, valid, output_file, real_param_dict=None, tasks=None): 40 | data = merge_data(data, valid) 41 | fig = plt.figure(0) 42 | plt.cla() 43 | cmap = plt.get_cmap('Spectral') 44 | min_ = 10000 45 | max_ = -10000 46 | if real_param_dict is not None: 47 | for k, v in real_param_dict.items(): 48 | v = to_scalar(v) 49 | if min_ > v: 50 | min_ = v 51 | if max_ < v: 52 | max_ = v 53 | norm = plt.Normalize(vmin=min_, vmax=max_) 54 | for ind, item in enumerate(data): 55 | x = item[:, 0] 56 | y = item[:, 1] 57 | if real_param_dict is None: 58 | plt.scatter(x, y) 59 | else: 60 | task_num = tasks[ind] 61 | plt.scatter(x, y, c=to_scalar(real_param_dict[task_num]) * np.ones_like(x), cmap=cmap, norm=norm) 62 | plt.colorbar() 63 | plt.savefig(output_file) 64 | plt.xlim(left=-1.1, right=1.1) 65 | plt.ylim(bottom=-1.1, top=1.1) 66 | ####### 67 | fig2 = plt.figure(1) 68 | ax = plt.gca() 69 | circle_list = [] 70 | for item in data: 71 | x = item[:, 0] 72 | y = item[:, 1] 73 | x_mean = np.mean(x) 74 | y_mean = np.mean(y) 75 | std = (np.std(x) + np.std(y)) / 2 76 | plt.scatter(x_mean, y_mean, marker='+', linewidths=2) 77 | circle_list.append(plt.Circle((x_mean, y_mean), std / 2, color='r', fill=False)) 78 | for circle in circle_list: 79 | ax.add_artist(circle) 80 | plt.xlim(left=-1.1, right=1.1) 81 | plt.ylim(bottom=-1.1, top=1.1) 82 | # print(f'saving fig to {output_file}') 83 | return fig, fig2 84 | 85 | def visualize_repre_real_param(data, valid, tasks, real_param_dict): 86 | data = merge_data(data, valid) 87 | cmap = plt.get_cmap('Spectral') 88 | min_ = 10000 89 | max_ = -10000 90 | min1_ = 10000 91 | max1_ = -10000 92 | min2_ = 10000 93 | max2_ = -10000 94 | for k, v in real_param_dict.items(): 95 | v1, v2 = to_2dim_vector(v) 96 | 97 | v = to_scalar(v) 98 | 99 | if min_ > v: 100 | min_ = v 101 | if max_ < v: 102 | max_ = v 103 | min1_ = min1_ if min1_ < v1 else v1 104 | max1_ = max1_ if max1_ > v1 else v1 105 | 106 | min2_ = min2_ if min2_ < v2 else v2 107 | max2_ = max2_ if max2_ > v2 else v2 108 | norm = plt.Normalize(vmin=min_, vmax=max_) 109 | norm1 = plt.Normalize(vmin=min1_, vmax=max1_) 110 | norm2 = plt.Normalize(vmin=min2_, vmax=max2_) 111 | 112 | fig2 = plt.figure(3) 113 | figsize = (3.2 * 2, 2.24 * 2 * 1.5) 114 | f, axarr = plt.subplots(2, 1, sharex=True, squeeze=False, figsize=figsize) 115 | 116 | pts = None 117 | means = [] 118 | colors = [] 119 | for i in range(2): 120 | ax = axarr[i][0] 121 | for ind, item in enumerate(data): 122 | task_num = tasks[ind] 123 | real_param = real_param_dict[task_num] 124 | vector_real_param = to_2dim_vector(real_param) 125 | v1, v2 = to_2dim_vector(real_param) 126 | vs = [v1, v2] 127 | v1_norm = norm1(v1) 128 | v2_norm = norm2(v2) 129 | norms = [norm1, norm2] 130 | normalized_color = cmap(v1_norm) 131 | lightness = (vector_real_param[0] + 1) / 2 132 | # light_color = [lightness * item for item in normalized_color] 133 | normalized_color = [normalized_color[0], 134 | normalized_color[1], 135 | normalized_color[2], 136 | v2_norm] 137 | light_color = normalized_color 138 | # light_color = [0, v2_norm, v1_norm] 139 | x = item[:, 0] 140 | y = item[:, 1] 141 | x_mean = np.mean(x) 142 | y_mean = np.mean(y) 143 | means.append([x_mean, y_mean]) 144 | colors.append(light_color) 145 | ax.scatter([x_mean], [y_mean], marker='o', linewidths=3, c=[vs[i]], cmap=cmap, norm=norms[i]) 146 | ax.set_xlim(left=-1.1, right=1.1) 147 | ax.set_ylim(bottom=-1.1, top=1.1) 148 | # ax.scatter([x_mean], [y_mean], marker='.', linewidths=1, c=[v1], cmap=cmap, norm=norm1) 149 | # means = np.array(means) 150 | 151 | # plt.scatter(means[:, 0], means[:, 1], marker='+', linewidths=2, c=colors) 152 | mapple = cm.ScalarMappable(norm=norm1, cmap=cmap) 153 | mapple.set_array([]) 154 | plt.colorbar(mapple, ax=[axarr[0][0], axarr[1][0]]) 155 | # plt.xlim(left=-1.1, right=1.1) 156 | # plt.ylim(bottom=-1.1, top=1.1) 157 | # print(f'saving fig to {output_file}') 158 | return f 159 | 160 | 161 | def visualize_repre_real_param_legacy(data, valid, tasks, real_param_dict): 162 | data = merge_data(data, valid) 163 | 164 | fig2 = plt.figure(3) 165 | for ind, item in enumerate(data): 166 | task_num = tasks[ind] 167 | real_param = real_param_dict[task_num] 168 | x = item[:, 0] 169 | y = item[:, 1] 170 | x_mean = np.mean(x) 171 | y_mean = np.mean(y) 172 | plt.scatter(x_mean, y_mean, marker='+', linewidths=2) 173 | plt.text(x_mean, y_mean, '{:.2f}'.format(to_scalar(real_param))) 174 | plt.xlim(left=-1.1, right=1.1) 175 | plt.ylim(bottom=-1.1, top=1.1) 176 | # print(f'saving fig to {output_file}') 177 | return fig2 178 | 179 | def xy_filter(x, y, ratio): 180 | x_mean = np.mean(x) 181 | y_mean = np.mean(y) 182 | distance = np.sqrt(np.square(x - x_mean) + np.square(y - y_mean)) 183 | distance_sorted = np.sort(distance) 184 | distance_threshold = distance_sorted[int(distance_sorted.shape[0] * ratio)] 185 | x_res = x[distance < distance_threshold] 186 | y_res = y[distance < distance_threshold] 187 | return x_res, y_res 188 | 189 | def visualize_repre_filtered(data, valid, ratio=0.8): 190 | data = merge_data(data, valid) 191 | fig = plt.figure(7) 192 | plt.cla() 193 | for item in data: 194 | x = item[:, 0] 195 | y = item[:, 1] 196 | x, y = xy_filter(x, y, ratio) 197 | plt.scatter(x, y) 198 | plt.xlim(left=-1.1, right=1.1) 199 | plt.ylim(bottom=-1.1, top=1.1) 200 | 201 | fig2 = plt.figure(1) 202 | ax = plt.gca() 203 | circle_list = [] 204 | for item in data: 205 | x = item[:, 0] 206 | y = item[:, 1] 207 | x, y = xy_filter(x, y, ratio) 208 | x_mean = np.mean(x) 209 | y_mean = np.mean(y) 210 | std = (np.std(x) + np.std(y)) / 2 211 | plt.scatter(x_mean, y_mean, marker='+', linewidths=2) 212 | circle_list.append(plt.Circle((x_mean, y_mean), std / 2, color='r', fill=False)) 213 | for circle in circle_list: 214 | ax.add_artist(circle) 215 | plt.xlim(left=-1.1, right=1.1) 216 | plt.ylim(bottom=-1.1, top=1.1) 217 | # print(f'saving fig to {output_file}') 218 | return fig, fig2 219 | 220 | 221 | """ 222 | fig = plt.figure(0) 223 | ax = plt.gca() 224 | disk1 = plt.Circle((0, 0), 0.3, color='r', fill=False) 225 | disk2 = plt.Circle((0, 0.5), 0.3, color='r', fill=False) 226 | ax.set_xlim((-1, 1)) 227 | ax.set_ylim((-1, 1)) 228 | ax.add_artist(disk1) 229 | ax.add_artist(disk2) 230 | plt.show() 231 | """ -------------------------------------------------------------------------------- /models/rnn_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import os 4 | import time 5 | 6 | class RNNBase(torch.nn.Module): 7 | def __init__(self, input_size, output_size, hidden_size_list, activation, layer_type, logger=None, aux_dim=0): 8 | super().__init__() 9 | assert len(activation) - 1 == len(hidden_size_list), "number of activation should be " \ 10 | "larger by 1 than size of hidden layers." 11 | assert len(activation) == len(layer_type), "number of layer type should equal to the activate" 12 | activation_dict = { 13 | 'tanh': torch.nn.Tanh, 14 | 'relu': torch.nn.ReLU, 15 | 'sigmoid': torch.nn.Sigmoid, 16 | 'leaky_relu': torch.nn.LeakyReLU, 17 | 'linear': None 18 | } 19 | layer_dict = { 20 | 'fc': torch.nn.Linear, 21 | # 'lstm': torch.nn.LSTM, # two output 22 | 'gru': torch.nn.GRU # one output 23 | } 24 | def decorate(module): 25 | return torch.jit.script(module) 26 | def fc_decorate(module): 27 | return module 28 | def rnn_decorate(module): 29 | return module 30 | lst_nh = input_size + aux_dim 31 | self.layer_list = [] 32 | self.layer_type = copy.deepcopy(layer_type) 33 | self.activation_list = [] 34 | self.rnn_hidden_state_input_size = [] 35 | self.rnn_layer_type = [] 36 | self.rnn_num = 0 37 | self.logger = logger 38 | for ind, item in enumerate(hidden_size_list): 39 | if self.layer_type[ind] == 'fc': 40 | self.layer_list.append(fc_decorate(layer_dict[self.layer_type[ind]](lst_nh, item))) 41 | else: 42 | self.rnn_num += 1 43 | self.layer_list.append(rnn_decorate(layer_dict[self.layer_type[ind]](lst_nh, item, batch_first=True))) 44 | self.rnn_hidden_state_input_size.append(item) 45 | self.rnn_layer_type.append(self.layer_type[ind]) 46 | if activation_dict[activation[ind]] is not None: 47 | self.activation_list.append(activation_dict[activation[ind]]()) 48 | else: 49 | self.activation_list.append(None) 50 | lst_nh = item + aux_dim 51 | if self.layer_type[-1] == 'fc': 52 | self.layer_list.append(fc_decorate(layer_dict[self.layer_type[-1]](lst_nh, output_size))) 53 | else: 54 | self.rnn_num += 1 55 | self.layer_list.append(rnn_decorate(layer_dict[self.layer_type[-1]](lst_nh, output_size, batch_first=True))) 56 | self.rnn_hidden_state_input_size.append(output_size) 57 | self.rnn_layer_type.append(self.layer_type[-1]) 58 | if activation_dict[activation[-1]] is not None: 59 | self.activation_list.append(activation_dict[activation[-1]]()) 60 | else: 61 | self.activation_list.append(None) 62 | # self.layer_list.append(torch.nn.Linear(lst_nh, output_size)) 63 | # self.activation_list.append(activation_dict[activation[-1]]()) 64 | self.total_module_list = self.layer_list + self.activation_list 65 | self._total_modules = torch.nn.ModuleList(self.total_module_list) 66 | self.input_size = input_size 67 | self.cumulative_forward_time = 0 68 | self.cumulative_meta_forward_time = 0 69 | assert len(self.layer_list) == len(self.activation_list), "number of layer should be equal to the number of activation" 70 | 71 | def make_init_state(self, batch_size, device=None): 72 | if device is None: 73 | device = torch.device("cpu") 74 | init_states = [] 75 | for ind, item in enumerate(self.rnn_hidden_state_input_size): 76 | if self.rnn_layer_type[ind] == 'lstm': 77 | init_states.append((torch.zeros((1, batch_size, item), device=device), 78 | torch.zeros((1, batch_size, item), device=device))) 79 | else: 80 | init_states.append(torch.zeros((1, batch_size, item), device=device)) 81 | return init_states 82 | 83 | def meta_forward(self, x, hidden_state=None, require_full_hidden=False, aux_state=None): 84 | _meta_start_time = time.time() 85 | assert x.shape[-1] == self.input_size, f"inputting size does not match!!!! input is {x.shape[-1]}, expected: {self.input_size}" 86 | if hidden_state is None: 87 | hidden_state = self.make_init_state(x.shape[0], x.device) 88 | assert len(hidden_state) == self.rnn_num, f"rnn num does not match, input is {len(hidden_state)}, expected: {self.rnn_num}" 89 | x_dim = len(x.shape) 90 | assert x_dim >= 2, f"dim of input is {x_dim}, which < 1" 91 | if x_dim == 2: 92 | x = torch.unsqueeze(x, 0) 93 | aux_dim = -1 94 | if aux_state is not None: 95 | aux_dim = len(aux_state.shape) 96 | if aux_dim == 2: 97 | aux_state = torch.unsqueeze(aux_state, 0) 98 | rnn_count = 0 99 | output_hidden_state = [] 100 | output_rnn = [] 101 | for ind, layer in enumerate(self.layer_list): 102 | if aux_dim > 0: 103 | x = torch.cat((x, aux_state), -1) 104 | activation = self.activation_list[ind] 105 | layer_type = self.layer_type[ind] 106 | if layer_type == 'gru': 107 | _start_time = time.time() 108 | x, h = layer(x, hidden_state[rnn_count]) 109 | _end_time = time.time() 110 | self.cumulative_forward_time += _end_time - _start_time 111 | rnn_count += 1 112 | output_hidden_state.append(h) 113 | if require_full_hidden: 114 | output_rnn.append(x) 115 | else: 116 | _start_time = time.time() 117 | x = layer(x) 118 | _end_time = time.time() 119 | self.cumulative_forward_time += _end_time - _start_time 120 | if activation is not None: 121 | _start_time = time.time() 122 | x = activation(x) 123 | _end_time = time.time() 124 | self.cumulative_forward_time += _end_time - _start_time 125 | if x_dim == 2: 126 | x = torch.squeeze(x, 0) 127 | self.cumulative_meta_forward_time += time.time() - _meta_start_time 128 | if require_full_hidden: 129 | return x, output_hidden_state, output_rnn 130 | return x, output_hidden_state 131 | 132 | def copy_weight_from(self, src_net, tau): 133 | """I am target net, tau ~~ 1 134 | if tau = 0, self <--- src_net 135 | if tau = 1, self <--- self 136 | """ 137 | with torch.no_grad(): 138 | if tau == 0.0: 139 | self.load_state_dict(src_net.state_dict()) 140 | return 141 | elif tau == 1.0: 142 | return 143 | for param, target_param in zip(src_net.parameters(True), self.parameters(True)): 144 | target_param.data.copy_(target_param.data * tau + (1-tau) * param.data) 145 | 146 | def info(self, info): 147 | if self.logger: 148 | self.logger.log(info) 149 | else: 150 | print(info) 151 | 152 | def save(self, path): 153 | if not os.path.exists(os.path.dirname(path)): 154 | os.makedirs(os.path.dirname(path)) 155 | self.info(f'saving model to {path}..') 156 | torch.save(self.state_dict(), path) 157 | 158 | def load(self, path, **kwargs): 159 | self.info(f'loading from {path}..') 160 | map_location = None 161 | if 'map_location' in kwargs: 162 | map_location = kwargs['map_location'] 163 | self.load_state_dict(torch.load(path, map_location=map_location)) 164 | 165 | @staticmethod 166 | def append_hidden_state(hidden_state, data): 167 | for i in range(len(hidden_state)): 168 | if hidden_state[i] is None: 169 | hidden_state[i] = data[i] 170 | elif isinstance(hidden_state[i], tuple): 171 | hidden_state[i] = (torch.cat((hidden_state[i][0], data[i][0]), 1), 172 | torch.cat((hidden_state[i][1], data[i][1]), 1)) 173 | else: 174 | hidden_state[i] = torch.cat((hidden_state[i], data[i]), 1) 175 | return hidden_state 176 | 177 | @staticmethod 178 | def pop_hidden_state(hidden_state): 179 | for i in range(len(hidden_state)): 180 | if hidden_state[i] is not None: 181 | if isinstance(hidden_state[i], tuple): 182 | if hidden_state[i][0].shape[1] == 1: 183 | hidden_state[i] = None 184 | else: 185 | hidden_state[i] = (hidden_state[i][0][:, 1:, :], 186 | hidden_state[i][1][:, 1:, :]) 187 | else: 188 | if hidden_state[i].shape[1] == 1: 189 | hidden_state[i] = None 190 | else: 191 | hidden_state[i] = hidden_state[i][:, 1:, :] 192 | return hidden_state 193 | 194 | @staticmethod 195 | def get_hidden_length(hidden_state): 196 | if len(hidden_state) == 0: 197 | length = 0 198 | elif hidden_state[0] is not None: 199 | if isinstance(hidden_state[0], tuple): 200 | length = hidden_state[0][0].shape[1] 201 | else: 202 | length = hidden_state[0].shape[1] 203 | else: 204 | length = 0 205 | return length 206 | 207 | if __name__ == '__main__': 208 | input_size = 32 209 | output_size = 4 210 | hidden_size = [64, 128, 32] 211 | activations = ["relu", "relu", "relu", "tanh"] 212 | layer_type = ["fc", "fc", "fc", "fc"] 213 | nn = RNNBase(input_size, output_size, hidden_size, activations, layer_type) 214 | init_state = nn.make_init_state(1, device=torch.device("cpu")) 215 | print(init_state, len(init_state)) 216 | hidden_state = init_state 217 | for item in hidden_state: 218 | if isinstance(item, tuple): 219 | for item1 in item: 220 | print(item1.shape) 221 | print('\n') 222 | else: 223 | print(item.shape) 224 | print('\n') 225 | for i in range(10): 226 | print(i) 227 | out, hidden_state = nn.meta_forward(torch.randn((1, 1, input_size)), hidden_state) 228 | 229 | 230 | -------------------------------------------------------------------------------- /envs/grid_world_general.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | import gym 5 | import numpy as np 6 | 7 | 8 | class GridWorldPlat(gym.Env): 9 | """ 10 | map: 11 | --------------------------------- 12 | | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 13 | --------------------------------- 14 | | 1 | 2 | 2 | 2 | 2 | 2 | 2 | 1 | 15 | --------------------------------- 16 | | 1 | 2 | 3 | 9 | 3 | 3 | 2 | 1 | 17 | --------------------------------- 18 | | 1 | 2 | 2 | 2 | 2 | 2 | 2 | 1 | 19 | --------------------------------- 20 | | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 21 | --------------------------------- 22 | action0: left/right moving size [-3, -2, -1, 0, 1, 2, 3] 23 | action1: up/down moving size [-3, -2, -1, 0, 1, 2, 3] 24 | parameter: moving offset (left/right, up/down) 25 | { 26 | 'x': random.randint(-3,3), 27 | 'y': random.randint(-3,3) 28 | } 29 | """ 30 | def __init__(self, env_flag=(0, 0), append_context=False, offset_size=2, width=11, height=11, moving_size=3): 31 | super(gym.Env).__init__() 32 | self.deterministic = True 33 | 34 | self.continuous_action = True 35 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,)) 36 | 37 | self.observation_space = None 38 | self._grid_escape_time = 0 39 | self._grid_max_time = 300 40 | self.env_flag = env_flag 41 | self.append_context = append_context 42 | 43 | self.state_space = 2 44 | self._raw_state_length = self.state_space 45 | if self.append_context: 46 | self.state_space += 2 47 | self.diy_env = True 48 | self.observation_space = gym.spaces.Box(-1, 1, (self.state_space, )) 49 | self.width = width 50 | self.height = height 51 | self.max_offset = offset_size 52 | self.max_moving_step_size = moving_size 53 | self._current_position = (self.width - 1, self.height - 1) 54 | self.rewards = [] 55 | self._init_reward() 56 | self.optimal_policy = None 57 | 58 | def _init_reward(self): 59 | self.rewards = [] 60 | center_x = (self.width - 1) / 2 61 | center_y = (self.height - 1) / 2 62 | max_reward = center_x + center_y 63 | self.rewards = np.zeros((self.height, self.width)) 64 | for x in range(self.width): 65 | for y in range(self.height): 66 | self.rewards[y, x] = max_reward - (np.abs(center_x - x) + np.abs(center_y - y)) 67 | # reward has been modified 68 | # self.rewards[int(center_y), int(center_x)] *= 10 69 | # self.rewards *= 0.1 70 | 71 | def _init_optimal_policy(self): 72 | self.optimal_policy = {} 73 | center_x = (self.width - 1) / 2 74 | center_y = (self.height - 1) / 2 75 | for i in range(self.width): 76 | for j in range(self.height): 77 | desired_action = [(center_x - i) - self.env_flag[0], (center_y - j) - self.env_flag[1]] 78 | if desired_action[0] > self.max_moving_step_size: 79 | desired_action[0] = self.max_moving_step_size 80 | elif desired_action[0] < -self.max_moving_step_size: 81 | desired_action[0] = -self.max_moving_step_size 82 | if desired_action[1] > self.max_moving_step_size: 83 | desired_action[1] = self.max_moving_step_size 84 | elif desired_action[1] < -self.max_moving_step_size: 85 | desired_action[1] = -self.max_moving_step_size 86 | self.optimal_policy[(i, j)] = (desired_action[0], desired_action[1]) 87 | 88 | @property 89 | def context(self): 90 | return [self.env_flag[0] / self.max_offset, self.env_flag[1] / self.max_offset] 91 | 92 | def embed_state(self, state): 93 | center_x = (self.width - 1) / 2 94 | center_y = (self.height - 1) / 2 95 | 96 | return [(self._current_position[0] - center_x) / center_x, (self._current_position[1] - center_y) / center_y] 97 | 98 | def get_next_position(self, action): 99 | x_action = int(action[0] * (self.max_moving_step_size + 1)) 100 | y_action = int(action[1] * (self.max_moving_step_size + 1)) 101 | if x_action > self.max_moving_step_size: 102 | x_action = self.max_moving_step_size 103 | if x_action < -self.max_moving_step_size: 104 | x_action = -self.max_moving_step_size 105 | if y_action > self.max_moving_step_size: 106 | y_action = self.max_moving_step_size 107 | if y_action < -self.max_moving_step_size: 108 | y_action = -self.max_moving_step_size 109 | x_action_origin, y_action_origin = x_action, y_action 110 | x_action += self.env_flag[0] 111 | y_action += self.env_flag[1] 112 | possible_x = self._current_position[0] + x_action 113 | possible_y = self._current_position[1] + y_action 114 | possible_x = np.clip(possible_x, 0, self.width - 1) 115 | possible_y = np.clip(possible_y, 0, self.height - 1) 116 | self._current_position = (int(possible_x), int(possible_y)) 117 | return self._current_position, (x_action_origin, y_action_origin) 118 | 119 | def step(self, action): 120 | self._grid_escape_time += 1 121 | done = False 122 | info = {} 123 | 124 | if self._grid_escape_time >= self._grid_max_time: 125 | done = True 126 | info['optimal_action'] = None if self.optimal_policy is None else \ 127 | self.optimal_policy[(self._current_position[0], self._current_position[1])] 128 | next_position, action_origin = self.get_next_position(action) 129 | reward = self.rewards[int(next_position[1]), int(next_position[0])] 130 | next_state_vector = self.embed_state(self._current_position) 131 | if self.append_context: 132 | next_state_vector += self.context 133 | info['next_optimal_action'] = None if self.optimal_policy is None else \ 134 | self.optimal_policy[(self._current_position[0], self._current_position[1])] 135 | if self.optimal_policy is not None: 136 | info['action_discrepancy'] = (action_origin[0] - info['optimal_action'][0], 137 | action_origin[1] - info['optimal_action'][1],) 138 | if (self._current_position[0], self._current_position[1]) == ( 139 | (self.width - 1) / 2, 140 | (self.height - 1) / 2 141 | ): 142 | info['keep_at_target'] = True 143 | else: 144 | info['keep_at_target'] = False 145 | return next_state_vector, reward, done, info 146 | 147 | def reset(self): 148 | self._grid_escape_time = 0 149 | self._current_position = (random.randint(0, self.width-1), random.randint(0, self.height-1)) 150 | state = self.embed_state(self._current_position) 151 | if self.append_context: 152 | state += self.context 153 | return state 154 | 155 | def seed(self, seed=None): 156 | self.action_space.seed(seed) 157 | 158 | def render(self, mode='human'): 159 | map_rows = [] 160 | nothing_label = 'o' 161 | have_thing_label = '*' 162 | for i in range(self.height): 163 | map_rows.append(nothing_label * self.width + '\n') 164 | map_rows[self._current_position[1]] = nothing_label*self._current_position[0] + have_thing_label\ 165 | + nothing_label*(self.width - self._current_position[0] - 1) + '\n' 166 | map = '' 167 | for i in range(self.height): 168 | map += map_rows[i] 169 | print(map) 170 | 171 | 172 | 173 | class RandomGridWorldPlat(GridWorldPlat): 174 | def __init__(self, append_context=False, offset_size=2, width=11, height=11, moving_size=3): 175 | self.max_offset = offset_size 176 | self.original_max_offset = self.max_offset 177 | self.renv_flag = (random.randint(-self.max_offset, self.max_offset), 178 | random.randint(-self.max_offset, self.max_offset)) 179 | self.fix_env = None 180 | super(RandomGridWorldPlat, self).__init__(self.renv_flag, append_context, offset_size, width, height, moving_size) 181 | 182 | def reset(self): 183 | if self.fix_env is None: 184 | self.renv_flag = (random.randint(-self.max_offset, self.max_offset), 185 | random.randint(-self.max_offset, self.max_offset)) 186 | self.env_flag = self.renv_flag 187 | else: 188 | self.renv_flag = self.env_flag = self.fix_env 189 | self._init_optimal_policy() 190 | return super(RandomGridWorldPlat, self).reset() 191 | 192 | def set_ood(self, is_ood): 193 | if is_ood: 194 | self.max_offset = self.original_max_offset + 1 195 | else: 196 | self.max_offset = self.original_max_offset 197 | def set_fix_env(self, fix_env): 198 | self.renv_flag = self.env_flag = self.fix_env = fix_env 199 | self._init_optimal_policy() 200 | 201 | def set_task(self, task): 202 | self.set_fix_env((task[0], task[0])) 203 | 204 | def sample_tasks(self, n_tasks): 205 | tasks = [] 206 | task_set = set() 207 | while len(task_set) < n_tasks: 208 | task_set.add((random.randint(-self.max_offset, self.max_offset), 209 | random.randint(-self.max_offset, self.max_offset))) 210 | for item in task_set: 211 | tasks.append([item[0], item[1]]) 212 | 213 | return tasks 214 | 215 | @property 216 | def env_parameter_vector_(self): 217 | return self.context 218 | 219 | @property 220 | def env_parameter_length(self): 221 | return 2 222 | 223 | def plot_reward(self): 224 | import matplotlib.patches as patches 225 | import matplotlib.cm as cm 226 | import matplotlib.pyplot as plt 227 | cmap = plt.get_cmap('coolwarm') 228 | def plot_square(ax, x, y, v): 229 | ax.add_patch(patches.Rectangle( 230 | (x , y ), 231 | 1., # width 232 | 1., # height 233 | facecolor=cm.coolwarm(v), 234 | # cmap=cmap, 235 | # c=v, 236 | edgecolor='black' 237 | )) 238 | figure = plt.figure(0, figsize=(4.5, 4/5 * 4.5)) 239 | ax = figure.add_subplot(111) 240 | for x in range(self.width): 241 | for y in range(self.height): 242 | plot_square(ax, x, y, self.rewards[y, x]/np.max(self.rewards)) 243 | if self.rewards[y, x] < 10: 244 | plt.text(x+0.4, y+0.4, f'{int(self.rewards[y, x])}') 245 | else: 246 | plt.text(x+0.25, y+0.4, f'{int(self.rewards[y, x])}') 247 | 248 | plt.scatter([-100, -100], [100, 100], c=[0, 1]) 249 | plt.xlim(left=0, right=11) 250 | plt.ylim(bottom=0, top=11) 251 | norm = plt.Normalize(vmin=0, vmax=10) 252 | plt.colorbar(cm.ScalarMappable(cmap=cmap, norm=norm)) 253 | plt.savefig('grid_demon.pdf') 254 | 255 | plt.show() 256 | 257 | from gym.envs.registration import register 258 | 259 | register( 260 | id='GridWorldPlat-v2', entry_point=RandomGridWorldPlat 261 | ) 262 | 263 | def _main(): 264 | import gym 265 | env = gym.make('GridWorldPlat-v2') 266 | print('observation space: ', env.observation_space) 267 | print('action space: ', env.action_space) 268 | print(hasattr(env, 'rmdm_env_flag')) 269 | print(env.rewards) 270 | env.plot_reward() 271 | exit(0) 272 | for i in range(10): 273 | done = False 274 | state = env.reset() 275 | print('---' * 18) 276 | print(env.renv_flag) 277 | act_space = env.action_space 278 | action_to_set = None 279 | while not done: 280 | action = act_space.sample() 281 | action = action if action_to_set is None else [action_to_set[0] / 3, action_to_set[1] / 3] 282 | state, reward, done, info = env.step(action) 283 | action_to_set = info['next_optimal_action'] 284 | print(info, reward) 285 | env.render() 286 | 287 | if __name__ == '__main__': 288 | _main() 289 | 290 | -------------------------------------------------------------------------------- /algorithms/RMDM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import time 4 | import numpy as np 5 | 6 | class ContrastiveLoss: 7 | def __init__(self, dim, device=torch.device('cpu')): 8 | self.device = device 9 | # self.W = torch.rand((dim, dim), requires_grad=True, device=device) 10 | self.W = torch.eye(dim, requires_grad=True, device=device) 11 | self.w_optim = torch.optim.Adam([self.W], lr=1e-2) 12 | self.loss_func = torch.nn.CrossEntropyLoss() 13 | 14 | def get_loss_meta(self, y, need_w_grad=False): 15 | if need_w_grad: 16 | proj_k = (self.W + self.W.t()).matmul(y.t()) 17 | else: 18 | proj_k = (self.W + self.W.t()).detach().matmul(y.t()) 19 | logits = y.matmul(proj_k) 20 | # print(logits.max(dim=1, keepdim=True).values) 21 | logits = logits - logits.max(dim=1, keepdim=True).values 22 | labels = torch.arange(logits.shape[0]) 23 | # print(logits) 24 | loss = self.loss_func(logits, labels) 25 | return loss 26 | 27 | def get_loss(self, y): 28 | loss_w = self.get_loss_meta(y.detach(), True) 29 | self.w_optim.zero_grad() 30 | loss_w.backward() 31 | self.w_optim.step() 32 | return self.get_loss_meta(y) 33 | 34 | def __call__(self, *args, **kwargs): 35 | return self.get_loss(*args, **kwargs) 36 | 37 | 38 | class constraint: 39 | def __init__(self, alg_name='dpp', w_consistency=1.0, w_diverse=1.0): 40 | self.alg_name = alg_name 41 | self.w_consistency = w_consistency 42 | self.w_diverse = w_diverse 43 | self.contrastive_loss = ContrastiveLoss() if alg_name == 'contrastive' else None 44 | 45 | @staticmethod 46 | def get_loss_dpp(y): 47 | y = y / torch.clamp_min(y.pow(2).mean(dim=1, keepdim=True).sqrt(), 1e-5) 48 | K = (y.matmul(y.t()) - 1).exp() + torch.eye(y.shape[0], device=y.device) * 1e-3 49 | loss = -torch.logdet(K) 50 | if torch.isnan(loss).any(): 51 | print(K) 52 | return loss 53 | 54 | def get_loss_contrastive(self, y): 55 | pass 56 | 57 | def get_loss(self, predicted_env_vector, tasks, valid): 58 | tasks = tasks[..., 0, 0] 59 | all_tasks = torch.unique(tasks).detach().cpu().numpy().tolist() 60 | if len(all_tasks) <= 1: 61 | return None, None, None 62 | total_trasition_num = valid.sum() 63 | all_predicted_env_vectors = [] 64 | all_valids = [] 65 | mean_vector = [] 66 | var_vector = [] 67 | valid_num_list = [] 68 | masks_list = [] 69 | real_all_tasks = [] 70 | 71 | for item in all_tasks: 72 | if item == 0: 73 | continue 74 | masks = tasks == item 75 | valid_it = valid[masks] 76 | if valid_it.sum() == 0: 77 | continue 78 | masks_list.append(masks) 79 | all_valids.append(valid_it) 80 | real_all_tasks.append(item) 81 | if len(all_tasks) <= 1: 82 | return None, None, None 83 | all_tasks = real_all_tasks 84 | for ind, item in enumerate(all_tasks): 85 | masks = masks_list[ind] 86 | valid_it = all_valids[ind] 87 | env_vector_it = predicted_env_vector[masks] 88 | all_predicted_env_vectors.append(env_vector_it) 89 | point_num = valid_it.sum() 90 | assert point_num > 0, 'trajectory should not be empty!!!!' 91 | # print(env_vector_it.shape, valid_it.shape) 92 | repre_it = (env_vector_it * valid_it).sum(1, keepdim=True).sum(0, keepdim=True) / point_num 93 | mean_vector.append(repre_it) 94 | var_it = ((env_vector_it - repre_it.detach()) * valid_it).pow(2).sum() / point_num / \ 95 | predicted_env_vector.shape[-1] 96 | var_vector.append(var_it) 97 | valid_num_list.append(point_num) 98 | ##### consistency loss 99 | consistency_loss = sum([a1 * a2 for a1, a2 in zip(var_vector, valid_num_list)]) / total_trasition_num 100 | ##### use DPP loss 101 | repres = [item.reshape(1, -1) for item in mean_vector] 102 | repre_tensor = torch.cat(repres, 0) 103 | if self.alg_name == 'dpp': 104 | diverse_loss = self.get_loss_dpp(repre_tensor) 105 | elif self.alg_name == 'contrastive': 106 | diverse_loss = self.contrastive_loss(repre_tensor) 107 | else: 108 | raise NotImplementedError(f'{self.alg_name} has not been implemented!!!') 109 | 110 | constraint_loss = self.w_consistency * consistency_loss + diverse_loss * self.w_diverse 111 | # print(consistency_loss, dpp_loss) 112 | 113 | return constraint_loss, consistency_loss, diverse_loss 114 | 115 | 116 | def get_rbf_matrix(data, centers, alpha, element_wise_exp=False): 117 | out_shape = torch.Size([data.shape[0], centers.shape[0], data.shape[-1]]) 118 | data = data.unsqueeze(1).expand(out_shape) 119 | centers = centers.unsqueeze(0).expand(out_shape) 120 | if element_wise_exp: 121 | mtx = (-(centers - data).pow(2) * alpha).exp().mean(dim=-1, keepdim=False) 122 | else: 123 | mtx = (-(centers - data).pow(2) * alpha).sum(dim=-1, keepdim=False).exp() 124 | return mtx 125 | 126 | 127 | def get_loss_dpp(y, kernel='rbf', rbf_radius=3000.0): 128 | # K = (y.matmul(y.t()) - 1).exp() + torch.eye(y.shape[0]) * 1e-3 129 | if kernel == 'rbf': 130 | K = get_rbf_matrix(y, y, alpha=rbf_radius, element_wise_exp=False) + torch.eye(y.shape[0], device=y.device) * 1e-3 131 | elif kernel == 'rbf_element_wise': 132 | K = get_rbf_matrix(y, y, alpha=rbf_radius, element_wise_exp=True) + torch.eye(y.shape[0], device=y.device) * 1e-3 133 | elif kernel == 'inner': 134 | # y = y / y.pow(2).sum(dim=-1, keepdim=True).sqrt() 135 | K = y.matmul(y.t()).exp() 136 | # K = torch.softmax(K, dim=0) 137 | K = K + torch.eye(y.shape[0], device=y.device) * 1e-3 138 | print(K) 139 | # print('k shape: ', K.shape, ', y_mtx shape: ', y_mtx.shape) 140 | else: 141 | assert False 142 | loss = -torch.logdet(K) 143 | # loss = -(y.pinverse().t().detach() * y).sum() 144 | return loss 145 | 146 | 147 | def get_loss_cov(y): 148 | cov = (y - y.mean(dim=0, keepdim=True)).pow(2).mean() 149 | return -torch.log(cov + 1e-4) 150 | 151 | class RMDMLoss: 152 | def __init__(self, tau=0.995, target_consistency_metric=-4.0, target_diverse_metric=None, max_env_len=40): 153 | self.mean_vector = {} 154 | self.tau = tau 155 | self.target_consistency_metric = target_consistency_metric 156 | self.target_diverse_metric = target_diverse_metric 157 | self.lst_tasks = [] 158 | self.max_env_len = max_env_len 159 | self.current_env_mean = None 160 | self.history_env_mean = None 161 | 162 | def construct_loss(self, consistency_loss, diverse_loss, consis_w, diverse_w, std): 163 | consis_w_loss = None 164 | divers_w_loss = None 165 | if isinstance(consis_w, torch.Tensor): 166 | rmdm_loss_it = consis_w.detach() * consistency_loss + diverse_loss * diverse_w.detach() 167 | if std >= 1e-1: 168 | rmdm_loss_it = consis_w.detach() * consistency_loss 169 | # alpha_loss = (alpha[0] * (target - current).detach()).mean() 170 | 171 | if self.target_consistency_metric is not None: 172 | consis_w_loss = consis_w * ((self.target_consistency_metric - consistency_loss.detach()).detach().mean()) 173 | if self.target_diverse_metric is not None: 174 | divers_w_loss = diverse_w * ((self.target_diverse_metric - diverse_loss.detach()).detach().mean()) 175 | pass 176 | else: 177 | rmdm_loss_it = consis_w * consistency_loss + diverse_loss * diverse_w 178 | if std >= 1e-1: 179 | rmdm_loss_it = consis_w * consistency_loss 180 | return rmdm_loss_it, consis_w_loss, divers_w_loss 181 | 182 | def rmdm_loss(self, predicted_env_vector, tasks, valid, consis_w, diverse_w, need_all_repre=False, 183 | need_parameter_loss=False, rbf_radius=3000.0, 184 | kernel_type='rbf'): 185 | tasks = torch.max(tasks[..., 0, 0], tasks[..., -1, 0]) 186 | all_tasks = torch.unique(tasks).detach().cpu().numpy().tolist() 187 | if len(all_tasks) <= 1: 188 | print(f'current task num: {len(all_tasks)}, {all_tasks}') 189 | return None, None, None, 0 190 | total_trasition_num = valid.sum() 191 | all_predicted_env_vectors = [] 192 | all_valids = [] 193 | mean_vector = [] 194 | var_vector = [] 195 | valid_num_list = [] 196 | masks_list = [] 197 | real_all_tasks = [] 198 | 199 | for item in all_tasks: 200 | if item == 0: 201 | continue 202 | masks = tasks == item 203 | valid_it = valid[masks] 204 | if valid_it.sum() == 0: 205 | continue 206 | masks_list.append(masks) 207 | all_valids.append(valid_it) 208 | real_all_tasks.append(item) 209 | if len(all_tasks) <= 1: 210 | print(f'current task num: {len(all_tasks)}, {all_tasks}') 211 | return None, None, None, 0 212 | # print(f'task num: {len(all_tasks)}, env_vector: {predicted_env_vector.shape}') 213 | all_tasks = real_all_tasks 214 | self.lst_tasks = copy.deepcopy(real_all_tasks) 215 | dpp_inner = [] 216 | use_dpp_inner = False 217 | for ind, item in enumerate(all_tasks): 218 | masks = masks_list[ind] 219 | valid_it = all_valids[ind] 220 | env_vector_it = predicted_env_vector[masks] 221 | all_predicted_env_vectors.append(env_vector_it) 222 | point_num = valid_it.sum() 223 | assert point_num > 0, 'trajectory should not be empty!!!!' 224 | repre_it = (env_vector_it * valid_it).sum(1, keepdim=True).sum(0, keepdim=True) / point_num 225 | if item not in self.mean_vector: 226 | self.mean_vector[item] = repre_it.detach() 227 | else: 228 | self.mean_vector[item] = ((repre_it.detach() * (1-self.tau)) + self.mean_vector[item] * self.tau).detach() 229 | mean_vector.append(repre_it) 230 | var_it = ((env_vector_it - self.mean_vector[item]) * valid_it).pow(2).sum() / point_num / predicted_env_vector.shape[-1] 231 | var_vector.append(var_it) 232 | valid_num_list.append(point_num) 233 | ##### consistency loss 234 | var = sum([a1 * a2 for a1, a2 in zip(var_vector, valid_num_list)]) / total_trasition_num 235 | stds = var.sqrt() 236 | consistency_loss = stds # + 1e-4) 237 | if stds < 1e-3: 238 | consistency_loss = consistency_loss.detach() 239 | ##### use DPP loss 240 | repres = [item.reshape(1, -1) for item in mean_vector] 241 | for item in self.mean_vector: 242 | if item not in all_tasks: 243 | repres.append(self.mean_vector[item].reshape(1, -1)) 244 | repre_tensor = torch.cat(repres, 0) 245 | dpp_loss = get_loss_dpp(repre_tensor, rbf_radius=rbf_radius, kernel=kernel_type) 246 | rmdm_loss_it, consis_w_loss, diverse_w_loss = self.construct_loss(consistency_loss, dpp_loss, consis_w, diverse_w, stds.item()) 247 | # rmdm_loss_it = dpp_loss + consistency_loss 248 | # print(consistency_loss, dpp_loss) 249 | if need_parameter_loss: 250 | if need_all_repre: 251 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), consis_w_loss, diverse_w_loss, all_predicted_env_vectors, all_valids 252 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), consis_w_loss, diverse_w_loss 253 | 254 | if need_all_repre: 255 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), all_predicted_env_vectors, all_valids 256 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks) 257 | 258 | def rmdm_loss_timing(self, predicted_env_vector, tasks, valid, 259 | consis_w, diverse_w, need_all_repre=False, 260 | need_parameter_loss=False, cum_time=[], rbf_radius=3000.0, 261 | kernel_type='rbf'): 262 | if self.current_env_mean is None: 263 | self.current_env_mean = torch.zeros((self.max_env_len, 1, predicted_env_vector.shape[-1]), device=predicted_env_vector.device) 264 | self.history_env_mean = torch.zeros((self.max_env_len, 1, predicted_env_vector.shape[-1]), device=predicted_env_vector.device) 265 | tasks = tasks[..., -1, 0] # torch.max(tasks[..., 0, 0], ) 266 | tasks_sorted, indices = torch.sort(tasks) 267 | tasks_sorted_np = tasks_sorted.detach().cpu().numpy().reshape((-1)) 268 | task_ind_map = {} 269 | tasks_sorted_np_idx = np.where(np.diff(tasks_sorted_np))[0] + 1 270 | last_ind = 0 271 | for i, item in enumerate(tasks_sorted_np_idx): 272 | task_ind_map[tasks_sorted_np[item-1]] = [last_ind, item] 273 | last_ind = item 274 | if i == len(tasks_sorted_np_idx) - 1: 275 | task_ind_map[tasks_sorted_np[-1]] = [last_ind, len(tasks_sorted_np)] 276 | predicted_env_vector = predicted_env_vector[indices] 277 | # remove the invalid data 278 | if 0 in task_ind_map: 279 | predicted_env_vector = predicted_env_vector[task_ind_map[0][1]:] 280 | start_ind = task_ind_map[0][1] 281 | task_ind_map.pop(0) 282 | for k in task_ind_map: 283 | task_ind_map[k][0] -= start_ind 284 | task_ind_map[k][1] -= start_ind 285 | # finish preprocess the data 286 | # def update_cum_time(time_last, time_count, cum_time): 287 | # # if len(cum_time) < time_count + 1: 288 | # # cum_time.append(time.time() - time_last) 289 | # # else: 290 | # # cum_time[time_count] += time.time() - time_last 291 | # # time_last = time.time() 292 | # # time_count += 1 293 | # return time_last, time_count 294 | if len(task_ind_map) <= 1: 295 | print(f'current task num: {len(task_ind_map)}, {task_ind_map}') 296 | return None, None, None, 0 297 | total_trasition_num = predicted_env_vector.shape[0] 298 | all_valids, mean_vector, valid_num_list, all_predicted_env_vectors = [], [], [], [] 299 | real_all_tasks = sorted(list(task_ind_map.keys())) 300 | all_tasks, self.lst_tasks = real_all_tasks, real_all_tasks 301 | use_history_mean = True 302 | for ind, item in enumerate(all_tasks): 303 | env_vector_it = predicted_env_vector[task_ind_map[item][0]:task_ind_map[item][1]] 304 | if need_all_repre: 305 | all_predicted_env_vectors.append(env_vector_it) 306 | point_num = env_vector_it.shape[0] 307 | repre_it = env_vector_it.mean(dim=0, keepdim=True) 308 | if item not in self.mean_vector: 309 | with torch.no_grad(): 310 | self.history_env_mean[int(item-1)] = repre_it 311 | self.current_env_mean[int(item-1)] = repre_it 312 | mean_vector.append(repre_it) 313 | valid_num_list.append(point_num) 314 | valid_num_tensor = torch.from_numpy(np.array(valid_num_list)).to(device=valid.device, 315 | dtype=torch.get_default_dtype()).reshape((-1, 1, 1)) 316 | task_set = set(all_tasks) 317 | with torch.no_grad(): 318 | for k in self.mean_vector: 319 | if k not in task_set: 320 | self.current_env_mean[int(k-1)] = self.history_env_mean[int(k-1)] 321 | self.current_env_mean = self.current_env_mean.detach() 322 | self.history_env_mean = self.history_env_mean * self.tau + (1-self.tau) * self.current_env_mean 323 | for item in all_tasks: 324 | if item not in self.mean_vector: 325 | self.mean_vector[item] = 1 326 | ##### use DPP loss 327 | repres = [item[0] for item in mean_vector] 328 | valid_repres_len = len(repres) 329 | for item in self.mean_vector: 330 | if item not in task_set: 331 | repres.append(self.history_env_mean[int(item-1)]) 332 | repre_tensor = torch.cat(repres, 0) 333 | dpp_loss = get_loss_dpp(repre_tensor, kernel=kernel_type, rbf_radius=rbf_radius) 334 | ##### consistency loss 335 | # total minus outter 336 | if not use_history_mean: 337 | with torch.no_grad(): 338 | total_mean = ((repre_tensor[:valid_repres_len] * valid_num_tensor).sum(dim=0, keepdim=True) / total_trasition_num) 339 | total_outter_var = ((repre_tensor[:valid_repres_len] - total_mean).pow(2) * valid_num_tensor).sum(dim=0, keepdim=True) / total_trasition_num 340 | total_var = (predicted_env_vector - total_mean).pow(2).mean(dim=0, keepdim=True) 341 | var = max(total_var.mean() - total_outter_var.mean(), 0) 342 | ###################### 343 | # summation of inner 344 | else: 345 | total_var = 0 346 | for ind, item in enumerate(all_tasks): 347 | mean_vector = self.history_env_mean[int(item-1)] 348 | if need_all_repre: 349 | env_vector_it = all_predicted_env_vectors[ind] 350 | else: 351 | env_vector_it = predicted_env_vector[task_ind_map[item][0]:task_ind_map[item][1]] 352 | var_it = (env_vector_it - mean_vector.detach()).pow(2).sum(dim=0, keepdim=True).mean() 353 | total_var = total_var + var_it 354 | var = total_var / total_trasition_num 355 | ##################### 356 | 357 | stds = var.sqrt() 358 | consistency_loss = stds # + 1e-4) 359 | if stds < 1e-3: 360 | consistency_loss = consistency_loss.detach() 361 | rmdm_loss_it, consis_w_loss, diverse_w_loss = self.construct_loss(consistency_loss, dpp_loss, consis_w, 362 | diverse_w, stds.item()) 363 | # rmdm_loss_it = consistency_loss + dpp_loss 364 | if need_parameter_loss: 365 | if need_all_repre: 366 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), consis_w_loss, diverse_w_loss, all_predicted_env_vectors, all_valids 367 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), consis_w_loss, diverse_w_loss 368 | if need_all_repre: 369 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks), all_predicted_env_vectors, all_valids 370 | return rmdm_loss_it, consistency_loss, dpp_loss, len(all_tasks) 371 | 372 | 373 | 374 | 375 | -------------------------------------------------------------------------------- /envs/nonstationary_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import Wrapper 3 | import numpy as np 4 | import copy 5 | 6 | # this file is referred to https://github.com/dennisl88/rand_param_envs 7 | 8 | class NonstationaryEnv(Wrapper): 9 | RAND_PARAMS = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction', 'gravity', 'density', 10 | 'wind', 'geom_friction_1_dim', 'dof_damping_1_dim'] 11 | RAND_PARAMS_EXTENDED = RAND_PARAMS + ['geom_size'] 12 | 13 | def __init__(self, env, rand_params=['gravity'], log_scale_limit=3.0): 14 | super().__init__(env) 15 | self.is_diy_env = hasattr(env, 'diy_env') 16 | if len(rand_params) == 1 and rand_params[0] == 'None': 17 | rand_params = [] 18 | if self.is_diy_env: 19 | rand_params = [] 20 | self.normalize_context = True 21 | self.log_scale_limit = log_scale_limit 22 | self.rand_params = rand_params 23 | self.save_parameters() 24 | self.min_param, self.max_param = self.get_minmax_parameter(log_scale_limit) 25 | self.cur_parameter_vector = self.env_parameter_vector_ 26 | self.cur_step_ind = 0 27 | # for non-stationary changing 28 | self.setted_env_params = None 29 | self.setted_env_changing_period = None 30 | self.setted_env_changing_interval = None 31 | self.min_action = env.action_space.low 32 | self.max_action = env.action_space.high 33 | self.range_action = self.max_action - self.min_action 34 | self._debug_state = None 35 | 36 | def get_minmax_parameter(self, log_scale_limit): 37 | min_param = {} 38 | max_param = {} 39 | bound = lambda x, y: np.array(1.5) ** (np.ones(shape=x) * ((-1 if y == 'low' else 1) * log_scale_limit)) 40 | if 'body_mass' in self.rand_params: 41 | min_multiplyers = bound(self.model.body_mass.shape, 'low') 42 | max_multiplyers = bound(self.model.body_mass.shape, 'high') 43 | min_param['body_mass'] = self.init_params['body_mass'] * min_multiplyers 44 | max_param['body_mass'] = self.init_params['body_mass'] * max_multiplyers 45 | 46 | # body_inertia 47 | if 'body_inertia' in self.rand_params: 48 | min_multiplyers = bound(self.model.body_inertia.shape, 'low') 49 | max_multiplyers = bound(self.model.body_inertia.shape, 'high') 50 | min_param['body_inertia'] = self.init_params['body_inertia'] * min_multiplyers 51 | max_param['body_inertia'] = self.init_params['body_inertia'] * max_multiplyers 52 | 53 | # damping -> different multiplier for different dofs/joints 54 | if 'dof_damping' in self.rand_params: 55 | min_multiplyers = bound(self.model.dof_damping.shape, 'low') 56 | max_multiplyers = bound(self.model.dof_damping.shape, 'high') 57 | min_param['dof_damping'] = self.init_params['dof_damping'] * min_multiplyers 58 | max_param['dof_damping'] = self.init_params['dof_damping'] * max_multiplyers 59 | 60 | # friction at the body components 61 | if 'geom_friction' in self.rand_params: 62 | min_multiplyers = bound(self.model.geom_friction.shape, 'low') 63 | max_multiplyers = bound(self.model.geom_friction.shape, 'high') 64 | min_param['geom_friction'] = self.init_params['geom_friction'] * min_multiplyers 65 | max_param['geom_friction'] = self.init_params['geom_friction'] * max_multiplyers 66 | 67 | if 'geom_friction_1_dim' in self.rand_params: 68 | min_multiplyers = bound((1,), 'low') 69 | max_multiplyers = bound((1,), 'high') 70 | min_param['geom_friction_1_dim'] = np.array([min_multiplyers]) 71 | max_param['geom_friction_1_dim'] = np.array([max_multiplyers]) 72 | 73 | if 'dof_damping_1_dim' in self.rand_params: 74 | min_multiplyers = bound((1,), 'low') 75 | max_multiplyers = bound((1,), 'high') 76 | min_param['dof_damping_1_dim'] = np.array([min_multiplyers]) 77 | max_param['dof_damping_1_dim'] = np.array([max_multiplyers]) 78 | 79 | if 'gravity' in self.rand_params: 80 | min_multiplyers = bound(self.model.opt.gravity.shape, 'low') 81 | max_multiplyers = bound(self.model.opt.gravity.shape, 'high') 82 | min_param['gravity'] = self.init_params['gravity'] * min_multiplyers 83 | max_param['gravity'] = self.init_params['gravity'] * max_multiplyers 84 | 85 | if 'gravity_angle' in self.rand_params: 86 | min_param['gravity'][:2] = min_param['gravity'][2] 87 | max_param['gravity'][:2] = max_param['gravity'][2] 88 | 89 | if 'wind' in self.rand_params: 90 | min_param['wind'] = np.array([-log_scale_limit, -log_scale_limit]) 91 | max_param['wind'] = np.array([log_scale_limit, log_scale_limit]) 92 | 93 | if 'density' in self.rand_params: 94 | min_multiplyers = bound((1,), 'low') 95 | max_multiplyers = bound((1,), 'high') 96 | min_param['density'] = self.init_params['density'] * min_multiplyers 97 | max_param['density'] = self.init_params['density'] * max_multiplyers 98 | 99 | for key in min_param: 100 | min_it = min_param[key] 101 | max_it = max_param[key] 102 | min_real = np.min([min_it, max_it], 0) 103 | max_real = np.max([max_it, min_it], 0) 104 | min_param[key] = min_real 105 | max_param[key] = max_real 106 | return min_param, max_param 107 | 108 | def denormalization(self, action): 109 | return (action + 1) / 2 * self.range_action + self.min_action 110 | 111 | def normalization(self, action): 112 | return (action - self.min_action) / self.range_action * 2 - 1 113 | 114 | def step(self, action): 115 | self.cur_step_ind += 1 116 | if self.setted_env_params is not None and self.cur_step_ind % self.setted_env_changing_interval == 0: 117 | assert isinstance(self.setted_env_params, list) 118 | env_to_be = {} 119 | weight_origin = self.cur_step_ind / self.setted_env_changing_period 120 | # weight = min(weight_origin, 1) 121 | weight_in_duration = weight_origin - (weight_origin // 2 * 2) 122 | if weight_in_duration <= 1: 123 | weight = weight_in_duration 124 | else: 125 | weight = 2 - weight_in_duration 126 | ind = int(weight_origin) 127 | if isinstance(self.setted_env_params[0], dict): 128 | for key in self.setted_env_params[0]: 129 | # env_to_be[key] = (1 - weight) * self.setted_env_params[0][key] + weight * self.setted_env_params[1][key] 130 | # env_to_be[key] = self.setted_env_params[ind][key] if weight_in_duration <= 1 else self.setted_env_params[ind][key] 131 | env_to_be[key] = self.setted_env_params[ind][key] 132 | elif isinstance(self.setted_env_params[0], int): 133 | env_to_be = self.setted_env_params[ind] 134 | elif isinstance(self.setted_env_params[0], list): 135 | env_to_be = copy.deepcopy(self.setted_env_params[ind]) 136 | else: 137 | raise NotImplementedError(f'type of {type(self.setted_env_params[ind])} is not implemented.') 138 | self.set_task(env_to_be) 139 | try: 140 | res = super(NonstationaryEnv, self).step(action) 141 | self._debug_state = res[0] 142 | return res 143 | except Exception as e: 144 | print(e) 145 | print('Inf or NaN found in state or action!!') 146 | print('Action: ', action) 147 | print('Param: ', self.env_parameter_vector_) 148 | print('Current state: ', self._debug_state) 149 | param_variable = getattr(self.unwrapped.model, 'dof_damping') 150 | print('Current parameter: ', self.cur_params) 151 | print('Dof_damping: ', param_variable) 152 | return self._debug_state, 0, True, {} 153 | 154 | # raise e 155 | # return super(NonstationaryEnv, self).step(action) 156 | 157 | def set_nonstationary_para(self, setting_env_params, changine_period, changing_interval): 158 | self.setted_env_changing_period = changine_period 159 | self.setted_env_params = setting_env_params 160 | self.setted_env_changing_interval = changing_interval 161 | 162 | def reset_nonstationary(self): 163 | self.set_nonstationary_para(None, None, None) 164 | 165 | def reset(self, **kwargs): 166 | self.cur_step_ind = 0 167 | return super(NonstationaryEnv, self).reset(**kwargs) 168 | 169 | def sample_tasks(self, n_tasks, dig_range=None, linspace=False): 170 | """ 171 | Generates randomized parameter sets for the mujoco env 172 | Args: 173 | n_tasks (int) : number of different meta-tasks needed 174 | Returns: 175 | tasks (list) : an (n_tasks) length list of tasks 176 | """ 177 | if self.is_diy_env: 178 | return self.env.sample_tasks(n_tasks) 179 | 180 | current_task_count_ = 0 181 | param_sets = [] 182 | if dig_range is None: 183 | if linspace: 184 | def uniform_function(low_, up_, size): 185 | res = [0] * np.prod(size) 186 | interval = (up_ - low_) / (n_tasks - 1) 187 | for i in range(len(res)): 188 | res[i] = interval * current_task_count_ + low_ 189 | res = np.array(res).reshape(size) 190 | return res 191 | uniform = uniform_function 192 | else: 193 | uniform = lambda low_,up_,size: np.random.uniform(low_, up_, size=size) 194 | else: 195 | dig_range = np.abs(dig_range) 196 | def uniform_function(low_, up_, size): 197 | res = [0] * np.prod(size) 198 | for i in range(len(res)): 199 | if linspace: 200 | if current_task_count_ >= n_tasks // 2: 201 | interval = (up_ - dig_range) / (n_tasks // 2) 202 | # res[i] = interval * (current_task_count_ - n_tasks // 2 + 1) + dig_range 203 | res[i] = interval * (current_task_count_ - n_tasks // 2 ) + dig_range 204 | else: 205 | interval = (-dig_range - low_) / (n_tasks // 2) 206 | # res[i] = interval * (n_tasks // 2 - current_task_count_ - 1) + low_ 207 | res[i] = interval * (n_tasks // 2 - current_task_count_) + low_ 208 | else: 209 | while True: 210 | rand = np.random.uniform(low_, up_) 211 | if rand > dig_range or rand < -dig_range: 212 | res[i] = rand 213 | break 214 | res = np.array(res).reshape(size) 215 | return res 216 | uniform = uniform_function 217 | bound = lambda x: np.array(1.5) ** uniform(-self.log_scale_limit, self.log_scale_limit, x) 218 | bound_uniform = lambda x: uniform(-self.log_scale_limit, self.log_scale_limit, x) 219 | for _ in range(n_tasks): 220 | # body mass -> one multiplier for all body parts 221 | new_params = {} 222 | 223 | if 'body_mass' in self.rand_params: 224 | body_mass_multiplyers = bound(self.model.body_mass.shape) 225 | new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers 226 | 227 | # body_inertia 228 | if 'body_inertia' in self.rand_params: 229 | body_inertia_multiplyers = bound(self.model.body_inertia.shape) 230 | new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia'] 231 | 232 | # damping -> different multiplier for different dofs/joints 233 | if 'dof_damping' in self.rand_params: 234 | dof_damping_multipliers = bound(self.model.dof_damping.shape) 235 | new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers) 236 | 237 | # friction at the body components 238 | if 'geom_friction' in self.rand_params: 239 | dof_damping_multipliers = bound(self.model.geom_friction.shape) 240 | new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers) 241 | 242 | if 'geom_friction_1_dim' in self.rand_params: 243 | geom_friction_1_dim_multipliers = bound((1,)) 244 | new_params['geom_friction_1_dim'] = geom_friction_1_dim_multipliers 245 | 246 | if 'dof_damping_1_dim' in self.rand_params: 247 | dof_damping_1_dim_multipliers = bound((1,)) 248 | new_params['dof_damping_1_dim'] = dof_damping_1_dim_multipliers 249 | 250 | if 'gravity' in self.rand_params: 251 | gravity_mutipliers = bound(self.model.opt.gravity.shape) 252 | new_params['gravity'] = np.multiply(self.init_params['gravity'], gravity_mutipliers) 253 | 254 | if 'gravity_angle' in self.rand_params: 255 | min_angle = - self.log_scale_limit * np.array([1, 1]) / 8 256 | max_angle = self.log_scale_limit * np.array([1, 1]) / 8 257 | angle = np.random.uniform(min_angle, max_angle) 258 | new_params['gravity'][0] = new_params['gravity'][2] * np.sin(angle[0]) * np.sin(angle[1]) 259 | new_params['gravity'][1] = new_params['gravity'][2] * np.sin(angle[0]) * np.cos(angle[1]) 260 | new_params['gravity'][2] *= np.cos(angle[0]) 261 | 262 | if 'wind' in self.rand_params: 263 | new_params['wind'] = bound_uniform((2, )) 264 | 265 | if 'density' in self.rand_params: 266 | density_mutipliers = bound((1,)) 267 | new_params['density'] = np.multiply(self.init_params['density'], density_mutipliers) 268 | param_sets.append(new_params) 269 | current_task_count_ += 1 270 | 271 | 272 | return param_sets 273 | 274 | def cross_params(self, param_a, param_b): 275 | param_res = [] 276 | for item_a in param_a: 277 | for item_b in param_b: 278 | r = dict() 279 | for k, v in item_a.items(): 280 | r[k] = v 281 | for k, v in item_b.items(): 282 | r[k] = v 283 | param_res.append(r) 284 | return param_res 285 | 286 | def set_task(self, task): 287 | if self.is_diy_env: 288 | self.env.set_task(task) 289 | self.cur_parameter_vector = self.env_parameter_vector_ 290 | return 291 | for param, param_val in task.items(): 292 | if param == 'gravity_angle': 293 | continue 294 | if param == 'gravity': 295 | param_variable = getattr(self.unwrapped.model.opt, param) 296 | elif param == 'density': 297 | self.unwrapped.model.opt.density = float(param_val[0]) 298 | continue 299 | elif param == 'wind': 300 | param_variable = getattr(self.unwrapped.model.opt, param) 301 | param_variable[:2] = param_val 302 | continue 303 | elif param == 'geom_friction_1_dim': 304 | param_variable = getattr(self.unwrapped.model, 'geom_friction') 305 | param_variable[:] = self.init_params[param][:] * param_val 306 | continue 307 | elif param == 'dof_damping_1_dim': 308 | param_variable = getattr(self.unwrapped.model, 'dof_damping') 309 | param_variable[:] = self.init_params[param][:] * param_val 310 | continue 311 | else: 312 | param_variable = getattr(self.unwrapped.model, param) 313 | assert param_variable.shape == param_val.shape, 'shapes of new parameter value and old one must match' 314 | param_variable[:] = param_val 315 | self.cur_params = task 316 | self.cur_parameter_vector = self.env_parameter_vector_ 317 | 318 | def get_task(self): 319 | return self.cur_params 320 | 321 | def save_parameters(self): 322 | self.init_params = {} 323 | if 'body_mass' in self.rand_params: 324 | self.init_params['body_mass'] = self.unwrapped.model.body_mass 325 | 326 | # body_inertia 327 | if 'body_inertia' in self.rand_params: 328 | self.init_params['body_inertia'] = self.unwrapped.model.body_inertia 329 | 330 | # damping -> different multiplier for different dofs/joints 331 | if 'dof_damping' in self.rand_params: 332 | self.init_params['dof_damping'] = np.array(self.unwrapped.model.dof_damping).copy() 333 | 334 | # friction at the body components 335 | if 'geom_friction' in self.rand_params: 336 | self.init_params['geom_friction'] = np.array(self.unwrapped.model.geom_friction).copy() 337 | 338 | if 'geom_friction_1_dim' in self.rand_params: 339 | self.init_params['geom_friction_1_dim'] = np.array(self.unwrapped.model.geom_friction).copy() 340 | 341 | if 'dof_damping_1_dim' in self.rand_params: 342 | self.init_params['dof_damping_1_dim'] = np.array(self.unwrapped.model.dof_damping).copy() 343 | 344 | if 'gravity' in self.rand_params: 345 | self.init_params['gravity'] = self.unwrapped.model.opt.gravity 346 | 347 | if 'wind' in self.rand_params: 348 | self.init_params['wind'] = self.unwrapped.model.opt.wind[:2] 349 | 350 | if 'density' in self.rand_params: 351 | self.init_params['density'] = np.array([self.unwrapped.model.opt.density]) 352 | 353 | self.cur_params = copy.deepcopy(self.init_params) 354 | if 'dof_damping_1_dim' in self.cur_params: 355 | self.cur_params['dof_damping_1_dim'] = np.array([1.0]) 356 | if 'geom_friction_1_dim' in self.cur_params: 357 | self.cur_params['geom_friction_1_dim'] = np.array([1.0]) 358 | 359 | 360 | @property 361 | def env_parameter_vector(self): 362 | return self.cur_parameter_vector 363 | 364 | @property 365 | def env_parameter_vector_(self): 366 | if self.is_diy_env: 367 | return self.env.env_parameter_vector_ 368 | keys = [key for key in self.rand_params] 369 | if len(keys) == 0: 370 | return [] 371 | vec_ = [self.cur_params[key].reshape(-1,) for key in keys] 372 | cur_vec = np.hstack(vec_) 373 | if not self.normalize_context: 374 | return cur_vec 375 | vec_range = self.param_max - self.param_min 376 | vec_range[vec_range == 0] = 1.0 377 | cur_vec = (cur_vec - self.param_min) / vec_range 378 | return cur_vec 379 | 380 | @property 381 | def env_parameter_length(self): 382 | if self.is_diy_env: 383 | return self.env.env_parameter_length 384 | length = np.sum([np.shape(self.cur_params[key].reshape(-1, ))[-1] for key in self.cur_params]) 385 | return length 386 | 387 | @property 388 | def param_max(self): 389 | keys = [key for key in self.rand_params] 390 | vec_ = [self.max_param[key].reshape(-1,) for key in keys] 391 | if len(vec_) == 0: 392 | return [] 393 | return np.hstack(vec_) 394 | 395 | @property 396 | def param_min(self): 397 | keys = [key for key in self.rand_params] 398 | vec_ = [self.min_param[key].reshape(-1, ) for key in keys] 399 | if len(vec_) == 0: 400 | return [] 401 | return np.hstack(vec_) 402 | 403 | @property 404 | def _elapsed_steps(self): 405 | return self.cur_step_ind 406 | 407 | @property 408 | def _max_episode_steps(self): 409 | if hasattr(self.env, '_max_episode_steps'): 410 | return self.env._max_episode_steps 411 | return 1000 412 | 413 | 414 | if __name__ == '__main__': 415 | from grid_world import GridWorld 416 | from grid_world_general import RandomGridWorldPlat 417 | env = NonstationaryEnv(gym.make('Humanoid-v2'), ['dof_damping_1_dim', 'gravity', 'body_mass', 'geom_friction', 'density']) 418 | # env = NonstationaryEnv(gym.make('GridWorldPlat-v2'), ['dof_damping_1_dim']) 419 | print(env.param_min.shape) 420 | #env2 = 421 | #print(env2.metadata) 422 | env.reset() 423 | tasks = env.sample_tasks(20) 424 | print(tasks[0]) 425 | env.set_task(tasks[0]) 426 | env.set_nonstationary_para(tasks, 100, 10) 427 | print(tasks) 428 | for i in range(10000): 429 | state, reward, done, _ = env.step(env.action_space.sample()) 430 | # print(state, reward) 431 | # print(env._elapsed_steps) 432 | # print(env._max_episode_steps) 433 | if i % 50 == 0: 434 | print(i) 435 | task = tasks[np.random.randint(0, 19)] 436 | print('task: ', task) 437 | # env.set_task(task) 438 | print('length: ', env.env_parameter_length) 439 | #print(env.env.model.opt.gravity) 440 | print(env.unwrapped.model.dof_damping) 441 | print(env.init_params) 442 | print('parameter vec: ', env.env_parameter_vector) 443 | print('param_min: ', env.param_min) 444 | print('param_max: ', env.param_max) 445 | print('\n\n') 446 | if done: 447 | state = env.reset() 448 | 449 | 450 | -------------------------------------------------------------------------------- /models/policy.py: -------------------------------------------------------------------------------- 1 | from models.rnn_base import RNNBase 2 | import torch 3 | from torch.distributions import Normal 4 | import numpy as np 5 | import os 6 | import time 7 | 8 | class Policy(torch.nn.Module): 9 | def __init__(self, obs_dim, act_dim, up_hidden_size, up_activations, up_layer_type, 10 | ep_hidden_size, ep_activation, ep_layer_type, ep_dim, use_gt_env_feature, 11 | rnn_fix_length, use_rmdm, share_ep, 12 | logger=None, freeze_ep=False, enhance_ep=False, stop_pg_for_ep=False, 13 | bottle_neck=False, bottle_sigma=1e-4): 14 | super(Policy, self).__init__() 15 | self.obs_dim = obs_dim 16 | self.act_dim = act_dim 17 | self.use_gt_env_feature = use_gt_env_feature 18 | # stop the gradient from ep when inferring action. 19 | self.stop_pg_for_ep = stop_pg_for_ep 20 | self.enhance_ep = enhance_ep 21 | self.bottle_neck = bottle_neck 22 | self.bottle_sigma = bottle_sigma 23 | # aux dim: we add ep to every layer inputs. 24 | aux_dim = ep_dim if enhance_ep else 0 25 | self.ep_dim = ep_dim 26 | self.up = RNNBase(obs_dim + ep_dim, act_dim * 2, up_hidden_size, up_activations, up_layer_type, logger, aux_dim) 27 | self.ep = RNNBase(obs_dim + act_dim, ep_dim, ep_hidden_size, ep_activation, ep_layer_type, logger) 28 | self.ep_temp = RNNBase(obs_dim + act_dim, ep_dim, ep_hidden_size, ep_activation, ep_layer_type, logger) 29 | self.ep_rnn_count = self.ep.rnn_num 30 | self.up_rnn_count = self.up.rnn_num 31 | # ep first, up second 32 | self.module_list = torch.nn.ModuleList(self.up.total_module_list + self.ep.total_module_list 33 | + self.ep_temp.total_module_list) 34 | self.soft_plus = torch.nn.Softplus() 35 | self.min_log_std = -7.0 36 | self.max_log_std = 2.0 37 | self.sample_hidden_state = None 38 | self.rnn_fix_length = rnn_fix_length 39 | self.use_rmdm = use_rmdm 40 | self.ep_tensor = None 41 | self.share_ep = share_ep 42 | self.freeze_ep = freeze_ep 43 | self.allow_sample = True 44 | self.device = torch.device('cpu') 45 | 46 | def set_deterministic_ep(self, deterministic): 47 | self.allow_sample = not deterministic 48 | 49 | def to(self, device): 50 | if not device == self.device: 51 | self.device = device 52 | if self.sample_hidden_state is not None: 53 | for i in range(len(self.sample_hidden_state)): 54 | if self.sample_hidden_state[i] is not None: 55 | self.sample_hidden_state[i] = self.sample_hidden_state[i].to(self.device) 56 | super().to(device) 57 | 58 | def get_ep_temp(self, x, h, require_full_output=False): 59 | if require_full_output: 60 | ep, h, full_hidden = self.ep_temp.meta_forward(x, h, require_full_output) 61 | if self.freeze_ep: 62 | ep = ep.detach() 63 | self.ep_tensor = ep 64 | 65 | return ep, h, full_hidden 66 | ep, h = self.ep_temp.meta_forward(x, h, require_full_output) 67 | if self.freeze_ep: 68 | ep = ep.detach() 69 | self.ep_tensor = ep 70 | return ep, h 71 | 72 | def apply_temp_ep(self, tau): 73 | self.ep.copy_weight_from(self.ep_temp, tau) 74 | 75 | def get_ep(self, x, h, require_full_output=False): 76 | # self.ep_tensor = torch.zeros((x.shape[0], x.shape[1], self.ep_dim), device=x.device, dtype=x.dtype) 77 | # return self.ep_tensor, h 78 | if require_full_output: 79 | ep, h, full_hidden = self.ep.meta_forward(x, h, require_full_output) 80 | if self.share_ep or self.freeze_ep: 81 | ep = ep.detach() 82 | self.ep_tensor = ep 83 | # if self.use_rmdm: 84 | # ep = ep.detach() 85 | # ep = ep / torch.clamp_min(ep.pow(2).mean(dim=-1, keepdim=True).sqrt(), 1e-5) 86 | return ep, h, full_hidden 87 | ep, h = self.ep.meta_forward(x, h, require_full_output) 88 | if self.share_ep or self.freeze_ep: 89 | ep = ep.detach() 90 | self.ep_tensor = ep 91 | # if self.use_rmdm: 92 | # ep = ep.detach() 93 | # ep = ep / torch.clamp_min(ep.pow(2).mean(dim=-1, keepdim=True).sqrt(), 1e-5) 94 | return ep, h 95 | 96 | def ep_h(self, h): 97 | return h[:self.ep_rnn_count] 98 | 99 | def up_h(self, h): 100 | return h[self.ep_rnn_count:] 101 | 102 | def make_init_state(self, batch_size, device): 103 | ep_h = self.ep.make_init_state(batch_size, device) 104 | up_h = self.up.make_init_state(batch_size, device) 105 | h = ep_h + up_h 106 | return h 107 | 108 | def make_init_action(self, device=torch.device('cpu')): 109 | return torch.zeros((1, self.act_dim), device=device) 110 | 111 | def tmp_ep_res(self, x, lst_a, h): 112 | ep_h = h[:self.ep_rnn_count] 113 | ep, ep_h_out = self.get_ep_temp(torch.cat((x, lst_a), -1), ep_h) 114 | return ep 115 | 116 | def meta_forward(self, x, lst_a, h, require_full_output=False): 117 | ep_h = h[:self.ep_rnn_count] 118 | up_h = h[self.ep_rnn_count:] 119 | if not require_full_output: 120 | if not self.use_gt_env_feature: 121 | ep, ep_h_out = self.get_ep(torch.cat((x, lst_a), -1), ep_h) 122 | if self.bottle_neck and self.allow_sample: 123 | ep = ep + torch.randn_like(ep) * self.bottle_sigma 124 | if self.stop_pg_for_ep: 125 | ep = ep.detach() 126 | aux_input = ep if self.enhance_ep else None 127 | up, up_h_out = self.up.meta_forward(torch.cat((x, ep), -1), up_h, aux_state=aux_input) 128 | else: 129 | up, up_h_out = self.up.meta_forward(x, up_h) 130 | ep_h_out = [] 131 | else: 132 | if not self.use_gt_env_feature: 133 | ep, ep_h_out, ep_full_hidden = self.get_ep(torch.cat((x, lst_a), -1), ep_h, require_full_output) 134 | if self.bottle_neck and self.allow_sample: 135 | ep = ep + torch.randn_like(ep) * self.bottle_sigma 136 | if self.stop_pg_for_ep: 137 | ep = ep.detach() 138 | aux_input = ep if self.enhance_ep else None 139 | up, up_h_out, up_full_hidden = self.up.meta_forward(torch.cat((x, ep), -1), up_h, require_full_output, aux_state=aux_input) 140 | else: 141 | up, up_h_out, up_full_hidden = self.up.meta_forward(x, up_h, require_full_output) 142 | ep_h_out = [] 143 | ep_full_hidden = [] 144 | h_out = ep_h_out + up_h_out 145 | return up, h_out, ep_full_hidden + up_full_hidden 146 | h_out = ep_h_out + up_h_out 147 | return up, h_out 148 | 149 | def forward(self, x, lst_a, h, require_log_std=False): 150 | policy_out, h_out = self.meta_forward(x, lst_a, h) 151 | mu, log_std = policy_out.chunk(2, dim=-1) 152 | log_std = torch.clamp(log_std, self.min_log_std, self.max_log_std) 153 | std = log_std.exp() 154 | if require_log_std: 155 | return mu, std, log_std, h_out 156 | return mu, std, h_out 157 | 158 | def rsample(self, x, lst_a, h): 159 | mu, std, log_std, h_out = self.forward(x, lst_a, h, require_log_std=True) 160 | # sample = torch.randn_like(mu).detach() * std + mu 161 | noise = torch.randn_like(mu).detach() * std 162 | sample = noise + mu 163 | log_prob = (- 0.5 * (noise / std).pow(2) - (log_std + 0.5 * np.log(2 * np.pi))).sum(-1, keepdim=True) 164 | # log_prob = dist.log_prob(sample).sum(-1, keepdim=True) 165 | 166 | log_prob = log_prob - (2 * (- sample - self.soft_plus(-2 * sample) + np.log(2))).sum(-1, keepdim=True) 167 | return torch.tanh(mu), std, torch.tanh(sample), log_prob, h_out 168 | 169 | def save(self, path): 170 | self.up.save(os.path.join(path, 'universe_policy.pt')) 171 | self.ep.save(os.path.join(path, 'environment_probe.pt')) 172 | 173 | def load(self, path, **kwargs): 174 | self.up.load(os.path.join(path, 'universe_policy.pt'), **kwargs) 175 | self.ep.load(os.path.join(path, 'environment_probe.pt'), **kwargs) 176 | 177 | @staticmethod 178 | def make_config_from_param(parameter): 179 | return dict( 180 | up_hidden_size=parameter.up_hidden_size, 181 | up_activations=parameter.up_activations, 182 | up_layer_type=parameter.up_layer_type, 183 | ep_hidden_size=parameter.ep_hidden_size, 184 | ep_activation=parameter.ep_activations, 185 | ep_layer_type=parameter.ep_layer_type, 186 | ep_dim=parameter.ep_dim, 187 | use_gt_env_feature=parameter.use_true_parameter, 188 | rnn_fix_length=parameter.rnn_fix_length, 189 | use_rmdm=parameter.use_rmdm, 190 | share_ep=parameter.share_ep, 191 | enhance_ep=parameter.enhance_ep, 192 | stop_pg_for_ep=parameter.stop_pg_for_ep, 193 | bottle_neck=parameter.bottle_neck, 194 | bottle_sigma=parameter.bottle_sigma 195 | ) 196 | 197 | def inference_init_hidden(self, batch_size, device=torch.device("cpu")): 198 | if self.rnn_fix_length is None or self.rnn_fix_length == 0: 199 | self.sample_hidden_state = self.make_init_state(batch_size, device) 200 | else: 201 | self.sample_hidden_state = [None] * len(self.make_init_state(batch_size, device)) 202 | 203 | def inference_check_hidden(self, batch_size): 204 | if self.sample_hidden_state is None: 205 | return False 206 | if len(self.sample_hidden_state) == 0: 207 | return True 208 | if self.rnn_fix_length is not None and self.rnn_fix_length > 0: 209 | return True 210 | if isinstance(self.sample_hidden_state[0], tuple): 211 | return self.sample_hidden_state[0][0].shape[0] == batch_size 212 | else: 213 | return self.sample_hidden_state[0].shape[0] == batch_size 214 | 215 | def inference_rnn_fix_one_action(self, state, lst_action): 216 | if self.use_gt_env_feature: 217 | mu, std, act, logp, self.sample_hidden_state = self.rsample(state, lst_action, self.sample_hidden_state) 218 | return mu, std, act, logp, self.sample_hidden_state 219 | 220 | while RNNBase.get_hidden_length(self.sample_hidden_state) >= self.rnn_fix_length: 221 | self.sample_hidden_state = RNNBase.pop_hidden_state(self.sample_hidden_state) 222 | self.sample_hidden_state = RNNBase.append_hidden_state(self.sample_hidden_state, 223 | self.make_init_state(1, state.device)) 224 | # print('1: ', self.sample_hidden_state) 225 | if len(state.shape) == 2: 226 | state = state.unsqueeze(0) 227 | lst_action = lst_action.unsqueeze(0) 228 | length = RNNBase.get_hidden_length(self.sample_hidden_state) 229 | # length = max(length, 1) 230 | # state = state.repeat_interleave(length, dim=0) 231 | state = torch.cat([state] * length, dim=0) 232 | # print('input: ', state[0]) 233 | # lst_action = lst_action.repeat_interleave(length, dim=0) 234 | lst_action = torch.cat([lst_action] * length, dim=0) 235 | mu, std, act, logp, self.sample_hidden_state = self.rsample(state, lst_action, self.sample_hidden_state) 236 | # print('2: ', self.sample_hidden_state) 237 | 238 | return mu, std, act, logp, self.sample_hidden_state 239 | 240 | def inference_one_step(self, state, deterministic=True): 241 | self.set_deterministic_ep(deterministic) 242 | with torch.no_grad(): 243 | lst_action = state[..., :self.act_dim] 244 | state = state[..., self.act_dim:] 245 | if self.rnn_fix_length is None or self.rnn_fix_length == 0 or len(self.sample_hidden_state) == 0: 246 | mu, std, act, logp, self.sample_hidden_state = self.rsample(state, lst_action, self.sample_hidden_state) 247 | else: 248 | while RNNBase.get_hidden_length(self.sample_hidden_state) < self.rnn_fix_length - 1 and not self.use_gt_env_feature: 249 | _, _, _, _, self.sample_hidden_state = self.inference_rnn_fix_one_action(torch.zeros_like(state), 250 | torch.zeros_like(lst_action)) 251 | mu, std, act, logp, self.sample_hidden_state = self.inference_rnn_fix_one_action(state, lst_action) 252 | mu, std, act, logp = map(lambda x: x[:1].reshape((1, -1)), [mu, std, act, logp]) 253 | # self.ep_tensor = 254 | if deterministic: 255 | return mu 256 | return act 257 | 258 | def inference_reset_one_hidden(self, idx): 259 | if self.rnn_fix_length is not None and self.rnn_fix_length > 0: 260 | raise NotImplementedError('if rnn fix length is set, parallel sampling is not allowed!!!') 261 | for i in range(len(self.sample_hidden_state)): 262 | if isinstance(self.sample_hidden_state[i], tuple): 263 | self.sample_hidden_state[i][0][0, idx] = 0 264 | self.sample_hidden_state[i][1][0, idx] = 0 265 | else: 266 | self.sample_hidden_state[i][0, idx] = 0 267 | 268 | @staticmethod 269 | def slice_tensor(x, slice_num): 270 | assert len(x.shape) == 3, 'slice operation should be added on 3-dim tensor' 271 | assert x.shape[1] % slice_num == 0, f'cannot reshape length with {x.shape[1]} to {slice_num} slices' 272 | s = x.shape 273 | x = x.reshape([s[0], s[1] // slice_num, slice_num, s[2]]).transpose(0, 1) 274 | return x 275 | 276 | @staticmethod 277 | def slice_tensor_overlap(x, slice_num): 278 | x_shape = x.shape 279 | x = torch.cat((torch.zeros((x_shape[0], slice_num-1, x_shape[2]), device=x.device), x), dim=1) 280 | xs = [] 281 | for i in range(x_shape[1]): 282 | xs.append(x[:, i: i + slice_num, :]) 283 | x = torch.cat(xs, dim=0) 284 | return x 285 | 286 | def generate_hidden_state_with_slice(self, sliced_state: torch.Tensor, sliced_lst_action: torch.Tensor): 287 | """ 288 | :param sliced_state: 0-dim: mini-trajectory index, 1-dim: batch_size, 2-dim: time step, 3-dim: feature index 289 | :param sliced_lst_action: 290 | :param slice_num: 291 | :return: 292 | """ 293 | with torch.no_grad(): 294 | mini_traj_num = sliced_state.shape[0] 295 | batch_size = sliced_state.shape[1] 296 | device = sliced_state.device 297 | hidden_states = [] 298 | hidden_state_now = self.make_init_state(batch_size, device) 299 | for i in range(mini_traj_num): 300 | hidden_states.append(hidden_state_now) 301 | _, hidden_state_now = self.meta_forward(sliced_state[i], sliced_lst_action[i], hidden_state_now) 302 | return hidden_states 303 | 304 | @staticmethod 305 | def reshaping_hidden(full_hidden, init_hidden, slice_num, traj_len): 306 | for i in range(len(full_hidden)): 307 | # print(f'{hidden_state_now[i].shape}, {full_hidden[i].shape}') 308 | full_hidden[i] = torch.cat((init_hidden[i].squeeze(0).unsqueeze(1), full_hidden[i]), dim=1) 309 | # full_hidden[i] = full_hidden[i].unsqueeze(0) 310 | idx = [i * slice_num for i in range(traj_len // slice_num)] 311 | hidden_states = [item[:, idx].transpose(0, 1) for item in full_hidden] 312 | hidden_states_res = [] 313 | for item in hidden_states: 314 | it_shape = item.shape 315 | hidden_states_res.append(item.reshape((1, it_shape[0] * it_shape[1], it_shape[2]))) 316 | return hidden_states_res 317 | 318 | def generate_hidden_state(self, state: torch.Tensor, lst_action: torch.Tensor, slice_num, use_tmp_ep=False): 319 | """ 320 | :param sliced_state: 0-dim: mini-trajectory index, 1-dim: batch_size, 2-dim: time step, 3-dim: feature index 321 | :param sliced_lst_action: 322 | :param slice_num: 323 | :return: 324 | """ 325 | with torch.no_grad(): 326 | batch_size = state.shape[0] 327 | device = state.device 328 | hidden_states = [] 329 | hidden_state_now = self.make_init_state(batch_size, device) 330 | if not use_tmp_ep: 331 | _, _, full_hidden = self.meta_forward(state, lst_action, hidden_state_now, require_full_output=True) 332 | else: 333 | _, _, full_hidden = self.get_ep_temp(torch.cat((state, lst_action), -1), hidden_state_now, 334 | require_full_output=True) 335 | hidden_states_res = self.reshaping_hidden(full_hidden, hidden_state_now, slice_num, state.shape[1]) 336 | return hidden_states_res 337 | 338 | @staticmethod 339 | def merge_slice_tensor(*args): 340 | res = [] 341 | for item in args: 342 | s = item.shape 343 | # print(s) 344 | res.append(item.reshape(s[0] * s[1], s[2], s[3])) 345 | return res 346 | 347 | @staticmethod 348 | def merge_slice_hidden(hidden_states): 349 | """ 350 | usage: [state, lst_action], hidden = self.merge_slice(sliced_state, sliced_lst_action, hidden_states) 351 | :param args: 352 | :param hidden_states: 353 | :return: 354 | """ 355 | res_hidden = [] 356 | len_hidden = len(hidden_states[0]) 357 | for i in range(len_hidden): 358 | h = [item[i] for item in hidden_states] 359 | hid = torch.cat(h, dim=1) 360 | res_hidden.append(hid) 361 | return res_hidden 362 | 363 | @staticmethod 364 | def hidden_state_sample(hidden_state, inds): 365 | res_hidden = [] 366 | len_hidden = len(hidden_state) 367 | for i in range(len_hidden): 368 | h = hidden_state[i][:, inds] 369 | hid = h 370 | res_hidden.append(hid) 371 | return res_hidden 372 | 373 | @staticmethod 374 | def hidden_state_slice(hidden_state, start, end): 375 | res_hidden = [] 376 | len_hidden = len(hidden_state) 377 | for i in range(len_hidden): 378 | h = hidden_state[i][:, start: end] 379 | hid = h 380 | res_hidden.append(hid) 381 | return res_hidden 382 | 383 | 384 | @staticmethod 385 | def hidden_state_mask(hidden_state, masks): 386 | res_hidden = [] 387 | len_hidden = len(hidden_state) 388 | for i in range(len_hidden): 389 | h = hidden_state[i].squeeze(0)[masks].unsqueeze(0) 390 | hid = h 391 | res_hidden.append(hid) 392 | return res_hidden 393 | 394 | @staticmethod 395 | def hidden_detach(hidden_state): 396 | res_hidden = [] 397 | len_hidden = len(hidden_state) 398 | for i in range(len_hidden): 399 | res_hidden.append(hidden_state[i].detach()) 400 | return res_hidden 401 | 402 | def _test_forward_time(self, device, num=1000, batch_size=1): 403 | 404 | h = self.make_init_state(batch_size, device) 405 | start_time = time.time() 406 | action_tensor = torch.randn((batch_size, 1, self.act_dim), device=device) 407 | obs_tensor = torch.randn((batch_size, 1, self.obs_dim), device=device) 408 | 409 | for i in range(num): 410 | self.forward(obs_tensor, action_tensor, h=h) 411 | end_time = time.time() 412 | print('pure forward time: {}, pure meta forward time: {}'.format(self.ep.cumulative_forward_time + self.up.cumulative_forward_time, 413 | self.ep.cumulative_meta_forward_time + self.up.cumulative_meta_forward_time)) 414 | print('running for {} times, spending time is {:.2f}, batch size is {}, device: {}'.format(num, end_time - start_time, batch_size, device)) 415 | self.ep.cumulative_forward_time = 0 416 | self.up.cumulative_forward_time = 0 417 | self.ep.cumulative_meta_forward_time = 0 418 | self.up.cumulative_meta_forward_time = 0 419 | 420 | def _test_inference_time(self, device, num=1000, batch_size=1): 421 | if not self.inference_check_hidden(1): 422 | self.inference_init_hidden(1, device) 423 | # h = self.make_init_state(batch_size, device) 424 | start_time = time.time() 425 | obs_act = torch.randn((1, self.obs_dim+self.act_dim), device=device) 426 | for i in range(num): 427 | self.inference_one_step(obs_act) 428 | # self.forward(torch.randn((batch_size, 1, self.obs_dim), device=device), torch.randn((batch_size, 1, self.act_dim), device=device), h=h) 429 | end_time = time.time() 430 | print('pure forward time: {}, pure meta forward time: {}'.format( 431 | self.ep.cumulative_forward_time + self.up.cumulative_forward_time, 432 | self.ep.cumulative_meta_forward_time + self.up.cumulative_meta_forward_time)) 433 | # print('pure forward time: {}'.format(self.ep.cumulative_forward_time + self.up.cumulative_forward_time)) 434 | print('running for {} times, spending time is {:.2f}, batch size is {}, device: {}'.format(num, end_time - start_time, batch_size, device)) 435 | self.ep.cumulative_forward_time = 0 436 | self.up.cumulative_forward_time = 0 437 | self.ep.cumulative_meta_forward_time = 0 438 | self.up.cumulative_meta_forward_time = 0 439 | -------------------------------------------------------------------------------- /agent/Agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import gym 4 | import ray 5 | import torch 6 | import numpy as np 7 | import random 8 | import math 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from models.policy import Policy 12 | from utils.replay_memory import Memory, MemoryNp 13 | from utils.history_construct import HistoryConstructor 14 | from parameter.Parameter import Parameter 15 | from parameter.private_config import SKIP_MAX_LEN_DONE, NON_STATIONARY_PERIOD, NON_STATIONARY_INTERVAL 16 | from log_util.logger import Logger 17 | from envs.grid_world import RandomGridWorld 18 | from envs.grid_world_general import RandomGridWorldPlat 19 | from parameter.private_config import ENV_DEFAULT_CHANGE 20 | 21 | 22 | class EnvWorker: 23 | def __init__(self, parameter: Parameter, env_name='Hopper-v2', seed=0, policy_type=Policy, 24 | history_len=0, env_decoration=None, env_tasks=None, 25 | use_true_parameter=False, non_stationary=False): 26 | self.env = gym.make(env_name) 27 | self.action_space = self.env.action_space 28 | self.fix_env_setting = False 29 | self.set_global_seed(seed) 30 | self.use_true_parameter = use_true_parameter 31 | self.non_stationary = non_stationary 32 | if env_decoration is not None: 33 | default_change_range = ENV_DEFAULT_CHANGE if not hasattr(parameter, 'env_default_change_range') \ 34 | else parameter.env_default_change_range 35 | if not hasattr(parameter, 'env_default_change_range'): 36 | print('[WARN]: env_default_change_range does not appears in parameter!') 37 | self.env = env_decoration(self.env, log_scale_limit=default_change_range, 38 | rand_params=parameter.varying_params) 39 | self.observation_space = self.env.observation_space 40 | 41 | self.history_constructor = HistoryConstructor(history_len, self.observation_space.shape[0], 42 | self.action_space.shape[0], need_lst_action=True) 43 | self.env_tasks = None 44 | self.task_ind = -1 45 | self.env.reset() 46 | if env_tasks is not None and isinstance(env_tasks, list) and len(env_tasks) > 0: 47 | self.env_tasks = env_tasks 48 | self.task_ind = random.randint(0, len(self.env_tasks) - 1) 49 | self.env.set_task(self.env_tasks[self.task_ind]) 50 | policy_config = Policy.make_config_from_param(parameter) 51 | if use_true_parameter: 52 | policy_config['ep_dim'] = self.env.env_parameter_length 53 | self.policy = policy_type(obs_dim=self.observation_space.shape[0], 54 | act_dim=self.action_space.shape[0], 55 | **policy_config) 56 | self.policy.inference_init_hidden(1) 57 | self.bottle_neck = parameter.bottle_neck 58 | self.ep_len = 0 59 | self.ep_cumrew = 0 60 | self.history_len = history_len 61 | self.ep_len_list = [] 62 | self.ep_cumrew_list = [] 63 | self.ep_rew_list = [] 64 | self.state = self.reset(None) 65 | self.state = self.extend_state(self.state) 66 | 67 | def set_weight(self, state_dict): 68 | self.policy.load_state_dict(state_dict) 69 | 70 | def set_global_seed(self, seed): 71 | import numpy as np 72 | import torch 73 | import random 74 | self.env.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | random.seed(seed) 78 | self.action_space.seed(seed) 79 | 80 | def get_weight(self): 81 | return self.policy.state_dict() 82 | 83 | def set_fix_env_setting(self, fix_env_setting=True): 84 | self.fix_env_setting = fix_env_setting 85 | 86 | def change_env_param(self, set_env_ind=None): 87 | if self.fix_env_setting: 88 | self.env.set_task(self.env_tasks[self.task_ind]) 89 | return 90 | if self.env_tasks is not None and len(self.env_tasks) > 0: 91 | self.task_ind = random.randint(0, len(self.env_tasks) - 1) if set_env_ind is None or set_env_ind >= \ 92 | len(self.env_tasks) else set_env_ind 93 | self.env.set_task(self.env_tasks[self.task_ind]) 94 | if self.non_stationary: 95 | another_task = random.randint(0, len(self.env_tasks) - 1) 96 | env_param_list = [self.env_tasks[self.task_ind]] + [self.env_tasks[random.randint(0, len(self.env_tasks)-1)] for _ in range(15)] 97 | self.env.set_nonstationary_para(env_param_list, 98 | NON_STATIONARY_PERIOD, NON_STATIONARY_INTERVAL) 99 | 100 | def sample(self, min_batch, deterministic=False, env_ind=None): 101 | step_count = 0 102 | mem = Memory() 103 | log = {'EpRet': [], 104 | 'EpMeanRew': [], 105 | 'EpLen': []} 106 | if deterministic and self.bottle_neck: 107 | self.policy.set_deterministic_ep(True) 108 | elif self.bottle_neck and not deterministic: 109 | self.policy.set_deterministic_ep(False) 110 | with torch.no_grad(): 111 | while step_count < min_batch: 112 | state = self.reset(env_ind) 113 | state = self.extend_state(state) 114 | self.policy.inference_init_hidden(1) 115 | while True: 116 | state_tensor = torch.from_numpy(state).to(torch.get_default_dtype()).unsqueeze(0) 117 | action_tensor = self.policy.inference_one_step(state_tensor, deterministic)[0] 118 | action = action_tensor.numpy() 119 | self.before_apply_action(action) 120 | next_state, reward, done, _ = self.env.step(self.env.denormalization(action)) 121 | if self.non_stationary: 122 | self.env_param_vector = self.env.env_parameter_vector 123 | next_state = self.extend_state(next_state) 124 | mask = 0.0 if done else 1.0 125 | if SKIP_MAX_LEN_DONE and done and self.env._elapsed_steps >= self.env._max_episode_steps: 126 | mask = 1.0 127 | mem.push(state[self.action_space.shape[0]:], action.astype(state.dtype), [mask], 128 | next_state[self.action_space.shape[0]:], [reward], 129 | None, [self.task_ind + 1], self.env_param_vector, 130 | state[:self.action_space.shape[0]], [done], [1]) 131 | self.ep_cumrew += reward 132 | self.ep_len += 1 133 | step_count += 1 134 | if done: 135 | log['EpMeanRew'].append(self.ep_cumrew / self.ep_len) 136 | log['EpLen'].append(self.ep_len) 137 | log['EpRet'].append(self.ep_cumrew) 138 | break 139 | state = next_state 140 | mem.memory = [mem.memory[0]] 141 | return mem, log 142 | 143 | def get_current_state(self): 144 | return self.state 145 | 146 | def extend_state(self, state): 147 | state = self.history_constructor(state) 148 | if self.use_true_parameter: 149 | state = np.hstack([state, self.env_param_vector]) 150 | return state 151 | 152 | def before_apply_action(self, action): 153 | self.history_constructor.update_action(action) 154 | 155 | def reset(self, env_ind=None): 156 | self.history_constructor.reset() 157 | self.change_env_param(env_ind) 158 | state = self.env.reset() 159 | self.env_param_vector = self.env.env_parameter_vector 160 | self.ep_len = 0 161 | self.ep_cumrew = 0 162 | return state 163 | 164 | def step(self, action, env_ind=None, render=False, need_info=False): 165 | self.before_apply_action(action) 166 | next_state, reward, done, info = self.env.step(self.env.denormalization(action)) 167 | if render: 168 | self.env.render() 169 | if self.non_stationary: 170 | self.env_param_vector = self.env.env_parameter_vector 171 | current_env_step = self.env._elapsed_steps 172 | self.state = next_state = self.extend_state(next_state) 173 | self.ep_len += 1 174 | self.ep_cumrew += reward 175 | cur_task_ind = self.task_ind 176 | cur_env_param = self.env_param_vector 177 | if done: 178 | self.ep_len_list.append(self.ep_len) 179 | self.ep_cumrew_list.append(self.ep_cumrew) 180 | self.ep_rew_list.append(self.ep_cumrew / self.ep_len) 181 | state = self.reset(env_ind) 182 | self.state = self.extend_state(state) 183 | if need_info: 184 | return next_state, reward, done, self.state, cur_task_ind, cur_env_param, current_env_step, info 185 | return next_state, reward, done, self.state, cur_task_ind, cur_env_param, current_env_step 186 | 187 | def collect_result(self): 188 | ep_len_list = self.ep_len_list 189 | self.ep_len_list = [] 190 | ep_cumrew_list = self.ep_cumrew_list 191 | self.ep_cumrew_list = [] 192 | ep_rew_list = self.ep_rew_list 193 | self.ep_rew_list = [] 194 | log = { 195 | 'EpMeanRew': ep_rew_list, 196 | 'EpLen': ep_len_list, 197 | 'EpRet': ep_cumrew_list 198 | } 199 | return log 200 | 201 | 202 | class EnvRemoteArray: 203 | def __init__(self, parameter, env_name, worker_num=2, seed=None, 204 | deterministic=False, use_remote=True, 205 | policy_type=Policy, history_len=0, env_decoration=None, 206 | env_tasks=None, use_true_parameter=False, non_stationary=False): 207 | self.env = gym.make(env_name) 208 | self.obs_dim = self.env.observation_space.shape[0] 209 | self.act_dim = self.env.action_space.shape[0] 210 | self.action_space = self.env.action_space 211 | self.set_seed(seed) 212 | self.non_stationary = non_stationary 213 | self.env_tasks = env_tasks 214 | # if worker_num == 1: 215 | # use_remote = False 216 | RemoteEnvWorker = ray.remote(EnvWorker) if use_remote else EnvWorker 217 | if use_remote: 218 | self.workers = [RemoteEnvWorker.remote(parameter, env_name, random.randint(0, 10000), 219 | policy_type, history_len, env_decoration, env_tasks, 220 | use_true_parameter, non_stationary) for _ in range(worker_num)] 221 | else: 222 | self.workers = [RemoteEnvWorker(parameter, env_name, random.randint(0, 10000), 223 | policy_type, history_len, env_decoration, env_tasks, 224 | use_true_parameter, non_stationary) for _ in range(worker_num)] 225 | 226 | if env_decoration is not None: 227 | default_change_range = ENV_DEFAULT_CHANGE if not hasattr(parameter, 'env_default_change_range') \ 228 | else parameter.env_default_change_range 229 | if not hasattr(parameter, 'env_default_change_range'): 230 | print('[WARN]: env_default_change_range does not appears in parameter!') 231 | self.env = env_decoration(self.env, log_scale_limit=default_change_range, 232 | rand_params=parameter.varying_params) 233 | net_config = Policy.make_config_from_param(parameter) 234 | self.policy = Policy(self.env.observation_space.shape[0], self.env.action_space.shape[0], **net_config) 235 | self.worker_num = worker_num 236 | self.env_name = env_name 237 | 238 | self.env.reset() 239 | if isinstance(env_tasks, list) and len(env_tasks) > 0: 240 | self.env.set_task(random.choice(env_tasks)) 241 | self.env_parameter_len = self.env.env_parameter_length 242 | self.running_state = None 243 | self.deterministic = deterministic 244 | self.use_remote = use_remote 245 | self.total_steps = 0 246 | 247 | def set_seed(self, seed): 248 | if seed is None: 249 | return 250 | import numpy as np 251 | import torch 252 | import random 253 | np.random.seed(seed) 254 | torch.manual_seed(seed) 255 | random.seed(seed) 256 | self.action_space.seed(seed) 257 | self.env.seed(seed) 258 | 259 | def set_fix_env_setting(self, fix_env_setting=True): 260 | if self.use_remote: 261 | ray.get([worker.set_fix_env_setting.remote(fix_env_setting) for worker in self.workers]) 262 | else: 263 | for worker in self.workers: 264 | worker.set_fix_env_setting(fix_env_setting) 265 | 266 | def submit_task(self, min_batch, policy=None, env_ind=None): 267 | assert not (policy is None and self.policy is None) 268 | cur_policy = policy if policy is not None else self.policy 269 | ray.get([worker.set_weight.remote(cur_policy.state_dict()) for worker in self.workers]) 270 | 271 | min_batch_per_worker = min_batch // self.worker_num + 1 272 | futures = [worker.sample.remote(min_batch_per_worker, self.deterministic, env_ind) 273 | for worker in self.workers] 274 | return futures 275 | 276 | def query_sample(self, futures, need_memory): 277 | mem_list_pre = ray.get(futures) 278 | mem = Memory() 279 | [mem.append(new_mem) for new_mem, _ in mem_list_pre] 280 | logs = {key: [] for key in mem_list_pre[0][1]} 281 | for key in logs: 282 | for _, item in mem_list_pre: 283 | logs[key] += item[key] 284 | logs['TotalSteps'] = len(mem) 285 | batch = self.extract_from_memory(mem) 286 | if need_memory: 287 | return batch, logs, mem 288 | return batch, logs 289 | 290 | # always use remote 291 | def sample(self, min_batch, policy=None, env_ind=None): 292 | assert not (policy is None and self.policy is None) 293 | cur_policy = policy if policy is not None else self.policy 294 | ray.get([worker.set_weight.remote(cur_policy.state_dict()) for worker in self.workers]) 295 | 296 | min_batch_per_worker = min_batch // self.worker_num + 1 297 | futures = [worker.sample.remote(min_batch_per_worker, self.deterministic, env_ind) 298 | for worker in self.workers] 299 | mem_list_pre = ray.get(futures) 300 | mem = Memory() 301 | [mem.append(new_mem) for new_mem, _ in mem_list_pre] 302 | logs = {key: [] for key in mem_list_pre[0][1]} 303 | for key in logs: 304 | for _, item in mem_list_pre: 305 | logs[key] += item[key] 306 | logs['TotalSteps'] = len(mem) 307 | return mem, logs 308 | 309 | def sample_locally(self, min_batch, policy=None, env_ind=None): 310 | assert not (policy is None and self.policy is None) 311 | cur_policy = policy if policy is not None else self.policy 312 | for worker in self.workers: 313 | worker.set_weight(cur_policy.state_dict()) 314 | min_batch_per_worker = min_batch // self.worker_num + 1 315 | mem_list_pre = [worker.sample(min_batch_per_worker, self.deterministic, env_ind) 316 | for worker in self.workers] 317 | mem = Memory() 318 | [mem.append(new_mem) for new_mem, _ in mem_list_pre] 319 | logs = {key: [] for key in mem_list_pre[0][1]} 320 | for key in logs: 321 | for _, item in mem_list_pre: 322 | logs[key] += item[key] 323 | logs['TotalSteps'] = len(mem) 324 | return mem, logs 325 | 326 | def sample1step(self, policy=None, random=False, device=torch.device('cpu'), env_ind=None): 327 | assert not (policy is None and self.policy is None) 328 | cur_policy = policy if policy is not None else self.policy 329 | if (not self.use_remote) and self.worker_num == 1: 330 | return self.sample1step1env(policy, random, device, env_ind) 331 | if not cur_policy.inference_check_hidden(self.worker_num): 332 | cur_policy.inference_init_hidden(self.worker_num, device) 333 | if self.use_remote: 334 | states = ray.get([worker.get_current_state.remote() for worker in self.workers]) 335 | else: 336 | states = [worker.get_current_state() for worker in self.workers] 337 | 338 | states = np.array(states) 339 | with torch.no_grad(): 340 | if random: 341 | actions = [self.env.normalization(self.action_space.sample()) for item in states] 342 | else: 343 | states_tensor = torch.from_numpy(states).to(torch.get_default_dtype()).to(device).unsqueeze(1) 344 | actions = cur_policy.inference_one_step(states_tensor, self.deterministic).to(torch.device('cpu')).squeeze(1).numpy() 345 | 346 | if self.use_remote: 347 | srd = ray.get([worker.step.remote(action, env_ind) for action, worker in zip(actions, self.workers)]) 348 | else: 349 | srd = [worker.step(action) for action, worker in zip(actions, self.workers)] 350 | 351 | mem = Memory() 352 | for ind, (next_state, reward, done, _, task_ind, env_param, current_steps) in enumerate(srd): 353 | if done: 354 | cur_policy.inference_reset_one_hidden(ind) 355 | mask = 0.0 if done else 1.0 356 | if SKIP_MAX_LEN_DONE and done and current_steps >= self.env._max_episode_steps: 357 | mask = 1.0 358 | mem.push(states[ind, self.action_space.shape[0]:], actions[ind].astype(states.dtype), [mask], 359 | next_state[self.action_space.shape[0]:], [reward], None, [task_ind + 1], 360 | env_param, states[ind, :self.action_space.shape[0]], [done], [1]) 361 | if self.use_remote: 362 | logs_ = ray.get([worker.collect_result.remote() for worker in self.workers]) 363 | else: 364 | logs_ = [worker.collect_result() for worker in self.workers] 365 | logs = {key: [] for key in logs_[0]} 366 | for key in logs: 367 | for item in logs_: 368 | logs[key] += item[key] 369 | logs['TotalSteps'] = len(mem) 370 | return mem, logs 371 | 372 | def get_action(self, state, cur_policy, random, device=torch.device("cpu")): 373 | with torch.no_grad(): 374 | if random: 375 | action = self.env.normalization(self.action_space.sample()) 376 | else: 377 | action = cur_policy.inference_one_step(torch.from_numpy(state[None]).to(device=device, 378 | dtype=torch.get_default_dtype()), 379 | self.deterministic)[0].to( 380 | torch.device('cpu')).numpy() 381 | return action 382 | 383 | def sample1step1env(self, policy, random=False, device=torch.device('cpu'), env_ind=None, render=False, need_info=False): 384 | if not policy.inference_check_hidden(1): 385 | policy.inference_init_hidden(1, device) 386 | cur_policy = policy 387 | worker = self.workers[0] 388 | state = worker.get_current_state() 389 | action = self.get_action(state, cur_policy, random, device) 390 | if need_info: 391 | next_state, reward, done, _, task_ind, env_param, current_steps, info = worker.step(action, env_ind, render, need_info=True) 392 | else: 393 | next_state, reward, done, _, task_ind, env_param, current_steps = worker.step(action, env_ind, render, need_info=False) 394 | 395 | if done: 396 | policy.inference_init_hidden(1, device) 397 | mem = Memory() 398 | mask = 0.0 if done else 1.0 399 | if SKIP_MAX_LEN_DONE and done and current_steps >= self.env._max_episode_steps: 400 | mask = 1.0 401 | mem.push(state[self.act_dim:], action.astype(state.dtype), [mask], 402 | next_state[self.action_space.shape[0]:], [reward], None, 403 | [task_ind + 1], env_param, state[:self.act_dim], [done], [1]) 404 | logs = worker.collect_result() 405 | self.total_steps += 1 406 | logs['TotalSteps'] = self.total_steps 407 | if need_info: 408 | return mem, logs, info 409 | return mem, logs 410 | 411 | def collect_samples(self, min_batch, policy=None, need_memory=False): 412 | for i in range(10): 413 | try: 414 | mem, logs = self.sample(min_batch, policy) 415 | break 416 | except Exception as e: 417 | print(f'Error occurs while sampling, the error is {e}, tried time: {i}') 418 | batch = self.extract_from_memory(mem) 419 | if need_memory: 420 | return batch, logs, mem 421 | return batch, logs 422 | 423 | def update_running_state(self, state): 424 | pass 425 | 426 | @staticmethod 427 | def extract_from_memory(mem): 428 | batch = mem.sample() 429 | state, action, next_state, reward, mask = np.array(batch.state), np.array(batch.action), np.array( 430 | batch.next_state), np.array(batch.reward), np.array(batch.mask) 431 | res = {'state': state, 'action': action, 'next_state': next_state, 'reward': reward, 'mask': mask} 432 | return res 433 | 434 | def make_env_param_dict(self, parameter_name): 435 | res = {} 436 | if self.env_tasks is not None: 437 | for ind, item in enumerate(self.env_tasks): 438 | res[ind + 1] = item 439 | res_interprete = {} 440 | for k, v in res.items(): 441 | if isinstance(v, dict): 442 | res_interprete[k] = [v[parameter_name][-1]] 443 | elif isinstance(v, int): 444 | res_interprete[k] = v 445 | elif isinstance(v, list): 446 | res_interprete[k] = math.sqrt(sum([item**2 for item in v])) 447 | else: 448 | raise NotImplementedError(f'type({type(v)}) is not implemented.') 449 | return res_interprete 450 | 451 | def make_env_param_dict_from_params(self, params): 452 | res_interprete = {} 453 | for param in params: 454 | res_ = self.make_env_param_dict(param) 455 | for k, v in res_.items(): 456 | if k not in res_interprete: 457 | res_interprete[k] = v 458 | else: 459 | res_interprete[k] += v 460 | 461 | return res_interprete 462 | 463 | if __name__ == '__main__': 464 | from envs.nonstationary_env import NonstationaryEnv 465 | env_name = 'Hopper-v2' 466 | logger = Logger() 467 | parameter = logger.parameter 468 | env = NonstationaryEnv(gym.make(env_name), rand_params=parameter.varying_params) 469 | 470 | ray.init() 471 | worker_num = 1 472 | env_array = EnvRemoteArray(parameter=parameter, env_name=env_name, 473 | worker_num=worker_num, seed=None, 474 | policy_type=Policy, env_decoration=NonstationaryEnv, 475 | use_true_parameter=False, 476 | env_tasks=env.sample_tasks(10), history_len=0, use_remote=False) 477 | print(parameter.varying_params) 478 | paramdict = env_array.make_env_param_dict_from_params(parameter.varying_params) 479 | print(paramdict) 480 | configs = Policy.make_config_from_param(parameter) 481 | net = Policy(env.observation_space.shape[0], env.action_space.shape[0], **configs) 482 | device = torch.device('cuda', index=0) if torch.cuda.is_available() else torch.device('cpu') 483 | net.to(device) 484 | import time 485 | total_step = 0 486 | replay_buffer = Memory() 487 | for _ in range(100): 488 | t0 = time.time() 489 | batch, logs, mem = env_array.collect_samples(8000, net, need_memory=True) 490 | logger.add_tabular_data(**logs) 491 | replay_buffer.append(mem) 492 | 493 | # for i in range(4000): 494 | # mem, logs = env_array.sample1step(net, False, device) 495 | # replay_buffer.append(mem) 496 | # logger.add_tabular_data(**logs) 497 | 498 | total_step += 4000 * worker_num 499 | # transitions = env_array.extract_from_memory(memory) 500 | t1 = time.time() 501 | logger.log_tabular('TimeInterval (s)', t1-t0) 502 | logger.log_tabular('EnvSteps', total_step) 503 | logger.log_tabular('ReplayBufferSize', len(replay_buffer)) 504 | logger.dump_tabular() 505 | # print(logs) 506 | 507 | 508 | 509 | -------------------------------------------------------------------------------- /parameter/Parameter.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import json 4 | import socket 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | from parameter.private_config import * 7 | from datetime import datetime 8 | 9 | class Parameter: 10 | def __init__(self, config_path=None, debug=False, information=None): 11 | self.base_path = self.get_base_path() 12 | self.debug = debug 13 | self.experiment_target = EXPERIMENT_TARGET 14 | self.DEFAULT_CONFIGS = global_configs() 15 | self.arg_names = [] 16 | self.host_name = 'localhost' 17 | self.ip = '127.0.0.1' 18 | self.exec_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') 19 | self.commit_id = self.get_commit_id() 20 | self.log_func = None 21 | self.json_name = 'parameter.json' 22 | self.txt_name = 'parameter.txt' 23 | self.information = information 24 | if config_path: 25 | self.config_path = config_path 26 | else: 27 | self.info('use default config path') 28 | self.config_path = osp.join(get_base_path(), 'parameter') 29 | self.info(f'json path is {os.path.join(self.config_path, self.json_name)}, ' 30 | f'txt path is {os.path.join(self.config_path, self.txt_name)}') 31 | if config_path: 32 | self.load_config() 33 | else: 34 | self.args = self.parse() 35 | self.apply_vars(self.args) 36 | 37 | def set_log_func(self, log_func): 38 | self.log_func = log_func 39 | 40 | def info(self, info): 41 | if self.log_func is not None: 42 | self.log_func(info) 43 | else: 44 | print(info) 45 | 46 | @staticmethod 47 | def get_base_path(): 48 | return get_base_path() 49 | 50 | def set_config_path(self, config_path): 51 | self.config_path = config_path 52 | 53 | @staticmethod 54 | def important_configs(): 55 | res = ['env_name', 'use_true_parameter', 'use_rmdm', 56 | 'use_uposi', 'share_ep', "rnn_fix_length", 57 | 'uniform_sample', 'enhance_ep', 'bottle_neck', 'stop_pg_for_ep', "ep_dim", 'seed'] 58 | return res 59 | 60 | def apply_vars(self, args): 61 | for name in self.arg_names: 62 | setattr(self, name, getattr(args, name)) 63 | 64 | def make_dict(self): 65 | res = {} 66 | for name in self.arg_names: 67 | res[name] = getattr(self, name) 68 | res['description'] = self.experiment_target 69 | res['exec_time'] = self.exec_time 70 | res['commit_id'] = self.commit_id 71 | return res 72 | 73 | def parse(self): 74 | parser = argparse.ArgumentParser(description=EXPERIMENT_TARGET) 75 | 76 | self.env_name = "Hopper-v2" 77 | parser.add_argument('--env_name', default=self.env_name, metavar='G', 78 | help='name of the environment to run') 79 | self.register_param('env_name') 80 | 81 | self.model_path = "" 82 | parser.add_argument('--model_path', metavar='G', 83 | help='path of pre-trained model') 84 | self.register_param("model_path") 85 | 86 | self.render = False 87 | parser.add_argument('--render', action='store_true', default=self.render, 88 | help='render the environment') 89 | self.register_param("render") 90 | 91 | self.log_std = -1.5 92 | parser.add_argument('--log_std', type=float, default=self.log_std, metavar='G', 93 | help='log std for the policy (default: -0.0)') 94 | self.register_param('log_std') 95 | 96 | self.gamma = 0.99 97 | parser.add_argument('--gamma', type=float, default=self.gamma, metavar='G', 98 | help='discount factor (default: 0.99)') 99 | self.register_param('gamma') 100 | 101 | self.learning_rate = 3e-4 102 | parser.add_argument('--learning_rate', type=float, default=self.learning_rate, metavar='G', 103 | help='learning rate (default: 3e-4)') 104 | self.register_param('learning_rate') 105 | 106 | self.value_learning_rate = 1e-3 107 | parser.add_argument('--value_learning_rate', type=float, default=self.value_learning_rate, metavar='G', 108 | help='learning rate (default: 1e-3)') 109 | self.register_param('value_learning_rate') 110 | 111 | self.num_threads = 4 112 | parser.add_argument('--num_threads', type=int, default=self.num_threads, metavar='N', 113 | help='number of threads for agent (default: 1)') 114 | self.register_param('num_threads') 115 | 116 | self.seed = 1 117 | parser.add_argument('--seed', type=int, default=self.seed, metavar='N', 118 | help='random seed (default: 1)') 119 | self.register_param('seed') 120 | 121 | self.random_num = 4000 122 | parser.add_argument('--random_num', type=int, default=self.random_num, metavar='N', 123 | help='sample random_num fully random samples,') 124 | self.register_param('random_num') 125 | 126 | self.start_train_num = 20000 127 | parser.add_argument('--start_train_num', type=int, default=self.start_train_num, metavar='N', 128 | help='after reach start_train_num, training start') 129 | self.register_param('start_train_num') 130 | 131 | self.test_sample_num = 4000 132 | parser.add_argument('--test_sample_num', type=int, default=self.test_sample_num, metavar='N', 133 | help='sample num in test phase') 134 | self.register_param('test_sample_num') 135 | 136 | self.sac_update_time = 1000 137 | parser.add_argument('--sac_update_time', type=int, default=self.sac_update_time, metavar='N', 138 | help='update time after sampling a batch data') 139 | self.register_param('sac_update_time') 140 | 141 | self.sac_replay_size = 1e6 142 | parser.add_argument('--sac_replay_size', type=int, default=self.sac_replay_size, metavar='N', 143 | help='update time after sampling a batch data') 144 | self.register_param('sac_replay_size') 145 | 146 | self.min_batch_size = 1200 147 | parser.add_argument('--min_batch_size', type=int, default=self.min_batch_size, metavar='N', 148 | help='minimal sample number per iteration') 149 | self.register_param('min_batch_size') 150 | 151 | if FC_MODE: 152 | self.sac_mini_batch_size = 256 153 | parser.add_argument('--sac_mini_batch_size', type=int, default=self.sac_mini_batch_size, metavar='N', 154 | help='update time after sampling a batch data') 155 | self.register_param('sac_mini_batch_size') 156 | 157 | self.sac_inner_iter_num = 1 158 | parser.add_argument('--sac_inner_iter_num', type=int, default=self.sac_inner_iter_num, metavar='N', 159 | help='after sample several trajectories from replay buffer, ' 160 | 'sac_inner_iter_num mini-batch will be sampled from the batch, ' 161 | 'and model will be optimized for sac_inner_iter_num times.') 162 | self.register_param('sac_inner_iter_num') 163 | else: 164 | self.sac_mini_batch_size = 256 165 | parser.add_argument('--sac_mini_batch_size', type=int, default=self.sac_mini_batch_size, metavar='N', 166 | help='sac_mini_batch_size trajectories will be sampled from the replay buffer.') 167 | self.register_param('sac_mini_batch_size') 168 | 169 | self.sac_inner_iter_num = 1 170 | parser.add_argument('--sac_inner_iter_num', type=int, default=self.sac_inner_iter_num, metavar='N', 171 | help='after sample several trajectories from replay buffer, ' 172 | 'sac_inner_iter_num mini-batch will be sampled from the batch, ' 173 | 'and model will be optimized for sac_inner_iter_num times.') 174 | self.register_param('sac_inner_iter_num') 175 | 176 | self.rnn_sample_max_batch_size = 3e5 177 | parser.add_argument('--rnn_sample_max_batch_size', type=int, default=self.rnn_sample_max_batch_size, metavar='N', 178 | help='max point num sampled from replay buffer per time') 179 | self.register_param('rnn_sample_max_batch_size') 180 | 181 | self.rnn_slice_num = 16 182 | parser.add_argument('--rnn_slice_num', type=int, default=self.rnn_slice_num, metavar='N', 183 | help='gradient clip steps') 184 | self.register_param('rnn_slice_num') 185 | 186 | self.sac_tau = 0.995 187 | parser.add_argument('--sac_tau', type=float, default=self.sac_tau, metavar='N', 188 | help='ratio of coping value net to target value net') 189 | self.register_param('sac_tau') 190 | 191 | self.sac_alpha = 0.2 192 | parser.add_argument('--sac_alpha', type=float, default=self.sac_alpha, metavar='N', 193 | help='sac temperature coefficient') 194 | self.register_param('sac_alpha') 195 | 196 | self.reward_scale = 1.0 197 | parser.add_argument('--reward_scale', type=float, default=self.reward_scale, metavar='N', 198 | help='sac temperature coefficient') 199 | self.register_param('reward_scale') 200 | 201 | self.max_iter_num = 10000 202 | parser.add_argument('--max_iter_num', type=int, default=self.max_iter_num, metavar='N', 203 | help='maximal number of main iterations (default: 500)') 204 | self.register_param('max_iter_num') 205 | 206 | self.save_model_interval = 5 207 | parser.add_argument('--save_model_interval', type=int, default=self.save_model_interval, metavar='N', 208 | help="interval between saving model (default: 5, means don't save)") 209 | self.register_param('save_model_interval') 210 | 211 | self.std_learnable = 1 212 | parser.add_argument('--std_learnable', type=int, default=self.std_learnable, metavar='N', 213 | help="standard dev can be learned") 214 | self.register_param('std_learnable') 215 | 216 | self.update_interval = 1 217 | parser.add_argument('--update_interval', type=int, default=self.update_interval, metavar='N', 218 | help="standard dev can be learned") 219 | self.register_param('update_interval') 220 | 221 | self.ep_pretrain_path_suffix = 'None' # '-use_rmdm-rnn_len_32-ep_dim_2-1-debug' 222 | parser.add_argument('--ep_pretrain_path_suffix', type=str, default=self.ep_pretrain_path_suffix, metavar='N', 223 | help="environment probing pretrain model path") 224 | self.register_param('ep_pretrain_path_suffix') 225 | 226 | self.name_suffix = 'None' # '-use_rmdm-rnn_len_32-ep_dim_2-1-debug' 227 | parser.add_argument('--name_suffix', type=str, default=self.name_suffix, metavar='N', 228 | help="name suffix of the experiment") 229 | self.register_param('name_suffix') 230 | 231 | self.ep_apply_tau = 0.99 232 | parser.add_argument('--ep_apply_tau', type=float, default=self.ep_apply_tau, metavar='N', 233 | help="tau used to apply ep") 234 | self.register_param('ep_apply_tau') 235 | 236 | self.target_entropy_ratio = 1.5 237 | parser.add_argument('--target_entropy_ratio', type=float, default=self.target_entropy_ratio, metavar='N', 238 | help="target entropy") 239 | self.register_param('target_entropy_ratio') 240 | 241 | self.history_length = 0 242 | parser.add_argument('--history_length', type=int, default=self.history_length, metavar='N', 243 | help="interval between saving model (default: 0, means don't save)") 244 | self.register_param('history_length') 245 | 246 | self.task_num = 0 247 | parser.add_argument('--task_num', type=int, default=self.task_num, metavar='N', 248 | help="interval between saving model (default: 0, means don't save)") 249 | self.register_param('task_num') 250 | 251 | self.test_task_num = 0 252 | parser.add_argument('--test_task_num', type=int, default=self.test_task_num, metavar='N', 253 | help="number of tasks for testing") 254 | self.register_param('test_task_num') 255 | 256 | self.use_true_parameter = False 257 | parser.add_argument('--use_true_parameter', action='store_true') 258 | self.register_param("use_true_parameter") 259 | 260 | self.bottle_neck = False 261 | parser.add_argument('--bottle_neck', action='store_true') 262 | self.register_param("bottle_neck") 263 | 264 | self.transition_learn_aux = False 265 | parser.add_argument('--transition_learn_aux', action='store_true') 266 | self.register_param("transition_learn_aux") 267 | 268 | self.bottle_sigma = 1e-2 269 | parser.add_argument('--bottle_sigma', type=float, default=self.bottle_sigma, metavar='N', 270 | help="std of the noise injected to ep while inference (information bottleneck)") 271 | self.register_param('bottle_sigma') 272 | 273 | self.l2_norm_for_ep = 0.0 274 | parser.add_argument('--l2_norm_for_ep', type=float, default=self.l2_norm_for_ep, metavar='N', 275 | help="L2 norm added to EP module") 276 | self.register_param('l2_norm_for_ep') 277 | 278 | self.policy_max_gradient = 10 279 | parser.add_argument('--policy_max_gradient', type=float, default=self.policy_max_gradient, metavar='N', 280 | help="maximum gradient of policy") 281 | self.register_param('policy_max_gradient') 282 | 283 | 284 | self.use_rmdm = False 285 | parser.add_argument('--use_rmdm', action='store_true', 286 | help="use Relational Matrix Determinant Maximization or not") 287 | self.register_param('use_rmdm') 288 | 289 | self.use_uposi = False 290 | parser.add_argument('--use_uposi', action='store_true') 291 | self.register_param('use_uposi') 292 | 293 | self.uniform_sample = False 294 | parser.add_argument('--uniform_sample', action='store_true') 295 | self.register_param('uniform_sample') 296 | 297 | self.share_ep = False 298 | parser.add_argument('--share_ep', action='store_true') 299 | self.register_param('share_ep') 300 | 301 | self.enhance_ep = False 302 | parser.add_argument('--enhance_ep', action='store_true') 303 | self.register_param('enhance_ep') 304 | 305 | self.stop_pg_for_ep = False 306 | parser.add_argument('--stop_pg_for_ep', action='store_true') 307 | self.register_param('stop_pg_for_ep') 308 | 309 | self.use_contrastive = False 310 | parser.add_argument('--use_contrastive', action='store_true') 311 | self.register_param('use_contrastive') 312 | 313 | self.rmdm_update_interval = -1 314 | parser.add_argument('--rmdm_update_interval', type=int, default=self.rmdm_update_interval, metavar='N', 315 | help="update interval of rmdm") 316 | self.register_param('rmdm_update_interval') 317 | 318 | self.rnn_fix_length = 0 319 | parser.add_argument('--rnn_fix_length', type=int, default=self.rnn_fix_length, metavar='N', 320 | help="fix the rnn memory length to rnn_fix_length") 321 | self.register_param('rnn_fix_length') 322 | 323 | 324 | self.minimal_repre_rp_size = 1e5 325 | parser.add_argument('--minimal_repre_rp_size', type=float, default=self.minimal_repre_rp_size, metavar='N', 326 | help="after minimal_repre_rp_size, start training EP module") 327 | self.register_param('minimal_repre_rp_size') 328 | 329 | 330 | # self.ep_start_num = 150000 331 | self.ep_start_num = 0 332 | parser.add_argument('--ep_start_num', type=int, default=self.ep_start_num, metavar='N', 333 | help="only when the size of the replay buffer is larger than ep_start_num" 334 | ", ep can be learned") 335 | self.register_param('ep_start_num') 336 | 337 | self.kernel_type = 'rbf_element_wise' 338 | parser.add_argument('--kernel_type', default=self.kernel_type, metavar='G', 339 | help='kernel type for DPP loss computing (rbf/rbf_element_wise/inner)') 340 | self.register_param('kernel_type') 341 | 342 | self.rmdm_ratio = 1.0 343 | parser.add_argument('--rmdm_ratio', type=float, default=self.rmdm_ratio, metavar='N', 344 | help="gradient ratio of rmdm") 345 | self.register_param('rmdm_ratio') 346 | 347 | self.test_variable = 1.0 348 | parser.add_argument('--test_variable', type=float, default=self.test_variable, metavar='N', 349 | help="variable for testing variable") 350 | self.register_param('test_variable') 351 | 352 | self.rmdm_tau = 0.995 353 | parser.add_argument('--rmdm_tau', type=float, default=self.rmdm_tau, metavar='N', 354 | help="smoothing ratio of the representation") 355 | self.register_param('rmdm_tau') 356 | 357 | self.repre_loss_factor = 1.0 358 | parser.add_argument('--repre_loss_factor', type=float, default=self.repre_loss_factor, metavar='N', 359 | help="size of the representation loss") 360 | self.register_param('repre_loss_factor') 361 | 362 | self.ep_smooth_factor = 0.0 363 | parser.add_argument('--ep_smooth_factor', type=float, default=self.ep_smooth_factor, metavar='N', 364 | help="smooth factor for ep module, 0.0 for apply concurrently") 365 | self.register_param('ep_smooth_factor') 366 | 367 | self.rbf_radius = 80.0 368 | parser.add_argument('--rbf_radius', type=float, default=self.rbf_radius, metavar='N', 369 | help="radius of the rbf kerel") 370 | self.register_param('rbf_radius') 371 | 372 | self.env_default_change_range = 3.0 373 | parser.add_argument('--env_default_change_range', type=float, default=self.env_default_change_range, metavar='N', 374 | help="environment default change range") 375 | self.register_param('env_default_change_range') 376 | 377 | self.env_ood_change_range = 4.0 378 | parser.add_argument('--env_ood_change_range', type=float, default=self.env_ood_change_range, 379 | metavar='N', 380 | help="environment OOD change range") 381 | self.register_param('env_ood_change_range') 382 | 383 | self.consistency_loss_weight = 50.0 384 | parser.add_argument('--consistency_loss_weight', type=float, default=self.consistency_loss_weight, metavar='N', 385 | help="loss ratio of the consistency loss") 386 | self.register_param('consistency_loss_weight') 387 | 388 | self.diversity_loss_weight = 0.025 389 | parser.add_argument('--diversity_loss_weight', type=float, default=self.diversity_loss_weight, metavar='N', 390 | help="loss ratio of the DPP loss") 391 | self.register_param('diversity_loss_weight') 392 | 393 | self.varying_params = ['gravity', 'body_mass'] 394 | parser.add_argument('--varying_params', nargs='+', type=str, default=self.varying_params) 395 | self.register_param('varying_params') 396 | 397 | self.up_hidden_size = [128, 64] 398 | parser.add_argument('--up_hidden_size', nargs='+', type=int, default=self.up_hidden_size, 399 | help="architecture of the hidden layers of Universe Policy") 400 | self.register_param('up_hidden_size') 401 | 402 | self.up_activations = ['leaky_relu', 'leaky_relu', 'linear'] 403 | parser.add_argument('--up_activations', nargs='+', type=str, 404 | default=self.up_activations, 405 | help="activation of each layer of Universe Policy") 406 | self.register_param('up_activations') 407 | 408 | self.up_layer_type = ['fc', 'fc', 'fc'] 409 | parser.add_argument('--up_layer_type', nargs='+', type=str, 410 | default=self.up_layer_type, 411 | help="net type of Universe Policy") 412 | self.register_param('up_layer_type') 413 | 414 | # self.ep_hidden_size = [128, 64, 32] # [128, 64, 32] 415 | self.ep_hidden_size = [128, 64] # [256, 128] 416 | parser.add_argument('--ep_hidden_size', nargs='+', type=int, default=self.ep_hidden_size, 417 | help="architecture of the hidden layers of Environment Probing Net") 418 | self.register_param('ep_hidden_size') 419 | 420 | if FC_MODE: 421 | self.ep_activations = ['leaky_relu', 'leaky_relu', 'leaky_relu', 'tanh'] 422 | parser.add_argument('--ep_activations', nargs='+', type=str, 423 | default=self.ep_activations, 424 | help="activation of each layer of Environment Probing Net") 425 | self.register_param('ep_activations') 426 | 427 | self.ep_layer_type = ['fc', 'fc', 'fc', 'fc'] 428 | parser.add_argument('--ep_layer_type', nargs='+', type=str, 429 | default=self.ep_layer_type, 430 | help="net type of Environment Probing Net") 431 | self.register_param('ep_layer_type') 432 | else: 433 | # original RNN architecture 434 | self.ep_activations = ['leaky_relu', 'linear', 'tanh'] 435 | parser.add_argument('--ep_activations', nargs='+', type=str, 436 | default=self.ep_activations, 437 | help="activation of each layer of Environment Probing Net") 438 | self.register_param('ep_activations') 439 | self.ep_layer_type = ['fc', 'gru', 'fc'] 440 | parser.add_argument('--ep_layer_type', nargs='+', type=str, 441 | default=self.ep_layer_type, 442 | help="net type of Environment Probing Net") 443 | self.register_param('ep_layer_type') 444 | # fc architecture 445 | # self.ep_activations = ['leaky_relu', 'tanh', 'tanh'] 446 | # parser.add_argument('--ep_activations', nargs='+', type=str, 447 | # default=self.ep_activations, 448 | # help="activation of each layer of Environment Probing Net") 449 | # self.register_param('ep_activations') 450 | # self.ep_layer_type = ['fc', 'fc', 'fc'] 451 | # parser.add_argument('--ep_layer_type', nargs='+', type=str, 452 | # default=self.ep_layer_type, 453 | # help="net type of Environment Probing Net") 454 | # self.register_param('ep_layer_type') 455 | 456 | self.ep_dim = 2 457 | parser.add_argument('--ep_dim', type=int, default=self.ep_dim, metavar='N', 458 | help="dimension of environment features") 459 | self.register_param('ep_dim') 460 | 461 | self.value_hidden_size = [128, 64] 462 | parser.add_argument('--value_hidden_size', nargs='+', type=int, default=self.value_hidden_size, 463 | help="architecture of the hidden layers of value") 464 | self.register_param('value_hidden_size') 465 | 466 | self.value_activations = ['leaky_relu', 'leaky_relu', 'linear'] 467 | parser.add_argument('--value_activations', nargs='+', type=str, 468 | default=self.value_activations, 469 | help="activation of each layer of value") 470 | self.register_param('value_activations') 471 | 472 | self.value_layer_type = ['fc', 'fc', 'fc'] 473 | parser.add_argument('--value_layer_type', nargs='+', type=str, 474 | default=self.value_layer_type, 475 | help="net type of value") 476 | self.register_param('value_layer_type') 477 | 478 | return parser.parse_args() 479 | 480 | def register_param(self, name): 481 | self.arg_names.append(name) 482 | 483 | def get_experiment_description(self): 484 | description = f"本机{self.host_name}, ip为{self.ip}\n" 485 | description += f"目前实验目的为{self.experiment_target}\n" 486 | description += f"实验简称: {self.short_name}\n" 487 | description += f"commit id: {self.commit_id}\n" 488 | vars = '' 489 | important_config = self.important_configs() 490 | for name in self.arg_names: 491 | if name in important_config: 492 | vars += f'**{name}**: {getattr(self, name)}\n' 493 | else: 494 | vars += f'{name}: {getattr(self, name)}\n' 495 | for name in self.DEFAULT_CONFIGS: 496 | vars += f'{name}: {self.DEFAULT_CONFIGS[name]}\n' 497 | return description + vars 498 | 499 | def __str__(self): 500 | return self.get_experiment_description() 501 | 502 | def clear_local_file(self): 503 | cmd = f'rm -f {os.path.join(self.config_path, self.json_name)} {os.path.join(self.config_path, self.txt_name)}' 504 | system(cmd) 505 | 506 | def save_config(self): 507 | self.info(f'save json config to {os.path.join(self.config_path, self.json_name)}') 508 | if not os.path.exists(self.config_path): 509 | os.makedirs(self.config_path) 510 | with open(os.path.join(self.config_path, self.json_name), 'w') as f: 511 | things = self.make_dict() 512 | ser = json.dumps(things) 513 | f.write(ser) 514 | self.info(f'save readable config to {os.path.join(self.config_path, self.txt_name)}') 515 | with open(os.path.join(self.config_path, self.txt_name), 'w') as f: 516 | print(self, file=f) 517 | 518 | def load_config(self): 519 | self.info(f'load json config from {os.path.join(self.config_path, self.json_name)}') 520 | with open(os.path.join(self.config_path, self.json_name), 'r') as f: 521 | ser = json.load(f) 522 | for k, v in ser.items(): 523 | if not k == 'description': 524 | setattr(self, k, v) 525 | self.register_param(k) 526 | self.experiment_target = ser['description'] 527 | 528 | @property 529 | def differences(self): 530 | if not os.path.exists(os.path.join(self.config_path, self.json_name)): 531 | return None 532 | with open(os.path.join(self.config_path, self.json_name), 'r') as f: 533 | ser = json.load(f) 534 | differences = [] 535 | for k, v in ser.items(): 536 | if not hasattr(self, k): 537 | differences.append(k) 538 | else: 539 | v2 = getattr(self, k) 540 | if not v2 == v: 541 | differences.append(k) 542 | return differences 543 | 544 | def check_identity(self, need_decription=False, need_exec_time=False): 545 | if not os.path.exists(os.path.join(self.config_path, self.json_name)): 546 | self.info(f'{os.path.join(self.config_path, self.json_name)} not exists') 547 | return False 548 | with open(os.path.join(self.config_path, self.json_name), 'r') as f: 549 | ser = json.load(f) 550 | flag = True 551 | for k, v in ser.items(): 552 | if not k == 'description' and not k == 'exec_time': 553 | if not hasattr(self, k): 554 | flag = False 555 | return flag 556 | v2 = getattr(self, k) 557 | if not v2 == v: 558 | flag = False 559 | return flag 560 | if need_decription: 561 | if not self.experiment_target == ser['description']: 562 | flag = False 563 | return flag 564 | if need_exec_time: 565 | if not self.exec_time == ser['exec_time']: 566 | flag = False 567 | return flag 568 | return flag 569 | 570 | @property 571 | def short_name(self): 572 | name = '' 573 | for item in self.important_configs(): 574 | value = getattr(self, item) 575 | if value: 576 | if item == 'env_name': 577 | name += value 578 | elif item == 'seed': 579 | name += f'-{value}' 580 | elif item == 'rnn_fix_length': 581 | name += f"-rnn_len_{value}" 582 | elif item == 'ep_dim': 583 | name += f"-ep_dim_{value}" 584 | else: 585 | name += f'-{item}' 586 | if self.debug: 587 | name += '-debug' 588 | if self.information is not None: 589 | name += '-{}'.format(self.information) 590 | if hasattr(self, 'name_suffix') and not self.name_suffix == 'None': 591 | name += f'_{self.name_suffix}' 592 | elif not len(SHORT_NAME_SUFFIX) == 0: 593 | name += f'_{SHORT_NAME_SUFFIX}' 594 | return name 595 | 596 | def get_commit_id(self): 597 | base_path = get_base_path() 598 | cmd = f'cd {base_path} && git log' 599 | commit_id = None 600 | try: 601 | with os.popen(cmd) as f: 602 | line = f.readline() 603 | words = line.split(' ') 604 | commit_id = words[-1][:-1] 605 | except Exception as e: 606 | self.info(f'Error occurs while fetching commit id!!! {e}') 607 | return commit_id 608 | 609 | 610 | if __name__ == '__main__': 611 | parameter = Parameter() 612 | parameter.get_commit_id() 613 | print(parameter) 614 | 615 | 616 | --------------------------------------------------------------------------------