├── .gitignore ├── 1-gym_developing ├── README.md ├── core.py ├── grid_game.py ├── maze_game.py └── suceess.png ├── 2-markov_decision_process ├── __init__.py ├── game.py └── our_life.py ├── 3-dynamic_program ├── grid_game_with_average_policy.py ├── grid_game_with_policy_iterate.py ├── grid_game_with_value_iterate.py ├── maze_game_with_dynamic_program.py ├── policy_iteration_algorithm.png └── value_iteration_algorithm.png ├── 4-monte_carlo ├── monte_carlo_evaluate.py └── monte_carlo_sample.py ├── 5-temporal_difference ├── README.md ├── push_box_game.py ├── push_box_game │ ├── agent.py │ ├── main.py │ └── q_table.pkl ├── q_learning_algortihm.png ├── sarsa_algorithm.png └── sarsa_lambda_algorithm.png ├── 6-value_function_approximate ├── deep_learning_flappy_bird │ ├── .gitignore │ ├── README.md │ ├── assets │ │ ├── audio │ │ │ ├── die.ogg │ │ │ ├── die.wav │ │ │ ├── hit.ogg │ │ │ ├── hit.wav │ │ │ ├── point.ogg │ │ │ ├── point.wav │ │ │ ├── swoosh.ogg │ │ │ ├── swoosh.wav │ │ │ ├── wing.ogg │ │ │ └── wing.wav │ │ └── sprites │ │ │ ├── 0.png │ │ │ ├── 1.png │ │ │ ├── 2.png │ │ │ ├── 3.png │ │ │ ├── 4.png │ │ │ ├── 5.png │ │ │ ├── 6.png │ │ │ ├── 7.png │ │ │ ├── 8.png │ │ │ ├── 9.png │ │ │ ├── background-black.png │ │ │ ├── base.png │ │ │ ├── pipe-green.png │ │ │ ├── redbird-downflap.png │ │ │ ├── redbird-midflap.png │ │ │ └── redbird-upflap.png │ ├── deep_q_network.py │ ├── game │ │ ├── flappy_bird_utils.py │ │ └── wrapped_flappy_bird.py │ ├── images │ │ ├── flappy_bird_demp.gif │ │ ├── network.png │ │ └── preprocess.png │ ├── logs_bird │ │ ├── hidden.txt │ │ └── readout.txt │ └── saved_networks │ │ ├── bird-dqn-2880000 │ │ ├── bird-dqn-2880000.meta │ │ ├── bird-dqn-2890000 │ │ ├── bird-dqn-2890000.meta │ │ ├── bird-dqn-2900000 │ │ ├── bird-dqn-2900000.meta │ │ ├── bird-dqn-2910000 │ │ ├── bird-dqn-2910000.meta │ │ ├── bird-dqn-2920000 │ │ ├── bird-dqn-2920000.meta │ │ ├── checkpoint │ │ └── pretrained_model │ │ └── bird-dqn-policy ├── deep_q_network_algortihm.png └── deep_q_network_template.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Python: 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | .idea 9 | # pycharm setting: 10 | *.xml -------------------------------------------------------------------------------- /1-gym_developing/README.md: -------------------------------------------------------------------------------- 1 | ## 使用方法 2 | > 下面的操作最好是在管理员权限下操作 3 | 1. 首先找到gym的安装目录,例如我的是`C:\Program Files\Python36\Lib\site-packages\gym`。 4 | --- 5 | 2. 首先将[core.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/core.py) 6 | 文件拷贝覆盖gym安装目录下面的`core.py`文件。 7 | --- 8 | 3. 进入到gym安装目录下面的`/envs/classic_control`,将对应的环境文件([grid_game.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/grid_game.py)和 9 | [maze_game.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/maze_game.py))拷贝到 10 | 拷贝到下面。 11 | --- 12 | 4. 找到gym安装目录下面的`/envs/classic_control/__init__.py`文件,在后面添加 13 | 14 | ```python 15 | from gym.envs.classic_control.grid_game import GridEnv 16 | from gym.envs.classic_control.maze_game import MazeEnv 17 | ``` 18 | --- 19 | 5. 找到gym安装目录下面的`/envs/__init__.py`文件,在后面添加 20 | ```python 21 | register( 22 | id='GridGame-v0', 23 | entry_point='gym.envs.classic_control:GridEnv', 24 | max_episode_steps=200, 25 | reward_threshold=100.0, 26 | ) 27 | register( 28 | id='MazeGame-v0', 29 | entry_point='gym.envs.classic_control:MazeEnv', 30 | max_episode_steps=200, 31 | reward_threshold=100.0, 32 | ) 33 | ``` 34 | --- 35 | 6. 验证配置成功 36 | 运行下面python代码 37 | ```python 38 | import gym 39 | env = gym.make("GridGame-v0") 40 | env.reset() 41 | env.render() 42 | ``` 43 | 如果出现下面的界面说明配置成功 44 | -------------------------------------------------------------------------------- /1-gym_developing/core.py: -------------------------------------------------------------------------------- 1 | from gym import logger 2 | import numpy as np 3 | 4 | import gym 5 | from gym import error 6 | from gym.utils import closer 7 | 8 | env_closer = closer.Closer() 9 | 10 | # Env-related abstractions 11 | 12 | class Env(object): 13 | """The main OpenAI Gym class. It encapsulates an environment with 14 | arbitrary behind-the-scenes dynamics. An environment can be 15 | partially or fully observed. 16 | 17 | The main API methods that users of this class need to know are: 18 | 19 | transform 20 | step 21 | reset 22 | render 23 | close 24 | seed 25 | 26 | And set the following attributes: 27 | 28 | action_space: The Space object corresponding to valid actions 29 | observation_space: The Space object corresponding to valid observations 30 | reward_range: A tuple corresponding to the min and max possible rewards 31 | 32 | Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range. 33 | 34 | The methods are accessed publicly as "step", "reset", etc.. The 35 | non-underscored versions are wrapper methods to which we may add 36 | functionality over time. 37 | """ 38 | 39 | # Set this in SOME subclasses 40 | metadata = {'render.modes': []} 41 | reward_range = (-np.inf, np.inf) 42 | spec = None 43 | 44 | # Set these in ALL subclasses 45 | action_space = None 46 | observation_space = None 47 | 48 | def transform(self, state, action): 49 | ''' 50 | Run one timestep of the environment's dynamics. When end of 51 | episode is reached, you are responsible for calling `transform()` 52 | to reset args state. 53 | 54 | Accepts an state and action returns a tuple(observation,reward,done,info). 55 | 56 | Args: 57 | state (object):an state provided by the anvironment 58 | 59 | Returns: 60 | Returns: 61 | observation (object): agent's observation of the current environment 62 | reward (float) : amount of reward returned after previous action 63 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results 64 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) 65 | ''' 66 | pass 67 | 68 | def step(self, action): 69 | """Run one timestep of the environment's dynamics. When end of 70 | episode is reached, you are responsible for calling `reset()` 71 | to reset this environment's state. 72 | 73 | Accepts an action and returns a tuple (observation, reward, done, info). 74 | 75 | Args: 76 | action (object): an action provided by the environment 77 | 78 | Returns: 79 | observation (object): agent's observation of the current environment 80 | reward (float) : amount of reward returned after previous action 81 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results 82 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) 83 | """ 84 | raise NotImplementedError 85 | 86 | def reset(self): 87 | """Resets the state of the environment and returns an initial observation. 88 | 89 | Returns: observation (object): the initial observation of the 90 | space. 91 | """ 92 | raise NotImplementedError 93 | 94 | def render(self, mode='human'): 95 | """Renders the environment. 96 | 97 | The set of supported modes varies per environment. (And some 98 | environments do not support rendering at all.) By convention, 99 | if mode is: 100 | 101 | - human: render to the current display or terminal and 102 | return nothing. Usually for human consumption. 103 | - rgb_array: Return an numpy.ndarray with shape (x, y, 3), 104 | representing RGB values for an x-by-y pixel image, suitable 105 | for turning into a video. 106 | - ansi: Return a string (str) or StringIO.StringIO containing a 107 | terminal-style text representation. The text can include newlines 108 | and ANSI escape sequences (e.g. for colors). 109 | 110 | Note: 111 | Make sure that your class's metadata 'render.modes' key includes 112 | the list of supported modes. It's recommended to call super() 113 | in implementations to use the functionality of this method. 114 | 115 | Args: 116 | mode (str): the mode to render with 117 | close (bool): close all open renderings 118 | 119 | Example: 120 | 121 | class MyEnv(Env): 122 | metadata = {'render.modes': ['human', 'rgb_array']} 123 | 124 | def render(self, mode='human'): 125 | if mode == 'rgb_array': 126 | return np.array(...) # return RGB frame suitable for video 127 | elif mode is 'human': 128 | ... # pop up a window and render 129 | else: 130 | super(MyEnv, self).render(mode=mode) # just raise an exception 131 | """ 132 | raise NotImplementedError 133 | 134 | def close(self): 135 | """Override _close in your subclass to perform any necessary cleanup. 136 | 137 | Environments will automatically close() themselves when 138 | garbage collected or when the program exits. 139 | """ 140 | return 141 | 142 | def seed(self, seed=None): 143 | """Sets the seed for this env's random number generator(s). 144 | 145 | Note: 146 | Some environments use multiple pseudorandom number generators. 147 | We want to capture all such seeds used in order to ensure that 148 | there aren't accidental correlations between multiple generators. 149 | 150 | Returns: 151 | list: Returns the list of seeds used in this env's random 152 | number generators. The first value in the list should be the 153 | "main" seed, or the value which a reproducer should pass to 154 | 'seed'. Often, the main seed equals the provided 'seed', but 155 | this won't be true if seed=None, for example. 156 | """ 157 | logger.warn("Could not seed environment %s", self) 158 | return 159 | 160 | @property 161 | def unwrapped(self): 162 | """Completely unwrap this env. 163 | 164 | Returns: 165 | gym.Env: The base non-wrapped gym.Env instance 166 | """ 167 | return self 168 | 169 | def __str__(self): 170 | if self.spec is None: 171 | return '<{} instance>'.format(type(self).__name__) 172 | else: 173 | return '<{}<{}>>'.format(type(self).__name__, self.spec.id) 174 | 175 | 176 | class GoalEnv(Env): 177 | """A goal-based environment. It functions just as any regular OpenAI Gym environment but it 178 | imposes a required structure on the observation_space. More concretely, the observation 179 | space is required to contain at least three elements, namely `observation`, `desired_goal`, and 180 | `achieved_goal`. Here, `desired_goal` specifies the goal that the agent should attempt to achieve. 181 | `achieved_goal` is the goal that it currently achieved instead. `observation` contains the 182 | actual observations of the environment as per usual. 183 | """ 184 | 185 | def reset(self): 186 | # Enforce that each GoalEnv uses a Goal-compatible observation space. 187 | if not isinstance(self.observation_space, gym.spaces.Dict): 188 | raise error.Error('GoalEnv requires an observation space of type gym.spaces.Dict') 189 | result = super(GoalEnv, self).reset() 190 | for key in ['observation', 'achieved_goal', 'desired_goal']: 191 | if key not in result: 192 | raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key)) 193 | return result 194 | 195 | def compute_reward(self, achieved_goal, desired_goal, info): 196 | """Compute the step reward. This externalizes the reward function and makes 197 | it dependent on an a desired goal and the one that was achieved. If you wish to include 198 | additional rewards that are independent of the goal, you can include the necessary values 199 | to derive it in info and compute it accordingly. 200 | 201 | Args: 202 | achieved_goal (object): the goal that was achieved during execution 203 | desired_goal (object): the desired goal that we asked the agent to attempt to achieve 204 | info (dict): an info dictionary with additional information 205 | 206 | Returns: 207 | float: The reward that corresponds to the provided achieved goal w.r.t. to the desired 208 | goal. Note that the following should always hold true: 209 | 210 | ob, reward, done, info = env.step() 211 | assert reward == env.compute_reward(ob['achieved_goal'], ob['goal'], info) 212 | """ 213 | raise NotImplementedError() 214 | 215 | # Space-related abstractions 216 | 217 | class Space(object): 218 | """Defines the observation and action spaces, so you can write generic 219 | code that applies to any Env. For example, you can choose a random 220 | action. 221 | """ 222 | def __init__(self, shape=None, dtype=None): 223 | self.shape = None if shape is None else tuple(shape) 224 | self.dtype = None if dtype is None else np.dtype(dtype) 225 | 226 | def sample(self): 227 | """ 228 | Uniformly randomly sample a random element of this space 229 | """ 230 | raise NotImplementedError 231 | 232 | def contains(self, x): 233 | """ 234 | Return boolean specifying if x is a valid 235 | member of this space 236 | """ 237 | raise NotImplementedError 238 | 239 | def to_jsonable(self, sample_n): 240 | """Convert a batch of samples from this space to a JSONable data type.""" 241 | # By default, assume identity is JSONable 242 | return sample_n 243 | 244 | def from_jsonable(self, sample_n): 245 | """Convert a JSONable data type to a batch of samples from this space.""" 246 | # By default, assume identity is JSONable 247 | return sample_n 248 | 249 | 250 | warn_once = True 251 | 252 | def deprecated_warn_once(text): 253 | global warn_once 254 | if not warn_once: return 255 | warn_once = False 256 | logger.warn(text) 257 | 258 | 259 | class Wrapper(Env): 260 | env = None 261 | 262 | def __init__(self, env): 263 | self.env = env 264 | self.action_space = self.env.action_space 265 | self.observation_space = self.env.observation_space 266 | self.reward_range = self.env.reward_range 267 | self.metadata = self.env.metadata 268 | self._warn_double_wrap() 269 | 270 | @classmethod 271 | def class_name(cls): 272 | return cls.__name__ 273 | 274 | def _warn_double_wrap(self): 275 | env = self.env 276 | while True: 277 | if isinstance(env, Wrapper): 278 | if env.class_name() == self.class_name(): 279 | raise error.DoubleWrapperError("Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__)) 280 | env = env.env 281 | else: 282 | break 283 | 284 | def transform(self, state, action): 285 | return self.env.transform(state, action) 286 | 287 | def step(self, action): 288 | if hasattr(self, "_step"): 289 | deprecated_warn_once("%s doesn't implement 'step' method, but it implements deprecated '_step' method." % type(self)) 290 | self.step = self._step 291 | return self.step(action) 292 | else: 293 | deprecated_warn_once("%s doesn't implement 'step' method, " % type(self) + 294 | "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.") 295 | return self.env.step(action) 296 | 297 | def reset(self, **kwargs): 298 | if hasattr(self, "_reset"): 299 | deprecated_warn_once("%s doesn't implement 'reset' method, but it implements deprecated '_reset' method." % type(self)) 300 | self.reset = self._reset 301 | return self._reset(**kwargs) 302 | else: 303 | deprecated_warn_once("%s doesn't implement 'reset' method, " % type(self) + 304 | "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.") 305 | return self.env.reset(**kwargs) 306 | 307 | def render(self, mode='human'): 308 | return self.env.render(mode) 309 | 310 | def close(self): 311 | if self.env: 312 | return self.env.close() 313 | 314 | def seed(self, seed=None): 315 | return self.env.seed(seed) 316 | 317 | def compute_reward(self, achieved_goal, desired_goal, info): 318 | return self.env.compute_reward(achieved_goal, desired_goal, info) 319 | 320 | def __str__(self): 321 | return '<{}{}>'.format(type(self).__name__, self.env) 322 | 323 | def __repr__(self): 324 | return str(self) 325 | 326 | @property 327 | def unwrapped(self): 328 | return self.env.unwrapped 329 | 330 | @property 331 | def spec(self): 332 | return self.env.spec 333 | 334 | 335 | class ObservationWrapper(Wrapper): 336 | def transform(self, state, action): 337 | return self.env.transform(state, action) 338 | 339 | def step(self, action): 340 | observation, reward, done, info = self.env.step(action) 341 | return self.observation(observation), reward, done, info 342 | 343 | def reset(self, **kwargs): 344 | observation = self.env.reset(**kwargs) 345 | return self.observation(observation) 346 | 347 | def observation(self, observation): 348 | deprecated_warn_once("%s doesn't implement 'observation' method. Maybe it implements deprecated '_observation' method." % type(self)) 349 | return self._observation(observation) 350 | 351 | 352 | class RewardWrapper(Wrapper): 353 | def transform(self, state, action): 354 | return self.env.transform(state, action) 355 | 356 | def reset(self): 357 | return self.env.reset() 358 | 359 | def step(self, action): 360 | observation, reward, done, info = self.env.step(action) 361 | return observation, self.reward(reward), done, info 362 | 363 | def reward(self, reward): 364 | deprecated_warn_once("%s doesn't implement 'reward' method. Maybe it implements deprecated '_reward' method." % type(self)) 365 | return self._reward(reward) 366 | 367 | 368 | class ActionWrapper(Wrapper): 369 | def transform(self, state, action): 370 | return self.env.transform(state, action) 371 | 372 | def step(self, action): 373 | action = self.action(action) 374 | return self.env.step(action) 375 | 376 | def reset(self): 377 | return self.env.reset() 378 | 379 | def action(self, action): 380 | deprecated_warn_once("%s doesn't implement 'action' method. Maybe it implements deprecated '_action' method." % type(self)) 381 | return self._action(action) 382 | 383 | def reverse_action(self, action): 384 | deprecated_warn_once("%s doesn't implement 'reverse_action' method. Maybe it implements deprecated '_reverse_action' method." % type(self)) 385 | return self._reverse_action(action) 386 | -------------------------------------------------------------------------------- /1-gym_developing/grid_game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/7 5 | import gym 6 | from gym.utils import seeding 7 | import logging 8 | import numpy as np 9 | import pandas as pd 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class GridEnv(gym.Env): 15 | metadata = { 16 | 'render.modes': ['human', 'rgb_array'], 17 | 'video.frames_per_second': 2 18 | } 19 | 20 | def __init__(self): 21 | 22 | self.observation_space = (1, 2, 3, 4, 5, 6, 7, 8) # 状态空间 23 | self.x = [140, 220, 300, 380, 460, 140, 300, 460] 24 | self.y = [250, 250, 250, 250, 250, 150, 150, 150] 25 | self.__terminal_space = dict() # 终止状态为字典格式 26 | self.__terminal_space[6] = 1 27 | self.__terminal_space[7] = 1 28 | self.__terminal_space[8] = 1 29 | 30 | # 状态转移的数据格式为字典 31 | self.action_space = ('n', 'e', 's', 'w') 32 | self.t = pd.DataFrame(data=None, index=self.observation_space, columns=self.action_space) 33 | self.t.loc[1, "s"] = 6 34 | self.t.loc[1, "e"] = 2 35 | self.t.loc[2, "w"] = 1 36 | self.t.loc[2, "e"] = 3 37 | self.t.loc[3, "s"] = 7 38 | self.t.loc[3, "w"] = 2 39 | self.t.loc[3, "e"] = 4 40 | self.t.loc[4, "w"] = 3 41 | self.t.loc[4, "e"] = 5 42 | self.t.loc[5, "s"] = 8 43 | self.t.loc[5, "w"] = 4 44 | self.__gamma = 0.8 # 折扣因子 45 | self.viewer = None 46 | self.__state = None 47 | self.seed() 48 | 49 | def _reward(self, state): 50 | r = 0.0 51 | if state in (6, 8): 52 | r = -1.0 53 | elif state == 7: 54 | r = 1.0 55 | return r 56 | 57 | def seed(self, seed=None): 58 | self.np_random, seed = seeding.np_random(seed) 59 | return [seed] 60 | 61 | def close(self): 62 | if self.viewer: 63 | self.viewer.close() 64 | 65 | def transform(self, state, action): 66 | #卫语句 67 | if state in self.__terminal_space: 68 | return state, self._reward(state), True, {} 69 | 70 | # 状态转移 71 | if pd.isna(self.t.loc[state, action]): 72 | next_state = state 73 | else: 74 | next_state = self.t.loc[state, action] 75 | 76 | # 计算回报 77 | r = self._reward(next_state) 78 | 79 | # 判断是否终止 80 | is_terminal = False 81 | if next_state in self.__terminal_space: 82 | is_terminal = True 83 | 84 | return next_state, r, is_terminal, {} 85 | 86 | def step(self, action): 87 | state = self.__state 88 | 89 | next_state, r, is_terminal,_ = self.transform(state, action) 90 | 91 | self.__state = next_state 92 | 93 | return next_state, r, is_terminal, {} 94 | 95 | def reset(self): 96 | self.__state = np.random.choice(self.observation_space) 97 | return self.__state 98 | 99 | def render(self, mode='human', close=False): 100 | if close: 101 | if self.viewer is not None: 102 | self.viewer.close() 103 | self.viewer = None 104 | return 105 | screen_width = 600 106 | screen_height = 400 107 | 108 | if self.viewer is None: 109 | from gym.envs.classic_control import rendering 110 | self.viewer = rendering.Viewer(screen_width, screen_height) 111 | # 创建网格世界 112 | self.line1 = rendering.Line((100, 300), (500, 300)) 113 | self.line2 = rendering.Line((100, 200), (500, 200)) 114 | self.line3 = rendering.Line((100, 300), (100, 100)) 115 | self.line4 = rendering.Line((180, 300), (180, 100)) 116 | self.line5 = rendering.Line((260, 300), (260, 100)) 117 | self.line6 = rendering.Line((340, 300), (340, 100)) 118 | self.line7 = rendering.Line((420, 300), (420, 100)) 119 | self.line8 = rendering.Line((500, 300), (500, 100)) 120 | self.line9 = rendering.Line((100, 100), (180, 100)) 121 | self.line10 = rendering.Line((260, 100), (340, 100)) 122 | self.line11 = rendering.Line((420, 100), (500, 100)) 123 | # 创建第一个骷髅 124 | self.kulo1 = rendering.make_circle(40) 125 | self.circletrans = rendering.Transform(translation=(140, 150)) 126 | self.kulo1.add_attr(self.circletrans) 127 | self.kulo1.set_color(0, 0, 0) 128 | # 创建第二个骷髅 129 | self.kulo2 = rendering.make_circle(40) 130 | self.circletrans = rendering.Transform(translation=(460, 150)) 131 | self.kulo2.add_attr(self.circletrans) 132 | self.kulo2.set_color(0, 0, 0) 133 | # 创建金条 134 | self.gold = rendering.make_circle(40) 135 | self.circletrans = rendering.Transform(translation=(300, 150)) 136 | self.gold.add_attr(self.circletrans) 137 | self.gold.set_color(1, 0.9, 0) 138 | # 创建机器人 139 | self.robot = rendering.make_circle(30) 140 | self.robotrans = rendering.Transform() 141 | self.robot.add_attr(self.robotrans) 142 | self.robot.set_color(0.8, 0.6, 0.4) 143 | 144 | self.line1.set_color(0, 0, 0) 145 | self.line2.set_color(0, 0, 0) 146 | self.line3.set_color(0, 0, 0) 147 | self.line4.set_color(0, 0, 0) 148 | self.line5.set_color(0, 0, 0) 149 | self.line6.set_color(0, 0, 0) 150 | self.line7.set_color(0, 0, 0) 151 | self.line8.set_color(0, 0, 0) 152 | self.line9.set_color(0, 0, 0) 153 | self.line10.set_color(0, 0, 0) 154 | self.line11.set_color(0, 0, 0) 155 | 156 | self.viewer.add_geom(self.line1) 157 | self.viewer.add_geom(self.line2) 158 | self.viewer.add_geom(self.line3) 159 | self.viewer.add_geom(self.line4) 160 | self.viewer.add_geom(self.line5) 161 | self.viewer.add_geom(self.line6) 162 | self.viewer.add_geom(self.line7) 163 | self.viewer.add_geom(self.line8) 164 | self.viewer.add_geom(self.line9) 165 | self.viewer.add_geom(self.line10) 166 | self.viewer.add_geom(self.line11) 167 | self.viewer.add_geom(self.kulo1) 168 | self.viewer.add_geom(self.kulo2) 169 | self.viewer.add_geom(self.gold) 170 | self.viewer.add_geom(self.robot) 171 | 172 | if self.__state is None: 173 | return None 174 | 175 | self.robotrans.set_translation(self.x[self.__state - 1], self.y[self.__state - 1]) 176 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array')) 177 | 178 | 179 | if __name__ == '__main__': 180 | env = GridEnv() 181 | env.reset() 182 | env.render() 183 | -------------------------------------------------------------------------------- /1-gym_developing/maze_game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/7 5 | import gym 6 | from gym.utils import seeding 7 | import logging 8 | import numpy as np 9 | import pandas as pd 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MazeEnv(gym.Env): 15 | metadata = { 16 | 'render.modes': ['human', 'rgb_array'], 17 | 'video.frames_per_second': 1 18 | } 19 | 20 | def __init__(self): 21 | self.map = np.array([ 22 | [0, 0, 1, 0, 0], 23 | [0, 0, 1, 0, 0], 24 | [1, 0, 0, 0, 0], 25 | [1, 0, 0, 1, 1], 26 | [1, 0, 0, 0, 0]], dtype=np.bool) 27 | 28 | self.observation_space = [tuple(s) for s in np.argwhere(self.map == 0)] 29 | self.walls = [tuple(s) for s in np.argwhere(self.map)] 30 | self.__terminal_space = ((4, 2),) # 终止状态为字典格式 31 | 32 | # 状态转移的数据格式为字典 33 | self.action_space = ('n', 'e', 's', 'w') 34 | self.t = pd.DataFrame(data=None, index=self.observation_space, columns=self.action_space) 35 | self._trans_make() 36 | self.viewer = None 37 | self.__state = None 38 | self.seed() 39 | 40 | def _trans_make(self): 41 | for s in self.observation_space: 42 | for a in self.action_space: 43 | if a == "n": 44 | n_s = np.array(s) + np.array([0, 1]) 45 | elif a == "e": 46 | n_s = np.array(s) + np.array([1, 0]) 47 | elif a == "s": 48 | n_s = np.array(s) + np.array([0, -1]) 49 | elif a == "w": 50 | n_s = np.array(s) + np.array([-1, 0]) 51 | if (0 <= n_s).all() and (n_s <= 4).all() and not self.map[n_s[0], n_s[1]]: 52 | self.t.loc[s, a] = tuple(n_s) 53 | else: 54 | self.t.loc[s, a] = s 55 | 56 | def _reward(self, state): 57 | r = 0.0 58 | n_s = np.array(state) 59 | if (0 <= n_s).all() and (n_s <= 4).all() and tuple(n_s) in self.__terminal_space: 60 | r = 1.0 61 | return r 62 | 63 | def seed(self, seed=None): 64 | self.np_random, seed = seeding.np_random(seed) 65 | return [seed] 66 | 67 | def close(self): 68 | if self.viewer: 69 | self.viewer.close() 70 | 71 | def transform(self, state, action): 72 | # 卫语言 73 | if state in self.__terminal_space: 74 | return state, self._reward(state), True, {} 75 | 76 | # 状态转移 77 | next_state = self.t[action][state] 78 | 79 | # 计算回报 80 | r = self._reward(next_state) 81 | 82 | # 判断是否终止 83 | is_terminal = False 84 | if next_state in self.__terminal_space: 85 | is_terminal = True 86 | 87 | return next_state, r, is_terminal, {} 88 | 89 | def step(self, action): 90 | state = self.__state 91 | 92 | next_state, r, is_terminal, _ = self.transform(state, action) 93 | 94 | self.__state = next_state 95 | 96 | return next_state, r, is_terminal, {} 97 | 98 | def reset(self): 99 | 100 | while True: 101 | self.__state = self.observation_space[np.random.choice(len(self.observation_space))] 102 | if self.__state not in self.__terminal_space: 103 | break 104 | return self.__state 105 | 106 | def render(self, mode='human', close=False): 107 | if close: 108 | if self.viewer is not None: 109 | self.viewer.close() 110 | self.viewer = None 111 | return 112 | unit = 50 113 | screen_width = 5 * unit 114 | screen_height = 5 * unit 115 | 116 | if self.viewer is None: 117 | from gym.envs.classic_control import rendering 118 | self.viewer = rendering.Viewer(screen_width, screen_height) 119 | 120 | #创建网格 121 | for c in range(5): 122 | line = rendering.Line((0, c*unit), (screen_width, c*unit)) 123 | line.set_color(0, 0, 0) 124 | self.viewer.add_geom(line) 125 | for r in range(5): 126 | line = rendering.Line((r*unit, 0), (r*unit, screen_height)) 127 | line.set_color(0, 0, 0) 128 | self.viewer.add_geom(line) 129 | 130 | # 创建墙壁 131 | for x, y in self.walls: 132 | r = rendering.make_polygon( 133 | v=[[x * unit, y * unit], 134 | [(x + 1) * unit, y * unit], 135 | [(x + 1) * unit, (y + 1) * unit], 136 | [x * unit, (y + 1) * unit], 137 | [x * unit, y * unit]]) 138 | r.set_color(0, 0, 0) 139 | self.viewer.add_geom(r) 140 | 141 | # 创建机器人 142 | self.robot = rendering.make_circle(20) 143 | self.robotrans = rendering.Transform() 144 | self.robot.add_attr(self.robotrans) 145 | self.robot.set_color(0.8, 0.6, 0.4) 146 | self.viewer.add_geom(self.robot) 147 | 148 | # 创建出口 149 | self.exit = rendering.make_circle(20) 150 | self.exitrans = rendering.Transform(translation=(4*unit+unit/2, 2*unit+unit/2)) 151 | self.exit.add_attr(self.exitrans) 152 | self.exit.set_color(0, 1, 0) 153 | self.viewer.add_geom(self.exit) 154 | 155 | if self.__state is None: 156 | return None 157 | 158 | self.robotrans.set_translation(self.__state[0] * unit + unit / 2, self.__state[1] * unit + unit / 2) 159 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array')) 160 | 161 | 162 | if __name__ == '__main__': 163 | env = MazeEnv() 164 | env.reset() 165 | env.render() 166 | -------------------------------------------------------------------------------- /1-gym_developing/suceess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/1-gym_developing/suceess.png -------------------------------------------------------------------------------- /2-markov_decision_process/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/2 -------------------------------------------------------------------------------- /2-markov_decision_process/game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/7 5 | import gym 6 | import numpy as np 7 | 8 | 9 | def main(): 10 | env = gym.make("MazeGame-v0") 11 | s = env.reset() 12 | a_s = env.action_space 13 | for i in range(100): 14 | env.render() 15 | a = np.random.choice(a_s) 16 | print(a) 17 | s, r, t, _ = env.step(a) 18 | 19 | if __name__ == '__main__': 20 | main() 21 | -------------------------------------------------------------------------------- /2-markov_decision_process/our_life.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/2 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | def main(): 10 | S = ("S1", "S2", "S3", "S4", "S5") 11 | A = ("玩", "退出", "学习", "发表", "睡觉") 12 | strategy = pd.DataFrame(data=None, index=S, columns=A) 13 | reward = pd.DataFrame(data=None, index=S, columns=A) 14 | gama = 1 15 | strategy.loc["S1", :] = np.array([0.9, 0.1, 0, 0, 0]) 16 | strategy.loc["S2", :] = np.array([0.5, 0, 0.5, 0, 0]) 17 | strategy.loc["S3", :] = np.array([0, 0, 0.8, 0, 0.2]) 18 | strategy.loc["S4", :] = np.array([0, 0, 0, 0.4, 0.6]) 19 | strategy.loc["S5", :] = np.array([0, 0, 0, 0, 0]) 20 | print("策略") 21 | print(strategy) 22 | reward.loc["S1", :] = np.array([-1, 0, 0, 0, 0]) 23 | reward.loc["S2", :] = np.array([-1, 0, -2, 0, 0]) 24 | reward.loc["S3", :] = np.array([0, 0, -2, 0, 0]) 25 | reward.loc["S4", :] = np.array([0, 0, 0, 10, 1]) 26 | reward.loc["S5", :] = np.array([0, 0, 0, 0, 0]) 27 | print("回报函数") 28 | print(reward) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /3-dynamic_program/grid_game_with_average_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/9 5 | ''' 6 | 这里面存在几个问题,按照书作者的意思是 7 | 作用的动作如果不可以移动的话,那么 8 | 用于迭代的 v(s') = v(s) 9 | 再用于计算 v(s) = Σπ*(r + γ*v(s')) 10 | 我感觉这个是不对 11 | 如果没有移动应该就没有 r 和 折扣这一说 12 | 但是这里但是姑且用这么书中的方式 13 | 之后的策略迭代和值迭代都是修改过的 14 | ''' 15 | import pandas as pd 16 | import numpy as np 17 | 18 | 19 | class GridMDP: 20 | def __init__(self, **kwargs): 21 | for k, v in kwargs.items(): 22 | setattr(self, k, v) 23 | self.__action_dir = pd.Series( 24 | data=[np.array((-1, 0)), 25 | np.array((1, 0)), 26 | np.array((0, -1)), 27 | np.array((0, 1))], 28 | index=self.action_space) 29 | self.terminal_states = [(0, 0), (3, 3)] 30 | 31 | def transform(self, state, action): 32 | dir = self.__action_dir[action] 33 | state_ = np.array(state) + dir 34 | if (state_ >= 0).all() and (state_ < 4).all(): 35 | state_ = tuple(state_) 36 | else: 37 | state_ = state 38 | return state_ 39 | 40 | 41 | def average_policy(mdp, v_s, policy): 42 | state_space = mdp.state_space 43 | action_space = mdp.action_space 44 | reward = mdp.reward 45 | gamma = mdp.gamma 46 | while True: 47 | print(v_s) 48 | v_s_ = v_s.copy() 49 | for state in state_space: 50 | v_s_a = pd.Series() 51 | for action in action_space: 52 | state_ = mdp.transform(state,action) 53 | if state_ in mdp.terminal_states: 54 | v_s_a[action] = 0 55 | elif state_ != state: 56 | v_s_a[action] = v_s_[state_] 57 | else: 58 | v_s_a[action] = v_s_[state] 59 | v_s[state] = sum([policy[action] * (reward + gamma * v_s_a[action]) for action in action_space]) 60 | if (np.abs(v_s_ - v_s) < 1e-8).all(): 61 | break 62 | return v_s 63 | 64 | 65 | def main(): 66 | state_space = [(i, j) for i in range(4) for j in range(4)] 67 | state_space.remove((0, 0)) 68 | state_space.remove((3, 3)) 69 | mdp = GridMDP( 70 | state_space=state_space, 71 | action_space=["n", "s", "w", "e"], 72 | reward=-1, 73 | gamma=1) 74 | v_s = pd.Series(np.zeros((len(state_space))),index=state_space) 75 | policy = pd.Series(data=0.25 * np.ones(shape=(4)), index=mdp.action_space) 76 | v_s = average_policy(mdp, v_s, policy) 77 | print("convergence valuation of __state is:") 78 | print(v_s) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /3-dynamic_program/grid_game_with_policy_iterate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/9 5 | import pandas as pd 6 | import numpy as np 7 | 8 | 9 | class GridMDP: 10 | def __init__(self, **kwargs): 11 | for k, v in kwargs.items(): 12 | setattr(self, k, v) 13 | self.__action_dir = pd.Series( 14 | data=[np.array((-1, 0)), 15 | np.array((1, 0)), 16 | np.array((0, -1)), 17 | np.array((0, 1))], 18 | index=self.action_space) 19 | self.terminal_space = [(0, 0), (3, 3)] 20 | 21 | def transform(self, state, action): 22 | dir = self.__action_dir[action] 23 | state_ = np.array(state) + dir 24 | if (state_ >= 0).all() and (state_ < 4).all(): 25 | state_ = tuple(state_) 26 | else: 27 | state_ = state 28 | return state_ 29 | 30 | 31 | def policy_evaluate(v_s, policy, mdp): 32 | state_space = mdp.state_space 33 | gamma = mdp.gamma 34 | reward = mdp.reward 35 | while True: 36 | v_s_ = v_s.copy() 37 | for state in state_space: 38 | action = policy[state] 39 | state_ = mdp.transform(state, action) 40 | if state_ in mdp.terminal_space: # 发生转移 41 | v_s[state] = reward + 0.0 42 | elif state_ != state: # 终点位置 43 | v_s[state] = reward + gamma * v_s_[state_] 44 | else: # 没有发生转移 45 | v_s[state] = reward + gamma * v_s_[state_] 46 | 47 | if (np.abs(v_s_ - v_s) < 1e-8).all(): 48 | break 49 | return v_s 50 | 51 | 52 | def policy_improve(v_s, mdp): 53 | state_space = mdp.state_space 54 | action_space = mdp.action_space 55 | gamma = mdp.gamma 56 | reward = mdp.reward 57 | policy_ = pd.Series(index=state_space) 58 | for state in state_space: 59 | v_s_a = pd.Series() 60 | for action in action_space: 61 | state_ = mdp.transform(state, action) 62 | if state_ in mdp.terminal_space: 63 | v_s_a[action] = reward 64 | else: 65 | v_s_a[action] = reward + gamma * v_s[state_] 66 | 67 | # 随机选取最大的值 68 | m = v_s_a.max() 69 | policy_[state] = np.random.choice(v_s_a[v_s_a == m].index) 70 | return policy_ 71 | 72 | 73 | def policy_iterate(mdp): 74 | v_s = pd.Series(data=np.zeros(shape=len(mdp.state_space)), index=mdp.state_space) 75 | policy = pd.Series(data=np.random.choice(mdp.action_space, size=(len(mdp.state_space))), index=mdp.state_space) 76 | while True: 77 | print(policy) 78 | v_s = policy_evaluate(v_s, policy, mdp) 79 | policy_ = policy_improve(v_s, mdp) 80 | if (policy_ == policy).all(): 81 | break 82 | else: 83 | policy = policy_ 84 | return policy 85 | 86 | 87 | def main(): 88 | state_space = [(i, j) for i in range(4) for j in range(4)] 89 | state_space.remove((0, 0)) 90 | state_space.remove((3, 3)) 91 | mdp = GridMDP( 92 | state_space=state_space, 93 | action_space=["n", "s", "w", "e"], 94 | reward=-1, 95 | gamma=0.9) 96 | policy = policy_iterate(mdp) 97 | print("convergence policy is:") 98 | print(policy) 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /3-dynamic_program/grid_game_with_value_iterate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/9 5 | import pandas as pd 6 | import numpy as np 7 | 8 | 9 | class GridMDP: 10 | def __init__(self, **kwargs): 11 | for k, v in kwargs.items(): 12 | setattr(self, k, v) 13 | self.__action_dir = pd.Series( 14 | data=[np.array((-1, 0)), 15 | np.array((1, 0)), 16 | np.array((0, -1)), 17 | np.array((0, 1))], 18 | index=self.action_space) 19 | self.terminal_space = [(0, 0), (3, 3)] 20 | 21 | def transform(self, state, action): 22 | dir = self.__action_dir[action] 23 | state_ = np.array(state) + dir 24 | if (state_ >= 0).all() and (state_ < 4).all(): 25 | state_ = tuple(state_) 26 | else: 27 | state_ = state 28 | return state_ 29 | 30 | 31 | def value_iterate(mdp): 32 | state_space = mdp.state_space 33 | action_space = mdp.action_space 34 | gamma = mdp.gamma 35 | reward = mdp.reward 36 | v_s = pd.Series(data=np.zeros(shape=len(state_space)), index=state_space) 37 | policy = pd.Series(index=state_space) 38 | while True: 39 | print(v_s) 40 | v_s_ = v_s.copy() 41 | for state in state_space: 42 | v_s_a = pd.Series() 43 | for action in action_space: 44 | state_ = mdp.transform(state,action) 45 | if state_ in mdp.terminal_space: 46 | v_s_a[action] = reward 47 | else: 48 | v_s_a[action] = reward + gamma * v_s_[state_] 49 | 50 | v_s[state] = v_s_a.max() 51 | policy[state] = np.random.choice(v_s_a[v_s_a == v_s[state]].index) 52 | 53 | if (np.abs(v_s_ - v_s) < 1e-8).all(): 54 | break 55 | return policy 56 | 57 | 58 | def main(): 59 | state_space = [(i, j) for i in range(4) for j in range(4)] 60 | state_space.remove((0, 0)) 61 | state_space.remove((3, 3)) 62 | mdp = GridMDP( 63 | state_space=state_space, 64 | action_space=["n", "s", "w", "e"], 65 | reward=-1, 66 | gamma=0.9) 67 | policy = value_iterate(mdp) 68 | print("convergence policy is:") 69 | print(policy) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /3-dynamic_program/maze_game_with_dynamic_program.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/9 5 | import gym 6 | import pandas as pd 7 | import numpy as np 8 | 9 | 10 | def value_iterate(env): 11 | state_space = env.observation_space 12 | action_space = env.action_space 13 | v_s = pd.Series(data=np.zeros(shape=len(state_space)), index=state_space) 14 | policy = pd.Series(index=state_space) 15 | gamma = 0.8 16 | while True: 17 | print(v_s) 18 | v_s_ = v_s.copy() 19 | for state in state_space: 20 | v_s_a = pd.Series() 21 | for action in action_space: 22 | state_, reward, is_done, _ = env.transform(state, action) 23 | if is_done: 24 | v_s_a[action] = reward 25 | else: 26 | v_s_a[action] = reward + gamma*v_s_[state_] 27 | v_s[state] = v_s_a.max() 28 | policy[state] = np.random.choice((v_s_a == v_s[state]).index) 29 | if (np.abs(v_s_ - v_s) < 1e-8).all(): 30 | break 31 | return policy 32 | 33 | 34 | ### 这个就是一个伪代码 ### 35 | def main(): 36 | env = gym.make("MazeGame-v0") 37 | print(env.action_space) 38 | print(env.observation_space) 39 | policy = value_iterate(env) 40 | print("convergence policy is:") 41 | print(policy) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /3-dynamic_program/policy_iteration_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/3-dynamic_program/policy_iteration_algorithm.png -------------------------------------------------------------------------------- /3-dynamic_program/value_iteration_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/3-dynamic_program/value_iteration_algorithm.png -------------------------------------------------------------------------------- /4-monte_carlo/monte_carlo_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/5/2 5 | from collections import defaultdict 6 | 7 | 8 | def mc(gamma, env, state_sample, reward_sample): 9 | V = defaultdict(float) 10 | N = defaultdict(int) 11 | states = env.observation_space 12 | num = len(state_sample) 13 | for i in range(num): 14 | G = 0.0 15 | episode_len = len(state_sample[i]) 16 | # 从后往前尝试 17 | for episode in range(episode_len-1, -1, -1): 18 | G *= gamma 19 | G += reward_sample[i][episode] 20 | 21 | # 计算每一状态的值函数累加 22 | for episode in range(episode_len): 23 | s = state_sample[i][episode] 24 | V[s] += G 25 | N[s] += 1 26 | G -= reward_sample[i][episode] 27 | G /= gamma 28 | 29 | # 经验平均 30 | for s in states: 31 | if N[s] >= 0.000001: 32 | V[s] /= N[s] 33 | return V 34 | -------------------------------------------------------------------------------- /4-monte_carlo/monte_carlo_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/4/30 5 | import numpy as np 6 | 7 | 8 | def gen_andom(env, num): 9 | state_sample = [] 10 | action_sample = [] 11 | reward_sample = [] 12 | # 模拟num次的采样 13 | for i in range(num): 14 | s_tmp = [] 15 | a_tmp = [] 16 | r_tmp = [] 17 | s = env.reset() 18 | is_done = False 19 | # 每次采样的过程 20 | while not is_done: 21 | a = np.random.choice(env.action_space) 22 | s_, r, is_done, _ = env.transform(s, a) 23 | s_tmp.append(s) 24 | a_tmp.append(a) 25 | r_tmp.append(r) 26 | s = s_ 27 | state_sample.append(s_tmp) 28 | action_sample.append(a_tmp) 29 | reward_sample.append(r_tmp) 30 | return state_sample, action_sample, reward_sample 31 | -------------------------------------------------------------------------------- /5-temporal_difference/README.md: -------------------------------------------------------------------------------- 1 | ## 使用方法 2 | 1. 参考1-gym_developing里面的[使用方法](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/README.md)安装push_box_game环境 3 | 环境文件是push_box_game.py 4 | 5 | -------------------------------------------------------------------------------- /5-temporal_difference/push_box_game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2018/5/4 5 | import gym 6 | from gym.utils import seeding 7 | from gym.envs.classic_control import rendering 8 | import logging 9 | import numpy as np 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class PushBoxEnv(gym.Env): 15 | metadata = { 16 | 'render.modes': ['human', 'rgb_array'], 17 | 'video.frames_per_second': 1 18 | } 19 | 20 | def __init__(self): 21 | self.reward = np.array([ 22 | [-100, -100, -100, -100, -100], 23 | [-100, 0, 0, 0, -100], 24 | [-100, 0, -100, 0, -100], 25 | [-100, 0, -100, 0, -100], 26 | [-100, 0, 100, 0, -100], 27 | ]).T 28 | self.walls= [(0, 1), (0, 3), (2, 2), (2, 3)] 29 | self.dest_position = (2, 4) 30 | self.work_position = (2, 0) 31 | self.box_position = (2, 1) 32 | self.action_space = ((1, 0), (-1, 0), (0, 1), (0, -1)) 33 | self.viewer = None 34 | self.__state = None 35 | self.seed() 36 | 37 | def _is_terminal(self, state): 38 | box_position = state[2:] 39 | if self.reward[box_position] != 0: 40 | return True 41 | return False 42 | 43 | def _move_ok(self, position): 44 | if tuple(position) not in self.walls and \ 45 | (position >= 0).all() and \ 46 | (position < 5).all(): 47 | return True 48 | return False 49 | 50 | def _trans_make(self, state, action): 51 | work_position = np.array(state[:2]) 52 | box_position = np.array(state[2:]) 53 | action = np.array(action) 54 | next_work_position = work_position + action 55 | # 判断人可不可以移动 56 | if self._move_ok(next_work_position): 57 | work_position = next_work_position 58 | # 判断箱子可以移动 59 | if (next_work_position == box_position).all(): 60 | 61 | next_box_position = box_position + action 62 | if self._move_ok(next_box_position): 63 | # 说明箱子可以移动 64 | box_position = next_box_position 65 | else: 66 | # 虽然箱子在前方但是只是挡住了自己的路线,需要将已经移动的人还原 67 | work_position -= action 68 | return tuple(np.hstack((work_position, box_position))) 69 | 70 | def _reward(self, state): 71 | box_position = state[2:] 72 | return self.reward[box_position] 73 | 74 | def seed(self, seed=None): 75 | self.np_random, seed = seeding.np_random(seed) 76 | return [seed] 77 | 78 | def close(self): 79 | if self.viewer: 80 | self.viewer.close() 81 | 82 | def transform(self, state, action): 83 | # 卫语言 84 | if self._is_terminal(state): 85 | return state, self._reward(state), True, {} 86 | 87 | # 状态转移 88 | next_state = self._trans_make(state, action) 89 | 90 | # 计算回报 91 | r = self._reward(next_state) 92 | 93 | # 判断是否终止 94 | is_terminal = False 95 | if self._is_terminal(next_state): 96 | is_terminal = True 97 | 98 | return next_state, r, is_terminal, {} 99 | 100 | def step(self, action): 101 | # 系统当前状态 102 | state = self.__state 103 | 104 | # 调用transform状态转移 105 | next_state, r, is_terminal, _ = self.transform(state, action) 106 | 107 | # 状态转移 108 | self.__state = next_state 109 | 110 | return next_state, r, is_terminal, {} 111 | 112 | def reset(self): 113 | self.__state = tuple(np.hstack((self.work_position, self.box_position))) 114 | return self.__state 115 | 116 | def render(self, mode='human', close=False): 117 | if close: 118 | if self.viewer is not None: 119 | self.viewer.close() 120 | self.viewer = None 121 | return 122 | unit = 50 123 | screen_width = 5 * unit 124 | screen_height = 5 * unit 125 | 126 | if self.viewer is None: 127 | 128 | self.viewer = rendering.Viewer(screen_width, screen_height) 129 | 130 | #创建网格 131 | for c in range(5): 132 | line = rendering.Line((0, c*unit), (screen_width, c*unit)) 133 | line.set_color(0, 0, 0) 134 | self.viewer.add_geom(line) 135 | for r in range(5): 136 | line = rendering.Line((r*unit, 0), (r*unit, screen_height)) 137 | line.set_color(0, 0, 0) 138 | self.viewer.add_geom(line) 139 | 140 | # 创建墙壁 141 | for x, y in self.walls: 142 | r = rendering.make_polygon( 143 | v=[[x * unit, y * unit], 144 | [(x + 1) * unit, y * unit], 145 | [(x + 1) * unit, (y + 1) * unit], 146 | [x * unit, (y + 1) * unit], 147 | [x * unit, y * unit]]) 148 | r.set_color(0, 0, 0) 149 | self.viewer.add_geom(r) 150 | 151 | # 创建终点 152 | d_x, d_y = self.dest_position 153 | dest = rendering.make_polygon(v=[ 154 | [d_x*unit, d_y*unit], 155 | [(d_x+1)*unit, d_y*unit], 156 | [(d_x+1)*unit, (d_y+1)*unit], 157 | [d_x*unit, (d_y+1)*unit], 158 | [d_x*unit, d_y*unit]]) 159 | dest_trans = rendering.Transform() 160 | dest.add_attr(dest_trans) 161 | dest.set_color(1, 0, 0) 162 | self.viewer.add_geom(dest) 163 | 164 | # 创建worker 165 | self.work = rendering.make_circle(20) 166 | self.work_trans = rendering.Transform() 167 | self.work.add_attr(self.work_trans) 168 | self.work.set_color(0, 1, 0) 169 | self.viewer.add_geom(self.work) 170 | 171 | # 创建箱子 172 | self.box = rendering.make_circle(20) 173 | self.box_trans = rendering.Transform() 174 | self.box.add_attr(self.box_trans) 175 | self.box.set_color(0, 0, 1) 176 | self.viewer.add_geom(self.box) 177 | 178 | if self.__state is None: 179 | return None 180 | w_x,w_y = self.__state[:2] 181 | b_x,b_y = self.__state[2:] 182 | self.work_trans.set_translation(w_x*unit+unit/2, w_y*unit+unit/2) 183 | self.box_trans.set_translation(b_x*unit+unit/2, b_y*unit+unit/2) 184 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array')) 185 | 186 | 187 | if __name__ == '__main__': 188 | env = PushBoxEnv() 189 | env.reset() 190 | env.render() 191 | env.step((-1, 0)) 192 | env.render() 193 | -------------------------------------------------------------------------------- /5-temporal_difference/push_box_game/agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2017/10/4 5 | 6 | import numpy as np 7 | from sklearn.preprocessing import LabelBinarizer 8 | import cv2 9 | import random 10 | import pandas as pd 11 | import os 12 | import tensorflow as tf 13 | from collections import deque 14 | 15 | 16 | class Agent(object): 17 | 18 | def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): 19 | self.actions = actions 20 | self.learning_rate = learning_rate 21 | self.reward_decay = reward_decay 22 | self.epsilon = e_greedy 23 | self.q_table = pd.DataFrame() 24 | 25 | def check_state_exist(self, state): 26 | if tuple(state) not in self.q_table.columns: 27 | self.q_table[tuple(state)] = [0]*len(self.actions) 28 | 29 | def choose_action(self, observation): 30 | observation = tuple(observation) 31 | self.check_state_exist(observation) 32 | if np.random.uniform() < self.epsilon: 33 | state_action = self.q_table[observation] 34 | state_action = state_action.reindex(np.random.permutation(state_action.index)) # some action_space have same value 35 | action_idx = state_action.argmax() 36 | else: 37 | action_idx = np.random.choice(range(len(self.actions))) 38 | return action_idx 39 | 40 | def save_train_parameter(self, name): 41 | self.q_table.to_pickle(name) 42 | 43 | def load_train_parameter(self, name): 44 | self.q_table = pd.read_pickle(name) 45 | 46 | def learn(self, *args, **kwargs): 47 | pass 48 | 49 | def train(self, *args, **kwargs): 50 | pass 51 | 52 | 53 | class QLearningAgent(Agent): 54 | 55 | def learn(self, s, a, r, s_, done): 56 | self.check_state_exist(s_) 57 | q_predict = self.q_table[s][a] 58 | if not done: 59 | q_target = r + self.reward_decay * self.q_table[s_].max() # next __state is not terminal 60 | else: 61 | q_target = r # next __state is terminal 62 | self.q_table[s][a] += self.learning_rate * (q_target - q_predict) # update 63 | 64 | def train(self, env, max_iterator=100): 65 | self.load_train_parameter("q_table.pkl") 66 | for episode in range(max_iterator): 67 | 68 | observation = env.reset() 69 | 70 | while True: 71 | # env.render() 72 | 73 | action_idx = self.choose_action(observation) 74 | 75 | observation_, reward, done, _ = env.step(self.actions[action_idx]) 76 | 77 | print(observation, reward) 78 | 79 | self.learn(observation, action_idx, reward, observation_, done) 80 | 81 | observation = observation_ 82 | 83 | if done: 84 | self.save_train_parameter("q_table.pkl") 85 | break 86 | 87 | 88 | class SarsaAgent(Agent): 89 | 90 | def learn(self, s, a, r, s_, a_, done): 91 | self.check_state_exist(s_) 92 | q_predict = self.q_table[s][a] 93 | if not done: 94 | q_target = r + self.reward_decay * self.q_table[s_][a_] # next __state is not terminal 95 | else: 96 | q_target = r # next __state is terminal 97 | self.q_table[s][a] += self.learning_rate*(q_target - q_predict) # update 98 | 99 | def train(self, env, max_iterator=100): 100 | self.load_train_parameter("q_table.pkl") 101 | for episode in range(max_iterator): 102 | 103 | observation = env.reset() 104 | action_idx = self.choose_action(observation) 105 | while True: 106 | # env.render() 107 | print(observation) 108 | 109 | observation_, reward, done, _ = env.step(self.actions[action_idx]) 110 | 111 | action_idx_ = self.choose_action(observation_) 112 | 113 | self.learn(observation, action_idx, reward, observation_, action_idx_, done) 114 | 115 | observation = observation_ 116 | action_idx = action_idx_ 117 | 118 | if done: 119 | self.save_train_parameter("q_table.pkl") 120 | break 121 | -------------------------------------------------------------------------------- /5-temporal_difference/push_box_game/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2017/10/4 5 | import gym 6 | from agent import QLearningAgent 7 | if __name__ == "__main__": 8 | env = gym.make("PushBoxGame-v0") 9 | RL = QLearningAgent(actions=env.action_space) 10 | RL.train(env) -------------------------------------------------------------------------------- /5-temporal_difference/push_box_game/q_table.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/push_box_game/q_table.pkl -------------------------------------------------------------------------------- /5-temporal_difference/q_learning_algortihm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/q_learning_algortihm.png -------------------------------------------------------------------------------- /5-temporal_difference/sarsa_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/sarsa_algorithm.png -------------------------------------------------------------------------------- /5-temporal_difference/sarsa_lambda_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/sarsa_lambda_algorithm.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all pyc files. 2 | *.pyc 3 | 4 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/README.md: -------------------------------------------------------------------------------- 1 | # Using Deep Q-Network to Learn How To Play Flappy Bird 2 | 3 | 4 | 5 | 7 mins version: [DQN for flappy bird](https://www.youtube.com/watch?v=THhUXIhjkCM) 6 | 7 | ## Overview 8 | This project follows the description of the Deep Q Learning algorithm described in Playing Atari with Deep Reinforcement Learning [2] and shows that this learning algorithm can be further generalized to the notorious Flappy Bird. 9 | 10 | ## Installation Dependencies: 11 | * Python 2.7 or 3 12 | * TensorFlow 0.7 13 | * pygame 14 | * OpenCV-Python 15 | 16 | ## How to Run? 17 | ``` 18 | git clone https://github.com/yenchenlin1994/DeepLearningFlappyBird.git 19 | cd DeepLearningFlappyBird 20 | python deep_q_network.py 21 | ``` 22 | 23 | ## What is Deep Q-Network? 24 | It is a convolutional neural network, trained with a variant of Q-learning, whose input is raw pixels and whose output is a value function estimating future rewards. 25 | 26 | For those who are interested in deep reinforcement learning, I highly recommend to read the following post: 27 | 28 | [Demystifying Deep Reinforcement Learning](http://www.nervanasys.com/demystifying-deep-reinforcement-learning/) 29 | 30 | ## Deep Q-Network Algorithm 31 | 32 | The pseudo-code for the Deep Q Learning algorithm, as given in [1], can be found below: 33 | 34 | ``` 35 | Initialize replay memory D to size N 36 | Initialize action-value function Q with random weights 37 | for episode = 1, M do 38 | Initialize state s_1 39 | for t = 1, T do 40 | With probability ϵ select random action a_t 41 | otherwise select a_t=max_a Q(s_t,a; θ_i) 42 | Execute action a_t in emulator and observe r_t and s_(t+1) 43 | Store transition (s_t,a_t,r_t,s_(t+1)) in D 44 | Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D 45 | Set y_j:= 46 | r_j for terminal s_(j+1) 47 | r_j+γ*max_(a^' ) Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1) 48 | Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ 49 | end for 50 | end for 51 | ``` 52 | 53 | ## Experiments 54 | 55 | #### Environment 56 | Since deep Q-network is trained on the raw pixel values observed from the game screen at each time step, [3] finds that remove the background appeared in the original game can make it converge faster. This process can be visualized as the following figure: 57 | 58 | 59 | 60 | #### Network Architecture 61 | According to [1], I first preprocessed the game screens with following steps: 62 | 63 | 1. Convert image to grayscale 64 | 2. Resize image to 80x80 65 | 3. Stack last 4 frames to produce an 80x80x4 input array for network 66 | 67 | The architecture of the network is shown in the figure below. The first layer convolves the input image with an 8x8x4x32 kernel at a stride size of 4. The output is then put through a 2x2 max pooling layer. The second layer convolves with a 4x4x32x64 kernel at a stride of 2. We then max pool again. The third layer convolves with a 3x3x64x64 kernel at a stride of 1. We then max pool one more time. The last hidden layer consists of 256 fully connected ReLU nodes. 68 | 69 | 70 | 71 | The final output layer has the same dimensionality as the number of valid actions which can be performed in the game, where the 0th index always corresponds to doing nothing. The values at this output layer represent the Q function given the input state for each valid action. At each time step, the network performs whichever action corresponds to the highest Q value using a ϵ greedy policy. 72 | 73 | 74 | #### Training 75 | At first, I initialize all weight matrices randomly using a normal distribution with a standard deviation of 0.01, then set the replay memory with a max size of 500,00 experiences. 76 | 77 | I start training by choosing actions uniformly at random for the first 10,000 time steps, without updating the network weights. This allows the system to populate the replay memory before training begins. 78 | 79 | Note that unlike [1], which initialize ϵ = 1, I linearly anneal ϵ from 0.1 to 0.0001 over the course of the next 3000,000 frames. The reason why I set it this way is that agent can choose an action every 0.03s (FPS=30) in our game, high ϵ will make it **flap** too much and thus keeps itself at the top of the game screen and finally bump the pipe in a clumsy way. This condition will make Q function converge relatively slow since it only start to look other conditions when ϵ is low. 80 | However, in other games, initialize ϵ to 1 is more reasonable. 81 | 82 | During training time, at each time step, the network samples minibatches of size 32 from the replay memory to train on, and performs a gradient step on the loss function described above using the Adam optimization algorithm with a learning rate of 0.000001. After annealing finishes, the network continues to train indefinitely, with ϵ fixed at 0.001. 83 | 84 | ## FAQ 85 | 86 | #### Checkpoint not found 87 | Change [first line of `saved_networks/checkpoint`](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/saved_networks/checkpoint#L1) to 88 | 89 | `model_checkpoint_path: "saved_networks/bird-dqn-2920000"` 90 | 91 | #### How to reproduce? 92 | 1. Comment out [these lines](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/deep_q_network.py#L108-L112) 93 | 94 | 2. Modify `deep_q_network.py`'s parameter as follow: 95 | ```python 96 | OBSERVE = 10000 97 | EXPLORE = 3000000 98 | FINAL_EPSILON = 0.0001 99 | INITIAL_EPSILON = 0.1 100 | ``` 101 | 102 | ## References 103 | 104 | [1] Mnih Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves, Martin Riedmiller, Andreas K. Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, and Demis Hassabis. **Human-level Control through Deep Reinforcement Learning**. Nature, 529-33, 2015. 105 | 106 | [2] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. **Playing Atari with Deep Reinforcement Learning**. NIPS, Deep Learning workshop 107 | 108 | [3] Kevin Chen. **Deep Reinforcement Learning for Flappy Bird** [Report](http://cs229.stanford.edu/proj2015/362_report.pdf) | [Youtube result](https://youtu.be/9WKBzTUsPKc) 109 | 110 | ## Disclaimer 111 | This work is highly based on the following repos: 112 | 113 | 1. [sourabhv/FlapPyBird] (https://github.com/sourabhv/FlapPyBird) 114 | 2. [asrivat1/DeepLearningVideoGames](https://github.com/asrivat1/DeepLearningVideoGames) 115 | 116 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.ogg -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.wav -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.ogg -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.wav -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.ogg -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.wav -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.ogg -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.wav -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/0.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/1.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/2.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/3.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/4.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/5.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/6.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/7.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/8.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/9.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/background-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/background-black.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/base.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/pipe-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/pipe-green.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-downflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-downflap.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-midflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-midflap.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-upflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-upflap.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/deep_q_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | 4 | import tensorflow as tf 5 | import cv2 6 | import sys 7 | sys.path.append("game/") 8 | import game.wrapped_flappy_bird as game 9 | import random 10 | import numpy as np 11 | from collections import deque 12 | 13 | GAME = 'bird' # the name of the game being played for log files 14 | ACTIONS = 2 # number of valid actions 15 | GAMMA = 0.99 # decay rate of past observations 16 | OBSERVE = 100000. # timesteps to observe before training 17 | EXPLORE = 2000000. # frames over which to anneal epsilon 18 | FINAL_EPSILON = 0.0001 # final value of epsilon 19 | INITIAL_EPSILON = 0.0001 # starting value of epsilon 20 | REPLAY_MEMORY = 50000 # number of previous transitions to remember 21 | BATCH = 32 # size of minibatch 22 | FRAME_PER_ACTION = 1 23 | 24 | def weight_variable(shape): 25 | initial = tf.truncated_normal(shape, stddev = 0.01) 26 | return tf.Variable(initial) 27 | 28 | def bias_variable(shape): 29 | initial = tf.constant(0.01, shape = shape) 30 | return tf.Variable(initial) 31 | 32 | def conv2d(x, W, stride): 33 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME") 34 | 35 | def max_pool_2x2(x): 36 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME") 37 | 38 | def createNetwork(): 39 | # network weights 40 | W_conv1 = weight_variable([8, 8, 4, 32]) 41 | b_conv1 = bias_variable([32]) 42 | 43 | W_conv2 = weight_variable([4, 4, 32, 64]) 44 | b_conv2 = bias_variable([64]) 45 | 46 | W_conv3 = weight_variable([3, 3, 64, 64]) 47 | b_conv3 = bias_variable([64]) 48 | 49 | W_fc1 = weight_variable([1600, 512]) 50 | b_fc1 = bias_variable([512]) 51 | 52 | W_fc2 = weight_variable([512, ACTIONS]) 53 | b_fc2 = bias_variable([ACTIONS]) 54 | 55 | # input layer 56 | s = tf.placeholder("float", [None, 80, 80, 4]) 57 | 58 | # hidden layers 59 | h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1) 60 | h_pool1 = max_pool_2x2(h_conv1) 61 | 62 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2) 63 | #h_pool2 = max_pool_2x2(h_conv2) 64 | 65 | h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3) 66 | #h_pool3 = max_pool_2x2(h_conv3) 67 | 68 | #h_pool3_flat = tf.reshape(h_pool3, [-1, 256]) 69 | h_conv3_flat = tf.reshape(h_conv3, [-1, 1600]) 70 | 71 | h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1) 72 | 73 | # readout layer 74 | readout = tf.matmul(h_fc1, W_fc2) + b_fc2 75 | 76 | return s, readout, h_fc1 77 | 78 | def trainNetwork(s, readout, h_fc1, sess): 79 | # define the cost function 80 | a = tf.placeholder("float", [None, ACTIONS]) 81 | y = tf.placeholder("float", [None]) 82 | readout_action = tf.reduce_sum(tf.multiply(readout, a), reduction_indices=1) 83 | cost = tf.reduce_mean(tf.square(y - readout_action)) 84 | train_step = tf.train.AdamOptimizer(1e-6).minimize(cost) 85 | 86 | # open up a game __state to communicate with emulator 87 | game_state = game.GameState() 88 | 89 | # store the previous observations in replay memory 90 | D = deque() 91 | 92 | # printing 93 | a_file = open("logs_" + GAME + "/readout.txt", 'w') 94 | h_file = open("logs_" + GAME + "/hidden.txt", 'w') 95 | 96 | # get the first __state by doing nothing and preprocess the image to 80x80x4 97 | do_nothing = np.zeros(ACTIONS) 98 | do_nothing[0] = 1 99 | x_t, r_0, terminal = game_state.frame_step(do_nothing) 100 | x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY) 101 | ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY) 102 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2) 103 | 104 | # saving and loading networks 105 | saver = tf.train.Saver() 106 | sess.run(tf.initialize_all_variables()) 107 | checkpoint = tf.train.get_checkpoint_state("saved_networks") 108 | if checkpoint and checkpoint.model_checkpoint_path: 109 | saver.restore(sess, checkpoint.model_checkpoint_path) 110 | print("Successfully loaded:", checkpoint.model_checkpoint_path) 111 | else: 112 | print("Could not find old network weights") 113 | 114 | # start training 115 | epsilon = INITIAL_EPSILON 116 | t = 0 117 | while "flappy bird" != "angry bird": 118 | # choose an action epsilon greedily 119 | readout_t = readout.eval(feed_dict={s : [s_t]})[0] 120 | a_t = np.zeros([ACTIONS]) 121 | action_index = 0 122 | 123 | if t % FRAME_PER_ACTION == 0: 124 | if random.random() <= epsilon: 125 | print("----------Random Action----------") 126 | action_index = random.randrange(ACTIONS) 127 | a_t[random.randrange(ACTIONS)] = 1 128 | else: 129 | action_index = np.argmax(readout_t) 130 | a_t[action_index] = 1 131 | else: 132 | a_t[0] = 1 # do nothing 133 | 134 | # scale down epsilon 135 | if epsilon > FINAL_EPSILON and t > OBSERVE: 136 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE 137 | 138 | # run the selected action and observe next __state and reward 139 | x_t1_colored, r_t, terminal = game_state.frame_step(a_t) 140 | x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY) 141 | ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY) 142 | x_t1 = np.reshape(x_t1, (80, 80, 1)) 143 | #s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2) 144 | s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2) 145 | 146 | # store the transition in D 147 | D.append((s_t, a_t, r_t, s_t1, terminal)) 148 | if len(D) > REPLAY_MEMORY: 149 | D.popleft() 150 | 151 | # only train if done observing 152 | if t > OBSERVE: 153 | # sample a minibatch to train on 154 | minibatch = random.sample(D, BATCH) 155 | 156 | # get the batch variables 157 | s_j_batch = [d[0] for d in minibatch] 158 | a_batch = [d[1] for d in minibatch] 159 | r_batch = [d[2] for d in minibatch] 160 | s_j1_batch = [d[3] for d in minibatch] 161 | 162 | y_batch = [] 163 | readout_j1_batch = readout.eval(feed_dict = {s : s_j1_batch}) 164 | for i in range(0, len(minibatch)): 165 | terminal = minibatch[i][4] 166 | # if terminal, only equals reward 167 | if terminal: 168 | y_batch.append(r_batch[i]) 169 | else: 170 | y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i])) 171 | 172 | # perform gradient step 173 | train_step.run(feed_dict = { 174 | y : y_batch, 175 | a : a_batch, 176 | s : s_j_batch} 177 | ) 178 | 179 | # update the old values 180 | s_t = s_t1 181 | t += 1 182 | 183 | # save progress every 10000 iterations 184 | if t % 10000 == 0: 185 | saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step = t) 186 | 187 | # print info 188 | state = "" 189 | if t <= OBSERVE: 190 | state = "observe" 191 | elif t > OBSERVE and t <= OBSERVE + EXPLORE: 192 | state = "explore" 193 | else: 194 | state = "train" 195 | 196 | print("TIMESTEP", t, "/ STATE", state, \ 197 | "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \ 198 | "/ Q_MAX %e" % np.max(readout_t)) 199 | # write info to files 200 | ''' 201 | if t % 10000 <= 100: 202 | a_file.write(",".join([str(x) for x in readout_t]) + '\n') 203 | h_file.write(",".join([str(x) for x in h_fc1.eval(feed_dict={s:[s_t]})[0]]) + '\n') 204 | cv2.imwrite("logs_tetris/frame" + str(t) + ".png", x_t1) 205 | ''' 206 | 207 | def playGame(): 208 | sess = tf.InteractiveSession() 209 | s, readout, h_fc1 = createNetwork() 210 | trainNetwork(s, readout, h_fc1, sess) 211 | 212 | def main(): 213 | playGame() 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/game/flappy_bird_utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | def load(): 4 | # path of player with different observation_space 5 | PLAYER_PATH = ( 6 | 'assets/sprites/redbird-upflap.png', 7 | 'assets/sprites/redbird-midflap.png', 8 | 'assets/sprites/redbird-downflap.png' 9 | ) 10 | 11 | # path of background 12 | BACKGROUND_PATH = 'assets/sprites/background-black.png' 13 | 14 | # path of pipe 15 | PIPE_PATH = 'assets/sprites/pipe-green.png' 16 | 17 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 18 | 19 | # numbers sprites for score display 20 | IMAGES['numbers'] = ( 21 | pygame.image.load('assets/sprites/0.png').convert_alpha(), 22 | pygame.image.load('assets/sprites/1.png').convert_alpha(), 23 | pygame.image.load('assets/sprites/2.png').convert_alpha(), 24 | pygame.image.load('assets/sprites/3.png').convert_alpha(), 25 | pygame.image.load('assets/sprites/4.png').convert_alpha(), 26 | pygame.image.load('assets/sprites/5.png').convert_alpha(), 27 | pygame.image.load('assets/sprites/6.png').convert_alpha(), 28 | pygame.image.load('assets/sprites/7.png').convert_alpha(), 29 | pygame.image.load('assets/sprites/8.png').convert_alpha(), 30 | pygame.image.load('assets/sprites/9.png').convert_alpha() 31 | ) 32 | 33 | # base (ground) sprite 34 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha() 35 | 36 | # sounds 37 | if 'win' in sys.platform: 38 | soundExt = '.wav' 39 | else: 40 | soundExt = '.ogg' 41 | 42 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt) 43 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt) 44 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt) 45 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt) 46 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt) 47 | 48 | # select random background sprites 49 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 50 | 51 | # select random player sprites 52 | IMAGES['player'] = ( 53 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 54 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 55 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 56 | ) 57 | 58 | # select random pipe sprites 59 | IMAGES['pipe'] = ( 60 | pygame.transform.rotate( 61 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 62 | pygame.image.load(PIPE_PATH).convert_alpha(), 63 | ) 64 | 65 | # hismask for pipes 66 | HITMASKS['pipe'] = ( 67 | getHitmask(IMAGES['pipe'][0]), 68 | getHitmask(IMAGES['pipe'][1]), 69 | ) 70 | 71 | # hitmask for player 72 | HITMASKS['player'] = ( 73 | getHitmask(IMAGES['player'][0]), 74 | getHitmask(IMAGES['player'][1]), 75 | getHitmask(IMAGES['player'][2]), 76 | ) 77 | 78 | return IMAGES, SOUNDS, HITMASKS 79 | 80 | def getHitmask(image): 81 | """returns a hitmask using an image's alpha.""" 82 | mask = [] 83 | for x in range(image.get_width()): 84 | mask.append([]) 85 | for y in range(image.get_height()): 86 | mask[x].append(bool(image.get_at((x,y))[3])) 87 | return mask 88 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/game/wrapped_flappy_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import random 4 | import pygame 5 | import flappy_bird_utils 6 | import pygame.surfarray as surfarray 7 | from pygame.locals import * 8 | from itertools import cycle 9 | 10 | FPS = 30 11 | SCREENWIDTH = 288 12 | SCREENHEIGHT = 512 13 | 14 | pygame.init() 15 | FPSCLOCK = pygame.time.Clock() 16 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) 17 | pygame.display.set_caption('Flappy Bird') 18 | 19 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load() 20 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe 21 | BASEY = SCREENHEIGHT * 0.79 22 | 23 | PLAYER_WIDTH = IMAGES['player'][0].get_width() 24 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 25 | PIPE_WIDTH = IMAGES['pipe'][0].get_width() 26 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height() 27 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 28 | 29 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 30 | 31 | 32 | class GameState: 33 | def __init__(self): 34 | self.score = self.playerIndex = self.loopIter = 0 35 | self.playerx = int(SCREENWIDTH * 0.2) 36 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 37 | self.basex = 0 38 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH 39 | 40 | newPipe1 = getRandomPipe() 41 | newPipe2 = getRandomPipe() 42 | self.upperPipes = [ 43 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']}, 44 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']}, 45 | ] 46 | self.lowerPipes = [ 47 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']}, 48 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']}, 49 | ] 50 | 51 | # player velocity, max velocity, downward accleration, accleration on flap 52 | self.pipeVelX = -4 53 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped 54 | self.playerMaxVelY = 10 # max vel along Y, max descend speed 55 | self.playerMinVelY = -8 # min vel along Y, max ascend speed 56 | self.playerAccY = 1 # players downward accleration 57 | self.playerFlapAcc = -9 # players speed on flapping 58 | self.playerFlapped = False # True when player flaps 59 | 60 | def frame_step(self, input_actions): 61 | pygame.event.pump() 62 | 63 | reward = 0.1 64 | terminal = False 65 | 66 | if sum(input_actions) != 1: 67 | raise ValueError('Multiple input action_space!') 68 | 69 | # input_actions[0] == 1: do nothing 70 | # input_actions[1] == 1: flap the bird 71 | if input_actions[1] == 1: 72 | if self.playery > -2 * PLAYER_HEIGHT: 73 | self.playerVelY = self.playerFlapAcc 74 | self.playerFlapped = True 75 | #SOUNDS['wing'].play() 76 | 77 | # check for score 78 | playerMidPos = self.playerx + PLAYER_WIDTH / 2 79 | for pipe in self.upperPipes: 80 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 81 | if pipeMidPos <= playerMidPos < pipeMidPos + 4: 82 | self.score += 1 83 | #SOUNDS['point'].play() 84 | reward = 1 85 | 86 | # playerIndex basex change 87 | if (self.loopIter + 1) % 3 == 0: 88 | self.playerIndex = next(PLAYER_INDEX_GEN) 89 | self.loopIter = (self.loopIter + 1) % 30 90 | self.basex = -((-self.basex + 100) % self.baseShift) 91 | 92 | # player's movement 93 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped: 94 | self.playerVelY += self.playerAccY 95 | if self.playerFlapped: 96 | self.playerFlapped = False 97 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT) 98 | if self.playery < 0: 99 | self.playery = 0 100 | 101 | # move pipes to left 102 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 103 | uPipe['x'] += self.pipeVelX 104 | lPipe['x'] += self.pipeVelX 105 | 106 | # add new pipe when first pipe is about to touch left of screen 107 | if 0 < self.upperPipes[0]['x'] < 5: 108 | newPipe = getRandomPipe() 109 | self.upperPipes.append(newPipe[0]) 110 | self.lowerPipes.append(newPipe[1]) 111 | 112 | # remove first pipe if its out of the screen 113 | if self.upperPipes[0]['x'] < -PIPE_WIDTH: 114 | self.upperPipes.pop(0) 115 | self.lowerPipes.pop(0) 116 | 117 | # check if crash here 118 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 119 | 'index': self.playerIndex}, 120 | self.upperPipes, self.lowerPipes) 121 | if isCrash: 122 | #SOUNDS['hit'].play() 123 | #SOUNDS['die'].play() 124 | terminal = True 125 | self.__init__() 126 | reward = -1 127 | 128 | # draw sprites 129 | SCREEN.blit(IMAGES['background'], (0,0)) 130 | 131 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 132 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y'])) 133 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y'])) 134 | 135 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY)) 136 | # print score so player overlaps the score 137 | # showScore(self.score) 138 | SCREEN.blit(IMAGES['player'][self.playerIndex], 139 | (self.playerx, self.playery)) 140 | 141 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 142 | pygame.display.update() 143 | FPSCLOCK.tick(FPS) 144 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2) 145 | return image_data, reward, terminal 146 | 147 | def getRandomPipe(): 148 | """returns a randomly generated pipe""" 149 | # y of gap between upper and lower pipe 150 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90] 151 | index = random.randint(0, len(gapYs)-1) 152 | gapY = gapYs[index] 153 | 154 | gapY += int(BASEY * 0.2) 155 | pipeX = SCREENWIDTH + 10 156 | 157 | return [ 158 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe 159 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe 160 | ] 161 | 162 | 163 | def showScore(score): 164 | """displays score in center of screen""" 165 | scoreDigits = [int(x) for x in list(str(score))] 166 | totalWidth = 0 # total width of all numbers to be printed 167 | 168 | for digit in scoreDigits: 169 | totalWidth += IMAGES['numbers'][digit].get_width() 170 | 171 | Xoffset = (SCREENWIDTH - totalWidth) / 2 172 | 173 | for digit in scoreDigits: 174 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1)) 175 | Xoffset += IMAGES['numbers'][digit].get_width() 176 | 177 | 178 | def checkCrash(player, upperPipes, lowerPipes): 179 | """returns True if player collders with base or pipes.""" 180 | pi = player['index'] 181 | player['w'] = IMAGES['player'][0].get_width() 182 | player['h'] = IMAGES['player'][0].get_height() 183 | 184 | # if player crashes into ground 185 | if player['y'] + player['h'] >= BASEY - 1: 186 | return True 187 | else: 188 | 189 | playerRect = pygame.Rect(player['x'], player['y'], 190 | player['w'], player['h']) 191 | 192 | for uPipe, lPipe in zip(upperPipes, lowerPipes): 193 | # upper and lower pipe rects 194 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 195 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 196 | 197 | # player and upper/lower pipe hitmasks 198 | pHitMask = HITMASKS['player'][pi] 199 | uHitmask = HITMASKS['pipe'][0] 200 | lHitmask = HITMASKS['pipe'][1] 201 | 202 | # if bird collided with upipe or lpipe 203 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 204 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) 205 | 206 | if uCollide or lCollide: 207 | return True 208 | 209 | return False 210 | 211 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 212 | """Checks if two objects collide and not just their rects""" 213 | rect = rect1.clip(rect2) 214 | 215 | if rect.width == 0 or rect.height == 0: 216 | return False 217 | 218 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 219 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 220 | 221 | for x in range(rect.width): 222 | for y in range(rect.height): 223 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 224 | return True 225 | return False 226 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/images/flappy_bird_demp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/flappy_bird_demp.gif -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/images/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/network.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/images/preprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/preprocess.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/logs_bird/hidden.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/hidden.txt -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/logs_bird/readout.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/readout.txt -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000 -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000.meta -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000 -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000.meta -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000 -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000.meta -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000 -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000.meta -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000 -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000.meta -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "bird-dqn-2920000" 2 | all_model_checkpoint_paths: "bird-dqn-2880000" 3 | all_model_checkpoint_paths: "bird-dqn-2890000" 4 | all_model_checkpoint_paths: "bird-dqn-2900000" 5 | all_model_checkpoint_paths: "bird-dqn-2910000" 6 | all_model_checkpoint_paths: "bird-dqn-2920000" 7 | -------------------------------------------------------------------------------- /6-value_function_approximate/deep_learning_flappy_bird/saved_networks/pretrained_model/bird-dqn-policy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/pretrained_model/bird-dqn-policy -------------------------------------------------------------------------------- /6-value_function_approximate/deep_q_network_algortihm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_q_network_algortihm.png -------------------------------------------------------------------------------- /6-value_function_approximate/deep_q_network_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # author : zlq16 4 | # date : 2017/10/31 5 | import tensorflow as tf 6 | import cv2 7 | import random 8 | import numpy as np 9 | from collections import deque 10 | 11 | #Game的定义类,此处Game是什么不重要,只要提供执行Action的方法,获取当前游戏区域像素的方法即可 12 | class Game(object): 13 | def __init__(self): #Game初始化 14 | # action是MOVE_STAY、MOVE_LEFT、MOVE_RIGHT 15 | # ai控制棒子左右移动;返回游戏界面像素数和对应的奖励。(像素->奖励->强化棒子往奖励高的方向移动) 16 | pass 17 | def step(self, action): 18 | pass 19 | # learning_rate 20 | GAMMA = 0.99 21 | # 跟新梯度 22 | INITIAL_EPSILON = 1.0 23 | FINAL_EPSILON = 0.05 24 | # 测试观测次数 25 | EXPLORE = 500000 26 | OBSERVE = 500 27 | # 记忆经验大小 28 | REPLAY_MEMORY = 500000 29 | # 每次训练取出的记录数 30 | BATCH = 100 31 | # 输出层神经元数。代表3种操作-MOVE_STAY:[1, 0, 0] MOVE_LEFT:[0, 1, 0] MOVE_RIGHT:[0, 0, 1] 32 | output = 3 33 | MOVE_STAY =[1, 0, 0] 34 | MOVE_LEFT =[0, 1, 0] 35 | MOVE_RIGHT=[0, 0, 1] 36 | input_image = tf.placeholder("float", [None, 80, 100, 4]) # 游戏像素 37 | action = tf.placeholder("float", [None, output]) # 操作 38 | 39 | #定义CNN-卷积神经网络 40 | def convolutional_neural_network(input_image): 41 | weights = {'w_conv1':tf.Variable(tf.zeros([8, 8, 4, 32])), 42 | 'w_conv2':tf.Variable(tf.zeros([4, 4, 32, 64])), 43 | 'w_conv3':tf.Variable(tf.zeros([3, 3, 64, 64])), 44 | 'w_fc4':tf.Variable(tf.zeros([3456, 784])), 45 | 'w_out':tf.Variable(tf.zeros([784, output]))} 46 | 47 | biases = {'b_conv1':tf.Variable(tf.zeros([32])), 48 | 'b_conv2':tf.Variable(tf.zeros([64])), 49 | 'b_conv3':tf.Variable(tf.zeros([64])), 50 | 'b_fc4':tf.Variable(tf.zeros([784])), 51 | 'b_out':tf.Variable(tf.zeros([output]))} 52 | 53 | conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights['w_conv1'], strides = [1, 4, 4, 1], padding = "VALID") + biases['b_conv1']) 54 | conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights['w_conv2'], strides = [1, 2, 2, 1], padding = "VALID") + biases['b_conv2']) 55 | conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights['w_conv3'], strides = [1, 1, 1, 1], padding = "VALID") + biases['b_conv3']) 56 | conv3_flat = tf.reshape(conv3, [-1, 3456]) 57 | fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights['w_fc4']) + biases['b_fc4']) 58 | 59 | output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out'] 60 | return output_layer 61 | 62 | #训练神经网络 63 | def train_neural_network(input_image): 64 | argmax = tf.placeholder("float", [None, output]) 65 | gt = tf.placeholder("float", [None]) 66 | 67 | #损失函数 68 | predict_action = convolutional_neural_network(input_image) 69 | action = tf.reduce_sum(tf.mul(predict_action, argmax), reduction_indices = 1) #max(Q(S,:)) 70 | cost = tf.reduce_mean(tf.square(action - gt)) 71 | optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost) 72 | 73 | #游戏开始 74 | game = Game() 75 | D = deque() 76 | _, image = game.step(MOVE_STAY) 77 | image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY) 78 | ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY) 79 | input_image_data = np.stack((image, image, image, image), axis = 2) 80 | 81 | with tf.Session() as sess: 82 | #初始化神经网络各种参数 83 | sess.run(tf.initialize_all_variables()) 84 | #保存神经网络参数的模块 85 | saver = tf.train.Saver() 86 | 87 | #总的运行次数 88 | n = 0 89 | epsilon = INITIAL_EPSILON 90 | while True: 91 | 92 | #神经网络输出的是Q(S,:)值 93 | action_t = predict_action.eval(feed_dict = {input_image : [input_image_data]})[0] 94 | argmax_t = np.zeros([output], dtype=np.int) 95 | 96 | #贪心选取动作 97 | if(random.random() <= INITIAL_EPSILON): 98 | maxIndex = random.randrange(output) 99 | else: 100 | maxIndex = np.argmax(action_t) 101 | 102 | #将action对应的Q(S,a)最大值提取出来 103 | argmax_t[maxIndex] = 1 104 | 105 | #贪婪的部分开始不断的增加 106 | if epsilon > FINAL_EPSILON: 107 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE 108 | 109 | #将选取的动作带入到环境,观察环境状态S'与回报reward 110 | reward, image = game.step(list(argmax_t)) 111 | 112 | #将得到的图形进行变换用于神经网络的输出 113 | image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY) 114 | ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY) 115 | image = np.reshape(image, (80, 100, 1)) 116 | input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis = 2) 117 | 118 | #将S,a,r,S'记录的大脑中 119 | D.append((input_image_data, argmax_t, reward, input_image_data1)) 120 | 121 | #大脑的记忆是有一定的限度的 122 | if len(D) > REPLAY_MEMORY: 123 | D.popleft() 124 | 125 | #如果达到观察期就要进行神经网络训练 126 | if n > OBSERVE: 127 | 128 | #随机的选取一定记忆的数据进行训练 129 | minibatch = random.sample(D, BATCH) 130 | #将里面的每一个记忆的S提取出来 131 | input_image_data_batch = [d[0] for d in minibatch] 132 | #将里面的每一个记忆的a提取出来 133 | argmax_batch = [d[1] for d in minibatch] 134 | #将里面的每一个记忆回报提取出来 135 | reward_batch = [d[2] for d in minibatch] 136 | #将里面的每一个记忆的下一步转台提取出来 137 | input_image_data1_batch = [d[3] for d in minibatch] 138 | 139 | gt_batch = [] 140 | #利用已经有的求解Q(S',:) 141 | out_batch = predict_action.eval(feed_dict = {input_image : input_image_data1_batch}) 142 | 143 | #利用bellman优化得到长期的回报r + γmax(Q(s',:)) 144 | for i in range(0, len(minibatch)): 145 | gt_batch.append(reward_batch[i] + GAMMA * np.max(out_batch[i])) 146 | 147 | #利用事先定义的优化函数进行优化神经网络参数 148 | print("gt_batch:", gt_batch, "argmax:", argmax_batch) 149 | optimizer.run(feed_dict = {gt : gt_batch, argmax : argmax_batch, input_image : input_image_data_batch}) 150 | 151 | input_image_data = input_image_data1 152 | n = n+1 153 | print(n, "epsilon:", epsilon, " ", "action:", maxIndex, " ", "_reward:", reward) 154 | 155 | train_neural_network(input_image) 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 代码说明 2 | ## 描述 3 | > 这是一个我学习《深入浅出强化学习-原理入门》的学习代码仓库,主要是一些书上的例子和书后面的练习题的代码 4 | ## 目录 5 | ### 1-gym二次开发(gym develop) 6 | 1. [gym二次开发相关文件配置](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/README.md) 7 | 2. [改写gym下的core.py文件](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/core.py) 8 | 3. [利用gym二次开发的一个网格游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/grid_game.py) 9 | 4. [利用gym二次开发的一个迷宫游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/maze_game.py) 10 | ### 2-马尔科夫决策过程(Markov Decision Process) 11 | 1. [学习生活的例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/2-markov_decision_process/our_life.py) 12 | 2. [里面对于迷宫的环境模拟的课后作业](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/2-markov_decision_process/game.py) 13 | ### 3-动态规划(Dynamic Program) 14 | 1. [网格游戏在均匀策略下的策略评估例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_average_policy.py) 15 | 2. [策略迭代算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/policy_iteration_algorithm.png) 16 | 3. [网格游戏在贪婪策略下的策略迭代例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_policy_iterate.py) 17 | 4. [值迭代算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/value_iteration_algorithm.png) 18 | 5. [网格游戏在贪婪测略下的值迭代例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_value_iterate.py) 19 | 6. [迷宫游戏在动态规划下的课后作业](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/maze_game_with_dynamic_program.py) 20 | ### 4-蒙特卡洛值迭代(Monte Carlo) 21 | 1. [蒙特卡罗方法采样](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/4-monte_carlo/monte_carlo_sample.py) 22 | 2. [蒙特卡罗方法评估](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/4-monte_carlo/monte_carlo_evaluate.py) 23 | ### 5-时间差分值迭代(Temporal Difference) 24 | 1. [Q-learning算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/5-temporal_difference/q_learning_algortihm.png) 25 | 2. [Sarsa算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/sarsa_algorithm.png) 26 | 3. [Sarsa(λ)算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/sarsa_lambda_algorithm.png) 27 | 4. [利用gym二次开发的一个推箱子游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/push_box_game.py) 28 | 5. [利用时间差分学习推箱子实例](https://github.com/zhuliquan/reinforcement_learning_basic_book/tree/master/5-temporal_difference/push_box_game) 29 | ### 6-值函数逼近(Value Function Approximate) 30 | 1. [Deep Q-learning算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/6-value_function_approximate/deep_q_network_algortihm.png) 31 | 2. [Deep Q-learning算法模板](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/6-value_function_approximate/deep_q_network_template.py) 32 | 3. [利用Deep Q-learning写的flappy游戏](https://github.com/zhuliquan/reinforcement_learning_basic_book/tree/master/6-value_function_approximate/deep_learning_flappy_bird) 33 | --------------------------------------------------------------------------------