├── LICENSE ├── README.md ├── SelfPlay ├── __init__.py ├── agent │ ├── __init__.py │ ├── base_agent.py │ ├── human_agent.py │ ├── random_agent.py │ ├── registry.py │ └── reinforce_agent.py ├── app │ ├── __init__.py │ ├── acrobot_reinforce.py │ ├── cartpole_reinforce.py │ ├── env_reinforce.py │ ├── mazebase_reinforce.py │ ├── mountaincar_reinforce.py │ ├── plot.py │ ├── selfplay.py │ └── util.py ├── config │ └── config.cfg.sample ├── environment │ ├── __init__.py │ ├── acrobot.py │ ├── cartpole.py │ ├── env.py │ ├── mazebase_wrapper.py │ ├── mountain_car.py │ ├── observation.py │ ├── registry.py │ ├── selfplay.py │ ├── selfplay_memory.py │ └── selfplay_target.py ├── memory │ ├── __init__.py │ ├── base_memory.py │ ├── lstm_memory.py │ └── memory_config.py ├── model │ ├── __init__.py │ └── base_model.py ├── plotter │ ├── __init__.py │ ├── plot_from_dir.py │ ├── plot_from_file.py │ └── util.py ├── policy │ ├── __init__.py │ ├── acrobot_policy.py │ ├── base_policy.py │ ├── cartpole_policy.py │ ├── mazebase_policy.py │ ├── mountaincar_policy.py │ ├── policy_config.py │ └── registry.py ├── requirements.txt ├── scripts │ └── filter_json_lines.py ├── test.sh └── utils │ ├── __init__.py │ ├── argument_parser.py │ ├── config.py │ ├── constant.py │ ├── log.py │ ├── optim_registry.py │ └── util.py ├── _config.yml ├── assets └── images │ ├── acrobot_pca.png │ └── mazebase_selfplay_compare.png ├── docs └── report.pdf ├── model └── .keep └── plot └── .keep /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shagun Sodhani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Memory Augmented Self-Play 2 | 3 | [Self-play](https://arxiv.org/abs/1703.05407) is an unsupervised training procedure which enables the reinforcement learning agents to explore the environment without requiring any external rewards. We augment the self-play setting by providing an external memory where the agent can store experience from the previous tasks. This enables the agent to come up with more diverse self-play tasks resulting in faster exploration of the environment. The agent pretrained in the memory augmented self-play setting easily outperforms the agent pretrained in no-memory self-play setting. 4 | 5 | ## Paper 6 | 7 | * [Arxiv](https://arxiv.org/abs/1805.11016) [Submitted to ICML Workshop] 8 | 9 | ## Setup 10 | 11 | * Install requirements using `pip install -r SelfPlay/requirements.txt`. 12 | * copy `SelfPlay/config/config.cfg.sample` to `Selfplay/config/config.cfg` and update parameters as required 13 | * `cd SelfPlay` and run `./test.sh` 14 | * Refer `SelfPlay/test.sh` to see the different scripts that are supported. 15 | 16 | ## Results 17 | 18 | 19 | For the mazebase task, the trend of current reward (computed as a running average over last 10000 episodes) in target task with the number of episodes is shown in Figure 1. The key observation is that **memory augmented self-play** consistently performs better than both **self-play** and no self-play setting. Even though all the models converge towards the same reward value, using **memory augmented self-play** increases the speed of convergence. All the curves were plotted using the average value after running the models for 5 different seeds. We also observe that **self-play** (without memory) performs better than the no self-play setting. This was observed in the **self-play** paper and our experiments validate that observation. 20 | 21 | ![Comparison of different approaches on Mazebase task](assets/images/mazebase_selfplay_compare.png) 22 | 23 | **Figure 1: Comparison of different approaches on Mazebase task** 24 | 25 | For the Acrobot task, we observe that the memory augmented agent performs the best. Even though both the self-play and no self-play variants perform poorly, the no-selfplay version is slightly better. One of our motivations behind adding memory was that having an explicit memory would allow Alice to remember what states she has already visited. This would enable her to explore more parts of the environment as compared to the **self-play** setting. To validate this hypothesis, we perform a simple analysis. We compile the list of all the start and the end states that Alice encounters during training. Even though the start states are chosen randomly, the end states are determined by the actions she takes. We embed all the states into a 2-dimensional space using PCA and plot the line segments connecting the start and the end states for each episode. The resulting plot is shown in Figure 2. We observe that using LSTM memory results in a wider distribution of end state points as compared to the case of **self-play** with no memory. The mean euclidean distance between start and end points (in PCA space) increases from 0.0192 (**self-play** without memory) to 0.1079 (**memory augmented self-play**), a 5x improvement. This affirms our hypothesis that **memory augmented self-play** enables Alice to explore more parts of the environment and enables her to come up with more diverse set of tasks for Bob to train on. 26 | 27 | 28 | 29 | ![Plot of start and end states in 2D with and without memory augmentation](assets/images/acrobot_pca.png) 30 | 31 | **Figure 2: Plot of start and end states in 2D with and without memory augmentation** 32 | 33 | ## References 34 | 35 | ``` 36 | @article{sukhbaatar2017intrinsic, 37 | title={Intrinsic motivation and automatic curricula via asymmetric self-play}, 38 | author={Sukhbaatar, Sainbayar and Lin, Zeming and Kostrikov, Ilya and Synnaeve, Gabriel and Szlam, Arthur and Fergus, Rob}, 39 | journal={arXiv preprint arXiv:1703.05407}, 40 | year={2017} 41 | } 42 | ``` 43 | 44 | -------------------------------------------------------------------------------- /SelfPlay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/__init__.py -------------------------------------------------------------------------------- /SelfPlay/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/agent/__init__.py -------------------------------------------------------------------------------- /SelfPlay/agent/base_agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC, abstractmethod 3 | 4 | from utils.constant import * 5 | 6 | class BaseAgent(ABC): 7 | """ 8 | Base class for the agents 9 | """ 10 | 11 | def __init__(self, config, possible_actions=[], name=None, **kwargs): 12 | self._type = config[MODEL][AGENT] 13 | self.actions = possible_actions 14 | self.gamma = config[MODEL][GAMMA] 15 | self._lambda = config[MODEL][LAMBDA] 16 | self.save_dir = config[MODEL][SAVE_DIR] 17 | self.load_path = config[MODEL][LOAD_PATH] 18 | self.num_optimisers = config[MODEL][NUM_OPTIMIZERS] 19 | if name: 20 | self.name = name 21 | else: 22 | self.name = ALICE 23 | if(self.num_optimisers==1): 24 | self.learning_rate = config[MODEL][LEARNING_RATE] 25 | elif(self.num_optimisers==2): 26 | self.learning_rate_actor = config[MODEL][LEARNING_RATE_ACTOR] 27 | self.learning_rate_critic = config[MODEL][LEARNING_RATE_CRITIC] 28 | 29 | # cant use "lambda" as it is a keyword in python 30 | 31 | @abstractmethod 32 | def get_action(self, observation): 33 | pass 34 | 35 | def get_random_action(self): 36 | """ 37 | Return a random action 38 | """ 39 | return random.choice(self.actions) 40 | 41 | def get_optimisers(self, optimiser_name): 42 | return None 43 | 44 | def update_policy(self, optimizer, **kwargs): 45 | return optimizer 46 | 47 | def set_initial_state(self): 48 | pass 49 | -------------------------------------------------------------------------------- /SelfPlay/agent/human_agent.py: -------------------------------------------------------------------------------- 1 | from agent.base_agent import BaseAgent 2 | 3 | 4 | class HumanAgent(BaseAgent): 5 | """ 6 | An agent controlled by human 7 | """ 8 | 9 | def __init__(self, config, possible_actions=[], name=None, **kwargs): 10 | super(HumanAgent, self).__init__(config, possible_actions=possible_actions, name=name, **kwargs) 11 | 12 | def get_action(self, observation): 13 | """ 14 | This code is borrowed from: 15 | https://github.com/facebook/MazeBase/blob/23454fe092ecf35a8aab4da4972f231c6458209b/py/example.py#L172 16 | """ 17 | print(list(enumerate(self.actions))) 18 | action_index = -1 19 | while action_index not in range(len(self.actions)): 20 | action_index = input("Input a number to choose the action: ") 21 | try: 22 | action_index = int(action_index) 23 | except ValueError: 24 | action_index = -1 25 | return self.actions[action_index] 26 | -------------------------------------------------------------------------------- /SelfPlay/agent/random_agent.py: -------------------------------------------------------------------------------- 1 | from agent.base_agent import BaseAgent 2 | 3 | class RandomAgent(BaseAgent): 4 | """ 5 | A randomly behaving agent 6 | """ 7 | 8 | def __init__(self, config, possible_actions=[], name=None, **kwargs): 9 | super(RandomAgent, self).__init__(config, possible_actions=possible_actions, name=name, **kwargs) 10 | 11 | def get_action(self, observation): 12 | """Return a random action""" 13 | return self.get_random_action() 14 | 15 | -------------------------------------------------------------------------------- /SelfPlay/agent/registry.py: -------------------------------------------------------------------------------- 1 | from agent.human_agent import HumanAgent 2 | from agent.random_agent import RandomAgent 3 | from agent.reinforce_agent import ReinforceAgent 4 | from utils.constant import RANDOM, HUMAN, REINFORCE 5 | 6 | 7 | def choose_agent(agent_type=RANDOM): 8 | if (agent_type == RANDOM): 9 | return RandomAgent 10 | elif (agent_type == HUMAN): 11 | return HumanAgent 12 | elif (agent_type == REINFORCE): 13 | return ReinforceAgent 14 | 15 | 16 | def get_supported_agents(): 17 | return set([RANDOM, HUMAN, REINFORCE]) 18 | -------------------------------------------------------------------------------- /SelfPlay/agent/reinforce_agent.py: -------------------------------------------------------------------------------- 1 | from agent.base_agent import BaseAgent 2 | from policy.registry import choose_policy 3 | from utils.constant import MODEL, USE_BASELINE, BATCH_SIZE, ENV, IS_SELF_PLAY, \ 4 | IS_SELF_PLAY_WITH_MEMORY, EPISODE_MEMORY_SIZE 5 | from utils.optim_registry import choose_optimiser 6 | 7 | 8 | class ReinforceAgent(BaseAgent): 9 | """ 10 | An agent trained using REINFORCE-Baseline algorithm 11 | """ 12 | 13 | def __init__(self, config, possible_actions=[], name=None, input_size=None): 14 | super(ReinforceAgent, self).__init__(config=config, 15 | possible_actions=possible_actions, name=name) 16 | self.policy = choose_policy(env_name=config[MODEL][ENV], 17 | agent_type=self._type, 18 | use_baseline=config[MODEL][USE_BASELINE], 19 | input_size=input_size, 20 | num_actions=len(self.actions), 21 | batch_size=config[MODEL][BATCH_SIZE], 22 | is_self_play=config[MODEL][IS_SELF_PLAY], 23 | is_self_play_with_memory=config[MODEL][IS_SELF_PLAY_WITH_MEMORY], 24 | _lambda=self._lambda, 25 | episode_memory_size=config[MODEL][EPISODE_MEMORY_SIZE]) 26 | 27 | def get_optimisers(self, optimiser_name): 28 | optimiser = choose_optimiser(optimiser_name=optimiser_name) 29 | params = self.policy.memory.get_params() 30 | if (self.num_optimisers == 1): 31 | params += self.policy.parameters() 32 | return (optimiser(params, self.learning_rate),) 33 | # elif (self.num_optimisers == 2): 34 | # return (optimiser(self.policy.get_actor_params(), self.learning_rate_actor), 35 | # optimiser(self.policy.get_critic_params(), self.learning_rate_critic)) 36 | 37 | def update_policy(self, optimizers, observation=None): 38 | # Only to be called once the episode is over 39 | return self.policy.update(optimisers=optimizers, gamma=self.gamma, agent_name=self.name) 40 | 41 | def get_action(self, observation): 42 | return self.actions[self.policy.get_action(observation)] 43 | 44 | def save_model(self, epochs, optimisers, name, timestamp): 45 | return self.policy.save_model(epochs=epochs, optimisers=optimisers, 46 | save_dir=self.save_dir, name=name, timestamp=timestamp) 47 | 48 | def load_model(self, optimisers, name, timestamp): 49 | return self.policy.load_model(optimisers=optimisers, 50 | load_path=self.load_path, name=name, timestamp=timestamp) 51 | -------------------------------------------------------------------------------- /SelfPlay/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/app/__init__.py -------------------------------------------------------------------------------- /SelfPlay/app/acrobot_reinforce.py: -------------------------------------------------------------------------------- 1 | from app.env_reinforce import run 2 | from app.util import bootstrap 3 | from utils.constant import * 4 | 5 | 6 | def main(): 7 | config = bootstrap() 8 | 9 | ############## 10 | config[MODEL][ENV] = ACROBOT 11 | config[MODEL][AGENT] = REINFORCE 12 | config[MODEL][USE_BASELINE] = True 13 | ############## 14 | 15 | run(config=config) 16 | 17 | 18 | if __name__ == '__main__': 19 | main() 20 | -------------------------------------------------------------------------------- /SelfPlay/app/cartpole_reinforce.py: -------------------------------------------------------------------------------- 1 | from app.env_reinforce import run 2 | from app.util import bootstrap 3 | from utils.constant import * 4 | 5 | 6 | def main(): 7 | config = bootstrap() 8 | 9 | ############## 10 | config[MODEL][ENV] = CARTPOLE 11 | config[MODEL][AGENT] = REINFORCE 12 | config[MODEL][USE_BASELINE] = True 13 | ############## 14 | 15 | run(config=config) 16 | 17 | 18 | if __name__ == '__main__': 19 | main() 20 | 21 | main() 22 | -------------------------------------------------------------------------------- /SelfPlay/app/env_reinforce.py: -------------------------------------------------------------------------------- 1 | from agent.registry import choose_agent 2 | from app.util import bootstrap 3 | from environment.registry import choose_env 4 | from utils.constant import * 5 | from utils.log import write_config_log, write_reward_log 6 | 7 | 8 | def run_episode(env, agent, optimisers, total_episodic_rewards, i_episode, max_steps_per_episode=1000): 9 | current_episodic_reward = 0.0 10 | env.reset() 11 | agent.set_initial_state() 12 | for t in range(max_steps_per_episode): # Don't infinite loop while learning 13 | observation = env.observe() 14 | action = agent.get_action(observation) 15 | observation = env.act(action) 16 | current_episodic_reward += observation.reward 17 | if observation.is_episode_over: 18 | break 19 | optimisers = agent.update_policy(optimisers, observation=observation) 20 | total_episodic_rewards += current_episodic_reward 21 | 22 | if i_episode % 1 == 0: 23 | write_reward_log(episode_number=i_episode, current_episodic_reward=current_episodic_reward, 24 | average_episodic_reward=total_episodic_rewards / i_episode, agent=agent.name, 25 | environment=env.name) 26 | return agent, optimisers, total_episodic_rewards 27 | 28 | 29 | def run(config): 30 | write_config_log(config) 31 | env = choose_env(env=config[MODEL][ENV])() 32 | possible_actions = env.all_possible_actions() 33 | agent = config[MODEL][AGENT] 34 | 35 | agent = choose_agent(agent_type=agent) \ 36 | (config=config, possible_actions=possible_actions) 37 | optimisers = agent.get_optimisers(optimiser_name=config[MODEL][OPTIMISER]) 38 | total_episodic_rewards = 0.0 39 | for i_episode in range(1, config[MODEL][NUM_EPOCHS] + 1): 40 | agent, optimisers, total_episodic_rewards = run_episode(env, agent, optimisers, 41 | total_episodic_rewards, i_episode, 42 | max_steps_per_episode=config[MODEL][ 43 | MAX_STEPS_PER_EPISODE]) 44 | 45 | 46 | if __name__ == '__main__': 47 | config = bootstrap() 48 | run(config) 49 | -------------------------------------------------------------------------------- /SelfPlay/app/mazebase_reinforce.py: -------------------------------------------------------------------------------- 1 | from app.env_reinforce import run 2 | from app.util import bootstrap 3 | from utils.constant import * 4 | 5 | 6 | def main(): 7 | config = bootstrap() 8 | ############## 9 | config[MODEL][ENV] = MAZEBASE 10 | config[MODEL][AGENT] = REINFORCE 11 | config[MODEL][USE_BASELINE] = True 12 | ############## 13 | 14 | run(config=config) 15 | 16 | 17 | if __name__ == '__main__': 18 | main() 19 | -------------------------------------------------------------------------------- /SelfPlay/app/mountaincar_reinforce.py: -------------------------------------------------------------------------------- 1 | from app.env_reinforce import run 2 | from app.util import bootstrap 3 | from utils.constant import * 4 | 5 | 6 | def main(): 7 | config = bootstrap() 8 | 9 | ############## 10 | config[MODEL][ENV] = MOUNTAINCAR 11 | config[MODEL][AGENT] = REINFORCE 12 | config[MODEL][USE_BASELINE] = True 13 | ############## 14 | 15 | run(config=config) 16 | 17 | 18 | if __name__ == '__main__': 19 | main() 20 | -------------------------------------------------------------------------------- /SelfPlay/app/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('agg') 4 | import sys 5 | 6 | from utils.config import get_config 7 | from utils.constant import * 8 | from plotter.plot_from_dir import plot_from_dir 9 | from plotter.plot_from_file import plot_from_file 10 | from pathlib import Path 11 | 12 | 13 | def run(logs_path, env=MAZEBASE, dir_to_save_plots=None, last_n=0, window_size=1): 14 | logs_path = Path(logs_path).resolve() 15 | print(logs_path) 16 | 17 | if (logs_path.is_dir()): 18 | plot_from_dir(logs_path=logs_path, env=env, dir_to_save_plots=dir_to_save_plots, last_n=last_n, window_size=window_size) 19 | elif (logs_path.is_file()): 20 | plot_from_file(log_file_path=logs_path, env=env, dir_to_save_plots=dir_to_save_plots, last_n=last_n, window_size=window_size) 21 | 22 | 23 | if __name__ == '__main__': 24 | logs_path = None 25 | config = None 26 | window_size = 5000 27 | if (len(sys.argv) == 2): 28 | logs_path = sys.argv[1] 29 | config = get_config(use_cmd_config=False) 30 | else: 31 | config = get_config(use_cmd_config=True) 32 | logs_path = config[LOG][FILE_PATH] 33 | # logs_path = "/Users/shagun/projects/self-play/SelfPlay/selfplay-with-memory-logs1" 34 | env = config[MODEL][ENV] 35 | last_n = 0 36 | dir_to_save_plots = config[PLOT][BASE_PATH] 37 | run(logs_path=logs_path, 38 | env=env, 39 | dir_to_save_plots=dir_to_save_plots, 40 | last_n=last_n, 41 | window_size=window_size) 42 | -------------------------------------------------------------------------------- /SelfPlay/app/selfplay.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from time import time 3 | 4 | import numpy as np 5 | from agent.registry import choose_agent 6 | from app.util import bootstrap 7 | from environment.registry import choose_env, choose_selfplay 8 | from environment.selfplay_target import SelfPlayTarget 9 | from utils.constant import * 10 | from utils.log import write_config_log, write_reward_log, write_time_log 11 | 12 | 13 | def run_target_episode(selfplay_target, bob, optimisers_bob, max_steps_per_episode): 14 | current_episodic_reward = 0.0 15 | selfplay_target.reset() 16 | bob.set_initial_state() 17 | for t in range(max_steps_per_episode): # Don't infinite loop while learning 18 | observation = selfplay_target.bob_observe() 19 | action = bob.get_action(observation) 20 | observation = selfplay_target.act(action) 21 | current_episodic_reward += observation[0].reward 22 | if observation[0].is_episode_over: 23 | break 24 | optimisers_bob = bob.update_policy(optimisers_bob, observation=observation) 25 | 26 | return bob, optimisers_bob, current_episodic_reward 27 | 28 | 29 | def run_target_epochs(selfplay_target, bob, optimisers_bob, batch_size, total_episodic_rewards, total_episodes, 30 | max_steps_per_episode=1000): 31 | for i in range(batch_size): 32 | bob, optimisers_bob, current_episodic_reward = run_target_episode(selfplay_target, bob, 33 | optimisers_bob, max_steps_per_episode) 34 | 35 | total_episodes += 1 36 | total_episodic_rewards += current_episodic_reward 37 | for (player, total_reward, current_reward) in [(BOB, total_episodic_rewards, current_episodic_reward)]: 38 | if total_episodes % 1 == 0: 39 | write_reward_log(episode_number=str(total_episodes), agent=player, 40 | current_episodic_reward=current_reward, 41 | average_episodic_reward=total_reward / total_episodes, 42 | environment=selfplay_target.name) 43 | 44 | return bob, optimisers_bob, total_episodic_rewards, total_episodes 45 | 46 | 47 | def run_selfplay_episode(selfplay, alice, bob, optimisers_alice, optimisers_bob, reward_scale=1.0, tMax=80): 48 | tA = 0 49 | selfplay.reset() 50 | selfplay.alice_start() 51 | while True: 52 | tA += 1 53 | observation = selfplay.alice_observe() 54 | action = alice.get_action(observation) 55 | observation = selfplay.act(action) 56 | if (tA >= tMax or observation[0].is_episode_over): 57 | selfplay.alice_stop() 58 | break 59 | write_time_log(time_alice=tA, agent=ALICE, environment=selfplay.name, time=tA) 60 | # write_position_log(alice_start_position=list(selfplay.alice_observations.start.state.astype(np.float64)), 61 | # alice_end_position=list(selfplay.alice_observations.end.state.astype(np.float64))) 62 | 63 | selfplay.bob_start() 64 | tB = 0 65 | while True: 66 | observation = selfplay.bob_observe() 67 | if (observation[0].is_episode_over or tA + tB >= tMax): 68 | if (observation[0].is_episode_over): 69 | print("solved") 70 | else: 71 | print("notsolved") 72 | selfplay.bob_stop() 73 | break 74 | tB += 1 75 | action = bob.get_action(observation) 76 | selfplay.act(action) 77 | write_time_log(time_bob=tB, agent=BOB, environment=selfplay.name, time=tA) 78 | # write_position_log(bob_end_position=list(observation[0].state.astype(np.float64))) 79 | 80 | rA = reward_scale * max(0, tB - tA) 81 | rB = -reward_scale * (tB) 82 | alice_reward_observation = selfplay.alice_observations.end 83 | alice_reward_observation.reward = rA 84 | alice.get_action(observation=(alice_reward_observation, alice_reward_observation)) 85 | 86 | bob_reward_observation = selfplay.observe() 87 | bob_reward_observation.reward = rB 88 | bob.get_action(observation=(bob_reward_observation, bob_reward_observation)) 89 | 90 | optimisers_alice = alice.update_policy(optimisers_alice) 91 | optimisers_bob = bob.update_policy(optimisers_bob) 92 | 93 | return alice, bob, optimisers_alice, optimisers_bob, rA, rB, selfplay 94 | 95 | 96 | def run_selfplay_epoch(selfplay, alice, bob, optimisers_alice, optimisers_bob, batch_size, 97 | total_rA, total_rB, total_episodes, reward_scale=1.0, tMax=80, use_memory=False): 98 | for i in range(batch_size): 99 | alice, bob, optimizers_alice, optimizers_bob, rA, rB, selfplay = run_selfplay_episode(selfplay, alice, bob, 100 | optimisers_alice, 101 | optimisers_bob, 102 | reward_scale=reward_scale, 103 | tMax=tMax) 104 | if (use_memory): 105 | alice_history = np.concatenate((selfplay.alice_observations.start.state.reshape(1, -1), 106 | selfplay.alice_observations.end.state.reshape(1, -1)), axis=1) 107 | 108 | alice.policy.update_memory(alice_history) 109 | 110 | total_rA += rA 111 | total_rB += rB 112 | total_episodes += 1 113 | 114 | for (player, total_reward, current_reward) in [(ALICE, total_rA, rA), (BOB, total_rB, rB)]: 115 | if total_episodes % 1 == 0: 116 | write_reward_log(episode_number=str(total_episodes), agent=player, 117 | current_episodic_reward=current_reward, 118 | average_episodic_reward=total_reward / total_episodes, 119 | environment=selfplay.name) 120 | 121 | return alice, bob, optimizers_alice, optimizers_bob, total_rA, total_rB, total_episodes 122 | 123 | 124 | def run(config): 125 | config_alice = deepcopy(config) 126 | config_bob = deepcopy(config) 127 | config_bob[MODEL][IS_SELF_PLAY_WITH_MEMORY] = False 128 | write_config_log(config_alice) 129 | write_config_log(config_bob) 130 | 131 | use_memory = config[MODEL][IS_SELF_PLAY] and config[MODEL][IS_SELF_PLAY_WITH_MEMORY] 132 | task = config[MODEL][SELFPLAY_TYPE] 133 | env_for_selfplay = choose_env(env=config[MODEL][ENV])() 134 | env_for_selfplay_target = choose_env(env=config[MODEL][ENV])() 135 | agent = config[MODEL][AGENT] 136 | batch_size = config[MODEL][BATCH_SIZE] 137 | reward_scale = config[MODEL][REWARD_SCALE] 138 | max_steps_per_episode_self_play = config[MODEL][MAX_STEPS_PER_EPISODE_SELFPLAY] 139 | target_to_selfplay_ratio = config[MODEL][TARGET_TO_SELFPLAY_RATIO] 140 | 141 | #################Self Play Env################# 142 | selfplay = choose_selfplay(config=config_alice)(environment=env_for_selfplay, task=task) 143 | # selfplay = SelfPlayEnv(environment=env_for_selfplay, task=task) 144 | possible_actions_alice = selfplay.all_possible_actions(agent=ALICE) 145 | possible_actions_bob = selfplay.all_possible_actions(agent=BOB) 146 | 147 | alice = choose_agent(agent_type=agent) \ 148 | (config=config, possible_actions=possible_actions_alice, name=ALICE) 149 | bob = choose_agent(agent_type=agent) \ 150 | (config=config_bob, possible_actions=possible_actions_bob, name=BOB) 151 | 152 | # I am not sure why we want this 153 | selfplay.reset() 154 | 155 | optimisers_alice = alice.get_optimisers(optimiser_name=config[MODEL][OPTIMISER]) 156 | optimisers_bob = bob.get_optimisers(optimiser_name=config[MODEL][OPTIMISER]) 157 | total_rA = 0.0 158 | total_rB = 0.0 159 | 160 | if (config[MODEL][LOAD]): 161 | timestamp = config[MODEL][LOAD_TIMESTAMP] 162 | alice.load_model(optimizers=optimisers_alice, name=ALICE, timestamp=timestamp) 163 | bob.load_model(optimizers=optimisers_bob, name=BOB, timestamp=timestamp) 164 | 165 | #################Self Play Target Env################# 166 | selfplay_target = SelfPlayTarget(environment=env_for_selfplay_target) 167 | total_target_episodic_rewards = 0.0 168 | 169 | episode_counter_selfplay = 0.0 170 | episode_counter_target = 0.0 171 | 172 | for i_epoch in range(config[MODEL][NUM_EPOCHS]): 173 | 174 | alice, bob, optimizers_alice, optimizers_bob, total_rA, total_rB, \ 175 | episode_counter_selfplay = run_selfplay_epoch(selfplay, alice, bob, optimisers_alice, optimisers_bob, 176 | batch_size, 177 | total_rA, total_rB, total_episodes=episode_counter_selfplay, 178 | reward_scale=reward_scale, tMax=max_steps_per_episode_self_play, 179 | use_memory=use_memory) 180 | bob, optimizers_bob, total_target_episodic_rewards, episode_counter_target = run_target_epochs( 181 | selfplay_target=selfplay_target, 182 | bob=bob, optimisers_bob=optimisers_bob, batch_size=batch_size * target_to_selfplay_ratio, 183 | total_episodic_rewards=total_target_episodic_rewards, total_episodes=episode_counter_target, 184 | max_steps_per_episode=config[MODEL][MAX_STEPS_PER_EPISODE] 185 | ) 186 | if (i_epoch % config[MODEL][PERSIST_PER_EPOCH] == 0): 187 | timestamp = str(int(time() * 10000)) 188 | alice.save_model(epochs=i_epoch, optimisers=optimisers_alice, name=ALICE, timestamp=timestamp) 189 | bob.save_model(epochs=i_epoch, optimisers=optimisers_bob, name=BOB, timestamp=timestamp) 190 | 191 | 192 | if __name__ == '__main__': 193 | config = bootstrap() 194 | config[MODEL][IS_SELF_PLAY] = True 195 | run(config) 196 | -------------------------------------------------------------------------------- /SelfPlay/app/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from utils.config import get_config 4 | from utils.constant import * 5 | from utils.util import set_seed 6 | 7 | 8 | def bootstrap(): 9 | config = get_config() 10 | set_seed(seed=config[GENERAL][SEED]) 11 | log_file_name = config[LOG][FILE_PATH] 12 | print("Writing logs to file name: {}".format(log_file_name)) 13 | logging.basicConfig(filename=log_file_name, format='%(message)s', filemode='w', level=logging.DEBUG) 14 | return config 15 | -------------------------------------------------------------------------------- /SelfPlay/config/config.cfg.sample: -------------------------------------------------------------------------------- 1 | [general] 2 | base_path = 3 | device = cpu 4 | seed = 42 5 | id = 6 | 7 | [model] 8 | env = mazebase 9 | agent = reinforce 10 | batch_size = 10 11 | num_epochs = 10 12 | learning_rate = 0.001 13 | learning_rate_actor = 0.001 14 | learning_rate_critic = 0.001 15 | persist_per_epoch = 1000 16 | early_stopping_patience = 10 17 | gamma = 0.1 18 | lambda = 0.1 19 | use_baseline = True 20 | num_optimizers = 1 21 | load = False 22 | save_dir = 23 | load_path = 24 | load_timestamp = 15233139888052 25 | optimiser = adam 26 | max_steps_per_episode=1000 27 | is_self_play = False 28 | is_self_play_with_memory = False 29 | reward_scale = 0.1 30 | max_steps_per_episode_selfplay = 80 31 | target_to_selfplay_ratio = 4 32 | episode_memory_size = 10 33 | memory_type = base_memory 34 | selfplay_type = copy 35 | 36 | [log] 37 | file_path = 38 | 39 | [plot] 40 | base_path = 41 | 42 | [tb] 43 | base_path = -------------------------------------------------------------------------------- /SelfPlay/environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/environment/__init__.py -------------------------------------------------------------------------------- /SelfPlay/environment/acrobot.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | from copy import deepcopy 5 | 6 | from environment.env import Environment 7 | from environment.observation import Observation 8 | 9 | from numpy import sin, cos, pi 10 | from gym import core, spaces 11 | from utils.constant import ACROBOT 12 | 13 | class Acrobot(Environment): 14 | metadata = { 15 | 'render.modes': ['human', 'rgb_array'], 16 | 'video.frames_per_second': 15 17 | } 18 | 19 | dt = .2 20 | 21 | LINK_LENGTH_1 = 1. # [m] 22 | LINK_LENGTH_2 = 1. # [m] 23 | LINK_MASS_1 = 1. #: [kg] mass of link 1 24 | LINK_MASS_2 = 1. #: [kg] mass of link 2 25 | LINK_COM_POS_1 = 0.5 #: [m] position of the center of mass of link 1 26 | LINK_COM_POS_2 = 0.5 #: [m] position of the center of mass of link 2 27 | LINK_MOI = 1. #: moments of inertia for both links 28 | 29 | MAX_VEL_1 = 4 * np.pi 30 | MAX_VEL_2 = 9 * np.pi 31 | 32 | AVAIL_TORQUE = [-1., 0., +1] 33 | 34 | torque_noise_max = 0. 35 | 36 | #: use dynamics equations from the nips paper or the book 37 | book_or_nips = "book" 38 | action_arrow = None 39 | domain_fig = None 40 | actions_num = 3 41 | 42 | def __init__(self): 43 | self.name = ACROBOT 44 | self.viewer = None 45 | high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2]) 46 | low = -high 47 | self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) 48 | self.action_space = spaces.Discrete(3) 49 | self.state = None 50 | 51 | self.n_max_steps = 2000 52 | self.steps_elapsed = 0 53 | self.reward = None 54 | self.reset() 55 | 56 | def act(self, action): 57 | s = self.state 58 | torque = self.AVAIL_TORQUE[action] 59 | 60 | # Add noise to the force action 61 | if self.torque_noise_max > 0: 62 | torque += np.random.uniform(-self.torque_noise_max, self.torque_noise_max) 63 | 64 | # Now, augment the state with our force action so it can be passed to 65 | # _dsdt 66 | s_augmented = np.append(s, torque) 67 | 68 | ns = rk4(self._dsdt, s_augmented, [0, self.dt]) 69 | # only care about final timestep of integration returned by integrator 70 | ns = ns[-1] 71 | ns = ns[:4] # omit action 72 | # ODEINT IS TOO SLOW! 73 | # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt]) 74 | # self.s_continuous = ns_continuous[-1] # We only care about the state 75 | # at the ''final timestep'', self.dt 76 | 77 | ns[0] = wrap(ns[0], -pi, pi) 78 | ns[1] = wrap(ns[1], -pi, pi) 79 | ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1) 80 | ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2) 81 | self.state = ns 82 | self.steps_elapsed += 1 83 | 84 | is_over = self.is_over() 85 | self.reward = -1. if not is_over else 0. 86 | 87 | self.observation = Observation(reward=self.reward, 88 | state=np.array([cos(ns[0]), np.sin(ns[0]), cos(ns[1]), sin(ns[1]), ns[2], ns[3]]), 89 | is_episode_over=self.is_over()) 90 | return self.observe() 91 | 92 | def observe(self): 93 | self.observation = Observation(reward=self.reward, 94 | state=np.array([cos(self.state[0]), np.sin(self.state[0]), cos(self.state[1]), sin(self.state[1]), self.state[2], self.state[3]]), 95 | is_episode_over=self.is_over()) 96 | return self.observation 97 | 98 | def reset(self): 99 | self.state = np.random.uniform(low=-0.1, high=0.1, size=(4,)) 100 | self.steps_elapsed = 0 101 | self.reward = 0.0 102 | self.observation = Observation(reward=self.reward, 103 | state=np.array([cos(self.state[0]), np.sin(self.state[0]), cos(self.state[1]), sin(self.state[1]), self.state[2], self.state[3]]), 104 | is_episode_over=self.is_over()) 105 | return self.observe() 106 | 107 | def is_over(self): 108 | if self.steps_elapsed >= self.n_max_steps: 109 | return True 110 | s = self.state 111 | return bool(-np.cos(s[0]) - np.cos(s[1] + s[0]) > 1.) 112 | 113 | def _dsdt(self, s_augmented, t): 114 | m1 = self.LINK_MASS_1 115 | m2 = self.LINK_MASS_2 116 | l1 = self.LINK_LENGTH_1 117 | lc1 = self.LINK_COM_POS_1 118 | lc2 = self.LINK_COM_POS_2 119 | I1 = self.LINK_MOI 120 | I2 = self.LINK_MOI 121 | g = 9.8 122 | a = s_augmented[-1] 123 | s = s_augmented[:-1] 124 | theta1 = s[0] 125 | theta2 = s[1] 126 | dtheta1 = s[2] 127 | dtheta2 = s[3] 128 | d1 = m1 * lc1 ** 2 + m2 * \ 129 | (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * np.cos(theta2)) + I1 + I2 130 | d2 = m2 * (lc2 ** 2 + l1 * lc2 * np.cos(theta2)) + I2 131 | phi2 = m2 * lc2 * g * np.cos(theta1 + theta2 - np.pi / 2.) 132 | phi1 = - m2 * l1 * lc2 * dtheta2 ** 2 * np.sin(theta2) \ 133 | - 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * np.sin(theta2) \ 134 | + (m1 * lc1 + m2 * l1) * g * np.cos(theta1 - np.pi / 2) + phi2 135 | if self.book_or_nips == "nips": 136 | # the following line is consistent with the description in the 137 | # paper 138 | ddtheta2 = (a + d2 / d1 * phi1 - phi2) / \ 139 | (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) 140 | else: 141 | # the following line is consistent with the java implementation and the 142 | # book 143 | ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * np.sin(theta2) - phi2) \ 144 | / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) 145 | ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 146 | return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.) 147 | 148 | def display(self, mode='human'): 149 | from gym.envs.classic_control import rendering 150 | 151 | s = self.state 152 | 153 | if self.viewer is None: 154 | self.viewer = rendering.Viewer(500, 500) 155 | self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) 156 | 157 | if s is None: return None 158 | 159 | p1 = [-self.LINK_LENGTH_1 * 160 | np.cos(s[0]), self.LINK_LENGTH_1 * np.sin(s[0])] 161 | 162 | p2 = [p1[0] - self.LINK_LENGTH_2 * np.cos(s[0] + s[1]), 163 | p1[1] + self.LINK_LENGTH_2 * np.sin(s[0] + s[1])] 164 | 165 | xys = np.array([[0, 0], p1, p2])[:, ::-1] 166 | thetas = [s[0] - np.pi / 2, s[0] + s[1] - np.pi / 2] 167 | 168 | self.viewer.draw_line((-2.2, 1), (2.2, 1)) 169 | for ((x, y), th) in zip(xys, thetas): 170 | l, r, t, b = 0, 1, .1, -.1 171 | jtransform = rendering.Transform(rotation=th, translation=(x, y)) 172 | link = self.viewer.draw_polygon([(l, b), (l, t), (r, t), (r, b)]) 173 | link.add_attr(jtransform) 174 | link.set_color(0, .8, .8) 175 | circ = self.viewer.draw_circle(.1) 176 | circ.set_color(.8, .8, 0) 177 | circ.add_attr(jtransform) 178 | 179 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 180 | 181 | def close(self): 182 | if self.viewer: self.viewer.close() 183 | 184 | def set_seed(self, seed): 185 | # @todo 186 | pass 187 | 188 | def all_possible_actions(self): 189 | return list(range(self.action_space.n)) 190 | 191 | def are_states_equal(self, state_1, state_2): 192 | return (np.linalg.norm(state_1 - state_2) < 0.1) 193 | 194 | def create_copy(self): 195 | return deepcopy(self.state) 196 | 197 | def load_copy(self, env_copy): 198 | self.state = env_copy 199 | 200 | def wrap(x, m, M): 201 | """ 202 | :param x: a scalar 203 | :param m: minimum possible value in range 204 | :param M: maximum possible value in range 205 | Wraps ``x`` so m <= x <= M; but unlike ``bound()`` which 206 | truncates, ``wrap()`` wraps x around the coordinate system defined by m,M.\n 207 | For example, m = -180, M = 180 (degrees), x = 360 --> returns 0. 208 | """ 209 | diff = M - m 210 | while x > M: 211 | x = x - diff 212 | while x < m: 213 | x = x + diff 214 | return x 215 | 216 | def bound(x, m, M=None): 217 | """ 218 | :param x: scalar 219 | Either have m as scalar, so bound(x,m,M) which returns m <= x <= M *OR* 220 | have m as length 2 vector, bound(x,m, ) returns m[0] <= x <= m[1]. 221 | """ 222 | if M is None: 223 | M = m[1] 224 | m = m[0] 225 | # bound x between min (m) and Max (M) 226 | return min(max(x, m), M) 227 | 228 | def rk4(derivs, y0, t, *args, **kwargs): 229 | """ 230 | Integrate 1D or ND system of ODEs using 4-th order Runge-Kutta. 231 | This is a toy implementation which may be useful if you find 232 | yourself stranded on a system w/o scipy. Otherwise use 233 | :func:`scipy.integrate`. 234 | *y0* 235 | initial state vector 236 | *t* 237 | sample times 238 | *derivs* 239 | returns the derivative of the system and has the 240 | signature ``dy = derivs(yi, ti)`` 241 | *args* 242 | additional arguments passed to the derivative function 243 | *kwargs* 244 | additional keyword arguments passed to the derivative function 245 | Example 1 :: 246 | ## 2D system 247 | def derivs6(x,t): 248 | d1 = x[0] + 2*x[1] 249 | d2 = -3*x[0] + 4*x[1] 250 | return (d1, d2) 251 | dt = 0.0005 252 | t = arange(0.0, 2.0, dt) 253 | y0 = (1,2) 254 | yout = rk4(derivs6, y0, t) 255 | Example 2:: 256 | ## 1D system 257 | alpha = 2 258 | def derivs(x,t): 259 | return -alpha*x + exp(-t) 260 | y0 = 1 261 | yout = rk4(derivs, y0, t) 262 | If you have access to scipy, you should probably be using the 263 | scipy.integrate tools rather than this function. 264 | """ 265 | 266 | try: 267 | Ny = len(y0) 268 | except TypeError: 269 | yout = np.zeros((len(t),), np.float_) 270 | else: 271 | yout = np.zeros((len(t), Ny), np.float_) 272 | 273 | yout[0] = y0 274 | 275 | for i in np.arange(len(t) - 1): 276 | thist = t[i] 277 | dt = t[i + 1] - thist 278 | dt2 = dt / 2.0 279 | y0 = yout[i] 280 | 281 | k1 = np.asarray(derivs(y0, thist, *args, **kwargs)) 282 | k2 = np.asarray(derivs(y0 + dt2 * k1, thist + dt2, *args, **kwargs)) 283 | k3 = np.asarray(derivs(y0 + dt2 * k2, thist + dt2, *args, **kwargs)) 284 | k4 = np.asarray(derivs(y0 + dt * k3, thist + dt, *args, **kwargs)) 285 | yout[i + 1] = y0 + dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4) 286 | return yout 287 | 288 | 289 | 290 | if __name__ == "__main__": 291 | game = Acrobot() 292 | 293 | game.reset() 294 | game.display() 295 | 296 | actions = game.all_possible_actions() 297 | print(actions) 298 | print(type(actions)) 299 | 300 | for i in range(10000): 301 | print(i) 302 | print("==============") 303 | _action = random.choice(actions) 304 | # print(_action) 305 | game.act(_action) 306 | game.display() 307 | if game.is_over(): 308 | break 309 | 310 | game.close() 311 | 312 | 313 | 314 | 315 | -------------------------------------------------------------------------------- /SelfPlay/environment/cartpole.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from gym import spaces, logger 4 | import random 5 | from copy import deepcopy 6 | 7 | from environment.env import Environment 8 | from environment.observation import Observation 9 | from utils.constant import CARTPOLE 10 | 11 | class CartPole(Environment): 12 | def __init__(self): 13 | self.name = CARTPOLE 14 | self.gravity = 9.8 15 | self.masscart = 1.0 16 | self.masspole = 0.1 17 | self.total_mass = (self.masspole + self.masscart) 18 | self.length = 0.5 # actually half the pole's length 19 | self.polemass_length = (self.masspole * self.length) 20 | self.force_mag = 10.0 21 | self.tau = 0.02 # seconds between state updates 22 | 23 | self.n_max_steps = 500 24 | self.steps_elapsed = 0 25 | 26 | # Angle at which to fail the episode 27 | self.theta_threshold_radians = 12 * 2 * math.pi / 360 28 | self.x_threshold = 2.4 29 | 30 | # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds 31 | high = np.array([ 32 | self.x_threshold * 2, 33 | np.finfo(np.float32).max, 34 | self.theta_threshold_radians * 2, 35 | np.finfo(np.float32).max]) 36 | 37 | self.action_space = spaces.Discrete(2) 38 | self.observation_space = spaces.Box(-high, high, dtype=np.float32) 39 | 40 | self.viewer = None 41 | self.state = None 42 | 43 | self.steps_beyond_done = None 44 | self.reset() 45 | 46 | def act(self, action): 47 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 48 | state = self.state 49 | x, x_dot, theta, theta_dot = state 50 | force = self.force_mag if action == 1 else -self.force_mag 51 | costheta = math.cos(theta) 52 | sintheta = math.sin(theta) 53 | temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass 54 | thetaacc = (self.gravity * sintheta - costheta * temp) / ( 55 | self.length * (4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass)) 56 | xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass 57 | x = x + self.tau * x_dot 58 | x_dot = x_dot + self.tau * xacc 59 | theta = theta + self.tau * theta_dot 60 | theta_dot = theta_dot + self.tau * thetaacc 61 | self.state = (x, x_dot, theta, theta_dot) 62 | self.steps_elapsed += 1 63 | 64 | done = self.is_over() 65 | 66 | if not done: 67 | reward = 1.0 68 | elif self.steps_beyond_done is None: 69 | # Pole just fell! 70 | self.steps_beyond_done = 0 71 | reward = 1.0 72 | else: 73 | if self.steps_beyond_done == 0: 74 | logger.warn( 75 | "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.") 76 | self.steps_beyond_done += 1 77 | reward = 0.0 78 | 79 | self.observation = Observation(reward=reward, 80 | state=np.array(self.state), 81 | is_episode_over=self.is_over()) 82 | return self.observe() 83 | 84 | def observe(self): 85 | return self.observation 86 | 87 | def is_over(self): 88 | if self.steps_elapsed >= self.n_max_steps: 89 | return True 90 | done = self.state[0] < -self.x_threshold \ 91 | or self.state[0] > self.x_threshold \ 92 | or self.state[2] < -self.theta_threshold_radians \ 93 | or self.state[2] > self.theta_threshold_radians 94 | done = bool(done) 95 | return done 96 | 97 | def reset(self): 98 | self.state = np.random.uniform(low=-0.05, high=0.05, size=(4,)) 99 | self.steps_beyond_done = None 100 | self.steps_elapsed = 0 101 | 102 | self.observation = Observation( 103 | reward=0.0, 104 | state=np.array(self.state), 105 | is_episode_over=self.is_over() 106 | ) 107 | 108 | return self.observe() 109 | 110 | def display(self, mode='human'): 111 | screen_width = 600 112 | screen_height = 400 113 | 114 | world_width = self.x_threshold * 2 115 | scale = screen_width / world_width 116 | carty = 100 # TOP OF CART 117 | polewidth = 10.0 118 | polelen = scale * 1.0 119 | cartwidth = 50.0 120 | cartheight = 30.0 121 | 122 | if self.viewer is None: 123 | from gym.envs.classic_control import rendering 124 | self.viewer = rendering.Viewer(screen_width, screen_height) 125 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 126 | axleoffset = cartheight / 4.0 127 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 128 | self.carttrans = rendering.Transform() 129 | cart.add_attr(self.carttrans) 130 | self.viewer.add_geom(cart) 131 | l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2 132 | pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 133 | pole.set_color(.8, .6, .4) 134 | self.poletrans = rendering.Transform(translation=(0, axleoffset)) 135 | pole.add_attr(self.poletrans) 136 | pole.add_attr(self.carttrans) 137 | self.viewer.add_geom(pole) 138 | self.axle = rendering.make_circle(polewidth / 2) 139 | self.axle.add_attr(self.poletrans) 140 | self.axle.add_attr(self.carttrans) 141 | self.axle.set_color(.5, .5, .8) 142 | self.viewer.add_geom(self.axle) 143 | self.track = rendering.Line((0, carty), (screen_width, carty)) 144 | self.track.set_color(0, 0, 0) 145 | self.viewer.add_geom(self.track) 146 | 147 | if self.state is None: return None 148 | 149 | x = self.state 150 | cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART 151 | self.carttrans.set_translation(cartx, carty) 152 | self.poletrans.set_rotation(-x[2]) 153 | 154 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 155 | 156 | def close(self): 157 | if self.viewer: self.viewer.close() 158 | 159 | def all_possible_actions(self): 160 | return list(range(self.action_space.n)) 161 | 162 | def set_seed(self, seed): 163 | # @todo 164 | pass 165 | 166 | def are_states_equal(self, state_1, state_2): 167 | # @todo 168 | pass 169 | 170 | def create_copy(self): 171 | return deepcopy(self.state) 172 | 173 | def load_copy(self, env_copy): 174 | self.state = env_copy 175 | 176 | if __name__ == "__main__": 177 | game = CartPole() 178 | 179 | game.reset() 180 | game.display() 181 | 182 | actions = game.all_possible_actions() 183 | print(actions) 184 | print(type(actions)) 185 | 186 | for i in range(10000): 187 | print(i) 188 | print("==============") 189 | _action = random.choice(actions) 190 | print(_action) 191 | game.act(_action) 192 | game.display() 193 | if game.is_over(): 194 | break 195 | 196 | game.close() -------------------------------------------------------------------------------- /SelfPlay/environment/env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from copy import deepcopy 3 | 4 | class Environment(ABC): 5 | """ 6 | Base class for the environments 7 | """ 8 | 9 | def __init__(self): 10 | self.name = None 11 | pass 12 | 13 | @abstractmethod 14 | def observe(self): 15 | '''Return an object of class environment.observation''' 16 | pass 17 | 18 | @abstractmethod 19 | def reset(self): 20 | '''Return an object of class environment.observation''' 21 | pass 22 | 23 | @abstractmethod 24 | def display(self): 25 | '''Prints the environment on the screen and does not return anything''' 26 | pass 27 | 28 | @abstractmethod 29 | def is_over(self): 30 | '''Return a boolean''' 31 | pass 32 | 33 | @abstractmethod 34 | def act(self, action): 35 | '''Return an object of class environment.observation''' 36 | pass 37 | 38 | @abstractmethod 39 | def all_possible_actions(self): 40 | '''Return a list of possible actions(ints)''' 41 | pass 42 | 43 | @abstractmethod 44 | def set_seed(self, seed): 45 | '''Method to set the seed for the environment''' 46 | pass 47 | 48 | def are_states_equal(self, state_1, state_2): 49 | '''Method to compare if two states are equal or sufficiently close''' 50 | return state_1 == state_2 51 | 52 | def are_observations_equal(self, obs1, obs2): 53 | '''Method to compare if two states observations are equal or "sufficiently" close''' 54 | return obs1.are_equal(self, other = obs2, are_states_equal = self.are_states_equal) 55 | 56 | def validate_action(self, action): 57 | '''Method to check if an action is supported''' 58 | if action not in self.all_possible_actions(): 59 | raise Exception("Invalid action ({}) being passed. Only following actions are supported: ({})\n" 60 | .format(action, ", ".join(self.all_possible_actions()))) 61 | 62 | def create_copy(self): 63 | return deepcopy(self) 64 | 65 | def load_copy(self, env_copy): 66 | self = deepcopy(env_copy) 67 | 68 | 69 | -------------------------------------------------------------------------------- /SelfPlay/environment/mazebase_wrapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | 4 | import mazebase 5 | # These weird import statements are taken from https://github.com/facebook/MazeBase/blob/23454fe092ecf35a8aab4da4972f231c6458209b/py/example.py#L12 6 | import mazebase.games as mazebase_games 7 | import numpy as np 8 | from mazebase.games import curriculum 9 | from mazebase.games import featurizers 10 | 11 | from environment.env import Environment 12 | from environment.observation import Observation 13 | from utils.constant import * 14 | 15 | 16 | class MazebaseWrapper(Environment): 17 | """ 18 | Wrapper class over maze base environment 19 | """ 20 | 21 | def __init__(self): 22 | super(MazebaseWrapper, self).__init__() 23 | self.name = MAZEBASE 24 | try: 25 | # Reference: https://github.com/facebook/MazeBase/blob/3e505455cae6e4ec442541363ef701f084aa1a3b/py/mazebase/games/mazegame.py#L454 26 | small_size = (10, 10, 10, 10) 27 | lk = curriculum.CurriculumWrappedGame( 28 | mazebase_games.LightKey, 29 | curriculums={ 30 | 'map_size': mazebase_games.curriculum.MapSizeCurriculum( 31 | small_size, 32 | small_size, 33 | (10, 10, 10, 10) 34 | ) 35 | } 36 | ) 37 | 38 | game = mazebase_games.MazeGame( 39 | games=[lk], 40 | featurizer=mazebase_games.featurizers.GridFeaturizer() 41 | ) 42 | 43 | 44 | except mazebase.utils.mazeutils.MazeException as e: 45 | print(e) 46 | self.game = game 47 | self.actions = self.game.all_possible_actions() 48 | 49 | def observe(self): 50 | game_observation = self.game.observe() 51 | # Logic borrowed from: 52 | # https://github.com/facebook/MazeBase/blob/23454fe092ecf35a8aab4da4972f231c6458209b/py/example.py#L192 53 | obs, info = game_observation[OBSERVATION] 54 | featurizers.grid_one_hot(self.game, obs) 55 | obs = np.array(obs) 56 | featurizers.vocabify(self.game, info) 57 | info = np.array(obs) 58 | game_observation[OBSERVATION] = np.concatenate((obs, info), 2).flatten() 59 | is_episode_over = self.game.is_over() 60 | return Observation(id=game_observation[ID], 61 | reward=game_observation[REWARD], 62 | state=game_observation[OBSERVATION], 63 | is_episode_over=is_episode_over) 64 | 65 | def reset(self): 66 | try: 67 | self.game.reset() 68 | except Exception as e: 69 | print(e) 70 | return self.observe() 71 | 72 | def display(self): 73 | return self.game.display() 74 | 75 | def is_over(self): 76 | return self.game.is_over() 77 | 78 | def act(self, action): 79 | self.game.act(action=action) 80 | return self.observe() 81 | 82 | def all_possible_actions(self): 83 | return self.actions 84 | 85 | def set_seed(self, seed): 86 | # Not needed here as we already set the numpy seed 87 | pass 88 | 89 | def create_copy(self): 90 | return deepcopy(self.game.game) 91 | 92 | def load_copy(self, env_copy): 93 | self.game.game = env_copy 94 | 95 | def are_states_equal(self, state_1, state_2): 96 | return np.array_equal(state_1, state_2) 97 | 98 | 99 | if __name__ == "__main__": 100 | env = MazebaseWrapper() 101 | env.display() 102 | actions = env.all_possible_actions() 103 | print(actions) 104 | for i in range(10): 105 | print("==============") 106 | _action = random.choice(actions) 107 | print(_action) 108 | env.act(_action) 109 | env.display() 110 | -------------------------------------------------------------------------------- /SelfPlay/environment/mountain_car.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from gym import spaces 4 | from gym.utils import seeding 5 | import random 6 | from copy import deepcopy 7 | 8 | from environment.env import Environment 9 | from environment.observation import Observation 10 | from utils.constant import MOUNTAINCAR 11 | 12 | 13 | class MountainCar(Environment): 14 | def __init__(self): 15 | self.name = MOUNTAINCAR 16 | self.min_position = -1.2 17 | self.max_position = 0.6 18 | self.max_speed = 0.07 19 | self.goal_position = 0.5 20 | self.state = None 21 | self.observation = None 22 | self.n_max_steps = 10000 23 | self.steps_elapsed = 0 24 | 25 | self.low = np.array([self.min_position, -self.max_speed]) 26 | self.high = np.array([self.max_position, self.max_speed]) 27 | 28 | self.viewer = None 29 | 30 | self.action_space = spaces.Discrete(3) 31 | self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32) 32 | 33 | self.reset() 34 | 35 | def act(self, action): 36 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 37 | 38 | position, velocity = self.state 39 | velocity += (action - 1) * 0.001 + math.cos(3 * position) * (-0.0025) 40 | velocity = np.clip(velocity, -self.max_speed, self.max_speed) 41 | position += velocity 42 | position = np.clip(position, self.min_position, self.max_position) 43 | if (position == self.min_position and velocity < 0): velocity = 0 44 | 45 | reward = -1.0 46 | self.state = (position, velocity) 47 | self.steps_elapsed += 1 48 | 49 | self.observation = Observation(reward=reward, 50 | state=np.array(self.state), 51 | is_episode_over=self.is_over()) 52 | return self.observe() 53 | 54 | def observe(self): 55 | return self.observation 56 | 57 | def reset(self): 58 | self.state = np.array([np.random.uniform(low=-0.6, high=-0.4), 0]) 59 | self.observation = Observation( 60 | reward=0.0, 61 | state=np.array(self.state), 62 | is_episode_over=self.is_over() 63 | ) 64 | self.steps_elapsed = 0 65 | return self.observe() 66 | 67 | 68 | def is_over(self): 69 | if self.steps_elapsed >= self.n_max_steps: 70 | return True 71 | return bool(self.state[0] >= self.goal_position) 72 | 73 | def _height(self, xs): 74 | return np.sin(3 * xs) * .45 + .55 75 | 76 | def display(self, mode='human'): 77 | screen_width = 600 78 | screen_height = 400 79 | 80 | world_width = self.max_position - self.min_position 81 | scale = screen_width / world_width 82 | carwidth = 40 83 | carheight = 20 84 | 85 | if self.viewer is None: 86 | from gym.envs.classic_control import rendering 87 | self.viewer = rendering.Viewer(screen_width, screen_height) 88 | xs = np.linspace(self.min_position, self.max_position, 100) 89 | ys = self._height(xs) 90 | xys = list(zip((xs - self.min_position) * scale, ys * scale)) 91 | 92 | self.track = rendering.make_polyline(xys) 93 | self.track.set_linewidth(4) 94 | self.viewer.add_geom(self.track) 95 | 96 | clearance = 10 97 | 98 | l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0 99 | car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 100 | car.add_attr(rendering.Transform(translation=(0, clearance))) 101 | self.cartrans = rendering.Transform() 102 | car.add_attr(self.cartrans) 103 | self.viewer.add_geom(car) 104 | frontwheel = rendering.make_circle(carheight / 2.5) 105 | frontwheel.set_color(.5, .5, .5) 106 | frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance))) 107 | frontwheel.add_attr(self.cartrans) 108 | self.viewer.add_geom(frontwheel) 109 | backwheel = rendering.make_circle(carheight / 2.5) 110 | backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance))) 111 | backwheel.add_attr(self.cartrans) 112 | backwheel.set_color(.5, .5, .5) 113 | self.viewer.add_geom(backwheel) 114 | flagx = (self.goal_position - self.min_position) * scale 115 | flagy1 = self._height(self.goal_position) * scale 116 | flagy2 = flagy1 + 50 117 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) 118 | self.viewer.add_geom(flagpole) 119 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]) 120 | flag.set_color(.8, .8, 0) 121 | self.viewer.add_geom(flag) 122 | 123 | pos = self.state[0] 124 | self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale) 125 | self.cartrans.set_rotation(math.cos(3 * pos)) 126 | 127 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 128 | 129 | def close(self): 130 | if self.viewer: self.viewer.close() 131 | 132 | def all_possible_actions(self): 133 | return list(range(self.action_space.n)) 134 | 135 | def set_seed(self, seed): 136 | # @todo 137 | pass 138 | 139 | def are_states_equal(self, state_1, state_2): 140 | return (np.linalg.norm(state_1 - state_2) < 0.2) 141 | 142 | def create_copy(self): 143 | return deepcopy(self.state) 144 | 145 | def load_copy(self, env_copy): 146 | self.state = env_copy 147 | 148 | if __name__ == "__main__": 149 | game = MountainCar() 150 | 151 | game.reset() 152 | game.display() 153 | 154 | actions = game.all_possible_actions() 155 | print(actions) 156 | for i in range(10000): 157 | print(i) 158 | print("==============") 159 | _action = random.choice(actions) 160 | print(_action) 161 | game.act(_action) 162 | game.display() 163 | if game.is_over(): 164 | break 165 | 166 | game.close() 167 | 168 | 169 | -------------------------------------------------------------------------------- /SelfPlay/environment/observation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Observation(): 4 | def __init__(self, id=-1, reward=0, state=None, is_episode_over=False): 5 | self.id = id 6 | self.reward = reward 7 | self.state = state 8 | self.is_episode_over = is_episode_over 9 | 10 | def are_equal(self, other, are_states_equal): 11 | if (self.id != other.id): 12 | return False 13 | elif (self.reward != other.reward): 14 | return False 15 | elif (not are_states_equal(self.state, other.state)): 16 | return False 17 | elif (self.is_episode_over != other.is_episode_over): 18 | return False 19 | return True 20 | 21 | 22 | class ObservationTuple(): 23 | def __init__(self, start=None, end=None, target=None, memory = None): 24 | self.start = start 25 | self.end = end 26 | self.target = target 27 | self.memory = memory 28 | -------------------------------------------------------------------------------- /SelfPlay/environment/registry.py: -------------------------------------------------------------------------------- 1 | from environment.acrobot import Acrobot 2 | from environment.cartpole import CartPole 3 | from environment.mazebase_wrapper import MazebaseWrapper 4 | from environment.mountain_car import MountainCar 5 | from environment.selfplay import SelfPlay 6 | from environment.selfplay_memory import SelfPlayMemory 7 | 8 | from utils.constant import * 9 | 10 | 11 | def choose_env(env=MAZEBASE): 12 | if (env == ACROBOT): 13 | return Acrobot 14 | elif (env == CARTPOLE): 15 | return CartPole 16 | elif (env == MAZEBASE): 17 | return MazebaseWrapper 18 | elif (env == MOUNTAINCAR): 19 | return MountainCar 20 | elif (env == SELFPLAY): 21 | return SelfPlay 22 | 23 | 24 | def get_supported_envs(): 25 | return set([ACROBOT, CARTPOLE, MAZEBASE, MOUNTAINCAR, SELFPLAY]) 26 | 27 | 28 | def choose_selfplay(config): 29 | if (config[MODEL][IS_SELF_PLAY]): 30 | if (config[MODEL][IS_SELF_PLAY_WITH_MEMORY]): 31 | return SelfPlayMemory 32 | else: 33 | return SelfPlay 34 | -------------------------------------------------------------------------------- /SelfPlay/environment/selfplay.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | 4 | from environment.env import Environment 5 | from environment.mazebase_wrapper import MazebaseWrapper 6 | from environment.observation import ObservationTuple, Observation 7 | from utils.constant import * 8 | 9 | 10 | class SelfPlay(Environment): 11 | """ 12 | Wrapper class over self play environment 13 | 14 | SelfPlay supports two modes: 15 | * COPY aka REPEAT where the Bob should repeat what Alice did 16 | In this setting, Bob start from same position as Alice and has to reach the same end position as Alice. 17 | * UNDO where Bob should undo what Alice did 18 | In this setting, Bob start from the end position of Alice and has to reach the start end position of Alice. 19 | """ 20 | 21 | def __init__(self, environment, task=None): 22 | super(SelfPlay, self).__init__() 23 | self.environment = environment 24 | self.name = SELFPLAY + "_" + self.environment.name 25 | self.alice_start_environment = None 26 | self.alice_end_environment = None 27 | # The environment (along with the state) in which alice starts 28 | self.agent_id = 0 29 | self.agents = (ALICE, BOB) 30 | self.observation = Observation() 31 | self.alice_observations = ObservationTuple() 32 | self.bob_observations = ObservationTuple() 33 | _all_possible_actions = self.environment.all_possible_actions() 34 | self.stop_action = len(_all_possible_actions) 35 | self.actions = _all_possible_actions 36 | self.is_over = None 37 | self.task = task 38 | 39 | def _process_observation(self): 40 | self.observation.reward = 0.0 41 | self.observation.is_episode_over = self.is_over 42 | 43 | def observe(self): 44 | self._process_observation() 45 | return self.observation 46 | 47 | def reset(self): 48 | self.observation = self.environment.reset() 49 | self.alice_observations = ObservationTuple() 50 | self.bob_observations = ObservationTuple() 51 | self.is_over = False 52 | self.agent_id = 0 53 | return self.observe() 54 | 55 | def alice_observe(self): 56 | observation = self.observe() 57 | return (observation, self.alice_observations.start) 58 | 59 | def bob_observe(self): 60 | if(self.environment.are_states_equal(self.observation.state, self.bob_observations.target.state)): 61 | self.is_over = True 62 | observation = self.observe() 63 | return (observation, self.bob_observations.target) 64 | 65 | def alice_start(self): 66 | # Memory=None is provided to make the interface same as selfplay_memory 67 | self.agent_id = 0 68 | self.is_over = False 69 | self.alice_observations.start = deepcopy(self.observe()) 70 | if (self.task == COPY): 71 | self.alice_start_environment = self.environment.create_copy() 72 | 73 | def alice_stop(self): 74 | self.agent_id = -1 75 | self.is_over = True 76 | self.alice_observations.end = deepcopy(self.observe()) 77 | if(self.task == UNDO): 78 | self.alice_end_environment = self.environment.create_copy() 79 | 80 | def bob_start(self): 81 | self.agent_id = 1 82 | self.is_over = False 83 | if (self.task == COPY): 84 | self.environment.load_copy(self.alice_start_environment) 85 | self.observation = self.environment.observe() 86 | self.bob_observations.start = deepcopy(self.observe()) 87 | self.bob_observations.target = deepcopy(self.alice_observations.end) 88 | if (not self.environment.are_states_equal(self.bob_observations.start.state, 89 | self.alice_observations.start.state)): 90 | print("Error in initialising Bob's environment") 91 | elif(self.task == UNDO): 92 | self.environment.load_copy(self.alice_end_environment) 93 | self.observation = self.environment.observe() 94 | self.bob_observations.start = deepcopy(self.observe()) 95 | self.bob_observations.target = deepcopy(self.alice_observations.start) 96 | if(not self.environment.are_states_equal(self.bob_observations.start.state, 97 | self.alice_observations.end.state)): 98 | print("Error in initialising Bob's environment") 99 | 100 | def bob_stop(self): 101 | self.agent_id = -1 102 | self.is_over = True 103 | self.bob_observations.end_observation = deepcopy(self.observe()) 104 | 105 | def agent_stop(self): 106 | if (self.agent_id == 0): 107 | self.alice_stop() 108 | elif (self.agent_id == 1): 109 | self.bob_stop() 110 | 111 | def get_current_agent(self): 112 | return self.agents[self.agent_id] 113 | 114 | def switch_player(self): 115 | self.agent_id = (self.agent_id + 1) % 2 116 | 117 | def display(self): 118 | return self.environment.display() 119 | 120 | def is_over(self): 121 | return self.is_over 122 | 123 | def act(self, action): 124 | prev_agent_id = self.agent_id 125 | if (action == self.stop_action): 126 | self.agent_stop() 127 | elif (action != self.stop_action): 128 | self.observation = self.environment.act(action=action) 129 | if (prev_agent_id == 0): 130 | return self.alice_observe() 131 | elif (prev_agent_id == 1): 132 | return self.bob_observe() 133 | else: 134 | return self.observe() 135 | 136 | def all_possible_actions(self, agent=ALICE): 137 | if (agent == ALICE): 138 | return self.actions + [self.stop_action] 139 | elif (agent == BOB): 140 | return self.actions 141 | 142 | return self.actions 143 | 144 | def get_task(self): 145 | return self.task 146 | 147 | def set_task(self, task=COPY): 148 | play.task = task 149 | 150 | def set_seed(self, seed): 151 | self.environment.set_seed(seed) 152 | 153 | 154 | if __name__ == "__main__": 155 | play = SelfPlay(environment=MazebaseWrapper()) 156 | # env.display() 157 | actions = play.all_possible_actions() 158 | print(actions) 159 | for i in range(100): 160 | print("==============") 161 | _action = random.choice(actions) 162 | print(_action) 163 | play.act(_action) 164 | print((play.observe()).reward) 165 | -------------------------------------------------------------------------------- /SelfPlay/environment/selfplay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from environment.env import Environment 4 | from environment.mazebase_wrapper import MazebaseWrapper 5 | from environment.observation import ObservationTuple, Observation 6 | from utils.constant import * 7 | from copy import deepcopy 8 | from environment.selfplay import SelfPlay 9 | 10 | class SelfPlayMemory(SelfPlay): 11 | """ 12 | Wrapper class over self play environment 13 | """ 14 | 15 | def __init__(self, environment, task=COPY): 16 | super(SelfPlayMemory, self).__init__(environment=environment, task=task) 17 | self.environment = environment 18 | self.name = SELFPLAY + "_" + MEMORY + "_" + self.environment.name 19 | # The environment (along with the state) in which alice starts 20 | 21 | def observe(self): 22 | self._process_observation() 23 | return self.observation 24 | 25 | def alice_observe(self): 26 | observation = self.observe() 27 | return (observation, self.alice_observations.start) 28 | 29 | if __name__ == "__main__": 30 | play = SelfPlayMemory(environment=MazebaseWrapper()) 31 | actions = play.all_possible_actions() 32 | print(actions) 33 | for i in range(100): 34 | print("==============") 35 | _action = random.choice(actions) 36 | print(_action) 37 | play.act(_action) 38 | print((play.observe()).reward) 39 | -------------------------------------------------------------------------------- /SelfPlay/environment/selfplay_target.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from environment.selfplay import SelfPlay 4 | from environment.mazebase_wrapper import MazebaseWrapper 5 | from environment.observation import ObservationTuple, Observation 6 | from utils.constant import * 7 | from copy import deepcopy 8 | import numpy as np 9 | 10 | class SelfPlayTarget(SelfPlay): 11 | """ 12 | Wrapper class over self play environment 13 | """ 14 | 15 | def __init__(self, environment, task=TARGET): 16 | super(SelfPlayTarget, self).__init__(environment=environment, task=task) 17 | self.name = SELFPLAY + "_" + TARGET + "_" + self.environment.name 18 | self.alice_start_environment = None 19 | self.agent_id = 1 20 | self.agents = (BOB) 21 | self.observation = Observation() 22 | self.alice_observations = None 23 | self.bob_observations = ObservationTuple() 24 | _all_possible_actions = self.environment.all_possible_actions() 25 | self.stop_action = None 26 | self.actions = _all_possible_actions 27 | self.is_over = None 28 | self.task = task 29 | 30 | def observe(self): 31 | return self.observation 32 | 33 | def reset(self): 34 | self.observation = self.environment.reset() 35 | self.bob_observations.start = deepcopy(self.observation) 36 | self.bob_observations.start.state = np.zeros_like(self.bob_observations.start.state) 37 | self.is_over = False 38 | self.agent_id = 1 39 | return self.observe() 40 | 41 | def alice_observe(self): 42 | return None 43 | 44 | def bob_observe(self): 45 | observation = self.observe() 46 | return (observation, self.bob_observations.start) 47 | 48 | 49 | def alice_start(self): 50 | return None 51 | 52 | def alice_stop(self): 53 | return None 54 | 55 | def bob_start(self): 56 | self.agent_id = 1 57 | self.is_over = False 58 | 59 | def bob_stop(self): 60 | return None 61 | 62 | def agent_stop(self): 63 | return None 64 | 65 | def display(self): 66 | return self.environment.display() 67 | 68 | def is_over(self): 69 | return self.is_over 70 | 71 | def act(self, action): 72 | self.observation = self.environment.act(action=action) 73 | return self.bob_observe() 74 | 75 | if __name__ == "__main__": 76 | play = SelfPlay(environment=MazebaseWrapper(), task=COPY) 77 | # env.display() 78 | actions = play.all_possible_actions() 79 | print(actions) 80 | for i in range(100): 81 | print("==============") 82 | _action = random.choice(actions) 83 | print(_action) 84 | play.act(_action) 85 | print((play.observe()).reward) 86 | -------------------------------------------------------------------------------- /SelfPlay/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/memory/__init__.py -------------------------------------------------------------------------------- /SelfPlay/memory/base_memory.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | from model.base_model import BaseModel 7 | from memory.memory_config import MemoryConfig 8 | from utils.constant import * 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | 13 | class BaseMemory(BaseModel): 14 | def __init__(self, memory_config): 15 | super(BaseMemory, self).__init__() 16 | self.memory_config = memory_config 17 | self.internal_memory = deque(maxlen=memory_config[EPISODE_MEMORY_SIZE]) 18 | self.hidden_state = None 19 | self.reset() 20 | 21 | # response_network is the network that map the start and the end observations into a single vector 22 | self.response_network = self.get_response_network() 23 | self.summary_network = self.get_summary_network() 24 | self.init_weights() 25 | 26 | 27 | def reset(self): 28 | self.internal_memory.clear() 29 | self.hidden_state = Variable(torch.from_numpy(np.zeros(shape=(self.memory_config[INPUT_DIM])) 30 | )).float() 31 | 32 | # def get_response_network(self): 33 | # return nn.Sequential( 34 | # nn.Linear(self.memory_config[INPUT_DIM], self.memory_config[OUTPUT_DIM]) 35 | # ) 36 | 37 | def get_response_network(self): 38 | def _reponse(): 39 | return self.hidden_state 40 | return _reponse 41 | 42 | def get_summary_network(self): 43 | def _summariser(): 44 | num_entries = len(self.internal_memory) 45 | if(num_entries ==0 ): 46 | return self.hidden_state 47 | else: 48 | return (sum(self.internal_memory)/num_entries).squeeze(0) 49 | return _summariser 50 | 51 | def update_memory(self, history): 52 | self.internal_memory.append(history) 53 | # self.internal_memory.append(Variable(torch.from_numpy(history)).float()) 54 | 55 | def summarize(self): 56 | # This is the method which the different classes would override 57 | return self.forward() 58 | 59 | def forward(self): 60 | self.hidden_state = self.summary_network() 61 | return self.response_network() 62 | 63 | def init_weights(self): 64 | self.init_weights_response_network() 65 | self.init_weights_summary_network() 66 | 67 | def init_weights_response_network(self): 68 | # based on https://discuss.pytorch.org/t/how-to-initialize-weights-in-nn-sequential-container/8534/3 69 | self.apply(self._init_weights) 70 | 71 | def init_weights_summary_network(self): 72 | pass 73 | 74 | def get_params(self): 75 | params = [] 76 | for param in self.named_parameters(): 77 | if param[1].requires_grad: 78 | params.append(param[1]) 79 | return params 80 | 81 | def _init_weights(self, module): 82 | if type(module) == nn.Linear: 83 | module.weight.data.fill_(1.0) 84 | 85 | if __name__ == '__main__': 86 | memory = BaseMemory(memory_config=MemoryConfig()) 87 | print(memory) 88 | -------------------------------------------------------------------------------- /SelfPlay/memory/lstm_memory.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | from memory.base_memory import BaseMemory 7 | from memory.memory_config import MemoryConfig 8 | from utils.constant import * 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | 13 | class LstmMemory(BaseMemory): 14 | def __init__(self, memory_config): 15 | super(LstmMemory, self).__init__(memory_config) 16 | self.internal_memory = nn.LSTM(self.memory_config[OUTPUT_DIM], self.memory_config[OUTPUT_DIM], 1) 17 | self.hidden_state = None 18 | self.cell_state = None 19 | self.reset() 20 | 21 | 22 | def reset(self): 23 | self.hidden_state = Variable(torch.from_numpy(np.zeros(shape=(1, 1, self.memory_config[OUTPUT_DIM])) 24 | )).float() 25 | self.cell_state = Variable(torch.from_numpy(np.zeros(shape=(1, 1, self.memory_config[OUTPUT_DIM])) 26 | )).float() 27 | 28 | def get_summary_network(self): 29 | def _summariser(): 30 | return self.hidden_state 31 | return _summariser 32 | 33 | def update_memory(self, history): 34 | # print(history.unsqueeze(0).size()) 35 | # _, (self.hidden_state, _) = self.internal_memory(history, self.hidden_state) 36 | _, (self.hidden_state, self.cell_state) = self.internal_memory(history.unsqueeze(0), (self.hidden_state, self.cell_state)) 37 | # self.internal_memory.append(Variable(torch.from_numpy(history)).float()) 38 | # print(self.hidden_state.size()) 39 | 40 | def summarize(self): 41 | # This is the method which the different classes would override 42 | return self.forward() 43 | 44 | def forward(self): 45 | self.hidden_state = self.summary_network() 46 | return self.response_network() 47 | 48 | def init_weights(self): 49 | self.init_weights_response_network() 50 | self.init_weights_summary_network() 51 | 52 | def init_weights_response_network(self): 53 | # based on https://discuss.pytorch.org/t/how-to-initialize-weights-in-nn-sequential-container/8534/3 54 | self.apply(self._init_weights) 55 | 56 | def init_weights_summary_network(self): 57 | pass 58 | 59 | def get_params(self): 60 | params = [] 61 | for param in self.internal_memory.named_parameters(): 62 | if param[1].requires_grad: 63 | params.append(param[1]) 64 | return params 65 | 66 | def _init_weights(self, module): 67 | if type(module) == nn.Linear: 68 | module.weight.data.fill_(1.0) 69 | 70 | if __name__ == '__main__': 71 | memory = LstmMemory(memory_config=MemoryConfig()) 72 | print(memory) 73 | -------------------------------------------------------------------------------- /SelfPlay/memory/memory_config.py: -------------------------------------------------------------------------------- 1 | from utils.constant import * 2 | 3 | 4 | class MemoryConfig: 5 | def __init__(self, episode_memory_size=10, input_dim=156*10*10*2, output_dim=50): 6 | self.data = { 7 | EPISODE_MEMORY_SIZE: episode_memory_size, 8 | INPUT_DIM: input_dim, 9 | OUTPUT_DIM: output_dim 10 | } 11 | 12 | def __getitem__(self, key): 13 | return self.data[key] 14 | 15 | def __setitem__(self, key, value): 16 | self.data[key] = value -------------------------------------------------------------------------------- /SelfPlay/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/model/__init__.py -------------------------------------------------------------------------------- /SelfPlay/model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from time import time 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from utils.constant import * 9 | 10 | 11 | class BaseModel(torch.nn.Module): 12 | def __init__(self): 13 | super(BaseModel, self).__init__() 14 | 15 | def forward(self, data): 16 | pass 17 | 18 | def save_model(self, epochs=-1, optimisers=None, save_dir=None, name=ALICE, timestamp=None): 19 | ''' 20 | Method to persist the model 21 | ''' 22 | if not timestamp: 23 | timestamp = str(int(time())) 24 | state = { 25 | EPOCHS: epochs + 1, 26 | STATE_DICT: self.state_dict(), 27 | OPTIMISER: [optimiser.state_dict() for optimiser in optimisers], 28 | NP_RANDOM_STATE: np.random.get_state(), 29 | PYTHON_RANDOM_STATE: random.getstate(), 30 | PYTORCH_RANDOM_STATE: torch.get_rng_state() 31 | } 32 | path = os.path.join(save_dir, 33 | name + "_model_timestamp_" + timestamp + ".tar") 34 | torch.save(state, path) 35 | print("saved model to path = {}".format(path)) 36 | 37 | def load_model(self, optimisers, load_path=None, name=ALICE, timestamp=None): 38 | timestamp = str(timestamp) 39 | path = os.path.join(load_path, 40 | name + "_model_timestamp_" + timestamp + ".tar") 41 | print("Loading model from path {}".format(path)) 42 | checkpoint = torch.load(path) 43 | epochs = checkpoint[EPOCHS] 44 | self._load_metadata(checkpoint) 45 | self._load_model_params(checkpoint[STATE_DICT]) 46 | 47 | for i, _ in enumerate(optimisers): 48 | optimisers[i].load_state_dict(checkpoint[OPTIMISER][i]) 49 | return optimisers, epochs 50 | 51 | def _load_metadata(self, checkpoint): 52 | np.random.set_state(checkpoint[NP_RANDOM_STATE]) 53 | random.setstate(checkpoint[PYTHON_RANDOM_STATE]) 54 | torch.set_rng_state(checkpoint[PYTORCH_RANDOM_STATE]) 55 | 56 | def _load_model_params(self, state_dict): 57 | self.load_state_dict(state_dict) 58 | -------------------------------------------------------------------------------- /SelfPlay/plotter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/plotter/__init__.py -------------------------------------------------------------------------------- /SelfPlay/plotter/plot_from_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from plotter.util import metrics_key_set, matplotlib_args, agents, get_env_list, matplotlib_kwargs, compute_running_average 9 | from utils.constant import * 10 | from utils.log import parse_log_file 11 | from utils.util import make_dir 12 | 13 | 14 | def get_dir_to_save_plots(logs_path, dir_to_save_plots): 15 | dir_to_save_plots = (Path(dir_to_save_plots) 16 | .parent 17 | .joinpath( 18 | logs_path 19 | .as_posix() 20 | .rsplit("/", 1)[1] 21 | )).as_posix() 22 | 23 | make_dir(dir_to_save_plots) 24 | 25 | return dir_to_save_plots 26 | 27 | 28 | def parse_logs_from_dir(logs_path, envs): 29 | dir_logs = {} 30 | for agent in agents: 31 | dir_logs[agent] = {} 32 | 33 | for log_idx, log_file_path in enumerate(logs_path.glob("**/log*.txt")): 34 | print("Parsing {}".format(log_file_path)) 35 | for agent in agents: 36 | logs = parse_log_file(log_file_path=log_file_path, agent=agent, env_list=envs) 37 | for env in logs: 38 | if env not in dir_logs[agent]: 39 | dir_logs[agent][env] = {} 40 | for key in logs[env]: 41 | if (key in metrics_key_set): 42 | dir_logs[agent][env][key] = [] 43 | for key in logs[env]: 44 | if (key in metrics_key_set): 45 | dir_logs[agent][env][key].append(logs[env][key]) 46 | 47 | for agent in agents: 48 | for env in list(dir_logs[agent].keys()): 49 | for key in dir_logs[agent][env]: 50 | dir_logs[agent][env][key] = np.asarray(dir_logs[agent][env][key]) 51 | 52 | return dir_logs 53 | 54 | 55 | def transform_logs_to_aggregated_logs(dir_logs, window_size=1000): 56 | aggregated_logs = deepcopy(dir_logs) 57 | for agent, agent_val in dir_logs.items(): 58 | for env, env_value in agent_val.items(): 59 | for key, key_val in env_value.items(): 60 | min_len = min(map(lambda x: len(x), key_val)) 61 | metric_val = np.asarray(list(map(lambda x: compute_running_average(x[:min_len], 62 | window_size=window_size), key_val))) 63 | _metric = { 64 | AVERAGE: np.mean(metric_val, axis=0), 65 | STDDEV: np.std(metric_val, axis=0) 66 | } 67 | aggregated_logs[agent][env][key] = _metric 68 | del dir_logs 69 | return aggregated_logs 70 | 71 | 72 | def plot_from_dir(logs_path, env=MAZEBASE, dir_to_save_plots=None, last_n=0, window_size=1000): 73 | ''' 74 | This method wraps the ` 75 | :return: 76 | ''' 77 | dir_to_save_plots = get_dir_to_save_plots(logs_path, dir_to_save_plots) 78 | 79 | envs = get_env_list(base_env=env) 80 | 81 | dir_logs = parse_logs_from_dir(logs_path=logs_path, envs=envs) 82 | 83 | np.save("{}/dir_logs.npy".format(dir_to_save_plots), 84 | dir_logs) 85 | 86 | aggregated_logs = transform_logs_to_aggregated_logs(dir_logs, window_size=window_size) 87 | 88 | for agent, agent_val in aggregated_logs.items(): 89 | for env, env_value in agent_val.items(): 90 | for key, key_val in env_value.items(): 91 | if (last_n > 0): 92 | plot_last_n_aggregated(agent_val, last_n, key, agent, dir_to_save_plots, env, 93 | *matplotlib_args, 94 | **matplotlib_kwargs) 95 | else: 96 | plot_aggregated(agent_val, key, agent, dir_to_save_plots, env, 97 | *matplotlib_args, 98 | **matplotlib_kwargs) 99 | 100 | print(dir_to_save_plots) 101 | 102 | 103 | def plot_aggregated(logs, key, agent, dir_to_save_plots, env, *args, **kwargs): 104 | new_metric = { 105 | AVERAGE: logs[env][key][AVERAGE], 106 | STDDEV: logs[env][key][STDDEV] 107 | } 108 | metric_average = new_metric[AVERAGE] 109 | std_average = new_metric[STDDEV] 110 | if (len(metric_average) > 2): 111 | plt.plot(metric_average, *args, **kwargs) 112 | # plt.show() 113 | ax = plt.gca() 114 | ax.fill_between(range(len(std_average)), metric_average + std_average, metric_average - std_average, alpha=0.2) 115 | # plt.show() 116 | if (key in set([CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD])): 117 | ylabel = REWARD 118 | xlabel = "Number of Episodes" 119 | title = key 120 | elif (key in set([AVERAGE_BATCH_LOSS])): 121 | ylabel = LOSS 122 | xlabel = "Number of Batches" 123 | title = key 124 | elif (key in set([TIME])): 125 | ylabel = "time taken in self play" 126 | xlabel = "Number of Episodes" 127 | title = ylabel 128 | title = agent + "_" + title 129 | title = title + "___" + ENVIRONMENT + "_" + env 130 | plt.ylabel(ylabel) 131 | plt.xlabel(xlabel) 132 | plt.title(title) 133 | plt.show() 134 | if dir_to_save_plots: 135 | path = os.path.join(dir_to_save_plots, title) 136 | plt.savefig(path) 137 | plt.clf() 138 | else: 139 | print("Not enough data to plot anything for key = {}, agent = {}, env = {}".format(key, agent, env)) 140 | 141 | 142 | def plot_last_n_aggregated(logs, n, key, agent, dir_to_save_plots, env, *args, **kwargs): 143 | recent_logs = {} 144 | recent_logs[env] = {} 145 | recent_logs[env][key] = logs[env][key] 146 | print("mean = {} for key = {}, agent = {}, env = {}".format( 147 | np.mean(np.asarray(recent_logs[env][key][AVERAGE])), 148 | key, agent, env)) 149 | if (n > 0 and len(recent_logs[env][key]) > n): 150 | recent_logs[env][key] = recent_logs[env][key][-n:] 151 | if (key in [CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, TIME]): 152 | return plot_aggregated(recent_logs, key, agent, dir_to_save_plots, env, *args, **kwargs) 153 | -------------------------------------------------------------------------------- /SelfPlay/plotter/plot_from_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from plotter.util import metrics_key_set, matplotlib_args, agents, get_env_list, matplotlib_kwargs, compute_running_average 7 | from utils.constant import * 8 | from utils.log import parse_log_file 9 | 10 | 11 | def plot_from_file(log_file_path, env=MAZEBASE, dir_to_save_plots=None, last_n=0, window_size = 1000): 12 | envs = get_env_list(base_env=env) 13 | for agent in agents: 14 | logs = parse_log_file(log_file_path=log_file_path, agent=agent, env_list=envs) 15 | for key in metrics_key_set: 16 | for env in envs: 17 | if (env in logs): 18 | if (last_n > 0): 19 | plot_last_n(logs, last_n, key, agent, dir_to_save_plots, env, window_size, *matplotlib_args, **matplotlib_kwargs) 20 | else: 21 | plot(logs, key, agent, dir_to_save_plots, env, window_size, *matplotlib_args, **matplotlib_kwargs) 22 | 23 | 24 | def plot(logs, key, agent, dir_to_save_plots, env, window_size, *args, **kwargs): 25 | new_metric = compute_running_average(logs[env][key], window_size=window_size) 26 | if (len(new_metric) > 2): 27 | plt.plot(new_metric, *args, **kwargs) 28 | if (key in set([CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD])): 29 | ylabel = REWARD 30 | xlabel = "Number of Episodes" 31 | title = key 32 | elif (key in set([AVERAGE_BATCH_LOSS])): 33 | ylabel = LOSS 34 | xlabel = "Number of Batches" 35 | title = key 36 | elif (key in set([TIME])): 37 | ylabel = "time taken in self play" 38 | xlabel = "Number of Episodes" 39 | title = ylabel 40 | title = agent + "_" + title 41 | title = title + "___" + ENVIRONMENT + "_" + env 42 | plt.ylabel(ylabel) 43 | plt.xlabel(xlabel) 44 | plt.title(title) 45 | plt.show() 46 | if dir_to_save_plots: 47 | path = os.path.join(dir_to_save_plots, title) 48 | plt.savefig(path) 49 | plt.clf() 50 | else: 51 | print("Not enough data to plot anything for key = {}, agent = {}, env = {}".format(key, agent, env)) 52 | 53 | 54 | def plot_last_n(logs, n, key, agent, dir_to_save_plots, env, window_size, *args, **kwargs): 55 | recent_logs = {} 56 | recent_logs[env] = {} 57 | recent_logs[env][key] = logs[env][key] 58 | print("mean = {} for key = {}, agent = {}, env = {}".format( 59 | np.mean(np.asarray(recent_logs[env][key])), 60 | key, agent, env)) 61 | if (n > 0 and len(recent_logs[env][key]) > n): 62 | recent_logs[env][key] = recent_logs[env][key][-n:] 63 | if (key in [CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, TIME]): 64 | return plot(recent_logs, key, agent, dir_to_save_plots, env, window_size, *args, **kwargs) 65 | -------------------------------------------------------------------------------- /SelfPlay/plotter/util.py: -------------------------------------------------------------------------------- 1 | from utils.constant import CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, TIME, \ 2 | ALICE, BOB, SELFPLAY 3 | import numpy as np 4 | 5 | matplotlib_args = ["--bo"] 6 | matplotlib_kwargs = {"ms":1.0} 7 | metrics_key_set = set([CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, TIME]) 8 | agents = [ALICE, BOB] 9 | 10 | def get_env_list(base_env): 11 | return [SELFPLAY + "_" + base_env, SELFPLAY + "_target_" + base_env, base_env, SELFPLAY + "_memory_" + base_env] 12 | 13 | def compute_running_average(metric, window_size=1000): 14 | new_metric = [] 15 | for i in range(window_size, len(metric)-window_size-1): 16 | new_metric.append(sum(metric[i-window_size:i])/window_size) 17 | return np.asarray(new_metric) 18 | -------------------------------------------------------------------------------- /SelfPlay/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/policy/__init__.py -------------------------------------------------------------------------------- /SelfPlay/policy/acrobot_policy.py: -------------------------------------------------------------------------------- 1 | from policy.base_policy import BasePolicyReinforce, BasePolicyReinforceWithBaseline 2 | from utils.constant import * 3 | 4 | default_input_size = 6 5 | 6 | 7 | class AcrobotPolicyReinforce(BasePolicyReinforce): 8 | def __init__(self, policy_config): 9 | if not policy_config[INPUT_SIZE]: 10 | policy_config[INPUT_SIZE] = default_input_size 11 | policy_config[SHARED_FEATURES_SIZE] = 128 12 | super(AcrobotPolicyReinforce, self).__init__(policy_config) 13 | self.init_weights() 14 | 15 | 16 | class AcrobotPolicyReinforceWithBaseline(BasePolicyReinforceWithBaseline): 17 | def __init__(self, policy_config): 18 | if not policy_config[INPUT_SIZE]: 19 | policy_config[INPUT_SIZE] = default_input_size 20 | policy_config[SHARED_FEATURES_SIZE] = 128 21 | super(AcrobotPolicyReinforceWithBaseline, self).__init__(policy_config) 22 | self.init_weights() 23 | -------------------------------------------------------------------------------- /SelfPlay/policy/base_policy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from time import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from torch.distributions import Categorical 11 | from torch.nn.init import xavier_uniform 12 | 13 | from memory.base_memory import BaseMemory 14 | from memory.lstm_memory import LstmMemory 15 | from memory.memory_config import MemoryConfig 16 | from utils.constant import * 17 | from utils.log import write_loss_log 18 | 19 | 20 | class BasePolicy(torch.nn.Module): 21 | def __init__(self, policy_config): 22 | super(BasePolicy, self).__init__() 23 | self.logits = [] 24 | # This corresponds to the log(pi) values 25 | self.returns = [] 26 | self.use_baseline = policy_config[USE_BASELINE] 27 | self.losses = [0.0, 0.0] 28 | self.update_frequency = int(policy_config[BATCH_SIZE]) 29 | self.update_counter = 0 30 | self.shared_features = None 31 | self.actor = None 32 | self.critic = None 33 | self.is_self_play = policy_config[IS_SELF_PLAY] 34 | self.is_self_play_with_memory = bool(policy_config[IS_SELF_PLAY_WITH_MEMORY] * self.is_self_play) 35 | self.input_size = policy_config[INPUT_SIZE] 36 | self.shared_features_size = policy_config[SHARED_FEATURES_SIZE] 37 | self.shared_features_size_output = self.shared_features_size 38 | memory_config = MemoryConfig(episode_memory_size=policy_config[EPISODE_MEMORY_SIZE], 39 | input_dim=self.shared_features_size, output_dim=self.shared_features_size) 40 | if policy_config[MEMORY_TYPE] == BASE_MEMORY: 41 | self.memory = BaseMemory(memory_config=memory_config) 42 | elif policy_config[MEMORY_TYPE] == LSTM_MEMORY: 43 | self.memory = LstmMemory(memory_config=memory_config) 44 | 45 | if (self.is_self_play): 46 | self.input_size = self.input_size * 2 47 | if(self.is_self_play_with_memory): 48 | self.shared_features_size_output = self.shared_features_size*2 49 | # if (self.is_self_play_with_memory): 50 | # self.input_size = self.input_size * 3 51 | # else: 52 | # self.input_size = self.input_size * 2 53 | # Mote that the name of the environment and the agent are provided only for the sake of book-keeping 54 | self.bookkeeping = {} 55 | self.bookkeeping[ENVIRONMENT] = policy_config[ENVIRONMENT] 56 | if (self.is_self_play): 57 | self.bookkeeping[ENVIRONMENT] = SELFPLAY + "_" + self.bookkeeping[ENVIRONMENT] 58 | self.bookkeeping[AGENT] = policy_config[AGENT] 59 | self.num_actions = policy_config[NUM_ACTIONS] 60 | 61 | self.shared_features = nn.Sequential( 62 | nn.Linear(self.input_size, self.shared_features_size), 63 | nn.ReLU() 64 | ) 65 | 66 | self.actor = nn.Sequential( 67 | nn.Linear(self.shared_features_size_output, self.num_actions) 68 | ) 69 | 70 | def update_memory(self, history): 71 | # Tuple of Observations(start_state, end_state) 72 | history = self.shared_features(Variable(torch.from_numpy(history)).float()) 73 | self.memory.update_memory(history=history) 74 | 75 | def summarize_memory(self): 76 | return self.memory.summarize() 77 | 78 | def forward(self, data): 79 | pass 80 | 81 | def _to_do_update(self, agent_name): 82 | if (self.update_counter % self.update_frequency == 0): 83 | write_loss_log(average_batch_loss=(sum(self.losses)).data[0] / self.update_frequency, agent=agent_name, 84 | environment=self.bookkeeping[ENVIRONMENT]) 85 | return True 86 | return False 87 | 88 | def init_weights(self): 89 | self.init_weights_shared_features() 90 | self.init_weights_actor() 91 | self.init_weights_critic() 92 | 93 | def init_weights_shared_features(self): 94 | if self.shared_features: 95 | for layer in self.shared_features: 96 | self._init_weights_layer(layer) 97 | 98 | def init_weights_actor(self): 99 | if self.actor: 100 | for layer in self.actor: 101 | self._init_weights_layer(layer) 102 | 103 | def init_weights_critic(self): 104 | if self.critic: 105 | for layer in self.critic: 106 | self._init_weights_layer(layer) 107 | 108 | def _init_weights_layer(self, layer): 109 | '''Method to initialise the weights for a given layer''' 110 | if isinstance(layer, nn.Linear): 111 | xavier_uniform(layer.weight.data) 112 | # xavier_uniform(layer.bias.data) 113 | 114 | def save_model(self, epochs=-1, optimisers=None, save_dir=None, name=ALICE, timestamp=None): 115 | ''' 116 | Method to persist the model 117 | ''' 118 | if not timestamp: 119 | timestamp = str(int(time())) 120 | state = { 121 | EPOCHS: epochs + 1, 122 | STATE_DICT: self.state_dict(), 123 | OPTIMISER: [optimiser.state_dict() for optimiser in optimisers], 124 | NP_RANDOM_STATE: np.random.get_state(), 125 | PYTHON_RANDOM_STATE: random.getstate(), 126 | PYTORCH_RANDOM_STATE: torch.get_rng_state() 127 | } 128 | path = os.path.join(save_dir, 129 | name + "_model_timestamp_" + timestamp + ".tar") 130 | torch.save(state, path) 131 | print("saved model to path = {}".format(path)) 132 | 133 | def load_model(self, optimisers, load_path=None, name=ALICE, timestamp=None): 134 | timestamp = str(timestamp) 135 | path = os.path.join(load_path, 136 | name + "_model_timestamp_" + timestamp + ".tar") 137 | print("Loading model from path {}".format(path)) 138 | checkpoint = torch.load(path) 139 | epochs = checkpoint[EPOCHS] 140 | self._load_metadata(checkpoint) 141 | self._load_model_params(checkpoint[STATE_DICT]) 142 | 143 | for i, _ in enumerate(optimisers): 144 | optimisers[i].load_state_dict(checkpoint[OPTIMISER][i]) 145 | return optimisers, epochs 146 | 147 | def _load_metadata(self, checkpoint): 148 | np.random.set_state(checkpoint[NP_RANDOM_STATE]) 149 | random.setstate(checkpoint[PYTHON_RANDOM_STATE]) 150 | torch.set_rng_state(checkpoint[PYTORCH_RANDOM_STATE]) 151 | 152 | def _load_model_params(self, state_dict): 153 | self.load_state_dict(state_dict) 154 | 155 | def get_reward(self, observation): 156 | if (self.is_self_play and isinstance(observation, tuple) and len(observation) > 1): 157 | # This should be modified in the future 158 | reward = observation[0].reward 159 | else: 160 | reward = observation.reward 161 | return reward 162 | 163 | def get_state(self, observation): 164 | if (self.is_self_play and isinstance(observation, tuple)): 165 | # Should be removed 166 | if (self.is_self_play_with_memory and len(observation) == 3): 167 | S = Variable(torch.cat((torch.from_numpy(observation[0].state).float(), 168 | torch.from_numpy(observation[1].state).float(), 169 | torch.from_numpy(observation[2].state).float())).unsqueeze(0)) 170 | # elif (not self.is_self_play_with_memory and len(observation) == 2): 171 | elif (len(observation) == 2): 172 | S = Variable(torch.cat((torch.from_numpy(observation[0].state).float(), 173 | torch.from_numpy(observation[1].state).float())).unsqueeze(0)) 174 | else: 175 | S = Variable((torch.from_numpy(observation.state).float().unsqueeze(0))) 176 | return S 177 | 178 | 179 | class BasePolicyReinforce(BasePolicy): 180 | # def __init__(self, input_size=-1, batch_size=32, is_self_play=False): 181 | def __init__(self, policy_config): 182 | super(BasePolicyReinforce, self).__init__(policy_config) 183 | 184 | # super(BasePolicyReinforce, self).__init__(use_baseline=False, input_size=input_size, batch_size=batch_size, is_self_play=is_self_play) 185 | self.logits = [] 186 | # This corresponds to the log(pi) values 187 | self.returns = [] 188 | 189 | def forward(self, data): 190 | shared_features = F.relu(self.shared_features(data)) 191 | if(self.is_self_play and self.is_self_play_with_memory): 192 | shared_features = torch.cat((F.relu(self.shared_features(data)), 193 | self.summarize_memory().unsqueeze(0).detach()), dim=1) 194 | action_logits = self.actor(shared_features) 195 | return F.softmax(action_logits, dim=1) 196 | 197 | def get_action(self, observation): 198 | reward = self.get_reward(observation) 199 | S = self.get_state(observation) 200 | self.returns.append(reward) 201 | action_prob = self.forward(S) 202 | distribution = Categorical(action_prob) 203 | action = distribution.sample() 204 | self.logits.append(distribution.log_prob(action)) 205 | return action.data[0] 206 | 207 | def update(self, optimisers, gamma, agent_name): 208 | # In this case, the list of optimisers has just 1 value 209 | optimiser = optimisers[0] 210 | running_return = 0 211 | policy_loss = [] 212 | _returns = [] 213 | gamma_exps = [] 214 | current_gamma_exp = 1.0 215 | for _return in self.returns[::-1]: 216 | running_return = _return + gamma * running_return 217 | _returns.insert(0, running_return) 218 | gamma_exps.append(current_gamma_exp) 219 | current_gamma_exp = current_gamma_exp * gamma 220 | _returns = torch.FloatTensor(_returns) 221 | gamma_exps = torch.FloatTensor(gamma_exps) 222 | _returns = (_returns - _returns.mean()) / (_returns.std(unbiased=False) + np.finfo(np.float32).eps) 223 | for logit, _return, current_gamma_exp in zip(self.logits, _returns, gamma_exps): 224 | policy_loss.append(-logit * _return * current_gamma_exp) 225 | total_loss = torch.cat(policy_loss).sum() 226 | self.losses[0] += total_loss 227 | 228 | self.update_counter += 1 229 | 230 | if (self._to_do_update(agent_name=agent_name)): 231 | optimiser.zero_grad() 232 | loss = self.losses[0] / self.update_frequency 233 | loss.backward() 234 | optimiser.step() 235 | self.losses[0] = 0.0 236 | self.returns = [] 237 | self.logits = [] 238 | return (optimiser,) 239 | 240 | 241 | class BasePolicyReinforceWithBaseline(BasePolicy): 242 | def __init__(self, policy_config): 243 | super(BasePolicyReinforceWithBaseline, self).__init__(policy_config) 244 | self.logits = [] 245 | # This corresponds to the log(pi) values 246 | self.returns = [] 247 | self.state_values = [] 248 | self.actor_params_names = set([SHARED_FEATURES, ACTOR]) 249 | self.critic_params_names = set([SHARED_FEATURES, CRITIC]) 250 | self._lambda = policy_config[LAMBDA] 251 | self.is_self_play = policy_config[IS_SELF_PLAY] 252 | self.critic = nn.Sequential( 253 | nn.Linear(self.shared_features_size_output, 1) 254 | ) 255 | 256 | def forward(self, data): 257 | shared_features = F.relu(self.shared_features(data)) 258 | if(self.is_self_play and self.is_self_play_with_memory): 259 | shared_features = torch.cat((shared_features, 260 | self.summarize_memory().unsqueeze(0).detach()), dim=1) 261 | action_logits = self.actor(shared_features) 262 | state_values = self.critic(shared_features) 263 | return F.softmax(action_logits, dim=1), state_values 264 | 265 | def get_action(self, observation): 266 | reward = self.get_reward(observation) 267 | S = self.get_state(observation) 268 | self.returns.append(reward) 269 | action_prob, state_value = self.forward(S) 270 | distribution = Categorical(action_prob) 271 | action = distribution.sample() 272 | self.logits.append(distribution.log_prob(action)) 273 | self.state_values.append(state_value) 274 | return action.data[0] 275 | 276 | def get_actor_params(self): 277 | params = [] 278 | for param in self.named_parameters(): 279 | if param[0].split(".")[0] in self.actor_params_names and param[1].requires_grad: 280 | params.append(param[1]) 281 | return params 282 | 283 | def get_memory_params(self): 284 | return self.memory.get_params() 285 | 286 | def get_critic_params(self): 287 | params = [] 288 | for param in self.named_parameters(): 289 | if param[0].split(".")[0] in self.critic_params_names and param[1].requires_grad: 290 | params.append(param[1]) 291 | return params 292 | 293 | def update(self, optimisers, gamma, agent_name): 294 | num_optimisers = len(optimisers) 295 | if (num_optimisers == 1): 296 | optimiser = optimisers[0] 297 | elif (num_optimisers == 2): 298 | actor_optimiser, critic_optimiser = optimisers 299 | 300 | running_return = 0 301 | policy_loss = [] 302 | state_value_loss = [] 303 | _returns = [] 304 | gamma_exps = [] 305 | current_gamma_exp = 1.0 306 | for _return in self.returns[::-1]: 307 | running_return = _return + gamma * running_return 308 | _returns.insert(0, running_return) 309 | gamma_exps.append(current_gamma_exp) 310 | current_gamma_exp = current_gamma_exp * gamma 311 | _returns = torch.FloatTensor(_returns) 312 | gamma_exps = torch.FloatTensor(gamma_exps) 313 | _returns = (_returns - _returns.mean()) / (_returns.std(unbiased=False) + np.finfo(np.float32).eps) 314 | for logit, state_value, _return, current_gamma_exp in zip(self.logits, self.state_values, _returns, gamma_exps): 315 | state_value_loss.append(F.smooth_l1_loss(state_value, Variable(torch.Tensor([_return])))) 316 | _return = _return - state_value.data[0][0] 317 | policy_loss.append(-logit * _return * current_gamma_exp) 318 | 319 | self.returns = [] 320 | self.logits = [] 321 | self.state_values = [] 322 | 323 | actor_loss = torch.cat(policy_loss).sum() 324 | critic_loss = torch.cat(state_value_loss).sum() 325 | 326 | self.update_counter += 1 327 | 328 | if (num_optimisers == 1): 329 | loss = actor_loss + self._lambda * critic_loss 330 | self.losses[0] += loss 331 | if (self._to_do_update(agent_name=agent_name)): 332 | optimiser.zero_grad() 333 | loss = self.losses[0] / self.update_frequency 334 | loss.backward() 335 | optimiser.step() 336 | self.losses[0] = 0.0 337 | return (optimiser,) 338 | 339 | elif (num_optimisers == 2): 340 | self.losses[0] += actor_loss 341 | self.losses[1] += critic_loss 342 | if (self._to_do_update(agent_name=agent_name)): 343 | actor_optimiser.zero_grad() 344 | actor_loss.backward(retain_graph=True) 345 | actor_optimiser.step() 346 | 347 | critic_optimiser.zero_grad() 348 | critic_loss.backward() 349 | critic_optimiser.step() 350 | 351 | self.losses[0] = 0.0 352 | self.losses[1] = 0.0 353 | 354 | return (actor_optimiser, critic_optimiser) 355 | -------------------------------------------------------------------------------- /SelfPlay/policy/cartpole_policy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from policy.base_policy import BasePolicyReinforce, BasePolicyReinforceWithBaseline 4 | from utils.constant import * 5 | 6 | default_input_size = 4 7 | 8 | 9 | class CartpolePolicyReinforce(BasePolicyReinforce): 10 | def __init__(self, policy_config): 11 | if not policy_config[INPUT_SIZE]: 12 | policy_config[INPUT_SIZE] = default_input_size 13 | policy_config[SHARED_FEATURES_SIZE] = 10 14 | super(CartpolePolicyReinforce, self).__init__(policy_config) 15 | self.init_weights() 16 | 17 | 18 | class CartpolePolicyReinforceWithBaseline(BasePolicyReinforceWithBaseline): 19 | def __init__(self, policy_config): 20 | if not policy_config[INPUT_SIZE]: 21 | policy_config[INPUT_SIZE] = default_input_size 22 | policy_config[SHARED_FEATURES_SIZE] = 10 23 | super(CartpolePolicyReinforceWithBaseline, self).__init__(policy_config) 24 | self.init_weights() 25 | -------------------------------------------------------------------------------- /SelfPlay/policy/mazebase_policy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from policy.base_policy import BasePolicyReinforce, BasePolicyReinforceWithBaseline 4 | from utils.constant import * 5 | 6 | default_input_size = 10 * 10 * 156 7 | 8 | class MazebasePolicyReinforce(BasePolicyReinforce): 9 | def __init__(self, policy_config): 10 | if not policy_config[INPUT_SIZE]: 11 | policy_config[INPUT_SIZE] = default_input_size 12 | policy_config[SHARED_FEATURES_SIZE] = 50 13 | super(MazebasePolicyReinforce, self).__init__(policy_config) 14 | self.init_weights() 15 | 16 | class MazebasePolicyReinforceWithBaseline(BasePolicyReinforceWithBaseline): 17 | def __init__(self,policy_config): 18 | if not policy_config[INPUT_SIZE]: 19 | policy_config[INPUT_SIZE] = default_input_size 20 | policy_config[SHARED_FEATURES_SIZE] = 50 21 | super(MazebasePolicyReinforceWithBaseline, self).__init__(policy_config) 22 | self.shared_features = nn.Sequential( 23 | nn.Linear(self.input_size, self.shared_features_size), 24 | nn.Tanh() 25 | ) 26 | self.init_weights() -------------------------------------------------------------------------------- /SelfPlay/policy/mountaincar_policy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from policy.base_policy import BasePolicyReinforce, BasePolicyReinforceWithBaseline 4 | from utils.constant import * 5 | 6 | default_input_size = 2 7 | 8 | class MountainCarPolicyReinforce(BasePolicyReinforce): 9 | def __init__(self, policy_config): 10 | if not policy_config[INPUT_SIZE]: 11 | policy_config[INPUT_SIZE] = default_input_size 12 | policy_config[SHARED_FEATURES_SIZE] = 128 13 | super(MountainCarPolicyReinforce, self).__init__(policy_config) 14 | self.init_weights() 15 | 16 | 17 | class MountainCarPolicyReinforceWithBaseline(BasePolicyReinforceWithBaseline): 18 | def __init__(self, policy_config): 19 | if not policy_config[INPUT_SIZE]: 20 | policy_config[INPUT_SIZE] = default_input_size 21 | policy_config[SHARED_FEATURES_SIZE] = 128 22 | super(MountainCarPolicyReinforceWithBaseline, self).__init__(policy_config) 23 | self.init_weights() 24 | -------------------------------------------------------------------------------- /SelfPlay/policy/policy_config.py: -------------------------------------------------------------------------------- 1 | from utils.constant import * 2 | 3 | 4 | class PolicyConfig: 5 | def __init__(self, env_name = ENVIRONMENT, 6 | agent_type = REINFORCE_AGENT, 7 | use_baseline=True, 8 | input_size=None, 9 | num_actions=10, 10 | batch_size=32, 11 | is_self_play=False, 12 | is_self_play_with_memory=False, 13 | _lambda=0.3, 14 | agent=AGENT, 15 | episode_memory_size=10, 16 | memory_type=BASE_MEMORY): 17 | self.data = { 18 | ENVIRONMENT: env_name, 19 | AGENT_TYPE: agent_type, 20 | USE_BASELINE: use_baseline, 21 | INPUT_SIZE: input_size, 22 | NUM_ACTIONS: num_actions, 23 | BATCH_SIZE: batch_size, 24 | IS_SELF_PLAY: is_self_play, 25 | IS_SELF_PLAY_WITH_MEMORY: is_self_play_with_memory, 26 | LAMBDA: _lambda, 27 | AGENT: agent, 28 | EPISODE_MEMORY_SIZE: episode_memory_size, 29 | MEMORY_TYPE: memory_type, 30 | } 31 | 32 | def __getitem__(self, key): 33 | return self.data[key] 34 | 35 | def __setitem__(self, key, value): 36 | self.data[key] = value -------------------------------------------------------------------------------- /SelfPlay/policy/registry.py: -------------------------------------------------------------------------------- 1 | from policy.acrobot_policy import AcrobotPolicyReinforce, AcrobotPolicyReinforceWithBaseline 2 | from policy.cartpole_policy import CartpolePolicyReinforce, CartpolePolicyReinforceWithBaseline 3 | from policy.mazebase_policy import MazebasePolicyReinforce, MazebasePolicyReinforceWithBaseline 4 | from policy.mountaincar_policy import MountainCarPolicyReinforce, MountainCarPolicyReinforceWithBaseline 5 | from policy.policy_config import PolicyConfig 6 | from utils.constant import * 7 | 8 | 9 | def choose_policy(env_name=ENVIRONMENT, 10 | agent_type=REINFORCE_AGENT, 11 | use_baseline=True, 12 | input_size=10 * 10 * 156, 13 | num_actions=10, 14 | batch_size=32, 15 | is_self_play=False, 16 | is_self_play_with_memory=False, 17 | _lambda=0.3, 18 | agent=AGENT, 19 | episode_memory_size=10): 20 | policy_name = env_name + "_" + POLICY + "_" + agent_type.split("_")[0] 21 | 22 | policy_config = PolicyConfig(env_name=env_name, 23 | agent_type=agent_type, 24 | use_baseline=use_baseline, 25 | input_size=input_size, 26 | num_actions=num_actions, 27 | batch_size=batch_size, 28 | is_self_play=is_self_play, 29 | is_self_play_with_memory=is_self_play_with_memory, 30 | _lambda=_lambda, 31 | agent=agent, 32 | episode_memory_size=episode_memory_size 33 | ) 34 | 35 | if (use_baseline): 36 | policy_name += "_with_baseline" 37 | 38 | if (policy_name == MAZEBASE_POLICY_REINFORCE): 39 | return MazebasePolicyReinforce(policy_config) 40 | if (policy_name == MAZEBASE_POLICY_REINFORCE_WITH_BASELINE): 41 | return MazebasePolicyReinforceWithBaseline(policy_config) 42 | elif (policy_name == ACROBOT_POLICY_REINFORCE): 43 | return AcrobotPolicyReinforce(policy_config) 44 | elif (policy_name == ACROBOT_POLICY_REINFORCE_WITH_BASELINE): 45 | return AcrobotPolicyReinforceWithBaseline(policy_config) 46 | elif (policy_name == CARTPOLE_POLICY_REINFORCE): 47 | return CartpolePolicyReinforce(policy_config) 48 | elif (policy_name == CARTPOLE_POLICY_REINFORCE_WITH_BASELINE): 49 | return CartpolePolicyReinforceWithBaseline(policy_config) 50 | elif (policy_name == MOUNTAINCAR_POLICY_REINFORCE): 51 | return MountainCarPolicyReinforce(policy_config) 52 | elif (policy_name == MOUNTAINCAR_POLICY_REINFORCE_WITH_BASELINE): 53 | return MountainCarPolicyReinforceWithBaseline(policy_config) 54 | -------------------------------------------------------------------------------- /SelfPlay/requirements.txt: -------------------------------------------------------------------------------- 1 | http://download.pytorch.org/whl/cu80/torch-0.3.1-cp36-cp36m-linux_x86_64.whl 2 | numpy 3 | gym 4 | git+git://github.com/facebook/MazeBase.git 5 | matplotlib -------------------------------------------------------------------------------- /SelfPlay/scripts/filter_json_lines.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser(description='Preprocess the log file by filtering non-json line') 5 | parser.add_argument('--input_file_path', type=str, 6 | help="Path of the input log file") 7 | parser.add_argument('--output_file_path', type=str, 8 | help="Path of the output log file") 9 | 10 | args = parser.parse_args() 11 | 12 | with open(args.output_file_path, "w") as output_file: 13 | with open(args.input_file_path, "r") as input_file: 14 | for log in input_file: 15 | try: 16 | data = json.loads(log) 17 | if (isinstance(data, dict)): 18 | output_file.write(log) 19 | except json.JSONDecodeError as e: 20 | pass 21 | -------------------------------------------------------------------------------- /SelfPlay/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | PYTHONPATH=$PWD python3 -u app/mountaincar_reinforce.py --model_num_epochs 6 --model_batch_size 2 --model_max_steps_per_episode 100 --model_is_self_play False> log.txt 3 | PYTHONPATH=$PWD python3 -u app/acrobot_reinforce.py --model_num_epochs 6 --model_batch_size 2 --model_max_steps_per_episode 100 --model_is_self_play False > log.txt 4 | PYTHONPATH=$PWD python3 -u app/mazebase_reinforce.py --model_num_epochs 6 --model_batch_size 2 --model_use_baseline False --model_max_steps_per_episode 100 --model_is_self_play False > log.txt 5 | PYTHONPATH=$PWD python3 -u app/cartpole_reinforce.py --model_num_epochs 6 --model_batch_size 2 --model_max_steps_per_episode 100 --model_use_baseline True --model_is_self_play False 6 | PYTHONPATH=$PWD python3 -u app/selfplay.py --model_num_epochs 3 --model_batch_size 2 --model_max_steps_per_episode 40 --model_is_self_play True --model_is_self_play_with_memory True --model_memory_type lstm_memory > log.txt 7 | PYTHONPATH=$PWD python3 -u app/selfplay.py --model_num_epochs 3 --model_batch_size 2 --model_max_steps_per_episode 40 --model_is_self_play True --model_is_self_play_with_memory True --model_memory_type base_memory > log.txt 8 | PYTHONPATH=$PWD python3 -u app/selfplay.py --model_num_epochs 3 --model_batch_size 2 --model_max_steps_per_episode 40 --model_is_self_play True > log.txt 9 | PYTHONPATH=$PWD python3 -u app/plot.py 10 | -------------------------------------------------------------------------------- /SelfPlay/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/SelfPlay/utils/__init__.py -------------------------------------------------------------------------------- /SelfPlay/utils/argument_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | # Taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError('Boolean value expected.') 12 | 13 | 14 | def argument_parser(config): 15 | delimiter = "_" 16 | parser = argparse.ArgumentParser() 17 | for key1 in config: 18 | for key2 in config[key1]: 19 | argument = "--" + key1 + delimiter + key2 20 | _type = type(config[key1][key2]) 21 | if (_type == bool): 22 | parser.add_argument(argument, help="Refer config.cfg to know about the config params", 23 | type=str2bool) 24 | else: 25 | parser.add_argument(argument, help="Refer config.cfg to know about the config params", 26 | type=_type) 27 | args = vars(parser.parse_args()) 28 | 29 | for key, value in args.items(): 30 | key1, *key2 = key.split(delimiter) 31 | key2 = delimiter.join(key2) 32 | if value: 33 | config[key1][key2] = value 34 | 35 | return config 36 | -------------------------------------------------------------------------------- /SelfPlay/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from configparser import ConfigParser 3 | from datetime import datetime 4 | 5 | from agent.registry import get_supported_agents 6 | from environment.registry import get_supported_envs 7 | from utils.argument_parser import argument_parser 8 | from utils.constant import * 9 | from utils.optim_registry import get_supported_optimisers 10 | from utils.util import make_dir 11 | 12 | 13 | def _read_config(): 14 | ''' 15 | Method to read the config file and return as a dict 16 | :return: 17 | ''' 18 | config = ConfigParser() 19 | path = os.path.abspath(os.pardir).split('/SelfPlay')[0] 20 | config.read(os.path.join(path, 'SelfPlay/config', 'config.cfg')) 21 | return config._sections 22 | 23 | 24 | def _get_boolean_value(value): 25 | if (value.lower() == TRUE): 26 | return True 27 | else: 28 | return False 29 | 30 | 31 | def get_config(use_cmd_config = True): 32 | '''Method to prepare the config for all downstream tasks''' 33 | 34 | # Read the config file 35 | config = _read_config() 36 | 37 | if(use_cmd_config): 38 | config = argument_parser(config) 39 | 40 | if (config[GENERAL][BASE_PATH] == ""): 41 | base_path = os.getcwd().split('/SelfPlay')[0] 42 | config[GENERAL][BASE_PATH] = base_path 43 | 44 | if (config[GENERAL][DEVICE] == ""): 45 | config[GENERAL][DEVICE] = CPU 46 | 47 | for key in [SEED]: 48 | config[GENERAL][key] = int(config[GENERAL][key]) 49 | 50 | key = ID 51 | if config[GENERAL][key] == "": 52 | config[GENERAL][key] = str(config[GENERAL][SEED]) 53 | 54 | # Model Params 55 | for key in [NUM_EPOCHS, BATCH_SIZE, PERSIST_PER_EPOCH, EARLY_STOPPING_PATIENCE, 56 | NUM_OPTIMIZERS, LOAD_TIMESTAMP, MAX_STEPS_PER_EPISODE, MAX_STEPS_PER_EPISODE_SELFPLAY, 57 | TARGET_TO_SELFPLAY_RATIO, EPISODE_MEMORY_SIZE]: 58 | config[MODEL][key] = int(config[MODEL][key]) 59 | 60 | for key in [LEARNING_RATE, GAMMA, LAMBDA, LEARNING_RATE_ACTOR, LEARNING_RATE_CRITIC, REWARD_SCALE]: 61 | config[MODEL][key] = float(config[MODEL][key]) 62 | 63 | for key in [USE_BASELINE, LOAD, IS_SELF_PLAY, IS_SELF_PLAY_WITH_MEMORY]: 64 | config[MODEL][key] = _get_boolean_value(config[MODEL][key]) 65 | 66 | agent = config[MODEL][AGENT] 67 | 68 | if (agent not in get_supported_agents()): 69 | config[MODEL][AGENT] = REINFORCE 70 | 71 | env = config[MODEL][ENV] 72 | 73 | if (env not in get_supported_envs()): 74 | config[MODEL][ENV] = MAZEBASE 75 | 76 | optimiser = config[MODEL][OPTIMISER] 77 | if (optimiser not in get_supported_optimisers()): 78 | config[MODEL][OPTIMISER] = ADAM 79 | 80 | if (config[MODEL][SAVE_DIR] == ""): 81 | config[MODEL][SAVE_DIR] = os.path.join(config[GENERAL][BASE_PATH], "model") 82 | elif (config[MODEL][SAVE_DIR][0] != "/"): 83 | config[MODEL][SAVE_DIR] = os.path.join(config[GENERAL][BASE_PATH], config[MODEL][SAVE_DIR]) 84 | 85 | make_dir(config[MODEL][SAVE_DIR]) 86 | 87 | if (config[MODEL][LOAD_PATH] == ""): 88 | config[MODEL][LOAD_PATH] = os.path.join(config[GENERAL][BASE_PATH], "model") 89 | 90 | elif (config[MODEL][LOAD_PATH][0] != "/"): 91 | config[MODEL][LOAD_PATH] = os.path.join(config[GENERAL][BASE_PATH], config[MODEL][LOAD_PATH]) 92 | 93 | # TB Params 94 | config[TB][DIR] = os.path.join(config[TB][BASE_PATH], datetime.now().strftime('%b%d_%H-%M-%S')) 95 | config[TB][SCALAR_PATH] = os.path.join(config[TB][BASE_PATH], "all_scalars.json") 96 | 97 | # Log Params 98 | key = FILE_PATH 99 | if (config[LOG][key] == ""): 100 | config[LOG][key] = os.path.join(config[GENERAL][BASE_PATH], 101 | "SelfPlay", 102 | "log_{}.txt".format(str(config[GENERAL][SEED]))) 103 | 104 | # Plot Params 105 | if (config[PLOT][BASE_PATH] == ""): 106 | config[PLOT][BASE_PATH] = os.path.join(config[GENERAL][BASE_PATH], "plot", config[GENERAL][ID]) 107 | 108 | make_dir(path=config[PLOT][BASE_PATH]) 109 | 110 | return config 111 | -------------------------------------------------------------------------------- /SelfPlay/utils/constant.py: -------------------------------------------------------------------------------- 1 | _LAMBDA = "_lambda" 2 | AC_AGENT = "ac_agent" 3 | ACROBOT = "acrobot" 4 | ACROBOT_POLICY_REINFORCE = "acrobot_policy_reinforce" 5 | ACROBOT_POLICY_REINFORCE_WITH_BASELINE = "acrobot_policy_reinforce_with_baseline" 6 | ACTOR = "actor" 7 | ADAM = "adam" 8 | AGENT = "agent" 9 | AGENT_TYPE = "agent_type" 10 | ALICE = "alice" 11 | ALICE_END_POSITION = "alice_end_position" 12 | ALICE_POSITION = "alice_position" 13 | ALICE_START_POSITION = "alice_start_position" 14 | AVERAGE = "average" 15 | AVERAGE_BATCH_LOSS = "average_batch_loss" 16 | AVERAGE_EPISODIC_REWARD = "average_episodic_reward" 17 | BASE_AGENT = "base_agent" 18 | BASE_MEMORY = "base_memory" 19 | BASE_MODEL = "base_model" 20 | BASE_PATH = "base_path" 21 | BATCH_SIZE = "batch_size" 22 | BOB = "bob" 23 | BOB_END_POSITION = "bob_end_position" 24 | BOB_POSITION = "bob_position" 25 | BOB_START_POSITION = "bob_start_position" 26 | CARTPOLE = "cartpole" 27 | CARTPOLE_POLICY_AC = "cartpole_policy_ac" 28 | CARTPOLE_POLICY_REINFORCE = "cartpole_policy_reinforce" 29 | CARTPOLE_POLICY_REINFORCE_WITH_BASELINE = "cartpole_policy_reinforce_with_baseline" 30 | CONFIG ="config" 31 | COPY = "copy" 32 | CPU = "cpu" 33 | CRITIC = "critic" 34 | CURRENT_EPISODIC_REWARD = "current_episodic_reward" 35 | DATASET = "dataset" 36 | DEBUG = "debug" 37 | DESCRIPTION = "description" 38 | DEVICE = "device" 39 | DIM = "dim" 40 | DIR = "dir" 41 | DOTFULL = ".full" 42 | DOTMIN = ".min" 43 | EARLY_STOPPING_PATIENCE = "early_stopping_patience" 44 | EMBEDDING = "embedding" 45 | EMBEDDING_DIM = "embedding_dim" 46 | ENV = "env" 47 | ENVIRONMENT = "environment" 48 | EPISODE = "Episode" 49 | EPISODE_MEMORY_SIZE = "episode_memory_size" 50 | EPISODE_NUMBER = "episode_number" 51 | EPOCHS = "epochs" 52 | FILE_PATH = "file_path" 53 | FULL = "full" 54 | GAMMA = "gamma" 55 | GENERAL = "general" 56 | GPU = "gpu" 57 | HUMAN = "human" 58 | HUMAN_AGENT = "human_agent" 59 | ID = "id" 60 | INPUT_DIM = "input_dim" 61 | INPUT_SIZE = "input_size" 62 | IS_PREPROCESSED = "is_preprocessed" 63 | IS_SELF_PLAY = "is_self_play" 64 | IS_SELF_PLAY_WITH_MEMORY = "is_self_play_with_memory" 65 | JSON = "json" 66 | K = "k" 67 | LAMBDA = "lambda" 68 | LEARNING_RATE = "learning_rate" 69 | LEARNING_RATE_ACTOR = "learning_rate_actor" 70 | LEARNING_RATE_CRITIC = "learning_rate_critic" 71 | LOAD = "load" 72 | LOAD_PATH = "load_path" 73 | LOAD_TIMESTAMP = "load_timestamp" 74 | LOG = "log" 75 | LOSS = "loss" 76 | LSTM_MEMORY = "lstm_memory" 77 | MAX_MEMORY_SIZE = "max_memory_size" 78 | MAX_STEPS_PER_EPISODE = "max_steps_per_episode" 79 | MAX_STEPS_PER_EPISODE_SELFPLAY = "max_steps_per_episode_selfplay" 80 | MAZEBASE = "mazebase" 81 | MAZEBASE_POLICY = "mazebase_policy" 82 | MAZEBASE_POLICY_AC = "mazebase_policy_ac" 83 | MAZEBASE_POLICY_REINFORCE = "mazebase_policy_reinforce" 84 | MAZEBASE_POLICY_REINFORCE_WITH_BASELINE = "mazebase_policy_reinforce_with_baseline" 85 | MEMORY = "memory" 86 | MEMORY_SIZE = "memory_size" 87 | MEMORY_TYPE = "memory_type" 88 | MIN = "min" 89 | MODEL = "model" 90 | MOUNTAINCAR = "mountaincar" 91 | MOUNTAINCAR_POLICY_REINFORCE = "mountaincar_policy_reinforce" 92 | MOUNTAINCAR_POLICY_REINFORCE_WITH_BASELINE = "mountaincar_policy_reinforce_with_baseline" 93 | NAME = "name" 94 | NEXT_STATE = "next_state" 95 | NP_RANDOM_STATE = "np_random_state" 96 | NUM_ACTIONS = "num_actions" 97 | NUM_EPOCHS = "num_epochs" 98 | NUM_OPTIMIZERS = "num_optimizers" 99 | OBSERVATION = "observation" 100 | OPTIMISER = "optimiser" 101 | OUTPUT_DIM = "output_dim" 102 | OUTPUTS = "outputs" 103 | PATH = "path" 104 | PERSIST_PER_EPOCH = "persist_per_epoch" 105 | PLOT = "plot" 106 | POLICY = "policy" 107 | POSITION = "position" 108 | PRECISION = "precision" 109 | PTB = "ptb" 110 | PYTHON_RANDOM_STATE = "python_random_state" 111 | PYTORCH_RANDOM_STATE = "pytorch_random_state" 112 | RANDOM = "random" 113 | RANDOM_AGENT = "random_agent" 114 | RECALL = "recall" 115 | REINFORCE = "reinforce" 116 | REINFORCE_AGENT = "reinforce_agent" 117 | REINFORCE_BASELINE = "reinforce_baseline" 118 | REINFORCE_BASELINE_AGENT = "reinforce_baseline_agent" 119 | REVERSE = "reverse" 120 | REWARD = "reward" 121 | REWARD_SCALE = "reward_scale" 122 | RMSPROP = "rmsprop" 123 | SAVE_DIR = "save_dir" 124 | SCALAR_PATH = "scalar_path" 125 | SEED = "seed" 126 | SELFPLAY = "selfplay" 127 | SELFPLAY_POLICY_REINFORCE = "selfplay_policy_reinforce" 128 | SELFPLAY_POLICY_REINFORCE_WITH_BASELINE = "selfplay_policy_reinforce_with_baseline" 129 | SGD = "sgd" 130 | SHARED_FEATURES = "shared_features" 131 | SHARED_FEATURES_SIZE = "shared_features_size" 132 | STATE_DICT = "state_dict" 133 | TARGET = "target" 134 | TARGET_TO_SELFPLAY_RATIO = "target_to_selfplay_ratio" 135 | TB = "tb" 136 | TEST = "test" 137 | TEST_ACCURACY = "test_accuracy" 138 | TEST_LOSS = "test_loss" 139 | TEST_MACRO_FSCORE = "test_macro_fscore" 140 | TEST_MICRO_FSCORE = "test_micro_fscore" 141 | TEST_NEG = "test_neg" 142 | TEST_NO_EDGE = "test_no_edge" 143 | TEST_POS = "test_pos" 144 | TEST_TIME = "test_time" 145 | TIME = "time" 146 | TIME_ALICE = "time_alice" 147 | TIME_BOB = "time_bob" 148 | TIMESTAMP = "timestamp" 149 | TRAIN = "train" 150 | TRAIN_ACCURACY = "train_accuracy" 151 | TRAIN_LOSS = "train_loss" 152 | TRAIN_TIME = "train_time" 153 | TRUE = "true" 154 | TYPE = "type" 155 | USE_BASELINE = "use_baseline" 156 | USE_MIN = "use_min" 157 | USE_NEURAL_BASELINE = "use_neural_baseline" 158 | USE_PRETRAINED_EMBEDDING = "use_pretrained_embedding" 159 | USE_SPARSE_EMBEDDING = "use_sparse_embedding" 160 | VAL = "val" 161 | VAL_ACCURACY = "val_accuracy" 162 | VAL_LOSS = "val_loss" 163 | VAL_MACRO_FSCORE = "val_macro_fscore" 164 | VAL_MICRO_FSCORE = "val_micro_fscore" 165 | VAL_TIME = "val_time" 166 | UNDO = "undo" 167 | SELFPLAY_TYPE = "selfplay_type" 168 | STDDEV = "stddev" -------------------------------------------------------------------------------- /SelfPlay/utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from utils.constant import * 5 | import numpy as np 6 | 7 | log_types = [CONFIG, REWARD, LOSS, AGENT, ENVIRONMENT] 8 | 9 | 10 | def _format_log(log): 11 | return json.dumps(log) 12 | 13 | 14 | def write_log(log): 15 | '''This is the default method to write a log. It is assumed that the log has already been processed 16 | before feeding to this method''' 17 | print(log) 18 | logging.info(log) 19 | 20 | 21 | def read_logs(log): 22 | '''This is the single point to read any log message from the file since all the log messages are persisted as jsons''' 23 | default_data = { 24 | TYPE: "" 25 | } 26 | try: 27 | data = json.loads(log) 28 | if (not isinstance(data, dict)): 29 | data = default_data 30 | except json.JSONDecodeError as e: 31 | data = default_data 32 | return data 33 | 34 | 35 | def _format_custom_logs(keys=[], raw_log={}, _type=REWARD): 36 | log = {} 37 | if (keys): 38 | for key in keys: 39 | if key in raw_log: 40 | log[key] = raw_log[key] 41 | else: 42 | log = raw_log 43 | log[TYPE] = _type 44 | return _format_log(log) 45 | 46 | 47 | delimiter = "\t" 48 | quantifier = ":" 49 | 50 | 51 | def format_reward_log(**kwargs): 52 | ''' 53 | Method to return the formatted string about reward information to be pushed into the logs 54 | ''' 55 | # message = EPISODE + quantifier + " {}" + delimiter + \ 56 | # CURRENT_EPISODIC_REWARD + quantifier + " {:5f}" + delimiter + \ 57 | # AVERAGE_EPISODIC_REWARD + quantifier + " {:5f}" 58 | # return message.format( 59 | # episode_number, current_episodic_reward, average_episodic_reward) 60 | 61 | return json.dumps(kwargs) 62 | 63 | 64 | def format_config_log(config): 65 | ''' 66 | Method to return the formatted string about config information to be pushed into the logs 67 | ''' 68 | return json.dumps(config) 69 | 70 | 71 | def format_loss_log(**kwargs): 72 | return json.dumps(kwargs) 73 | 74 | 75 | def write_config_log(config): 76 | config[TYPE] = CONFIG 77 | log = _format_log(config) 78 | write_log(log) 79 | 80 | 81 | def write_reward_log(**kwargs): 82 | keys = [CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AGENT, ENVIRONMENT] 83 | log = _format_custom_logs(keys=keys, raw_log=kwargs, _type=REWARD) 84 | write_log(log) 85 | 86 | def write_time_log(**kwargs): 87 | keys = [TIME_ALICE, TIME_BOB, AGENT, ENVIRONMENT] 88 | log = _format_custom_logs(keys=keys, raw_log=kwargs, _type=TIME) 89 | write_log(log) 90 | 91 | def write_position_log(**kwargs): 92 | keys = [ALICE_START_POSITION, ALICE_END_POSITION, BOB_START_POSITION, BOB_END_POSITION] 93 | log = _format_custom_logs(keys=keys, raw_log=kwargs, _type=POSITION) 94 | write_log(log) 95 | 96 | 97 | def write_loss_log(**kwargs): 98 | keys = [AVERAGE_BATCH_LOSS, AGENT, ENVIRONMENT] 99 | # if (AGENT, ENV in kwargs): 100 | # keys.append(AGENT) 101 | log = _format_custom_logs(keys=keys, raw_log=kwargs, _type=LOSS) 102 | write_log(log) 103 | 104 | 105 | def pprint(config): 106 | print(json.dumps(config, indent=4)) 107 | 108 | 109 | def parse_log_file(log_file_path, agent=None, env_list=None): 110 | logs = {} 111 | agent_keys = [CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, 112 | ENVIRONMENT, TIME] 113 | common_keys = [CONFIG] 114 | keys = agent_keys + common_keys 115 | for env in env_list: 116 | logs[env] = {} 117 | for key in keys: 118 | logs[env][key] = [] 119 | with open(log_file_path, "r") as f: 120 | for line in f: 121 | data = read_logs(log=line) 122 | _type = data[TYPE] 123 | if (_type in [REWARD, LOSS]): 124 | if (data[AGENT] == agent): 125 | if(data[ENVIRONMENT] in logs): 126 | for key in agent_keys: 127 | if key in data: 128 | logs[data[ENVIRONMENT]][key].append(data[key]) 129 | elif(_type == TIME): 130 | data[ENVIRONMENT] = env_list[0] 131 | if(TIME_ALICE in data): 132 | _agent = ALICE 133 | key = TIME_ALICE 134 | elif(TIME_BOB in data): 135 | _agent = BOB 136 | key = TIME_BOB 137 | if(_agent == agent): 138 | logs[data[ENVIRONMENT]][TIME].append(data[key]) 139 | else: 140 | if _type == CONFIG: 141 | key = CONFIG 142 | for env in env_list: 143 | logs[env][key].append(data) 144 | 145 | for env in list(logs.keys()): 146 | keys = set([CURRENT_EPISODIC_REWARD, AVERAGE_EPISODIC_REWARD, AVERAGE_BATCH_LOSS, TIME]) 147 | keys_to_modify = [CONFIG, ENVIRONMENT] 148 | to_delete = True 149 | for key in logs[env]: 150 | if(key in keys and len(logs[env][key]) > 1): 151 | logs[env][key] = np.asarray(logs[env][key]) 152 | to_delete = False 153 | if (to_delete): 154 | del logs[env] 155 | else: 156 | for key in keys_to_modify: 157 | logs[env][key] = logs[env][key][0] 158 | 159 | return logs 160 | 161 | -------------------------------------------------------------------------------- /SelfPlay/utils/optim_registry.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | from utils.constant import * 3 | 4 | 5 | def get_supported_optimisers(): 6 | return set([ADAM, SGD, RMSPROP]) 7 | 8 | 9 | def choose_optimiser(optimiser_name=ADAM): 10 | if (optimiser_name == ADAM): 11 | return torch.optim.Adam 12 | elif (optimiser_name == SGD): 13 | return torch.optim.SGD 14 | elif (optimiser_name == RMSPROP): 15 | return torch.optim.RMSprop 16 | -------------------------------------------------------------------------------- /SelfPlay/utils/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import pathlib 6 | 7 | 8 | def set_seed(seed): 9 | torch.manual_seed(seed) 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | if torch.cuda.is_available(): 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | def make_dir(path): 16 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /assets/images/acrobot_pca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/assets/images/acrobot_pca.png -------------------------------------------------------------------------------- /assets/images/mazebase_selfplay_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/assets/images/mazebase_selfplay_compare.png -------------------------------------------------------------------------------- /docs/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/docs/report.pdf -------------------------------------------------------------------------------- /model/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/model/.keep -------------------------------------------------------------------------------- /plot/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shagunsodhani/memory-augmented-self-play/1032891cb149cdc8411eb2fdba8745094524b5ce/plot/.keep --------------------------------------------------------------------------------