├── __init__.py ├── dqn ├── __init__.py ├── config.yaml ├── model.py ├── train.py ├── experiment.py └── ai.py ├── tabular ├── __init__.py ├── config.yaml ├── ai.py ├── train.py └── experiment.py ├── results └── README.md ├── .gitignore ├── run.sh ├── README.md ├── LICENSE.txt ├── environment ├── README.md └── fruit_collection.py └── utils.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tabular/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | This folder is used for logging. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | *.pyc 4 | __pycache__ 5 | *.pkl 6 | *.h5 7 | *.json 8 | *.pdf 9 | *.csv 10 | *.bak 11 | *.npy 12 | *render* 13 | test* -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | echo '=====' 2 | echo 'Tabular GVF ...' 3 | THEANO_FLAG="device=cpu" ipython ./tabular/train.py -- -o use_gvf True -o folder_name tabular_gvf_ -o nb_experiments 5 4 | 5 | echo '=====' 6 | echo 'Tabular NO_GVF ...' 7 | THEANO_FLAG="device=cpu" ipython ./tabular/train.py -- -o use_gvf False -o folder_name tabular_no-gvf_ -o nb_experiments 5 8 | 9 | echo '=====' 10 | echo 'DQN all baselines ...' 11 | THEANO_FLAG="device=cpu" ipython ./dqn/train.py -- -o nb_experiments 5 12 | -------------------------------------------------------------------------------- /tabular/config.yaml: -------------------------------------------------------------------------------- 1 | nb_experiments: 1 2 | nb_epochs: 200 3 | epoch_size: 1 4 | nb_eval_episodes: 5 5 | folder_location: '/results/' 6 | folder_name: 'tabular_' 7 | random_seed: 7654 8 | 9 | aggregator_epsilon: 1.0 10 | aggregator_final_epsilon: 1.0 11 | aggregator_decay_start: 1 12 | aggregator_decay_steps: 1 13 | 14 | init_q: 0.0 15 | gamma: 0.99 16 | alpha: 0.001 17 | final_alpha: 0.001 18 | alpha_decay_steps: 1 19 | alpha_decay_start: 1 20 | 21 | learning_method: mean # mean or max 22 | use_gvf: True 23 | -------------------------------------------------------------------------------- /dqn/config.yaml: -------------------------------------------------------------------------------- 1 | rendering: False 2 | nb_experiments: 1 3 | random_seed: 1234 4 | total_eps: 5000 5 | is_learning: True 6 | eps_per_epoch: 10 7 | max_start_nullops: 0 8 | is_testing: True 9 | eps_per_test: 100 10 | episode_max_len: 300 11 | folder_location: '/results/' 12 | test: False 13 | 14 | epsilon: 1.0 15 | annealing: False 16 | final_epsilon: 1.0 17 | test_epsilon: 0.0 18 | 19 | learning_rate: 0.001 20 | minibatch_size: 32 21 | history_len: 1 22 | replay_max_size: 10000 23 | replay_min_size: 1000 24 | learning_frequency: 4 25 | update_freq: 100 26 | action_dim: 1 27 | num_units: 250 28 | -------------------------------------------------------------------------------- /dqn/model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input 2 | from keras.layers.core import Dense, Flatten 3 | from keras.models import Model 4 | 5 | floatX = 'float32' 6 | 7 | 8 | def build_dense(state_shape, nb_units, nb_actions, nb_channels, remove_features=False): 9 | if remove_features: 10 | state_shape = state_shape[: -1] + [state_shape[-1] - nb_channels + 1] 11 | input_dim = tuple(state_shape) 12 | states = Input(shape=input_dim, dtype=floatX, name='states') 13 | flatten = Flatten()(states) 14 | hid = Dense(output_dim=nb_units, init='he_uniform', activation='relu', name='hidden')(flatten) 15 | out = Dense(output_dim=nb_actions, init='he_uniform', activation='linear', name='out')(hid) 16 | return Model(input=states, output=out) 17 | -------------------------------------------------------------------------------- /tabular/ai.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | floatX = np.float32 5 | 6 | 7 | class AI(object): 8 | def __init__(self, nb_actions, init_q, gamma, alpha, learning_method, rng, freeze=False): 9 | self.nb_actions = nb_actions 10 | self.freeze = freeze # no learning if True 11 | self.init_q = init_q 12 | self.q = dict() 13 | self.gamma = gamma 14 | self.alpha = alpha 15 | self.start_alpha = alpha 16 | self.learning_method = learning_method 17 | self.rng = rng 18 | 19 | def reset(self): 20 | self.q = dict() 21 | self.alpha = deepcopy(self.start_alpha) 22 | 23 | def _get_q(self, s, a): 24 | sa = tuple(list(s) + [a]) 25 | if sa in self.q: 26 | return self.q[sa] 27 | else: 28 | self._set_init_q(s) 29 | return self.init_q 30 | 31 | def get_q(self, s, a=None): 32 | if a is not None: 33 | return self._get_q(s, a) 34 | else: 35 | return np.asarray([self._get_q(s, a) for a in range(self.nb_actions)], dtype=floatX) 36 | 37 | def _set_q(self, s, a, q): 38 | sa = tuple(list(s) + [a]) 39 | if sa not in self.q: 40 | self._set_init_q(s) 41 | self.q[sa] = floatX(q) 42 | 43 | def _set_init_q(self, s): 44 | s = list(s) 45 | for a in range(self.nb_actions): 46 | self.q[tuple(s + [a])] = floatX(self.init_q) 47 | 48 | def get_max_action(self, s, stochastic=True): 49 | values = self.get_q(s) 50 | if stochastic: 51 | actions = np.where(values == values.max())[0] 52 | return self.rng.choice(actions) 53 | else: 54 | return np.argmax(values) 55 | 56 | def learn(self, s, a, r, s2, term): 57 | if self.freeze: 58 | return 59 | if term: 60 | q2 = 0. 61 | elif self.learning_method == 'max': 62 | q2 = np.max(self.get_q(s2)) 63 | elif self.learning_method == 'mean': 64 | q2 = np.mean(self.get_q(s2)) 65 | else: 66 | raise ValueError('Learning method is not known.') 67 | delta = r + self.gamma * q2 - self._get_q(s, a) 68 | self._set_q(s, a, self._get_q(s, a) + self.alpha * delta) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hybrid Reward Architecture 2 | This repository hosts the code published along with the following NIPS article (Experiment 4.1: Fruit Collection Task): 3 | 4 | * https://arxiv.org/abs/1706.04208 5 | 6 | For more information about this article, see the following blog posts: 7 | 8 | * https://www.microsoft.com/en-us/research/blog/hybrid-reward-architecture-achieving-super-human-ms-pac-man-performance/ 9 | * https://blogs.microsoft.com/ai/2017/06/14/divide-conquer-microsoft-researchers-used-ai-master-ms-pac-man/ 10 | 11 | # Dependencies 12 | 13 | We strongly suggest to use [Anaconda distribution](https://www.anaconda.com/download/). 14 | 15 | * Python 3.5 or higher 16 | * pygame 1.9.2+ (pip install pygame) 17 | * click (pip install click) 18 | * numpy (pip install numpy -- or install Anaconda distribution) 19 | * [Keras](https://keras.io) 1.2.0+, but less than 2.0 (pip install keras==1.2) 20 | * Theano or Tensorflow. The code is fully tested on Theano. (pip install theano) 21 | 22 | # Usage 23 | 24 | While any run is going on, the results as well as the **AI** models will be saved in the `./results` subfolder. For a complete run, five experiments for each method, use the following command (may take several hours depending on your machine): 25 | 26 | ``` 27 | ./run.sh 28 | ``` 29 | 30 | * NOTE: Because the state-shape is relatively small, the deep RL methods of this code run faster on CPU. 31 | 32 | Alternatively, for a single run use the following commands: 33 | 34 | * Tabular GVF: 35 | ``` 36 | ipython ./tabular/train.py -- -o use_gvf True -o folder_name tabular_gvf_ -o nb_experiments 1 37 | ``` 38 | 39 | * Tabular no-GVF: 40 | ``` 41 | ipython ./tabular/train.py -- -o use_gvf False -o folder_name tabular_no-gvf_ -o nb_experiments 1 42 | ``` 43 | 44 | * DQN: 45 | ``` 46 | THEANO_FLAG="device=cpu" ipython ./dqn/train.py -- --mode hra+1 -o nb_experiments 1 47 | ``` 48 | * `--mode` can be either of `dqn`, `dqn+1`, `hra`, `hra+1`, or `all`. 49 | 50 | # Demo 51 | 52 | We have also provided the code to demo Tabular GVF/NO-GVF methods. You first need to train the model using one of the above commands (Tabular GVF or no-GVF) and then run the demo. For example, 53 | ``` 54 | ipython ./tabular/train.py -- -o use_gvf True -o folder_name tabular_gvf_ -o nb_experiments 1 55 | ipython ./tabular/train.py -- --demo -o folder_name tabular_gvf_ 56 | ``` 57 | 58 | If you would like to save the results, use the `--save` option: 59 | ``` 60 | ipython ./tabular/train.py -- --demo --save -o folder_name tabular_gvf_ 61 | ``` 62 | The rendered images will be saved in `./render` directory by default. 63 | 64 | # License 65 | 66 | Please refer to LICENSE.txt. 67 | -------------------------------------------------------------------------------- /dqn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import yaml 4 | import numpy as np 5 | 6 | from utils import Font, set_params 7 | from dqn.experiment import DQNExperiment 8 | from environment.fruit_collection import FruitCollectionMini 9 | from dqn.ai import AI 10 | 11 | np.set_printoptions(suppress=True, linewidth=200, precision=2) 12 | floatX = 'float32' 13 | 14 | 15 | def worker(params): 16 | np.random.seed(seed=params['random_seed']) 17 | random_state = np.random.RandomState(params['random_seed']) 18 | env = FruitCollectionMini(rendering=False, game_length=300, state_mode='mini') 19 | params['reward_dim'] = len(env.possible_fruits) 20 | for ex in range(params['nb_experiments']): 21 | print('\n') 22 | print(Font.bold + Font.red + '>>>>> Experiment ', ex, ' >>>>>' + Font.end) 23 | print('\n') 24 | 25 | ai = AI(env.state_shape, env.nb_actions, params['action_dim'], params['reward_dim'], 26 | history_len=params['history_len'], gamma=params['gamma'], learning_rate=params['learning_rate'], 27 | epsilon=params['epsilon'], test_epsilon=params['test_epsilon'], minibatch_size=params['minibatch_size'], 28 | replay_max_size=params['replay_max_size'], update_freq=params['update_freq'], 29 | learning_frequency=params['learning_frequency'], num_units=params['num_units'], rng=random_state, 30 | remove_features=params['remove_features'], use_mean=params['use_mean'], use_hra=params['use_hra']) 31 | 32 | expt = DQNExperiment(env=env, ai=ai, episode_max_len=params['episode_max_len'], 33 | history_len=params['history_len'], max_start_nullops=params['max_start_nullops'], 34 | replay_min_size=params['replay_min_size'], folder_location=params['folder_location'], 35 | folder_name=params['folder_name'], testing=params['test'], score_window_size=100, 36 | rng=random_state) 37 | env.reset() 38 | if not params['test']: 39 | with open(expt.folder_name + '/config.yaml', 'w') as y: 40 | yaml.safe_dump(params, y) # saving params for future reference 41 | expt.do_training(total_eps=params['total_eps'], eps_per_epoch=params['eps_per_epoch'], 42 | eps_per_test=params['eps_per_test'], is_learning=True, is_testing=True) 43 | else: 44 | raise NotImplementedError 45 | 46 | 47 | @click.command() 48 | @click.option('--mode', default='all', help='Which method to run: dqn, dqn+1, hra, hra+1, all') 49 | @click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str])) 50 | def run(mode, options): 51 | valid_modes = ['dqn', 'dqn+1', 'hra', 'hra+1', 'all'] 52 | assert mode in valid_modes 53 | if mode in ['all']: 54 | modes = valid_modes[:-1] 55 | else: 56 | modes = [mode] 57 | 58 | dir_path = os.path.dirname(os.path.realpath(__file__)) 59 | cfg_file = os.path.join(dir_path, 'config.yaml') 60 | params = yaml.safe_load(open(cfg_file, 'r')) 61 | # replacing params with command line options 62 | for opt in options: 63 | assert opt[0] in params 64 | dtype = type(params[opt[0]]) 65 | if dtype == bool: 66 | new_opt = False if opt[1] != 'True' else True 67 | else: 68 | new_opt = dtype(opt[1]) 69 | params[opt[0]] = new_opt 70 | 71 | for m in modes: 72 | params = set_params(params, m) 73 | worker(params) 74 | 75 | 76 | if __name__ == '__main__': 77 | run() 78 | -------------------------------------------------------------------------------- /tabular/train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import yaml 4 | 5 | import click 6 | import numpy as np 7 | 8 | from tabular.ai import AI 9 | from tabular.experiment import SoCExperiment 10 | from environment.fruit_collection import FruitCollectionMini 11 | 12 | np.set_printoptions(suppress=True, precision=2) 13 | 14 | 15 | def soc_agent(params): 16 | rng = np.random.RandomState(params['random_seed']) 17 | env = FruitCollectionMini(rendering=False, game_length=300) 18 | for mc_count in range(params['nb_experiments']): 19 | ai_list = [] 20 | if not params['use_gvf']: 21 | for _ in range(env.nb_targets): 22 | fruit_ai = AI(nb_actions=env.nb_actions, init_q=params['init_q'], gamma=params['gamma'], 23 | alpha=params['alpha'], learning_method=params['learning_method'], rng=rng) 24 | ai_list.append(fruit_ai) 25 | else: 26 | for _ in env.possible_fruits: 27 | gvf_ai = AI(nb_actions=env.nb_actions, init_q=params['init_q'], gamma=params['gamma'], 28 | alpha=params['alpha'], learning_method=params['learning_method'], rng=rng) 29 | ai_list.append(gvf_ai) 30 | expt = SoCExperiment(ai_list=ai_list, env=env, aggregator_epsilon=params['aggregator_epsilon'], 31 | aggregator_final_epsilon=params['aggregator_final_epsilon'], 32 | aggregator_decay_steps=params['aggregator_decay_steps'], 33 | aggregator_decay_start=params['aggregator_decay_start'], final_alpha=params['final_alpha'], 34 | alpha_decay_steps=params['alpha_decay_steps'], alpha_decay_start=params['alpha_decay_start'], 35 | epoch_size=params['epoch_size'], folder_name=params['folder_name'], 36 | folder_location=params['folder_location'], 37 | nb_eval_episodes=params['nb_eval_episodes'], use_gvf=params['use_gvf'], rng=rng) 38 | with open(expt.folder_name + '/config.yaml', 'w') as y: 39 | yaml.safe_dump(params, y) # saving params for future reference 40 | expt.do_epochs(number=params['nb_epochs']) 41 | 42 | 43 | def demo_soc(params, nb_episodes, rendering_sleep, saving): 44 | rng = np.random.RandomState(1234) 45 | env = FruitCollectionMini(rendering=True, lives=1, game_length=300, image_saving=saving, rng=rng) 46 | i = 0 47 | while os.path.exists(os.getcwd() + params['folder_location'] + params['folder_name'] + str(i)): 48 | i += 1 49 | file_name = os.getcwd() + params['folder_location'] + params['folder_name'] + str(i - 1) + '/soc_ai_list.pkl' 50 | print(file_name) 51 | with open(file_name, 'rb') as f: 52 | ai_list = pickle.load(f) 53 | expt = SoCExperiment(ai_list=ai_list, env=env, aggregator_epsilon=0., aggregator_final_epsilon=None, 54 | aggregator_decay_steps=None, aggregator_decay_start=None, use_gvf=params['use_gvf'], 55 | epoch_size=params['epoch_size'], make_folder=False, folder_name=params['folder_name'], 56 | folder_location=params['folder_location'], nb_eval_episodes=params['nb_eval_episodes'], 57 | final_alpha=None, alpha_decay_steps=None, alpha_decay_start=None, rng=rng) 58 | expt.demo(nb_episodes=nb_episodes, rendering_sleep=rendering_sleep) 59 | 60 | 61 | @click.command() 62 | @click.option('--demo/--no-demo', default=False, help='Do a demo.') 63 | @click.option('--save/--no-save', default=False, help='Save images.') 64 | @click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str])) 65 | def main(options, demo, save): 66 | dir_path = os.path.dirname(os.path.realpath(__file__)) 67 | config = os.path.join(dir_path, 'config.yaml') 68 | with open(config, 'r') as f: 69 | params = yaml.safe_load(f) 70 | # replacing params with command line options 71 | for opt in options: 72 | assert opt[0] in params 73 | dtype = type(params[opt[0]]) 74 | if dtype == bool: 75 | new_opt = False if opt[1] != 'True' else True 76 | else: 77 | new_opt = dtype(opt[1]) 78 | params[opt[0]] = new_opt 79 | 80 | if demo: 81 | demo_soc(params, nb_episodes=3, rendering_sleep=0.1, saving=save) 82 | else: 83 | soc_agent(params) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Microsoft Research License Agreement for HRA Samples 2 | 3 | This Microsoft Research License Agreement (“Agreement”) is a legal agreement between you and Microsoft Corporation (or based on where you live, one of its affiliates). They apply to the Microsoft Research software named above which may include source code and any associated materials or documentation and any updates we provide in our discretion (together, the “Software”). By agreeing to this Agreement and/or by using the Software, you accept these terms. If you do not accept them, do not use the Software. If you comply with these license terms, you have the rights below. 4 | 5 | 1. SCOPE OF RIGHTS. 6 | 7 | a. You may use, copy, modify, create derivative works, and distribute the Software for non-commercial purposes, subject to the restrictions in this Agreement. Examples of non-commercial uses are teaching, academic research, public demonstrations and personal experimentation. 8 | b. If you distribute the Software or any derivative works of the Software, you will distribute them under the same terms and conditions as in this license, and you will not grant other rights to the Software or derivative works that are different from those provided by this Agreement. 9 | c. If you have created derivative works of the Software, and distribute such derivative works, you will cause the modified files to carry prominent notices so that recipients know that they are not receiving the original software. Such notices must state: (i) that you have changed the software; and (ii) the date of any changes. 10 | 11 | 2. RESTRICTIONS AND LIMITATIONS. You may not: 12 | 13 | a. Alter any copyright, trademark or patent notice in the software; 14 | b. Use Microsoft’s trademarks in your programs’ names or in a way that suggests your derivative works or modifications come from or are endorsed by Microsoft; or 15 | c. Include the software in malicious, deceptive or unlawful programs. 16 | The Software is licensed, not sold. This Agreement only gives you some rights to use the Software. Microsoft reserves all other rights. The patent rights, if any, granted to you in this Agreement only apply to the Software, not to any derivative works you make. 17 | 18 | 3. LICENSE TO MICROSOFT. Microsoft is granted back, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer your modifications to and/or derivative works of the Software source code or data, for any purpose. 19 | 20 | 4. FEEDBACK. Any feedback about the Software provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. 21 | 22 | 5. TERMINATION. If you breach this Agreement or if you sue Microsoft or any other party over patents that you think may apply to or read on the Software or anyone's use of the Software, this Agreement (and your license and rights obtained herein) terminate automatically. If this Agreement is terminated, you must cease using and distributing any derivative works or modifications of the Software. Any sections that are intended to survive termination of this Agreement shall survive. 23 | 24 | 6. EXPORT RESTRICTIONS. The Software is subject to United States export laws and regulations. You must comply with all domestic and international export laws and regulations that apply to the Software. These laws include restrictions on destinations, end users and end use. For additional information, see www.microsoft.com/exporting. 25 | 26 | 7. ENTIRE AGREEMENT. This Agreement, any exhibits, and the terms for supplements, updates, Internet-based services and support services that you use, are the entire agreement for the Software and support services. 27 | 28 | 8. SEVERABILITY. If any court of competent jurisdiction determines that any provision of this Agreement is illegal, invalid or unenforceable, the remaining provisions will remain in full force and effect. 29 | 30 | 9. GOVERNING LAW AND VENUE. This Agreement is governed by and construed in accordance with the laws of the state of Washington, without reference to its choice of law principles to the contrary. Each party hereby consents to the jurisdiction and venue of the state and federal courts located in King County, Washington, with regard to any suit or claim arising under or by reason of this Agreement. 31 | 32 | 10. LEGAL EFFECT. This Agreement describes certain legal rights. You may have other rights under the laws of your country. You may also have rights with respect to the party from whom you acquired the dataset. This Agreement does not change your rights under the laws of your country if the laws of your country do not permit it to do so. 33 | 34 | 11. NO ASSIGNMENT. You may not assign this Agreement or any rights or obligations hereunder, except with Microsoft’s express written consent. Any attempted assignment in violation of this section will be void. 35 | 36 | 12. DISCLAIMER OF WARRANTY. THE SOFTWARE IS LICENSED “AS-IS.” YOU BEAR THE RISK OF USING IT. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES OR CONDITIONS. YOU MAY HAVE ADDITIONAL CONSUMER RIGHTS OR STATUTORY GUARANTEES UNDER YOUR LOCAL LAWS WHICH THIS AGREEMENT CANNOT CHANGE. TO THE EXTENT PERMITTED UNDER YOUR LOCAL LAWS, MICROSOFT EXCLUDES THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. 37 | -------------------------------------------------------------------------------- /dqn/experiment.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from utils import Font, plot_and_write, create_folder 4 | 5 | 6 | class DQNExperiment(object): 7 | def __init__(self, env, ai, episode_max_len, history_len=1, max_start_nullops=1, replay_min_size=0, 8 | score_window_size=100, rng=None, folder_location='/experiments/', folder_name='expt', testing=False): 9 | self.rng = rng 10 | self.fps = 0 11 | self.episode_num = 0 12 | self.last_episode_steps = 0 13 | self.total_training_steps = 0 14 | self.score_agent = 0 15 | self.eval_steps = [] 16 | self.eval_scores = [] 17 | self.env = env 18 | self.ai = ai 19 | self.history_len = history_len 20 | self.max_start_nullops = max_start_nullops 21 | if not testing: 22 | self.folder_name = create_folder(folder_location, folder_name) 23 | self.episode_max_len = episode_max_len 24 | self.score_agent_window = np.zeros(score_window_size) 25 | self.steps_agent_window = np.zeros(score_window_size) 26 | self.replay_min_size = max(self.ai.minibatch_size, replay_min_size) 27 | self.last_state = np.empty(tuple([self.history_len] + self.env.state_shape), dtype=np.uint8) 28 | 29 | def do_training(self, total_eps=5000, eps_per_epoch=100, eps_per_test=100, is_learning=True, is_testing=True): 30 | while self.episode_num < total_eps: 31 | print(Font.yellow + Font.bold + 'Training ... ' + str(self.episode_num) + '/' + str(total_eps) + Font.end, 32 | end='\n') 33 | self.do_episodes(number=eps_per_epoch, is_learning=is_learning) 34 | 35 | if is_testing: 36 | eval_scores, eval_steps = self.do_episodes(number=eps_per_test, is_learning=False) 37 | self.eval_steps.append(eval_steps) 38 | self.eval_scores.append(eval_scores) 39 | plot_and_write(plot_dict={'steps': self.eval_steps}, loc=self.folder_name + "/steps", 40 | x_label="Episodes", y_label="Steps", title="", kind='line', legend=True) 41 | plot_and_write(plot_dict={'scores': self.eval_scores}, loc=self.folder_name + "/scores", 42 | x_label="Episodes", y_label="Scores", title="", kind='line', legend=True) 43 | self.ai.dump_network(weights_file_path=self.folder_name + '/q_network_weights.h5', 44 | overwrite=True) 45 | 46 | def do_episodes(self, number=1, is_learning=True): 47 | scores = [] 48 | steps = [] 49 | for num in range(number): 50 | self._do_episode(is_learning=is_learning) 51 | scores.append(self.score_agent) 52 | steps.append(self.last_episode_steps) 53 | if not is_learning: 54 | self.score_agent_window = self._update_window(self.score_agent_window, self.score_agent) 55 | self.steps_agent_window = self._update_window(self.steps_agent_window, self.last_episode_steps) 56 | else: 57 | self.episode_num += 1 58 | return np.mean(scores), np.mean(steps) 59 | 60 | def _do_episode(self, is_learning=True): 61 | rewards = [] 62 | self.env.reset() 63 | self._reset() 64 | term = False 65 | self.fps = 0 66 | start_time = time.time() 67 | while not term: 68 | reward, term = self._step(evaluate=not is_learning) 69 | rewards.append(reward) 70 | if self.ai.transitions.size >= self.replay_min_size and is_learning and \ 71 | self.last_episode_steps % self.ai.learning_frequency == 0: 72 | self.ai.learn() 73 | self.score_agent += reward 74 | if not term and self.last_episode_steps >= self.episode_max_len: 75 | print('Reaching maximum number of steps in the current episode.') 76 | term = True 77 | self.fps = int(self.last_episode_steps / max(0.1, (time.time() - start_time))) 78 | return rewards 79 | 80 | def _step(self, evaluate=False): 81 | self.last_episode_steps += 1 82 | action = self.ai.get_action(self.last_state, evaluate) 83 | new_obs, reward, game_over, info = self.env.step(action) 84 | reward_channels = info['head_reward'] 85 | if new_obs.ndim == 1 and len(self.env.state_shape) == 2: 86 | new_obs = new_obs.reshape(self.env.state_shape) 87 | if not evaluate: 88 | self.ai.transitions.add(s=self.last_state[-1].astype('float32'), a=action, r=reward_channels, t=game_over) 89 | self.total_training_steps += 1 90 | if new_obs.ndim == 1 and len(self.env.state_shape) == 2: 91 | new_obs = new_obs.reshape(self.env.state_shape) 92 | self._update_state(new_obs) 93 | return reward, game_over 94 | 95 | def _reset(self): 96 | self.last_episode_steps = 0 97 | self.score_agent = 0 98 | 99 | assert self.max_start_nullops >= self.history_len or self.max_start_nullops == 0 100 | if self.max_start_nullops != 0: 101 | num_nullops = self.rng.randint(self.history_len, self.max_start_nullops) 102 | for i in range(num_nullops - self.history_len): 103 | self.env.step(0) 104 | 105 | for i in range(self.history_len): 106 | if i > 0: 107 | self.env.step(0) 108 | obs = self.env.get_state() 109 | if obs.ndim == 1 and len(self.env.state_shape) == 2: 110 | obs = obs.reshape(self.env.state_shape) 111 | self.last_state[i] = obs 112 | 113 | def _update_state(self, new_obs): 114 | temp_buffer = np.empty(self.last_state.shape, dtype=np.uint8) 115 | temp_buffer[:-1] = self.last_state[-self.history_len + 1:] 116 | temp_buffer[-1] = new_obs 117 | self.last_state = temp_buffer 118 | 119 | @staticmethod 120 | def _update_window(window, new_value): 121 | window[:-1] = window[1:] 122 | window[-1] = new_value 123 | return window 124 | -------------------------------------------------------------------------------- /environment/README.md: -------------------------------------------------------------------------------- 1 | # Fruit-Collection Environment 2 | This environment provides the game of Fruit-Collection (a fruit-collection and ghost-avoidance game), which is a good testbed for several RL algorithms. We have used different variation of this environment in some of our papers, including: 3 | 4 | * https://arxiv.org/abs/1706.04208 5 | * https://arxiv.org/abs/1612.05159 6 | * https://arxiv.org/abs/1704.00756 7 | 8 | # About 9 | Fruit-Collection is a vastly configurable game with very interesting properties, among which is its state-space that can be huge and completely intractable using classical tabular methods. 10 | The game consists an arbitrary maze (with or without walls), in which the player should move and eat fixed fruits (blue blocks). Each fruit results in +1 reward. There may exist a number of ghosts (red squares) that cause the player to loose life and receive negative score. 11 | Based on the maze shape, number of ghosts, (fixed or random) position of fruits, number of lives, and initial position of the player, the game can easily be designed in a way that serves best as a testbed for a given algorithm of interest. 12 | 13 | # Dependencies 14 | 15 | * Python 3.5 or higher 16 | * pygame (pip install pygame) 17 | * click (pip install click) 18 | * numpy (pip install numpy -- or install Anaconda distribution) 19 | 20 | # Getting Started 21 | Three particular environment are provided out-of-the-box, some of them has been used in previously published research. Nevertheless, making a new version is still very easy and only requires to identify two methods in the superclass (see the example below). 22 | 23 | * `mini`: 10 x 10 maze + no ghost + no wall + random selection of 5 fruits from a pre-defined set of 10 possible fruits at each episode + random initial position of the player. 24 | 25 | * `small`: 11 x 11 maze with some dead ends + two ghosts doing random walk + random fruits which are initialized at each episode with 50% chance of existence at each location + fixed initial position of the player at lower right corner. 26 | 27 | * `large`: 21 x 14 maze with two horizontal passing corridors + 4 ghosts doing random walk + random fruits similar to `small` + initial position at [10, 8]. 28 | 29 | The main arguments and other internal properties (such as reward scheme) are straightforward. 30 | 31 | ## Arguments: 32 | 33 | * game_length: `int` maximum number of steps at each episode. 34 | 35 | * lives: `int` number of times that the player can hit a ghost, after which the episode will be finished. 36 | 37 | * state_mode: `str` specifies the format of returned state (see below). 38 | 39 | * is_fruit: `bool` whether or not the fruits appear on the maze (does not do anything if no fruit is present in the game). 40 | 41 | * is_ghost:`bool` whether or not the ghosts appear on the maze (does not do anything if no ghost is present in the game, e.g. in `mini`) 42 | 43 | * rendering: `bool` Does the rendering if `True`. 44 | 45 | * image_saving: `bool` saves the rendered frame as bitmap at each `render` call. 46 | 47 | * render_dir: `str` the directory in which rendered frames will be saved: if `image_saving=True` a new folder called `render0` will be made inside `render_dir` (with new int suffix automatically assigned for each episode). If none provided, `render` is used by default. 48 | 49 | ## Return state 50 | Three options are available (see argument `state_mode`): 51 | 52 | * `pixel`: includes 4 binary channels (matrices) corresponding to walls, fruits, player position, and ghosts position. 53 | 54 | * `1hot`: include augmentation of 1hot vectors for player position and ghost positions. 55 | 56 | * `multi-head`: Returns concatenation of three binary vectors for player position (1-hot), not-eaten fruits (binary), and ghost positions (binary). 57 | 58 | 59 | ## Main usage: 60 | ``` 61 | import numpy as np 62 | import time 63 | from fruit_collection import FruitCollectionSmall # or FruitCollectionLarge or FruitCollectionMini 64 | 65 | env = FruitCollectionSmall(rendering=True, lives=1, is_fruit=True, is_ghost=True, image_saving=False) 66 | env.reset() 67 | env.render() 68 | 69 | for _ in range(50): 70 | action = np.random.choice(env.legal_actions) 71 | obs, r, term, info = env.step(action) 72 | env.render() 73 | time.sleep(.2) 74 | ``` 75 | 76 | ## Human Play! 77 | You can also play Fruit-Collection yourself: 78 | ``` 79 | python fruit_collection.py -m small 80 | ``` 81 | or to have saving and to deactivate ghosts: 82 | ``` 83 | python fruit_collection.py -m small --save --no-ghost 84 | ``` 85 | 86 | Press `Q` to finish the game. 87 | 88 | # Making Your Own Fruit-Collection 89 | It is still quite easy to specialize Fruit-Collection for your own experiment. You mainly need to define the two methods `init_with_mode` and `_reset_targets` in your superclass. 90 | `init_with_mode` defines basic properties such as maze shape, location of walls (if any), etc. On the other hand, `_reset_targets` is called at the `reset` method of base class and defines properties at the begining of each episode (such as how fruits are spawned and player's init position). 91 | 92 | See the following example (also see the three superclasses for Mini, Small, and Large): 93 | ``` 94 | class MyFruitCollection(FruitCollection): 95 | def init_with_mode(self): 96 | self.is_ghost = False 97 | self.is_fruit = True 98 | self.nb_fruits = 4 99 | self.scr_w = 5 100 | self.scr_h = 5 101 | self.possible_fruits = [[0, 0], [0, 4], [4, 0], [4, 4]] 102 | self.rendering_scale = 50 103 | self.walls = [[1, 0], [2, 0], [4, 1], [0, 2], [2, 2], [3, 3], [1, 4]] 104 | if self.is_ghost: 105 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [0, 1], 106 | 'active': True}] 107 | else: 108 | self.ghosts = [] 109 | 110 | def _reset_targets(self): 111 | while True: 112 | self.player_pos_x, self.player_pos_y = np.random.randint(0, self.scr_w), np.random.randint(0, self.scr_h) 113 | if [self.player_pos_x, self.player_pos_y] not in self.possible_fruits + self.walls: 114 | break 115 | self.fruits = [] 116 | self.active_fruits = [] 117 | if self.is_fruit: 118 | for x in range(self.scr_w): 119 | for y in range(self.scr_h): 120 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 121 | 'location': [x, y], 'active': False}) 122 | self.active_fruits.append(False) 123 | fruits_idx = deepcopy(self.possible_fruits) 124 | np.random.shuffle(fruits_idx) 125 | fruits_idx = fruits_idx[:self.nb_fruits] 126 | self.mini_target = [False] * len(self.possible_fruits) 127 | for f in fruits_idx: 128 | idx = f[1] * self.scr_w + f[0] 129 | self.fruits[idx]['active'] = True 130 | self.active_fruits[idx] = True 131 | self.mini_target[self.possible_fruits.index(f)] = True 132 | ``` 133 | 134 | -------------------------------------------------------------------------------- /dqn/ai.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | from keras import backend as K 5 | from keras.optimizers import RMSprop 6 | 7 | from utils import ExperienceReplay, slice_tensor_tensor, flatten 8 | from dqn.model import build_dense 9 | 10 | floatX = 'float32' 11 | 12 | 13 | class AI(object): 14 | def __init__(self, state_shape, nb_actions, action_dim, reward_dim, history_len=1, gamma=.99, 15 | learning_rate=0.00025, epsilon=0.05, final_epsilon=0.05, test_epsilon=0.0, 16 | minibatch_size=32, replay_max_size=100, update_freq=50, learning_frequency=1, 17 | num_units=250, remove_features=False, use_mean=False, use_hra=True, rng=None): 18 | self.rng = rng 19 | self.history_len = history_len 20 | self.state_shape = [1] + state_shape 21 | self.nb_actions = nb_actions 22 | self.action_dim = action_dim 23 | self.reward_dim = reward_dim 24 | self.gamma = gamma 25 | self.learning_rate = learning_rate 26 | self.learning_rate_start = learning_rate 27 | self.epsilon = epsilon 28 | self.start_epsilon = epsilon 29 | self.test_epsilon = test_epsilon 30 | self.final_epsilon = final_epsilon 31 | self.minibatch_size = minibatch_size 32 | self.update_freq = update_freq 33 | self.update_counter = 0 34 | self.nb_units = num_units 35 | self.use_mean = use_mean 36 | self.use_hra = use_hra 37 | self.remove_features = remove_features 38 | self.learning_frequency = learning_frequency 39 | self.replay_max_size = replay_max_size 40 | self.transitions = ExperienceReplay(max_size=self.replay_max_size, history_len=history_len, rng=self.rng, 41 | state_shape=state_shape, action_dim=action_dim, reward_dim=reward_dim) 42 | self.networks = [self._build_network() for _ in range(self.reward_dim)] 43 | self.target_networks = [self._build_network() for _ in range(self.reward_dim)] 44 | self.all_params = flatten([network.trainable_weights for network in self.networks]) 45 | self.all_target_params = flatten([target_network.trainable_weights for target_network in self.target_networks]) 46 | self.weight_transfer(from_model=self.networks, to_model=self.target_networks) 47 | self._compile_learning() 48 | print('Compiled Model and Learning.') 49 | 50 | def _build_network(self): 51 | return build_dense(self.state_shape, int(self.nb_units / self.reward_dim), 52 | self.nb_actions, self.reward_dim, self.remove_features) 53 | 54 | def _remove_features(self, s, i): 55 | return K.concatenate([s[:, :, :, : -self.reward_dim], 56 | K.expand_dims(s[:, :, :, self.state_shape[-1] - self.reward_dim + i], dim=-1)]) 57 | 58 | def _compute_cost(self, q, a, r, t, q2): 59 | preds = slice_tensor_tensor(q, a) 60 | bootstrap = K.max if not self.use_mean else K.mean 61 | targets = r + (1 - t) * self.gamma * bootstrap(q2, axis=1) 62 | cost = K.sum((targets - preds) ** 2) 63 | return cost 64 | 65 | def _compile_learning(self): 66 | s = K.placeholder(shape=tuple([None] + [self.history_len] + self.state_shape)) 67 | a = K.placeholder(ndim=1, dtype='int32') 68 | r = K.placeholder(ndim=2, dtype='float32') 69 | s2 = K.placeholder(shape=tuple([None] + [self.history_len] + self.state_shape)) 70 | t = K.placeholder(ndim=1, dtype='float32') 71 | 72 | updates = [] 73 | costs = 0 74 | qs = [] 75 | q2s = [] 76 | for i in range(len(self.networks)): 77 | local_s = s 78 | local_s2 = s2 79 | if self.remove_features: 80 | local_s = self._remove_features(local_s, i) 81 | local_s2 = self._remove_features(local_s2, i) 82 | qs.append(self.networks[i](local_s)) 83 | q2s.append(self.target_networks[i](local_s2)) 84 | if self.use_hra: 85 | cost = self._compute_cost(qs[-1], a, r[:, i], t, q2s[-1]) 86 | optimizer = RMSprop(lr=self.learning_rate, rho=.95, epsilon=1e-7) 87 | updates += optimizer.get_updates(params=self.networks[i].trainable_weights, loss=cost, constraints={}) 88 | costs += cost 89 | if not self.use_hra: 90 | q = sum(qs) 91 | q2 = sum(q2s) 92 | summed_reward = K.sum(r, axis=-1) 93 | cost = self._compute_cost(q, a, summed_reward, t, q2) 94 | optimizer = RMSprop(lr=self.learning_rate, rho=.95, epsilon=1e-7) 95 | updates += optimizer.get_updates(params=self.all_params, loss=cost, constraints={}) 96 | costs += cost 97 | 98 | target_updates = [] 99 | for network, target_network in zip(self.networks, self.target_networks): 100 | for target_weight, network_weight in zip(target_network.trainable_weights, network.trainable_weights): 101 | target_updates.append(K.update(target_weight, network_weight)) 102 | 103 | self._train_on_batch = K.function(inputs=[s, a, r, s2, t], outputs=[costs], updates=updates) 104 | self.predict_network = K.function(inputs=[s], outputs=qs) 105 | self.update_weights = K.function(inputs=[], outputs=[], updates=target_updates) 106 | 107 | def update_lr(self, cur_step, total_steps): 108 | self.learning_rate = ((total_steps - cur_step - 1) / total_steps) * self.learning_rate_start 109 | 110 | def get_max_action(self, states): 111 | states = self._reshape(states) 112 | q = np.array(self.predict_network([states])) 113 | q = np.sum(q, axis=0) 114 | return np.argmax(q, axis=1) 115 | 116 | def get_action(self, states, evaluate): 117 | eps = self.epsilon if not evaluate else self.test_epsilon 118 | if self.rng.binomial(1, eps): 119 | return self.rng.randint(self.nb_actions) 120 | else: 121 | return self.get_max_action(states=states) 122 | 123 | def train_on_batch(self, s, a, r, s2, t): 124 | s = self._reshape(s) 125 | s2 = self._reshape(s2) 126 | if len(r.shape) == 1: 127 | r = np.expand_dims(r, axis=-1) 128 | return self._train_on_batch([s, a, r, s2, t]) 129 | 130 | def learn(self): 131 | assert self.minibatch_size <= self.transitions.size, 'not enough data in the pool' 132 | s, a, r, s2, term = self.transitions.sample(self.minibatch_size) 133 | objective = self.train_on_batch(s, a, r, s2, term) 134 | if self.update_counter == self.update_freq: 135 | self.update_weights([]) 136 | self.update_counter = 0 137 | else: 138 | self.update_counter += 1 139 | return objective 140 | 141 | def dump_network(self, weights_file_path='q_network_weights.h5', overwrite=True): 142 | for i, network in enumerate(self.networks): 143 | network.save_weights(weights_file_path[:-3] + str(i) + weights_file_path[-3:], overwrite=overwrite) 144 | 145 | def load_weights(self, weights_file_path='q_network_weights.h5'): 146 | for i, network in enumerate(self.networks): 147 | network.load_weights(weights_file_path[:-3] + str(i) + weights_file_path[-3:]) 148 | self.update_weights([]) 149 | 150 | @staticmethod 151 | def _reshape(states): 152 | if len(states.shape) == 2: 153 | states = np.expand_dims(states, axis=0) 154 | if len(states.shape) == 3: 155 | states = np.expand_dims(states, axis=1) 156 | return states 157 | 158 | @staticmethod 159 | def weight_transfer(from_model, to_model): 160 | for f_model, t_model in zip(from_model, to_model): 161 | t_model.set_weights(deepcopy(f_model.get_weights())) 162 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import numpy as np 5 | from numpy import all, uint8 6 | import pandas as pd 7 | import matplotlib as mpl 8 | from keras import backend as K 9 | 10 | mpl.use('Agg') 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def flatten(l): 15 | return [item for sublist in l for item in sublist] 16 | 17 | 18 | def set_params(params, mode, gamma=None, lr=None, folder_name=None): 19 | if mode == 'dqn': 20 | params['gamma'] = .85 21 | params['learning_rate'] = .0005 22 | params['remove_features'] = False 23 | params['use_mean'] = False 24 | params['use_hra'] = False 25 | elif mode == 'dqn+1': 26 | params['gamma'] = .85 27 | params['learning_rate'] = .0005 28 | params['remove_features'] = True 29 | params['use_mean'] = False 30 | params['use_hra'] = False 31 | elif mode == 'hra': 32 | params['gamma'] = .99 33 | params['learning_rate'] = .001 34 | params['remove_features'] = False 35 | params['use_mean'] = True 36 | params['use_hra'] = True 37 | elif mode == 'hra+1': 38 | params['gamma'] = .99 39 | params['learning_rate'] = .001 40 | params['remove_features'] = True 41 | params['use_mean'] = True 42 | params['use_hra'] = True 43 | if gamma is not None: 44 | params['gamma'] = gamma 45 | params['learning_rate'] = lr 46 | if folder_name is None: 47 | params['folder_name'] = mode + '__g' + str(params['gamma']) + '__lr' + str(params['learning_rate']) + '__' 48 | else: 49 | params['folder_name'] = folder_name 50 | return params 51 | 52 | 53 | def slice_tensor_tensor(tensor, tensor_slice): 54 | """ 55 | Theano and tensorflow differ in the method of extracting the value of the actions taken 56 | arg1: the tensor to be slice i.e Q(s) 57 | arg2: the indices to slice by ie a 58 | """ 59 | if K.backend() == 'theano': 60 | output = tensor[K.T.arange(tensor_slice.shape[0]), tensor_slice] 61 | elif K.backend() == 'tensorflow': 62 | amask = K.tf.one_hot(tensor_slice, tensor.get_shape()[1], 1.0, 0.0) 63 | output = K.tf.reduce_sum(tensor * amask, reduction_indices=1) 64 | else: 65 | raise Exception("Not using theano or tensor flow as backend") 66 | return output 67 | 68 | 69 | def plot(data={}, loc="visualization.pdf", x_label="", y_label="", title="", kind='line', 70 | legend=True, index_col=None, clip=None, moving_average=False): 71 | if all([len(data[key]) > 1 for key in data]): 72 | if moving_average: 73 | smoothed_data = {} 74 | for key in data: 75 | smooth_scores = [np.mean(data[key][max(0, i - 10):i + 1]) for i in range(len(data[key]))] 76 | smoothed_data['smoothed_' + key] = smooth_scores 77 | smoothed_data[key] = data[key] 78 | data = smoothed_data 79 | df = pd.DataFrame(data=data) 80 | ax = df.plot(kind=kind, legend=legend, ylim=clip) 81 | ax.set_xlabel(x_label) 82 | ax.set_ylabel(y_label) 83 | ax.set_title(title) 84 | plt.tight_layout() 85 | plt.savefig(loc) 86 | plt.close() 87 | 88 | 89 | def write_to_csv(data={}, loc="data.csv"): 90 | if all([len(data[key]) > 1 for key in data]): 91 | df = pd.DataFrame(data=data) 92 | df.to_csv(loc) 93 | 94 | 95 | def plot_and_write(plot_dict, loc, x_label="", y_label="", title="", kind='line', legend=True, 96 | moving_average=False): 97 | for key in plot_dict: 98 | plot(data={key: plot_dict[key]}, loc=loc + ".pdf", x_label=x_label, y_label=y_label, title=title, 99 | kind=kind, legend=legend, index_col=None, moving_average=moving_average) 100 | write_to_csv(data={key: plot_dict[key]}, loc=loc + ".csv") 101 | 102 | 103 | def create_folder(folder_location, folder_name): 104 | i = 0 105 | while os.path.exists(os.getcwd() + folder_location + folder_name + str(i)): 106 | i += 1 107 | folder_name = os.getcwd() + folder_location + folder_name + str(i) 108 | os.mkdir(folder_name) 109 | return folder_name 110 | 111 | 112 | class Font: 113 | purple = '\033[95m' 114 | cyan = '\033[96m' 115 | darkcyan = '\033[36m' 116 | blue = '\033[94m' 117 | green = '\033[92m' 118 | yellow = '\033[93m' 119 | red = '\033[91m' 120 | bgblue = '\033[44m' 121 | bold = '\033[1m' 122 | underline = '\033[4m' 123 | end = '\033[0m' 124 | 125 | 126 | class ExperienceReplay(object): 127 | """ 128 | Efficient experience replay pool for DQN. 129 | """ 130 | def __init__(self, max_size=100, history_len=1, state_shape=None, action_dim=1, reward_dim=1, state_dtype=np.uint8, 131 | rng=None): 132 | if rng is None: 133 | self.rng = np.random.RandomState(1234) 134 | else: 135 | self.rng = rng 136 | self.size = 0 137 | self.head = 0 138 | self.tail = 0 139 | self.max_size = max_size 140 | self.history_len = history_len 141 | self.state_shape = state_shape 142 | self.action_dim = action_dim 143 | self.reward_dim = reward_dim 144 | self.state_dtype = state_dtype 145 | self._minibatch_size = None 146 | self.states = np.zeros([self.max_size] + list(self.state_shape), dtype=self.state_dtype) 147 | self.terms = np.zeros(self.max_size, dtype='bool') 148 | if self.action_dim == 1: 149 | self.actions = np.zeros(self.max_size, dtype='int32') 150 | else: 151 | self.actions = np.zeros((self.max_size, self.action_dim), dtype='int32') 152 | if self.reward_dim == 1: 153 | self.rewards = np.zeros(self.max_size, dtype='float32') 154 | else: 155 | self.rewards = np.zeros((self.max_size, self.reward_dim), dtype='float32') 156 | 157 | def _init_batch(self, number): 158 | self.s = np.zeros([number] + [self.history_len] + list(self.state_shape), dtype=self.states[0].dtype) 159 | self.s2 = np.zeros([number] + [self.history_len] + list(self.state_shape), dtype=self.states[0].dtype) 160 | self.t = np.zeros(number, dtype='bool') 161 | action_indicator = self.actions[0] 162 | if self.actions.ndim == 1: 163 | self.a = np.zeros(number, dtype='int32') 164 | else: 165 | self.a = np.zeros((number, action_indicator.size), dtype=action_indicator.dtype) 166 | if self.rewards.ndim == 1: 167 | self.r = np.zeros(number, dtype='float32') 168 | else: 169 | self.r = np.zeros((number, self.reward_dim), dtype='float32') 170 | 171 | def sample(self, num=1): 172 | if self.size == 0: 173 | logging.error('cannot sample from empty transition table') 174 | elif num <= self.size: 175 | if not self._minibatch_size or num != self._minibatch_size: 176 | self._init_batch(number=num) 177 | self._minibatch_size = num 178 | for i in range(num): 179 | self.s[i], self.a[i], self.r[i], self.s2[i], self.t[i] = self._get_transition() 180 | return self.s, self.a, self.r, self.s2, self.t 181 | elif num > self.size: 182 | logging.error('transition table has only {0} elements; {1} requested'.format(self.size, num)) 183 | 184 | def _get_transition(self): 185 | sample_success = False 186 | while not sample_success: 187 | randint = self.rng.randint(self.head, self.head + self.size - self.history_len) 188 | state_indices = np.arange(randint, randint + self.history_len) 189 | next_state_indices = state_indices + 1 190 | transition_index = randint + self.history_len - 1 191 | a_axis = None if self.action_dim == 1 else 0 192 | r_axis = None if self.reward_dim == 1 else 0 193 | if not np.any(self.terms.take(state_indices[:-1], mode='wrap')): 194 | s = self.states.take(state_indices, mode='wrap', axis=0) 195 | a = self.actions.take(transition_index, mode='wrap', axis=a_axis) 196 | r = self.rewards.take(transition_index, mode='wrap', axis=r_axis) 197 | t = self.terms.take(transition_index, mode='wrap') 198 | s2 = self.states.take(next_state_indices, mode='wrap', axis=0) 199 | sample_success = True 200 | return s, a, r, s2, t 201 | 202 | def add(self, s, a, r, t): 203 | self.states[self.tail] = s 204 | self.actions[self.tail] = a 205 | self.rewards[self.tail] = r 206 | self.terms[self.tail] = t 207 | self.tail = (self.tail + 1) % self.max_size 208 | if self.size == self.max_size: 209 | self.head = (self.head + 1) % self.max_size 210 | else: 211 | self.size += 1 212 | 213 | def reset(self): 214 | self.size = 0 215 | self.head = 0 216 | self.tail = 0 217 | self._minibatch_size = None 218 | self.states = np.zeros([self.max_size] + list(self.state_shape), dtype=self.state_dtype) 219 | self.terms = np.zeros(self.max_size, dtype='bool') 220 | if isinstance(self.action_dim, int): 221 | self.actions = np.zeros(self.max_size, dtype='int32') 222 | else: 223 | self.actions = np.zeros((self.max_size, self.action_dim.size), dtype=self.action_dim.dtype) 224 | if isinstance(self.reward_dim, int): 225 | self.rewards = np.zeros(self.max_size, dtype='float32') 226 | else: 227 | self.rewards = np.zeros((self.max_size, 2), dtype='float32') 228 | -------------------------------------------------------------------------------- /tabular/experiment.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | from copy import deepcopy 4 | import numpy as np 5 | from utils import Font, plot_and_write, create_folder 6 | 7 | floatX = np.float32 8 | 9 | 10 | class SoCExperiment(object): 11 | def __init__(self, ai_list, env, aggregator_epsilon, aggregator_final_epsilon, aggregator_decay_start, 12 | aggregator_decay_steps, epoch_size, nb_eval_episodes, final_alpha, alpha_decay_steps, 13 | alpha_decay_start, use_gvf, rng, make_folder=True, folder_location='/results/', folder_name='expt'): 14 | self.ai_list = ai_list 15 | self.env_done_for_agent = None 16 | self.learning_flag = None 17 | self.env = env 18 | self.use_gvf = use_gvf 19 | self.rng = rng 20 | self.aggregator_start_epsilon = aggregator_epsilon 21 | self.aggregator_epsilon = aggregator_epsilon 22 | self.aggregator_final_epsilon = aggregator_final_epsilon 23 | self.aggregator_decay_start = aggregator_decay_start 24 | self.aggregator_decay_steps = aggregator_decay_steps 25 | self.last_state = None 26 | self.action = None 27 | self.score_agent = 0 28 | self.step_in_episode = 0 29 | self.total_learning_steps = 0 # is not reset 30 | self.count_episode = 0 31 | self.epoch_size = epoch_size 32 | self.nb_eval_episodes = nb_eval_episodes 33 | self.eval_flag = False 34 | self.episode_done = False 35 | if make_folder: 36 | self.folder_name = create_folder(folder_location, folder_name) 37 | self.final_alpha = final_alpha 38 | self.alpha_decay_steps = alpha_decay_steps 39 | self.alpha_decay_start = alpha_decay_start 40 | self.reset() 41 | 42 | def reset(self): 43 | self.env.reset() 44 | if self.env.rendering: 45 | self.env.render() 46 | self.episode_done = False 47 | if self.use_gvf: 48 | self.env_done_for_agent = [not t for t in self.env.mini_target] 49 | self.last_state = [self.env.player_pos_y, self.env.player_pos_x] 50 | else: 51 | self.env_done_for_agent = [not t for t in self.env.active_targets] 52 | self.last_state = self.env.get_soc_state() 53 | self.learning_flag = deepcopy(self.env.active_targets) 54 | self.action = None 55 | self.step_in_episode = 0 56 | self.score_agent = 0 57 | 58 | def do_epochs(self, number): 59 | self.count_episode = 0 60 | eval_returns = [] 61 | eval_steps = [] 62 | self.eval_flag = False 63 | 64 | def do_eval(): 65 | eval_return = 0 66 | eval_episode_steps = 0 67 | for eval_episode in range(self.nb_eval_episodes): 68 | if eval_episode in []: 69 | self.env.rendering = True 70 | else: 71 | self.env.rendering = False 72 | print(Font.bold + Font.blue + '>>> Eval Episode {}'.format(eval_episode) + Font.end) 73 | eval_return += self._do_episode(is_learning=False, rendering_sleep=None) 74 | eval_episode_steps += self.step_in_episode 75 | eval_returns.append(eval_return / self.nb_eval_episodes) 76 | eval_steps.append(eval_episode_steps / self.nb_eval_episodes) 77 | plot_and_write(plot_dict={'scores': eval_returns}, loc=self.folder_name + "/scores", 78 | x_label="Epochs", y_label="Mean Score", title="", kind='line', legend=True, 79 | moving_average=True) 80 | plot_and_write(plot_dict={'steps': eval_steps}, loc=self.folder_name + "/steps", 81 | x_label="Epochs", y_label="Mean Steps", title="", kind='line', legend=True) 82 | with open(self.folder_name + "/soc_ai_list.pkl", 'wb') as f: 83 | pickle.dump(self.ai_list, f) 84 | self.eval_flag = False 85 | 86 | for count_epoch in range(number): 87 | # Evaluation: 88 | do_eval() 89 | for count_episode in range(self.epoch_size): 90 | print(Font.bold + 'Epoch: {} '.format(count_epoch) + 91 | Font.yellow + '>>> Episode {}'.format(self.count_episode) + Font.end) 92 | self._do_episode(is_learning=True, rendering_sleep=None) 93 | self.count_episode += 1 94 | do_eval() 95 | return eval_returns 96 | 97 | def demo(self, nb_episodes, rendering_sleep): 98 | assert self.env.rendering is True 99 | for k in range(nb_episodes): 100 | print('\nDemo Episode ', k) 101 | self._do_episode(is_learning=False, rendering_sleep=rendering_sleep) 102 | 103 | def _do_episode(self, is_learning, rendering_sleep): 104 | self.reset() 105 | episode_return = 0 106 | while not self.episode_done: 107 | r = self._step(is_learning=is_learning) 108 | if self.env.rendering: 109 | self.env.render() 110 | time.sleep(rendering_sleep) 111 | episode_return += r # undiscounted return (for eval purposes) 112 | print('Aggregator eps: ', round(self.aggregator_epsilon, 2), ' | alpha: ', round(self.ai_list[0].alpha, 4), 113 | Font.cyan + ' | Episode Score: ' + Font.end, round(episode_return, 2)) 114 | return episode_return 115 | 116 | def _get_action(self, s, explore): 117 | if explore and self.rng.binomial(1, self.aggregator_epsilon): # aggregator exploration 118 | action = self.rng.randint(self.env.nb_actions) 119 | else: 120 | # sum all q's then select max action 121 | q = [] 122 | if self.use_gvf: 123 | s = tuple(s) 124 | for gvf_idx in range(len(self.env.possible_fruits)): 125 | if not self.env_done_for_agent[gvf_idx]: 126 | q.append(self.ai_list[gvf_idx].get_q(s)) 127 | else: 128 | for agent_idx, agent_state in enumerate(s): 129 | if self.learning_flag[agent_idx] is True: 130 | q.append(self.ai_list[agent_idx].get_q(agent_state)) 131 | if self.env.rendering is True: 132 | print(Font.bold + Font.cyan + 'Values:' + Font.end) 133 | print(self.env.action_meanings) 134 | for kk, qq in enumerate(q): 135 | print('-'*35) 136 | print('alpha: ', self.ai_list[kk].alpha, 'gamma: ', self.ai_list[kk].gamma) 137 | string = ' '.join('{:0.2f}'.format(i) for i in qq) 138 | print(string) 139 | print('sum: ', np.sum(q, axis=0)) 140 | q_aggregate = np.sum(q, axis=0) 141 | actions = np.where(q_aggregate == q_aggregate.max())[0] # is biased if using np.argmax(q_aggregate) 142 | action = self.rng.choice(actions) 143 | return action 144 | 145 | def _step(self, is_learning=True): 146 | action = self._get_action(self.last_state, is_learning) 147 | _, r_env, self.episode_done, info = self.env.step(action) 148 | if self.use_gvf: 149 | s2 = [self.env.player_pos_y, self.env.player_pos_x] 150 | if is_learning: 151 | # Hint: ALL gvf agents learn in parallel at each transition (regardless of player's position) 152 | for gvf_idx, gvf_goal in enumerate(self.env.possible_fruits): 153 | if gvf_goal != s2: 154 | self.ai_list[gvf_idx].learn(self.last_state, action, 0., s2, False) 155 | else: 156 | self.ai_list[gvf_idx].learn(self.last_state, action, 1., s2, True) 157 | self.env_done_for_agent[gvf_idx] = True 158 | else: 159 | for gvf_idx, gvf_goal in enumerate(self.env.possible_fruits): 160 | if gvf_goal == s2: 161 | self.env_done_for_agent[gvf_idx] = True 162 | else: 163 | s2 = self.env.get_soc_state() 164 | for k, s2_agent in enumerate(s2): 165 | r_base = self.env.targets[k]['reward'] 166 | if info['fruit'] is not None: 167 | if info['fruit'] == k: 168 | r = r_base 169 | self.env_done_for_agent[k] = True 170 | else: 171 | r = 0. 172 | else: 173 | r = 0. 174 | if is_learning and self.learning_flag[k] is True: 175 | self.ai_list[k].learn(self.last_state[k], action, r, s2_agent, self.env_done_for_agent[k]) 176 | if self.env_done_for_agent[k] is True: # if env terminates for agent_k it stops learning from next step 177 | self.learning_flag[k] = False 178 | self.last_state = deepcopy(s2) 179 | self.score_agent += r_env 180 | self.step_in_episode += 1 181 | if is_learning: 182 | self.total_learning_steps += 1 183 | self._anneal_eps() 184 | self._anneal_alpha() 185 | if self.total_learning_steps % self.epoch_size == 0: 186 | self.eval_flag = True 187 | return r_env 188 | 189 | def _anneal_eps(self): 190 | # linear annealing 191 | if self.total_learning_steps < self.aggregator_decay_start: 192 | return 193 | if self.aggregator_epsilon > self.aggregator_final_epsilon: 194 | decay = (self.aggregator_start_epsilon - self.aggregator_final_epsilon) \ 195 | * (self.total_learning_steps - self.aggregator_decay_start) / self.aggregator_decay_steps 196 | temp = self.aggregator_start_epsilon - decay 197 | if temp > self.aggregator_final_epsilon: 198 | self.aggregator_epsilon = temp 199 | else: 200 | self.aggregator_epsilon = self.aggregator_final_epsilon 201 | 202 | def _anneal_alpha(self): 203 | # linear annealing 204 | if self.total_learning_steps < self.alpha_decay_start: 205 | return 206 | for ai in self.ai_list: 207 | if ai.alpha > self.final_alpha: 208 | decay = (ai.start_alpha - self.final_alpha) * (self.total_learning_steps - self.alpha_decay_start) \ 209 | / self.alpha_decay_steps 210 | temp = ai.start_alpha - decay 211 | if temp > ai.alpha: 212 | ai.alpha = temp 213 | else: 214 | ai.alpha = self.final_alpha 215 | -------------------------------------------------------------------------------- /environment/fruit_collection.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | import pygame 4 | import numpy as np 5 | import click 6 | 7 | 8 | # RGB colors 9 | WHITE = (255, 255, 255) 10 | BLACK = (0, 0, 0) 11 | RED = (255, 0, 0) 12 | BLUE = (0, 100, 255) 13 | WALL = (80, 80, 80) 14 | 15 | 16 | class FruitCollection(object): 17 | def __init__(self, game_length=300, lives=1e6, state_mode='pixel', is_fruit=True, is_ghost=True, 18 | rng=None, rendering=False, image_saving=False, render_dir=None): 19 | self.game_length = game_length 20 | self.lives = lives 21 | self.is_fruit = is_fruit 22 | self.is_ghost = is_ghost 23 | self.legal_actions = [0, 1, 2, 3] 24 | self.action_meanings = ['up', 'down', 'left', 'right'] 25 | self.reward_scheme = {'ghost': -10.0, 'fruit': +1.0, 'step': 0.0, 'wall': 0.0} 26 | self.nb_actions = len(self.legal_actions) 27 | if rng is None: 28 | self.rng = np.random.RandomState(1234) 29 | else: 30 | self.rng = rng 31 | self.player_pos_x = None 32 | self.player_pos_y = None 33 | self.agent_init_pos = None 34 | self.pass_wall_rows = None 35 | self.init_lives = deepcopy(self.lives) 36 | self.step_reward = 0.0 37 | self.possible_fruits = [] 38 | self.state_mode = state_mode # how the returned state look like ('pixel' or '1hot' or 'multi-head') 39 | self.nb_fruits = None 40 | self.scr_w = None 41 | self.scr_h = None 42 | self.rendering_scale = None 43 | self.walls = None 44 | self.fruits = None 45 | self.ghosts = None 46 | self.init_with_mode() 47 | self.nb_non_wall = self.scr_w * self.scr_h - len(self.walls) 48 | self.init_ghosts = deepcopy(self.ghosts) 49 | self._rendering = rendering 50 | if rendering: 51 | self._init_pygame() 52 | self.image_saving = image_saving 53 | self.render_dir_main = render_dir 54 | self.render_dir = None 55 | self.targets = None # fruits + ghosts 56 | self.active_targets = None # boolean list 57 | self.active_fruits = None 58 | self.nb_targets = None 59 | self.init_targets = None 60 | self.nb_ghosts = None 61 | self.soc_state_shape = None 62 | self.state_shape = None 63 | self.state = None 64 | self.step_id = 0 65 | self.game_over = False 66 | self.mini_target = [] # only is used for mini 67 | self.reset() 68 | 69 | def init_with_mode(self): 70 | raise NotImplementedError 71 | 72 | @property 73 | def rendering(self): 74 | return self._rendering 75 | 76 | @rendering.setter 77 | def rendering(self, flag): 78 | if flag is True: 79 | if self._rendering is False: 80 | self._init_pygame() 81 | self._rendering = True 82 | else: 83 | self.close() 84 | self._rendering = False 85 | 86 | def _init_pygame(self): 87 | pygame.init() 88 | size = [self.rendering_scale * self.scr_w, self.rendering_scale * self.scr_h] 89 | self.screen = pygame.display.set_mode(size) 90 | pygame.display.set_caption("Fruit Collection") 91 | 92 | def _init_rendering_folder(self): 93 | if self.render_dir_main is None: 94 | self.render_dir_main = 'render' 95 | if not os.path.exists(os.path.join(os.getcwd(), self.render_dir_main)): 96 | os.mkdir(os.path.join(os.getcwd(), self.render_dir_main)) 97 | i = 0 98 | while os.path.exists(os.path.join(os.getcwd(), self.render_dir_main, 'render' + str(i))): 99 | i += 1 100 | self.render_dir = os.path.join(os.getcwd(), self.render_dir_main, 'render' + str(i)) 101 | os.mkdir(self.render_dir) 102 | 103 | def reset(self): 104 | if self.image_saving: 105 | self._init_rendering_folder() 106 | self.game_over = False 107 | self.step_id = 0 108 | self._reset_targets() 109 | self.nb_ghosts = len(self.ghosts) 110 | self.targets = deepcopy(self.fruits) + deepcopy(self.ghosts) 111 | self.nb_targets = len(self.targets) 112 | self.active_targets = self.active_fruits + [True] * len(self.ghosts) 113 | self.lives = deepcopy(self.init_lives) 114 | self.soc_state_shape = [self.scr_w, self.scr_h, self.scr_w + 1, self.scr_h + 1] 115 | if self.state_mode == '1hot': 116 | self.state_shape = [self.nb_non_wall * self.nb_fruits + self.nb_ghosts * (self.nb_non_wall ** 2)] 117 | elif self.state_mode == 'pixel': 118 | self.state_shape = [4, self.scr_w, self.scr_h] 119 | elif self.state_mode == 'multi-head': 120 | self.state_shape = [3 * self.scr_w * self.scr_h] 121 | elif self.state_mode == 'mini': 122 | self.state_shape = [100 + len(self.possible_fruits)] 123 | 124 | def _reset_targets(self): 125 | raise NotImplementedError 126 | 127 | def close(self): 128 | if self.rendering: 129 | pygame.quit() 130 | 131 | def _move_player(self, action): 132 | assert action in self.legal_actions, 'Illegal action.' 133 | hit_wall = False 134 | if action == 3: # right 135 | passed_wall = False 136 | if self.pass_wall_rows is not None: 137 | for wall_row in self.pass_wall_rows: 138 | if [self.player_pos_x, self.player_pos_y] == [self.scr_w - 1, wall_row]: 139 | self.player_pos_x = 0 140 | passed_wall = True 141 | break 142 | if not passed_wall: 143 | if [self.player_pos_x + 1, self.player_pos_y] not in self.walls and self.player_pos_x < self.scr_w - 1: 144 | self.player_pos_x += 1 145 | else: 146 | hit_wall = True 147 | elif action == 2: # left 148 | passed_wall = False 149 | if self.pass_wall_rows is not None: 150 | for wall_row in self.pass_wall_rows: 151 | if [self.player_pos_x, self.player_pos_y] == [0, wall_row]: 152 | self.player_pos_x = self.scr_w - 1 153 | passed_wall = True 154 | break 155 | if not passed_wall: 156 | if [self.player_pos_x - 1, self.player_pos_y] not in self.walls and self.player_pos_x > 0: 157 | self.player_pos_x -= 1 158 | else: 159 | hit_wall = True 160 | elif action == 1: # down 161 | if [self.player_pos_x, self.player_pos_y + 1] not in self.walls and self.player_pos_y < self.scr_h - 1: 162 | self.player_pos_y += 1 163 | else: 164 | hit_wall = True 165 | elif action == 0: # up 166 | if [self.player_pos_x, self.player_pos_y - 1] not in self.walls and self.player_pos_y > 0: 167 | self.player_pos_y -= 1 168 | else: 169 | hit_wall = True 170 | return hit_wall 171 | 172 | def _check_fruit(self): 173 | if not self.is_fruit: 174 | return None 175 | caught_target = None 176 | caught_target_idx = None 177 | target_count = -1 178 | for k, target in enumerate(self.targets): 179 | target_count += 1 180 | if target['reward'] < 0: # not fruit 181 | continue 182 | if target['location'] == [self.player_pos_x, self.player_pos_y] and target['active'] is True: 183 | caught_target = deepcopy([self.player_pos_y, self.player_pos_x]) 184 | caught_target_idx = k 185 | target['active'] = False 186 | target['location'] = [self.scr_w, self.scr_h] # null value 187 | break 188 | check = [] 189 | for target in self.targets: 190 | if target['reward'] > 0: 191 | check.append(target['active']) 192 | if True not in check: 193 | self.game_over = True 194 | return caught_target, caught_target_idx 195 | 196 | def _check_ghost(self): 197 | if not self.is_ghost: 198 | return None 199 | caught_target = None 200 | for k, target in enumerate(self.targets): 201 | if target['reward'] > 0: # not ghost 202 | continue 203 | if target['location'] == [self.player_pos_x, self.player_pos_y] and target['active'] is True: 204 | caught_target = k 205 | # target['active'] = False 206 | target['locations'] = [self.scr_w, self.scr_h] # null value 207 | self.lives -= 1 208 | break 209 | return caught_target 210 | 211 | def _move_ghosts(self): 212 | if not self.is_ghost: 213 | return 214 | for target in self.targets: 215 | if target['reward'] < 0: 216 | loc = target['location'] 217 | not_moved = True 218 | while not_moved: 219 | direction = self.rng.randint(0, 4) 220 | if direction == 0 and loc[0] < self.scr_w - 1 and [loc[0] + 1, loc[1]] not in self.walls: 221 | loc[0] += 1 222 | not_moved = False 223 | elif direction == 1 and loc[0] > 0 and [loc[0] - 1, loc[1]] not in self.walls: 224 | loc[0] -= 1 225 | not_moved = False 226 | elif direction == 2 and loc[1] < self.scr_h - 1 and [loc[0], loc[1] + 1] not in self.walls: 227 | loc[1] += 1 228 | not_moved = False 229 | elif direction == 3 and loc[1] > 0 and [loc[0], loc[1] - 1] not in self.walls: 230 | loc[1] -= 1 231 | not_moved = False 232 | 233 | def get_state(self): 234 | if self.state_mode == 'pixel': 235 | return self.get_state_pixel() 236 | elif self.state_mode == '1hot': 237 | return self.get_1hot_features() 238 | elif self.state_mode == 'multi-head': 239 | return self.get_state_multi_head() 240 | elif self.state_mode == 'mini': 241 | return self.get_mini_state() 242 | else: 243 | raise ValueError('State-mode is not known.') 244 | 245 | def get_mini_state(self): 246 | state = np.zeros((self.scr_w * self.scr_h + len(self.possible_fruits)), dtype=np.int8) 247 | state[self.player_pos_y * self.scr_h + self.player_pos_x] = 1 248 | for target in self.targets: 249 | if target['active'] and target['reward'] > 0: 250 | offset = self.possible_fruits.index([target['location'][1], target['location'][0]]) 251 | index = (self.scr_w * self.scr_h) + offset 252 | state[index] = 1 253 | return state 254 | 255 | def get_state_multi_head(self): 256 | # three binary heads: player, fruits, ghosts 257 | state = np.zeros(3 * self.scr_w * self.scr_h, dtype=np.int8) 258 | state[self.player_pos_y * self.scr_h + self.player_pos_x] = 1 259 | for target in self.targets: 260 | if target['active']: 261 | if target['reward'] > 0: 262 | index = (self.scr_w * self.scr_h) + (target['location'][1] * self.scr_h + target['location'][0]) 263 | else: 264 | index = 2 * (self.scr_w * self.scr_h) + \ 265 | (target['location'][1] * self.scr_h + target['location'][0]) 266 | state[index] = 1 267 | return state 268 | 269 | def get_state_pixel(self): 270 | state = np.zeros((self.state_shape[1], self.state_shape[2], self.state_shape[0]), dtype=np.int8) 271 | # walls, fruits, player, ghost 272 | player_pos = [self.player_pos_x, self.player_pos_y] 273 | fruits = [] 274 | ghosts = [] 275 | for target in self.targets: 276 | if target['active'] is True: 277 | if target['reward'] > 0: 278 | fruits.append(target['location']) 279 | elif target['reward'] < 0: 280 | ghosts.append(target['location']) 281 | for loc in fruits: 282 | if loc in ghosts and self.is_ghost: 283 | # state[tuple(loc)] = self.code['fruit+ghost'] 284 | state[tuple(loc)][1] = 1 285 | state[tuple(loc)][3] = 1 286 | ghosts.remove(loc) 287 | else: 288 | state[tuple(loc)][1] = 1 289 | # state[tuple(loc)] = self.code['fruit'] 290 | if player_pos in ghosts and self.is_ghost: 291 | state[tuple(player_pos)][2] = 1 292 | state[tuple(player_pos)][3] = 1 293 | ghosts.remove(player_pos) 294 | else: 295 | state[tuple(player_pos)][2] = 1 296 | if self.is_ghost: 297 | for loc in ghosts: 298 | state[tuple(loc)][3] = 1 299 | # state[tuple(loc)] = self.code['ghost'] 300 | for loc in self.walls: 301 | state[tuple(loc)][0] = 1 302 | # state[tuple(loc)] = self.code['wall'] 303 | return deepcopy(state.T) 304 | 305 | def get_soc_state(self): 306 | # call this after each step to get SoC state list (len = self.nb_targets) 307 | # returns list of 4-tuples; one 4-tuple for each target. 308 | state = [] 309 | for target in self.targets: 310 | target_state = [self.player_pos_x, self.player_pos_y] 311 | if target['active'] is True: 312 | target_state.extend(target['location']) 313 | else: 314 | target_state.extend([self.scr_w, self.scr_h]) 315 | state.append(target_state) 316 | return deepcopy(state) 317 | 318 | def get_1hot_features(self): 319 | agent_idx = self._get_idx(self.player_pos_x, self.player_pos_y) 320 | agent_state = np.zeros(self.nb_non_wall, dtype=np.int8) 321 | agent_state[agent_idx] = 1 322 | state = np.zeros(self.state_shape, dtype=np.int8) 323 | i = -1 324 | for target in self.targets: 325 | if target['reward'] > 0: 326 | i += 1 327 | if target['active'] is True: 328 | state[i * self.nb_non_wall: (i + 1) * self.nb_non_wall] = agent_state.copy() 329 | ghost_indices = [] 330 | for target in self.targets: 331 | if target['reward'] < 0: 332 | ghost_indices.append(self._get_idx(target['location'][0], target['location'][1])) 333 | last_fruit_pointer = self.nb_non_wall * self.nb_fruits 334 | for i, ghost_idx in enumerate(ghost_indices): 335 | ghost_agent_idx = agent_idx * self.nb_non_wall + ghost_idx 336 | state[last_fruit_pointer + i * (self.nb_non_wall ** 2) + ghost_agent_idx] = 1 337 | return state.copy() 338 | 339 | def _get_idx(self, x, y): 340 | assert [x, y] not in self.walls 341 | idx = 0 342 | flag = False 343 | for i in range(self.scr_w): 344 | for j in range(self.scr_h): 345 | if [i, j] in self.walls: 346 | continue 347 | if [i, j] == [x, y]: 348 | flag = True 349 | break 350 | else: 351 | idx += 1 352 | if flag: 353 | break 354 | return idx 355 | 356 | def step(self, action): 357 | # actions: [0, 1, 2, 3] == [up, down, left, right] 358 | if self.game_over: 359 | raise ValueError('Environment has already been terminated.') 360 | if self.step_id >= self.game_length - 1: 361 | self.game_over = True 362 | if self.state_mode == 'mini': 363 | head_reward = np.zeros(len(self.possible_fruits), dtype=np.float32) 364 | else: 365 | head_reward = [] 366 | return self.get_state(), 0., self.game_over, \ 367 | {'ghost': None, 'fruit': None, 'hit_wall': False, 'head_reward': head_reward} 368 | last_player_position = deepcopy([self.player_pos_x, self.player_pos_y]) 369 | hit_wall = self._move_player(action) 370 | if hit_wall: 371 | wall_reward = self.reward_scheme['wall'] 372 | else: 373 | wall_reward = 0.0 374 | possible_caught_ghost = self._check_ghost() 375 | if possible_caught_ghost is not None: 376 | last_ghost_position = deepcopy(self.targets[possible_caught_ghost]['location']) 377 | self._move_ghosts() 378 | swap_flag = False # in a T-situation it is possible that no hit happens 379 | if possible_caught_ghost is not None: 380 | if last_player_position == self.targets[possible_caught_ghost]['location'] and \ 381 | last_ghost_position == [self.player_pos_x, self.player_pos_y]: 382 | swap_flag = True 383 | # check for ghost hit head-to-head after the moves 384 | caught_ghost = self._check_ghost() 385 | if caught_ghost is None and swap_flag: # if a swap occurred 386 | caught_ghost = possible_caught_ghost 387 | if caught_ghost is not None: 388 | ghost_reward = self.reward_scheme['ghost'] 389 | else: 390 | ghost_reward = 0. 391 | caught_fruit, caught_fruit_idx = self._check_fruit() 392 | if self.state_mode == 'mini': 393 | head_reward = np.zeros(len(self.possible_fruits), dtype=np.float32) 394 | else: 395 | head_reward = [] 396 | if caught_fruit is not None: 397 | fruit_reward = self.reward_scheme['fruit'] 398 | if self.state_mode == 'mini': 399 | head_reward[self.possible_fruits.index(caught_fruit)] = 1. 400 | else: 401 | fruit_reward = 0. 402 | if self.lives == 0: 403 | self.game_over = True 404 | self.step_id += 1 405 | return self.get_state(), ghost_reward + fruit_reward + wall_reward, \ 406 | self.game_over, {'fruit': caught_fruit_idx, 'ghost': caught_ghost, 'head_reward': head_reward, 407 | 'hit_wall': hit_wall} 408 | 409 | def render(self): 410 | if not self.rendering: 411 | return 412 | pygame.event.pump() 413 | self.screen.fill(BLACK) 414 | size = [self.rendering_scale, self.rendering_scale] 415 | player = pygame.Rect(self.rendering_scale * self.player_pos_x, self.rendering_scale * self.player_pos_y, 416 | size[0], size[1]) 417 | pygame.draw.rect(self.screen, WHITE, player) 418 | for target in self.targets: 419 | if target['active'] is True: 420 | pos = target['location'] 421 | p = [self.rendering_scale * pos[0], self.rendering_scale * pos[1]] 422 | gl = pygame.Rect(p[0], p[1], size[0], size[1]) 423 | pygame.draw.rect(self.screen, target['colour'], gl) 424 | for wall_pos in self.walls: 425 | p = [self.rendering_scale * wall_pos[0], self.rendering_scale * wall_pos[1]] 426 | wall = pygame.Rect(p[0], p[1], size[0], size[1]) 427 | pygame.draw.rect(self.screen, WALL, wall) 428 | pygame.display.flip() 429 | 430 | if self.image_saving: 431 | self.save_image() 432 | 433 | def save_image(self): 434 | if self.rendering and self.render_dir is not None: 435 | pygame.image.save(self.screen, self.render_dir + '/render' + str(self.step_id) + '.jpg') 436 | else: 437 | raise ValueError('env.rendering is False and/or environment has not been reset.') 438 | 439 | 440 | class FruitCollectionSmall(FruitCollection): 441 | def init_with_mode(self): 442 | self.nb_fruits = None 443 | self.scr_w = 11 444 | self.scr_h = 11 445 | self.rendering_scale = 40 446 | self.walls = [[4, 0], [6, 0], [1, 1], [2, 1], [4, 1], [6, 1], [8, 1], [9, 1], [1, 3], [3, 3], [4, 3], 447 | [6, 3], [7, 3], [9, 3], [1, 4], [3, 4], [4, 4], [6, 4], [7, 4], [9, 4], [1, 5], [9, 5], 448 | [1, 6], [2, 6], [3, 6], [5, 6], [7, 6], [8, 6], [9, 6], [0, 8], [1, 8], [2, 8], [4, 8], 449 | [5, 8], [6, 8], [8, 8], [9, 8], [10, 8], [4, 9], [5, 9], [6, 9], [1, 10], [2, 10], [8, 10], 450 | [9, 10]] 451 | if self.is_ghost: 452 | if self.is_fruit: 453 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [0, 5], 454 | 'active': True}, 455 | {'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [4, 5], 456 | 'active': True}] 457 | else: 458 | self.ghosts = [] # will be reset 459 | else: 460 | self.ghosts = [] 461 | 462 | def _reset_targets(self): 463 | if self.is_ghost and not self.is_fruit: 464 | if self.rng.binomial(1, 0.5): 465 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [0, 5], 466 | 'active': True}] 467 | else: 468 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [4, 5], 469 | 'active': True}] 470 | [self.player_pos_x, self.player_pos_y] = deepcopy([self.scr_w - 1, self.scr_h - 1]) 471 | # Targets: Format: [ {colour: c1, reward: r1, locations: list_l1, 'active': list_a1}, ... ] 472 | occupied = self.walls + [[self.player_pos_x, self.player_pos_y]] 473 | self.fruits = [] 474 | self.active_fruits = [] 475 | if self.is_fruit: 476 | for x in range(self.scr_w): 477 | for y in range(self.scr_h): 478 | if [x, y] not in occupied: 479 | if self.rng.binomial(1, 0.5): 480 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 481 | 'location': [x, y], 'active': True}) 482 | self.active_fruits.append(True) 483 | else: 484 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 485 | 'location': [x, y], 'active': False}) 486 | self.active_fruits.append(False) 487 | self.nb_fruits = len(self.fruits) 488 | 489 | 490 | class FruitCollectionMini(FruitCollection): 491 | def init_with_mode(self): 492 | self.is_ghost = False 493 | self.is_fruit = True 494 | self.nb_fruits = 5 495 | self.possible_fruits = [[0, 0], [0, 9], [1, 2], [3, 6], [4, 4], [5, 7], [6, 2], [7, 7], [8, 8], [9, 0]] 496 | self.scr_w = 10 497 | self.scr_h = 10 498 | self.rendering_scale = 50 499 | self.walls = [] 500 | if self.is_ghost: 501 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [0, 1], 502 | 'active': True}] 503 | else: 504 | self.ghosts = [] 505 | 506 | def _reset_targets(self): 507 | while True: 508 | self.player_pos_x, self.player_pos_y = self.rng.randint(0, self.scr_w), self.rng.randint(0, self.scr_h) 509 | if [self.player_pos_x, self.player_pos_y] not in self.possible_fruits: 510 | break 511 | # Targets: Format: [ {colour: c1, reward: r1, locations: list_l1, 'active': list_a1}, ... ] 512 | self.fruits = [] 513 | self.active_fruits = [] 514 | if self.is_fruit: 515 | for x in range(self.scr_w): 516 | for y in range(self.scr_h): 517 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 518 | 'location': [x, y], 'active': False}) 519 | self.active_fruits.append(False) 520 | fruits_idx = deepcopy(self.possible_fruits) 521 | self.rng.shuffle(fruits_idx) 522 | fruits_idx = fruits_idx[:self.nb_fruits] 523 | self.mini_target = [False] * len(self.possible_fruits) 524 | for f in fruits_idx: 525 | idx = f[1] * self.scr_w + f[0] 526 | self.fruits[idx]['active'] = True 527 | self.active_fruits[idx] = True 528 | self.mini_target[self.possible_fruits.index(f)] = True 529 | 530 | 531 | class FruitCollectionLarge(FruitCollection): 532 | def init_with_mode(self): 533 | self.nb_fruits = None 534 | self.scr_w = 21 535 | self.scr_h = 14 536 | self.rendering_scale = 30 537 | self.pass_wall_rows = [4, 8] 538 | self.walls = [[0, 0], [5, 0], [15, 0], [20, 0], 539 | [0, 1], [2, 1], [3, 1], [5, 1], [7, 1], [8, 1], [9, 1], [10, 1], [11, 1], [12, 1], 540 | [13, 1], [15, 1], [17, 1], [18, 1], [20, 1], 541 | [0, 2], [20, 2], 542 | [0, 3], [1, 3], [3, 3], [5, 3], [6, 3], [8, 3], [9, 3], [10, 3], [11, 3], [12, 3], 543 | [14, 3], [15, 3], [17, 3], [19, 3], [20, 3], 544 | [3, 4], [17, 4], 545 | [0, 5], [1, 5], [3, 5], [4, 5], [5, 5], [6, 5], [8, 5], [9, 5], [10, 5], [11, 5], 546 | [12, 5], [14, 5], [15, 5], [16, 5], [17, 5], [19, 5], [20, 5], 547 | [0, 6], [1, 6], [8, 6], [9, 6], [10, 6], [11, 6], [12, 6], [19, 6], [20, 6], 548 | [0, 7], [1, 7], [3, 7], [4, 7], [5, 7], [6, 7], [8, 7], [9, 7], [10, 7], [11, 7], 549 | [12, 7], [14, 7], [15, 7], [16, 7], [17, 7], [19, 7], [20, 7], 550 | [3, 8], [17, 8], 551 | [0, 9], [1, 9], [3, 9], [5, 9], [7, 9], [9, 9], [10, 9], [11, 9], [13, 9], [15, 9], 552 | [17, 9], [19, 9], [20, 9], 553 | [0, 10], [5, 10], [7, 10], [13, 10], [15, 10], [20, 10], 554 | [0, 11], [2, 11], [3, 11], [5, 11], [9, 11], [10, 11], [11, 11], [15, 11], [17, 11], 555 | [18, 11], [20, 11], 556 | [0, 12], [2, 12], [3, 12], [5, 12], [6, 12], [7, 12], [9, 12], [10, 12], [11, 12], 557 | [13, 12], [14, 12], [15, 12], [17, 12], [18, 12], [20, 12], 558 | [0, 13], [20, 13]] 559 | if self.is_ghost: 560 | self.ghosts = [{'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [1, 4], 561 | 'active': True}, 562 | {'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [8, 4], 563 | 'active': True}, 564 | {'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [12, 4], 565 | 'active': True}, 566 | {'colour': RED, 'reward': self.reward_scheme['ghost'], 'location': [19, 4], 567 | 'active': True}] 568 | else: 569 | self.ghosts = [] 570 | 571 | def _reset_targets(self): 572 | self.ghosts = deepcopy(self.init_ghosts) 573 | [self.player_pos_x, self.player_pos_y] = [10, 8] 574 | # Targets: Format: [ {colour: c1, reward: r1, locations: list_l1, 'active': list_a1}, ... ] 575 | occupied = self.walls + [[self.player_pos_x, self.player_pos_y]] 576 | self.fruits = [] 577 | self.active_fruits = [] 578 | if self.is_fruit: 579 | for x in range(self.scr_w): 580 | for y in range(self.scr_h): 581 | if [x, y] not in occupied: 582 | if self.rng.binomial(1, 0.5): 583 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 584 | 'location': [x, y], 'active': True}) 585 | self.active_fruits.append(True) 586 | else: 587 | self.fruits.append({'colour': BLUE, 'reward': self.reward_scheme['fruit'], 588 | 'location': [x, y], 'active': False}) 589 | self.active_fruits.append(False) 590 | self.nb_fruits = len(self.fruits) 591 | 592 | 593 | @click.command() 594 | @click.option('--mode', '-m', help="'small' or 'large' or 'mini'") 595 | @click.option('--fruit/--no-fruit', default=True, help='Activates fruits.') 596 | @click.option('--ghost/--no-ghost', default=True, help='Activates ghosts.') 597 | @click.option('--save/--no-save', default=False, help='Saving rendering screen.') 598 | def test(mode, fruit, ghost, save): 599 | if mode == 'small': 600 | e = FruitCollectionSmall 601 | elif mode == 'mini': 602 | e = FruitCollectionMini 603 | elif mode == 'large': 604 | e = FruitCollectionLarge 605 | else: 606 | raise ValueError('Incorrect mode.') 607 | env = e(rendering=True, lives=1, is_fruit=fruit, is_ghost=ghost, image_saving=save) 608 | print('state shape', env.state_shape) 609 | for _ in range(1): 610 | env.reset() 611 | env.render() 612 | while not env.game_over: 613 | action = None 614 | events = pygame.event.get() 615 | for event in events: 616 | if event.type == pygame.KEYDOWN: 617 | if event.key == pygame.K_UP: 618 | action = 0 619 | if event.key == pygame.K_DOWN: 620 | action = 1 621 | if event.key == pygame.K_LEFT: 622 | action = 2 623 | if event.key == pygame.K_RIGHT: 624 | action = 3 625 | if event.key == pygame.K_q: 626 | return 627 | if action is None: 628 | continue 629 | obs, r, term, info = env.step(action) 630 | env.render() 631 | print("\033[2J\033[H\033[2J", end="") 632 | print() 633 | print('pos: ', env.player_pos_x, env.player_pos_y) 634 | print('reward: ', r) 635 | print('state:') 636 | print('─' * 30) 637 | print(obs) 638 | print('─' * 30) 639 | 640 | 641 | if __name__ == '__main__': 642 | test() 643 | --------------------------------------------------------------------------------