├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── deep_rl ├── __init__.py ├── agent │ ├── A2C_agent.py │ ├── BaseAgent.py │ ├── CategoricalDQN_agent.py │ ├── DDPG_agent.py │ ├── DQN_agent.py │ ├── NStepDQN_agent.py │ ├── OptionCritic_agent.py │ ├── PPO_agent.py │ ├── QuantileRegressionDQN_agent.py │ ├── TD3_agent.py │ └── __init__.py ├── component │ ├── __init__.py │ ├── envs.py │ ├── random_process.py │ └── replay.py ├── network │ ├── __init__.py │ ├── network_bodies.py │ ├── network_heads.py │ └── network_utils.py └── utils │ ├── __init__.py │ ├── config.py │ ├── logger.py │ ├── misc.py │ ├── normalizer.py │ ├── plot.py │ ├── schedule.py │ ├── sum_tree.py │ └── torch_utils.py ├── docker_batch.sh ├── docker_build.sh ├── docker_clean.sh ├── docker_python.sh ├── docker_shell.sh ├── docker_stop.sh ├── examples.py ├── images ├── Breakout.png ├── PPO.png └── mujoco_eval.png ├── requirements.txt ├── setup.py ├── template_jobs.py └── template_plot.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | .idea 92 | .vscode 93 | data 94 | dataset 95 | log 96 | old_logs 97 | figure 98 | images_data 99 | mjkey.txt 100 | .DS_Store 101 | tf_log -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-base 2 | 3 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --allow-unauthenticated --no-install-recommends \ 4 | build-essential apt-utils cmake git curl vim ca-certificates \ 5 | libjpeg-dev libpng-dev \ 6 | libgtk3.0 libsm6 cmake ffmpeg pkg-config \ 7 | qtbase5-dev libqt5opengl5-dev libassimp-dev \ 8 | libboost-python-dev libtinyxml-dev bash \ 9 | wget unzip libosmesa6-dev software-properties-common \ 10 | libopenmpi-dev libglew-dev openssh-server \ 11 | libosmesa6-dev libgl1-mesa-glx libgl1-mesa-dev patchelf libglfw3 12 | 13 | RUN rm -rf /var/lib/apt/lists/* 14 | 15 | ARG UID 16 | RUN useradd -u $UID --create-home user 17 | USER user 18 | WORKDIR /home/user 19 | 20 | RUN wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 21 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 && \ 22 | rm Miniconda3-latest-Linux-x86_64.sh 23 | ENV PATH /home/user/miniconda3/bin:$PATH 24 | 25 | RUN mkdir -p .mujoco \ 26 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \ 27 | && unzip mujoco.zip -d .mujoco \ 28 | && rm mujoco.zip 29 | RUN wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip \ 30 | && unzip mujoco.zip -d .mujoco \ 31 | && rm mujoco.zip 32 | 33 | # Make sure you have a license, otherwise comment this line out 34 | # Of course you then cannot use Mujoco and DM Control, but Roboschool is still available 35 | COPY ./mjkey.txt .mujoco/mjkey.txt 36 | 37 | ENV LD_LIBRARY_PATH /home/user/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH} 38 | ENV LD_LIBRARY_PATH /home/user/.mujoco/mjpro200_linux/bin:${LD_LIBRARY_PATH} 39 | 40 | RUN conda install -y python=3.6 41 | RUN conda install mpi4py 42 | COPY requirements.txt requirements.txt 43 | RUN pip install -r requirements.txt 44 | RUN pip install glfw Cython imageio lockfile 45 | RUN pip install mujoco-py==1.50.1.68 46 | RUN pip install git+git://github.com/deepmind/dm_control.git@103834 47 | RUN pip install git+https://github.com/ShangtongZhang/dm_control2gym.git@scalar_fix 48 | RUN pip install git+git://github.com/openai/baselines.git@8e56dd#egg=baselines 49 | WORKDIR /home/user/deep_rl 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shangtong Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepRL 2 | 3 | > If you have any question or want to report a bug, please open an issue instead of emailing me directly. 4 | 5 | Modularized implementation of popular deep RL algorithms in PyTorch. 6 | Easy switch between toy tasks and challenging games. 7 | 8 | Implemented algorithms: 9 | * (Double/Dueling/Prioritized) Deep Q-Learning (DQN) 10 | * Categorical DQN (C51) 11 | * Quantile Regression DQN (QR-DQN) 12 | * (Continuous/Discrete) Synchronous Advantage Actor Critic (A2C) 13 | * Synchronous N-Step Q-Learning (N-Step DQN) 14 | * Deep Deterministic Policy Gradient (DDPG) 15 | * Proximal Policy Optimization (PPO) 16 | * The Option-Critic Architecture (OC) 17 | * Twined Delayed DDPG (TD3) 18 | * [Off-PAC-KL/TruncatedETD/DifferentialGQ/MVPI/ReverseRL/COF-PAC/GradientDICE/Bi-Res-DDPG/DAC/Geoff-PAC/QUOTA/ACE](#code-of-my-papers) 19 | 20 | The DQN agent, as well as C51 and QR-DQN, has an asynchronous actor for data generation and an asynchronous replay buffer for transferring data to GPU. 21 | Using 1 RTX 2080 Ti and 3 threads, the DQN agent runs for 10M steps (40M frames, 2.5M gradient updates) for Breakout within 6 hours. 22 | 23 | # Dependency 24 | * PyTorch v1.5.1 25 | * See ```Dockerfile``` and ```requirements.txt``` for more details 26 | 27 | # Usage 28 | 29 | ```examples.py``` contains examples for all the implemented algorithms. 30 | ```Dockerfile``` contains the environment for generating the curves below. 31 | Please use this bibtex if you want to cite this repo 32 | ``` 33 | @misc{deeprl, 34 | author = {Zhang, Shangtong}, 35 | title = {Modularized Implementation of Deep RL Algorithms in PyTorch}, 36 | year = {2018}, 37 | publisher = {GitHub}, 38 | journal = {GitHub Repository}, 39 | howpublished = {\url{https://github.com/ShangtongZhang/DeepRL}}, 40 | } 41 | ``` 42 | 43 | # Curves (commit ```9e811e```) 44 | 45 | ## BreakoutNoFrameskip-v4 (1 run) 46 | 47 | ![Loading...](https://raw.githubusercontent.com/ShangtongZhang/DeepRL/master/images/Breakout.png) 48 | 49 | ## Mujoco 50 | 51 | * DDPG/TD3 evaluation performance. 52 | ![Loading...](https://raw.githubusercontent.com/ShangtongZhang/DeepRL/master/images/mujoco_eval.png) 53 | (5 runs, mean + standard error) 54 | 55 | * PPO online performance. 56 | ![Loading...](https://raw.githubusercontent.com/ShangtongZhang/DeepRL/master/images/PPO.png) 57 | (5 runs, mean + standard error, smoothed by a window of size 10) 58 | 59 | 60 | # References 61 | * [Human Level Control through Deep Reinforcement Learning](https://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) 62 | * [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) 63 | * [Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461) 64 | * [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581) 65 | * [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) 66 | * [HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent](https://arxiv.org/abs/1106.5730) 67 | * [Deterministic Policy Gradient Algorithms](http://proceedings.mlr.press/v32/silver14.pdf) 68 | * [Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971) 69 | * [High-Dimensional Continuous Control Using Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438) 70 | * [Hybrid Reward Architecture for Reinforcement Learning](https://arxiv.org/abs/1706.04208) 71 | * [Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477) 72 | * [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) 73 | * [Emergence of Locomotion Behaviours in Rich Environments](https://arxiv.org/abs/1707.02286) 74 | * [Action-Conditional Video Prediction using Deep Networks in Atari Games](https://arxiv.org/abs/1507.08750) 75 | * [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887) 76 | * [Distributional Reinforcement Learning with Quantile Regression](https://arxiv.org/abs/1710.10044) 77 | * [The Option-Critic Architecture](https://arxiv.org/abs/1609.05140) 78 | * [Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477) 79 | * Some hyper-parameters are from [DeepMind Control Suite](https://arxiv.org/abs/1801.00690), [OpenAI Baselines](https://github.com/openai/baselines) and [Ilya Kostrikov](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr) 80 | 81 | # Code of My Papers 82 | > They are located in other branches of this repo and seem to be good examples for using this codebase. 83 | * [Global Optimality and Finite Sample Analysis of Softmax Off-Policy Actor Critic under State Distribution Mismatch](https://arxiv.org/abs/2111.02997) [[Off-PAC-KL](https://github.com/ShangtongZhang/DeepRL/tree/Off-PAC-KL)] 84 | * [Truncated Emphatic Temporal Difference Methods for Prediction and Control](https://arxiv.org/abs/2108.05338) [[TruncatedETD](https://github.com/ShangtongZhang/DeepRL/tree/TruncatedETD)] 85 | * [A Deeper Look at Discounting Mismatch in Actor-Critic Algorithms](https://arxiv.org/abs/2010.01069) [[Discounting](https://github.com/ShangtongZhang/DeepRL/tree/discounting)] 86 | * [Breaking the Deadly Triad with a Target Network](https://arxiv.org/abs/2101.08862) [[TargetNetwork](https://github.com/ShangtongZhang/DeepRL/tree/TargetNetwork)] 87 | * [Average-Reward Off-Policy Policy Evaluation with Function Approximation](https://arxiv.org/abs/2101.02808) [[DifferentialGQ](https://github.com/ShangtongZhang/DeepRL/tree/DifferentialGQ)] 88 | * [Mean-Variance Policy Iteration for Risk-Averse Reinforcement Learning](https://arxiv.org/abs/2004.10888) [[MVPI](https://github.com/ShangtongZhang/DeepRL/tree/MVPI)] 89 | * [Learning Retrospective Knowledge with Reverse Reinforcement Learning](https://arxiv.org/abs/2007.06703) [[ReverseRL](https://github.com/ShangtongZhang/DeepRL/tree/ReverseRL)] 90 | * [Provably Convergent Two-Timescale Off-Policy Actor-Critic with Function Approximation](https://arxiv.org/abs/1911.04384) [[COF-PAC](https://github.com/ShangtongZhang/DeepRL/tree/COF-PAC), [TD3-random](https://github.com/ShangtongZhang/DeepRL/tree/TD3-random)] 91 | * [GradientDICE: Rethinking Generalized Offline Estimation of Stationary Values](https://arxiv.org/abs/2001.11113) [[GradientDICE](https://github.com/ShangtongZhang/DeepRL/tree/GradientDICE)] 92 | * [Deep Residual Reinforcement Learning](https://arxiv.org/abs/1905.01072) [[Bi-Res-DDPG](https://github.com/ShangtongZhang/DeepRL/tree/Bi-Res-DDPG)] 93 | * [Generalized Off-Policy Actor-Critic](https://arxiv.org/abs/1903.11329) [[Geoff-PAC](https://github.com/ShangtongZhang/DeepRL/tree/Geoff-PAC), [TD3-random](https://github.com/ShangtongZhang/DeepRL/tree/TD3-random)] 94 | * [DAC: The Double Actor-Critic Architecture for Learning Options](https://arxiv.org/abs/1904.12691) [[DAC](https://github.com/ShangtongZhang/DeepRL/tree/DAC)] 95 | * [QUOTA: The Quantile Option Architecture for Reinforcement Learning](https://arxiv.org/abs/1811.02073) [[QUOTA-discrete](https://github.com/ShangtongZhang/DeepRL/tree/QUOTA-discrete), [QUOTA-continuous](https://github.com/ShangtongZhang/DeepRL/tree/QUOTA-continuous)] 96 | * [ACE: An Actor Ensemble Algorithm for Continuous Control with Tree Search](https://arxiv.org/abs/1811.02696) [[ACE](https://github.com/ShangtongZhang/DeepRL/tree/ACE)] 97 | -------------------------------------------------------------------------------- /deep_rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import * 2 | from .component import * 3 | from .network import * 4 | from .utils import * -------------------------------------------------------------------------------- /deep_rl/agent/A2C_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from .BaseAgent import * 10 | 11 | 12 | class A2CAgent(BaseAgent): 13 | def __init__(self, config): 14 | BaseAgent.__init__(self, config) 15 | self.config = config 16 | self.task = config.task_fn() 17 | self.network = config.network_fn() 18 | self.optimizer = config.optimizer_fn(self.network.parameters()) 19 | self.total_steps = 0 20 | self.states = self.task.reset() 21 | 22 | def step(self): 23 | config = self.config 24 | storage = Storage(config.rollout_length) 25 | states = self.states 26 | for _ in range(config.rollout_length): 27 | prediction = self.network(config.state_normalizer(states)) 28 | next_states, rewards, terminals, info = self.task.step(to_np(prediction['action'])) 29 | self.record_online_return(info) 30 | rewards = config.reward_normalizer(rewards) 31 | storage.feed(prediction) 32 | storage.feed({'reward': tensor(rewards).unsqueeze(-1), 33 | 'mask': tensor(1 - terminals).unsqueeze(-1)}) 34 | 35 | states = next_states 36 | self.total_steps += config.num_workers 37 | 38 | self.states = states 39 | prediction = self.network(config.state_normalizer(states)) 40 | storage.feed(prediction) 41 | storage.placeholder() 42 | 43 | advantages = tensor(np.zeros((config.num_workers, 1))) 44 | returns = prediction['v'].detach() 45 | for i in reversed(range(config.rollout_length)): 46 | returns = storage.reward[i] + config.discount * storage.mask[i] * returns 47 | if not config.use_gae: 48 | advantages = returns - storage.v[i].detach() 49 | else: 50 | td_error = storage.reward[i] + config.discount * storage.mask[i] * storage.v[i + 1] - storage.v[i] 51 | advantages = advantages * config.gae_tau * config.discount * storage.mask[i] + td_error 52 | storage.advantage[i] = advantages.detach() 53 | storage.ret[i] = returns.detach() 54 | 55 | entries = storage.extract(['log_pi_a', 'v', 'ret', 'advantage', 'entropy']) 56 | policy_loss = -(entries.log_pi_a * entries.advantage).mean() 57 | value_loss = 0.5 * (entries.ret - entries.v).pow(2).mean() 58 | entropy_loss = entries.entropy.mean() 59 | 60 | self.optimizer.zero_grad() 61 | (policy_loss - config.entropy_weight * entropy_loss + 62 | config.value_loss_weight * value_loss).backward() 63 | nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip) 64 | self.optimizer.step() 65 | -------------------------------------------------------------------------------- /deep_rl/agent/BaseAgent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import torch 8 | import numpy as np 9 | from ..utils import * 10 | import torch.multiprocessing as mp 11 | from collections import deque 12 | from skimage.io import imsave 13 | 14 | 15 | class BaseAgent: 16 | def __init__(self, config): 17 | self.config = config 18 | self.logger = get_logger(tag=config.tag, log_level=config.log_level) 19 | self.task_ind = 0 20 | 21 | def close(self): 22 | close_obj(self.task) 23 | 24 | def save(self, filename): 25 | torch.save(self.network.state_dict(), '%s.model' % (filename)) 26 | with open('%s.stats' % (filename), 'wb') as f: 27 | pickle.dump(self.config.state_normalizer.state_dict(), f) 28 | 29 | def load(self, filename): 30 | state_dict = torch.load('%s.model' % filename, map_location=lambda storage, loc: storage) 31 | self.network.load_state_dict(state_dict) 32 | with open('%s.stats' % (filename), 'rb') as f: 33 | self.config.state_normalizer.load_state_dict(pickle.load(f)) 34 | 35 | def eval_step(self, state): 36 | raise NotImplementedError 37 | 38 | def eval_episode(self): 39 | env = self.config.eval_env 40 | state = env.reset() 41 | while True: 42 | action = self.eval_step(state) 43 | state, reward, done, info = env.step(action) 44 | ret = info[0]['episodic_return'] 45 | if ret is not None: 46 | break 47 | return ret 48 | 49 | def eval_episodes(self): 50 | episodic_returns = [] 51 | for ep in range(self.config.eval_episodes): 52 | total_rewards = self.eval_episode() 53 | episodic_returns.append(np.sum(total_rewards)) 54 | self.logger.info('steps %d, episodic_return_test %.2f(%.2f)' % ( 55 | self.total_steps, np.mean(episodic_returns), np.std(episodic_returns) / np.sqrt(len(episodic_returns)) 56 | )) 57 | self.logger.add_scalar('episodic_return_test', np.mean(episodic_returns), self.total_steps) 58 | return { 59 | 'episodic_return_test': np.mean(episodic_returns), 60 | } 61 | 62 | def record_online_return(self, info, offset=0): 63 | if isinstance(info, dict): 64 | ret = info['episodic_return'] 65 | if ret is not None: 66 | self.logger.add_scalar('episodic_return_train', ret, self.total_steps + offset) 67 | self.logger.info('steps %d, episodic_return_train %s' % (self.total_steps + offset, ret)) 68 | elif isinstance(info, tuple): 69 | for i, info_ in enumerate(info): 70 | self.record_online_return(info_, i) 71 | else: 72 | raise NotImplementedError 73 | 74 | def switch_task(self): 75 | config = self.config 76 | if not config.tasks: 77 | return 78 | segs = np.linspace(0, config.max_steps, len(config.tasks) + 1) 79 | if self.total_steps > segs[self.task_ind + 1]: 80 | self.task_ind += 1 81 | self.task = config.tasks[self.task_ind] 82 | self.states = self.task.reset() 83 | self.states = config.state_normalizer(self.states) 84 | 85 | def record_episode(self, dir, env): 86 | mkdir(dir) 87 | steps = 0 88 | state = env.reset() 89 | while True: 90 | self.record_obs(env, dir, steps) 91 | action = self.record_step(state) 92 | state, reward, done, info = env.step(action) 93 | ret = info[0]['episodic_return'] 94 | steps += 1 95 | if ret is not None: 96 | break 97 | 98 | def record_step(self, state): 99 | raise NotImplementedError 100 | 101 | # For DMControl 102 | def record_obs(self, env, dir, steps): 103 | env = env.env.envs[0] 104 | obs = env.render(mode='rgb_array') 105 | imsave('%s/%04d.png' % (dir, steps), obs) 106 | 107 | 108 | class BaseActor(mp.Process): 109 | STEP = 0 110 | RESET = 1 111 | EXIT = 2 112 | SPECS = 3 113 | NETWORK = 4 114 | CACHE = 5 115 | 116 | def __init__(self, config): 117 | mp.Process.__init__(self) 118 | self.config = config 119 | self.__pipe, self.__worker_pipe = mp.Pipe() 120 | 121 | self._state = None 122 | self._task = None 123 | self._network = None 124 | self._total_steps = 0 125 | self.__cache_len = 2 126 | 127 | if not config.async_actor: 128 | self.start = lambda: None 129 | self.step = self._sample 130 | self.close = lambda: None 131 | self._set_up() 132 | self._task = config.task_fn() 133 | 134 | def _sample(self): 135 | transitions = [] 136 | for _ in range(self.config.sgd_update_frequency): 137 | transition = self._transition() 138 | if transition is not None: 139 | transitions.append(transition) 140 | return transitions 141 | 142 | def run(self): 143 | self._set_up() 144 | config = self.config 145 | self._task = config.task_fn() 146 | 147 | cache = deque([], maxlen=2) 148 | while True: 149 | op, data = self.__worker_pipe.recv() 150 | if op == self.STEP: 151 | if not len(cache): 152 | cache.append(self._sample()) 153 | cache.append(self._sample()) 154 | self.__worker_pipe.send(cache.popleft()) 155 | cache.append(self._sample()) 156 | elif op == self.EXIT: 157 | self.__worker_pipe.close() 158 | return 159 | elif op == self.NETWORK: 160 | self._network = data 161 | else: 162 | raise NotImplementedError 163 | 164 | def _transition(self): 165 | raise NotImplementedError 166 | 167 | def _set_up(self): 168 | pass 169 | 170 | def step(self): 171 | self.__pipe.send([self.STEP, None]) 172 | return self.__pipe.recv() 173 | 174 | def close(self): 175 | self.__pipe.send([self.EXIT, None]) 176 | self.__pipe.close() 177 | 178 | def set_network(self, net): 179 | if not self.config.async_actor: 180 | self._network = net 181 | else: 182 | self.__pipe.send([self.NETWORK, net]) 183 | -------------------------------------------------------------------------------- /deep_rl/agent/CategoricalDQN_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from ..utils import * 10 | import time 11 | from .BaseAgent import * 12 | from .DQN_agent import * 13 | 14 | 15 | class CategoricalDQNActor(DQNActor): 16 | def __init__(self, config): 17 | super().__init__(config) 18 | 19 | def _set_up(self): 20 | self.config.atoms = tensor(self.config.atoms) 21 | 22 | def compute_q(self, prediction): 23 | q_values = (prediction['prob'] * self.config.atoms).sum(-1) 24 | return to_np(q_values) 25 | 26 | 27 | class CategoricalDQNAgent(DQNAgent): 28 | def __init__(self, config): 29 | BaseAgent.__init__(self, config) 30 | self.config = config 31 | config.lock = mp.Lock() 32 | config.atoms = np.linspace(config.categorical_v_min, 33 | config.categorical_v_max, config.categorical_n_atoms) 34 | 35 | self.replay = config.replay_fn() 36 | self.actor = CategoricalDQNActor(config) 37 | 38 | self.network = config.network_fn() 39 | self.network.share_memory() 40 | self.target_network = config.network_fn() 41 | self.target_network.load_state_dict(self.network.state_dict()) 42 | self.optimizer = config.optimizer_fn(self.network.parameters()) 43 | 44 | self.actor.set_network(self.network) 45 | 46 | self.total_steps = 0 47 | self.batch_indices = range_tensor(config.batch_size) 48 | self.atoms = tensor(config.atoms) 49 | self.delta_atom = (config.categorical_v_max - config.categorical_v_min) / float(config.categorical_n_atoms - 1) 50 | 51 | def eval_step(self, state): 52 | self.config.state_normalizer.set_read_only() 53 | state = self.config.state_normalizer(state) 54 | prediction = self.network(state) 55 | q = (prediction['prob'] * self.atoms).sum(-1) 56 | action = to_np(q.argmax(-1)) 57 | self.config.state_normalizer.unset_read_only() 58 | return action 59 | 60 | def compute_loss(self, transitions): 61 | config = self.config 62 | states = self.config.state_normalizer(transitions.state) 63 | next_states = self.config.state_normalizer(transitions.next_state) 64 | with torch.no_grad(): 65 | prob_next = self.target_network(next_states)['prob'] 66 | q_next = (prob_next * self.atoms).sum(-1) 67 | if config.double_q: 68 | a_next = torch.argmax((self.network(next_states)['prob'] * self.atoms).sum(-1), dim=-1) 69 | else: 70 | a_next = torch.argmax(q_next, dim=-1) 71 | prob_next = prob_next[self.batch_indices, a_next, :] 72 | 73 | rewards = tensor(transitions.reward).unsqueeze(-1) 74 | masks = tensor(transitions.mask).unsqueeze(-1) 75 | atoms_target = rewards + self.config.discount ** config.n_step * masks * self.atoms.view(1, -1) 76 | atoms_target.clamp_(self.config.categorical_v_min, self.config.categorical_v_max) 77 | atoms_target = atoms_target.unsqueeze(1) 78 | target_prob = (1 - (atoms_target - self.atoms.view(1, -1, 1)).abs() / self.delta_atom).clamp(0, 1) * \ 79 | prob_next.unsqueeze(1) 80 | target_prob = target_prob.sum(-1) 81 | 82 | log_prob = self.network(states)['log_prob'] 83 | actions = tensor(transitions.action).long() 84 | log_prob = log_prob[self.batch_indices, actions, :] 85 | KL = (target_prob * target_prob.add(1e-5).log() - target_prob * log_prob).sum(-1) 86 | return KL 87 | 88 | def reduce_loss(self, loss): 89 | return loss.mean() 90 | -------------------------------------------------------------------------------- /deep_rl/agent/DDPG_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from .BaseAgent import * 10 | import torchvision 11 | 12 | 13 | class DDPGAgent(BaseAgent): 14 | def __init__(self, config): 15 | BaseAgent.__init__(self, config) 16 | self.config = config 17 | self.task = config.task_fn() 18 | self.network = config.network_fn() 19 | self.target_network = config.network_fn() 20 | self.target_network.load_state_dict(self.network.state_dict()) 21 | self.replay = config.replay_fn() 22 | self.random_process = config.random_process_fn() 23 | self.total_steps = 0 24 | self.state = None 25 | 26 | def soft_update(self, target, src): 27 | for target_param, param in zip(target.parameters(), src.parameters()): 28 | target_param.detach_() 29 | target_param.copy_(target_param * (1.0 - self.config.target_network_mix) + 30 | param * self.config.target_network_mix) 31 | 32 | def eval_step(self, state): 33 | self.config.state_normalizer.set_read_only() 34 | state = self.config.state_normalizer(state) 35 | action = self.network(state) 36 | self.config.state_normalizer.unset_read_only() 37 | return to_np(action) 38 | 39 | def step(self): 40 | config = self.config 41 | if self.state is None: 42 | self.random_process.reset_states() 43 | self.state = self.task.reset() 44 | self.state = config.state_normalizer(self.state) 45 | 46 | if self.total_steps < config.warm_up: 47 | action = [self.task.action_space.sample()] 48 | else: 49 | action = self.network(self.state) 50 | action = to_np(action) 51 | action += self.random_process.sample() 52 | action = np.clip(action, self.task.action_space.low, self.task.action_space.high) 53 | next_state, reward, done, info = self.task.step(action) 54 | next_state = self.config.state_normalizer(next_state) 55 | self.record_online_return(info) 56 | reward = self.config.reward_normalizer(reward) 57 | 58 | self.replay.feed(dict( 59 | state=self.state, 60 | action=action, 61 | reward=reward, 62 | next_state=next_state, 63 | mask=1-np.asarray(done, dtype=np.int32), 64 | )) 65 | 66 | if done[0]: 67 | self.random_process.reset_states() 68 | self.state = next_state 69 | self.total_steps += 1 70 | 71 | if self.replay.size() >= config.warm_up: 72 | transitions = self.replay.sample() 73 | states = tensor(transitions.state) 74 | actions = tensor(transitions.action) 75 | rewards = tensor(transitions.reward).unsqueeze(-1) 76 | next_states = tensor(transitions.next_state) 77 | mask = tensor(transitions.mask).unsqueeze(-1) 78 | 79 | phi_next = self.target_network.feature(next_states) 80 | a_next = self.target_network.actor(phi_next) 81 | q_next = self.target_network.critic(phi_next, a_next) 82 | q_next = config.discount * mask * q_next 83 | q_next.add_(rewards) 84 | q_next = q_next.detach() 85 | phi = self.network.feature(states) 86 | q = self.network.critic(phi, actions) 87 | critic_loss = (q - q_next).pow(2).mul(0.5).sum(-1).mean() 88 | 89 | self.network.zero_grad() 90 | critic_loss.backward() 91 | self.network.critic_opt.step() 92 | 93 | phi = self.network.feature(states) 94 | action = self.network.actor(phi) 95 | policy_loss = -self.network.critic(phi.detach(), action).mean() 96 | 97 | self.network.zero_grad() 98 | policy_loss.backward() 99 | self.network.actor_opt.step() 100 | 101 | self.soft_update(self.target_network, self.network) 102 | -------------------------------------------------------------------------------- /deep_rl/agent/DQN_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from ..utils import * 10 | import time 11 | from .BaseAgent import * 12 | 13 | 14 | class DQNActor(BaseActor): 15 | def __init__(self, config): 16 | BaseActor.__init__(self, config) 17 | self.config = config 18 | self.start() 19 | 20 | def compute_q(self, prediction): 21 | q_values = to_np(prediction['q']) 22 | return q_values 23 | 24 | def _transition(self): 25 | if self._state is None: 26 | self._state = self._task.reset() 27 | config = self.config 28 | if config.noisy_linear: 29 | self._network.reset_noise() 30 | with config.lock: 31 | prediction = self._network(config.state_normalizer(self._state)) 32 | q_values = self.compute_q(prediction) 33 | 34 | if config.noisy_linear: 35 | epsilon = 0 36 | elif self._total_steps < config.exploration_steps: 37 | epsilon = 1 38 | else: 39 | epsilon = config.random_action_prob() 40 | action = epsilon_greedy(epsilon, q_values) 41 | next_state, reward, done, info = self._task.step(action) 42 | entry = [self._state, action, reward, next_state, done, info] 43 | self._total_steps += 1 44 | self._state = next_state 45 | return entry 46 | 47 | 48 | class DQNAgent(BaseAgent): 49 | def __init__(self, config): 50 | BaseAgent.__init__(self, config) 51 | self.config = config 52 | config.lock = mp.Lock() 53 | 54 | self.replay = config.replay_fn() 55 | self.actor = DQNActor(config) 56 | 57 | self.network = config.network_fn() 58 | self.network.share_memory() 59 | self.target_network = config.network_fn() 60 | self.target_network.load_state_dict(self.network.state_dict()) 61 | self.optimizer = config.optimizer_fn(self.network.parameters()) 62 | 63 | self.actor.set_network(self.network) 64 | self.total_steps = 0 65 | 66 | def close(self): 67 | close_obj(self.replay) 68 | close_obj(self.actor) 69 | 70 | def eval_step(self, state): 71 | self.config.state_normalizer.set_read_only() 72 | state = self.config.state_normalizer(state) 73 | q = self.network(state)['q'] 74 | action = to_np(q.argmax(-1)) 75 | self.config.state_normalizer.unset_read_only() 76 | return action 77 | 78 | def reduce_loss(self, loss): 79 | return loss.pow(2).mul(0.5).mean() 80 | 81 | def compute_loss(self, transitions): 82 | config = self.config 83 | states = self.config.state_normalizer(transitions.state) 84 | next_states = self.config.state_normalizer(transitions.next_state) 85 | with torch.no_grad(): 86 | q_next = self.target_network(next_states)['q'].detach() 87 | if self.config.double_q: 88 | best_actions = torch.argmax(self.network(next_states)['q'], dim=-1) 89 | q_next = q_next.gather(1, best_actions.unsqueeze(-1)).squeeze(1) 90 | else: 91 | q_next = q_next.max(1)[0] 92 | masks = tensor(transitions.mask) 93 | rewards = tensor(transitions.reward) 94 | q_target = rewards + self.config.discount ** config.n_step * q_next * masks 95 | actions = tensor(transitions.action).long() 96 | q = self.network(states)['q'] 97 | q = q.gather(1, actions.unsqueeze(-1)).squeeze(-1) 98 | loss = q_target - q 99 | return loss 100 | 101 | def step(self): 102 | config = self.config 103 | transitions = self.actor.step() 104 | for states, actions, rewards, next_states, dones, info in transitions: 105 | self.record_online_return(info) 106 | self.total_steps += 1 107 | self.replay.feed(dict( 108 | state=np.array([s[-1] if isinstance(s, LazyFrames) else s for s in states]), 109 | action=actions, 110 | reward=[config.reward_normalizer(r) for r in rewards], 111 | mask=1 - np.asarray(dones, dtype=np.int32), 112 | )) 113 | 114 | if self.total_steps > self.config.exploration_steps: 115 | transitions = self.replay.sample() 116 | if config.noisy_linear: 117 | self.target_network.reset_noise() 118 | self.network.reset_noise() 119 | loss = self.compute_loss(transitions) 120 | if isinstance(transitions, PrioritizedTransition): 121 | priorities = loss.abs().add(config.replay_eps).pow(config.replay_alpha) 122 | idxs = tensor(transitions.idx).long() 123 | self.replay.update_priorities(zip(to_np(idxs), to_np(priorities))) 124 | sampling_probs = tensor(transitions.sampling_prob) 125 | weights = sampling_probs.mul(sampling_probs.size(0)).add(1e-6).pow(-config.replay_beta()) 126 | weights = weights / weights.max() 127 | loss = loss.mul(weights) 128 | 129 | loss = self.reduce_loss(loss) 130 | self.optimizer.zero_grad() 131 | loss.backward() 132 | nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip) 133 | with config.lock: 134 | self.optimizer.step() 135 | 136 | if self.total_steps / self.config.sgd_update_frequency % \ 137 | self.config.target_network_update_freq == 0: 138 | self.target_network.load_state_dict(self.network.state_dict()) 139 | -------------------------------------------------------------------------------- /deep_rl/agent/NStepDQN_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from ..utils import * 10 | from .BaseAgent import * 11 | 12 | 13 | class NStepDQNAgent(BaseAgent): 14 | def __init__(self, config): 15 | BaseAgent.__init__(self, config) 16 | self.config = config 17 | self.task = config.task_fn() 18 | self.network = config.network_fn() 19 | self.target_network = config.network_fn() 20 | self.optimizer = config.optimizer_fn(self.network.parameters()) 21 | self.target_network.load_state_dict(self.network.state_dict()) 22 | 23 | self.total_steps = 0 24 | self.states = self.task.reset() 25 | 26 | def step(self): 27 | config = self.config 28 | storage = Storage(config.rollout_length) 29 | 30 | states = self.states 31 | for _ in range(config.rollout_length): 32 | q = self.network(self.config.state_normalizer(states))['q'] 33 | 34 | epsilon = config.random_action_prob(config.num_workers) 35 | actions = epsilon_greedy(epsilon, to_np(q)) 36 | 37 | next_states, rewards, terminals, info = self.task.step(actions) 38 | self.record_online_return(info) 39 | rewards = config.reward_normalizer(rewards) 40 | 41 | storage.feed({'q': q, 42 | 'action': tensor(actions).unsqueeze(-1).long(), 43 | 'reward': tensor(rewards).unsqueeze(-1), 44 | 'mask': tensor(1 - terminals).unsqueeze(-1)}) 45 | 46 | states = next_states 47 | 48 | self.total_steps += config.num_workers 49 | if self.total_steps // config.num_workers % config.target_network_update_freq == 0: 50 | self.target_network.load_state_dict(self.network.state_dict()) 51 | 52 | self.states = states 53 | 54 | storage.placeholder() 55 | 56 | ret = self.target_network(config.state_normalizer(states))['q'].detach() 57 | ret = torch.max(ret, dim=1, keepdim=True)[0] 58 | for i in reversed(range(config.rollout_length)): 59 | ret = storage.reward[i] + config.discount * storage.mask[i] * ret 60 | storage.ret[i] = ret 61 | 62 | entries = storage.extract(['q', 'action', 'ret']) 63 | loss = 0.5 * (entries.q.gather(1, entries.action) - entries.ret).pow(2).mean() 64 | self.optimizer.zero_grad() 65 | loss.backward() 66 | nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip) 67 | self.optimizer.step() 68 | -------------------------------------------------------------------------------- /deep_rl/agent/OptionCritic_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from .BaseAgent import * 10 | 11 | 12 | class OptionCriticAgent(BaseAgent): 13 | def __init__(self, config): 14 | BaseAgent.__init__(self, config) 15 | self.config = config 16 | self.task = config.task_fn() 17 | self.network = config.network_fn() 18 | self.target_network = config.network_fn() 19 | self.optimizer = config.optimizer_fn(self.network.parameters()) 20 | self.target_network.load_state_dict(self.network.state_dict()) 21 | 22 | self.total_steps = 0 23 | self.worker_index = tensor(np.arange(config.num_workers)).long() 24 | 25 | self.states = self.config.state_normalizer(self.task.reset()) 26 | self.is_initial_states = tensor(np.ones((config.num_workers))).byte() 27 | self.prev_options = self.is_initial_states.clone().long() 28 | 29 | def sample_option(self, prediction, epsilon, prev_option, is_intial_states): 30 | with torch.no_grad(): 31 | q_option = prediction['q'] 32 | pi_option = torch.zeros_like(q_option).add(epsilon / q_option.size(1)) 33 | greedy_option = q_option.argmax(dim=-1, keepdim=True) 34 | prob = 1 - epsilon + epsilon / q_option.size(1) 35 | prob = torch.zeros_like(pi_option).add(prob) 36 | pi_option.scatter_(1, greedy_option, prob) 37 | 38 | mask = torch.zeros_like(q_option) 39 | mask[self.worker_index, prev_option] = 1 40 | beta = prediction['beta'] 41 | pi_hat_option = (1 - beta) * mask + beta * pi_option 42 | 43 | dist = torch.distributions.Categorical(probs=pi_option) 44 | options = dist.sample() 45 | dist = torch.distributions.Categorical(probs=pi_hat_option) 46 | options_hat = dist.sample() 47 | 48 | options = torch.where(is_intial_states, options, options_hat) 49 | return options 50 | 51 | def step(self): 52 | config = self.config 53 | storage = Storage(config.rollout_length, ['beta', 'option', 'beta_advantage', 'prev_option', 'init_state', 'eps']) 54 | 55 | for _ in range(config.rollout_length): 56 | prediction = self.network(self.states) 57 | epsilon = config.random_option_prob(config.num_workers) 58 | options = self.sample_option(prediction, epsilon, self.prev_options, self.is_initial_states) 59 | prediction['pi'] = prediction['pi'][self.worker_index, options] 60 | prediction['log_pi'] = prediction['log_pi'][self.worker_index, options] 61 | dist = torch.distributions.Categorical(probs=prediction['pi']) 62 | actions = dist.sample() 63 | entropy = dist.entropy() 64 | 65 | next_states, rewards, terminals, info = self.task.step(to_np(actions)) 66 | self.record_online_return(info) 67 | next_states = config.state_normalizer(next_states) 68 | rewards = config.reward_normalizer(rewards) 69 | storage.feed(prediction) 70 | storage.feed({'reward': tensor(rewards).unsqueeze(-1), 71 | 'mask': tensor(1 - terminals).unsqueeze(-1), 72 | 'option': options.unsqueeze(-1), 73 | 'prev_option': self.prev_options.unsqueeze(-1), 74 | 'entropy': entropy.unsqueeze(-1), 75 | 'action': actions.unsqueeze(-1), 76 | 'init_state': self.is_initial_states.unsqueeze(-1).float(), 77 | 'eps': epsilon}) 78 | 79 | self.is_initial_states = tensor(terminals).byte() 80 | self.prev_options = options 81 | self.states = next_states 82 | 83 | self.total_steps += config.num_workers 84 | if self.total_steps // config.num_workers % config.target_network_update_freq == 0: 85 | self.target_network.load_state_dict(self.network.state_dict()) 86 | 87 | with torch.no_grad(): 88 | prediction = self.target_network(self.states) 89 | storage.placeholder() 90 | betas = prediction['beta'][self.worker_index, self.prev_options] 91 | ret = (1 - betas) * prediction['q'][self.worker_index, self.prev_options] + \ 92 | betas * torch.max(prediction['q'], dim=-1)[0] 93 | ret = ret.unsqueeze(-1) 94 | 95 | for i in reversed(range(config.rollout_length)): 96 | ret = storage.reward[i] + config.discount * storage.mask[i] * ret 97 | adv = ret - storage.q[i].gather(1, storage.option[i]) 98 | storage.ret[i] = ret 99 | storage.advantage[i] = adv 100 | 101 | v = storage.q[i].max(dim=-1, keepdim=True)[0] * (1 - storage.eps[i]) + storage.q[i].mean(-1).unsqueeze(-1) * \ 102 | storage.eps[i] 103 | q = storage.q[i].gather(1, storage.prev_option[i]) 104 | storage.beta_advantage[i] = q - v + config.termination_regularizer 105 | 106 | entries = storage.extract( 107 | ['q', 'beta', 'log_pi', 'ret', 'advantage', 'beta_advantage', 'entropy', 'option', 'action', 'init_state', 'prev_option']) 108 | 109 | q_loss = (entries.q.gather(1, entries.option) - entries.ret.detach()).pow(2).mul(0.5).mean() 110 | pi_loss = -(entries.log_pi.gather(1, 111 | entries.action) * entries.advantage.detach()) - config.entropy_weight * entries.entropy 112 | pi_loss = pi_loss.mean() 113 | beta_loss = (entries.beta.gather(1, entries.prev_option) * entries.beta_advantage.detach() * (1 - entries.init_state)).mean() 114 | 115 | self.optimizer.zero_grad() 116 | (pi_loss + q_loss + beta_loss).backward() 117 | nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip) 118 | self.optimizer.step() 119 | -------------------------------------------------------------------------------- /deep_rl/agent/PPO_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from .BaseAgent import * 10 | 11 | 12 | class PPOAgent(BaseAgent): 13 | def __init__(self, config): 14 | BaseAgent.__init__(self, config) 15 | self.config = config 16 | self.task = config.task_fn() 17 | self.network = config.network_fn() 18 | if config.shared_repr: 19 | self.opt = config.optimizer_fn(self.network.parameters()) 20 | else: 21 | self.actor_opt = config.actor_opt_fn(self.network.actor_params) 22 | self.critic_opt = config.critic_opt_fn(self.network.critic_params) 23 | self.total_steps = 0 24 | self.states = self.task.reset() 25 | self.states = config.state_normalizer(self.states) 26 | if config.shared_repr: 27 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda step: 1 - step / config.max_steps) 28 | 29 | def step(self): 30 | config = self.config 31 | storage = Storage(config.rollout_length) 32 | states = self.states 33 | for _ in range(config.rollout_length): 34 | prediction = self.network(states) 35 | next_states, rewards, terminals, info = self.task.step(to_np(prediction['action'])) 36 | self.record_online_return(info) 37 | rewards = config.reward_normalizer(rewards) 38 | next_states = config.state_normalizer(next_states) 39 | storage.feed(prediction) 40 | storage.feed({'reward': tensor(rewards).unsqueeze(-1), 41 | 'mask': tensor(1 - terminals).unsqueeze(-1), 42 | 'state': tensor(states)}) 43 | states = next_states 44 | self.total_steps += config.num_workers 45 | 46 | self.states = states 47 | prediction = self.network(states) 48 | storage.feed(prediction) 49 | storage.placeholder() 50 | 51 | advantages = tensor(np.zeros((config.num_workers, 1))) 52 | returns = prediction['v'].detach() 53 | for i in reversed(range(config.rollout_length)): 54 | returns = storage.reward[i] + config.discount * storage.mask[i] * returns 55 | if not config.use_gae: 56 | advantages = returns - storage.v[i].detach() 57 | else: 58 | td_error = storage.reward[i] + config.discount * storage.mask[i] * storage.v[i + 1] - storage.v[i] 59 | advantages = advantages * config.gae_tau * config.discount * storage.mask[i] + td_error 60 | storage.advantage[i] = advantages.detach() 61 | storage.ret[i] = returns.detach() 62 | 63 | entries = storage.extract(['state', 'action', 'log_pi_a', 'ret', 'advantage']) 64 | EntryCLS = entries.__class__ 65 | entries = EntryCLS(*list(map(lambda x: x.detach(), entries))) 66 | entries.advantage.copy_((entries.advantage - entries.advantage.mean()) / entries.advantage.std()) 67 | 68 | if config.shared_repr: 69 | self.lr_scheduler.step(self.total_steps) 70 | 71 | for _ in range(config.optimization_epochs): 72 | sampler = random_sample(np.arange(entries.state.size(0)), config.mini_batch_size) 73 | for batch_indices in sampler: 74 | batch_indices = tensor(batch_indices).long() 75 | entry = EntryCLS(*list(map(lambda x: x[batch_indices], entries))) 76 | 77 | prediction = self.network(entry.state, entry.action) 78 | ratio = (prediction['log_pi_a'] - entry.log_pi_a).exp() 79 | obj = ratio * entry.advantage 80 | obj_clipped = ratio.clamp(1.0 - self.config.ppo_ratio_clip, 81 | 1.0 + self.config.ppo_ratio_clip) * entry.advantage 82 | policy_loss = -torch.min(obj, obj_clipped).mean() - config.entropy_weight * prediction['entropy'].mean() 83 | 84 | value_loss = 0.5 * (entry.ret - prediction['v']).pow(2).mean() 85 | 86 | approx_kl = (entry.log_pi_a - prediction['log_pi_a']).mean() 87 | if config.shared_repr: 88 | self.opt.zero_grad() 89 | (policy_loss + value_loss).backward() 90 | nn.utils.clip_grad_norm_(self.network.parameters(), config.gradient_clip) 91 | self.opt.step() 92 | else: 93 | if approx_kl <= 1.5 * config.target_kl: 94 | self.actor_opt.zero_grad() 95 | policy_loss.backward() 96 | self.actor_opt.step() 97 | self.critic_opt.zero_grad() 98 | value_loss.backward() 99 | self.critic_opt.step() 100 | 101 | -------------------------------------------------------------------------------- /deep_rl/agent/QuantileRegressionDQN_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from ..utils import * 10 | from .BaseAgent import * 11 | from .DQN_agent import * 12 | 13 | 14 | class QuantileRegressionDQNActor(DQNActor): 15 | def __init__(self, config): 16 | super().__init__(config) 17 | 18 | def compute_q(self, prediction): 19 | q_values = prediction['quantile'].mean(-1) 20 | return to_np(q_values) 21 | 22 | 23 | class QuantileRegressionDQNAgent(DQNAgent): 24 | def __init__(self, config): 25 | BaseAgent.__init__(self, config) 26 | self.config = config 27 | config.lock = mp.Lock() 28 | 29 | self.replay = config.replay_fn() 30 | self.actor = QuantileRegressionDQNActor(config) 31 | 32 | self.network = config.network_fn() 33 | self.network.share_memory() 34 | self.target_network = config.network_fn() 35 | self.target_network.load_state_dict(self.network.state_dict()) 36 | self.optimizer = config.optimizer_fn(self.network.parameters()) 37 | 38 | self.actor.set_network(self.network) 39 | 40 | self.total_steps = 0 41 | self.batch_indices = range_tensor(config.batch_size) 42 | 43 | self.quantile_weight = 1.0 / self.config.num_quantiles 44 | self.cumulative_density = tensor( 45 | (2 * np.arange(self.config.num_quantiles) + 1) / (2.0 * self.config.num_quantiles)).view(1, -1) 46 | 47 | def eval_step(self, state): 48 | self.config.state_normalizer.set_read_only() 49 | state = self.config.state_normalizer(state) 50 | q = self.network(state)['quantile'].mean(-1) 51 | action = np.argmax(to_np(q).flatten()) 52 | self.config.state_normalizer.unset_read_only() 53 | return [action] 54 | 55 | def compute_loss(self, transitions): 56 | states = self.config.state_normalizer(transitions.state) 57 | next_states = self.config.state_normalizer(transitions.next_state) 58 | 59 | quantiles_next = self.target_network(next_states)['quantile'].detach() 60 | a_next = torch.argmax(quantiles_next.sum(-1), dim=-1) 61 | quantiles_next = quantiles_next[self.batch_indices, a_next, :] 62 | 63 | rewards = tensor(transitions.reward).unsqueeze(-1) 64 | masks = tensor(transitions.mask).unsqueeze(-1) 65 | quantiles_next = rewards + self.config.discount ** self.config.n_step * masks * quantiles_next 66 | 67 | quantiles = self.network(states)['quantile'] 68 | actions = tensor(transitions.action).long() 69 | quantiles = quantiles[self.batch_indices, actions, :] 70 | 71 | quantiles_next = quantiles_next.t().unsqueeze(-1) 72 | diff = quantiles_next - quantiles 73 | loss = huber(diff) * (self.cumulative_density - (diff.detach() < 0).float()).abs() 74 | return loss.sum(-1).mean(1) 75 | 76 | def reduce_loss(self, loss): 77 | return loss.mean() 78 | -------------------------------------------------------------------------------- /deep_rl/agent/TD3_agent.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from ..network import * 8 | from ..component import * 9 | from .BaseAgent import * 10 | import torchvision 11 | 12 | 13 | class TD3Agent(BaseAgent): 14 | def __init__(self, config): 15 | BaseAgent.__init__(self, config) 16 | self.config = config 17 | self.task = config.task_fn() 18 | self.network = config.network_fn() 19 | self.target_network = config.network_fn() 20 | self.target_network.load_state_dict(self.network.state_dict()) 21 | self.replay = config.replay_fn() 22 | self.random_process = config.random_process_fn() 23 | self.total_steps = 0 24 | self.state = None 25 | 26 | def soft_update(self, target, src): 27 | for target_param, param in zip(target.parameters(), src.parameters()): 28 | target_param.detach_() 29 | target_param.copy_(target_param * (1.0 - self.config.target_network_mix) + 30 | param * self.config.target_network_mix) 31 | 32 | def eval_step(self, state): 33 | self.config.state_normalizer.set_read_only() 34 | state = self.config.state_normalizer(state) 35 | action = self.network(state) 36 | self.config.state_normalizer.unset_read_only() 37 | return to_np(action) 38 | 39 | def step(self): 40 | config = self.config 41 | if self.state is None: 42 | self.random_process.reset_states() 43 | self.state = self.task.reset() 44 | self.state = config.state_normalizer(self.state) 45 | 46 | if self.total_steps < config.warm_up: 47 | action = [self.task.action_space.sample()] 48 | else: 49 | action = self.network(self.state) 50 | action = to_np(action) 51 | action += self.random_process.sample() 52 | action = np.clip(action, self.task.action_space.low, self.task.action_space.high) 53 | next_state, reward, done, info = self.task.step(action) 54 | next_state = self.config.state_normalizer(next_state) 55 | self.record_online_return(info) 56 | reward = self.config.reward_normalizer(reward) 57 | 58 | self.replay.feed(dict( 59 | state=self.state, 60 | action=action, 61 | reward=reward, 62 | next_state=next_state, 63 | mask=1-np.asarray(done, dtype=np.int32), 64 | )) 65 | 66 | if done[0]: 67 | self.random_process.reset_states() 68 | self.state = next_state 69 | self.total_steps += 1 70 | 71 | if self.total_steps >= config.warm_up: 72 | transitions = self.replay.sample() 73 | states = tensor(transitions.state) 74 | actions = tensor(transitions.action) 75 | rewards = tensor(transitions.reward).unsqueeze(-1) 76 | next_states = tensor(transitions.next_state) 77 | mask = tensor(transitions.mask).unsqueeze(-1) 78 | 79 | a_next = self.target_network(next_states) 80 | noise = torch.randn_like(a_next).mul(config.td3_noise) 81 | noise = noise.clamp(-config.td3_noise_clip, config.td3_noise_clip) 82 | 83 | min_a = float(self.task.action_space.low[0]) 84 | max_a = float(self.task.action_space.high[0]) 85 | a_next = (a_next + noise).clamp(min_a, max_a) 86 | 87 | q_1, q_2 = self.target_network.q(next_states, a_next) 88 | target = rewards + config.discount * mask * torch.min(q_1, q_2) 89 | target = target.detach() 90 | 91 | q_1, q_2 = self.network.q(states, actions) 92 | critic_loss = F.mse_loss(q_1, target) + F.mse_loss(q_2, target) 93 | 94 | self.network.zero_grad() 95 | critic_loss.backward() 96 | self.network.critic_opt.step() 97 | 98 | if self.total_steps % config.td3_delay: 99 | action = self.network(states) 100 | policy_loss = -self.network.q(states, action)[0].mean() 101 | 102 | self.network.zero_grad() 103 | policy_loss.backward() 104 | self.network.actor_opt.step() 105 | 106 | self.soft_update(self.target_network, self.network) 107 | -------------------------------------------------------------------------------- /deep_rl/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .DQN_agent import * 2 | from .DDPG_agent import * 3 | from .A2C_agent import * 4 | from .CategoricalDQN_agent import * 5 | from .NStepDQN_agent import * 6 | from .QuantileRegressionDQN_agent import * 7 | from .PPO_agent import * 8 | from .OptionCritic_agent import * 9 | from .TD3_agent import * 10 | -------------------------------------------------------------------------------- /deep_rl/component/__init__.py: -------------------------------------------------------------------------------- 1 | from .replay import * 2 | from .random_process import * 3 | from .envs import Task 4 | from .envs import LazyFrames 5 | -------------------------------------------------------------------------------- /deep_rl/component/envs.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import os 8 | import gym 9 | import numpy as np 10 | import torch 11 | from gym.spaces.box import Box 12 | from gym.spaces.discrete import Discrete 13 | 14 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 15 | from baselines.common.atari_wrappers import FrameStack as FrameStack_ 16 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv, VecEnv 17 | 18 | from ..utils import * 19 | 20 | try: 21 | import roboschool 22 | except ImportError: 23 | pass 24 | 25 | 26 | # adapted from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/envs.py 27 | def make_env(env_id, seed, rank, episode_life=True): 28 | def _thunk(): 29 | random_seed(seed) 30 | if env_id.startswith("dm"): 31 | import dm_control2gym 32 | _, domain, task = env_id.split('-') 33 | env = dm_control2gym.make(domain_name=domain, task_name=task) 34 | else: 35 | env = gym.make(env_id) 36 | is_atari = hasattr(gym.envs, 'atari') and isinstance( 37 | env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 38 | if is_atari: 39 | env = make_atari(env_id) 40 | env.seed(seed + rank) 41 | env = OriginalReturnWrapper(env) 42 | if is_atari: 43 | env = wrap_deepmind(env, 44 | episode_life=episode_life, 45 | clip_rewards=False, 46 | frame_stack=False, 47 | scale=False) 48 | obs_shape = env.observation_space.shape 49 | if len(obs_shape) == 3: 50 | env = TransposeImage(env) 51 | env = FrameStack(env, 4) 52 | 53 | return env 54 | 55 | return _thunk 56 | 57 | 58 | class OriginalReturnWrapper(gym.Wrapper): 59 | def __init__(self, env): 60 | gym.Wrapper.__init__(self, env) 61 | self.total_rewards = 0 62 | 63 | def step(self, action): 64 | obs, reward, done, info = self.env.step(action) 65 | self.total_rewards += reward 66 | if done: 67 | info['episodic_return'] = self.total_rewards 68 | self.total_rewards = 0 69 | else: 70 | info['episodic_return'] = None 71 | return obs, reward, done, info 72 | 73 | def reset(self): 74 | return self.env.reset() 75 | 76 | 77 | class TransposeImage(gym.ObservationWrapper): 78 | def __init__(self, env=None): 79 | super(TransposeImage, self).__init__(env) 80 | obs_shape = self.observation_space.shape 81 | self.observation_space = Box( 82 | self.observation_space.low[0, 0, 0], 83 | self.observation_space.high[0, 0, 0], 84 | [obs_shape[2], obs_shape[1], obs_shape[0]], 85 | dtype=self.observation_space.dtype) 86 | 87 | def observation(self, observation): 88 | return observation.transpose(2, 0, 1) 89 | 90 | 91 | # The original LayzeFrames doesn't work well 92 | class LazyFrames(object): 93 | def __init__(self, frames): 94 | """This object ensures that common frames between the observations are only stored once. 95 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 96 | buffers. 97 | 98 | This object should only be converted to numpy array before being passed to the model. 99 | 100 | You'd not believe how complex the previous solution was.""" 101 | self._frames = frames 102 | 103 | def __array__(self, dtype=None): 104 | out = np.concatenate(self._frames, axis=0) 105 | if dtype is not None: 106 | out = out.astype(dtype) 107 | return out 108 | 109 | def __len__(self): 110 | return len(self.__array__()) 111 | 112 | def __getitem__(self, i): 113 | return self.__array__()[i] 114 | 115 | 116 | class FrameStack(FrameStack_): 117 | def __init__(self, env, k): 118 | FrameStack_.__init__(self, env, k) 119 | 120 | def _get_ob(self): 121 | assert len(self.frames) == self.k 122 | return LazyFrames(list(self.frames)) 123 | 124 | 125 | # The original one in baselines is really bad 126 | class DummyVecEnv(VecEnv): 127 | def __init__(self, env_fns): 128 | self.envs = [fn() for fn in env_fns] 129 | env = self.envs[0] 130 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 131 | self.actions = None 132 | 133 | def step_async(self, actions): 134 | self.actions = actions 135 | 136 | def step_wait(self): 137 | data = [] 138 | for i in range(self.num_envs): 139 | obs, rew, done, info = self.envs[i].step(self.actions[i]) 140 | if done: 141 | obs = self.envs[i].reset() 142 | data.append([obs, rew, done, info]) 143 | obs, rew, done, info = zip(*data) 144 | return obs, np.asarray(rew), np.asarray(done), info 145 | 146 | def reset(self): 147 | return [env.reset() for env in self.envs] 148 | 149 | def close(self): 150 | return 151 | 152 | 153 | class Task: 154 | def __init__(self, 155 | name, 156 | num_envs=1, 157 | single_process=True, 158 | log_dir=None, 159 | episode_life=True, 160 | seed=None): 161 | if seed is None: 162 | seed = np.random.randint(int(1e9)) 163 | if log_dir is not None: 164 | mkdir(log_dir) 165 | envs = [make_env(name, seed, i, episode_life) for i in range(num_envs)] 166 | if single_process: 167 | Wrapper = DummyVecEnv 168 | else: 169 | Wrapper = SubprocVecEnv 170 | self.env = Wrapper(envs) 171 | self.name = name 172 | self.observation_space = self.env.observation_space 173 | self.state_dim = int(np.prod(self.env.observation_space.shape)) 174 | 175 | self.action_space = self.env.action_space 176 | if isinstance(self.action_space, Discrete): 177 | self.action_dim = self.action_space.n 178 | elif isinstance(self.action_space, Box): 179 | self.action_dim = self.action_space.shape[0] 180 | else: 181 | assert 'unknown action space' 182 | 183 | def reset(self): 184 | return self.env.reset() 185 | 186 | def step(self, actions): 187 | if isinstance(self.action_space, Box): 188 | actions = np.clip(actions, self.action_space.low, self.action_space.high) 189 | return self.env.step(actions) 190 | 191 | 192 | if __name__ == '__main__': 193 | task = Task('Hopper-v2', 5, single_process=False) 194 | state = task.reset() 195 | while True: 196 | action = np.random.rand(task.observation_space.shape[0]) 197 | next_state, reward, done, _ = task.step(action) 198 | print(done) 199 | -------------------------------------------------------------------------------- /deep_rl/component/random_process.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import numpy as np 8 | 9 | 10 | class RandomProcess(object): 11 | def reset_states(self): 12 | pass 13 | 14 | 15 | class GaussianProcess(RandomProcess): 16 | def __init__(self, size, std): 17 | self.size = size 18 | self.std = std 19 | 20 | def sample(self): 21 | return np.random.randn(*self.size) * self.std() 22 | 23 | 24 | class OrnsteinUhlenbeckProcess(RandomProcess): 25 | def __init__(self, size, std, theta=.15, dt=1e-2, x0=None): 26 | self.theta = theta 27 | self.mu = 0 28 | self.std = std 29 | self.dt = dt 30 | self.x0 = x0 31 | self.size = size 32 | self.reset_states() 33 | 34 | def sample(self): 35 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.std() * np.sqrt( 36 | self.dt) * np.random.randn(*self.size) 37 | self.x_prev = x 38 | return x 39 | 40 | def reset_states(self): 41 | self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) 42 | -------------------------------------------------------------------------------- /deep_rl/component/replay.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import torch 8 | import numpy as np 9 | import torch.multiprocessing as mp 10 | from collections import deque 11 | from ..utils import * 12 | import random 13 | from collections import namedtuple 14 | 15 | Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'mask']) 16 | PrioritizedTransition = namedtuple('Transition', 17 | ['state', 'action', 'reward', 'next_state', 'mask', 'sampling_prob', 'idx']) 18 | 19 | 20 | class Storage: 21 | def __init__(self, memory_size, keys=None): 22 | if keys is None: 23 | keys = [] 24 | keys = keys + ['state', 'action', 'reward', 'mask', 25 | 'v', 'q', 'pi', 'log_pi', 'entropy', 26 | 'advantage', 'ret', 'q_a', 'log_pi_a', 27 | 'mean', 'next_state'] 28 | self.keys = keys 29 | self.memory_size = memory_size 30 | self.reset() 31 | 32 | def feed(self, data): 33 | for k, v in data.items(): 34 | if k not in self.keys: 35 | raise RuntimeError('Undefined key') 36 | getattr(self, k).append(v) 37 | 38 | def placeholder(self): 39 | for k in self.keys: 40 | v = getattr(self, k) 41 | if len(v) == 0: 42 | setattr(self, k, [None] * self.memory_size) 43 | 44 | def reset(self): 45 | for key in self.keys: 46 | setattr(self, key, []) 47 | self.pos = 0 48 | self._size = 0 49 | 50 | def extract(self, keys): 51 | data = [getattr(self, k)[:self.memory_size] for k in keys] 52 | data = map(lambda x: torch.cat(x, dim=0), data) 53 | Entry = namedtuple('Entry', keys) 54 | return Entry(*list(data)) 55 | 56 | 57 | class UniformReplay(Storage): 58 | TransitionCLS = Transition 59 | 60 | def __init__(self, memory_size, batch_size, n_step=1, discount=1, history_length=1, keys=None): 61 | super(UniformReplay, self).__init__(memory_size, keys) 62 | self.batch_size = batch_size 63 | self.n_step = n_step 64 | self.discount = discount 65 | self.history_length = history_length 66 | self.pos = 0 67 | self._size = 0 68 | 69 | def compute_valid_indices(self): 70 | indices = [] 71 | indices.extend(list(range(self.history_length - 1, self.pos - self.n_step))) 72 | indices.extend(list(range(self.pos + self.history_length - 1, self.size() - self.n_step))) 73 | return np.asarray(indices) 74 | 75 | def feed(self, data): 76 | for k, vs in data.items(): 77 | if k not in self.keys: 78 | raise RuntimeError('Undefined key') 79 | storage = getattr(self, k) 80 | pos = self.pos 81 | size = self.size() 82 | for v in vs: 83 | if pos >= len(storage): 84 | storage.append(v) 85 | size += 1 86 | else: 87 | storage[self.pos] = v 88 | pos = (pos + 1) % self.memory_size 89 | self.pos = pos 90 | self._size = size 91 | 92 | def sample(self, batch_size=None): 93 | if batch_size is None: 94 | batch_size = self.batch_size 95 | 96 | sampled_data = [] 97 | while len(sampled_data) < batch_size: 98 | transition = self.construct_transition(np.random.randint(0, self.size())) 99 | if transition is not None: 100 | sampled_data.append(transition) 101 | sampled_data = zip(*sampled_data) 102 | sampled_data = list(map(lambda x: np.asarray(x), sampled_data)) 103 | return Transition(*sampled_data) 104 | 105 | def valid_index(self, index): 106 | if index - self.history_length + 1 >= 0 and index + self.n_step < self.pos: 107 | return True 108 | if index - self.history_length + 1 >= self.pos and index + self.n_step < self.size(): 109 | return True 110 | return False 111 | 112 | def construct_transition(self, index): 113 | if not self.valid_index(index): 114 | return None 115 | s_start = index - self.history_length + 1 116 | s_end = index 117 | if s_start < 0: 118 | raise RuntimeError('Invalid index') 119 | next_s_start = s_start + self.n_step 120 | next_s_end = s_end + self.n_step 121 | if s_end < self.pos and next_s_end >= self.pos: 122 | raise RuntimeError('Invalid index') 123 | 124 | state = [self.state[i] for i in range(s_start, s_end + 1)] 125 | next_state = [self.state[i] for i in range(next_s_start, next_s_end + 1)] 126 | action = self.action[s_end] 127 | reward = [self.reward[i] for i in range(s_end, s_end + self.n_step)] 128 | mask = [self.mask[i] for i in range(s_end, s_end + self.n_step)] 129 | if self.history_length == 1: 130 | # eliminate the extra dimension if no frame stack 131 | state = state[0] 132 | next_state = next_state[0] 133 | state = np.array(state) 134 | next_state = np.array(next_state) 135 | cum_r = 0 136 | cum_mask = 1 137 | for i in reversed(np.arange(self.n_step)): 138 | cum_r = reward[i] + mask[i] * self.discount * cum_r 139 | cum_mask = cum_mask and mask[i] 140 | return Transition(state=state, action=action, reward=cum_r, next_state=next_state, mask=cum_mask) 141 | 142 | def size(self): 143 | return self._size 144 | 145 | def full(self): 146 | return self._size == self.memory_size 147 | 148 | def update_priorities(self, info): 149 | raise NotImplementedError 150 | 151 | 152 | class PrioritizedReplay(UniformReplay): 153 | TransitionCLS = PrioritizedTransition 154 | 155 | def __init__(self, memory_size, batch_size, n_step=1, discount=1, history_length=1, keys=None): 156 | super(PrioritizedReplay, self).__init__(memory_size, batch_size, n_step, discount, history_length, keys) 157 | self.tree = SumTree(memory_size) 158 | self.max_priority = 1 159 | 160 | def feed(self, data): 161 | super().feed(data) 162 | self.tree.add(self.max_priority, None) 163 | 164 | def sample(self, batch_size=None): 165 | if batch_size is None: 166 | batch_size = self.batch_size 167 | 168 | segment = self.tree.total() / batch_size 169 | 170 | sampled_data = [] 171 | for i in range(batch_size): 172 | a = segment * i 173 | b = segment * (i + 1) 174 | s = random.uniform(a, b) 175 | (idx, p, data_index) = self.tree.get(s) 176 | transition = super().construct_transition(data_index) 177 | if transition is None: 178 | continue 179 | sampled_data.append(PrioritizedTransition( 180 | *transition, 181 | sampling_prob=p / self.tree.total(), 182 | idx=idx, 183 | )) 184 | while len(sampled_data) < batch_size: 185 | # This should rarely happen 186 | sampled_data.append(random.choice(sampled_data)) 187 | 188 | sampled_data = zip(*sampled_data) 189 | sampled_data = list(map(lambda x: np.asarray(x), sampled_data)) 190 | sampled_data = PrioritizedTransition(*sampled_data) 191 | return sampled_data 192 | 193 | def update_priorities(self, info): 194 | for idx, priority in info: 195 | self.max_priority = max(self.max_priority, priority) 196 | self.tree.update(idx, priority) 197 | 198 | 199 | class ReplayWrapper(mp.Process): 200 | FEED = 0 201 | SAMPLE = 1 202 | EXIT = 2 203 | UPDATE_PRIORITIES = 3 204 | 205 | def __init__(self, replay_cls, replay_kwargs, async=True): 206 | mp.Process.__init__(self) 207 | self.replay_kwargs = replay_kwargs 208 | self.replay_cls = replay_cls 209 | self.cache_len = 2 210 | if async: 211 | self.pipe, self.worker_pipe = mp.Pipe() 212 | self.start() 213 | else: 214 | self.replay = replay_cls(**replay_kwargs) 215 | self.sample = self.replay.sample 216 | self.feed = self.replay.feed 217 | self.update_priorities = self.replay.update_priorities 218 | 219 | def run(self): 220 | replay = self.replay_cls(**self.replay_kwargs) 221 | 222 | cache = [] 223 | 224 | cache_initialized = False 225 | cur_cache = 0 226 | 227 | def set_up_cache(): 228 | batch_data = replay.sample() 229 | batch_data = [tensor(x) for x in batch_data] 230 | for i in range(self.cache_len): 231 | cache.append([x.clone() for x in batch_data]) 232 | for x in cache[i]: x.share_memory_() 233 | sample(0) 234 | sample(1) 235 | 236 | def sample(cur_cache): 237 | batch_data = replay.sample() 238 | batch_data = [tensor(x) for x in batch_data] 239 | for cache_x, x in zip(cache[cur_cache], batch_data): 240 | cache_x.copy_(x) 241 | 242 | while True: 243 | op, data = self.worker_pipe.recv() 244 | if op == self.FEED: 245 | replay.feed(data) 246 | elif op == self.SAMPLE: 247 | if cache_initialized: 248 | self.worker_pipe.send([cur_cache, None]) 249 | else: 250 | set_up_cache() 251 | cache_initialized = True 252 | self.worker_pipe.send([cur_cache, cache]) 253 | cur_cache = (cur_cache + 1) % 2 254 | sample(cur_cache) 255 | elif op == self.UPDATE_PRIORITIES: 256 | replay.update_priorities(data) 257 | elif op == self.EXIT: 258 | self.worker_pipe.close() 259 | return 260 | else: 261 | raise Exception('Unknown command') 262 | 263 | def feed(self, exp): 264 | self.pipe.send([self.FEED, exp]) 265 | 266 | def sample(self): 267 | self.pipe.send([self.SAMPLE, None]) 268 | cache_id, data = self.pipe.recv() 269 | if data is not None: 270 | self.cache = data 271 | return self.replay_cls.TransitionCLS(*self.cache[cache_id]) 272 | 273 | def update_priorities(self, info): 274 | self.pipe.send([self.UPDATE_PRIORITIES, info]) 275 | 276 | def close(self): 277 | self.pipe.send([self.EXIT, None]) 278 | self.pipe.close() 279 | -------------------------------------------------------------------------------- /deep_rl/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .network_utils import * 2 | from .network_bodies import * 3 | from .network_heads import * 4 | -------------------------------------------------------------------------------- /deep_rl/network/network_bodies.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from .network_utils import * 8 | 9 | 10 | class NatureConvBody(nn.Module): 11 | def __init__(self, in_channels=4, noisy_linear=False): 12 | super(NatureConvBody, self).__init__() 13 | self.feature_dim = 512 14 | self.conv1 = layer_init(nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)) 15 | self.conv2 = layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)) 16 | self.conv3 = layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)) 17 | if noisy_linear: 18 | self.fc4 = NoisyLinear(7 * 7 * 64, self.feature_dim) 19 | else: 20 | self.fc4 = layer_init(nn.Linear(7 * 7 * 64, self.feature_dim)) 21 | self.noisy_linear = noisy_linear 22 | 23 | def reset_noise(self): 24 | if self.noisy_linear: 25 | self.fc4.reset_noise() 26 | 27 | def forward(self, x): 28 | y = F.relu(self.conv1(x)) 29 | y = F.relu(self.conv2(y)) 30 | y = F.relu(self.conv3(y)) 31 | y = y.view(y.size(0), -1) 32 | y = F.relu(self.fc4(y)) 33 | return y 34 | 35 | 36 | class DDPGConvBody(nn.Module): 37 | def __init__(self, in_channels=4): 38 | super(DDPGConvBody, self).__init__() 39 | self.feature_dim = 39 * 39 * 32 40 | self.conv1 = layer_init(nn.Conv2d(in_channels, 32, kernel_size=3, stride=2)) 41 | self.conv2 = layer_init(nn.Conv2d(32, 32, kernel_size=3)) 42 | 43 | def forward(self, x): 44 | y = F.elu(self.conv1(x)) 45 | y = F.elu(self.conv2(y)) 46 | y = y.view(y.size(0), -1) 47 | return y 48 | 49 | 50 | class FCBody(nn.Module): 51 | def __init__(self, state_dim, hidden_units=(64, 64), gate=F.relu, noisy_linear=False): 52 | super(FCBody, self).__init__() 53 | dims = (state_dim,) + hidden_units 54 | if noisy_linear: 55 | self.layers = nn.ModuleList( 56 | [NoisyLinear(dim_in, dim_out) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 57 | else: 58 | self.layers = nn.ModuleList( 59 | [layer_init(nn.Linear(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 60 | 61 | self.gate = gate 62 | self.feature_dim = dims[-1] 63 | self.noisy_linear = noisy_linear 64 | 65 | def reset_noise(self): 66 | if self.noisy_linear: 67 | for layer in self.layers: 68 | layer.reset_noise() 69 | 70 | def forward(self, x): 71 | for layer in self.layers: 72 | x = self.gate(layer(x)) 73 | return x 74 | 75 | 76 | class DummyBody(nn.Module): 77 | def __init__(self, state_dim): 78 | super(DummyBody, self).__init__() 79 | self.feature_dim = state_dim 80 | 81 | def forward(self, x): 82 | return x 83 | -------------------------------------------------------------------------------- /deep_rl/network/network_heads.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from .network_utils import * 8 | from .network_bodies import * 9 | 10 | 11 | class VanillaNet(nn.Module, BaseNet): 12 | def __init__(self, output_dim, body): 13 | super(VanillaNet, self).__init__() 14 | self.fc_head = layer_init(nn.Linear(body.feature_dim, output_dim)) 15 | self.body = body 16 | self.to(Config.DEVICE) 17 | 18 | def forward(self, x): 19 | phi = self.body(tensor(x)) 20 | q = self.fc_head(phi) 21 | return dict(q=q) 22 | 23 | 24 | class DuelingNet(nn.Module, BaseNet): 25 | def __init__(self, action_dim, body): 26 | super(DuelingNet, self).__init__() 27 | self.fc_value = layer_init(nn.Linear(body.feature_dim, 1)) 28 | self.fc_advantage = layer_init(nn.Linear(body.feature_dim, action_dim)) 29 | self.body = body 30 | self.to(Config.DEVICE) 31 | 32 | def forward(self, x, to_numpy=False): 33 | phi = self.body(tensor(x)) 34 | value = self.fc_value(phi) 35 | advantange = self.fc_advantage(phi) 36 | q = value.expand_as(advantange) + (advantange - advantange.mean(1, keepdim=True).expand_as(advantange)) 37 | return dict(q=q) 38 | 39 | 40 | class CategoricalNet(nn.Module, BaseNet): 41 | def __init__(self, action_dim, num_atoms, body): 42 | super(CategoricalNet, self).__init__() 43 | self.fc_categorical = layer_init(nn.Linear(body.feature_dim, action_dim * num_atoms)) 44 | self.action_dim = action_dim 45 | self.num_atoms = num_atoms 46 | self.body = body 47 | self.to(Config.DEVICE) 48 | 49 | def forward(self, x): 50 | phi = self.body(tensor(x)) 51 | pre_prob = self.fc_categorical(phi).view((-1, self.action_dim, self.num_atoms)) 52 | prob = F.softmax(pre_prob, dim=-1) 53 | log_prob = F.log_softmax(pre_prob, dim=-1) 54 | return dict(prob=prob, log_prob=log_prob) 55 | 56 | 57 | class RainbowNet(nn.Module, BaseNet): 58 | def __init__(self, action_dim, num_atoms, body, noisy_linear): 59 | super(RainbowNet, self).__init__() 60 | if noisy_linear: 61 | self.fc_value = NoisyLinear(body.feature_dim, num_atoms) 62 | self.fc_advantage = NoisyLinear(body.feature_dim, action_dim * num_atoms) 63 | else: 64 | self.fc_value = layer_init(nn.Linear(body.feature_dim, num_atoms)) 65 | self.fc_advantage = layer_init(nn.Linear(body.feature_dim, action_dim * num_atoms)) 66 | 67 | self.action_dim = action_dim 68 | self.num_atoms = num_atoms 69 | self.body = body 70 | self.noisy_linear = noisy_linear 71 | self.to(Config.DEVICE) 72 | 73 | def reset_noise(self): 74 | if self.noisy_linear: 75 | self.fc_value.reset_noise() 76 | self.fc_advantage.reset_noise() 77 | self.body.reset_noise() 78 | 79 | def forward(self, x): 80 | phi = self.body(tensor(x)) 81 | value = self.fc_value(phi).view((-1, 1, self.num_atoms)) 82 | advantage = self.fc_advantage(phi).view(-1, self.action_dim, self.num_atoms) 83 | q = value + (advantage - advantage.mean(1, keepdim=True)) 84 | prob = F.softmax(q, dim=-1) 85 | log_prob = F.log_softmax(q, dim=-1) 86 | return dict(prob=prob, log_prob=log_prob) 87 | 88 | 89 | class QuantileNet(nn.Module, BaseNet): 90 | def __init__(self, action_dim, num_quantiles, body): 91 | super(QuantileNet, self).__init__() 92 | self.fc_quantiles = layer_init(nn.Linear(body.feature_dim, action_dim * num_quantiles)) 93 | self.action_dim = action_dim 94 | self.num_quantiles = num_quantiles 95 | self.body = body 96 | self.to(Config.DEVICE) 97 | 98 | def forward(self, x): 99 | phi = self.body(tensor(x)) 100 | quantiles = self.fc_quantiles(phi) 101 | quantiles = quantiles.view((-1, self.action_dim, self.num_quantiles)) 102 | return dict(quantile=quantiles) 103 | 104 | 105 | class OptionCriticNet(nn.Module, BaseNet): 106 | def __init__(self, body, action_dim, num_options): 107 | super(OptionCriticNet, self).__init__() 108 | self.fc_q = layer_init(nn.Linear(body.feature_dim, num_options)) 109 | self.fc_pi = layer_init(nn.Linear(body.feature_dim, num_options * action_dim)) 110 | self.fc_beta = layer_init(nn.Linear(body.feature_dim, num_options)) 111 | self.num_options = num_options 112 | self.action_dim = action_dim 113 | self.body = body 114 | self.to(Config.DEVICE) 115 | 116 | def forward(self, x): 117 | phi = self.body(tensor(x)) 118 | q = self.fc_q(phi) 119 | beta = F.sigmoid(self.fc_beta(phi)) 120 | pi = self.fc_pi(phi) 121 | pi = pi.view(-1, self.num_options, self.action_dim) 122 | log_pi = F.log_softmax(pi, dim=-1) 123 | pi = F.softmax(pi, dim=-1) 124 | return {'q': q, 125 | 'beta': beta, 126 | 'log_pi': log_pi, 127 | 'pi': pi} 128 | 129 | 130 | class DeterministicActorCriticNet(nn.Module, BaseNet): 131 | def __init__(self, 132 | state_dim, 133 | action_dim, 134 | actor_opt_fn, 135 | critic_opt_fn, 136 | phi_body=None, 137 | actor_body=None, 138 | critic_body=None): 139 | super(DeterministicActorCriticNet, self).__init__() 140 | if phi_body is None: phi_body = DummyBody(state_dim) 141 | if actor_body is None: actor_body = DummyBody(phi_body.feature_dim) 142 | if critic_body is None: critic_body = DummyBody(phi_body.feature_dim) 143 | self.phi_body = phi_body 144 | self.actor_body = actor_body 145 | self.critic_body = critic_body 146 | self.fc_action = layer_init(nn.Linear(actor_body.feature_dim, action_dim), 1e-3) 147 | self.fc_critic = layer_init(nn.Linear(critic_body.feature_dim, 1), 1e-3) 148 | 149 | self.actor_params = list(self.actor_body.parameters()) + list(self.fc_action.parameters()) 150 | self.critic_params = list(self.critic_body.parameters()) + list(self.fc_critic.parameters()) 151 | self.phi_params = list(self.phi_body.parameters()) 152 | 153 | self.actor_opt = actor_opt_fn(self.actor_params + self.phi_params) 154 | self.critic_opt = critic_opt_fn(self.critic_params + self.phi_params) 155 | self.to(Config.DEVICE) 156 | 157 | def forward(self, obs): 158 | phi = self.feature(obs) 159 | action = self.actor(phi) 160 | return action 161 | 162 | def feature(self, obs): 163 | obs = tensor(obs) 164 | return self.phi_body(obs) 165 | 166 | def actor(self, phi): 167 | return torch.tanh(self.fc_action(self.actor_body(phi))) 168 | 169 | def critic(self, phi, a): 170 | return self.fc_critic(self.critic_body(torch.cat([phi, a], dim=1))) 171 | 172 | 173 | class GaussianActorCriticNet(nn.Module, BaseNet): 174 | def __init__(self, 175 | state_dim, 176 | action_dim, 177 | phi_body=None, 178 | actor_body=None, 179 | critic_body=None): 180 | super(GaussianActorCriticNet, self).__init__() 181 | if phi_body is None: phi_body = DummyBody(state_dim) 182 | if actor_body is None: actor_body = DummyBody(phi_body.feature_dim) 183 | if critic_body is None: critic_body = DummyBody(phi_body.feature_dim) 184 | self.phi_body = phi_body 185 | self.actor_body = actor_body 186 | self.critic_body = critic_body 187 | self.fc_action = layer_init(nn.Linear(actor_body.feature_dim, action_dim), 1e-3) 188 | self.fc_critic = layer_init(nn.Linear(critic_body.feature_dim, 1), 1e-3) 189 | self.std = nn.Parameter(torch.zeros(action_dim)) 190 | self.phi_params = list(self.phi_body.parameters()) 191 | 192 | self.actor_params = list(self.actor_body.parameters()) + list(self.fc_action.parameters()) + self.phi_params 193 | self.actor_params.append(self.std) 194 | self.critic_params = list(self.critic_body.parameters()) + list(self.fc_critic.parameters()) + self.phi_params 195 | 196 | self.to(Config.DEVICE) 197 | 198 | def forward(self, obs, action=None): 199 | obs = tensor(obs) 200 | phi = self.phi_body(obs) 201 | phi_a = self.actor_body(phi) 202 | phi_v = self.critic_body(phi) 203 | mean = torch.tanh(self.fc_action(phi_a)) 204 | v = self.fc_critic(phi_v) 205 | dist = torch.distributions.Normal(mean, F.softplus(self.std)) 206 | if action is None: 207 | action = dist.sample() 208 | log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1) 209 | entropy = dist.entropy().sum(-1).unsqueeze(-1) 210 | return {'action': action, 211 | 'log_pi_a': log_prob, 212 | 'entropy': entropy, 213 | 'mean': mean, 214 | 'v': v} 215 | 216 | 217 | class CategoricalActorCriticNet(nn.Module, BaseNet): 218 | def __init__(self, 219 | state_dim, 220 | action_dim, 221 | phi_body=None, 222 | actor_body=None, 223 | critic_body=None): 224 | super(CategoricalActorCriticNet, self).__init__() 225 | if phi_body is None: phi_body = DummyBody(state_dim) 226 | if actor_body is None: actor_body = DummyBody(phi_body.feature_dim) 227 | if critic_body is None: critic_body = DummyBody(phi_body.feature_dim) 228 | self.phi_body = phi_body 229 | self.actor_body = actor_body 230 | self.critic_body = critic_body 231 | self.fc_action = layer_init(nn.Linear(actor_body.feature_dim, action_dim), 1e-3) 232 | self.fc_critic = layer_init(nn.Linear(critic_body.feature_dim, 1), 1e-3) 233 | 234 | self.actor_params = list(self.actor_body.parameters()) + list(self.fc_action.parameters()) 235 | self.critic_params = list(self.critic_body.parameters()) + list(self.fc_critic.parameters()) 236 | self.phi_params = list(self.phi_body.parameters()) 237 | 238 | self.to(Config.DEVICE) 239 | 240 | def forward(self, obs, action=None): 241 | obs = tensor(obs) 242 | phi = self.phi_body(obs) 243 | phi_a = self.actor_body(phi) 244 | phi_v = self.critic_body(phi) 245 | logits = self.fc_action(phi_a) 246 | v = self.fc_critic(phi_v) 247 | dist = torch.distributions.Categorical(logits=logits) 248 | if action is None: 249 | action = dist.sample() 250 | log_prob = dist.log_prob(action).unsqueeze(-1) 251 | entropy = dist.entropy().unsqueeze(-1) 252 | return {'action': action, 253 | 'log_pi_a': log_prob, 254 | 'entropy': entropy, 255 | 'v': v} 256 | 257 | 258 | class TD3Net(nn.Module, BaseNet): 259 | def __init__(self, 260 | action_dim, 261 | actor_body_fn, 262 | critic_body_fn, 263 | actor_opt_fn, 264 | critic_opt_fn, 265 | ): 266 | super(TD3Net, self).__init__() 267 | self.actor_body = actor_body_fn() 268 | self.critic_body_1 = critic_body_fn() 269 | self.critic_body_2 = critic_body_fn() 270 | 271 | self.fc_action = layer_init(nn.Linear(self.actor_body.feature_dim, action_dim), 1e-3) 272 | self.fc_critic_1 = layer_init(nn.Linear(self.critic_body_1.feature_dim, 1), 1e-3) 273 | self.fc_critic_2 = layer_init(nn.Linear(self.critic_body_2.feature_dim, 1), 1e-3) 274 | 275 | self.actor_params = list(self.actor_body.parameters()) + list(self.fc_action.parameters()) 276 | self.critic_params = list(self.critic_body_1.parameters()) + list(self.fc_critic_1.parameters()) +\ 277 | list(self.critic_body_2.parameters()) + list(self.fc_critic_2.parameters()) 278 | 279 | self.actor_opt = actor_opt_fn(self.actor_params) 280 | self.critic_opt = critic_opt_fn(self.critic_params) 281 | self.to(Config.DEVICE) 282 | 283 | def forward(self, obs): 284 | obs = tensor(obs) 285 | return torch.tanh(self.fc_action(self.actor_body(obs))) 286 | 287 | def q(self, obs, a): 288 | obs = tensor(obs) 289 | a = tensor(a) 290 | x = torch.cat([obs, a], dim=1) 291 | q_1 = self.fc_critic_1(self.critic_body_1(x)) 292 | q_2 = self.fc_critic_2(self.critic_body_2(x)) 293 | return q_1, q_2 294 | -------------------------------------------------------------------------------- /deep_rl/network/network_utils.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import math 12 | from ..utils import * 13 | 14 | 15 | class BaseNet: 16 | def __init__(self): 17 | pass 18 | 19 | def reset_noise(self): 20 | pass 21 | 22 | 23 | def layer_init(layer, w_scale=1.0): 24 | nn.init.orthogonal_(layer.weight.data) 25 | layer.weight.data.mul_(w_scale) 26 | nn.init.constant_(layer.bias.data, 0) 27 | return layer 28 | 29 | 30 | # Adapted from https://github.com/saj1919/RL-Adventure/blob/master/5.noisy%20dqn.ipynb 31 | class NoisyLinear(nn.Module): 32 | def __init__(self, in_features, out_features, std_init=0.4): 33 | super(NoisyLinear, self).__init__() 34 | 35 | self.in_features = in_features 36 | self.out_features = out_features 37 | self.std_init = std_init 38 | 39 | self.weight_mu = nn.Parameter(torch.zeros((out_features, in_features)), requires_grad=True) 40 | self.weight_sigma = nn.Parameter(torch.zeros((out_features, in_features)), requires_grad=True) 41 | self.register_buffer('weight_epsilon', torch.zeros((out_features, in_features))) 42 | 43 | self.bias_mu = nn.Parameter(torch.zeros(out_features), requires_grad=True) 44 | self.bias_sigma = nn.Parameter(torch.zeros(out_features), requires_grad=True) 45 | self.register_buffer('bias_epsilon', torch.zeros(out_features)) 46 | 47 | self.register_buffer('noise_in', torch.zeros(in_features)) 48 | self.register_buffer('noise_out_weight', torch.zeros(out_features)) 49 | self.register_buffer('noise_out_bias', torch.zeros(out_features)) 50 | 51 | self.reset_parameters() 52 | self.reset_noise() 53 | 54 | def forward(self, x): 55 | if self.training: 56 | weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon) 57 | bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon) 58 | else: 59 | weight = self.weight_mu 60 | bias = self.bias_mu 61 | 62 | return F.linear(x, weight, bias) 63 | 64 | def reset_parameters(self): 65 | mu_range = 1 / math.sqrt(self.weight_mu.size(1)) 66 | 67 | self.weight_mu.data.uniform_(-mu_range, mu_range) 68 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) 69 | 70 | self.bias_mu.data.uniform_(-mu_range, mu_range) 71 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) 72 | 73 | def reset_noise(self): 74 | self.noise_in.normal_(std=Config.NOISY_LAYER_STD) 75 | self.noise_out_weight.normal_(std=Config.NOISY_LAYER_STD) 76 | self.noise_out_bias.normal_(std=Config.NOISY_LAYER_STD) 77 | 78 | self.weight_epsilon.copy_(self.transform_noise(self.noise_out_weight).ger( 79 | self.transform_noise(self.noise_in))) 80 | self.bias_epsilon.copy_(self.transform_noise(self.noise_out_bias)) 81 | 82 | def transform_noise(self, x): 83 | return x.sign().mul(x.abs().sqrt()) 84 | -------------------------------------------------------------------------------- /deep_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .normalizer import * 3 | from .misc import * 4 | from .logger import * 5 | from .plot import Plotter 6 | from .schedule import * 7 | from .torch_utils import * 8 | from .sum_tree import * 9 | -------------------------------------------------------------------------------- /deep_rl/utils/config.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | from .normalizer import * 7 | import argparse 8 | import torch 9 | 10 | 11 | class Config: 12 | DEVICE = torch.device('cpu') 13 | NOISY_LAYER_STD = 0.1 14 | DEFAULT_REPLAY = 'replay' 15 | PRIORITIZED_REPLAY = 'prioritized_replay' 16 | 17 | def __init__(self): 18 | self.parser = argparse.ArgumentParser() 19 | self.task_fn = None 20 | self.optimizer_fn = None 21 | self.actor_optimizer_fn = None 22 | self.critic_optimizer_fn = None 23 | self.network_fn = None 24 | self.actor_network_fn = None 25 | self.critic_network_fn = None 26 | self.replay_fn = None 27 | self.random_process_fn = None 28 | self.discount = None 29 | self.target_network_update_freq = None 30 | self.exploration_steps = None 31 | self.log_level = 0 32 | self.history_length = None 33 | self.double_q = False 34 | self.tag = 'vanilla' 35 | self.num_workers = 1 36 | self.gradient_clip = None 37 | self.entropy_weight = 0 38 | self.use_gae = False 39 | self.gae_tau = 1.0 40 | self.target_network_mix = 0.001 41 | self.state_normalizer = RescaleNormalizer() 42 | self.reward_normalizer = RescaleNormalizer() 43 | self.min_memory_size = None 44 | self.max_steps = 0 45 | self.rollout_length = None 46 | self.value_loss_weight = 1.0 47 | self.iteration_log_interval = 30 48 | self.categorical_v_min = None 49 | self.categorical_v_max = None 50 | self.categorical_n_atoms = 51 51 | self.num_quantiles = None 52 | self.optimization_epochs = 4 53 | self.mini_batch_size = 64 54 | self.termination_regularizer = 0 55 | self.sgd_update_frequency = None 56 | self.random_action_prob = None 57 | self.__eval_env = None 58 | self.log_interval = int(1e3) 59 | self.save_interval = 0 60 | self.eval_interval = 0 61 | self.eval_episodes = 10 62 | self.async_actor = True 63 | self.tasks = False 64 | self.replay_type = Config.DEFAULT_REPLAY 65 | self.decaying_lr = False 66 | self.shared_repr = False 67 | self.noisy_linear = False 68 | self.n_step = 1 69 | 70 | @property 71 | def eval_env(self): 72 | return self.__eval_env 73 | 74 | @eval_env.setter 75 | def eval_env(self, env): 76 | self.__eval_env = env 77 | self.state_dim = env.state_dim 78 | self.action_dim = env.action_dim 79 | self.task_name = env.name 80 | 81 | def add_argument(self, *args, **kwargs): 82 | self.parser.add_argument(*args, **kwargs) 83 | 84 | def merge(self, config_dict=None): 85 | if config_dict is None: 86 | args = self.parser.parse_args() 87 | config_dict = args.__dict__ 88 | for key in config_dict.keys(): 89 | setattr(self, key, config_dict[key]) 90 | -------------------------------------------------------------------------------- /deep_rl/utils/logger.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from torch.utils.tensorboard import SummaryWriter 8 | import os 9 | import numpy as np 10 | import torch 11 | import logging 12 | 13 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 14 | from .misc import * 15 | 16 | 17 | def get_logger(tag='default', log_level=0): 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | if tag is not None: 21 | fh = logging.FileHandler('./log/%s-%s.txt' % (tag, get_time_str())) 22 | fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s')) 23 | fh.setLevel(logging.INFO) 24 | logger.addHandler(fh) 25 | return Logger(logger, './tf_log/logger-%s-%s' % (tag, get_time_str()), log_level) 26 | 27 | 28 | class Logger(object): 29 | def __init__(self, vanilla_logger, log_dir, log_level=0): 30 | self.log_level = log_level 31 | self.writer = None 32 | if vanilla_logger is not None: 33 | self.info = vanilla_logger.info 34 | self.debug = vanilla_logger.debug 35 | self.warning = vanilla_logger.warning 36 | self.all_steps = {} 37 | self.log_dir = log_dir 38 | 39 | def lazy_init_writer(self): 40 | if self.writer is None: 41 | self.writer = SummaryWriter(self.log_dir) 42 | 43 | def to_numpy(self, v): 44 | if isinstance(v, torch.Tensor): 45 | v = v.cpu().detach().numpy() 46 | return v 47 | 48 | def get_step(self, tag): 49 | if tag not in self.all_steps: 50 | self.all_steps[tag] = 0 51 | step = self.all_steps[tag] 52 | self.all_steps[tag] += 1 53 | return step 54 | 55 | def add_scalar(self, tag, value, step=None, log_level=0): 56 | self.lazy_init_writer() 57 | if log_level > self.log_level: 58 | return 59 | value = self.to_numpy(value) 60 | if step is None: 61 | step = self.get_step(tag) 62 | if np.isscalar(value): 63 | value = np.asarray([value]) 64 | self.writer.add_scalar(tag, value, step) 65 | 66 | def add_histogram(self, tag, values, step=None, log_level=0): 67 | self.lazy_init_writer() 68 | if log_level > self.log_level: 69 | return 70 | values = self.to_numpy(values) 71 | if step is None: 72 | step = self.get_step(tag) 73 | self.writer.add_histogram(tag, values, step) 74 | -------------------------------------------------------------------------------- /deep_rl/utils/misc.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import numpy as np 8 | import pickle 9 | import os 10 | import datetime 11 | import torch 12 | import time 13 | from .torch_utils import * 14 | from pathlib import Path 15 | import itertools 16 | from collections import OrderedDict, Sequence 17 | 18 | 19 | def run_steps(agent): 20 | config = agent.config 21 | agent_name = agent.__class__.__name__ 22 | t0 = time.time() 23 | while True: 24 | if config.save_interval and not agent.total_steps % config.save_interval: 25 | agent.save('data/%s-%s-%d' % (agent_name, config.tag, agent.total_steps)) 26 | if config.log_interval and not agent.total_steps % config.log_interval: 27 | agent.logger.info('steps %d, %.2f steps/s' % (agent.total_steps, config.log_interval / (time.time() - t0))) 28 | t0 = time.time() 29 | if config.eval_interval and not agent.total_steps % config.eval_interval: 30 | agent.eval_episodes() 31 | if config.max_steps and agent.total_steps >= config.max_steps: 32 | agent.close() 33 | break 34 | agent.step() 35 | agent.switch_task() 36 | 37 | 38 | def get_time_str(): 39 | return datetime.datetime.now().strftime("%y%m%d-%H%M%S") 40 | 41 | 42 | def get_default_log_dir(name): 43 | return './log/%s-%s' % (name, get_time_str()) 44 | 45 | 46 | def mkdir(path): 47 | Path(path).mkdir(parents=True, exist_ok=True) 48 | 49 | 50 | def close_obj(obj): 51 | if hasattr(obj, 'close'): 52 | obj.close() 53 | 54 | 55 | def random_sample(indices, batch_size): 56 | indices = np.asarray(np.random.permutation(indices)) 57 | batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size) 58 | for batch in batches: 59 | yield batch 60 | r = len(indices) % batch_size 61 | if r: 62 | yield indices[-r:] 63 | 64 | 65 | def is_plain_type(x): 66 | for t in [str, int, float, bool]: 67 | if isinstance(x, t): 68 | return True 69 | return False 70 | 71 | 72 | def generate_tag(params): 73 | if 'tag' in params.keys(): 74 | return 75 | game = params['game'] 76 | params.setdefault('run', 0) 77 | run = params['run'] 78 | del params['game'] 79 | del params['run'] 80 | str = ['%s_%s' % (k, v if is_plain_type(v) else v.__name__) for k, v in sorted(params.items())] 81 | tag = '%s-%s-run-%d' % (game, '-'.join(str), run) 82 | params['tag'] = tag 83 | params['game'] = game 84 | params['run'] = run 85 | 86 | 87 | def translate(pattern): 88 | groups = pattern.split('.') 89 | pattern = ('\.').join(groups) 90 | return pattern 91 | 92 | 93 | def split(a, n): 94 | k, m = divmod(len(a), n) 95 | return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)) 96 | 97 | 98 | class HyperParameter: 99 | def __init__(self, id, param): 100 | self.id = id 101 | self.param = dict() 102 | for key, item in param: 103 | self.param[key] = item 104 | 105 | def __str__(self): 106 | return str(self.id) 107 | 108 | def dict(self): 109 | return self.param 110 | 111 | 112 | class HyperParameters(Sequence): 113 | def __init__(self, ordered_params): 114 | if not isinstance(ordered_params, OrderedDict): 115 | raise NotImplementedError 116 | params = [] 117 | for key in ordered_params.keys(): 118 | param = [[key, iterm] for iterm in ordered_params[key]] 119 | params.append(param) 120 | self.params = list(itertools.product(*params)) 121 | 122 | def __getitem__(self, index): 123 | return HyperParameter(index, self.params[index]) 124 | 125 | def __len__(self): 126 | return len(self.params) -------------------------------------------------------------------------------- /deep_rl/utils/normalizer.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | import numpy as np 7 | import torch 8 | from baselines.common.running_mean_std import RunningMeanStd 9 | 10 | 11 | class BaseNormalizer: 12 | def __init__(self, read_only=False): 13 | self.read_only = read_only 14 | 15 | def set_read_only(self): 16 | self.read_only = True 17 | 18 | def unset_read_only(self): 19 | self.read_only = False 20 | 21 | def state_dict(self): 22 | return None 23 | 24 | def load_state_dict(self, _): 25 | return 26 | 27 | 28 | class MeanStdNormalizer(BaseNormalizer): 29 | def __init__(self, read_only=False, clip=10.0, epsilon=1e-8): 30 | BaseNormalizer.__init__(self, read_only) 31 | self.read_only = read_only 32 | self.rms = None 33 | self.clip = clip 34 | self.epsilon = epsilon 35 | 36 | def __call__(self, x): 37 | x = np.asarray(x) 38 | if self.rms is None: 39 | self.rms = RunningMeanStd(shape=(1,) + x.shape[1:]) 40 | if not self.read_only: 41 | self.rms.update(x) 42 | return np.clip((x - self.rms.mean) / np.sqrt(self.rms.var + self.epsilon), 43 | -self.clip, self.clip) 44 | 45 | def state_dict(self): 46 | return {'mean': self.rms.mean, 47 | 'var': self.rms.var} 48 | 49 | def load_state_dict(self, saved): 50 | self.rms.mean = saved['mean'] 51 | self.rms.var = saved['var'] 52 | 53 | class RescaleNormalizer(BaseNormalizer): 54 | def __init__(self, coef=1.0): 55 | BaseNormalizer.__init__(self) 56 | self.coef = coef 57 | 58 | def __call__(self, x): 59 | if not isinstance(x, torch.Tensor): 60 | x = np.asarray(x) 61 | return self.coef * x 62 | 63 | 64 | class ImageNormalizer(RescaleNormalizer): 65 | def __init__(self): 66 | RescaleNormalizer.__init__(self, 1.0 / 255) 67 | 68 | 69 | class SignNormalizer(BaseNormalizer): 70 | def __call__(self, x): 71 | return np.sign(x) 72 | -------------------------------------------------------------------------------- /deep_rl/utils/plot.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | import numpy as np 8 | import os 9 | import re 10 | 11 | 12 | class Plotter: 13 | COLORS = ['blue', 'green', 'red', 'black', 'cyan', 'magenta', 'yellow', 'brown', 'purple', 'pink', 14 | 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise', 15 | 'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue'] 16 | 17 | RETURN_TRAIN = 'episodic_return_train' 18 | RETURN_TEST = 'episodic_return_test' 19 | 20 | def __init__(self): 21 | pass 22 | 23 | def _rolling_window(self, a, window): 24 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 25 | strides = a.strides + (a.strides[-1],) 26 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 27 | 28 | def _window_func(self, x, y, window, func): 29 | yw = self._rolling_window(y, window) 30 | yw_func = func(yw, axis=-1) 31 | return x[window - 1:], yw_func 32 | 33 | def load_results(self, dirs, **kwargs): 34 | kwargs.setdefault('tag', self.RETURN_TRAIN) 35 | kwargs.setdefault('right_align', False) 36 | kwargs.setdefault('window', 0) 37 | kwargs.setdefault('top_k', 0) 38 | kwargs.setdefault('top_k_measure', None) 39 | kwargs.setdefault('interpolation', 100) 40 | xy_list = self.load_log_dirs(dirs, **kwargs) 41 | 42 | if kwargs['top_k']: 43 | perf = [kwargs['top_k_measure'](y) for _, y in xy_list] 44 | top_k_runs = np.argsort(perf)[-kwargs['top_k']:] 45 | new_xy_list = [] 46 | for r, (x, y) in enumerate(xy_list): 47 | if r in top_k_runs: 48 | new_xy_list.append((x, y)) 49 | xy_list = new_xy_list 50 | 51 | if kwargs['interpolation']: 52 | x_right = float('inf') 53 | for x, y in xy_list: 54 | x_right = min(x_right, x[-1]) 55 | x = np.arange(0, x_right, kwargs['interpolation']) 56 | y = [] 57 | for x_, y_ in xy_list: 58 | y.append(np.interp(x, x_, y_)) 59 | y = np.asarray(y) 60 | else: 61 | x = xy_list[0][0] 62 | y = [y for _, y in xy_list] 63 | x = np.asarray(x) 64 | y = np.asarray(y) 65 | 66 | return x, y 67 | 68 | def filter_log_dirs(self, pattern, negative_pattern=' ', root='./log', **kwargs): 69 | dirs = [item[0] for item in os.walk(root)] 70 | leaf_dirs = [] 71 | for i in range(len(dirs)): 72 | if i + 1 < len(dirs) and dirs[i + 1].startswith(dirs[i]): 73 | continue 74 | leaf_dirs.append(dirs[i]) 75 | names = [] 76 | p = re.compile(pattern) 77 | np = re.compile(negative_pattern) 78 | for dir in leaf_dirs: 79 | if p.match(dir) and not np.match(dir): 80 | names.append(dir) 81 | print(dir) 82 | print('') 83 | return sorted(names) 84 | 85 | def load_log_dirs(self, dirs, **kwargs): 86 | kwargs.setdefault('right_align', False) 87 | kwargs.setdefault('window', 0) 88 | kwargs.setdefault('right_most', 0) 89 | xy_list = [] 90 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 91 | for dir in dirs: 92 | event_acc = EventAccumulator(dir) 93 | event_acc.Reload() 94 | _, x, y = zip(*event_acc.Scalars(kwargs['tag'])) 95 | xy_list.append([x, y]) 96 | if kwargs['right_align']: 97 | x_max = float('inf') 98 | for x, y in xy_list: 99 | x_max = min(x_max, len(y)) 100 | xy_list = [[x[:x_max], y[:x_max]] for x, y in xy_list] 101 | x_max = kwargs['right_most'] 102 | if x_max: 103 | xy_list = [[x[:x_max], y[:x_max]] for x, y in xy_list] 104 | if kwargs['window']: 105 | xy_list = [self._window_func(np.asarray(x), np.asarray(y), kwargs['window'], np.mean) for x, y in xy_list] 106 | return xy_list 107 | 108 | def plot_mean(self, data, x=None, **kwargs): 109 | import matplotlib.pyplot as plt 110 | if x is None: 111 | x = np.arange(data.shape[1]) 112 | if kwargs['error'] == 'se': 113 | e_x = np.std(data, axis=0) / np.sqrt(data.shape[0]) 114 | elif kwargs['error'] == 'std': 115 | e_x = np.std(data, axis=0) 116 | else: 117 | raise NotImplementedError 118 | m_x = np.mean(data, axis=0) 119 | del kwargs['error'] 120 | plt.plot(x, m_x, **kwargs) 121 | del kwargs['label'] 122 | plt.fill_between(x, m_x + e_x, m_x - e_x, alpha=0.3, **kwargs) 123 | 124 | def plot_median_std(self, data, x=None, **kwargs): 125 | import matplotlib.pyplot as plt 126 | if x is None: 127 | x = np.arange(data.shape[1]) 128 | e_x = np.std(data, axis=0) 129 | m_x = np.median(data, axis=0) 130 | plt.plot(x, m_x, **kwargs) 131 | del kwargs['label'] 132 | plt.fill_between(x, m_x + e_x, m_x - e_x, alpha=0.3, **kwargs) 133 | 134 | def plot_games(self, games, **kwargs): 135 | kwargs.setdefault('agg', 'mean') 136 | import matplotlib.pyplot as plt 137 | l = len(games) 138 | plt.figure(figsize=(l * 5, 5)) 139 | for i, game in enumerate(games): 140 | plt.subplot(1, l, i + 1) 141 | for j, p in enumerate(kwargs['patterns']): 142 | label = kwargs['labels'][j] 143 | color = self.COLORS[j] 144 | log_dirs = self.filter_log_dirs(pattern='.*%s.*%s' % (game, p), **kwargs) 145 | x, y = self.load_results(log_dirs, **kwargs) 146 | if kwargs['downsample']: 147 | indices = np.linspace(0, len(x) - 1, kwargs['downsample']).astype(np.int) 148 | x = x[indices] 149 | y = y[:, indices] 150 | if kwargs['agg'] == 'mean': 151 | self.plot_mean(y, x, label=label, color=color, error='se') 152 | elif kwargs['agg'] == 'mean_std': 153 | self.plot_mean(y, x, label=label, color=color, error='std') 154 | elif kwargs['agg'] == 'median': 155 | self.plot_median_std(y, x, label=label, color=color) 156 | else: 157 | for k in range(y.shape[0]): 158 | plt.plot(x, y[i], label=label, color=color) 159 | label = None 160 | plt.xlabel('steps') 161 | if not i: 162 | plt.ylabel(kwargs['tag']) 163 | plt.title(game) 164 | plt.legend() 165 | 166 | def select_best_parameters(self, patterns, **kwargs): 167 | scores = [] 168 | for pattern in patterns: 169 | log_dirs = self.filter_log_dirs(pattern, **kwargs) 170 | xy_list = self.load_log_dirs(log_dirs, **kwargs) 171 | y = np.asarray([xy[1] for xy in xy_list]) 172 | scores.append(kwargs['score'](y)) 173 | indices = np.argsort(-np.asarray(scores)) 174 | return indices 175 | 176 | 177 | def reduce_dir(self, root, tag, ids, score_fn): 178 | tf_log_info = {} 179 | for dir, _, files in os.walk(root): 180 | for file in files: 181 | if 'tfevents' in file: 182 | dir = os.path.basename(dir) 183 | dir = re.sub(r'hp_\d+', 'placeholder', dir) 184 | dir = re.sub(r'run.*', 'run', dir) 185 | tf_log_info[dir] = {} 186 | for key in tf_log_info.keys(): 187 | scores = [] 188 | for id in ids: 189 | dir = key.replace('placeholder', 'hp_%s' % (id)) 190 | names = self.filter_log_dirs('.*%s.*' % (dir), root=root) 191 | xy_list = self.load_log_dirs(names, tag=tag, right_align=True) 192 | scores.append(score_fn(np.asarray([y for x, y in xy_list]))) 193 | best = np.nanargmax(scores) 194 | tf_log_info[key]['hp'] = ids[best] 195 | tf_log_info[key]['score'] = scores[best] 196 | return tf_log_info 197 | 198 | 199 | def reduce_patterns(self, patterns, root, tag, ids, score_fn): 200 | new_patterns = [] 201 | best_ids = [] 202 | for pattern in patterns: 203 | scores = [] 204 | pattern = re.sub(r'hp_\d+', 'placeholder', pattern) 205 | ps = [] 206 | for id in ids: 207 | p = pattern.replace('placeholder', 'hp_%s' % (id)) 208 | ps.append(p) 209 | names = self.filter_log_dirs('.*%s.*' % (p), root=root) 210 | xy_list = self.load_log_dirs(names, tag=tag, right_align=True) 211 | scores.append(score_fn(np.asarray([y for x, y in xy_list]))) 212 | try: 213 | best = np.nanargmax(scores) 214 | except ValueError as e: 215 | print(e) 216 | best = 0 217 | best_ids.append(best) 218 | new_patterns.append(ps[best]) 219 | return dict(patterns=new_patterns, ids=best_ids) 220 | 221 | -------------------------------------------------------------------------------- /deep_rl/utils/schedule.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | class ConstantSchedule: 8 | def __init__(self, val): 9 | self.val = val 10 | 11 | def __call__(self, steps=1): 12 | return self.val 13 | 14 | 15 | class LinearSchedule: 16 | def __init__(self, start, end=None, steps=None): 17 | if end is None: 18 | end = start 19 | steps = 1 20 | self.inc = (end - start) / float(steps) 21 | self.current = start 22 | self.end = end 23 | if end > start: 24 | self.bound = min 25 | else: 26 | self.bound = max 27 | 28 | def __call__(self, steps=1): 29 | val = self.current 30 | self.current = self.bound(self.current + self.inc * steps, self.end) 31 | return val 32 | -------------------------------------------------------------------------------- /deep_rl/utils/sum_tree.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rlcode/per/blob/master/SumTree.py 2 | 3 | import numpy 4 | # SumTree 5 | # a binary tree data structure where the parent’s value is the sum of its children 6 | class SumTree: 7 | write = 0 8 | def __init__(self, capacity): 9 | self.capacity = capacity 10 | self.tree = numpy.zeros(2 * capacity - 1) 11 | self.data = numpy.zeros(capacity, dtype=object) 12 | self.n_entries = 0 13 | self.pending_idx = set() 14 | 15 | # update to the root node 16 | def _propagate(self, idx, change): 17 | parent = (idx - 1) // 2 18 | self.tree[parent] += change 19 | if parent != 0: 20 | self._propagate(parent, change) 21 | 22 | # find sample on leaf node 23 | def _retrieve(self, idx, s): 24 | left = 2 * idx + 1 25 | right = left + 1 26 | 27 | if left >= len(self.tree): 28 | return idx 29 | 30 | if s <= self.tree[left]: 31 | return self._retrieve(left, s) 32 | else: 33 | return self._retrieve(right, s - self.tree[left]) 34 | 35 | def total(self): 36 | return self.tree[0] 37 | 38 | # store priority and sample 39 | def add(self, p, data): 40 | idx = self.write + self.capacity - 1 41 | self.pending_idx.add(idx) 42 | 43 | self.data[self.write] = data 44 | self.update(idx, p) 45 | 46 | self.write += 1 47 | if self.write >= self.capacity: 48 | self.write = 0 49 | 50 | if self.n_entries < self.capacity: 51 | self.n_entries += 1 52 | 53 | # update priority 54 | def update(self, idx, p): 55 | if idx not in self.pending_idx: 56 | return 57 | self.pending_idx.remove(idx) 58 | change = p - self.tree[idx] 59 | self.tree[idx] = p 60 | self._propagate(idx, change) 61 | 62 | # get priority and sample 63 | def get(self, s): 64 | idx = self._retrieve(0, s) 65 | dataIdx = idx - self.capacity + 1 66 | self.pending_idx.add(idx) 67 | return (idx, self.tree[idx], dataIdx) -------------------------------------------------------------------------------- /deep_rl/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from .config import * 8 | import torch 9 | import os 10 | 11 | 12 | def select_device(gpu_id): 13 | # if torch.cuda.is_available() and gpu_id >= 0: 14 | if gpu_id >= 0: 15 | Config.DEVICE = torch.device('cuda:%d' % (gpu_id)) 16 | else: 17 | Config.DEVICE = torch.device('cpu') 18 | 19 | 20 | def tensor(x): 21 | if isinstance(x, torch.Tensor): 22 | return x 23 | x = np.asarray(x, dtype=np.float32) 24 | x = torch.from_numpy(x).to(Config.DEVICE) 25 | return x 26 | 27 | 28 | def range_tensor(end): 29 | return torch.arange(end).long().to(Config.DEVICE) 30 | 31 | 32 | def to_np(t): 33 | return t.cpu().detach().numpy() 34 | 35 | 36 | def random_seed(seed=None): 37 | np.random.seed(seed) 38 | torch.manual_seed(np.random.randint(int(1e6))) 39 | 40 | 41 | def set_one_thread(): 42 | os.environ['OMP_NUM_THREADS'] = '1' 43 | os.environ['MKL_NUM_THREADS'] = '1' 44 | torch.set_num_threads(1) 45 | 46 | 47 | def huber(x, k=1.0): 48 | return torch.where(x.abs() < k, 0.5 * x.pow(2), k * (x.abs() - 0.5 * k)) 49 | 50 | 51 | def epsilon_greedy(epsilon, x): 52 | if len(x.shape) == 1: 53 | return np.random.randint(len(x)) if np.random.rand() < epsilon else np.argmax(x) 54 | elif len(x.shape) == 2: 55 | random_actions = np.random.randint(x.shape[1], size=x.shape[0]) 56 | greedy_actions = np.argmax(x, axis=-1) 57 | dice = np.random.rand(x.shape[0]) 58 | return np.where(dice < epsilon, random_actions, greedy_actions) 59 | 60 | 61 | def sync_grad(target_network, src_network): 62 | for param, src_param in zip(target_network.parameters(), src_network.parameters()): 63 | if src_param.grad is not None: 64 | param._grad = src_param.grad.clone() 65 | 66 | 67 | # adapted from https://github.com/pytorch/pytorch/issues/12160 68 | def batch_diagonal(input): 69 | # idea from here: https://discuss.pytorch.org/t/batch-of-diagonal-matrix/13560 70 | # batches a stack of vectors (batch x N) -> a stack of diagonal matrices (batch x N x N) 71 | # works in 2D -> 3D, should also work in higher dimensions 72 | # make a zero matrix, which duplicates the last dim of input 73 | dims = input.size() 74 | dims = dims + dims[-1:] 75 | output = torch.zeros(dims, device=input.device) 76 | # stride across the first dimensions, add one to get the diagonal of the last dimension 77 | strides = [output.stride(i) for i in range(input.dim() - 1)] 78 | strides.append(output.size(-1) + 1) 79 | # stride and copy the input to the diagonal 80 | output.as_strided(input.size(), strides).copy_(input) 81 | return output 82 | 83 | 84 | def batch_trace(input): 85 | i = range_tensor(input.size(-1)) 86 | t = input[:, i, i].sum(-1).unsqueeze(-1).unsqueeze(-1) 87 | return t 88 | 89 | 90 | class DiagonalNormal: 91 | def __init__(self, mean, std): 92 | self.dist = torch.distributions.Normal(mean, std) 93 | self.sample = self.dist.sample 94 | 95 | def log_prob(self, action): 96 | return self.dist.log_prob(action).sum(-1).unsqueeze(-1) 97 | 98 | def entropy(self): 99 | return self.dist.entropy().sum(-1).unsqueeze(-1) 100 | 101 | def cdf(self, action): 102 | return self.dist.cdf(action).prod(-1).unsqueeze(-1) 103 | 104 | 105 | class BatchCategorical: 106 | def __init__(self, logits): 107 | self.pre_shape = logits.size()[:-1] 108 | logits = logits.view(-1, logits.size(-1)) 109 | self.dist = torch.distributions.Categorical(logits=logits) 110 | 111 | def log_prob(self, action): 112 | log_pi = self.dist.log_prob(action.view(-1)) 113 | log_pi = log_pi.view(action.size()[:-1] + (-1,)) 114 | return log_pi 115 | 116 | def entropy(self): 117 | ent = self.dist.entropy() 118 | ent = ent.view(self.pre_shape + (-1,)) 119 | return ent 120 | 121 | def sample(self, sample_shape=torch.Size([])): 122 | ret = self.dist.sample(sample_shape) 123 | ret = ret.view(sample_shape + self.pre_shape + (-1,)) 124 | return ret 125 | 126 | 127 | class Grad: 128 | def __init__(self, network=None, grads=None): 129 | if grads is not None: 130 | self.grads = grads 131 | else: 132 | self.grads = [] 133 | for param in network.parameters(): 134 | self.grads.append(torch.zeros(param.data.size(), device=Config.DEVICE)) 135 | 136 | def add(self, op): 137 | if isinstance(op, Grad): 138 | for grad, op_grad in zip(self.grads, op.grads): 139 | grad.add_(op_grad) 140 | elif isinstance(op, torch.nn.Module): 141 | for grad, param in zip(self.grads, op.parameters()): 142 | if param.grad is not None: 143 | grad.add_(param.grad) 144 | return self 145 | 146 | def mul(self, coef): 147 | for grad in self.grads: 148 | grad.mul_(coef) 149 | return self 150 | 151 | def assign(self, network): 152 | for grad, param in zip(self.grads, network.parameters()): 153 | param._grad = grad.clone() 154 | 155 | def zero(self): 156 | for grad in self.grads: 157 | grad.zero_() 158 | 159 | def clone(self): 160 | return Grad(grads=[grad.clone() for grad in self.grads]) 161 | 162 | 163 | class Grads: 164 | def __init__(self, network=None, n=0, grads=None): 165 | if grads is not None: 166 | self.grads = grads 167 | else: 168 | self.grads = [Grad(network) for _ in range(n)] 169 | 170 | def clone(self): 171 | return Grads(grads=[grad.clone() for grad in self.grads]) 172 | 173 | def mul(self, op): 174 | if np.isscalar(op): 175 | for grad in self.grads: 176 | grad.mul(op) 177 | elif isinstance(op, torch.Tensor): 178 | op = op.view(-1) 179 | for i, grad in enumerate(self.grads): 180 | grad.mul(op[i]) 181 | else: 182 | raise NotImplementedError 183 | return self 184 | 185 | def add(self, op): 186 | if np.isscalar(op): 187 | for grad in self.grads: 188 | grad.mul(op) 189 | elif isinstance(op, Grads): 190 | for grad, op_grad in zip(self.grads, op.grads): 191 | grad.add(op_grad) 192 | elif isinstance(op, torch.Tensor): 193 | op = op.view(-1) 194 | for i, grad in enumerate(self.grads): 195 | grad.mul(op[i]) 196 | else: 197 | raise NotImplementedError 198 | return self 199 | 200 | def mean(self): 201 | grad = self.grads[0].clone() 202 | grad.zero() 203 | for g in self.grads: 204 | grad.add(g) 205 | grad.mul(1 / len(self.grads)) 206 | return grad 207 | 208 | 209 | def escape_float(x): 210 | return ('%s' % x).replace('.', '\.') -------------------------------------------------------------------------------- /docker_batch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPUs=(0 1 2 3 4 5 6 7) 3 | 4 | #for i in $(seq 0 7); do 5 | # for j in $(seq 0 0); do 6 | # nohup bash docker_python.sh ${GPUs[$i]} "template_jobs.py --i $i --j $j" >| job_${i}_${j}.out & 7 | # done 8 | #done 9 | 10 | 11 | rm -f jobs.txt 12 | touch jobs.txt 13 | for i in $(seq 0 100); do 14 | echo "$i" >> jobs.txt 15 | done 16 | cat jobs.txt | xargs -n 1 -P 40 sh -c 'bash docker_python.sh 0 "template_jobs.py --i $0"' 17 | rm -f jobs.txt 18 | 19 | 20 | #rm -f jobs.txt 21 | #touch jobs.txt 22 | #for i in $(seq 0 6); do 23 | # echo "$i ${GPUs[$(($i % 8))]}" >> jobs.txt 24 | #done 25 | #cat jobs.txt | xargs -n 2 -P 40 sh -c 'bash docker_python.sh $1 "template_jobs.py --i $0"' 26 | #rm -f jobs.txt 27 | -------------------------------------------------------------------------------- /docker_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | docker build --build-arg UID=$UID -t deep_rl:v1.5 . -------------------------------------------------------------------------------- /docker_clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf log tf_log *.out -------------------------------------------------------------------------------- /docker_python.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if hash nvidia-docker 2>/dev/null; then 4 | cmd=nvidia-docker 5 | else 6 | cmd=docker 7 | fi 8 | 9 | NV_GPU=$1 ${cmd} run --rm -v `pwd`:/home/user/deep_rl --entrypoint '/bin/bash' deep_rl:v1.5 -c "OMP_NUM_THREADS=1 python $2" -------------------------------------------------------------------------------- /docker_shell.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if hash nvidia-docker 2>/dev/null; then 4 | cmd=nvidia-docker 5 | else 6 | cmd=docker 7 | fi 8 | 9 | ${cmd} run --rm -v `pwd`:/home/user/deep_rl -it deep_rl:v1.5 10 | -------------------------------------------------------------------------------- /docker_stop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if hash nvidia-docker 2>/dev/null; then 4 | cmd=nvidia-docker 5 | else 6 | cmd=docker 7 | fi 8 | 9 | ${cmd} ps -a | awk '{ print $1,$2 }' | grep deep_rl:v1.5 | awk '{print $1 }' | xargs -I {} ${cmd} kill {} 10 | ${cmd} ps -a | awk '{ print $1,$2 }' | grep deep_rl:v1.5 | awk '{print $1 }' | xargs -I {} ${cmd} rm {} 11 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | ####################################################################### 2 | # Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # 3 | # Permission given to modify the code as long as you keep this # 4 | # declaration at the top # 5 | ####################################################################### 6 | 7 | from deep_rl import * 8 | 9 | 10 | # DQN 11 | def dqn_feature(**kwargs): 12 | generate_tag(kwargs) 13 | kwargs.setdefault('log_level', 0) 14 | kwargs.setdefault('n_step', 1) 15 | kwargs.setdefault('replay_cls', UniformReplay) 16 | kwargs.setdefault('async_replay', True) 17 | config = Config() 18 | config.merge(kwargs) 19 | 20 | config.task_fn = lambda: Task(config.game) 21 | config.eval_env = config.task_fn() 22 | 23 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 24 | config.network_fn = lambda: VanillaNet(config.action_dim, FCBody(config.state_dim)) 25 | # config.network_fn = lambda: DuelingNet(config.action_dim, FCBody(config.state_dim)) 26 | config.history_length = 1 27 | config.batch_size = 10 28 | config.discount = 0.99 29 | config.max_steps = 1e5 30 | 31 | replay_kwargs = dict( 32 | memory_size=int(1e4), 33 | batch_size=config.batch_size, 34 | n_step=config.n_step, 35 | discount=config.discount, 36 | history_length=config.history_length) 37 | 38 | config.replay_fn = lambda: ReplayWrapper(config.replay_cls, replay_kwargs, config.async_replay) 39 | config.replay_eps = 0.01 40 | config.replay_alpha = 0.5 41 | config.replay_beta = LinearSchedule(0.4, 1.0, config.max_steps) 42 | 43 | config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) 44 | config.target_network_update_freq = 200 45 | config.exploration_steps = 1000 46 | # config.double_q = True 47 | config.double_q = False 48 | config.sgd_update_frequency = 4 49 | config.gradient_clip = 5 50 | config.eval_interval = int(5e3) 51 | config.async_actor = False 52 | run_steps(DQNAgent(config)) 53 | 54 | 55 | def dqn_pixel(**kwargs): 56 | generate_tag(kwargs) 57 | kwargs.setdefault('log_level', 0) 58 | kwargs.setdefault('n_step', 1) 59 | kwargs.setdefault('replay_cls', UniformReplay) 60 | kwargs.setdefault('async_replay', True) 61 | config = Config() 62 | config.merge(kwargs) 63 | 64 | config.task_fn = lambda: Task(config.game) 65 | config.eval_env = config.task_fn() 66 | 67 | config.optimizer_fn = lambda params: torch.optim.RMSprop( 68 | params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) 69 | config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody(in_channels=config.history_length)) 70 | # config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody(in_channels=config.history_length)) 71 | config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) 72 | config.batch_size = 32 73 | config.discount = 0.99 74 | config.history_length = 4 75 | config.max_steps = int(2e7) 76 | replay_kwargs = dict( 77 | memory_size=int(1e6), 78 | batch_size=config.batch_size, 79 | n_step=config.n_step, 80 | discount=config.discount, 81 | history_length=config.history_length, 82 | ) 83 | config.replay_fn = lambda: ReplayWrapper(config.replay_cls, replay_kwargs, config.async_replay) 84 | config.replay_eps = 0.01 85 | config.replay_alpha = 0.5 86 | config.replay_beta = LinearSchedule(0.4, 1.0, config.max_steps) 87 | 88 | config.state_normalizer = ImageNormalizer() 89 | config.reward_normalizer = SignNormalizer() 90 | config.target_network_update_freq = 10000 91 | config.exploration_steps = 50000 92 | # config.exploration_steps = 100 93 | config.sgd_update_frequency = 4 94 | config.gradient_clip = 5 95 | config.double_q = False 96 | config.async_actor = True 97 | run_steps(DQNAgent(config)) 98 | 99 | 100 | # QR DQN 101 | def quantile_regression_dqn_feature(**kwargs): 102 | generate_tag(kwargs) 103 | kwargs.setdefault('log_level', 0) 104 | config = Config() 105 | config.merge(kwargs) 106 | 107 | config.task_fn = lambda: Task(config.game) 108 | config.eval_env = config.task_fn() 109 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 110 | config.network_fn = lambda: QuantileNet(config.action_dim, config.num_quantiles, FCBody(config.state_dim)) 111 | 112 | config.batch_size = 10 113 | replay_kwargs = dict( 114 | memory_size=int(1e4), 115 | batch_size=config.batch_size) 116 | config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) 117 | 118 | config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) 119 | config.discount = 0.99 120 | config.target_network_update_freq = 200 121 | config.exploration_steps = 100 122 | config.num_quantiles = 20 123 | config.gradient_clip = 5 124 | config.sgd_update_frequency = 4 125 | config.eval_interval = int(5e3) 126 | config.max_steps = 1e5 127 | run_steps(QuantileRegressionDQNAgent(config)) 128 | 129 | 130 | def quantile_regression_dqn_pixel(**kwargs): 131 | generate_tag(kwargs) 132 | kwargs.setdefault('log_level', 0) 133 | config = Config() 134 | config.merge(kwargs) 135 | 136 | config.task_fn = lambda: Task(config.game) 137 | config.eval_env = config.task_fn() 138 | 139 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=0.00005, eps=0.01 / 32) 140 | config.network_fn = lambda: QuantileNet(config.action_dim, config.num_quantiles, NatureConvBody()) 141 | config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) 142 | 143 | config.batch_size = 32 144 | replay_kwargs = dict( 145 | memory_size=int(1e6), 146 | batch_size=config.batch_size, 147 | history_length=4, 148 | ) 149 | config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) 150 | 151 | config.state_normalizer = ImageNormalizer() 152 | config.reward_normalizer = SignNormalizer() 153 | config.discount = 0.99 154 | config.target_network_update_freq = 10000 155 | config.exploration_steps = 50000 156 | config.sgd_update_frequency = 4 157 | config.gradient_clip = 5 158 | config.num_quantiles = 200 159 | config.max_steps = int(2e7) 160 | run_steps(QuantileRegressionDQNAgent(config)) 161 | 162 | 163 | # C51 164 | def categorical_dqn_feature(**kwargs): 165 | generate_tag(kwargs) 166 | kwargs.setdefault('log_level', 0) 167 | config = Config() 168 | config.merge(kwargs) 169 | 170 | config.task_fn = lambda: Task(config.game) 171 | config.eval_env = config.task_fn() 172 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 173 | config.network_fn = lambda: CategoricalNet(config.action_dim, config.categorical_n_atoms, FCBody(config.state_dim)) 174 | config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) 175 | 176 | config.batch_size = 10 177 | replay_kwargs = dict( 178 | memory_size=int(1e4), 179 | batch_size=config.batch_size) 180 | config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) 181 | 182 | config.discount = 0.99 183 | config.target_network_update_freq = 200 184 | config.exploration_steps = 100 185 | config.categorical_v_max = 100 186 | config.categorical_v_min = -100 187 | config.categorical_n_atoms = 50 188 | config.gradient_clip = 5 189 | config.sgd_update_frequency = 4 190 | 191 | config.eval_interval = int(5e3) 192 | config.max_steps = 1e5 193 | run_steps(CategoricalDQNAgent(config)) 194 | 195 | 196 | def categorical_dqn_pixel(**kwargs): 197 | generate_tag(kwargs) 198 | kwargs.setdefault('log_level', 0) 199 | config = Config() 200 | config.merge(kwargs) 201 | 202 | config.task_fn = lambda: Task(config.game) 203 | config.eval_env = config.task_fn() 204 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=0.00025, eps=0.01 / 32) 205 | config.network_fn = lambda: CategoricalNet(config.action_dim, config.categorical_n_atoms, NatureConvBody()) 206 | config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) 207 | 208 | config.batch_size = 32 209 | replay_kwargs = dict( 210 | memory_size=int(1e6), 211 | batch_size=config.batch_size, 212 | history_length=4, 213 | ) 214 | config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs, async=True) 215 | 216 | config.discount = 0.99 217 | config.state_normalizer = ImageNormalizer() 218 | config.reward_normalizer = SignNormalizer() 219 | config.target_network_update_freq = 10000 220 | config.exploration_steps = 50000 221 | config.categorical_v_max = 10 222 | config.categorical_v_min = -10 223 | config.categorical_n_atoms = 51 224 | config.sgd_update_frequency = 4 225 | config.gradient_clip = 0.5 226 | config.max_steps = int(2e7) 227 | run_steps(CategoricalDQNAgent(config)) 228 | 229 | 230 | # Rainbow 231 | def rainbow_feature(**kwargs): 232 | generate_tag(kwargs) 233 | kwargs.setdefault('log_level', 0) 234 | kwargs.setdefault('n_step', 3) 235 | kwargs.setdefault('replay_cls', PrioritizedReplay) 236 | kwargs.setdefault('async_replay', True) 237 | config = Config() 238 | config.merge(kwargs) 239 | 240 | config.task_fn = lambda: Task(config.game) 241 | config.eval_env = config.task_fn() 242 | 243 | config.max_steps = 1e5 244 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 245 | config.noisy_linear = True 246 | config.network_fn = lambda: RainbowNet( 247 | config.action_dim, 248 | config.categorical_n_atoms, 249 | FCBody(config.state_dim, noisy_linear=config.noisy_linear), 250 | noisy_linear=config.noisy_linear 251 | ) 252 | config.categorical_v_max = 100 253 | config.categorical_v_min = -100 254 | config.categorical_n_atoms = 50 255 | 256 | config.discount = 0.99 257 | config.batch_size = 32 258 | replay_kwargs = dict( 259 | memory_size=int(1e4), 260 | batch_size=config.batch_size, 261 | n_step=config.n_step, 262 | discount=config.discount, 263 | history_length=1) 264 | 265 | config.replay_fn = lambda: ReplayWrapper(config.replay_cls, replay_kwargs, config.async_replay) 266 | 267 | config.replay_eps = 0.01 268 | config.replay_alpha = 0.5 269 | config.replay_beta = LinearSchedule(0.4, 1, config.max_steps) 270 | config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) 271 | 272 | config.target_network_update_freq = 200 273 | config.exploration_steps = 1000 274 | config.double_q = True 275 | config.sgd_update_frequency = 4 276 | config.eval_interval = int(5e3) 277 | config.async_actor = True 278 | config.gradient_clip = 10 279 | 280 | run_steps(CategoricalDQNAgent(config)) 281 | 282 | 283 | def rainbow_pixel(**kwargs): 284 | generate_tag(kwargs) 285 | kwargs.setdefault('log_level', 0) 286 | kwargs.setdefault('n_step', 1) 287 | kwargs.setdefault('replay_cls', PrioritizedReplay) 288 | kwargs.setdefault('async_replay', True) 289 | kwargs.setdefault('noisy_linear', True) 290 | config = Config() 291 | config.merge(kwargs) 292 | 293 | config.task_fn = lambda: Task(config.game) 294 | config.eval_env = config.task_fn() 295 | 296 | config.max_steps = int(2e7) 297 | Config.NOISY_LAYER_STD = 0.5 298 | config.optimizer_fn = lambda params: torch.optim.Adam( 299 | params, lr=0.000625, eps=1.5e-4) 300 | config.network_fn = lambda: RainbowNet( 301 | config.action_dim, 302 | config.categorical_n_atoms, 303 | NatureConvBody(noisy_linear=config.noisy_linear), 304 | noisy_linear=config.noisy_linear, 305 | ) 306 | config.categorical_v_max = 10 307 | config.categorical_v_min = -10 308 | config.categorical_n_atoms = 51 309 | 310 | config.random_action_prob = LinearSchedule(1, 0.01, 25e4) 311 | 312 | config.batch_size = 32 313 | config.discount = 0.99 314 | config.history_length = 4 315 | replay_kwargs = dict( 316 | memory_size=int(1e6), 317 | batch_size=config.batch_size, 318 | n_step=config.n_step, 319 | discount=config.discount, 320 | history_length=config.history_length, 321 | ) 322 | config.replay_fn = lambda: ReplayWrapper(config.replay_cls, replay_kwargs, config.async_replay) 323 | config.replay_eps = 0.01 324 | config.replay_alpha = 0.5 325 | config.replay_beta = LinearSchedule(0.4, 1.0, config.max_steps) 326 | 327 | config.state_normalizer = ImageNormalizer() 328 | config.reward_normalizer = SignNormalizer() 329 | config.target_network_update_freq = 2000 330 | config.exploration_steps = 20000 331 | # config.exploration_steps = 200 332 | config.sgd_update_frequency = 4 333 | config.double_q = True 334 | config.async_actor = True 335 | config.gradient_clip = 10 336 | run_steps(CategoricalDQNAgent(config)) 337 | 338 | 339 | # A2C 340 | def a2c_feature(**kwargs): 341 | generate_tag(kwargs) 342 | kwargs.setdefault('log_level', 0) 343 | config = Config() 344 | config.merge(kwargs) 345 | 346 | config.num_workers = 5 347 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 348 | config.eval_env = Task(config.game) 349 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 350 | config.network_fn = lambda: CategoricalActorCriticNet( 351 | config.state_dim, config.action_dim, FCBody(config.state_dim, gate=F.tanh)) 352 | config.discount = 0.99 353 | config.use_gae = True 354 | config.gae_tau = 0.95 355 | config.entropy_weight = 0.01 356 | config.rollout_length = 5 357 | config.gradient_clip = 0.5 358 | run_steps(A2CAgent(config)) 359 | 360 | 361 | def a2c_pixel(**kwargs): 362 | generate_tag(kwargs) 363 | kwargs.setdefault('log_level', 0) 364 | config = Config() 365 | config.merge(kwargs) 366 | 367 | config.num_workers = 16 368 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 369 | config.eval_env = Task(config.game) 370 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=1e-4, alpha=0.99, eps=1e-5) 371 | config.network_fn = lambda: CategoricalActorCriticNet(config.state_dim, config.action_dim, NatureConvBody()) 372 | config.state_normalizer = ImageNormalizer() 373 | config.reward_normalizer = SignNormalizer() 374 | config.discount = 0.99 375 | config.use_gae = True 376 | config.gae_tau = 1.0 377 | config.entropy_weight = 0.01 378 | config.rollout_length = 5 379 | config.gradient_clip = 5 380 | config.max_steps = int(2e7) 381 | run_steps(A2CAgent(config)) 382 | 383 | 384 | def a2c_continuous(**kwargs): 385 | generate_tag(kwargs) 386 | kwargs.setdefault('log_level', 0) 387 | config = Config() 388 | config.merge(kwargs) 389 | 390 | config.num_workers = 16 391 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 392 | config.eval_env = Task(config.game) 393 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=0.0007) 394 | config.network_fn = lambda: GaussianActorCriticNet( 395 | config.state_dim, config.action_dim, 396 | actor_body=FCBody(config.state_dim), critic_body=FCBody(config.state_dim)) 397 | config.discount = 0.99 398 | config.use_gae = True 399 | config.gae_tau = 1.0 400 | config.entropy_weight = 0.01 401 | config.rollout_length = 5 402 | config.gradient_clip = 5 403 | config.max_steps = int(2e7) 404 | run_steps(A2CAgent(config)) 405 | 406 | 407 | # N-Step DQN 408 | def n_step_dqn_feature(**kwargs): 409 | generate_tag(kwargs) 410 | kwargs.setdefault('log_level', 0) 411 | config = Config() 412 | config.merge(kwargs) 413 | 414 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 415 | config.eval_env = Task(config.game) 416 | config.num_workers = 5 417 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 418 | config.network_fn = lambda: VanillaNet(config.action_dim, FCBody(config.state_dim)) 419 | config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) 420 | config.discount = 0.99 421 | config.target_network_update_freq = 200 422 | config.rollout_length = 5 423 | config.gradient_clip = 5 424 | run_steps(NStepDQNAgent(config)) 425 | 426 | 427 | def n_step_dqn_pixel(**kwargs): 428 | generate_tag(kwargs) 429 | kwargs.setdefault('log_level', 0) 430 | config = Config() 431 | config.merge(kwargs) 432 | 433 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 434 | config.eval_env = Task(config.game) 435 | config.num_workers = 16 436 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=1e-4, alpha=0.99, eps=1e-5) 437 | config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody()) 438 | config.random_action_prob = LinearSchedule(1.0, 0.05, 1e6) 439 | config.state_normalizer = ImageNormalizer() 440 | config.reward_normalizer = SignNormalizer() 441 | config.discount = 0.99 442 | config.target_network_update_freq = 10000 443 | config.rollout_length = 5 444 | config.gradient_clip = 5 445 | config.max_steps = int(2e7) 446 | run_steps(NStepDQNAgent(config)) 447 | 448 | 449 | # Option-Critic 450 | def option_critic_feature(**kwargs): 451 | generate_tag(kwargs) 452 | kwargs.setdefault('log_level', 0) 453 | config = Config() 454 | config.merge(kwargs) 455 | 456 | config.num_workers = 5 457 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 458 | config.eval_env = Task(config.game) 459 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 0.001) 460 | config.network_fn = lambda: OptionCriticNet(FCBody(config.state_dim), config.action_dim, num_options=2) 461 | config.random_option_prob = LinearSchedule(1.0, 0.1, 1e4) 462 | config.discount = 0.99 463 | config.target_network_update_freq = 200 464 | config.rollout_length = 5 465 | config.termination_regularizer = 0.01 466 | config.entropy_weight = 0.01 467 | config.gradient_clip = 5 468 | run_steps(OptionCriticAgent(config)) 469 | 470 | 471 | def option_critic_pixel(**kwargs): 472 | generate_tag(kwargs) 473 | kwargs.setdefault('log_level', 0) 474 | config = Config() 475 | config.merge(kwargs) 476 | 477 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 478 | config.eval_env = Task(config.game) 479 | config.num_workers = 16 480 | config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=1e-4, alpha=0.99, eps=1e-5) 481 | config.network_fn = lambda: OptionCriticNet(NatureConvBody(), config.action_dim, num_options=4) 482 | config.random_option_prob = LinearSchedule(0.1) 483 | config.state_normalizer = ImageNormalizer() 484 | config.reward_normalizer = SignNormalizer() 485 | config.discount = 0.99 486 | config.target_network_update_freq = 10000 487 | config.rollout_length = 5 488 | config.gradient_clip = 5 489 | config.max_steps = int(2e7) 490 | config.entropy_weight = 0.01 491 | config.termination_regularizer = 0.01 492 | run_steps(OptionCriticAgent(config)) 493 | 494 | 495 | # PPO 496 | def ppo_continuous(**kwargs): 497 | generate_tag(kwargs) 498 | kwargs.setdefault('log_level', 0) 499 | config = Config() 500 | config.merge(kwargs) 501 | 502 | config.task_fn = lambda: Task(config.game) 503 | config.eval_env = config.task_fn() 504 | 505 | config.network_fn = lambda: GaussianActorCriticNet( 506 | config.state_dim, config.action_dim, actor_body=FCBody(config.state_dim, gate=torch.tanh), 507 | critic_body=FCBody(config.state_dim, gate=torch.tanh)) 508 | config.actor_opt_fn = lambda params: torch.optim.Adam(params, 3e-4) 509 | config.critic_opt_fn = lambda params: torch.optim.Adam(params, 1e-3) 510 | config.discount = 0.99 511 | config.use_gae = True 512 | config.gae_tau = 0.95 513 | config.gradient_clip = 0.5 514 | config.rollout_length = 2048 515 | config.optimization_epochs = 10 516 | config.mini_batch_size = 64 517 | config.ppo_ratio_clip = 0.2 518 | config.log_interval = 2048 519 | config.max_steps = 3e6 520 | config.target_kl = 0.01 521 | config.state_normalizer = MeanStdNormalizer() 522 | run_steps(PPOAgent(config)) 523 | 524 | 525 | def ppo_pixel(**kwargs): 526 | generate_tag(kwargs) 527 | kwargs.setdefault('skip', False) 528 | config = Config() 529 | config.merge(kwargs) 530 | 531 | config.task_fn = lambda: Task(config.game, num_envs=config.num_workers) 532 | config.eval_env = Task(config.game) 533 | config.num_workers = 8 534 | config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=2.5e-4) 535 | config.network_fn = lambda: CategoricalActorCriticNet(config.state_dim, config.action_dim, NatureConvBody()) 536 | config.state_normalizer = ImageNormalizer() 537 | config.reward_normalizer = SignNormalizer() 538 | config.discount = 0.99 539 | config.use_gae = True 540 | config.gae_tau = 0.95 541 | config.entropy_weight = 0.01 542 | config.gradient_clip = 0.5 543 | config.rollout_length = 128 544 | config.optimization_epochs = 4 545 | config.mini_batch_size = config.rollout_length * config.num_workers // 4 546 | config.ppo_ratio_clip = 0.1 547 | config.log_interval = config.rollout_length * config.num_workers 548 | config.shared_repr = True 549 | config.max_steps = int(2e7) 550 | run_steps(PPOAgent(config)) 551 | 552 | 553 | # DDPG 554 | def ddpg_continuous(**kwargs): 555 | generate_tag(kwargs) 556 | kwargs.setdefault('log_level', 0) 557 | config = Config() 558 | config.merge(kwargs) 559 | 560 | config.task_fn = lambda: Task(config.game) 561 | config.eval_env = config.task_fn() 562 | config.max_steps = int(1e6) 563 | config.eval_interval = int(1e4) 564 | config.eval_episodes = 20 565 | 566 | config.network_fn = lambda: DeterministicActorCriticNet( 567 | config.state_dim, config.action_dim, 568 | actor_body=FCBody(config.state_dim, (400, 300), gate=F.relu), 569 | critic_body=FCBody(config.state_dim + config.action_dim, (400, 300), gate=F.relu), 570 | actor_opt_fn=lambda params: torch.optim.Adam(params, lr=1e-3), 571 | critic_opt_fn=lambda params: torch.optim.Adam(params, lr=1e-3)) 572 | 573 | config.replay_fn = lambda: UniformReplay(memory_size=int(1e6), batch_size=100) 574 | config.discount = 0.99 575 | config.random_process_fn = lambda: OrnsteinUhlenbeckProcess( 576 | size=(config.action_dim,), std=LinearSchedule(0.2)) 577 | config.warm_up = int(1e4) 578 | config.target_network_mix = 5e-3 579 | run_steps(DDPGAgent(config)) 580 | 581 | 582 | # TD3 583 | def td3_continuous(**kwargs): 584 | generate_tag(kwargs) 585 | kwargs.setdefault('log_level', 0) 586 | config = Config() 587 | config.merge(kwargs) 588 | 589 | config.task_fn = lambda: Task(config.game) 590 | config.eval_env = config.task_fn() 591 | config.max_steps = int(1e6) 592 | config.eval_interval = int(1e4) 593 | config.eval_episodes = 20 594 | 595 | config.network_fn = lambda: TD3Net( 596 | config.action_dim, 597 | actor_body_fn=lambda: FCBody(config.state_dim, (400, 300), gate=F.relu), 598 | critic_body_fn=lambda: FCBody( 599 | config.state_dim + config.action_dim, (400, 300), gate=F.relu), 600 | actor_opt_fn=lambda params: torch.optim.Adam(params, lr=1e-3), 601 | critic_opt_fn=lambda params: torch.optim.Adam(params, lr=1e-3)) 602 | 603 | replay_kwargs = dict( 604 | memory_size=int(1e6), 605 | batch_size=100, 606 | ) 607 | 608 | config.replay_fn = lambda: ReplayWrapper(UniformReplay, replay_kwargs) 609 | config.discount = 0.99 610 | config.random_process_fn = lambda: GaussianProcess( 611 | size=(config.action_dim,), std=LinearSchedule(0.1)) 612 | config.td3_noise = 0.2 613 | config.td3_noise_clip = 0.5 614 | config.td3_delay = 2 615 | config.warm_up = int(1e4) 616 | config.target_network_mix = 5e-3 617 | run_steps(TD3Agent(config)) 618 | 619 | 620 | if __name__ == '__main__': 621 | mkdir('log') 622 | mkdir('tf_log') 623 | set_one_thread() 624 | random_seed() 625 | # -1 is CPU, a positive integer is the index of GPU 626 | select_device(-1) 627 | # select_device(0) 628 | 629 | game = 'CartPole-v0' 630 | # dqn_feature(game=game, n_step=1, replay_cls=UniformReplay, async_replay=True, noisy_linear=True) 631 | # quantile_regression_dqn_feature(game=game) 632 | # categorical_dqn_feature(game=game) 633 | # rainbow_feature(game=game) 634 | # a2c_feature(game=game) 635 | # n_step_dqn_feature(game=game) 636 | # option_critic_feature(game=game) 637 | 638 | game = 'HalfCheetah-v2' 639 | # game = 'Hopper-v2' 640 | # a2c_continuous(game=game) 641 | # ppo_continuous(game=game) 642 | # ddpg_continuous(game=game) 643 | # td3_continuous(game=game) 644 | 645 | game = 'BreakoutNoFrameskip-v4' 646 | dqn_pixel(game=game, n_step=1, replay_cls=UniformReplay, async_replay=False) 647 | # quantile_regression_dqn_pixel(game=game) 648 | # categorical_dqn_pixel(game=game) 649 | # rainbow_pixel(game=game, async_replay=False) 650 | # a2c_pixel(game=game) 651 | # n_step_dqn_pixel(game=game) 652 | # option_critic_pixel(game=game) 653 | # ppo_pixel(game=game) 654 | -------------------------------------------------------------------------------- /images/Breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShangtongZhang/DeepRL/c0968b5c046314af09944843cdf06f4100c6bb95/images/Breakout.png -------------------------------------------------------------------------------- /images/PPO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShangtongZhang/DeepRL/c0968b5c046314af09944843cdf06f4100c6bb95/images/PPO.png -------------------------------------------------------------------------------- /images/mujoco_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShangtongZhang/DeepRL/c0968b5c046314af09944843cdf06f4100c6bb95/images/mujoco_eval.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1 2 | gym==0.10.8 3 | tensorflow==1.15.0 4 | atari-py==0.1.7 5 | opencv-python==4.0.0.21 6 | scikit-image==0.14.2 7 | tqdm==4.31.1 8 | pandas==0.24.1 9 | pathlib==1.0.1 10 | seaborn==0.9.0 11 | roboschool==1.0.49 12 | torchmeta 13 | torchvision 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | 4 | print('Please install OpenAI Baselines (commit 8e56dd) and requirement.txt') 5 | if not (sys.version.startswith('3.5') or sys.version.startswith('3.6')): 6 | raise Exception('Only Python 3.5 and 3.6 are supported') 7 | 8 | setup(name='deep_rl', 9 | packages=[package for package in find_packages() 10 | if package.startswith('deep_rl')], 11 | install_requires=[], 12 | description="Modularized Implementation of Deep RL Algorithms", 13 | author="Shangtong Zhang", 14 | url='https://github.com/ShangtongZhang/DeepRL', 15 | author_email="zhangshangtong.cpp@gmail.com", 16 | version="1.5") -------------------------------------------------------------------------------- /template_jobs.py: -------------------------------------------------------------------------------- 1 | from examples import * 2 | 3 | 4 | def batch_atari(): 5 | cf = Config() 6 | cf.add_argument('--i', type=int, default=0) 7 | cf.add_argument('--j', type=int, default=0) 8 | cf.merge() 9 | 10 | games = [ 11 | 'BreakoutNoFrameskip-v4', 12 | # 'AlienNoFrameskip-v4', 13 | # 'DemonAttackNoFrameskip-v4', 14 | # 'MsPacmanNoFrameskip-v4' 15 | ] 16 | 17 | algos = [ 18 | dqn_pixel, 19 | quantile_regression_dqn_pixel, 20 | categorical_dqn_pixel, 21 | rainbow_pixel, 22 | a2c_pixel, 23 | n_step_dqn_pixel, 24 | option_critic_pixel, 25 | ppo_pixel, 26 | ] 27 | 28 | params = [] 29 | 30 | for game in games: 31 | for r in range(1): 32 | for algo in algos: 33 | params.append([algo, dict(game=game, run=r, remark=algo.__name__)]) 34 | # for n_step in [1, 2, 3]: 35 | # for double_q in [True, False]: 36 | # params.extend([ 37 | # [dqn_pixel, 38 | # dict(game=game, run=r, n_step=n_step, replay_cls=PrioritizedReplay, double_q=double_q, 39 | # remark=dqn_pixel.__name__)], 40 | # [rainbow_pixel, 41 | # dict(game=game, run=r, n_step=n_step, noisy_linear=False, remark=rainbow_pixel.__name__)] 42 | # ]) 43 | # params.append( 44 | # [categorical_dqn_pixel, dict(game=game, run=r, remark=categorical_dqn_pixel.__name__)]), 45 | # params.append([dqn_pixel, dict(game=game, run=r, remark=dqn_pixel.__name__)]) 46 | 47 | algo, param = params[cf.i] 48 | algo(**param) 49 | exit() 50 | 51 | 52 | def batch_mujoco(): 53 | cf = Config() 54 | cf.add_argument('--i', type=int, default=0) 55 | cf.add_argument('--j', type=int, default=0) 56 | cf.merge() 57 | 58 | games = [ 59 | 'dm-acrobot-swingup', 60 | 'dm-acrobot-swingup_sparse', 61 | 'dm-ball_in_cup-catch', 62 | 'dm-cartpole-swingup', 63 | 'dm-cartpole-swingup_sparse', 64 | 'dm-cartpole-balance', 65 | 'dm-cartpole-balance_sparse', 66 | 'dm-cheetah-run', 67 | 'dm-finger-turn_hard', 68 | 'dm-finger-spin', 69 | 'dm-finger-turn_easy', 70 | 'dm-fish-upright', 71 | 'dm-fish-swim', 72 | 'dm-hopper-stand', 73 | 'dm-hopper-hop', 74 | 'dm-humanoid-stand', 75 | 'dm-humanoid-walk', 76 | 'dm-humanoid-run', 77 | 'dm-manipulator-bring_ball', 78 | 'dm-pendulum-swingup', 79 | 'dm-point_mass-easy', 80 | 'dm-reacher-easy', 81 | 'dm-reacher-hard', 82 | 'dm-swimmer-swimmer15', 83 | 'dm-swimmer-swimmer6', 84 | 'dm-walker-stand', 85 | 'dm-walker-walk', 86 | 'dm-walker-run', 87 | ] 88 | 89 | games = [ 90 | 'HalfCheetah-v2', 91 | 'Walker2d-v2', 92 | 'Swimmer-v2', 93 | 'Hopper-v2', 94 | 'Reacher-v2', 95 | 'Ant-v2', 96 | 'Humanoid-v2', 97 | 'HumanoidStandup-v2', 98 | ] 99 | 100 | params = [] 101 | 102 | for game in games: 103 | if 'Humanoid' in game: 104 | algos = [ppo_continuous] 105 | else: 106 | algos = [ppo_continuous, ddpg_continuous, td3_continuous] 107 | for algo in algos: 108 | for r in range(5): 109 | params.append([algo, dict(game=game, run=r)]) 110 | 111 | algo, param = params[cf.i] 112 | algo(**param, remark=algo.__name__) 113 | 114 | exit() 115 | 116 | 117 | if __name__ == '__main__': 118 | mkdir('log') 119 | mkdir('data') 120 | random_seed() 121 | 122 | # select_device(0) 123 | # batch_atari() 124 | 125 | select_device(-1) 126 | batch_mujoco() 127 | -------------------------------------------------------------------------------- /template_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | # matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | # plt.rc('text', usetex=True) 5 | from deep_rl import * 6 | 7 | 8 | def plot_ppo(): 9 | plotter = Plotter() 10 | games = [ 11 | 'HalfCheetah-v2', 12 | 'Walker2d-v2', 13 | 'Hopper-v2', 14 | 'Swimmer-v2', 15 | 'Reacher-v2', 16 | 'Ant-v2', 17 | 'Humanoid-v2', 18 | 'HumanoidStandup-v2', 19 | ] 20 | 21 | patterns = [ 22 | 'remark_ppo', 23 | ] 24 | 25 | labels = [ 26 | 'PPO' 27 | ] 28 | 29 | plotter.plot_games(games=games, 30 | patterns=patterns, 31 | agg='mean', 32 | downsample=0, 33 | labels=labels, 34 | right_align=False, 35 | tag=plotter.RETURN_TRAIN, 36 | root='./data/benchmark/mujoco', 37 | interpolation=100, 38 | window=10, 39 | ) 40 | 41 | # plt.show() 42 | plt.tight_layout() 43 | plt.savefig('images/PPO.png', bbox_inches='tight') 44 | 45 | 46 | def plot_ddpg_td3(): 47 | plotter = Plotter() 48 | games = [ 49 | 'HalfCheetah-v2', 50 | 'Walker2d-v2', 51 | 'Hopper-v2', 52 | 'Swimmer-v2', 53 | 'Reacher-v2', 54 | 'Ant-v2', 55 | ] 56 | 57 | patterns = [ 58 | 'remark_ddpg', 59 | 'remark_td3', 60 | ] 61 | 62 | labels = [ 63 | 'DDPG', 64 | 'TD3', 65 | ] 66 | 67 | plotter.plot_games(games=games, 68 | patterns=patterns, 69 | agg='mean', 70 | downsample=0, 71 | labels=labels, 72 | right_align=False, 73 | tag=plotter.RETURN_TEST, 74 | root='./data/benchmark/mujoco', 75 | interpolation=0, 76 | window=0, 77 | ) 78 | 79 | # plt.show() 80 | plt.tight_layout() 81 | plt.savefig('images/mujoco_eval.png', bbox_inches='tight') 82 | 83 | 84 | def plot_atari(): 85 | plotter = Plotter() 86 | games = [ 87 | 'BreakoutNoFrameskip-v4', 88 | ] 89 | 90 | patterns = [ 91 | 'remark_a2c', 92 | 'remark_categorical', 93 | 'remark_dqn', 94 | 'remark_n_step_dqn', 95 | 'remark_option_critic', 96 | 'remark_quantile', 97 | 'remark_ppo', 98 | # 'remark_rainbow', 99 | ] 100 | 101 | labels = [ 102 | 'A2C', 103 | 'C51', 104 | 'DQN', 105 | 'N-Step DQN', 106 | 'OC', 107 | 'QR-DQN', 108 | 'PPO', 109 | # 'Rainbow' 110 | ] 111 | 112 | plotter.plot_games(games=games, 113 | patterns=patterns, 114 | agg='mean', 115 | downsample=100, 116 | labels=labels, 117 | right_align=False, 118 | tag=plotter.RETURN_TRAIN, 119 | root='./data/benchmark/atari', 120 | interpolation=0, 121 | window=100, 122 | ) 123 | 124 | # plt.show() 125 | plt.tight_layout() 126 | plt.savefig('images/Breakout.png', bbox_inches='tight') 127 | 128 | 129 | if __name__ == '__main__': 130 | mkdir('images') 131 | # plot_ppo() 132 | # plot_ddpg_td3() 133 | plot_atari() --------------------------------------------------------------------------------