├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data ├── gridworld.10s10t │ ├── train.10.test.90.split │ ├── train.20.test.80.split │ ├── train.30.test.70.split │ ├── train.40.test.60.split │ ├── train.50.test.50.split │ ├── train.60.test.40.split │ ├── train.70.test.30.split │ └── train.80.test.20.split └── gridworld.20s20t │ ├── hard.split │ └── hard_extend.split ├── extend_gridworld.py ├── requirements.txt ├── synpo ├── __init__.py ├── agent │ ├── Grid_agent.py │ └── __init__.py ├── component │ ├── __init__.py │ ├── policy.py │ ├── replay.py │ └── task.py ├── network │ ├── __init__.py │ ├── base_network.py │ ├── grid_network.py │ └── operator.py └── utils │ ├── __init__.py │ ├── config.py │ ├── tf_logger.py │ ├── trainer.py │ └── utils.py ├── tools ├── generate_gridworld_extend.py └── generate_gridworld_split.py └── train_gridworld.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/__pycache__ 6 | log 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "synpo/gridworld"] 2 | path = synpo/gridworld 3 | url = git@github.com:Sha-Lab/gridworld.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hexiang Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [**Syn**thesized **Po**licies (SynPo)](https://sites.google.com/view/neurips2018-synpo/home) 2 | 3 | This repository implements the SynPo algorithms presented in: 4 | 5 | - Hu, Hexiang, Liyu Chen, Boqing Gong, and Fei Sha. "Synthesized Policies for Transfer and Adaptation across Tasks and Environments." In Advances in Neural Information Processing Systems, pp. 1176-1185. 2018. 6 | 7 | ## Requirements 8 | 9 | - Python 3+ 10 | - Numpy 1.10.0+ 11 | - Pytorch 0.4.0 12 | 13 | Please see [requirements.txt](https://github.com/Sha-Lab/SynPo/blob/master/requirements.txt) for complete details 14 | 15 | ## Usage 16 | 17 | We release our source code based on the gridworld environment in this [repo](https://github.com/Sha-Lab/gridworld). The usage is listed as following: 18 | 19 | ```bash 20 | usage: train_gridworld.py [-h] [--gpu_id GPU_ID] [--batch_size BATCH_SIZE] 21 | [--weight WEIGHT] [--scene SCENE] [--task TASK] 22 | [--embedding_dim EMBEDDING_DIM] 23 | [--scene_embedding_dim SCENE_EMBEDDING_DIM] 24 | [--task_embedding_dim TASK_EMBEDDING_DIM] 25 | [--num_obj_types NUM_OBJ_TYPES] 26 | [--task_length TASK_LENGTH] 27 | [--update_interval UPDATE_INTERVAL] 28 | [--scene_num SCENE_NUM] [--task_num TASK_NUM] 29 | [--reward_prediction REWARD_PREDICTION] 30 | [--scene_disentanglement SCENE_DISENTANGLEMENT] 31 | [--task_disentanglement TASK_DISENTANGLEMENT] 32 | --split_filepath SPLIT_FILEPATH [--lr LR] [--wd] 33 | [--mode {cloning}] [--network {mlp,mtl,synpo}] 34 | [--postfix POSTFIX] [--repeat REPEAT] [--evaluate] 35 | [--visualize] [--random_seed RANDOM_SEED] 36 | [--logger_name LOGGER_NAME] [--norm] 37 | 38 | optional arguments: 39 | -h, --help show this help message and exit 40 | --gpu_id GPU_ID 41 | --batch_size BATCH_SIZE 42 | --weight WEIGHT 43 | --scene SCENE 44 | --task TASK 45 | --embedding_dim EMBEDDING_DIM 46 | --scene_embedding_dim SCENE_EMBEDDING_DIM 47 | --task_embedding_dim TASK_EMBEDDING_DIM 48 | --num_obj_types NUM_OBJ_TYPES 49 | --task_length TASK_LENGTH 50 | --update_interval UPDATE_INTERVAL 51 | --scene_num SCENE_NUM 52 | --task_num TASK_NUM 53 | --reward_prediction REWARD_PREDICTION 54 | loss weight of reward prediction objective 55 | --scene_disentanglement SCENE_DISENTANGLEMENT 56 | loss weight of scene disentanglement prediction 57 | objective 58 | --task_disentanglement TASK_DISENTANGLEMENT 59 | loss weight of task disentanglement prediction 60 | objective 61 | --split_filepath SPLIT_FILEPATH 62 | train/test split filepath 63 | --lr LR base learning rate 64 | --wd enable weight decay 65 | --mode {cloning} training mode [only behavior cloing available for now] 66 | --network {mlp,mtl,synpo} 67 | select model architecture 68 | --postfix POSTFIX postfix to the log file 69 | --repeat REPEAT number of test run 70 | --evaluate evaluation mode 71 | --visualize visualize policy [only in evaluation mode] 72 | --random_seed RANDOM_SEED 73 | random seed value 74 | --logger_name LOGGER_NAME 75 | logger name format [must have for slots to fill] 76 | --norm whether normalize the scene/task embedding 77 | ``` 78 | 79 | ## References 80 | 81 | If you are using any resources within this repo for your research, please cite: 82 | 83 | ``` 84 | @inproceedings{hu2018synthesize, 85 | title={Synthesized Policies for Transfer and Adaptation across Tasks and Environments}, 86 | author={Hu, Hexiang and Chen, Liyu and Gong, Boqing and Sha, Fei}, 87 | booktitle={Advances in Neural Information Processing Systems}, 88 | pages={1176--1185}, 89 | year={2018} 90 | } 91 | ``` 92 | 93 | ## Acknolwedgement 94 | Part of the source code is modified based on the pytorch [DeepRL](https://github.com/ShangtongZhang/DeepRL) repo. We thank the original author for open source their implementation. 95 | 96 | ## License 97 | SynPo is MIT licensed, as found in the LICENSE file. 98 | 99 | 100 | -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.10.test.90.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.10.test.90.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.20.test.80.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.20.test.80.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.30.test.70.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.30.test.70.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.40.test.60.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.40.test.60.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.50.test.50.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.50.test.50.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.60.test.40.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.60.test.40.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.70.test.30.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.70.test.30.split -------------------------------------------------------------------------------- /data/gridworld.10s10t/train.80.test.20.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.10s10t/train.80.test.20.split -------------------------------------------------------------------------------- /data/gridworld.20s20t/hard.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.20s20t/hard.split -------------------------------------------------------------------------------- /data/gridworld.20s20t/hard_extend.split: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/data/gridworld.20s20t/hard_extend.split -------------------------------------------------------------------------------- /extend_gridworld.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import ipdb 4 | from datetime import datetime 5 | from itertools import product 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pickle 9 | from IPython import embed 10 | from ipdb import slaunch_ipdb_on_exception 11 | 12 | from synpo.agent import * 13 | from synpo.component import * 14 | from synpo.utils import * 15 | import synpo.gridworld as gridworld 16 | 17 | from synpo.utils import mkdir, set_seed 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--gpu_id', default=0, type=int) 22 | parser.add_argument('--batch_size', default=128, type=int) 23 | parser.add_argument('--weight', default=None, type=str) 24 | parser.add_argument('--scene', default=None, type=int) # use for evaluate 25 | parser.add_argument('--task', default=None, type=int) 26 | parser.add_argument('--embedding_dim', default=128, type=int) 27 | parser.add_argument('--scene_embedding_dim', default=128, type=int) 28 | parser.add_argument('--task_embedding_dim', default=128, type=int) 29 | parser.add_argument('--num_obj_types', default=5, type=int) 30 | parser.add_argument('--task_length', default=2, type=int) 31 | parser.add_argument('--update_interval', default=1, type=int) 32 | parser.add_argument('--scene_num', default=20, type=int) 33 | parser.add_argument('--task_num', default=20, type=int) 34 | parser.add_argument('--reward_prediction', action='store_false') 35 | parser.add_argument('--split_filepath', default=None, type=str) 36 | parser.add_argument('--wd', action='store_true') 37 | parser.add_argument('--scene_prediction', default=1, type=int) 38 | parser.add_argument('--task_prediction', default=1, type=int) 39 | parser.add_argument('--option', default='normal', choices=['less', 'normal', 'more']) 40 | parser.add_argument('--normalize_embedding', action='store_true') 41 | 42 | parser.add_argument('--network', default='synpo', choices=['mlp', 'mtl', 'synpo']) 43 | parser.add_argument('--postfix', default=0, type=str) 44 | parser.add_argument('--repeat', default=10, type=int) 45 | parser.add_argument('--evaluate', action='store_true') 46 | parser.add_argument('--visualize', action='store_true') 47 | parser.add_argument('--random_seed', default=0, type=int) 48 | parser.add_argument('--logger_name', default='log/synpo_{}_{}_{}_{}.log', type=str) 49 | parser.add_argument('--norm', action='store_true') 50 | parser.add_argument('--y_norm', action='store_true') 51 | parser.add_argument('--setting', type=int, default=2) 52 | parser.add_argument('--one_traj', action='store_true') 53 | args = parser.parse_args() 54 | 55 | def get_network(task): 56 | arg_dim = task.env.observation_space.spaces[1].shape[0] 57 | grid_dim = task.env.observation_space.spaces[0].shape[0] 58 | action_dim = task.env.action_space.n 59 | if args.network == 'mlp': 60 | network = GridWorldMLP(grid_dim, action_dim, arg_dim, 61 | scene_num=args.scene_num, 62 | task_num=args.task_num, 63 | embed_dim=args.embedding_dim, 64 | scene_dim=args.scene_embedding_dim, 65 | task_dim=args.task_embedding_dim, 66 | gpu=args.gpu_id, 67 | scene_disentanglement=args.scene_disentanglement, 68 | task_disentanglement=args.task_disentanglement, 69 | norm=args.norm) 70 | elif args.network == 'mtl': 71 | network = GridWorldMTL(grid_dim, action_dim, arg_dim, 72 | scene_num=args.scene_num, 73 | task_num=args.task_num, 74 | embed_dim=args.embedding_dim, 75 | scene_dim=args.scene_embedding_dim, 76 | task_dim=args.task_embedding_dim, 77 | gpu=args.gpu_id, 78 | scene_disentanglement=args.scene_disentanglement, 79 | task_disentanglement=args.task_disentanglement, 80 | norm=args.norm) 81 | elif args.network == 'synpo': 82 | network = GridWorldSynPo(grid_dim, action_dim, arg_dim, 83 | scene_num=args.scene_num, 84 | task_num=args.task_num, 85 | embed_dim=args.embedding_dim, 86 | scene_dim=args.scene_embedding_dim, 87 | task_dim=args.task_embedding_dim, 88 | gpu=args.gpu_id, 89 | norm=args.norm) 90 | else: 91 | raise ValueError('Non-supported Network') 92 | return network 93 | 94 | def gridworld_behaviour_cloning(args, layouts, train_combos, test_combos): 95 | config = Config() 96 | grid_world_task = GridWorldTask(layouts, 97 | num_obj_types=args.num_obj_types, 98 | task_length=args.task_length, 99 | history_length= config.history_length, 100 | train_combos=train_combos, 101 | test_combos=test_combos) 102 | config.task_fn = lambda: grid_world_task 103 | if args.wd: 104 | print('with weight decay!') 105 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=0.001, weight_decay=10e-5) 106 | else: 107 | print('without weight decay!') 108 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=0.001) 109 | network = get_network(grid_world_task) 110 | if args.weight is not None: 111 | weight_to_resume = torch.load(args.weight, map_location=lambda storage, loc: storage)['best_model_weight'] 112 | config.extend = copy.deepcopy(weight_to_resume) 113 | if args.normalize_embedding: 114 | for k, v in weight_to_resume.items(): 115 | if 'embed' in k: 116 | weight_to_resume[k] = F.normalize(v) 117 | network.load_state_dict(weight_to_resume) 118 | for k, v in network.named_parameters(): 119 | if args.option == 'more': 120 | if 'embed' not in k and 'refc' not in k and 'reward_fc' not in k and 'policy_fc' not in k: v.requires_grad = False 121 | elif args.option == 'normal': 122 | if 'embed' not in k and 'refc' not in k: v.requires_grad = False 123 | elif args.option == 'less': 124 | if 'embed' not in k: v.requires_grad = False 125 | else: print('with grad: {}'.format(k)) 126 | else: 127 | raise Exception('unsupported option') 128 | print(network) 129 | 130 | config.network_fn = lambda: network 131 | config.replay_fn = lambda: TrajectoryReplay(memory_size=10000, max_length=200, batch_size=64) 132 | config.policy_fn = lambda: GreedyPolicy(epsilon=0.1, final_step=500000, min_epsilon=0.0) 133 | config.logger = Logger('./log', logger) 134 | config.test_interval = 2000 135 | config.max_eps = 10000 # 6000 136 | config.exploration_steps = 50000 137 | config.postfix = args.postfix 138 | config.tag = network.__class__.__name__ 139 | config.update_interval = 1 # preset 140 | config.one_traj = args.one_traj 141 | return GridBehaviourCloning(config) 142 | 143 | if __name__ == '__main__': 144 | mkdir('data') 145 | mkdir('data/video') 146 | mkdir('log') 147 | os.system('export OMP_NUM_THREADS=1') 148 | 149 | set_seed(args.random_seed, c=args.random_seed) 150 | layouts = ['map{}'.format(i) for i in range(1, 21) ] 151 | 152 | if args.setting == 2: 153 | train_combos = [(i, j) for i, j in product(range(10, args.scene_num), range(10, args.task_num))] 154 | elif args.setting == 3: 155 | train_combos = [(i, j) for i, j in product(range(10), range(10, args.task_num))] + [(i, j) for i, j in product(range(10, args.scene_num), range(10))] 156 | else: 157 | raise Exception('error') 158 | # train_combos = [(i, j) for i, j in product(range(args.scene_num), range(args.task_num))] 159 | test_combos = [(i, j) for i, j in product(range(10, args.scene_num), range(10, args.task_num))] 160 | 161 | agent = gridworld_behaviour_cloning(args, layouts, train_combos, test_combos) 162 | 163 | agent.reward_prediction = args.reward_prediction 164 | if args.split_filepath is None: # Default Multi-task Setting 165 | agent.split_name = 'MTL' 166 | else: 167 | agent.split_name = "-".join(args.split_filepath.split('/')[-2:]) 168 | if args.evaluate: 169 | if args.scene is not None and args.task is not None: 170 | for _ in tqdm(range(args.repeat)): 171 | success, traj_len = agent.evaluate(visualize=args.visualize, index=(args.scene, args.task)) # main program 172 | else: 173 | rates = [] 174 | # for combo in test_combos: 175 | for combo in test_combos: 176 | success_list = [] 177 | trajectory_list = [] 178 | for _ in tqdm(range(args.repeat)): 179 | success, traj_len, _ = agent.evaluate(visualize=args.visualize, index=combo) # main program 180 | success_list.append(success) 181 | trajectory_list.append(traj_len) 182 | success_rate = sum(success_list) / len(success_list) 183 | rates.append(success_rate) 184 | print('* [Task={}, # of Tests={}] Average success rate: {:.4f}, Average trajectory length: {}'.format( combo, args.repeat, 185 | success_rate, sum(trajectory_list) / len(trajectory_list) )) 186 | print('average success rate: {:.4f}'.format(np.mean(rates))) 187 | else: 188 | logger.setLevel(logging.INFO) 189 | handler = logging.FileHandler(args.logger_name.format(agent.__class__.__name__, 190 | agent.learning_network.__class__.__name__, 191 | datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), 192 | args.postfix)) 193 | logger.addHandler(handler) 194 | with slaunch_ipdb_on_exception(): 195 | train_agent(agent) # main program 196 | 197 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python 2 | networkx==2.1 3 | scipy==1.0.0 4 | gym==0.10.4 5 | ipdb==0.11 6 | h5py==2.7.1 7 | readchar==0.7 8 | torch==0.3.1.post3 9 | pygame==1.9.3 10 | numpy==1.14.0 11 | tqdm==4.19.9 12 | ipython==6.5.0 13 | Pillow==5.2.0 14 | imageio==2.3.0 15 | tensorboardX==1.4 16 | torchvision==0.2.1 17 | -------------------------------------------------------------------------------- /synpo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sha-Lab/SynPo/8ac35a01d2c810187b9c14b914bcb792ed73caa9/synpo/__init__.py -------------------------------------------------------------------------------- /synpo/agent/Grid_agent.py: -------------------------------------------------------------------------------- 1 | from synpo.network import * 2 | from synpo.component import * 3 | from synpo.utils import * 4 | import numpy as np 5 | import time 6 | import os 7 | import pickle 8 | import torch 9 | from torch.nn.utils.clip_grad import clip_grad_norm 10 | import random 11 | import copy 12 | 13 | from synpo.utils import argmax1d, extract 14 | 15 | # Information profiling 16 | def profile_action_value(action, _action, value, q_value): 17 | print('[{}] prediction: {}, ground truth: {} ({:.4f}/{:.4f})'.format( \ 18 | 'FAIL' if action != _action else 'BINGO', \ 19 | action, _action, float(value[_action]), float(q_value))) 20 | 21 | if torch.cuda.is_available(): 22 | log_softmax = nn.LogSoftmax(dim=1).cuda() 23 | else: 24 | log_softmax = nn.LogSoftmax(dim=1) 25 | mse_criterion = nn.MSELoss() 26 | cross_entropy_criterion = lambda x, y: ( -log_softmax(x) * y ).sum(dim=1).mean() 27 | 28 | # also use qs 29 | def imitation_loss(agent, experiences): 30 | states, actions, rewards, qs, scene_ids, task_ids = extract(experiences, 31 | 'states', 'actions', 'rewards', 'qs', 'scene_ids', 'task_ids') 32 | states = agent.task.normalize_state(states) 33 | # qs = np.asarray([ q for q in qs ]) 34 | # act_id = np.argmax(qs, axis=1) 35 | action = agent.learning_network.variable(np.eye(5)[actions]) 36 | agent.cached_value = agent.cached_value or agent.learning_network.predict(states, scene_ids, task_ids, False) 37 | 38 | loss = 0 39 | if isinstance(agent.cached_value, tuple) and ( isinstance(agent.learning_network, GridWorldSynPo) or \ 40 | isinstance(agent.learning_network, GridWorldMTL) or \ 41 | isinstance(agent.learning_network, GridWorldMLP) ): 42 | act, _, scene_scores, task_scores = agent.cached_value 43 | actions = agent.learning_network.variable(actions, torch.LongTensor).view(-1, 1, 1) 44 | 45 | if scene_scores is not None: 46 | if len(scene_scores.size()) > 2: scene_scores = scene_scores.gather(1, actions.expand(-1, 1, scene_scores.size(2))).squeeze(1) 47 | scene_gt = agent.learning_network.variable(np.eye(scene_scores.size(1))[scene_ids]) 48 | 49 | loss += agent.config.scene_disentanglement_coeff * cross_entropy_criterion(scene_scores, scene_gt) 50 | 51 | if task_scores is not None: 52 | if len(task_scores.size()) > 2: task_scores = task_scores.gather(1, actions.expand(-1, 1, task_scores.size(2))).squeeze(1) 53 | task_gt = agent.learning_network.variable(np.eye(task_scores.size(1))[task_ids]) 54 | 55 | loss += agent.config.task_disentanglement_coeff * cross_entropy_criterion(task_scores, task_gt) 56 | elif isinstance(agent.cached_value, tuple): 57 | act = agent.cached_value[0] 58 | else: 59 | act = agent.cached_value 60 | loss += cross_entropy_criterion(act, action) 61 | return loss 62 | 63 | def reward_prediction_loss(agent, experiences): 64 | assert isinstance(agent.learning_network, ValueNet) 65 | states, actions, rewards, qs, scene_ids, task_ids = extract(experiences, 66 | 'states', 'actions', 'rewards', 'qs', 'scene_ids', 'task_ids') 67 | states = agent.task.normalize_state(states) 68 | agent.cached_value = agent.cached_value or agent.learning_network.predict(states, scene_ids, task_ids, False) 69 | 70 | r = agent.cached_value[1] 71 | rewards = agent.learning_network.variable(rewards) 72 | r = r[np.arange(len(actions)), actions] 73 | loss = mse_criterion(r, rewards) 74 | return loss 75 | 76 | class GridAgent: 77 | def __init__(self, config): 78 | self.config = config 79 | self.learning_network = config.network_fn() 80 | self.target_network = config.network_fn() 81 | self.target_network.load_state_dict(self.learning_network.state_dict()) 82 | self.task = config.task_fn() 83 | self.replay = config.replay_fn() 84 | self.policy = config.policy_fn() 85 | self.total_steps = 0 86 | self.update_interval = config.update_interval 87 | # add reward loss or not 88 | self.reward_prediction = False 89 | # cache calculation 90 | self.cached_value = None 91 | 92 | def episode(self, train=True, env_kwargs={}): 93 | raise NotImplementedError('Re-Write this method.') 94 | 95 | def close(self): 96 | pass 97 | 98 | def evaluate(self, visualize=False, step_time=0.1, seed=None, index=None, optimal=False): 99 | assert index is not None, 'just because I set default to None does not mean that you can leave it None' 100 | if seed is not None: 101 | self.task.seed(seed) 102 | heat_map = np.zeros((16, 16)) # does not count initialized position 103 | rng = copy.deepcopy(self.task.env.unwrapped.random) 104 | actions = [] 105 | state = self.task.reset(index, sample_pos=True) 106 | trajectory = [] 107 | accum_rewards = [] 108 | while True: 109 | if visualize: 110 | self.task.env.unwrapped.render() # change 111 | time.sleep(step_time) 112 | if optimal: 113 | value = self.task.get_qs(self.config.discount) 114 | else: 115 | value = self.learning_network.predict(self.task.normalize_state([state]), 116 | np.asarray([index[0]]), np.asarray([index[1]]), 117 | to_numpy=True, evaluate=True).flatten() 118 | action = np.argmax(value) 119 | actions.append(action) 120 | #action = self.task.get_opt_action() 121 | state, reward, done, _ = self.task.step(action) 122 | heat_map[self.task.pos()] = 1 123 | trajectory.append(action) 124 | accum_rewards.append(reward) 125 | 126 | if done: break 127 | return (reward > 10), len(trajectory), sum(accum_rewards), heat_map, rng, actions 128 | 129 | def evaluate2image(self, index, seed=None): 130 | if seed is not None: 131 | self.task.seed(seed) 132 | state = self.task.reset(index, sample_pos=True) 133 | trajectory = [] 134 | while True: 135 | trajectory.append(self.task.env.pretty_render().astype(np.uint8)) 136 | value = self.learning_network.predict(self.task.normalize_state([state]), 137 | np.asarray([index[0]]), np.asarray([index[1]]), 138 | to_numpy=True, evaluate=True).flatten() 139 | action = np.argmax(value) 140 | state, reward, done, _ = self.task.step(action) 141 | 142 | if done: break 143 | return (reward > 0), len(trajectory), trajectory 144 | 145 | class GridBehaviourCloning(GridAgent): 146 | def __init__(self, config): 147 | super(GridBehaviourCloning, self).__init__(config) 148 | keys = self.learning_network.state_dict().keys() 149 | self.optimizer = config.optimizer_fn( [ v for v in self.learning_network.parameters() if v.requires_grad ]) 150 | self.grad_clip = config.grad_clip 151 | self.one_traj = config.one_traj 152 | 153 | def episode(self, train=True, env_kwargs={}): 154 | self.cached_value = None 155 | episode_start_time = time.time() 156 | state = self.task.reset(**env_kwargs) 157 | scene_id, task_id = self.task.env.unwrapped.index() 158 | total_reward = 0.0 159 | steps = 0 160 | total_loss = [] 161 | while True: 162 | if not train: 163 | value = self.learning_network.predict(self.task.normalize_state([state]), 164 | np.asarray([scene_id]), np.asarray([task_id]), 165 | True, evaluate=True).flatten() 166 | qs = self.task.get_qs(discount=self.config.discount) 167 | if self.total_steps < self.config.exploration_steps: 168 | #action = self.policy.sample(qs, train=train) 169 | action = argmax1d(qs) 170 | else: 171 | if train: 172 | #action = self.policy.sample(qs, train=train) 173 | action = argmax1d(qs) 174 | else: 175 | action = self.policy.sample(value, train=train) 176 | next_state, reward, done, _ = self.task.step(action) 177 | total_reward += reward 178 | if train: 179 | if not self.one_traj or (scene_id, task_id) not in self.replay.combo_ids: 180 | self.replay.feed([state, action, reward, None, qs, next_state, scene_id, task_id, int(done)]) 181 | self.total_steps += 1 182 | steps += 1 183 | state = next_state 184 | if done and train and self.total_steps > self.config.exploration_steps: 185 | experiences = self.replay.sample() 186 | if isinstance(self.learning_network, ValueNet): 187 | if self.reward_prediction: 188 | loss = imitation_loss(self, experiences) + 0.01 * reward_prediction_loss(self, experiences) 189 | else: 190 | loss = imitation_loss(self, experiences) 191 | else: 192 | raise NotImplementedError('Not supported network') 193 | self.optimizer.zero_grad() 194 | loss.backward() 195 | total_loss.append( loss.data.cpu().item() ) 196 | if self.grad_clip > 0: 197 | clip_grad_norm(self.learning_network.parameters(), self.grad_clip) 198 | self.optimizer.step() 199 | if self.config.extend is not None: 200 | self.learning_network.scene_embed.weight.data[:10, :].copy_(self.config.extend['scene_embed.weight'][:10, :]) 201 | self.learning_network.task_embed.weight.data[:10, :].copy_(self.config.extend['task_embed.weight'][:10, :]) 202 | 203 | if train and self.total_steps > self.config.exploration_steps: 204 | self.policy.update_epsilon() 205 | 206 | if done: break 207 | 208 | episode_time = time.time() - episode_start_time 209 | self.config.logger.debug('episode steps %d, episode time %f, time per step %f' % 210 | (steps, episode_time, episode_time / float(steps))) 211 | return total_reward, total_loss, steps, (scene_id, task_id) 212 | -------------------------------------------------------------------------------- /synpo/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .Grid_agent import * 2 | -------------------------------------------------------------------------------- /synpo/component/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import * 2 | from .replay import * 3 | from .task import * 4 | -------------------------------------------------------------------------------- /synpo/component/policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from synpo.utils import argmax1d 4 | 5 | class GreedyPolicy: 6 | def __init__(self, epsilon, final_step, min_epsilon): 7 | self.init_epsilon = self.epsilon = epsilon 8 | self.current_steps = 0 9 | self.min_epsilon = min_epsilon 10 | self.final_step = final_step 11 | 12 | def sample(self, action_value, train=True): 13 | if train: 14 | if np.random.rand() < self.epsilon: 15 | return np.random.randint(0, len(action_value)) 16 | return argmax1d(action_value, True) 17 | return np.argmax(action_value) 18 | 19 | def update_epsilon(self): 20 | diff = float(self.current_steps) / self.final_step * (self.init_epsilon - self.min_epsilon) 21 | self.epsilon = self.init_epsilon - diff 22 | self.epsilon = max(self.epsilon, self.min_epsilon) 23 | self.current_steps += 1 24 | -------------------------------------------------------------------------------- /synpo/component/replay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.multiprocessing as mp 5 | from collections import defaultdict 6 | from IPython import embed 7 | 8 | from synpo.utils import discount_cumsum 9 | 10 | class TrajectoryReplay: 11 | def __init__(self, memory_size, max_length, batch_size, discount=0.99): 12 | self.memory_size = memory_size 13 | self.batch_size = batch_size 14 | self.max_length = max_length 15 | self.discount = discount 16 | 17 | self.clear() 18 | 19 | def feed(self, experience): 20 | state, action, reward, _, qs, _, scene_id, task_id, done = experience 21 | 22 | self.cur_states[self.cur_pos] = state 23 | self.cur_actions[self.cur_pos] = action 24 | self.cur_rewards[self.cur_pos] = reward 25 | self.cur_qs[self.cur_pos] = qs 26 | self.cur_scene_ids[self.cur_pos] = scene_id 27 | self.cur_task_ids[self.cur_pos] = task_id 28 | 29 | self.cur_pos += 1 30 | if done: 31 | self.feed_traj() 32 | 33 | def feed_traj(self): 34 | if self.full: 35 | _sid, _tid = self.scene_ids[self.pos][0], self.task_ids[self.pos][0] 36 | self.combo_ids[_sid, _tid].remove(self.pos) 37 | 38 | actual_q = np.asarray(discount_cumsum(self.cur_rewards[:self.cur_pos], self.discount)) 39 | 40 | self.states[self.pos, :self.cur_pos] = self.cur_states[:self.cur_pos] 41 | self.actions[self.pos, :self.cur_pos] = self.cur_actions[:self.cur_pos] 42 | self.rewards[self.pos, :self.cur_pos] = self.cur_rewards[:self.cur_pos] 43 | self.qs[self.pos, :self.cur_pos] = self.cur_qs[:self.cur_pos] 44 | self.actual_q[self.pos, :self.cur_pos] = actual_q 45 | self.scene_ids[self.pos, :self.cur_pos] = self.cur_scene_ids[:self.cur_pos] 46 | self.task_ids[self.pos, :self.cur_pos] = self.cur_task_ids[:self.cur_pos] 47 | self.t_pos[self.pos] = self.cur_pos 48 | 49 | self.combo_ids[self.scene_ids[self.pos][0], self.task_ids[self.pos][0]].add(self.pos) 50 | self.pos += 1 51 | self.cur_pos = 0 52 | 53 | if self.pos == self.memory_size: 54 | self.full = True 55 | self.pos = 0 56 | 57 | def _sample(self, sampled_indices): 58 | return { 59 | 'states': np.concatenate([self.states[i, :self.t_pos[i]] for i in sampled_indices]), 60 | 'actions': np.concatenate([self.actions[i, :self.t_pos[i]] for i in sampled_indices]), 61 | 'rewards': np.concatenate([self.rewards[i, :self.t_pos[i]] for i in sampled_indices]), 62 | 'qs': np.concatenate([self.qs[i, :self.t_pos[i]] for i in sampled_indices]), 63 | 'actual_q': np.concatenate([self.actual_q[i, :self.t_pos[i]] for i in sampled_indices]), 64 | 'scene_ids': np.concatenate([self.scene_ids[i, :self.t_pos[i]] for i in sampled_indices]), 65 | 'task_ids': np.concatenate([self.task_ids[i, :self.t_pos[i]] for i in sampled_indices]), 66 | } 67 | 68 | def sample(self): 69 | upper_bound = self.memory_size if self.full else self.pos 70 | sampled_indices = np.random.randint(0, upper_bound, size=self.batch_size) 71 | return self._sample(sampled_indices) 72 | 73 | def stratified_sample(self, combo_id=None): 74 | sampled_indices = [] 75 | if combo_id is None: #Default: multi-task sampling 76 | sampling_cands = [] 77 | sampling_tasks = { k[1]: [] for k in self.combo_ids.keys() } 78 | num_tasks = len(sampling_tasks.keys()) 79 | for combo_id, val in self.combo_ids.items(): sampling_tasks[combo_id[1]].append(val) 80 | for k in sampling_tasks.keys(): 81 | sampling_cands.extend( np.random.choice(sampling_tasks[k], round(self.batch_size / num_tasks ))) 82 | for v in sampling_cands: 83 | sampled_indices.extend( random.sample(v, min(len(v), 1)) ) 84 | else: 85 | sid, tid = combo_id 86 | if sid is not None and tid is None: # Marginalized sampling according to scene 87 | sampling_cands = [ v for k, v in self.combo_ids.items() if k[0] == sid ] 88 | sampling_cands = np.random.choice(sampling_cands, self.batch_size, replace=True) 89 | for v in sampling_cands: 90 | sampled_indices.extend( random.sample(v, min(len(v), 1)) ) 91 | elif sid is None and tid is not None: # Marginalized sampling according to tasks 92 | sampling_cands = [ v for k, v in self.combo_ids.items() if k[1] == tid ] 93 | sampling_cands = np.random.choice(sampling_cands, self.batch_size, replace=True) 94 | for v in sampling_cands: 95 | sampled_indices.extend( random.sample(v, min(len(v), 1)) ) 96 | else: # Specified sampling 97 | sampled_indices = random.sample(self.combo_ids[combo_id], self.batch_size) 98 | 99 | sampled_indices = np.asarray(sampled_indices) 100 | return self._sample(sampled_indices) 101 | 102 | def get_all(self): 103 | upper_bound = self.memory_size if self.full else self.pos 104 | sampled_indices = np.arange(upper_bound) 105 | return self._sample(sampled_indices) 106 | 107 | def clear(self): 108 | self.states = np.array([[None] * self.max_length] * self.memory_size) 109 | self.actions = np.empty((self.memory_size, self.max_length), dtype=np.uint8) 110 | self.rewards = np.empty((self.memory_size, self.max_length)) 111 | self.qs = np.array([[None] * self.max_length] * self.memory_size) 112 | self.actual_q = np.empty((self.memory_size, self.max_length)) 113 | self.scene_ids = np.empty((self.memory_size, self.max_length), dtype=np.uint8) 114 | self.task_ids = np.empty((self.memory_size, self.max_length), dtype=np.uint8) 115 | 116 | self.cur_states = np.array([None] * self.max_length) 117 | self.cur_actions = np.array([None] * self.max_length) 118 | self.cur_rewards = np.empty(self.max_length) 119 | self.cur_qs = np.array([None] * self.max_length) 120 | self.cur_scene_ids = np.empty(self.max_length, dtype=np.uint8) 121 | self.cur_task_ids = np.empty(self.max_length, dtype=np.uint8) 122 | self.cur_pos = 0 123 | 124 | self.combo_ids = defaultdict(set) 125 | 126 | self.pos = 0 # trajectory 127 | self.t_pos = np.zeros(self.memory_size, dtype=np.uint16) # within trajectory 128 | self.full = False 129 | 130 | def total_size(self): 131 | upper_bound = self.memory_size if self.full else self.pos 132 | return np.sum(self.t_pos[:upper_bound]) 133 | -------------------------------------------------------------------------------- /synpo/component/task.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import sys 3 | import numpy as np 4 | import os.path as osp 5 | import multiprocessing as mp 6 | import sys 7 | import torch 8 | 9 | from IPython import embed 10 | 11 | 12 | class BasicTask: 13 | def __init__(self, max_steps=sys.maxsize): 14 | self.steps = 0 15 | self.max_steps = max_steps 16 | 17 | def reset(self, *args): 18 | self.steps = 0 19 | state = self.env.reset(*args) 20 | return state 21 | 22 | def normalize_state(self, state): 23 | return state 24 | 25 | def step(self, action): 26 | next_state, reward, done, info = self.env.step(action) 27 | self.steps += 1 28 | done = (done or self.steps >= self.max_steps) 29 | return next_state, reward, done, info 30 | 31 | def random_action(self): 32 | return self.env.action_space.sample() 33 | 34 | class GridWorldTask(BasicTask): 35 | def __init__( 36 | self, 37 | layouts=['map{}'.format(i) for i in range(11, 31)], 38 | num_obj_types=5, 39 | task_length=2, 40 | history_length=4, 41 | max_steps=300, 42 | train_combos=None, 43 | test_combos=None, 44 | gaussian_img=True, 45 | record=False, 46 | ): 47 | from synpo.gridworld.env import GridWorld, read_map, ComboEnv, PORGBEnv 48 | self.train_combos = train_combos 49 | self.test_combos = test_combos 50 | self.num_combos = len(train_combos) + len(test_combos) 51 | self.env = PORGBEnv(ComboEnv(GridWorld( 52 | layouts, 53 | window=history_length, 54 | task_length=task_length, 55 | num_obj_types=num_obj_types, 56 | train_combos=train_combos, 57 | test_combos=test_combos, 58 | gaussian_img=gaussian_img)), record=record) 59 | self.action_dim = self.env.action_space.n 60 | self.max_steps = max_steps 61 | self.name = 'gridworld' 62 | 63 | def save_config(self): 64 | return self.__dict__ 65 | 66 | def reset(self, index=None, sample_pos=True, train=True): 67 | self.steps = 0 68 | state = self.env.reset(index, sample_pos=sample_pos, train=train) 69 | return state[0] 70 | 71 | def step(self, action): 72 | next_state, reward, done, info = self.env.step(action) 73 | self.steps += 1 74 | done = (done or self.steps >= self.max_steps) 75 | return next_state[0], reward, done, info 76 | 77 | def normalize_state(self, state): 78 | return np.asarray([np.asarray(s) for s in state]) 79 | 80 | def get_opt_action(self): 81 | return self.env.get_opt_action() 82 | 83 | def get_random_opt_action(self, discount): 84 | return self.env.get_random_opt_action(discount) 85 | 86 | def get_q(self, *args, **kwargs): 87 | return self.env.get_q(*args, **kwargs) 88 | 89 | def get_qs(self, *args, **kwargs): 90 | return self.env.get_qs(*args, **kwargs) 91 | 92 | def index(self): 93 | return self.env.index() 94 | 95 | def seed(self, *args, **kwargs): 96 | return self.env.seed(*args, **kwargs) 97 | 98 | def pos(self): 99 | return self.env.unwrapped.x, self.env.unwrapped.y 100 | 101 | def sub_task(parent_pipe, pipe, task_fn): 102 | parent_pipe.close() 103 | task = task_fn() 104 | task.env.seed(np.random.randint(0, sys.maxsize)) 105 | while True: 106 | op, data = pipe.recv() 107 | if op == 'step': 108 | pipe.send(task.step(data)) 109 | elif op == 'reset': 110 | pipe.send(task.reset()) 111 | elif op == 'exit': 112 | pipe.close() 113 | return 114 | else: 115 | assert False, 'Unknown Operation' 116 | 117 | class ParallelizedTask: 118 | def __init__(self, task_fn, num_workers): 119 | self.task_fn = task_fn 120 | self.task = task_fn() 121 | self.name = self.task.name 122 | self.pipes, worker_pipes = zip(*[mp.Pipe() for _ in range(num_workers)]) 123 | args = [(p, wp, task_fn) for p, wp in zip(self.pipes, worker_pipes)] 124 | self.workers = [mp.Process(target=sub_task, args=arg) for arg in args] 125 | for p in self.workers: p.start() 126 | for p in worker_pipes: p.close() 127 | self.observation_space = self.task.env.observation_space 128 | self.action_space = self.task.env.action_space 129 | 130 | def step(self, actions): 131 | for pipe, action in zip(self.pipes, actions): 132 | pipe.send(('step', action)) 133 | results = [p.recv() for p in self.pipes] 134 | results = map(lambda x: np.stack(x), zip(*results)) 135 | return results 136 | 137 | def reset(self, i=None): 138 | if i is None: 139 | for pipe in self.pipes: 140 | pipe.send(('reset', None)) 141 | results = [p.recv() for p in self.pipes] 142 | else: 143 | self.pipes[i].send(('reset', None)) 144 | results = self.pipes[i].recv() 145 | return np.stack(results) 146 | 147 | def close(self): 148 | for pipe in self.pipes: 149 | pipe.send(('exit', None)) 150 | for p in self.workers: p.join() 151 | -------------------------------------------------------------------------------- /synpo/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid_network import * 2 | -------------------------------------------------------------------------------- /synpo/network/base_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | # Base class for all kinds of network 8 | class BasicNet: 9 | def __init__(self, gpu, LSTM=False, stochastic=False): 10 | if not torch.cuda.is_available(): 11 | gpu = -1 12 | self.gpu = gpu 13 | self.stochastic = stochastic 14 | self.LSTM = LSTM 15 | self.init_weights() 16 | if self.gpu >= 0: 17 | self.cuda(self.gpu) 18 | 19 | def supported_dtype(self, x, torch_type): 20 | if torch_type == torch.FloatTensor: 21 | return np.asarray(x, dtype=np.float32) 22 | if torch_type == torch.LongTensor: 23 | return np.asarray(x, dtype=np.int64) 24 | 25 | def variable(self, x, dtype=torch.FloatTensor, requires_grad=False): 26 | if isinstance(x, Variable): 27 | return x 28 | x = dtype(torch.from_numpy(self.supported_dtype(x, dtype))) 29 | if self.gpu >= 0: 30 | x = x.cuda(self.gpu) 31 | return Variable(x, requires_grad=requires_grad) 32 | 33 | def tensor(self, x, dtype=torch.FloatTensor): 34 | x = dtype(torch.from_numpy(self.supported_dtype(x, dtype))) 35 | if self.gpu >= 0: 36 | x = x.cuda(self.gpu) 37 | return x 38 | 39 | def reset_noise(self): 40 | raise NotImplementedError('Not Supported') 41 | 42 | def reset(self, terminal): 43 | if not self.LSTM: 44 | return 45 | if terminal: 46 | self.h.data.zero_() 47 | self.c.data.zero_() 48 | self.h = Variable(self.h.data) 49 | self.c = Variable(self.c.data) 50 | 51 | def init_weights(self): 52 | for layer in self.children(): 53 | relu_gain = nn.init.calculate_gain('relu') 54 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 55 | nn.init.xavier_uniform(layer.weight.data) 56 | nn.init.constant(layer.bias.data, 0) 57 | if isinstance(layer, nn.Embedding): 58 | nn.init.xavier_uniform(layer.weight.data) 59 | 60 | # Pass both scene and task 61 | class ValueNet(BasicNet): 62 | def predict(self, x, _scene_ids=None, _task_ids=None, to_numpy=False, evaluate=False): 63 | if evaluate: self.eval() 64 | else: self.train() 65 | y = self.forward(x, _scene_ids, _task_ids) 66 | if to_numpy: 67 | y = y[0] 68 | if type(y) is list: 69 | y = [y_.cpu().data.numpy() for y_ in y] 70 | else: 71 | y = y.cpu().data.numpy() 72 | return y 73 | else: 74 | return y 75 | -------------------------------------------------------------------------------- /synpo/network/grid_network.py: -------------------------------------------------------------------------------- 1 | from .base_network import * 2 | from .operator import * 3 | 4 | import torch 5 | import itertools 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from IPython import embed 9 | 10 | 11 | # Policy network for ST2 12 | class GridWorldSynPo(nn.Module, ValueNet): 13 | def __init__(self, in_channels, n_actions, arg_dim, embed_dim, scene_num, task_num, \ 14 | gpu=-1, feat_dim=1024, task_dim=16, scene_dim=16, scene_disentanglement=0.1, task_disentanglement=0.1, norm=True): 15 | super(GridWorldSynPo, self).__init__() 16 | self.action_space = n_actions 17 | self.task_dim = task_dim 18 | self.scene_dim = scene_dim 19 | self.feat_dim = feat_dim 20 | self.embed_dim = embed_dim 21 | self.scene_num = scene_num 22 | self.task_num = task_num 23 | 24 | self.scene_disentangle = scene_disentanglement > 0.0 25 | self.task_disentangle = task_disentanglement > 0.0 26 | self.norm = norm 27 | 28 | out_channels = feat_dim // (2*2) 29 | self.state_func = GridWorldResNet18(in_channels, out_channels) 30 | self.state_fc = nn.Linear(feat_dim, embed_dim) 31 | self.policy_basis = PolicyBasis(n_actions, embed_dim, embed_dim) 32 | 33 | self.task_embed = nn.Embedding(task_num, self.task_dim) 34 | self.scene_embed = nn.Embedding(scene_num, self.scene_dim) 35 | 36 | self.policy_fc1 = nn.Linear(self.task_dim + self.scene_dim, self.embed_dim*4) 37 | self.policy_fc2 = nn.Linear(self.embed_dim*4, self.embed_dim) 38 | 39 | self.reward_fc1 = nn.Linear(self.task_dim + self.scene_dim, self.embed_dim*4) 40 | self.reward_fc2 = nn.Linear(self.embed_dim*4, self.embed_dim) 41 | 42 | self.scene_refc1 = nn.Linear(embed_dim, embed_dim*4) 43 | self.scene_refc2 = nn.Linear(embed_dim*4, scene_dim) 44 | 45 | self.task_refc1 = nn.Linear(embed_dim, embed_dim*4) 46 | self.task_refc2 = nn.Linear(embed_dim*4, task_dim) 47 | 48 | BasicNet.__init__(self, gpu) 49 | 50 | def forward(self, xs, _scene_ids=None, _task_ids=None): 51 | xs = self.variable(xs) 52 | 53 | # inference the state feature 54 | state_feat = self.state_func(xs) 55 | state_feat = self.state_fc(state_feat.view(state_feat.size(0), -1)) 56 | 57 | N = state_feat.size(0) 58 | # Prepare scene/task weight 59 | task_ids = self.variable(_task_ids, torch.LongTensor) 60 | task_emb = self.task_embed(task_ids) # normalize embedding! 61 | if self.norm: task_emb = F.normalize(task_emb) 62 | 63 | scene_ids = self.variable(_scene_ids, torch.LongTensor) 64 | scene_emb = self.scene_embed(scene_ids) 65 | if self.norm: scene_emb = F.normalize(scene_emb) 66 | 67 | policy_emb = F.relu(self.policy_fc1(torch.cat([task_emb, scene_emb], 1))) 68 | policy_emb = self.policy_fc2(policy_emb) 69 | 70 | reward_emb = F.relu(self.reward_fc1(torch.cat([task_emb, scene_emb], 1))) 71 | reward_emb = self.reward_fc2(reward_emb) 72 | 73 | # Generate Task-specific action weight 74 | policy, reward, state_action_feat = self.policy_basis(state_feat, policy_emb, reward_emb) 75 | 76 | if self.scene_disentangle: 77 | reproject_scene = F.relu(self.scene_refc1(state_action_feat)) # ( batch_size, num_action, state_action_dim ) -> ( batch_size*num_action, state_action_dim ) 78 | reproject_scene = self.scene_refc2(reproject_scene).view(N*self.action_space, self.scene_dim) 79 | #if self.norm: reproject_scene = F.normalize(reproject_scene) 80 | 81 | scene_score = torch.mm(reproject_scene, self.scene_embed(self.variable(range(self.scene_num), torch.LongTensor)).t()).view(N, self.action_space, self.scene_num) 82 | else: 83 | scene_score = None 84 | 85 | if self.task_disentangle: 86 | reproject_task = F.relu(self.task_refc1(state_action_feat)) 87 | reproject_task = self.task_refc2(reproject_task).view(N*self.action_space, self.task_dim) 88 | #if self.norm: reproject_task = F.normalize(reproject_task) 89 | 90 | task_score = torch.mm(reproject_task, self.task_embed( self.variable(range(self.task_num), torch.LongTensor)).t()).view(N, self.action_space, self.task_num) 91 | else: 92 | task_score = None 93 | 94 | return policy, reward, scene_score, task_score 95 | 96 | # Task single tensor ST2 97 | class GridWorldMTL(nn.Module, ValueNet): 98 | def __init__(self, in_channels, n_actions, arg_dim, embed_dim, scene_num, task_num, \ 99 | gpu=-1, feat_dim=1024, task_dim=16, scene_dim=16, norm=True): 100 | super(MTL, self).__init__() 101 | self.action_space = n_actions 102 | self.task_dim = task_dim 103 | self.scene_dim = scene_dim 104 | self.feat_dim = feat_dim 105 | self.embed_dim = embed_dim 106 | self.scene_num = scene_num 107 | self.task_num = task_num 108 | self.norm = norm 109 | 110 | out_channels = feat_dim // (2*2) 111 | self.state_func = GridWorldResNet18(in_channels, out_channels) 112 | self.state_fc = nn.Linear(feat_dim, embed_dim) 113 | self.policy_reward = PolicyBasis(n_actions, embed_dim, embed_dim) 114 | 115 | self.task_embed = nn.Embedding(task_num, self.task_dim) 116 | 117 | self.policy_fc1 = nn.Linear(self.task_dim, self.embed_dim*4) 118 | self.policy_fc2 = nn.Linear(self.embed_dim*4, self.embed_dim) 119 | 120 | self.reward_fc1 = nn.Linear(self.task_dim, self.embed_dim*4) 121 | self.reward_fc2 = nn.Linear(self.embed_dim*4, self.embed_dim) 122 | 123 | self.task_refc1 = nn.Linear(embed_dim, embed_dim*4) 124 | self.task_refc2 = nn.Linear(embed_dim*4, task_dim) 125 | 126 | BasicNet.__init__(self, gpu) 127 | 128 | def forward(self, xs, _scene_ids=None, _task_ids=None): 129 | xs = self.variable(xs) 130 | 131 | # inference the state feature 132 | state_feat = self.state_func(xs) 133 | state_feat = self.state_fc(state_feat.view(state_feat.size(0), -1)) 134 | 135 | N = state_feat.size(0) 136 | # Prepare scene/task weight 137 | task_ids = self.variable(_task_ids, torch.LongTensor) 138 | task_emb = self.task_embed(task_ids) 139 | if self.norm: task_emb = F.normalize(task_emb) 140 | 141 | policy_emb = F.relu(self.policy_fc1(task_emb)) 142 | policy_emb = self.policy_fc2(policy_emb) 143 | 144 | reward_emb = F.relu(self.reward_fc1(task_emb)) 145 | reward_emb = self.reward_fc2(reward_emb) 146 | 147 | # Generate Task-specific action weight 148 | policy, reward, state_action_feat = self.policy_basis(state_feat, policy_emb, reward_emb) 149 | 150 | reproject_task = F.relu(self.task_refc1(state_action_feat)) 151 | reproject_task = self.task_refc2(reproject_task).view(N*self.action_space, self.task_dim) 152 | #if self.norm: reproject_task = F.normalize(reproject_task) 153 | 154 | task_score = torch.mm(reproject_task, self.task_embed( self.variable(range(self.task_num), torch.LongTensor)).t()).view(N, self.action_space, self.task_num) 155 | 156 | return policy, reward, None, task_score 157 | 158 | class GridWorldMLP(nn.Module, ValueNet): 159 | def __init__(self, in_channels, n_actions, arg_dim, embed_dim, scene_num, task_num, \ 160 | feat_dim=1024, scene_dim=16, task_dim=16, gpu=-1, scene_disentanglement=0.1, task_disentanglement=0.1, norm=True, y_norm=True): 161 | super(GridWorldMLP, self).__init__() 162 | out_channels = feat_dim // (2*2) 163 | self.scene_num = scene_num 164 | self.task_num = task_num 165 | self.norm = norm 166 | self.y_norm = y_norm 167 | 168 | self.state_func = GridWorldResNet18(in_channels, out_channels) 169 | self.state_fc = nn.Linear(feat_dim, embed_dim*2) 170 | self.scene_embed = nn.Embedding(scene_num, scene_dim) 171 | self.task_embed = nn.Embedding(task_num, task_dim) 172 | self.policy_fc1 = nn.Linear(embed_dim*2 + scene_dim + task_dim, feat_dim) 173 | self.policy_fc2 = nn.Linear(feat_dim, n_actions) 174 | 175 | self.reward_fc1 = nn.Linear(embed_dim*2 + scene_dim + task_dim, feat_dim) 176 | self.reward_fc2 = nn.Linear(feat_dim, n_actions) 177 | 178 | self.scene_refc1 = nn.Linear(embed_dim*2, embed_dim*4) 179 | self.scene_refc2 = nn.Linear(embed_dim*4, embed_dim) 180 | 181 | self.task_refc1 = nn.Linear(embed_dim*2, embed_dim*4) 182 | self.task_refc2 = nn.Linear(embed_dim*4, embed_dim) 183 | self.scene_disentangle = scene_disentanglement > 0.0 184 | self.task_disentangle = task_disentanglement > 0.0 185 | 186 | BasicNet.__init__(self, gpu) 187 | 188 | def forward(self, xs, _scene_ids=None, _task_ids=None): 189 | xs = self.variable(xs) 190 | 191 | task_ids = self.variable(_task_ids, torch.LongTensor) 192 | task_emb = self.task_embed(task_ids) 193 | if self.norm: task_emb = F.normalize(task_emb) 194 | 195 | scene_ids = self.variable(_scene_ids, torch.LongTensor) 196 | scene_emb = self.scene_embed(scene_ids) 197 | if self.norm: scene_emb = F.normalize(scene_emb) 198 | 199 | y = self.state_func(xs) 200 | y = self.state_fc(y) 201 | y = y.view(y.shape[0], -1) 202 | if self.y_norm: y = F.normalize(y) 203 | in_feat = torch.cat([y, scene_emb, task_emb], 1) 204 | policy_embed = F.relu(self.policy_fc1(in_feat)) 205 | reward_embed = F.relu(self.reward_fc1(in_feat)) 206 | 207 | N = y.size(0) 208 | 209 | reproject_scene = F.relu(self.scene_refc1(y)) 210 | # ( batch_size, num_action, state_action_dim ) -> ( batch_size*num_action, state_action_dim ) 211 | reproject_scene = self.scene_refc2(reproject_scene) 212 | 213 | reproject_task = F.relu(self.task_refc1(y)) 214 | reproject_task = self.task_refc2(reproject_task) 215 | 216 | if self.scene_disentangle: 217 | scene_score = torch.mm(reproject_scene, self.scene_embed(self.variable(range(self.scene_num), torch.LongTensor)).t()).view(N, self.scene_num) 218 | else: 219 | scene_score = None 220 | 221 | if self.task_disentangle: 222 | task_score = torch.mm(reproject_task, self.task_embed( self.variable(range(self.task_num), torch.LongTensor)).t()).view(N, self.task_num) 223 | else: 224 | task_score = None 225 | 226 | return self.policy_fc2(policy_embed), self.reward_fc2(reward_embed), scene_score, task_score 227 | -------------------------------------------------------------------------------- /synpo/network/operator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | from IPython import embed 7 | 8 | # Warpped Layer for global average pooling 9 | class GlobalAveragePool(nn.Module): 10 | def forward(self, x): 11 | N, C = x.size(0), x.size(1) 12 | return x.view(N, C, -1).mean(2) 13 | 14 | def init_weights(module): 15 | for layer in module.children(): 16 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 17 | nn.init.xavier_uniform(layer.weight.data) 18 | nn.init.constant(layer.bias.data, 0) 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, groups=groups, stride=stride, padding=1, bias=False) 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(in_planes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion * planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 36 | stride=stride, bias=False), 37 | nn.BatchNorm2d(self.expansion * planes) 38 | ) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.bn2(self.conv2(out)) 43 | out += self.shortcut(x) 44 | out = F.relu(out) 45 | return out 46 | 47 | class GridResNet(nn.Module): 48 | def __init__(self, block, num_blocks, in_channels, nf): 49 | super(GridResNet, self).__init__() 50 | self.in_planes = nf 51 | self.in_channels = in_channels 52 | 53 | self.datanorm = nn.BatchNorm2d(in_channels, affine=False) 54 | self.conv1 = conv3x3(in_channels, nf * 1, groups=4) 55 | self.bn1 = nn.BatchNorm2d(nf * 1) 56 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) 57 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) 58 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) 59 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) 60 | 61 | def _make_layer(self, block, planes, num_blocks, stride): 62 | strides = [stride] + [1] * (num_blocks - 1) 63 | layers = [] 64 | for stride in strides: 65 | layers.append(block(self.in_planes, planes, stride)) 66 | self.in_planes = planes * block.expansion 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | bsz = x.size(0) 71 | out = F.relu(self.bn1(self.conv1(self.datanorm(x)))) 72 | out = self.layer1(out) 73 | out = self.layer2(out) 74 | out = self.layer3(out) 75 | out = self.layer4(out) 76 | return out.view(out.size(0), -1) 77 | 78 | class PolicyBasis(nn.Module): 79 | def __init__(self, action_num, state_dim, task_dim): 80 | super(PolicyBasis, self).__init__() 81 | self.state_dim = state_dim 82 | self.task_dim = task_dim 83 | self.action_num = action_num 84 | 85 | self.weight_mu = nn.Parameter(torch.Tensor(action_num, state_dim, task_dim)) 86 | self.policy_bias_mu = nn.Parameter(torch.Tensor(action_num)) 87 | self.reward_bias_mu = nn.Parameter(torch.Tensor(action_num)) 88 | 89 | self.reset_parameters() 90 | 91 | def forward(self, input1, input2, input3): 92 | N = input1.size(0) 93 | state_action_feat = torch.mm(input1, self.weight_mu.transpose(1, 0).contiguous().view( 94 | self.state_dim, self.action_num*self.task_dim)).view(N, self.action_num, self.task_dim) 95 | 96 | output1 = torch.bmm(state_action_feat, input2.unsqueeze(2)).squeeze(2) 97 | output2 = torch.bmm(state_action_feat, input3.unsqueeze(2)).squeeze(2) 98 | 99 | return output1 + self.policy_bias_mu, output2 + self.reward_bias_mu, state_action_feat 100 | 101 | def reset_parameters(self): 102 | mu_range = 1 / np.sqrt(self.state_dim*self.task_dim*self.action_num) 103 | self.weight_mu.data.uniform_(-mu_range, mu_range) 104 | self.policy_bias_mu.data.fill_(0) 105 | self.reward_bias_mu.data.fill_(0) 106 | 107 | def __repr__(self): 108 | return self.__class__.__name__ + \ 109 | '(state_featurs={}, task_features={}, action_num={})'.format( 110 | self.state_dim, self.task_dim, self.action_num) 111 | 112 | def GridWorldResNet18(in_channels, nf=128): 113 | return GridResNet(BasicBlock, [2, 2, 2, 2], in_channels, nf // 8) 114 | -------------------------------------------------------------------------------- /synpo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .trainer import * 3 | from .tf_logger import Logger 4 | from .utils import * 5 | import logging 6 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 7 | logger = logging.getLogger('MAIN') 8 | logger.setLevel(logging.INFO) 9 | -------------------------------------------------------------------------------- /synpo/utils/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | def __init__(self): 3 | self.task_fn = None 4 | self.optimizer_fn = None 5 | self.network_fn = None 6 | self.policy_fn = None 7 | self.replay_fn = None 8 | self.discount = 0.99 9 | self.target_network_update_freq = 100 10 | self.exploration_steps = 50000 11 | self.logger = None 12 | self.history_length = 4 13 | self.test_interval = 100 14 | self.test_repetitions = 10 15 | self.double_q = False 16 | self.tag = 'vanilla' 17 | self.update_interval = 1 18 | self.action_shift_fn = lambda a: a 19 | self.reward_shift_fn = lambda r: r 20 | self.episode_limit = 0 21 | self.save_interval = 0 22 | self.max_steps = 0 23 | self.max_eps = 200000 24 | self.grad_clip = 0 25 | self.n_test_samples = 100 26 | self.value_loss_weight = 0.5 27 | self.one_traj = False 28 | self.extend = None 29 | -------------------------------------------------------------------------------- /synpo/utils/tf_logger.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import os 3 | import numpy as np 4 | 5 | class Logger(object): 6 | def __init__(self, log_dir, vanilla_logger, skip=False): 7 | try: 8 | for f in os.listdir(log_dir): 9 | if not f.startswith('events'): 10 | continue 11 | os.remove('%s/%s' % (log_dir, f)) 12 | except IOError: 13 | os.mkdir(log_dir) 14 | if not skip: 15 | self.writer = SummaryWriter(log_dir) 16 | self.info = vanilla_logger.info 17 | self.debug = vanilla_logger.debug 18 | self.warning = vanilla_logger.warning 19 | self.skip = skip 20 | self.step = 0 21 | 22 | def scalar_summary(self, tag, value, step=None): 23 | if self.skip: 24 | return 25 | if step is None: 26 | step = self.step 27 | self.step += 1 28 | if np.isscalar(value): 29 | value = np.asarray([value]) 30 | self.writer.add_scalar(tag, value, step) 31 | 32 | def histo_summary(self, tag, values, step=None): 33 | if self.skip: 34 | return 35 | if step is None: 36 | step = self.step 37 | self.step += 1 38 | self.writer.add_histogram(tag, values, step, bins=1000) 39 | -------------------------------------------------------------------------------- /synpo/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import os.path as osp 5 | import copy 6 | from collections import defaultdict 7 | import random 8 | 9 | from IPython import embed 10 | import torch 11 | from .utils import mkdir 12 | 13 | def success(r): 14 | return r > 0 15 | 16 | def test_agent(agent, test_on_train=False): 17 | config = agent.config 18 | combo_test_rewards = defaultdict(list) 19 | combo_test_success = defaultdict(list) 20 | 21 | test_samples = agent.task.train_combos if test_on_train else agent.task.test_combos 22 | if config.n_test_samples: 23 | sampled_combos = random.choices(test_samples, k=config.n_test_samples) 24 | else: 25 | sampled_combos = test_samples 26 | for combo in sampled_combos: 27 | test_reward, _, _, _ = agent.episode(train=False, env_kwargs=dict(index=combo, sample_pos=True)) 28 | combo_test_rewards[combo].append(test_reward) 29 | combo_test_success[combo].append(success(test_reward)) 30 | config.logger.info('average test success rate and rewards of each task') 31 | for k, v in combo_test_rewards.items(): 32 | config.logger.info('{} {}: {} {}'.format(k, agent.task.env.task_desc[k[-1]], np.mean(v), np.mean(combo_test_success[k]))) 33 | avg_test_reward = np.mean(np.concatenate(list(combo_test_rewards.values()))) 34 | avg_test_success_rate = np.mean(np.concatenate(list(combo_test_success.values()))) 35 | return avg_test_reward, avg_test_success_rate 36 | 37 | def train_agent(agent): 38 | config = agent.config 39 | window_size = 100 40 | ep = 0 41 | rewards = [] 42 | steps = [] 43 | avg_test_rewards = [] 44 | avg_test_success_rates = [] 45 | agent_type = agent.__class__.__name__ 46 | best_model = None 47 | best_sr = 0 48 | combo_success = defaultdict(list) 49 | model_bank = [] 50 | # plot 51 | name = '{}-{}-{}-{}-{}'.format(agent_type, agent.task.name, config.tag, agent.split_name, agent.config.postfix) 52 | while True: 53 | ep += 1 54 | reward, loss, step, index = agent.episode() 55 | combo_success[index].append(success(reward)) 56 | rewards.append(reward) 57 | steps.append(step) 58 | config.logger.info('episode %d, reward %.3f, idx (%d, %d), avg loss %.3f, total steps %d, episode step %d, epislon %.2f' % ( 59 | ep, reward, index[0], index[1], sum(loss) / (len(loss) + 1e-12), agent.total_steps, step, agent.policy.epsilon)) 60 | 61 | if config.episode_limit and ep > config.episode_limit: break 62 | 63 | if config.test_interval and ep % config.test_interval == 0 and agent.total_steps > config.exploration_steps: 64 | config.logger.info('averge success rate of each task:') 65 | for k, v in combo_success.items(): 66 | config.logger.info('{} {}: {}'.format(k, agent.task.env.task_desc[k[-1]], np.mean(v))) 67 | combo_success.clear() 68 | 69 | config.logger.info('Testing on train...') 70 | avg_train_reward, avg_train_success_rate = test_agent(agent, test_on_train=True) 71 | config.logger.info('Avg test success rate %f, Avg test reward %f' % (avg_train_success_rate, avg_train_reward)) 72 | 73 | config.logger.info('Testing on test...') 74 | avg_test_reward, avg_test_success_rate = test_agent(agent) 75 | avg_test_rewards.append(avg_test_reward) 76 | avg_test_success_rates.append(avg_test_success_rate) 77 | if best_sr <= avg_test_success_rate: 78 | best_sr = avg_test_success_rate 79 | best_model = copy.deepcopy(agent.learning_network.state_dict()) 80 | config.logger.info('Avg test success rate %f, Avg test reward %f' % (avg_test_success_rate, avg_test_reward)) 81 | 82 | #============================================== 83 | # Unwrapped Model Saving Routine 84 | #============================================== 85 | snapshot_filepath = osp.join('data', 'outputs', '{}-{}-{}-{}-{}'.format(agent_type, agent.task.name, config.tag, agent.split_name, agent.config.postfix)) 86 | mkdir(snapshot_filepath) 87 | torch.save({'best_model_weight': agent.learning_network.state_dict() }, osp.join(snapshot_filepath, 88 | 'episode.{}.train-sr.{:3f}.test-sr.{:3f}.train-rw.{:3f}.test-rw.{:3f}.model'.format(ep, avg_train_success_rate, avg_test_success_rate, avg_train_reward, avg_test_reward))) 89 | torch.save({'task': agent.task.save_config(), 'best_sr': best_sr, 'best_model_weight': best_model, 'rewards': rewards, 90 | 'steps': steps, 'avg_test_rewards': avg_test_rewards, 'avg_test_success_rates': avg_test_success_rates}, 91 | osp.join(snapshot_filepath, 'train.record')) 92 | 93 | if (config.max_steps and agent.total_steps > config.max_steps) or (config.max_eps and ep > config.max_eps): 94 | config.logger.info('Testing on train Before Fiishing...') 95 | avg_train_reward, avg_train_success_rate = test_agent(agent, test_on_train=True) 96 | config.logger.info('Avg test success rate %f, Avg test reward %f' % (avg_train_success_rate, avg_train_reward)) 97 | 98 | config.logger.info('Testing on test Before Finishing...') 99 | avg_test_reward, avg_test_success_rate = test_agent(agent) 100 | avg_test_rewards.append(avg_test_reward) 101 | avg_test_success_rates.append(avg_test_success_rate) 102 | if best_sr <= avg_test_success_rate: 103 | best_sr = avg_test_success_rate 104 | best_model = copy.deepcopy(agent.learning_network.state_dict()) 105 | config.logger.info('Avg test success rate %f, Avg test reward %f' % (avg_test_success_rate, avg_test_reward)) 106 | #============================================== 107 | # Unwrapped Model Saving Routine 108 | #============================================== 109 | snapshot_filepath = osp.join('data', 'outputs', '{}-{}-{}-{}-{}'.format(agent_type, agent.task.name, config.tag, agent.split_name, agent.config.postfix)) 110 | mkdir(snapshot_filepath) 111 | torch.save({'best_model_weight': agent.learning_network.state_dict() }, osp.join(snapshot_filepath, 112 | 'episode.{}.train-sr.{:3f}.test-sr.{:3f}.train-rw.{:3f}.test-rw.{:3f}.model'.format(ep, avg_train_success_rate, avg_test_success_rate, avg_train_reward, avg_test_reward))) 113 | torch.save({'task': agent.task.save_config(), 'best_sr': best_sr, 'best_model_weight': best_model, 'rewards': rewards, 114 | 'steps': steps, 'avg_test_rewards': avg_test_rewards, 'avg_test_success_rates': avg_test_success_rates}, 115 | osp.join(snapshot_filepath, 'train.record')) 116 | break 117 | 118 | os.system("ls {} > {}/ls.log".format(snapshot_filepath, snapshot_filepath)) # automatically log 119 | agent.close() 120 | return steps, rewards, avg_test_rewards, avg_test_success_rates 121 | -------------------------------------------------------------------------------- /synpo/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import numpy as np 5 | import random 6 | 7 | 8 | def mkdir(path, rm=False): 9 | if not os.path.exists(path): 10 | os.makedirs(path) 11 | else: 12 | if rm: 13 | shutil.rmtree(path) 14 | os.makedirs(path) 15 | 16 | def set_seed(t, r=None, p=None, c=None): 17 | if r is None: 18 | r = t 19 | if p is None: 20 | p = r 21 | torch.manual_seed(t) 22 | random.seed(r) 23 | np.random.seed(p) 24 | if c is not None: 25 | torch.cuda.manual_seed(c) 26 | 27 | def extract(d, *args): 28 | ret = [] 29 | for k in args: 30 | ret.append(d[k]) 31 | return ret 32 | 33 | def argmax1d(a, random_tie=False): 34 | a = np.asarray(a) 35 | if random_tie: 36 | return np.random.choice(np.flatnonzero(a == a.max())) 37 | else: 38 | return np.argmax(a) 39 | 40 | def discount_cumsum(xs, discount=0.99): 41 | r = 0.0 42 | res = [] 43 | for x in xs[::-1]: 44 | r = r * discount + x 45 | res.append(r) 46 | return res[::-1] 47 | 48 | -------------------------------------------------------------------------------- /tools/generate_gridworld_extend.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | from pprint import pprint 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--seed', default=9487, type=int) 10 | parser.add_argument('--scene_num', default=20, type=int) 11 | parser.add_argument('--task_num', default=20, type=int) 12 | parser.add_argument('--extend_num', default=10, type=int) 13 | parser.add_argument('--split_filepath', default='data/hard_extend.split', type=str) 14 | 15 | args = parser.parse_args() 16 | np.random.seed(args.seed) 17 | 18 | def main(args): 19 | layouts = ['map{}'.format(i) for i in range(0, 20) ] 20 | train_combos = list(product(range(args.extend_num), range(args.extend_num))) 21 | test_combos = list(set(product(range(args.scene_num), range(args.task_num))) - set(train_combos)) 22 | 23 | print('Training combos') 24 | pprint(train_combos) 25 | print('Testing combos') 26 | pprint(test_combos) 27 | 28 | table = np.zeros((args.scene_num, args.task_num)) 29 | for i, j in train_combos: 30 | table[i][j] = 1 31 | 32 | print('Dumping splits to {}'.format(args.split_filepath)) 33 | with open(args.split_filepath, 'wb') as handle: 34 | pickle.dump({'train_combos': train_combos, 'test_combos': test_combos, 'scene_num': args.scene_num, 'task_num': args.task_num, 'layouts': layouts}, handle) 35 | 36 | if __name__ == '__main__': 37 | main(args) 38 | -------------------------------------------------------------------------------- /tools/generate_gridworld_split.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | from pprint import pprint 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--seed', default=9487, type=int) 9 | parser.add_argument('--scene_num', default=20, type=int) 10 | parser.add_argument('--task_num', default=20, type=int) 11 | parser.add_argument('--cover', default=4, type=int) 12 | parser.add_argument('--num_train', default=144, type=int) 13 | parser.add_argument('--split_filepath', default='data/hard.split', type=str) 14 | 15 | args = parser.parse_args() 16 | np.random.seed(args.seed) 17 | 18 | def main(args): 19 | def sanity_check(train_combos, scene_num, task_num, num=4): 20 | from collections import defaultdict 21 | row = defaultdict(list) 22 | col = defaultdict(list) 23 | for i in range(scene_num): 24 | row[i] = [] 25 | for i in range(task_num): 26 | col[i] = [] 27 | for combo in train_combos: 28 | row[combo[0]].append(combo) 29 | col[combo[1]].append(combo) 30 | for k, v in row.items(): 31 | if len(v) < num: return False 32 | for k, v in col.items(): 33 | if len(v) < num: return False 34 | 35 | return True 36 | 37 | layouts = ['map{}'.format(i) for i in range(0, 20) ] 38 | total_combos = [(i, j) for i, j in product(range(args.scene_num), range(args.task_num))] 39 | if args.num_train == -1: 40 | args.num_train = args.scene_num*args.cover+args.task_num*args.cover-args.cover*args.cover 41 | while True: 42 | train_split = np.random.choice(range(args.scene_num*args.task_num), 43 | args.num_train, 44 | replace=False).tolist() 45 | train_combos = [ total_combos[idx] for idx in train_split ] 46 | if sanity_check(train_combos, args.scene_num, args.task_num, num=args.cover): break 47 | 48 | test_combos = list(set(total_combos) - set(train_combos)) 49 | 50 | print('Training combos') 51 | pprint(train_combos) 52 | print('Testing combos') 53 | pprint(test_combos) 54 | 55 | table = np.zeros((args.scene_num, args.task_num)) 56 | for i, j in train_combos: 57 | table[i][j] = 1 58 | 59 | print('Dumping splits to {}'.format(args.split_filepath)) 60 | with open(args.split_filepath, 'wb') as handle: 61 | pickle.dump({'train_combos': train_combos, 'test_combos': test_combos, 'scene_num': args.scene_num, 'task_num': args.task_num, 'layouts': layouts}, handle) 62 | 63 | if __name__ == '__main__': 64 | main(args) 65 | -------------------------------------------------------------------------------- /train_gridworld.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import ipdb 4 | import random 5 | from datetime import datetime 6 | from itertools import product 7 | from tqdm import tqdm 8 | import numpy as np 9 | import pickle 10 | from IPython import embed 11 | from ipdb import slaunch_ipdb_on_exception 12 | 13 | from synpo.agent import * 14 | from synpo.component import * 15 | from synpo.utils import * 16 | import synpo.gridworld as gridworld 17 | 18 | from synpo.utils import mkdir, set_seed 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--gpu_id', default=0, type=int) 22 | parser.add_argument('--batch_size', default=128, type=int) 23 | parser.add_argument('--weight', default=None, type=str) 24 | parser.add_argument('--scene', default=None, type=int) 25 | parser.add_argument('--task', default=None, type=int) 26 | parser.add_argument('--embedding_dim', default=128, type=int) 27 | parser.add_argument('--scene_embedding_dim', default=128, type=int) 28 | parser.add_argument('--task_embedding_dim', default=128, type=int) 29 | parser.add_argument('--num_obj_types', default=5, type=int) 30 | parser.add_argument('--task_length', default=2, type=int) 31 | parser.add_argument('--update_interval', default=1, type=int) 32 | parser.add_argument('--scene_num', default=5, type=int) 33 | parser.add_argument('--task_num', default=5, type=int) 34 | parser.add_argument('--reward_prediction', default=1, type=int, 35 | help="loss weight of reward prediction objective") 36 | parser.add_argument('--scene_disentanglement', default=0.1, type=float, 37 | help="loss weight of scene disentanglement prediction objective") 38 | parser.add_argument('--task_disentanglement', default=0.1, type=float, 39 | help="loss weight of task disentanglement prediction objective") 40 | parser.add_argument('--split_filepath', default=None, type=str, required=True, 41 | help="train/test split filepath") 42 | parser.add_argument('--lr', default=0.001, type=float, 43 | help="base learning rate") 44 | parser.add_argument('--wd', action='store_true', 45 | help="enable weight decay") 46 | parser.add_argument('--mode', default='cloning', choices=['cloning'], 47 | help="training mode [only behavior cloing available for now]") 48 | parser.add_argument('--network', default='synpo', choices=['mlp', 'mtl', 'synpo'], 49 | help="select model architecture") 50 | parser.add_argument('--postfix', default='', type=str, 51 | help="postfix to the log file") 52 | parser.add_argument('--repeat', default=10, type=int, 53 | help="number of test run") 54 | parser.add_argument('--evaluate', action='store_true', 55 | help="evaluation mode") 56 | parser.add_argument('--visualize', action='store_true', 57 | help="visualize policy [only in evaluation mode]") 58 | parser.add_argument('--random_seed', default=0, type=int, 59 | help="random seed value") 60 | parser.add_argument('--logger_name', default='log/synpo_{}_{}_{}_{}.log', type=str, 61 | help="logger name format [must have for slots to fill]") 62 | parser.add_argument('--norm', action='store_true', 63 | help="whether normalize the scene/task embedding") 64 | parser.add_argument('--extend_mode', action='store_true', 65 | help="train on the first (10 ENV, 10 TASK) combinations.") 66 | args = parser.parse_args() 67 | 68 | def get_network(task): 69 | arg_dim = task.env.observation_space.spaces[1].shape[0] 70 | grid_dim = task.env.observation_space.spaces[0].shape[0] 71 | action_dim = task.env.action_space.n 72 | if args.network == 'mlp': 73 | network = GridWorldMLP(grid_dim, action_dim, arg_dim, 74 | scene_num=args.scene_num, 75 | task_num=args.task_num, 76 | embed_dim=args.embedding_dim, 77 | scene_dim=args.scene_embedding_dim, 78 | task_dim=args.task_embedding_dim, 79 | gpu=args.gpu_id, 80 | scene_disentanglement=args.scene_disentanglement, 81 | task_disentanglement=args.task_disentanglement, 82 | norm=args.norm) 83 | elif args.network == 'mtl': 84 | network = GridWorldMTL(grid_dim, action_dim, arg_dim, 85 | scene_num=args.scene_num, 86 | task_num=args.task_num, 87 | embed_dim=args.embedding_dim, 88 | scene_dim=args.scene_embedding_dim, 89 | task_dim=args.task_embedding_dim, 90 | gpu=args.gpu_id, 91 | scene_disentanglement=args.scene_disentanglement, 92 | task_disentanglement=args.task_disentanglement, 93 | norm=args.norm) 94 | elif args.network == 'synpo': 95 | network = GridWorldSynPo(grid_dim, action_dim, arg_dim, 96 | scene_num=args.scene_num, 97 | task_num=args.task_num, 98 | embed_dim=args.embedding_dim, 99 | scene_dim=args.scene_embedding_dim, 100 | task_dim=args.task_embedding_dim, 101 | gpu=args.gpu_id, 102 | norm=args.norm) 103 | else: 104 | raise ValueError('Non-supported Network') 105 | return network 106 | 107 | def gridworld_behaviour_cloning(args, layouts, train_combos, test_combos): 108 | config = Config() 109 | grid_world_task = GridWorldTask(layouts, 110 | num_obj_types=args.num_obj_types, 111 | task_length=args.task_length, 112 | history_length= config.history_length, 113 | train_combos=train_combos, 114 | test_combos=test_combos) 115 | config.task_fn = lambda: grid_world_task 116 | if args.wd: 117 | print('with weight decay!') 118 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=args.lr, weight_decay=10e-5) 119 | else: 120 | print('without weight decay!') 121 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=args.lr) 122 | 123 | network = get_network(grid_world_task) 124 | if args.weight is not None: network.load_state_dict(torch.load(args.weight)['best_model_weight']) 125 | 126 | print(network) 127 | 128 | config.network_fn = lambda: network 129 | config.replay_fn = lambda: TrajectoryReplay(memory_size=20000, 130 | max_length=200, 131 | batch_size=64) # number of trajectory per batch 132 | config.policy_fn = lambda: GreedyPolicy(epsilon=0.1, 133 | final_step=500000, 134 | min_epsilon=0.0) 135 | config.logger = Logger('./log', logger) 136 | config.test_interval = 2000 137 | config.exploration_steps = 50000 138 | config.postfix = args.postfix 139 | config.tag = network.__class__.__name__ 140 | config.update_interval = 1 # preset 141 | config.scene_disentanglement_coeff = args.scene_disentanglement 142 | config.task_disentanglement_coeff = args.task_disentanglement 143 | return GridBehaviourCloning(config) 144 | 145 | if __name__ == '__main__': 146 | mkdir('data') 147 | mkdir('log') 148 | os.system('export OMP_NUM_THREADS=1') 149 | 150 | if args.extend_mode: # Hardcoding numbers of scenes and tasks for training 151 | args.scene_num = 10 152 | args.task_num = 10 153 | 154 | set_seed(args.random_seed, c=args.random_seed) 155 | if args.split_filepath is None: # Default Multi-task Setting 156 | layouts = ['map{}'.format(i) for i in range(0, 20) ] 157 | train_combos = [(i, j) for i, j in product(range(args.scene_num), range(args.task_num))] 158 | test_combos = [(i, j) for i, j in product(range(args.scene_num), range(args.task_num))] 159 | else: 160 | with open(args.split_filepath, 'rb') as handle: 161 | data = pickle.load(handle) 162 | args.task_num = data['task_num'] 163 | args.scene_num = data['scene_num'] 164 | train_combos = data['train_combos'] 165 | test_combos = data['test_combos'] 166 | layouts = data['layouts'] 167 | print('num train:', len(train_combos), 'num test:', len(test_combos)) 168 | 169 | if args.mode == 'cloning': 170 | print('Loading Episodic Behavior Cloning') 171 | agent = gridworld_behaviour_cloning(args, layouts, train_combos, test_combos) 172 | 173 | agent.reward_prediction = args.reward_prediction 174 | if args.split_filepath is None: # Default Multi-task Setting 175 | agent.split_name = 'MTL' 176 | else: 177 | agent.split_name = "-".join(args.split_filepath.split('/')[-2:]) 178 | if args.evaluate: 179 | with slaunch_ipdb_on_exception(): 180 | traj_length = [] 181 | if args.scene is not None or args.task is not None: 182 | if args.scene is not None and args.task is None: 183 | index_scene = args.scene 184 | index_task = random.sample([x[1] for x in train_combos if x[0] == args.scene], 1)[0] 185 | else: 186 | index_scene = args.scene if args.scene is not None else np.random.randint(args.scene_num) 187 | index_task = args.task if args.task is not None else np.random.randint(args.task_num) 188 | for _ in tqdm(range(args.repeat)): 189 | success, traj_len, _, _ = agent.evaluate(visualize=args.visualize, 190 | index=(index_scene, index_task)) # main program 191 | if success: 192 | traj_length.append(traj_len) 193 | print('mean length:', np.mean(traj_length)) 194 | else: 195 | rates = [] 196 | for combo in train_combos: 197 | success_list = [] 198 | trajectory_list = [] 199 | for _ in tqdm(range(args.repeat)): 200 | success, traj_len, _ = agent.evaluate(visualize=args.visualize, index=combo) # main program 201 | success_list.append(success) 202 | trajectory_list.append(traj_len) 203 | success_rate = sum(success_list) / len(success_list) 204 | rates.append(success_rate) 205 | print('* [Task={}, # of Tests={}] Average success rate: {:.4f}, Average trajectory length: {}'.format( combo, args.repeat, 206 | success_rate, sum(trajectory_list) / len(trajectory_list) )) 207 | print('average success rate: {:.4f}'.format(np.mean(rates))) 208 | else: 209 | logger.setLevel(logging.INFO) 210 | handler = logging.FileHandler(args.logger_name.format(agent.__class__.__name__, 211 | agent.learning_network.__class__.__name__, 212 | datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), 213 | args.postfix)) 214 | logger.addHandler(handler) 215 | with slaunch_ipdb_on_exception(): 216 | train_agent(agent) # main program 217 | 218 | --------------------------------------------------------------------------------