├── img └── learning_curves.png ├── policy_improvement ├── __init__.py ├── optim_utils.py ├── dqn_update.py └── categorical_update.py ├── policy_evaluation ├── __init__.py ├── deterministic.py ├── categorical.py └── exploration_schedules.py ├── utils ├── __init__.py ├── torch_types.py ├── utils.py ├── parse_config.py └── wrappers.py ├── estimators ├── __init__.py ├── catch_net.py └── atari_net.py ├── agents ├── random_agent.py ├── __init__.py ├── evaluation_agent.py ├── categorical_dqn_agent.py ├── base_agent.py └── dqn_agent.py ├── configs ├── catch_dqn.yaml ├── catch_categorical.yaml ├── catch_dev.yaml ├── atari_dev.yaml └── atari_bench.yaml ├── LICENSE ├── Readme.md ├── data_structures ├── __init__.py ├── tensor_experience_replay.py └── ntuple_experience_replay.py ├── .gitignore ├── categorical.yml └── main.py /img/learning_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floringogianu/categorical-dqn/HEAD/img/learning_curves.png -------------------------------------------------------------------------------- /policy_improvement/__init__.py: -------------------------------------------------------------------------------- 1 | from policy_improvement.dqn_update import DQNPolicyImprovement 2 | from policy_improvement.categorical_update import CategoricalPolicyImprovement 3 | -------------------------------------------------------------------------------- /policy_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from policy_evaluation.categorical import CategoricalPolicyEvaluation 2 | from policy_evaluation.deterministic import DeterministicPolicy 3 | from policy_evaluation.exploration_schedules import get_schedule 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.wrappers import EvaluationMonitor, PreprocessFrames 2 | from utils.wrappers import SqueezeRewards, DoneAfterLostLife 3 | from utils.utils import env_factory, not_implemented 4 | from utils.parse_config import get_config 5 | from utils.torch_types import TorchTypes 6 | -------------------------------------------------------------------------------- /estimators/__init__.py: -------------------------------------------------------------------------------- 1 | # Bitdefender, 2107 2 | from estimators.atari_net import AtariNet 3 | from estimators.catch_net import CatchNet 4 | 5 | ESTIMATORS = { 6 | "atari": AtariNet, 7 | "catch": CatchNet 8 | } 9 | 10 | 11 | def get_estimator(name, in_ch, hist_len, action_no, hidden_size=128): 12 | return ESTIMATORS[name](in_ch, hist_len, action_no, hidden_size) 13 | -------------------------------------------------------------------------------- /agents/random_agent.py: -------------------------------------------------------------------------------- 1 | from .base_agent import BaseAgent 2 | 3 | 4 | class RandomAgent(BaseAgent): 5 | def __init__(self, action_space, cmdl): 6 | BaseAgent.__init__(self, action_space) 7 | 8 | self.name = "RND_agent" 9 | 10 | def evaluate_policy(self, state): 11 | return self.action_space.sample() 12 | 13 | def improve_policy(self, _state, _action, reward, state, done): 14 | pass 15 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | # from .nec_agent import NECAgent 2 | from agents.dqn_agent import DQNAgent 3 | from agents.categorical_dqn_agent import CategoricalDQNAgent 4 | from agents.random_agent import RandomAgent 5 | from agents.evaluation_agent import EvaluationAgent 6 | 7 | AGENTS = { 8 | # "nec": NECAgent, 9 | "evaluation": EvaluationAgent, 10 | "categorical": CategoricalDQNAgent, 11 | "dqn": DQNAgent, 12 | "random": RandomAgent 13 | } 14 | 15 | 16 | def get_agent(name): 17 | return AGENTS[name] 18 | -------------------------------------------------------------------------------- /utils/torch_types.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TorchTypes(object): 5 | 6 | def __init__(self, cuda=False): 7 | self.set_cuda(cuda) 8 | 9 | def set_cuda(self, use_cuda): 10 | if use_cuda: 11 | self.FT = torch.cuda.FloatTensor 12 | self.LT = torch.cuda.LongTensor 13 | self.BT = torch.cuda.ByteTensor 14 | self.IT = torch.cuda.IntTensor 15 | self.DT = torch.cuda.DoubleTensor 16 | else: 17 | self.FT = torch.FloatTensor 18 | self.LT = torch.LongTensor 19 | self.BT = torch.ByteTensor 20 | self.IT = torch.IntTensor 21 | self.DT = torch.DoubleTensor 22 | -------------------------------------------------------------------------------- /policy_evaluation/deterministic.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | 3 | 4 | class DeterministicPolicy(object): 5 | def __init__(self, policy): 6 | """Assumes policy returns an autograd.Variable""" 7 | 8 | self.name = "DP" 9 | self.policy = policy 10 | self.cuda = next(policy.parameters()).is_cuda 11 | 12 | def get_action(self, state): 13 | """ Takes best action based on estimated state-action values.""" 14 | state = state.cuda() if self.cuda else state 15 | q_val, argmax_a = self.policy( 16 | Variable(state, volatile=True)).data.max(1) 17 | """ 18 | result = self.policy(Variable(state_batch, volatile=True)) 19 | print(result) 20 | """ 21 | return (q_val.squeeze()[0], argmax_a.squeeze()[0]) 22 | -------------------------------------------------------------------------------- /policy_improvement/optim_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch.optim as optim 3 | 4 | 5 | def float_range(start, end, step): 6 | x = start 7 | if step > 0: 8 | while x < end: 9 | yield x 10 | x += step 11 | else: 12 | while x > end: 13 | yield x 14 | x += step 15 | 16 | 17 | def lr_schedule(start, end, steps_no): 18 | start, end, steps_no = float(start), float(end), float(steps_no) 19 | step = (end - start) / (steps_no - 1.) 20 | schedules = [float_range(start, end, step), itertools.repeat(end)] 21 | return itertools.chain(*schedules) 22 | 23 | 24 | def optim_factory(weights, cmdl): 25 | if cmdl.optim == "Adam": 26 | return optim.Adam(weights, lr=cmdl.lr, eps=cmdl.eps) 27 | elif cmdl.optim == "RMSprop": 28 | return optim.RMSprop(weights, lr=cmdl.lr, eps=cmdl.eps, 29 | alpha=cmdl.alpha) 30 | -------------------------------------------------------------------------------- /policy_evaluation/categorical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from utils import TorchTypes 4 | 5 | 6 | class CategoricalPolicyEvaluation(object): 7 | def __init__(self, policy, cmdl): 8 | """Assumes policy returns an autograd.Variable""" 9 | self.name = "CP" 10 | self.cmdl = cmdl 11 | self.policy = policy 12 | 13 | self.dtype = dtype = TorchTypes(cmdl.cuda) 14 | self.support = torch.linspace(cmdl.v_min, cmdl.v_max, cmdl.atoms_no) 15 | self.support = self.support.type(dtype.FT) 16 | 17 | def get_action(self, state): 18 | """ Takes best action based on estimated state-action values.""" 19 | state = state.type(self.dtype.FT) 20 | probs = self.policy(Variable(state, volatile=True)).data 21 | support = self.support.expand_as(probs) 22 | q_val, argmax_a = torch.mul(probs, support).squeeze().sum(1).max(0) 23 | return (q_val[0], argmax_a[0]) 24 | -------------------------------------------------------------------------------- /configs/catch_dqn.yaml: -------------------------------------------------------------------------------- 1 | # Catch DQN config file 2 | 3 | # env info 4 | label: "Catcher DQN" 5 | env_name: "Catcher-Level0-v0" 6 | env_class: catch 7 | 8 | # agent info 9 | agent_type: "dqn" 10 | 11 | # experiment settings 12 | seed: 99 13 | cuda: yes 14 | display_plots: no # uses vizdom for plotting 15 | 16 | # training vars 17 | training_steps: 1000000 18 | 19 | # estimator settings 20 | estimator: catch 21 | batch_size: 64 22 | hidden_size: 128 23 | hist_len: 2 24 | 25 | # exploration, q-learning, optimization settings 26 | epsilon: 1 27 | epsilon_steps: 100000 28 | experience_replay: nTupleExperienceReplay 29 | replay_mem_size: 100000 30 | start_learning_after: 1000 31 | update_freq: 2 32 | target_update_freq: 24 33 | gamma: 0.99 34 | optim: Adam 35 | lr: .000635 # decrease to 0.000135 for hist_len 4 36 | eps: 0.00015 # RMSprop: 0.01 | Adam: 0.01/batch_size 37 | 38 | # evaluator 39 | eval_steps: 12000 40 | eval_frequency: 12000 41 | eval_start: 0 42 | eval_env_name: "Catcher-Level0-v0" # eg.: deterministic or wo skipping frames 43 | 44 | # reporting 45 | report_frequency: 12000 # steps 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Florin Gogianu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/catch_categorical.yaml: -------------------------------------------------------------------------------- 1 | # CategoricalDQN for catch 2 | 3 | # env info 4 | label: "Catcher CategoricalDQN" 5 | env_name: "Catcher-Level0-v0" 6 | env_class: catch 7 | 8 | # agent info 9 | agent_type: "categorical" 10 | 11 | # experiment settings 12 | seed: 99 13 | cuda: yes 14 | display_plots: no # uses vizdom for plotting 15 | 16 | # training vars 17 | training_steps: 1000000 18 | # estimator settings 19 | estimator: catch 20 | batch_size: 64 21 | hidden_size: 128 22 | hist_len: 2 23 | # exploration, q-learning, optimization settings 24 | epsilon: 1 25 | epsilon_steps: 100000 26 | experience_replay: nTupleExperienceReplay 27 | replay_mem_size: 100000 28 | start_learning_after: 1000 29 | update_freq: 2 30 | target_update_freq: 24 31 | gamma: 0.99 32 | optim: Adam 33 | lr: .000635 # decrease to 0.000135 for hist_len 4 34 | alpha: 0.95 # RMSprop smoothing constant 35 | eps: 0.00015 # RMSprop: 0.01 | Adam: 0.01/batch_size 36 | 37 | # support settings 38 | atoms_no: 51 39 | v_min: -3 40 | v_max: 3 41 | 42 | # evaluator 43 | eval_steps: 12000 44 | eval_frequency: 12000 45 | eval_start: 0 46 | eval_env_name: "Catcher-Level0-v0" # eg.: deterministic or wo skipping frames 47 | 48 | # reporting 49 | report_frequency: 12000 # steps 50 | -------------------------------------------------------------------------------- /configs/catch_dev.yaml: -------------------------------------------------------------------------------- 1 | # Dev config file, can contain unused fields for easy switching 2 | 3 | # env info 4 | label: "Catcher Dev" 5 | env_name: "Catcher-Level0-v0" 6 | env_class: catch 7 | 8 | # agent info 9 | agent_type: "dqn" 10 | 11 | # experiment settings 12 | seed: 99 13 | cuda: yes 14 | display_plots: no # uses vizdom for plotting 15 | 16 | # training vars 17 | training_steps: 1000000 18 | 19 | # estimator settings 20 | estimator: catch 21 | batch_size: 64 22 | hidden_size: 128 23 | hist_len: 2 24 | 25 | # exploration, q-learning, optimization settings 26 | epsilon: 1 27 | epsilon_steps: 100000 28 | experience_replay: TensorExperienceReplay 29 | replay_mem_size: 100000 30 | start_learning_after: 1000 31 | cache: 16 # no of batches to cache 32 | update_freq: 2 33 | target_update_freq: 24 34 | gamma: 0.99 35 | optim: Adam 36 | lr: .000635 # decrease to 0.000135 for hist_len 4 37 | eps: 0.00015 # RMSprop: 0.01 | Adam: 0.01/batch_size 38 | alpha: 0.95 # RMSprop smoothing constant 39 | 40 | # support settings 41 | atoms_no: 51 42 | v_min: -3 43 | v_max: 3 44 | 45 | # evaluator 46 | eval_steps: 12000 47 | eval_frequency: 12000 48 | eval_start: 0 49 | eval_env_name: "Catcher-Level0-v0" # eg.: deterministic or wo skipping frames 50 | 51 | # reporting 52 | report_frequency: 12000 # steps 53 | -------------------------------------------------------------------------------- /configs/atari_dev.yaml: -------------------------------------------------------------------------------- 1 | # Atari Dev configuration file 2 | 3 | # env info 4 | label: "Atari Dev" 5 | env_name: "Breakout-v0" 6 | env_class: atari 7 | 8 | # agent info 9 | agent_type: "dqn" 10 | rescale_dims: 84 11 | 12 | # experiment settings 13 | seed: 23 14 | cuda: yes 15 | display_plots: no # uses vizdom for plotting 16 | 17 | # training settings 18 | training_steps: 80000000 19 | done_after_lost_life: yes 20 | 21 | # estimator settings 22 | estimator: atari 23 | hidden_size: 128 24 | batch_size: 64 25 | hist_len: 1 26 | 27 | # exploration, q-learning, optimization settings 28 | epsilon: 1 29 | epsilon_steps: 1000000 30 | experience_replay: nTupleExperienceReplay 31 | replay_mem_size: 1000000 32 | start_learning_after: 1000 33 | update_freq: 4 34 | target_update_freq: 40000 35 | gamma: 0.99 36 | optim: RMSprop 37 | lr: .00025 # decrease to 0.000135 for hist_len 4 38 | eps: 0.01 # RMSprop: 0.01 | Adam: 0.01/batch_size (categorical) 39 | alpha: 0.95 # RMSprop smoothing constant 40 | # clamps 41 | reward_clamp: yes # [-1, 1] 42 | 43 | # support settings 44 | atoms_no: 51 45 | v_min: -10 46 | v_max: 10 47 | 48 | # evaluator 49 | eval_steps: 150000 50 | eval_frequency: 250000 51 | eval_start: 100000 52 | eval_env_name: "Breakout-v0" # eg.: deterministic or wo skipping frames 53 | 54 | # reporting 55 | report_frequency: 50000 # steps 56 | -------------------------------------------------------------------------------- /configs/atari_bench.yaml: -------------------------------------------------------------------------------- 1 | # Example config file 2 | 3 | # env info 4 | label: "Atari Benchmark" 5 | env_name: "Breakout-v0" 6 | env_class: atari 7 | 8 | # agent info 9 | agent_type: "dqn" 10 | rescale_dims: 84 11 | 12 | # experiment settings 13 | seed: 23 14 | cuda: yes 15 | display_plots: no # uses vizdom for plotting 16 | 17 | # training settings 18 | training_steps: 30000 19 | done_after_lost_life: yes 20 | 21 | # estimator settings 22 | estimator: atari 23 | hidden_size: 128 24 | batch_size: 32 25 | hist_len: 4 26 | 27 | # exploration, q-learning, optimization settings 28 | epsilon: 1 29 | epsilon_steps: 1000 30 | experience_replay: nTupleExperienceReplay 31 | replay_mem_size: 10000 32 | start_learning_after: 1000 33 | cache: 16 # no of batches to cache 34 | update_freq: 4 35 | target_update_freq: 10000 36 | gamma: 0.99 37 | optim: RMSprop 38 | lr: .00025 # decrease to 0.000135 for hist_len 4 39 | eps: 0.01 # RMSprop: 0.01 | Adam: 0.01/batch_size (categorical) 40 | alpha: 0.95 # RMSprop smoothing constant 41 | # clamps 42 | reward_clamp: yes # [-1, 1] 43 | 44 | # support settings 45 | atoms_no: 51 46 | v_min: -10 47 | v_max: 10 48 | 49 | # evaluator 50 | eval_steps: 2000 51 | eval_frequency: 10000 52 | eval_start: 5000 53 | eval_env_name: "Breakout-v0" # eg.: deterministic or wo skipping frames 54 | 55 | # reporting 56 | report_frequency: 10000 # steps 57 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Categorical DQN. 2 | 3 | Implementation of the **Categorical DQN** as described in *A distributional 4 | Perspective on Reinforcement Learning*. 5 | 6 | Thanks to [@tudor-berariu](https://github.com/tudor-berariu) for optimisation 7 | and training tricks and for catching two nasty bugs. 8 | 9 | ## Dependencies 10 | 11 | You can take a look in the [env export file](categorical.yml) for the full 12 | list of dependencies. 13 | 14 | Install the game of Catch: 15 | ``` 16 | git clone https://github.com/floringogianu/gym_fast_envs 17 | cd gym_fast_envs 18 | 19 | pip install -r requirements.txt 20 | pip install -e . 21 | ``` 22 | 23 | Install `visdom` for reporting: `pip install visdom`. 24 | 25 | ## Training 26 | 27 | First start the `visdom` server: `python -m visdom.server`. If you don't want to install or use `visdom` make sure you deactivate the `display_plots` option in the `configs`. 28 | 29 | Train the Categorical DQN with `python main.py -cf configs/catch_categorical.yaml`. 30 | 31 | Train a DQN baseline with `python main.py -cf configs/catch_dqn.yaml`. 32 | 33 | ## To Do 34 | 35 | - [x] Migrate to `Pytorch 0.2.0`. Breaks compatibility with `0.1.12`. 36 | - [x] Add some training curves. 37 | - [x] Run on Atari. 38 | - [x] Add proper evaluation. 39 | 40 | ## Results 41 | 42 | First row is with batch size of 64, the second with 32. Will run on more seeds and average for a better comparison. Working on adding Atari results. 43 | 44 | ![Catch Learning Curves](img/learning_curves.png) 45 | -------------------------------------------------------------------------------- /data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | from data_structures.ntuple_experience_replay import nTupleExperienceReplay 2 | from data_structures.ntuple_experience_replay import CachedExperienceReplay 3 | from data_structures.tensor_experience_replay import TensorExperienceReplay 4 | 5 | 6 | class ExperienceReplay(object): 7 | @staticmethod 8 | def factory(cmdl, state_dims): 9 | type_name = cmdl.experience_replay 10 | 11 | if type_name == "nTupleExperienceReplay": 12 | if hasattr(cmdl, 'cache') and cmdl.cuda: 13 | print("[ExperienceReplay] Cached Experience Replay " 14 | + "implemented by %s." % type_name) 15 | return CachedExperienceReplay( 16 | cmdl.replay_mem_size, cmdl.batch_size, 17 | cmdl.hist_len, cmdl.cuda, cmdl.cache 18 | ) 19 | else: 20 | print("[ExperienceReplay] Implemented by %s." % type_name) 21 | return nTupleExperienceReplay( 22 | cmdl.replay_mem_size, cmdl.batch_size, 23 | cmdl.hist_len, cmdl.cuda 24 | ) 25 | 26 | if type_name == "TensorExperienceReplay": 27 | if hasattr(cmdl, 'rescale_dims'): 28 | state_dims = (cmdl.rescale_dims, cmdl.rescale_dims) 29 | print("[ExperienceReplay] Implemented by %s." % type_name) 30 | return TensorExperienceReplay( 31 | cmdl.replay_mem_size, cmdl.batch_size, 32 | cmdl.hist_len, state_dims, cmdl.cuda 33 | ) 34 | assert 0, "Bad ExperienceReplay creation: " + type_name 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /agents/evaluation_agent.py: -------------------------------------------------------------------------------- 1 | from numpy.random import uniform 2 | from estimators import get_estimator as get_model 3 | from policy_evaluation import DeterministicPolicy 4 | from policy_evaluation import CategoricalPolicyEvaluation 5 | 6 | 7 | class EvaluationAgent(object): 8 | def __init__(self, env_space, cmdl): 9 | self.name = "Evaluation" 10 | 11 | self.actions = env_space[0] 12 | self.action_no = action_no = self.actions.n 13 | self.cmdl = cmdl 14 | self.epsilon = 0.05 15 | 16 | if cmdl.agent_type == "dqn": 17 | self.policy = policy = get_model(cmdl.estimator, 1, cmdl.hist_len, 18 | self.action_no, cmdl.hidden_size) 19 | if self.cmdl.cuda: 20 | self.policy.cuda() 21 | self.policy_evaluation = DeterministicPolicy(policy) 22 | elif cmdl.agent_type == "categorical": 23 | self.policy = policy = get_model(cmdl.estimator, 1, cmdl.hist_len, 24 | (action_no, cmdl.atoms_no), 25 | hidden_size=cmdl.hidden_size) 26 | if self.cmdl.cuda: 27 | self.policy.cuda() 28 | self.policy_evaluation = CategoricalPolicyEvaluation(policy, cmdl) 29 | print("[%s] Evaluating %s agent." % (self.name, cmdl.agent_type)) 30 | 31 | self.max_q = -1000 32 | 33 | def evaluate_policy(self, state): 34 | if self.epsilon < uniform(): 35 | qval, action = self.policy_evaluation.get_action(state) 36 | self.max_q = max(qval, self.max_q) 37 | return action 38 | else: 39 | return self.actions.sample() 40 | -------------------------------------------------------------------------------- /estimators/catch_net.py: -------------------------------------------------------------------------------- 1 | """ Neural Network architecture for low-dimensional games. 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class CatchNet(nn.Module): 9 | def __init__(self, input_channels, hist_len, out_size, hidden_size=32): 10 | super(CatchNet, self).__init__() 11 | self.input_channels = input_channels 12 | self.hist_len = hist_len 13 | self.input_depth = hist_len * input_channels 14 | if type(out_size) is tuple: 15 | self.is_categorical = True 16 | self.action_no, self.atoms_no = out_size 17 | self.out_size = self.action_no * self.atoms_no 18 | else: 19 | self.is_categorical = False 20 | self.out_size = out_size 21 | self.hidden_size = hidden_size 22 | 23 | self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=5, 24 | stride=2, padding=1) 25 | self.conv2 = nn.Conv2d(32, 32, kernel_size=5, stride=2) 26 | self.lin1 = nn.Linear(512, self.hidden_size) 27 | self.head = nn.Linear(self.hidden_size, self.out_size) 28 | 29 | def forward(self, x): 30 | x = F.relu(self.conv1(x)) 31 | x = F.relu(self.conv2(x)) 32 | x = F.relu(self.lin1(x.view(x.size(0), -1))) 33 | out = self.head(x.view(x.size(0), -1)) 34 | if self.is_categorical: 35 | splits = out.chunk(self.action_no, 1) 36 | return torch.stack(list(map(lambda s: F.softmax(s), splits)), 1) 37 | else: 38 | return out 39 | 40 | def get_attributes(self): 41 | return (self.input_channels, self.hist_len, self.action_no, 42 | self.hidden_size) 43 | -------------------------------------------------------------------------------- /estimators/atari_net.py: -------------------------------------------------------------------------------- 1 | """ Neural Network architecture for Atari games. 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class AtariNet(nn.Module): 9 | def __init__(self, input_channels, hist_len, out_size, hidden_size=256): 10 | super(AtariNet, self).__init__() 11 | self.input_channels = input_channels 12 | self.hist_len = hist_len 13 | self.input_depth = input_depth = hist_len * input_channels 14 | if type(out_size) is tuple: 15 | self.is_categorical = True 16 | self.action_no, self.atoms_no = out_size 17 | self.out_size = self.action_no * self.atoms_no 18 | else: 19 | self.is_categorical = False 20 | self.out_size = out_size 21 | self.hidden_size = hidden_size 22 | 23 | self.conv1 = nn.Conv2d(input_depth, 32, kernel_size=8, stride=4) 24 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 25 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 26 | self.lin1 = nn.Linear(64 * 7 * 7, self.hidden_size) 27 | self.head = nn.Linear(self.hidden_size, self.out_size) 28 | 29 | def forward(self, x): 30 | x = F.relu(self.conv1(x)) 31 | x = F.relu(self.conv2(x)) 32 | x = F.relu(self.conv3(x)) 33 | x = F.relu(self.lin1(x.view(x.size(0), -1))) 34 | out = self.head(x.view(x.size(0), -1)) 35 | if self.is_categorical: 36 | splits = out.chunk(self.action_no, 1) 37 | return torch.stack(list(map(lambda s: F.softmax(s), splits)), 1) 38 | else: 39 | return out 40 | 41 | def get_attributes(self): 42 | return (self.input_channels, self.hist_len, self.action_no, 43 | self.hidden_size) 44 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import gym 3 | import gym_fast_envs # noqa 4 | from termcolor import colored as clr 5 | 6 | from utils import EvaluationMonitor 7 | from utils import PreprocessFrames 8 | from utils import SqueezeRewards 9 | from utils import DoneAfterLostLife 10 | 11 | 12 | def env_factory(cmdl, mode): 13 | # Undo the default logger and configure a new one. 14 | gym.undo_logger_setup() 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.WARNING) 17 | 18 | print(clr("[Main] Constructing %s environment." % mode, attrs=['bold'])) 19 | env = gym.make(cmdl.env_name) 20 | 21 | if hasattr(cmdl, 'rescale_dims'): 22 | state_dims = (cmdl.rescale_dims, cmdl.rescale_dims) 23 | else: 24 | state_dims = env.observation_space.shape[0:2] 25 | 26 | env_class, hist_len, cuda = cmdl.env_class, cmdl.hist_len, cmdl.cuda 27 | 28 | if mode == "training": 29 | env = PreprocessFrames(env, env_class, hist_len, state_dims, cuda) 30 | if hasattr(cmdl, 'reward_clamp') and cmdl.reward_clamp: 31 | env = SqueezeRewards(env) 32 | if hasattr(cmdl, 'done_after_lost_life') and cmdl.done_after_lost_life: 33 | env = DoneAfterLostLife(env) 34 | print('-' * 50) 35 | return env 36 | 37 | elif mode == "evaluation": 38 | if cmdl.eval_env_name != cmdl.env_name: 39 | print(clr("[%s] Warning! evaluating on a different env: %s" 40 | % ("Main", cmdl.eval_env_name), 'red', attrs=['bold'])) 41 | env = gym.make(cmdl.eval_env_name) 42 | 43 | env = PreprocessFrames(env, env_class, hist_len, state_dims, cuda) 44 | env = EvaluationMonitor(env, cmdl) 45 | print('-' * 50) 46 | return env 47 | 48 | 49 | def not_implemented(obj): 50 | import inspect 51 | method_name = inspect.stack()[1][3] 52 | raise RuntimeError( 53 | clr(("%s.%s not implemented nor delegated." % 54 | (obj.name, method_name)), 'white', 'on_red')) 55 | -------------------------------------------------------------------------------- /agents/categorical_dqn_agent.py: -------------------------------------------------------------------------------- 1 | from agents.dqn_agent import DQNAgent 2 | from estimators import get_estimator as get_model 3 | from policy_evaluation import CategoricalPolicyEvaluation 4 | from policy_improvement import CategoricalPolicyImprovement 5 | 6 | 7 | class CategoricalDQNAgent(DQNAgent): 8 | def __init__(self, action_space, cmdl): 9 | DQNAgent.__init__(self, action_space, cmdl) 10 | self.name = "Categorical_agent" 11 | self.cmdl = cmdl 12 | 13 | hist_len, action_no = cmdl.hist_len, self.action_no 14 | self.policy = policy = get_model(cmdl.estimator, 1, hist_len, 15 | (action_no, cmdl.atoms_no), 16 | hidden_size=cmdl.hidden_size) 17 | self.target = target = get_model(cmdl.estimator, 1, hist_len, 18 | (action_no, cmdl.atoms_no), 19 | hidden_size=cmdl.hidden_size) 20 | if self.cmdl.cuda: 21 | self.policy.cuda() 22 | self.target.cuda() 23 | 24 | self.policy_evaluation = CategoricalPolicyEvaluation(policy, cmdl) 25 | self.policy_improvement = CategoricalPolicyImprovement( 26 | policy, target, cmdl) 27 | 28 | def improve_policy(self, _s, _a, r, s, done): 29 | h = self.cmdl.hist_len - 1 30 | self.replay_memory.push(_s[0, h], _a, r, done) 31 | 32 | if len(self.replay_memory) < self.cmdl.start_learning_after: 33 | return 34 | 35 | if (self.step_cnt % self.cmdl.update_freq == 0) and ( 36 | len(self.replay_memory) > self.cmdl.batch_size): 37 | 38 | # get batch of transitions 39 | batch = self.replay_memory.sample() 40 | 41 | # compute gradients 42 | self.policy_improvement.accumulate_gradient(*batch) 43 | self.policy_improvement.update_model() 44 | 45 | if self.step_cnt % self.cmdl.target_update_freq == 0: 46 | self.policy_improvement.update_target_net() 47 | 48 | def display_model_stats(self): 49 | self.policy_improvement.get_model_stats() 50 | print("MaxQ=%2.2f. MemSz=%5d. Epsilon=%.2f." % ( 51 | self.max_q, len(self.replay_memory), self.epsilon)) 52 | -------------------------------------------------------------------------------- /categorical.yml: -------------------------------------------------------------------------------- 1 | name: atari 2 | channels: 3 | - soumith 4 | - defaults 5 | dependencies: 6 | - cffi=1.10.0=py36_0 7 | - cycler=0.10.0=py36_0 8 | - dbus=1.10.20=0 9 | - decorator=4.0.11=py36_0 10 | - expat=2.1.0=0 11 | - fontconfig=2.12.1=3 12 | - freetype=2.5.5=2 13 | - glib=2.50.2=1 14 | - gst-plugins-base=1.8.0=0 15 | - gstreamer=1.8.0=0 16 | - icu=54.1=0 17 | - ipython=6.1.0=py36_0 18 | - ipython_genutils=0.2.0=py36_0 19 | - jbig=2.1=0 20 | - jedi=0.10.2=py36_2 21 | - jpeg=9b=0 22 | - jsonschema=2.6.0=py36_0 23 | - jupyter_core=4.3.0=py36_0 24 | - libffi=3.2.1=1 25 | - libgcc=5.2.0=0 26 | - libgfortran=3.0.0=1 27 | - libiconv=1.14=0 28 | - libpng=1.6.27=0 29 | - libtiff=4.0.6=3 30 | - libxcb=1.12=1 31 | - libxml2=2.9.4=0 32 | - matplotlib=2.0.2=np113py36_0 33 | - mkl=2017.0.3=0 34 | - nbformat=4.3.0=py36_0 35 | - numpy=1.13.1=py36_0 36 | - olefile=0.44=py36_0 37 | - openssl=1.0.2l=0 38 | - pandas=0.20.2=np113py36_0 39 | - path.py=10.3.1=py36_0 40 | - pcre=8.39=1 41 | - pexpect=4.2.1=py36_0 42 | - pickleshare=0.7.4=py36_0 43 | - pillow=4.2.1=py36_0 44 | - pip=9.0.1=py36_1 45 | - prompt_toolkit=1.0.14=py36_0 46 | - ptyprocess=0.5.1=py36_0 47 | - pycparser=2.17=py36_0 48 | - pygments=2.2.0=py36_0 49 | - pyparsing=2.1.4=py36_0 50 | - pyqt=5.6.0=py36_2 51 | - python=3.6.2=0 52 | - python-dateutil=2.6.0=py36_0 53 | - pytz=2017.2=py36_0 54 | - qt=5.6.2=4 55 | - readline=6.2=2 56 | - requests=2.14.2=py36_0 57 | - scipy=0.19.1=np113py36_0 58 | - seaborn=0.7.1=py36_0 59 | - setuptools=27.2.0=py36_0 60 | - simplegeneric=0.8.1=py36_1 61 | - sip=4.18=py36_0 62 | - six=1.10.0=py36_0 63 | - sqlite=3.13.0=0 64 | - termcolor=1.1.0=py36_0 65 | - tk=8.5.18=0 66 | - traitlets=4.3.2=py36_0 67 | - wcwidth=0.1.7=py36_0 68 | - wheel=0.29.0=py36_0 69 | - xz=5.2.2=1 70 | - zlib=1.2.8=3 71 | - cuda80=1.0=0 72 | - pytorch=0.1.12=py36_2cu80 73 | - torchvision=0.1.8=py36_2 74 | - pip: 75 | - atari-py==0.1.1 76 | - certifi==2017.4.17 77 | - chardet==3.0.4 78 | - gym (/home/florin/Tools/pip-packages/gym)==0.9.2 79 | - gym-fast-envs (/home/florin/Tools/pip-packages/gym_fast_envs)==0.1 80 | - idna==2.5 81 | - ipython-genutils==0.2.0 82 | - jupyter-core==4.3.0 83 | - mccabe==0.6.1 84 | - prompt-toolkit==1.0.14 85 | - pycodestyle==2.3.1 86 | - pyflakes==1.5.0 87 | - pyglet==1.2.4 88 | - pyopengl==3.1.0 89 | - pyyaml==3.12 90 | - torch==0.1.12.post2 91 | - urllib3==1.21.1 92 | - visdom 93 | prefix: /home/florin/Tools/anaconda3/envs/atari 94 | 95 | -------------------------------------------------------------------------------- /utils/parse_config.py: -------------------------------------------------------------------------------- 1 | """ Functions and classes for parsing config files and command line arguments. 2 | """ 3 | import argparse 4 | import yaml 5 | import os 6 | from termcolor import colored as clr 7 | 8 | 9 | def parse_cmd_args(): 10 | """ Return parsed command line arguments. 11 | """ 12 | p = argparse.ArgumentParser(description='') 13 | p.add_argument('-l', '--label', type=str, default="default_label", 14 | metavar='label_name::str', 15 | help='Label of the current experiment') 16 | p.add_argument('-id', '--id', type=int, default=0, 17 | metavar='label_name::str', 18 | help='Id of this instance running within the current' + 19 | 'experiment') 20 | p.add_argument('-cf', '--config', type=str, default="catch_dev", 21 | metavar='path::str', 22 | help='Path to the config file.') 23 | p.add_argument('-r', '--results', type=str, default="./experiments", 24 | metavar='path::str', 25 | help='Path of the results folder.') 26 | args = p.parse_args() 27 | return args 28 | 29 | 30 | def to_namespace(d): 31 | """ Convert a dict to a namespace. 32 | """ 33 | n = argparse.Namespace() 34 | for k, v in d.items(): 35 | setattr(n, k, to_namespace(v) if isinstance(v, dict) else v) 36 | return n 37 | 38 | 39 | def inject_args(n, args): 40 | # inject some of the cmdl args into the config namespace 41 | setattr(n, "experiment_id", args.id) 42 | setattr(n, "results_path", args.results) 43 | return n 44 | 45 | 46 | def check_paths(cmdl): 47 | if not os.path.exists(cmdl.results_path): 48 | print( 49 | clr("%s path for saving results does not exist. Please create it." 50 | % cmdl.results_path, 'red', attrs=['bold'])) 51 | raise IOError 52 | else: 53 | print(clr("Warning, data in %s will be overwritten." 54 | % cmdl.results_path, 'red', attrs=['bold'])) 55 | 56 | 57 | def parse_config_file(path): 58 | f = open(path) 59 | config_data = yaml.load(f, Loader=yaml.SafeLoader) 60 | f.close() 61 | return to_namespace(config_data) 62 | 63 | 64 | def get_config(): 65 | args = parse_cmd_args() 66 | cmdl = parse_config_file(args.config) 67 | cmdl = inject_args(cmdl, args) 68 | check_paths(cmdl) 69 | return cmdl 70 | -------------------------------------------------------------------------------- /policy_evaluation/exploration_schedules.py: -------------------------------------------------------------------------------- 1 | """ Various exploration schedules. 2 | 3 | * constant_schedule(value) 4 | constant_schedule(.1) => .1, .1, .1, .1, .1, ... 5 | 6 | * linear_schedule(start, end, steps_no) 7 | linear_schedule(.5, .1, 5) => .5, .4, .3, .2, .1, .1, .1, .1, ... 8 | 9 | * log_schedule(start, end, steps_no) 10 | log_schedule(1, 0.001, 3) => 1., .1, .01, .001, .001, .001, ... 11 | """ 12 | 13 | import itertools 14 | 15 | 16 | def float_range(start, end, step): 17 | x = start 18 | if step > 0: 19 | while x < end: 20 | yield x 21 | x += step 22 | else: 23 | while x > end: 24 | yield x 25 | x += step 26 | 27 | 28 | def constant_schedule(epsilon=0.05): 29 | return itertools.repeat(epsilon) 30 | 31 | 32 | def linear_schedule(start, end, steps_no): 33 | start, end, steps_no = float(start), float(end), float(steps_no) 34 | step = (end - start) / (steps_no - 1.) 35 | schedules = [float_range(start, end, step), itertools.repeat(end)] 36 | return itertools.chain(*schedules) 37 | 38 | 39 | def log_schedule(start, end, steps_no): 40 | from math import log, exp 41 | log_start, log_end = log(start), log(end) 42 | log_step = (log_end - log_start) / (steps_no - 1) 43 | log_range = float_range(log_start, log_end, log_step) 44 | return itertools.chain(map(exp, log_range), itertools.repeat(end)) 45 | 46 | 47 | ALL_SCHEDULES = { 48 | "constant": constant_schedule, 49 | "linear": linear_schedule, 50 | "log": log_schedule 51 | } 52 | 53 | 54 | def get_schedule(name, *args): 55 | return ALL_SCHEDULES[name](*args) 56 | 57 | 58 | def get_random_schedule(args, probs): 59 | assert(len(args) == len(probs)) 60 | import numpy as np 61 | return get_schedule(*args[np.random.choice(len(args), p=probs)]) 62 | 63 | 64 | if __name__ == "__main__": 65 | import sys 66 | 67 | const = get_schedule("constant", [.1]) 68 | sys.stdout.write("Constant(0.1):") 69 | for _ in range(10): 70 | sys.stdout.write(" {:.2f}".format(next(const))) 71 | sys.stdout.write("\n") 72 | 73 | linear = get_schedule("linear", [.5, .1, 5]) 74 | sys.stdout.write("Linear Schedule(.5, .1, 5):") 75 | for _ in range(10): 76 | sys.stdout.write(" {:.2f}".format(next(linear))) 77 | sys.stdout.write("\n") 78 | 79 | logarithmic = get_schedule("log", [1, .001, 4]) 80 | sys.stdout.write("Logarithmic Schedule(1, .001, 4):") 81 | for _ in range(10): 82 | sys.stdout.write(" {:.3f}".format(next(logarithmic))) 83 | sys.stdout.write("\n") 84 | -------------------------------------------------------------------------------- /agents/base_agent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from termcolor import colored as clr 3 | from utils import not_implemented 4 | 5 | 6 | class BaseAgent(object): 7 | def __init__(self, env_space): 8 | self.actions = env_space[0] 9 | self.action_no = self.actions.n 10 | self.state_dims = env_space[1].shape[0:2] 11 | 12 | self.step_cnt = 0 13 | self.ep_cnt = 0 14 | self.ep_reward_cnt = 0 15 | self.ep_reward = [] 16 | self.max_mean_rw = -100 17 | 18 | def evaluate_policy(self, obs): 19 | not_implemented(self) 20 | 21 | def improve_policy(self, _state, _action, reward, state, done): 22 | not_implemented(self) 23 | 24 | def gather_stats(self, reward, done): 25 | self.step_cnt += 1 26 | self.ep_reward_cnt += reward 27 | if done: 28 | self.ep_cnt += 1 29 | self.ep_reward.append(self.ep_reward_cnt) 30 | self.ep_reward_cnt = 0 31 | 32 | def display_setup(self, env, config): 33 | emph = ["env_name", "agent_type", "label", "batch_size", "lr", 34 | "hist_len"] 35 | print("-------------------------------------------------") 36 | for k in config.__dict__: 37 | if config.__dict__[k] is not None: 38 | v = config.__dict__[k] 39 | space = "." * (32 - len(k)) 40 | config_line = "%s: %s %s" % (k, space, v) 41 | for e in emph: 42 | if k == e: 43 | config_line = clr(config_line, attrs=['bold']) 44 | print(config_line) 45 | print("-------------------------------------------------") 46 | custom = {"no_of_actions": self.action_no} 47 | for k, v in custom.items(): 48 | space = "." * (32 - len(k)) 49 | print("%s: %s %s" % (k, space, v)) 50 | print("-------------------------------------------------") 51 | 52 | def display_stats(self, start_time): 53 | fps = self.cmdl.report_frequency / (time.perf_counter() - start_time) 54 | 55 | print(clr("[%s] step=%7d, fps=%.2f " % (self.name, self.step_cnt, fps), 56 | attrs=['bold'])) 57 | self.ep_reward.clear() 58 | 59 | def display_final_report(self, ep_cnt, step_cnt, global_time): 60 | elapsed_time = time.perf_counter() - global_time 61 | fps = step_cnt / elapsed_time 62 | print(clr("[ %s ] finished after %d eps, %d steps. " 63 | % ("Main", ep_cnt, step_cnt), 'white', 'on_grey')) 64 | print(clr("[ %s ] finished after %.2fs, %.2ffps. " 65 | % ("Main", elapsed_time, fps), 'white', 'on_grey')) 66 | 67 | def display_model_stats(self): 68 | pass 69 | -------------------------------------------------------------------------------- /agents/dqn_agent.py: -------------------------------------------------------------------------------- 1 | from numpy.random import uniform 2 | from agents.base_agent import BaseAgent 3 | from estimators import get_estimator as get_model 4 | from policy_evaluation import DeterministicPolicy as DQNEvaluation 5 | from policy_evaluation import get_schedule as get_epsilon_schedule 6 | from policy_improvement import DQNPolicyImprovement as DQNImprovement 7 | from data_structures import ExperienceReplay 8 | from utils import TorchTypes 9 | 10 | 11 | class DQNAgent(BaseAgent): 12 | def __init__(self, env_space, cmdl): 13 | BaseAgent.__init__(self, env_space) 14 | self.name = "DQN_agent" 15 | self.cmdl = cmdl 16 | eps = self.cmdl.epsilon 17 | e_steps = self.cmdl.epsilon_steps 18 | 19 | self.policy = policy = get_model(cmdl.estimator, 1, cmdl.hist_len, 20 | self.action_no, cmdl.hidden_size) 21 | self.target = target = get_model(cmdl.estimator, 1, cmdl.hist_len, 22 | self.action_no, cmdl.hidden_size) 23 | if self.cmdl.cuda: 24 | self.policy.cuda() 25 | self.target.cuda() 26 | self.policy_evaluation = DQNEvaluation(policy) 27 | self.policy_improvement = DQNImprovement(policy, target, cmdl) 28 | 29 | self.exploration = get_epsilon_schedule("linear", eps, 0.05, e_steps) 30 | self.replay_memory = ExperienceReplay.factory(cmdl, self.state_dims) 31 | 32 | self.dtype = TorchTypes(cmdl.cuda) 33 | self.max_q = -1000 34 | 35 | def evaluate_policy(self, state): 36 | self.epsilon = next(self.exploration) 37 | if self.epsilon < uniform(): 38 | qval, action = self.policy_evaluation.get_action(state) 39 | self.max_q = max(qval, self.max_q) 40 | return action 41 | else: 42 | return self.actions.sample() 43 | 44 | def improve_policy(self, _s, _a, r, s, done): 45 | h = self.cmdl.hist_len - 1 46 | self.replay_memory.push(_s[0, h], _a, r, done) 47 | 48 | if len(self.replay_memory) < self.cmdl.start_learning_after: 49 | return 50 | 51 | if (self.step_cnt % self.cmdl.update_freq == 0) and ( 52 | len(self.replay_memory) > self.cmdl.batch_size): 53 | 54 | # get batch of transitions 55 | batch = self.replay_memory.sample() 56 | 57 | # compute gradients 58 | self.policy_improvement.accumulate_gradient(*batch) 59 | self.policy_improvement.update_model() 60 | 61 | if self.step_cnt % self.cmdl.target_update_freq == 0: 62 | self.policy_improvement.update_target_net() 63 | 64 | def display_model_stats(self): 65 | self.policy_improvement.get_model_stats() 66 | print("MaxQ=%2.2f. MemSz=%5d. Epsilon=%.2f." % ( 67 | self.max_q, len(self.replay_memory), self.epsilon)) 68 | self.max_q = -1000 69 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gc, time # noqa 2 | import gym, gym_fast_envs # noqa 3 | import torch, numpy # noqa 4 | 5 | import utils 6 | from agents import get_agent 7 | 8 | 9 | def train_agent(cmdl): 10 | global_time = time.perf_counter() 11 | 12 | env = utils.env_factory(cmdl, "training") 13 | eval_env = utils.env_factory(cmdl, "evaluation") 14 | 15 | name = cmdl.agent_type 16 | env_space = (env.action_space, env.observation_space) 17 | agent = get_agent(name)(env_space, cmdl) 18 | eval_env_space = (env.action_space, env.observation_space) 19 | eval_agent = get_agent("evaluation")(eval_env_space, cmdl) 20 | 21 | agent.display_setup(env, cmdl) 22 | 23 | ep_cnt = 1 24 | fps_time = time.perf_counter() 25 | 26 | s, r, done = env.reset(), 0, False 27 | 28 | for step_cnt in range(cmdl.training_steps): 29 | a = agent.evaluate_policy(s) 30 | _s, _a = s.clone(), a 31 | s, r, done, _ = env.step(a) 32 | agent.improve_policy(_s, _a, r, s, done) 33 | 34 | step_cnt += 1 35 | agent.gather_stats(r, done) 36 | 37 | # Do some reporting 38 | if step_cnt != 0 and step_cnt % cmdl.report_frequency == 0: 39 | agent.display_stats(fps_time) 40 | agent.display_model_stats() 41 | fps_time = time.perf_counter() 42 | gc.collect() 43 | 44 | # Start doing an evaluation 45 | eval_ready = step_cnt >= cmdl.eval_start 46 | if eval_ready and (step_cnt % cmdl.eval_frequency == 0): 47 | eval_time = time.perf_counter() 48 | evaluate_agent(step_cnt, eval_env, eval_agent, 49 | agent.policy, cmdl) 50 | gc.collect() 51 | fps_time = fps_time + (time.perf_counter() - eval_time) 52 | 53 | if done: 54 | ep_cnt += 1 55 | s, r, done = env.reset(), 0, False 56 | 57 | agent.display_final_report(ep_cnt, step_cnt, global_time) 58 | 59 | 60 | def evaluate_agent(crt_training_step, eval_env, eval_agent, policy, cmdl): 61 | print("[Evaluator] starting @ %d training steps:" % crt_training_step) 62 | agent = eval_agent 63 | 64 | eval_env.get_crt_step(crt_training_step) 65 | # need to change this 66 | agent.policy_evaluation.policy.load_state_dict(policy.state_dict()) 67 | 68 | step_cnt = 0 69 | s, r, done = eval_env.reset(), 0, False 70 | while step_cnt < cmdl.eval_steps: 71 | a = agent.evaluate_policy(s) 72 | s, r, done, _ = eval_env.step(a) 73 | step_cnt += 1 74 | if done: 75 | s, r, done = eval_env.reset(), 0, False 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | # Parse cmdl args for the config file and return config as Namespace 81 | config = utils.get_config() 82 | 83 | # Assuming everything in the config is deterministic already. 84 | torch.manual_seed(config.seed) 85 | numpy.random.seed(config.seed) 86 | torch.set_num_threads(4) 87 | 88 | # Let's do this! 89 | train_agent(config) 90 | -------------------------------------------------------------------------------- /data_structures/tensor_experience_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import TorchTypes 3 | 4 | 5 | class TensorCircularBuffer(object): 6 | def __init__(self, capacity, hist_len, state_dims, cuda): 7 | self.capacity = capacity 8 | self.state_dims = state_dims 9 | self.dtype = TorchTypes(cuda) 10 | 11 | self.position = 0 12 | # we won't be initializing the full memory on the CPU by default 13 | self.memory = { 14 | "_state": torch.ByteTensor(capacity, 1, *state_dims).fill_(0), 15 | "_action": torch.LongTensor(capacity, 1).fill_(0), 16 | "reward": torch.FloatTensor(capacity, 1).fill_(0), 17 | "done": torch.ByteTensor(capacity, 1).fill_(0) 18 | } 19 | self.full_idx = -1 20 | print("[Experience Replay] Done allocating main memory.") 21 | 22 | def push(self, _s, _a, r, d): 23 | idx = self.position 24 | _s = _s.unsqueeze(0) # (24, 24) -> (1, 24, 24) 25 | self.memory["_state"][idx] = _s * 255 26 | self.memory["_action"][idx, 0] = _a 27 | self.memory["reward"][idx, 0] = r 28 | self.memory["done"][idx, 0] = 0 if d else 1 29 | 30 | self.position = (self.position + 1) % self.capacity 31 | if self.full_idx < (self.capacity - 2): 32 | self.full_idx += 1 33 | 34 | def __len__(self): 35 | return self.full_idx 36 | 37 | 38 | class TensorExperienceReplay(TensorCircularBuffer): 39 | def __init__(self, capacity, batch_size, hist_len, state_dims, cuda): 40 | TensorCircularBuffer.__init__(self, capacity, hist_len, state_dims, 41 | cuda) 42 | self.hist_len = hist_len 43 | self.batch_size = batch_size 44 | batch_state_dims = (batch_size, hist_len, *state_dims) 45 | dtype = self.dtype 46 | 47 | self._states = dtype.FT(*batch_state_dims).fill_(0) 48 | self._actions = dtype.LT(batch_size, 1).fill_(0) 49 | self.states = dtype.FT(*batch_state_dims).fill_(0) 50 | self.rewards = dtype.FT(batch_size, 1).fill_(0) 51 | self.done = dtype.BT(batch_size, 1).fill_(0) 52 | print("[Experience Replay] Done allocating cuda batch.") 53 | 54 | def sample(self): 55 | batch_sz = self.batch_size 56 | memory = self.memory 57 | h = self.hist_len 58 | 59 | idxs = torch.LongTensor(batch_sz).random_(h, self.full_idx - 1) 60 | 61 | # need to figure out how to use idx directly 62 | for i in range(batch_sz): 63 | idx = idxs[i] 64 | self._states[i] = memory["_state"][idx-h:idx].float() / 255 65 | self.states[i] = memory["_state"][(idx-h)+1:idx+1].float() / 255 66 | self._actions[i] = memory["_action"][idx-1] 67 | self.rewards[i] = memory["reward"][idx-1] 68 | self.done[i] = memory["done"][idx-1] 69 | 70 | return [batch_sz, self._states, self._actions, 71 | self.rewards, self.states, self.done] 72 | 73 | """ 74 | # This ain't faster. 75 | # Need to find a better solution for indexing 76 | def sample(self): 77 | batch_sz = self.batch_size 78 | memory = self.memory 79 | h = self.hist_len 80 | dtype = self.dtype 81 | 82 | idxs = list(torch.LongTensor(batch_sz).random_(h, self.full_idx - 1)) 83 | 84 | s_idxs = [ix - j for ix in idxs for j in range(h)] 85 | ns_idxs = [(ix+1) - j for ix in idxs for j in range(h)] 86 | 87 | stx = torch.LongTensor(s_idxs).unsqueeze(1).unsqueeze(1).unsqueeze(1) 88 | nstx = torch.LongTensor(ns_idxs).unsqueeze(1).unsqueeze(1).unsqueeze(1) 89 | idxs = torch.LongTensor(idxs).unsqueeze(1) 90 | 91 | stx = stx.expand(len(stx), 1, *self.state_dims) 92 | nstx = stx.expand(len(nstx), 1, *self.state_dims) 93 | 94 | _states = (memory["_state"].gather(0, stx).float() / 255).view( 95 | batch_sz, h, *self.state_dims).type(dtype.FT) 96 | states = (memory["_state"].gather(0, nstx).float() / 255).view( 97 | batch_sz, h, *self.state_dims).type(dtype.FT) 98 | _actions = memory["_action"].gather(0, idxs).type(dtype.LT) 99 | rewards = memory["reward"].gather(0, idxs).type(dtype.FT) 100 | done = memory["done"].gather(0, idxs).type(dtype.BT) 101 | 102 | return [batch_sz, _states, _actions, rewards, states, done] 103 | """ 104 | -------------------------------------------------------------------------------- /policy_improvement/dqn_update.py: -------------------------------------------------------------------------------- 1 | """ Deep Q-Learning policy improvement. 2 | """ 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from termcolor import colored as clr 7 | from utils import TorchTypes 8 | from policy_improvement.optim_utils import optim_factory, lr_schedule 9 | 10 | 11 | class DQNPolicyImprovement(object): 12 | """ Deep Q-Learning training method. """ 13 | 14 | def __init__(self, policy, target_policy, cmdl): 15 | self.name = "DQN-PI" 16 | self.cmdl = cmdl 17 | self.policy = policy 18 | self.target_policy = target_policy 19 | self.lr = cmdl.lr 20 | self.gamma = cmdl.gamma 21 | 22 | self.optimizer = optim_factory(self.policy.parameters(), cmdl) 23 | self.optimizer.zero_grad() 24 | self.lr_generator = lr_schedule(cmdl.lr, 0.00001, cmdl.training_steps) 25 | 26 | self.dtype = TorchTypes(cmdl.cuda) 27 | 28 | def accumulate_gradient(self, batch_sz, states, actions, rewards, 29 | next_states, mask): 30 | """ Compute the temporal difference error. 31 | td_error = (r + gamma * max Q(s_,a)) - Q(s,a) 32 | """ 33 | states = Variable(states) 34 | actions = Variable(actions) 35 | rewards = Variable(rewards.squeeze()) 36 | next_states = Variable(next_states, volatile=True) 37 | 38 | # Compute Q(s, a) 39 | q_values = self.policy(states) 40 | q_values = q_values.gather(1, actions) 41 | 42 | # Compute Q(s_, a) 43 | q_target_values = Variable(torch.zeros(batch_sz).type(self.dtype.FT)) 44 | 45 | # Bootstrap for non-terminal states 46 | q_target_values[mask] = self.target_policy(next_states).max( 47 | 1, keepdim=True)[0][mask] 48 | q_target_values.volatile = False # So we don't mess the huber loss 49 | expected_q_values = (q_target_values * self.gamma) + rewards 50 | 51 | # Compute Huber loss 52 | loss = F.smooth_l1_loss(q_values, expected_q_values) 53 | 54 | # Accumulate gradients 55 | loss.backward() 56 | 57 | def update_model(self): 58 | if self.cmdl.optim == "RMSprop": 59 | lr = next(self.lr_generator) 60 | for param_group in self.optimizer.param_groups: 61 | param_group['lr'] = lr 62 | self.optimizer.step() 63 | self.optimizer.zero_grad() 64 | 65 | def update_target_net(self): 66 | """ Update the target net with the parameters in the online model.""" 67 | self.target_policy.load_state_dict(self.policy.state_dict()) 68 | 69 | def get_model_stats(self): 70 | param_abs_mean = 0 71 | grad_abs_mean = 0 72 | t_param_abs_mean = 0 73 | n_params = 0 74 | for p in self.policy.parameters(): 75 | param_abs_mean += p.data.abs().sum() 76 | grad_abs_mean += p.grad.data.abs().sum() 77 | n_params += p.data.nelement() 78 | for t in self.target_policy.parameters(): 79 | t_param_abs_mean += t.data.abs().sum() 80 | 81 | print("Wm: %.9f | Gm: %.9f | Tm: %.9f" % ( 82 | param_abs_mean / n_params, 83 | grad_abs_mean / n_params, 84 | t_param_abs_mean / n_params)) 85 | 86 | def _debug_transitions(self, mask, reward_batch): 87 | if mask[0, 0] == 0: 88 | r = reward_batch.data[0, 0] 89 | if r == 1.0: 90 | print(r) 91 | 92 | def _debug_states(self, state_batch, next_state_batch, mask, target): 93 | batch_idx = 23 94 | for k in range(state_batch.size(1)): 95 | for i in range(24): 96 | for j in range(24): 97 | px = state_batch[batch_idx, k, i, j] 98 | if px < 0.90: 99 | print(clr("%.2f " % px, 'magenta'), end="") 100 | else: 101 | print(("%.2f " % px), end="") 102 | print() 103 | print() 104 | print("************ NEXT STATE *********************") 105 | for v in range(next_state_batch.size(1)): 106 | for i in range(24): 107 | for j in range(24): 108 | px = next_state_batch[batch_idx, v, i, j] 109 | if px < 0.90: 110 | print(clr("%.2f " % px, 'magenta'), end="") 111 | else: 112 | print(clr("%.2f " % px, 'white'), end="") 113 | print() 114 | print() 115 | if mask[batch_idx, 0] == 0: 116 | print(clr("Done batch ............", 'magenta')) 117 | print(target[batch_idx]) 118 | else: 119 | print(".......................") 120 | -------------------------------------------------------------------------------- /data_structures/ntuple_experience_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | from utils import TorchTypes 4 | 5 | 6 | Transition = namedtuple('Transition', 7 | ('state', 'action', 'reward', 'done')) 8 | BatchTransition = namedtuple('BatchTransition', 9 | ('state', 'action', 'reward', 'state_', 'done')) 10 | 11 | 12 | class CircularBuffer(object): 13 | def __init__(self, capacity=100000): 14 | self.capacity = capacity 15 | self.memory = [] 16 | self.position = 0 17 | self.fill_idx = -1 18 | 19 | def push(self, s, a, r, d): 20 | s = s.unsqueeze(0).unsqueeze(0) # [24 24] --> [1, 1, 24, 24] 21 | if len(self.memory) < self.capacity: 22 | self.memory.append(Transition((s * 255).byte(), a, r, d)) 23 | self.fill_idx += 1 24 | else: 25 | self.memory[self.position] = Transition((s * 255).byte(), a, r, d) 26 | self.position = (self.position + 1) % self.capacity 27 | 28 | def get_batch(self): 29 | return self.memory[:self.position] 30 | 31 | def reset(self): 32 | self.memory.clear() 33 | self.position = 0 34 | 35 | def __len__(self): 36 | return len(self.memory) 37 | 38 | 39 | class nTupleExperienceReplay(CircularBuffer): 40 | def __init__(self, capacity, batch_size, hist_len, cuda): 41 | CircularBuffer.__init__(self, capacity) 42 | self.batch_size = batch_size 43 | self.dtype = TorchTypes(cuda) 44 | self.hist_len = hist_len 45 | 46 | def sample(self, batch_size=None): 47 | batch_size = self.batch_size if batch_size is None else batch_size 48 | return self._sample(batch_size) 49 | 50 | def _sample(self, batch_size=None): 51 | fidx = self.fill_idx - 1 # we can only index up to capacity - 2 52 | hist_len = self.hist_len 53 | mem = self.memory 54 | 55 | # sample batch_size indices 56 | idxs = torch.LongTensor(batch_size).random_(hist_len, fidx) 57 | 58 | # retrieve a list of ((hist_len + 1 transitions) * batch_size) 59 | samples = [mem[idxs[i]-hist_len:idxs[i]+1] for i in range(batch_size)] 60 | 61 | # concatenate frames for s and s_ 62 | # and create a new list of transitions (s, a, r_, s_, d_) 63 | transitions = [BatchTransition( 64 | torch.cat([samples[j][i].state for i in range(hist_len)], 1), 65 | samples[j][hist_len-1].action, # after idx of s 66 | samples[j][hist_len-1].reward, # after idx of s 67 | torch.cat([samples[j][i].state for i in range(1, hist_len+1)], 1), 68 | samples[j][hist_len-1].done) for j in range(batch_size)] 69 | 70 | return self._batch2torch(transitions, self.batch_size) 71 | 72 | def _batch2torch(self, batch, batch_size): 73 | """ List of transitions -> Batch of transitions -> pytorch tensors. 74 | 75 | Returns: 76 | states: torch.size([batch_size, hist_len, w, h]) 77 | a/r/d: torch.size([batch_size, 1]) 78 | """ 79 | # check-out pytorch dqn tutorial. 80 | # (t1, t2, ... tn) -> t((s1, s2, ..., sn), (a1, a2, ... an) ...) 81 | batch = BatchTransition(*zip(*batch)) 82 | 83 | # lists to tensors 84 | state_batch = torch.cat(batch.state, 0).type(self.dtype.FT) / 255 85 | action_batch = self.dtype.LT(batch.action).unsqueeze(1) 86 | reward_batch = self.dtype.FT(batch.reward).unsqueeze(1) 87 | next_state_batch = torch.cat(batch.state_, 0).type(self.dtype.FT) / 255 88 | # [False, False, True, False] -> [1, 1, 0, 1]::ByteTensor 89 | mask = 1 - self.dtype.BT(batch.done).unsqueeze(1) 90 | 91 | return [batch_size, state_batch, action_batch, reward_batch, 92 | next_state_batch, mask] 93 | 94 | 95 | class CachedExperienceReplay(nTupleExperienceReplay): 96 | def __init__(self, capacity, batch_size, hist_len, cuda, cached_batches): 97 | nTupleExperienceReplay.__init__(self, capacity, batch_size, hist_len, 98 | cuda) 99 | 100 | self.cached_batches = cached_batches # no of cached batches 101 | self.cache_size = cached_batches * batch_size 102 | self.sample_idx = 0 103 | 104 | def sample(self): 105 | if self.sample_idx % self.cached_batches == 0: 106 | self._fill_cache() 107 | self.sample_idx = 0 108 | cache = self._sample_from_cache(self.sample_idx) 109 | self.sample_idx += 1 110 | return cache 111 | 112 | def _fill_cache(self): 113 | sz = self.cache_size 114 | _, self.cs, self.ca, self.cr, self.cns, self.cd = self._sample(sz) 115 | 116 | def _sample_from_cache(self, batch_idx): 117 | batch_sz = self.batch_size 118 | sidx = batch_sz * batch_idx 119 | eidx = sidx + batch_sz 120 | return [ 121 | batch_sz, 122 | self.cs[sidx:eidx], 123 | self.ca[sidx:eidx], 124 | self.cr[sidx:eidx], 125 | self.cns[sidx:eidx], 126 | self.cd[sidx:eidx] 127 | ] 128 | -------------------------------------------------------------------------------- /policy_improvement/categorical_update.py: -------------------------------------------------------------------------------- 1 | """ Categorical DQN policy improvement. 2 | """ 3 | import torch 4 | from torch.autograd import Variable 5 | # import torch.nn.functional as F 6 | from termcolor import colored as clr 7 | from utils import TorchTypes 8 | from policy_improvement.optim_utils import optim_factory, lr_schedule 9 | 10 | 11 | class CategoricalPolicyImprovement(object): 12 | """ Deep Q-Learning training method. """ 13 | 14 | def __init__(self, policy, target_policy, cmdl): 15 | self.name = "Categorical-PI" 16 | self.cmdl = cmdl 17 | self.policy = policy 18 | self.target_policy = target_policy 19 | self.lr = cmdl.lr 20 | self.gamma = cmdl.gamma 21 | 22 | self.optimizer = optim_factory(self.policy.parameters(), cmdl) 23 | self.optimizer.zero_grad() 24 | self.lr_generator = lr_schedule(cmdl.lr, 0.00001, cmdl.training_steps) 25 | 26 | self.dtype = dtype = TorchTypes(cmdl.cuda) 27 | self.v_min, self.v_max = v_min, v_max = cmdl.v_min, cmdl.v_max 28 | self.atoms_no = atoms_no = cmdl.atoms_no 29 | self.support = torch.linspace(v_min, v_max, atoms_no) 30 | self.support = self.support.type(dtype.FT) 31 | self.delta_z = (cmdl.v_max - cmdl.v_min) / (cmdl.atoms_no - 1) 32 | self.m = torch.zeros(cmdl.batch_size, self.atoms_no).type(dtype.FT) 33 | 34 | def accumulate_gradient(self, batch_sz, states, actions, rewards, 35 | next_states, mask): 36 | """ Compute the difference between the return distributions of Q(s,a) 37 | and TQ(s_,a). 38 | """ 39 | states = Variable(states) 40 | actions = Variable(actions) 41 | next_states = Variable(next_states, volatile=True) 42 | 43 | # Compute probabilities of Q(s,a*) 44 | q_probs = self.policy(states) 45 | actions = actions.view(batch_sz, 1, 1) 46 | action_mask = actions.expand(batch_sz, 1, self.atoms_no) 47 | qa_probs = q_probs.gather(1, action_mask).squeeze() 48 | 49 | # Compute distribution of Q(s_,a) 50 | target_qa_probs = self._get_categorical(next_states, rewards, mask) 51 | 52 | # Compute the cross-entropy of phi(TZ(x_,a)) || Z(x,a) 53 | qa_probs = qa_probs.clamp(min=1e-3) # Tudor's trick for avoiding nans 54 | loss = - torch.sum(target_qa_probs * torch.log(qa_probs)) 55 | 56 | # Accumulate gradients 57 | loss.backward() 58 | 59 | def update_model(self): 60 | if self.cmdl.optim == "RMSprop": 61 | lr = next(self.lr_generator) 62 | for param_group in self.optimizer.param_groups: 63 | param_group['lr'] = lr 64 | self.optimizer.step() 65 | self.optimizer.zero_grad() 66 | 67 | def _get_categorical(self, next_states, rewards, mask): 68 | batch_sz = next_states.size(0) 69 | gamma = self.gamma 70 | 71 | # Compute probabilities p(x, a) 72 | probs = self.target_policy(next_states).data 73 | qs = torch.mul(probs, self.support.expand_as(probs)) 74 | argmax_a = qs.sum(2).max(1)[1].unsqueeze(1).unsqueeze(1) 75 | action_mask = argmax_a.expand(batch_sz, 1, self.atoms_no) 76 | qa_probs = probs.gather(1, action_mask).squeeze() 77 | 78 | # Mask gamma and reshape it torgether with rewards to fit p(x,a). 79 | rewards = rewards.expand_as(qa_probs) 80 | gamma = (mask.float() * gamma).expand_as(qa_probs) 81 | 82 | # Compute projection of the application of the Bellman operator. 83 | bellman_op = rewards + gamma * self.support.unsqueeze(0).expand_as(rewards) 84 | bellman_op = torch.clamp(bellman_op, self.v_min, self.v_max) 85 | 86 | # Compute categorical indices for distributing the probability 87 | m = self.m.fill_(0) 88 | b = (bellman_op - self.v_min) / self.delta_z 89 | l = b.floor().long() 90 | u = b.ceil().long() 91 | # Fix disappearing probability mass when l = b = u (b is int) 92 | l[(u > 0) * (l == u)] -= 1 93 | u[(l < (self.atoms_no - 1)) * (l == u)] += 1 94 | 95 | # Distribute probability 96 | """ 97 | for i in range(batch_sz): 98 | for j in range(self.atoms_no): 99 | uidx = u[i][j] 100 | lidx = l[i][j] 101 | m[i][lidx] = m[i][lidx] + qa_probs[i][j] * (uidx - b[i][j]) 102 | m[i][uidx] = m[i][uidx] + qa_probs[i][j] * (b[i][j] - lidx) 103 | for i in range(batch_sz): 104 | m[i].index_add_(0, l[i], qa_probs[i] * (u[i].float() - b[i])) 105 | m[i].index_add_(0, u[i], qa_probs[i] * (b[i] - l[i].float())) 106 | 107 | """ 108 | # Optimized by https://github.com/tudor-berariu 109 | offset = torch.linspace(0, ((batch_sz - 1) * self.atoms_no), batch_sz)\ 110 | .type(self.dtype.LT)\ 111 | .unsqueeze(1).expand(batch_sz, self.atoms_no) 112 | 113 | m.view(-1).index_add_(0, (l + offset).view(-1), 114 | (qa_probs * (u.float() - b)).view(-1)) 115 | m.view(-1).index_add_(0, (u + offset).view(-1), 116 | (qa_probs * (b - l.float())).view(-1)) 117 | return Variable(m) 118 | 119 | def update_target_net(self): 120 | """ Update the target net with the parameters in the online model.""" 121 | self.target_policy.load_state_dict(self.policy.state_dict()) 122 | 123 | def get_model_stats(self): 124 | param_abs_mean = 0 125 | grad_abs_mean = 0 126 | t_param_abs_mean = 0 127 | n_params = 0 128 | for p in self.policy.parameters(): 129 | param_abs_mean += p.data.abs().sum() 130 | grad_abs_mean += p.grad.data.abs().sum() 131 | n_params += p.data.nelement() 132 | for t in self.target_policy.parameters(): 133 | t_param_abs_mean += t.data.abs().sum() 134 | 135 | print("Wm: %.9f | Gm: %.9f | Tm: %.9f" % ( 136 | param_abs_mean / n_params, 137 | grad_abs_mean / n_params, 138 | t_param_abs_mean / n_params)) 139 | 140 | def _debug_transitions(self, mask, reward_batch): 141 | if mask[0] == 0: 142 | r = reward_batch[0, 0] 143 | if r == 1.0: 144 | print(r) 145 | 146 | def _debug_states(self, state_batch, next_state_batch, mask): 147 | for i in range(24): 148 | for j in range(24): 149 | px = state_batch[0, 0, i, j] 150 | if px < 0.90: 151 | print(clr("%.2f " % px, 'magenta'), end="") 152 | else: 153 | print(("%.2f " % px), end="") 154 | print() 155 | for i in range(24): 156 | for j in range(24): 157 | px = next_state_batch[0, 0, i, j] 158 | if px < 0.90: 159 | print(clr("%.2f " % px, 'magenta'), end="") 160 | else: 161 | print(clr("%.2f " % px, 'white'), end="") 162 | print() 163 | if mask[0] == 0: 164 | print(clr("Done batch ............", 'magenta')) 165 | else: 166 | print(".......................") 167 | -------------------------------------------------------------------------------- /utils/wrappers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import logging 3 | import torch 4 | import numpy as np 5 | import gym 6 | from gym import Wrapper 7 | from gym import ObservationWrapper 8 | from gym import RewardWrapper 9 | from PIL import Image 10 | from termcolor import colored as clr 11 | from collections import OrderedDict 12 | from utils.torch_types import TorchTypes 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SqueezeRewards(RewardWrapper): 18 | def __init__(self, env): 19 | super(SqueezeRewards, self).__init__(env) 20 | print("[Reward Wrapper] for clamping rewards to -+1") 21 | 22 | def _reward(self, reward): 23 | return float(np.sign(reward)) 24 | 25 | 26 | class PreprocessFrames(ObservationWrapper): 27 | def __init__(self, env, env_type, hist_len, state_dims, cuda=None): 28 | super(PreprocessFrames, self).__init__(env) 29 | 30 | self.env_type = env_type 31 | self.state_dims = state_dims 32 | self.hist_len = hist_len 33 | self.env_wh = self.env.observation_space.shape[0:2] 34 | self.env_ch = self.env.observation_space.shape[2] 35 | self.wxh = self.env_wh[0] * self.env_wh[1] 36 | 37 | # need to find a better way 38 | if self.env_type == "atari": 39 | self._preprocess = self._atari_preprocess 40 | elif self.env_type == "catch": 41 | self._preprocess = self._catch_preprocess 42 | print("[Preprocess Wrapper] for %s with state history of %d frames." 43 | % (self.env_type, hist_len)) 44 | 45 | self.cuda = False if cuda is None else cuda 46 | self.dtype = dtype = TorchTypes(self.cuda) 47 | self.rgb = dtype.FT([.2126, .7152, .0722]) 48 | 49 | # torch.size([1, 4, 24, 24]) 50 | """ 51 | self.hist_state = torch.FloatTensor(1, hist_len, *state_dims) 52 | self.hist_state.fill_(0) 53 | """ 54 | 55 | self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0) 56 | for i in range(hist_len)}) 57 | 58 | def _observation(self, o): 59 | return self._preprocess(o) 60 | 61 | def _reset(self): 62 | # self.hist_state.fill_(0) 63 | self.d = OrderedDict( 64 | {i: torch.FloatTensor(1, 1, *self.state_dims).fill_(0) 65 | for i in range(self.hist_len)}) 66 | observation = self.env.reset() 67 | return self._observation(observation) 68 | 69 | def _catch_preprocess(self, o): 70 | return self._get_concatenated_state(self._rgb2y(o)) 71 | 72 | def _atari_preprocess(self, o): 73 | img = Image.fromarray(self._rgb2y(o).numpy()) 74 | img = np.array(img.resize(self.state_dims, resample=Image.NEAREST)) 75 | th_img = torch.from_numpy(img) 76 | return self._get_concatenated_state(th_img) 77 | 78 | def _rgb2y(self, o): 79 | o = torch.from_numpy(o).type(self.dtype.FT) 80 | s = o.view(self.wxh, 3).mv(self.rgb).view(*self.env_wh) / 255 81 | return s.cpu() 82 | 83 | def _get_concatenated_state(self, o): 84 | hist_len = self.hist_len 85 | for i in range(hist_len - 1): 86 | self.d[i] = self.d[i + 1] 87 | self.d[hist_len - 1] = o.unsqueeze(0).unsqueeze(0) 88 | return torch.cat(list(self.d.values()), 1) 89 | 90 | """ 91 | def _get_concatenated_state(self, o): 92 | hist_len = self.hist_len # eg. 4 93 | 94 | # move frames already existent one position below 95 | if hist_len > 1: 96 | self.hist_state[0][0:hist_len - 1] = self.hist_state[0][1:hist_len] 97 | 98 | # concatenate the newest frame to the top of the augmented state 99 | self.hist_state[0][self.hist_len - 1] = o 100 | return self.hist_state 101 | """ 102 | 103 | 104 | class DoneAfterLostLife(gym.Wrapper): 105 | def __init__(self, env): 106 | super(DoneAfterLostLife, self).__init__(env) 107 | 108 | self.no_more_lives = True 109 | self.crt_live = env.unwrapped.ale.lives() 110 | self.has_many_lives = self.crt_live != 0 111 | 112 | if self.has_many_lives: 113 | self._step = self._many_lives_step 114 | else: 115 | self._step = self._one_live_step 116 | not_a = clr("not a", attrs=['bold']) 117 | 118 | print("[DoneAfterLostLife Wrapper] %s is %s many lives game." 119 | % (env.env.spec.id, "a" if self.has_many_lives else not_a)) 120 | 121 | def _reset(self): 122 | if self.no_more_lives: 123 | obs = self.env.reset() 124 | self.crt_live = self.env.unwrapped.ale.lives() 125 | return obs 126 | else: 127 | return self.__obs 128 | 129 | def _many_lives_step(self, action): 130 | obs, reward, done, info = self.env.step(action) 131 | crt_live = self.env.unwrapped.ale.lives() 132 | if crt_live < self.crt_live: 133 | # just lost a live 134 | done = True 135 | self.crt_live = crt_live 136 | 137 | if crt_live == 0: 138 | self.no_more_lives = True 139 | else: 140 | self.no_more_lives = False 141 | self.__obs = obs 142 | return obs, reward, done, info 143 | 144 | def _one_live_step(self, action): 145 | return self.env.step(action) 146 | 147 | 148 | class EvaluationMonitor(Wrapper): 149 | def __init__(self, env, cmdl): 150 | super(EvaluationMonitor, self).__init__(env) 151 | 152 | self.freq = cmdl.eval_frequency # in steps 153 | self.eval_steps = cmdl.eval_steps 154 | self.cmdl = cmdl 155 | 156 | if self.cmdl.display_plots: 157 | import Visdom 158 | self.vis = Visdom() 159 | self.plot = self.vis.line( 160 | Y=np.array([0]), X=np.array([0]), 161 | opts=dict( 162 | title=cmdl.label, 163 | caption="Episodic reward per %d steps." % self.eval_steps) 164 | ) 165 | 166 | self.eval_cnt = 0 167 | self.crt_training_step = 0 168 | self.step_cnt = 0 169 | self.ep_cnt = 1 170 | self.total_rw = 0 171 | self.max_mean_rw = -1000 172 | 173 | no_of_evals = cmdl.training_steps // cmdl.eval_frequency \ 174 | - (cmdl.eval_start-1) // cmdl.eval_frequency 175 | 176 | self.eval_frame_idx = torch.LongTensor(no_of_evals).fill_(0) 177 | self.eval_rw_per_episode = torch.FloatTensor(no_of_evals).fill_(0) 178 | self.eval_rw_per_frame = torch.FloatTensor(no_of_evals).fill_(0) 179 | self.eval_eps_per_eval = torch.LongTensor(no_of_evals).fill_(0) 180 | 181 | def get_crt_step(self, crt_training_step): 182 | self.crt_training_step = crt_training_step 183 | 184 | def _reset_monitor(self): 185 | self.step_cnt, self.ep_cnt, self.total_rw = 0, 0, 0 186 | 187 | def _step(self, action): 188 | # self._before_step(action) 189 | observation, reward, done, info = self.env.step(action) 190 | done = self._after_step(observation, reward, done, info) 191 | return observation, reward, done, info 192 | 193 | def _reset(self): 194 | observation = self.env.reset() 195 | self._after_reset(observation) 196 | return observation 197 | 198 | def _after_step(self, o, r, done, info): 199 | self.total_rw += r 200 | self.step_cnt += 1 201 | 202 | # Evaluation ends here 203 | if self.step_cnt == self.eval_steps: 204 | self._update() 205 | self._reset_monitor() 206 | return done 207 | 208 | def _after_reset(self, observation): 209 | if self.step_cnt != self.eval_steps: 210 | self.ep_cnt += 1 211 | 212 | def _update(self): 213 | mean_rw = self.total_rw / (self.ep_cnt - 1) 214 | max_mean_rw = self.max_mean_rw 215 | self.max_mean_rw = mean_rw if mean_rw > max_mean_rw else max_mean_rw 216 | 217 | self._update_plot(self.crt_training_step, mean_rw) 218 | self._display_logs(mean_rw, max_mean_rw) 219 | self._update_reports(mean_rw) 220 | self.eval_cnt += 1 221 | 222 | def _update_reports(self, mean_rw): 223 | idx = self.eval_cnt 224 | 225 | self.eval_frame_idx[idx] = self.crt_training_step 226 | self.eval_rw_per_episode[idx] = mean_rw 227 | self.eval_rw_per_frame[idx] = self.total_rw / self.step_cnt 228 | self.eval_eps_per_eval[idx] = (self.ep_cnt - 1) 229 | 230 | torch.save({ 231 | 'eval_frame_idx': self.eval_frame_idx, 232 | 'eval_rw_per_episode': self.eval_rw_per_episode, 233 | 'eval_rw_per_frame': self.eval_rw_per_frame, 234 | 'eval_eps_per_eval': self.eval_eps_per_eval 235 | }, self.cmdl.results_path + "/eval_stats.torch") 236 | 237 | def _update_plot(self, crt_training_step, mean_rw): 238 | if self.cmdl.display_plots: 239 | self.vis.line( 240 | X=np.array([crt_training_step]), 241 | Y=np.array([mean_rw]), 242 | win=self.plot, 243 | update='append' 244 | ) 245 | 246 | def _display_logs(self, mean_rw, max_mean_rw): 247 | bg_color = 'on_magenta' if mean_rw > max_mean_rw else 'on_blue' 248 | print(clr("[Evaluator] done in %5d steps. " % self.step_cnt, 249 | attrs=['bold']) 250 | + clr(" rw/ep=%3.2f " % mean_rw, 'white', bg_color, 251 | attrs=['bold'])) 252 | 253 | 254 | class VisdomMonitor(Wrapper): 255 | def __init__(self, env, cmdl): 256 | super(VisdomMonitor, self).__init__(env) 257 | 258 | self.freq = cmdl.report_freq # in steps 259 | self.cmdl = cmdl 260 | 261 | if self.cmdl.display_plots: 262 | from visdom import Visdom 263 | self.vis = Visdom() 264 | self.plot = self.vis.line( 265 | Y=np.array([0]), X=np.array([0]), 266 | opts=dict( 267 | title=cmdl.label, 268 | caption="Episodic reward per 1200 steps.") 269 | ) 270 | 271 | self.step_cnt = 0 272 | self.ep_cnt = -1 273 | self.ep_rw = [] 274 | self.last_reported_ep = 0 275 | 276 | def _step(self, action): 277 | # self._before_step(action) 278 | observation, reward, done, info = self.env.step(action) 279 | done = self._after_step(observation, reward, done, info) 280 | return observation, reward, done, info 281 | 282 | def _reset(self): 283 | self._before_reset() 284 | observation = self.env.reset() 285 | self._after_reset(observation) 286 | return observation 287 | 288 | def _after_step(self, o, r, done, info): 289 | self.ep_rw[self.ep_cnt] += r 290 | self.step_cnt += 1 291 | if self.step_cnt % self.freq == 0: 292 | self._update_plot() 293 | return done 294 | 295 | def _before_reset(self): 296 | self.ep_rw.append(0) 297 | 298 | def _after_reset(self, observation): 299 | self.ep_cnt += 1 300 | # print("[%2d][%4d] RESET" % (self.ep_cnt, self.step_cnt)) 301 | 302 | def _update_plot(self): 303 | # print(self.last_reported_ep, self.ep_cnt + 1) 304 | completed_eps = self.ep_rw[self.last_reported_ep:self.ep_cnt + 1] 305 | ep_mean_reward = sum(completed_eps) / len(completed_eps) 306 | if self.cmdl.display_plots: 307 | self.vis.line( 308 | X=np.array([self.step_cnt]), 309 | Y=np.array([ep_mean_reward]), 310 | win=self.plot, 311 | update='append' 312 | ) 313 | self.last_reported_ep = self.ep_cnt + 1 314 | 315 | 316 | class TestAtariWrappers(unittest.TestCase): 317 | 318 | def _test_env(self, env_name): 319 | env = gym.make(env_name) 320 | env = DoneAfterLostLife(env) 321 | 322 | o = env.reset() 323 | 324 | for i in range(10000): 325 | o, r, d, _ = env.step(env.action_space.sample()) 326 | if d: 327 | o = env.reset() 328 | print("%3d, %s, %d" % (i, env_name, env.unwrapped.ale.lives())) 329 | 330 | def test_pong(self): 331 | print("Testing Pong") 332 | self._test_env("Pong-v0") 333 | 334 | def test_frostbite(self): 335 | print("Testing Frostbite") 336 | self._test_env("Frostbite-v0") 337 | 338 | 339 | if __name__ == "__main__": 340 | import unittest 341 | unittest.main() 342 | --------------------------------------------------------------------------------