├── .gitignore
├── 1-gym_developing
├── README.md
├── core.py
├── grid_game.py
├── maze_game.py
└── suceess.png
├── 2-markov_decision_process
├── __init__.py
├── game.py
└── our_life.py
├── 3-dynamic_program
├── grid_game_with_average_policy.py
├── grid_game_with_policy_iterate.py
├── grid_game_with_value_iterate.py
├── maze_game_with_dynamic_program.py
├── policy_iteration_algorithm.png
└── value_iteration_algorithm.png
├── 4-monte_carlo
├── monte_carlo_evaluate.py
└── monte_carlo_sample.py
├── 5-temporal_difference
├── README.md
├── push_box_game.py
├── push_box_game
│ ├── agent.py
│ ├── main.py
│ └── q_table.pkl
├── q_learning_algortihm.png
├── sarsa_algorithm.png
└── sarsa_lambda_algorithm.png
├── 6-value_function_approximate
├── deep_learning_flappy_bird
│ ├── .gitignore
│ ├── README.md
│ ├── assets
│ │ ├── audio
│ │ │ ├── die.ogg
│ │ │ ├── die.wav
│ │ │ ├── hit.ogg
│ │ │ ├── hit.wav
│ │ │ ├── point.ogg
│ │ │ ├── point.wav
│ │ │ ├── swoosh.ogg
│ │ │ ├── swoosh.wav
│ │ │ ├── wing.ogg
│ │ │ └── wing.wav
│ │ └── sprites
│ │ │ ├── 0.png
│ │ │ ├── 1.png
│ │ │ ├── 2.png
│ │ │ ├── 3.png
│ │ │ ├── 4.png
│ │ │ ├── 5.png
│ │ │ ├── 6.png
│ │ │ ├── 7.png
│ │ │ ├── 8.png
│ │ │ ├── 9.png
│ │ │ ├── background-black.png
│ │ │ ├── base.png
│ │ │ ├── pipe-green.png
│ │ │ ├── redbird-downflap.png
│ │ │ ├── redbird-midflap.png
│ │ │ └── redbird-upflap.png
│ ├── deep_q_network.py
│ ├── game
│ │ ├── flappy_bird_utils.py
│ │ └── wrapped_flappy_bird.py
│ ├── images
│ │ ├── flappy_bird_demp.gif
│ │ ├── network.png
│ │ └── preprocess.png
│ ├── logs_bird
│ │ ├── hidden.txt
│ │ └── readout.txt
│ └── saved_networks
│ │ ├── bird-dqn-2880000
│ │ ├── bird-dqn-2880000.meta
│ │ ├── bird-dqn-2890000
│ │ ├── bird-dqn-2890000.meta
│ │ ├── bird-dqn-2900000
│ │ ├── bird-dqn-2900000.meta
│ │ ├── bird-dqn-2910000
│ │ ├── bird-dqn-2910000.meta
│ │ ├── bird-dqn-2920000
│ │ ├── bird-dqn-2920000.meta
│ │ ├── checkpoint
│ │ └── pretrained_model
│ │ └── bird-dqn-policy
├── deep_q_network_algortihm.png
└── deep_q_network_template.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python:
2 | *.py[cod]
3 | *.so
4 | *.egg
5 | *.egg-info
6 | dist
7 | build
8 | .idea
9 | # pycharm setting:
10 | *.xml
--------------------------------------------------------------------------------
/1-gym_developing/README.md:
--------------------------------------------------------------------------------
1 | ## 使用方法
2 | > 下面的操作最好是在管理员权限下操作
3 | 1. 首先找到gym的安装目录,例如我的是`C:\Program Files\Python36\Lib\site-packages\gym`。
4 | ---
5 | 2. 首先将[core.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/core.py)
6 | 文件拷贝覆盖gym安装目录下面的`core.py`文件。
7 | ---
8 | 3. 进入到gym安装目录下面的`/envs/classic_control`,将对应的环境文件([grid_game.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/grid_game.py)和
9 | [maze_game.py](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/maze_game.py))拷贝到
10 | 拷贝到下面。
11 | ---
12 | 4. 找到gym安装目录下面的`/envs/classic_control/__init__.py`文件,在后面添加
13 |
14 | ```python
15 | from gym.envs.classic_control.grid_game import GridEnv
16 | from gym.envs.classic_control.maze_game import MazeEnv
17 | ```
18 | ---
19 | 5. 找到gym安装目录下面的`/envs/__init__.py`文件,在后面添加
20 | ```python
21 | register(
22 | id='GridGame-v0',
23 | entry_point='gym.envs.classic_control:GridEnv',
24 | max_episode_steps=200,
25 | reward_threshold=100.0,
26 | )
27 | register(
28 | id='MazeGame-v0',
29 | entry_point='gym.envs.classic_control:MazeEnv',
30 | max_episode_steps=200,
31 | reward_threshold=100.0,
32 | )
33 | ```
34 | ---
35 | 6. 验证配置成功
36 | 运行下面python代码
37 | ```python
38 | import gym
39 | env = gym.make("GridGame-v0")
40 | env.reset()
41 | env.render()
42 | ```
43 | 如果出现下面的界面说明配置成功
44 |
--------------------------------------------------------------------------------
/1-gym_developing/core.py:
--------------------------------------------------------------------------------
1 | from gym import logger
2 | import numpy as np
3 |
4 | import gym
5 | from gym import error
6 | from gym.utils import closer
7 |
8 | env_closer = closer.Closer()
9 |
10 | # Env-related abstractions
11 |
12 | class Env(object):
13 | """The main OpenAI Gym class. It encapsulates an environment with
14 | arbitrary behind-the-scenes dynamics. An environment can be
15 | partially or fully observed.
16 |
17 | The main API methods that users of this class need to know are:
18 |
19 | transform
20 | step
21 | reset
22 | render
23 | close
24 | seed
25 |
26 | And set the following attributes:
27 |
28 | action_space: The Space object corresponding to valid actions
29 | observation_space: The Space object corresponding to valid observations
30 | reward_range: A tuple corresponding to the min and max possible rewards
31 |
32 | Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range.
33 |
34 | The methods are accessed publicly as "step", "reset", etc.. The
35 | non-underscored versions are wrapper methods to which we may add
36 | functionality over time.
37 | """
38 |
39 | # Set this in SOME subclasses
40 | metadata = {'render.modes': []}
41 | reward_range = (-np.inf, np.inf)
42 | spec = None
43 |
44 | # Set these in ALL subclasses
45 | action_space = None
46 | observation_space = None
47 |
48 | def transform(self, state, action):
49 | '''
50 | Run one timestep of the environment's dynamics. When end of
51 | episode is reached, you are responsible for calling `transform()`
52 | to reset args state.
53 |
54 | Accepts an state and action returns a tuple(observation,reward,done,info).
55 |
56 | Args:
57 | state (object):an state provided by the anvironment
58 |
59 | Returns:
60 | Returns:
61 | observation (object): agent's observation of the current environment
62 | reward (float) : amount of reward returned after previous action
63 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results
64 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
65 | '''
66 | pass
67 |
68 | def step(self, action):
69 | """Run one timestep of the environment's dynamics. When end of
70 | episode is reached, you are responsible for calling `reset()`
71 | to reset this environment's state.
72 |
73 | Accepts an action and returns a tuple (observation, reward, done, info).
74 |
75 | Args:
76 | action (object): an action provided by the environment
77 |
78 | Returns:
79 | observation (object): agent's observation of the current environment
80 | reward (float) : amount of reward returned after previous action
81 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results
82 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)
83 | """
84 | raise NotImplementedError
85 |
86 | def reset(self):
87 | """Resets the state of the environment and returns an initial observation.
88 |
89 | Returns: observation (object): the initial observation of the
90 | space.
91 | """
92 | raise NotImplementedError
93 |
94 | def render(self, mode='human'):
95 | """Renders the environment.
96 |
97 | The set of supported modes varies per environment. (And some
98 | environments do not support rendering at all.) By convention,
99 | if mode is:
100 |
101 | - human: render to the current display or terminal and
102 | return nothing. Usually for human consumption.
103 | - rgb_array: Return an numpy.ndarray with shape (x, y, 3),
104 | representing RGB values for an x-by-y pixel image, suitable
105 | for turning into a video.
106 | - ansi: Return a string (str) or StringIO.StringIO containing a
107 | terminal-style text representation. The text can include newlines
108 | and ANSI escape sequences (e.g. for colors).
109 |
110 | Note:
111 | Make sure that your class's metadata 'render.modes' key includes
112 | the list of supported modes. It's recommended to call super()
113 | in implementations to use the functionality of this method.
114 |
115 | Args:
116 | mode (str): the mode to render with
117 | close (bool): close all open renderings
118 |
119 | Example:
120 |
121 | class MyEnv(Env):
122 | metadata = {'render.modes': ['human', 'rgb_array']}
123 |
124 | def render(self, mode='human'):
125 | if mode == 'rgb_array':
126 | return np.array(...) # return RGB frame suitable for video
127 | elif mode is 'human':
128 | ... # pop up a window and render
129 | else:
130 | super(MyEnv, self).render(mode=mode) # just raise an exception
131 | """
132 | raise NotImplementedError
133 |
134 | def close(self):
135 | """Override _close in your subclass to perform any necessary cleanup.
136 |
137 | Environments will automatically close() themselves when
138 | garbage collected or when the program exits.
139 | """
140 | return
141 |
142 | def seed(self, seed=None):
143 | """Sets the seed for this env's random number generator(s).
144 |
145 | Note:
146 | Some environments use multiple pseudorandom number generators.
147 | We want to capture all such seeds used in order to ensure that
148 | there aren't accidental correlations between multiple generators.
149 |
150 | Returns:
151 | list: Returns the list of seeds used in this env's random
152 | number generators. The first value in the list should be the
153 | "main" seed, or the value which a reproducer should pass to
154 | 'seed'. Often, the main seed equals the provided 'seed', but
155 | this won't be true if seed=None, for example.
156 | """
157 | logger.warn("Could not seed environment %s", self)
158 | return
159 |
160 | @property
161 | def unwrapped(self):
162 | """Completely unwrap this env.
163 |
164 | Returns:
165 | gym.Env: The base non-wrapped gym.Env instance
166 | """
167 | return self
168 |
169 | def __str__(self):
170 | if self.spec is None:
171 | return '<{} instance>'.format(type(self).__name__)
172 | else:
173 | return '<{}<{}>>'.format(type(self).__name__, self.spec.id)
174 |
175 |
176 | class GoalEnv(Env):
177 | """A goal-based environment. It functions just as any regular OpenAI Gym environment but it
178 | imposes a required structure on the observation_space. More concretely, the observation
179 | space is required to contain at least three elements, namely `observation`, `desired_goal`, and
180 | `achieved_goal`. Here, `desired_goal` specifies the goal that the agent should attempt to achieve.
181 | `achieved_goal` is the goal that it currently achieved instead. `observation` contains the
182 | actual observations of the environment as per usual.
183 | """
184 |
185 | def reset(self):
186 | # Enforce that each GoalEnv uses a Goal-compatible observation space.
187 | if not isinstance(self.observation_space, gym.spaces.Dict):
188 | raise error.Error('GoalEnv requires an observation space of type gym.spaces.Dict')
189 | result = super(GoalEnv, self).reset()
190 | for key in ['observation', 'achieved_goal', 'desired_goal']:
191 | if key not in result:
192 | raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
193 | return result
194 |
195 | def compute_reward(self, achieved_goal, desired_goal, info):
196 | """Compute the step reward. This externalizes the reward function and makes
197 | it dependent on an a desired goal and the one that was achieved. If you wish to include
198 | additional rewards that are independent of the goal, you can include the necessary values
199 | to derive it in info and compute it accordingly.
200 |
201 | Args:
202 | achieved_goal (object): the goal that was achieved during execution
203 | desired_goal (object): the desired goal that we asked the agent to attempt to achieve
204 | info (dict): an info dictionary with additional information
205 |
206 | Returns:
207 | float: The reward that corresponds to the provided achieved goal w.r.t. to the desired
208 | goal. Note that the following should always hold true:
209 |
210 | ob, reward, done, info = env.step()
211 | assert reward == env.compute_reward(ob['achieved_goal'], ob['goal'], info)
212 | """
213 | raise NotImplementedError()
214 |
215 | # Space-related abstractions
216 |
217 | class Space(object):
218 | """Defines the observation and action spaces, so you can write generic
219 | code that applies to any Env. For example, you can choose a random
220 | action.
221 | """
222 | def __init__(self, shape=None, dtype=None):
223 | self.shape = None if shape is None else tuple(shape)
224 | self.dtype = None if dtype is None else np.dtype(dtype)
225 |
226 | def sample(self):
227 | """
228 | Uniformly randomly sample a random element of this space
229 | """
230 | raise NotImplementedError
231 |
232 | def contains(self, x):
233 | """
234 | Return boolean specifying if x is a valid
235 | member of this space
236 | """
237 | raise NotImplementedError
238 |
239 | def to_jsonable(self, sample_n):
240 | """Convert a batch of samples from this space to a JSONable data type."""
241 | # By default, assume identity is JSONable
242 | return sample_n
243 |
244 | def from_jsonable(self, sample_n):
245 | """Convert a JSONable data type to a batch of samples from this space."""
246 | # By default, assume identity is JSONable
247 | return sample_n
248 |
249 |
250 | warn_once = True
251 |
252 | def deprecated_warn_once(text):
253 | global warn_once
254 | if not warn_once: return
255 | warn_once = False
256 | logger.warn(text)
257 |
258 |
259 | class Wrapper(Env):
260 | env = None
261 |
262 | def __init__(self, env):
263 | self.env = env
264 | self.action_space = self.env.action_space
265 | self.observation_space = self.env.observation_space
266 | self.reward_range = self.env.reward_range
267 | self.metadata = self.env.metadata
268 | self._warn_double_wrap()
269 |
270 | @classmethod
271 | def class_name(cls):
272 | return cls.__name__
273 |
274 | def _warn_double_wrap(self):
275 | env = self.env
276 | while True:
277 | if isinstance(env, Wrapper):
278 | if env.class_name() == self.class_name():
279 | raise error.DoubleWrapperError("Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__))
280 | env = env.env
281 | else:
282 | break
283 |
284 | def transform(self, state, action):
285 | return self.env.transform(state, action)
286 |
287 | def step(self, action):
288 | if hasattr(self, "_step"):
289 | deprecated_warn_once("%s doesn't implement 'step' method, but it implements deprecated '_step' method." % type(self))
290 | self.step = self._step
291 | return self.step(action)
292 | else:
293 | deprecated_warn_once("%s doesn't implement 'step' method, " % type(self) +
294 | "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.")
295 | return self.env.step(action)
296 |
297 | def reset(self, **kwargs):
298 | if hasattr(self, "_reset"):
299 | deprecated_warn_once("%s doesn't implement 'reset' method, but it implements deprecated '_reset' method." % type(self))
300 | self.reset = self._reset
301 | return self._reset(**kwargs)
302 | else:
303 | deprecated_warn_once("%s doesn't implement 'reset' method, " % type(self) +
304 | "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.")
305 | return self.env.reset(**kwargs)
306 |
307 | def render(self, mode='human'):
308 | return self.env.render(mode)
309 |
310 | def close(self):
311 | if self.env:
312 | return self.env.close()
313 |
314 | def seed(self, seed=None):
315 | return self.env.seed(seed)
316 |
317 | def compute_reward(self, achieved_goal, desired_goal, info):
318 | return self.env.compute_reward(achieved_goal, desired_goal, info)
319 |
320 | def __str__(self):
321 | return '<{}{}>'.format(type(self).__name__, self.env)
322 |
323 | def __repr__(self):
324 | return str(self)
325 |
326 | @property
327 | def unwrapped(self):
328 | return self.env.unwrapped
329 |
330 | @property
331 | def spec(self):
332 | return self.env.spec
333 |
334 |
335 | class ObservationWrapper(Wrapper):
336 | def transform(self, state, action):
337 | return self.env.transform(state, action)
338 |
339 | def step(self, action):
340 | observation, reward, done, info = self.env.step(action)
341 | return self.observation(observation), reward, done, info
342 |
343 | def reset(self, **kwargs):
344 | observation = self.env.reset(**kwargs)
345 | return self.observation(observation)
346 |
347 | def observation(self, observation):
348 | deprecated_warn_once("%s doesn't implement 'observation' method. Maybe it implements deprecated '_observation' method." % type(self))
349 | return self._observation(observation)
350 |
351 |
352 | class RewardWrapper(Wrapper):
353 | def transform(self, state, action):
354 | return self.env.transform(state, action)
355 |
356 | def reset(self):
357 | return self.env.reset()
358 |
359 | def step(self, action):
360 | observation, reward, done, info = self.env.step(action)
361 | return observation, self.reward(reward), done, info
362 |
363 | def reward(self, reward):
364 | deprecated_warn_once("%s doesn't implement 'reward' method. Maybe it implements deprecated '_reward' method." % type(self))
365 | return self._reward(reward)
366 |
367 |
368 | class ActionWrapper(Wrapper):
369 | def transform(self, state, action):
370 | return self.env.transform(state, action)
371 |
372 | def step(self, action):
373 | action = self.action(action)
374 | return self.env.step(action)
375 |
376 | def reset(self):
377 | return self.env.reset()
378 |
379 | def action(self, action):
380 | deprecated_warn_once("%s doesn't implement 'action' method. Maybe it implements deprecated '_action' method." % type(self))
381 | return self._action(action)
382 |
383 | def reverse_action(self, action):
384 | deprecated_warn_once("%s doesn't implement 'reverse_action' method. Maybe it implements deprecated '_reverse_action' method." % type(self))
385 | return self._reverse_action(action)
386 |
--------------------------------------------------------------------------------
/1-gym_developing/grid_game.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/7
5 | import gym
6 | from gym.utils import seeding
7 | import logging
8 | import numpy as np
9 | import pandas as pd
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class GridEnv(gym.Env):
15 | metadata = {
16 | 'render.modes': ['human', 'rgb_array'],
17 | 'video.frames_per_second': 2
18 | }
19 |
20 | def __init__(self):
21 |
22 | self.observation_space = (1, 2, 3, 4, 5, 6, 7, 8) # 状态空间
23 | self.x = [140, 220, 300, 380, 460, 140, 300, 460]
24 | self.y = [250, 250, 250, 250, 250, 150, 150, 150]
25 | self.__terminal_space = dict() # 终止状态为字典格式
26 | self.__terminal_space[6] = 1
27 | self.__terminal_space[7] = 1
28 | self.__terminal_space[8] = 1
29 |
30 | # 状态转移的数据格式为字典
31 | self.action_space = ('n', 'e', 's', 'w')
32 | self.t = pd.DataFrame(data=None, index=self.observation_space, columns=self.action_space)
33 | self.t.loc[1, "s"] = 6
34 | self.t.loc[1, "e"] = 2
35 | self.t.loc[2, "w"] = 1
36 | self.t.loc[2, "e"] = 3
37 | self.t.loc[3, "s"] = 7
38 | self.t.loc[3, "w"] = 2
39 | self.t.loc[3, "e"] = 4
40 | self.t.loc[4, "w"] = 3
41 | self.t.loc[4, "e"] = 5
42 | self.t.loc[5, "s"] = 8
43 | self.t.loc[5, "w"] = 4
44 | self.__gamma = 0.8 # 折扣因子
45 | self.viewer = None
46 | self.__state = None
47 | self.seed()
48 |
49 | def _reward(self, state):
50 | r = 0.0
51 | if state in (6, 8):
52 | r = -1.0
53 | elif state == 7:
54 | r = 1.0
55 | return r
56 |
57 | def seed(self, seed=None):
58 | self.np_random, seed = seeding.np_random(seed)
59 | return [seed]
60 |
61 | def close(self):
62 | if self.viewer:
63 | self.viewer.close()
64 |
65 | def transform(self, state, action):
66 | #卫语句
67 | if state in self.__terminal_space:
68 | return state, self._reward(state), True, {}
69 |
70 | # 状态转移
71 | if pd.isna(self.t.loc[state, action]):
72 | next_state = state
73 | else:
74 | next_state = self.t.loc[state, action]
75 |
76 | # 计算回报
77 | r = self._reward(next_state)
78 |
79 | # 判断是否终止
80 | is_terminal = False
81 | if next_state in self.__terminal_space:
82 | is_terminal = True
83 |
84 | return next_state, r, is_terminal, {}
85 |
86 | def step(self, action):
87 | state = self.__state
88 |
89 | next_state, r, is_terminal,_ = self.transform(state, action)
90 |
91 | self.__state = next_state
92 |
93 | return next_state, r, is_terminal, {}
94 |
95 | def reset(self):
96 | self.__state = np.random.choice(self.observation_space)
97 | return self.__state
98 |
99 | def render(self, mode='human', close=False):
100 | if close:
101 | if self.viewer is not None:
102 | self.viewer.close()
103 | self.viewer = None
104 | return
105 | screen_width = 600
106 | screen_height = 400
107 |
108 | if self.viewer is None:
109 | from gym.envs.classic_control import rendering
110 | self.viewer = rendering.Viewer(screen_width, screen_height)
111 | # 创建网格世界
112 | self.line1 = rendering.Line((100, 300), (500, 300))
113 | self.line2 = rendering.Line((100, 200), (500, 200))
114 | self.line3 = rendering.Line((100, 300), (100, 100))
115 | self.line4 = rendering.Line((180, 300), (180, 100))
116 | self.line5 = rendering.Line((260, 300), (260, 100))
117 | self.line6 = rendering.Line((340, 300), (340, 100))
118 | self.line7 = rendering.Line((420, 300), (420, 100))
119 | self.line8 = rendering.Line((500, 300), (500, 100))
120 | self.line9 = rendering.Line((100, 100), (180, 100))
121 | self.line10 = rendering.Line((260, 100), (340, 100))
122 | self.line11 = rendering.Line((420, 100), (500, 100))
123 | # 创建第一个骷髅
124 | self.kulo1 = rendering.make_circle(40)
125 | self.circletrans = rendering.Transform(translation=(140, 150))
126 | self.kulo1.add_attr(self.circletrans)
127 | self.kulo1.set_color(0, 0, 0)
128 | # 创建第二个骷髅
129 | self.kulo2 = rendering.make_circle(40)
130 | self.circletrans = rendering.Transform(translation=(460, 150))
131 | self.kulo2.add_attr(self.circletrans)
132 | self.kulo2.set_color(0, 0, 0)
133 | # 创建金条
134 | self.gold = rendering.make_circle(40)
135 | self.circletrans = rendering.Transform(translation=(300, 150))
136 | self.gold.add_attr(self.circletrans)
137 | self.gold.set_color(1, 0.9, 0)
138 | # 创建机器人
139 | self.robot = rendering.make_circle(30)
140 | self.robotrans = rendering.Transform()
141 | self.robot.add_attr(self.robotrans)
142 | self.robot.set_color(0.8, 0.6, 0.4)
143 |
144 | self.line1.set_color(0, 0, 0)
145 | self.line2.set_color(0, 0, 0)
146 | self.line3.set_color(0, 0, 0)
147 | self.line4.set_color(0, 0, 0)
148 | self.line5.set_color(0, 0, 0)
149 | self.line6.set_color(0, 0, 0)
150 | self.line7.set_color(0, 0, 0)
151 | self.line8.set_color(0, 0, 0)
152 | self.line9.set_color(0, 0, 0)
153 | self.line10.set_color(0, 0, 0)
154 | self.line11.set_color(0, 0, 0)
155 |
156 | self.viewer.add_geom(self.line1)
157 | self.viewer.add_geom(self.line2)
158 | self.viewer.add_geom(self.line3)
159 | self.viewer.add_geom(self.line4)
160 | self.viewer.add_geom(self.line5)
161 | self.viewer.add_geom(self.line6)
162 | self.viewer.add_geom(self.line7)
163 | self.viewer.add_geom(self.line8)
164 | self.viewer.add_geom(self.line9)
165 | self.viewer.add_geom(self.line10)
166 | self.viewer.add_geom(self.line11)
167 | self.viewer.add_geom(self.kulo1)
168 | self.viewer.add_geom(self.kulo2)
169 | self.viewer.add_geom(self.gold)
170 | self.viewer.add_geom(self.robot)
171 |
172 | if self.__state is None:
173 | return None
174 |
175 | self.robotrans.set_translation(self.x[self.__state - 1], self.y[self.__state - 1])
176 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
177 |
178 |
179 | if __name__ == '__main__':
180 | env = GridEnv()
181 | env.reset()
182 | env.render()
183 |
--------------------------------------------------------------------------------
/1-gym_developing/maze_game.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/7
5 | import gym
6 | from gym.utils import seeding
7 | import logging
8 | import numpy as np
9 | import pandas as pd
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class MazeEnv(gym.Env):
15 | metadata = {
16 | 'render.modes': ['human', 'rgb_array'],
17 | 'video.frames_per_second': 1
18 | }
19 |
20 | def __init__(self):
21 | self.map = np.array([
22 | [0, 0, 1, 0, 0],
23 | [0, 0, 1, 0, 0],
24 | [1, 0, 0, 0, 0],
25 | [1, 0, 0, 1, 1],
26 | [1, 0, 0, 0, 0]], dtype=np.bool)
27 |
28 | self.observation_space = [tuple(s) for s in np.argwhere(self.map == 0)]
29 | self.walls = [tuple(s) for s in np.argwhere(self.map)]
30 | self.__terminal_space = ((4, 2),) # 终止状态为字典格式
31 |
32 | # 状态转移的数据格式为字典
33 | self.action_space = ('n', 'e', 's', 'w')
34 | self.t = pd.DataFrame(data=None, index=self.observation_space, columns=self.action_space)
35 | self._trans_make()
36 | self.viewer = None
37 | self.__state = None
38 | self.seed()
39 |
40 | def _trans_make(self):
41 | for s in self.observation_space:
42 | for a in self.action_space:
43 | if a == "n":
44 | n_s = np.array(s) + np.array([0, 1])
45 | elif a == "e":
46 | n_s = np.array(s) + np.array([1, 0])
47 | elif a == "s":
48 | n_s = np.array(s) + np.array([0, -1])
49 | elif a == "w":
50 | n_s = np.array(s) + np.array([-1, 0])
51 | if (0 <= n_s).all() and (n_s <= 4).all() and not self.map[n_s[0], n_s[1]]:
52 | self.t.loc[s, a] = tuple(n_s)
53 | else:
54 | self.t.loc[s, a] = s
55 |
56 | def _reward(self, state):
57 | r = 0.0
58 | n_s = np.array(state)
59 | if (0 <= n_s).all() and (n_s <= 4).all() and tuple(n_s) in self.__terminal_space:
60 | r = 1.0
61 | return r
62 |
63 | def seed(self, seed=None):
64 | self.np_random, seed = seeding.np_random(seed)
65 | return [seed]
66 |
67 | def close(self):
68 | if self.viewer:
69 | self.viewer.close()
70 |
71 | def transform(self, state, action):
72 | # 卫语言
73 | if state in self.__terminal_space:
74 | return state, self._reward(state), True, {}
75 |
76 | # 状态转移
77 | next_state = self.t[action][state]
78 |
79 | # 计算回报
80 | r = self._reward(next_state)
81 |
82 | # 判断是否终止
83 | is_terminal = False
84 | if next_state in self.__terminal_space:
85 | is_terminal = True
86 |
87 | return next_state, r, is_terminal, {}
88 |
89 | def step(self, action):
90 | state = self.__state
91 |
92 | next_state, r, is_terminal, _ = self.transform(state, action)
93 |
94 | self.__state = next_state
95 |
96 | return next_state, r, is_terminal, {}
97 |
98 | def reset(self):
99 |
100 | while True:
101 | self.__state = self.observation_space[np.random.choice(len(self.observation_space))]
102 | if self.__state not in self.__terminal_space:
103 | break
104 | return self.__state
105 |
106 | def render(self, mode='human', close=False):
107 | if close:
108 | if self.viewer is not None:
109 | self.viewer.close()
110 | self.viewer = None
111 | return
112 | unit = 50
113 | screen_width = 5 * unit
114 | screen_height = 5 * unit
115 |
116 | if self.viewer is None:
117 | from gym.envs.classic_control import rendering
118 | self.viewer = rendering.Viewer(screen_width, screen_height)
119 |
120 | #创建网格
121 | for c in range(5):
122 | line = rendering.Line((0, c*unit), (screen_width, c*unit))
123 | line.set_color(0, 0, 0)
124 | self.viewer.add_geom(line)
125 | for r in range(5):
126 | line = rendering.Line((r*unit, 0), (r*unit, screen_height))
127 | line.set_color(0, 0, 0)
128 | self.viewer.add_geom(line)
129 |
130 | # 创建墙壁
131 | for x, y in self.walls:
132 | r = rendering.make_polygon(
133 | v=[[x * unit, y * unit],
134 | [(x + 1) * unit, y * unit],
135 | [(x + 1) * unit, (y + 1) * unit],
136 | [x * unit, (y + 1) * unit],
137 | [x * unit, y * unit]])
138 | r.set_color(0, 0, 0)
139 | self.viewer.add_geom(r)
140 |
141 | # 创建机器人
142 | self.robot = rendering.make_circle(20)
143 | self.robotrans = rendering.Transform()
144 | self.robot.add_attr(self.robotrans)
145 | self.robot.set_color(0.8, 0.6, 0.4)
146 | self.viewer.add_geom(self.robot)
147 |
148 | # 创建出口
149 | self.exit = rendering.make_circle(20)
150 | self.exitrans = rendering.Transform(translation=(4*unit+unit/2, 2*unit+unit/2))
151 | self.exit.add_attr(self.exitrans)
152 | self.exit.set_color(0, 1, 0)
153 | self.viewer.add_geom(self.exit)
154 |
155 | if self.__state is None:
156 | return None
157 |
158 | self.robotrans.set_translation(self.__state[0] * unit + unit / 2, self.__state[1] * unit + unit / 2)
159 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
160 |
161 |
162 | if __name__ == '__main__':
163 | env = MazeEnv()
164 | env.reset()
165 | env.render()
166 |
--------------------------------------------------------------------------------
/1-gym_developing/suceess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/1-gym_developing/suceess.png
--------------------------------------------------------------------------------
/2-markov_decision_process/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/2
--------------------------------------------------------------------------------
/2-markov_decision_process/game.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/7
5 | import gym
6 | import numpy as np
7 |
8 |
9 | def main():
10 | env = gym.make("MazeGame-v0")
11 | s = env.reset()
12 | a_s = env.action_space
13 | for i in range(100):
14 | env.render()
15 | a = np.random.choice(a_s)
16 | print(a)
17 | s, r, t, _ = env.step(a)
18 |
19 | if __name__ == '__main__':
20 | main()
21 |
--------------------------------------------------------------------------------
/2-markov_decision_process/our_life.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/2
5 | import numpy as np
6 | import pandas as pd
7 |
8 |
9 | def main():
10 | S = ("S1", "S2", "S3", "S4", "S5")
11 | A = ("玩", "退出", "学习", "发表", "睡觉")
12 | strategy = pd.DataFrame(data=None, index=S, columns=A)
13 | reward = pd.DataFrame(data=None, index=S, columns=A)
14 | gama = 1
15 | strategy.loc["S1", :] = np.array([0.9, 0.1, 0, 0, 0])
16 | strategy.loc["S2", :] = np.array([0.5, 0, 0.5, 0, 0])
17 | strategy.loc["S3", :] = np.array([0, 0, 0.8, 0, 0.2])
18 | strategy.loc["S4", :] = np.array([0, 0, 0, 0.4, 0.6])
19 | strategy.loc["S5", :] = np.array([0, 0, 0, 0, 0])
20 | print("策略")
21 | print(strategy)
22 | reward.loc["S1", :] = np.array([-1, 0, 0, 0, 0])
23 | reward.loc["S2", :] = np.array([-1, 0, -2, 0, 0])
24 | reward.loc["S3", :] = np.array([0, 0, -2, 0, 0])
25 | reward.loc["S4", :] = np.array([0, 0, 0, 10, 1])
26 | reward.loc["S5", :] = np.array([0, 0, 0, 0, 0])
27 | print("回报函数")
28 | print(reward)
29 |
30 |
31 | if __name__ == '__main__':
32 | main()
33 |
--------------------------------------------------------------------------------
/3-dynamic_program/grid_game_with_average_policy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/9
5 | '''
6 | 这里面存在几个问题,按照书作者的意思是
7 | 作用的动作如果不可以移动的话,那么
8 | 用于迭代的 v(s') = v(s)
9 | 再用于计算 v(s) = Σπ*(r + γ*v(s'))
10 | 我感觉这个是不对
11 | 如果没有移动应该就没有 r 和 折扣这一说
12 | 但是这里但是姑且用这么书中的方式
13 | 之后的策略迭代和值迭代都是修改过的
14 | '''
15 | import pandas as pd
16 | import numpy as np
17 |
18 |
19 | class GridMDP:
20 | def __init__(self, **kwargs):
21 | for k, v in kwargs.items():
22 | setattr(self, k, v)
23 | self.__action_dir = pd.Series(
24 | data=[np.array((-1, 0)),
25 | np.array((1, 0)),
26 | np.array((0, -1)),
27 | np.array((0, 1))],
28 | index=self.action_space)
29 | self.terminal_states = [(0, 0), (3, 3)]
30 |
31 | def transform(self, state, action):
32 | dir = self.__action_dir[action]
33 | state_ = np.array(state) + dir
34 | if (state_ >= 0).all() and (state_ < 4).all():
35 | state_ = tuple(state_)
36 | else:
37 | state_ = state
38 | return state_
39 |
40 |
41 | def average_policy(mdp, v_s, policy):
42 | state_space = mdp.state_space
43 | action_space = mdp.action_space
44 | reward = mdp.reward
45 | gamma = mdp.gamma
46 | while True:
47 | print(v_s)
48 | v_s_ = v_s.copy()
49 | for state in state_space:
50 | v_s_a = pd.Series()
51 | for action in action_space:
52 | state_ = mdp.transform(state,action)
53 | if state_ in mdp.terminal_states:
54 | v_s_a[action] = 0
55 | elif state_ != state:
56 | v_s_a[action] = v_s_[state_]
57 | else:
58 | v_s_a[action] = v_s_[state]
59 | v_s[state] = sum([policy[action] * (reward + gamma * v_s_a[action]) for action in action_space])
60 | if (np.abs(v_s_ - v_s) < 1e-8).all():
61 | break
62 | return v_s
63 |
64 |
65 | def main():
66 | state_space = [(i, j) for i in range(4) for j in range(4)]
67 | state_space.remove((0, 0))
68 | state_space.remove((3, 3))
69 | mdp = GridMDP(
70 | state_space=state_space,
71 | action_space=["n", "s", "w", "e"],
72 | reward=-1,
73 | gamma=1)
74 | v_s = pd.Series(np.zeros((len(state_space))),index=state_space)
75 | policy = pd.Series(data=0.25 * np.ones(shape=(4)), index=mdp.action_space)
76 | v_s = average_policy(mdp, v_s, policy)
77 | print("convergence valuation of __state is:")
78 | print(v_s)
79 |
80 |
81 | if __name__ == '__main__':
82 | main()
83 |
--------------------------------------------------------------------------------
/3-dynamic_program/grid_game_with_policy_iterate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/9
5 | import pandas as pd
6 | import numpy as np
7 |
8 |
9 | class GridMDP:
10 | def __init__(self, **kwargs):
11 | for k, v in kwargs.items():
12 | setattr(self, k, v)
13 | self.__action_dir = pd.Series(
14 | data=[np.array((-1, 0)),
15 | np.array((1, 0)),
16 | np.array((0, -1)),
17 | np.array((0, 1))],
18 | index=self.action_space)
19 | self.terminal_space = [(0, 0), (3, 3)]
20 |
21 | def transform(self, state, action):
22 | dir = self.__action_dir[action]
23 | state_ = np.array(state) + dir
24 | if (state_ >= 0).all() and (state_ < 4).all():
25 | state_ = tuple(state_)
26 | else:
27 | state_ = state
28 | return state_
29 |
30 |
31 | def policy_evaluate(v_s, policy, mdp):
32 | state_space = mdp.state_space
33 | gamma = mdp.gamma
34 | reward = mdp.reward
35 | while True:
36 | v_s_ = v_s.copy()
37 | for state in state_space:
38 | action = policy[state]
39 | state_ = mdp.transform(state, action)
40 | if state_ in mdp.terminal_space: # 发生转移
41 | v_s[state] = reward + 0.0
42 | elif state_ != state: # 终点位置
43 | v_s[state] = reward + gamma * v_s_[state_]
44 | else: # 没有发生转移
45 | v_s[state] = reward + gamma * v_s_[state_]
46 |
47 | if (np.abs(v_s_ - v_s) < 1e-8).all():
48 | break
49 | return v_s
50 |
51 |
52 | def policy_improve(v_s, mdp):
53 | state_space = mdp.state_space
54 | action_space = mdp.action_space
55 | gamma = mdp.gamma
56 | reward = mdp.reward
57 | policy_ = pd.Series(index=state_space)
58 | for state in state_space:
59 | v_s_a = pd.Series()
60 | for action in action_space:
61 | state_ = mdp.transform(state, action)
62 | if state_ in mdp.terminal_space:
63 | v_s_a[action] = reward
64 | else:
65 | v_s_a[action] = reward + gamma * v_s[state_]
66 |
67 | # 随机选取最大的值
68 | m = v_s_a.max()
69 | policy_[state] = np.random.choice(v_s_a[v_s_a == m].index)
70 | return policy_
71 |
72 |
73 | def policy_iterate(mdp):
74 | v_s = pd.Series(data=np.zeros(shape=len(mdp.state_space)), index=mdp.state_space)
75 | policy = pd.Series(data=np.random.choice(mdp.action_space, size=(len(mdp.state_space))), index=mdp.state_space)
76 | while True:
77 | print(policy)
78 | v_s = policy_evaluate(v_s, policy, mdp)
79 | policy_ = policy_improve(v_s, mdp)
80 | if (policy_ == policy).all():
81 | break
82 | else:
83 | policy = policy_
84 | return policy
85 |
86 |
87 | def main():
88 | state_space = [(i, j) for i in range(4) for j in range(4)]
89 | state_space.remove((0, 0))
90 | state_space.remove((3, 3))
91 | mdp = GridMDP(
92 | state_space=state_space,
93 | action_space=["n", "s", "w", "e"],
94 | reward=-1,
95 | gamma=0.9)
96 | policy = policy_iterate(mdp)
97 | print("convergence policy is:")
98 | print(policy)
99 |
100 |
101 | if __name__ == '__main__':
102 | main()
103 |
--------------------------------------------------------------------------------
/3-dynamic_program/grid_game_with_value_iterate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/9
5 | import pandas as pd
6 | import numpy as np
7 |
8 |
9 | class GridMDP:
10 | def __init__(self, **kwargs):
11 | for k, v in kwargs.items():
12 | setattr(self, k, v)
13 | self.__action_dir = pd.Series(
14 | data=[np.array((-1, 0)),
15 | np.array((1, 0)),
16 | np.array((0, -1)),
17 | np.array((0, 1))],
18 | index=self.action_space)
19 | self.terminal_space = [(0, 0), (3, 3)]
20 |
21 | def transform(self, state, action):
22 | dir = self.__action_dir[action]
23 | state_ = np.array(state) + dir
24 | if (state_ >= 0).all() and (state_ < 4).all():
25 | state_ = tuple(state_)
26 | else:
27 | state_ = state
28 | return state_
29 |
30 |
31 | def value_iterate(mdp):
32 | state_space = mdp.state_space
33 | action_space = mdp.action_space
34 | gamma = mdp.gamma
35 | reward = mdp.reward
36 | v_s = pd.Series(data=np.zeros(shape=len(state_space)), index=state_space)
37 | policy = pd.Series(index=state_space)
38 | while True:
39 | print(v_s)
40 | v_s_ = v_s.copy()
41 | for state in state_space:
42 | v_s_a = pd.Series()
43 | for action in action_space:
44 | state_ = mdp.transform(state,action)
45 | if state_ in mdp.terminal_space:
46 | v_s_a[action] = reward
47 | else:
48 | v_s_a[action] = reward + gamma * v_s_[state_]
49 |
50 | v_s[state] = v_s_a.max()
51 | policy[state] = np.random.choice(v_s_a[v_s_a == v_s[state]].index)
52 |
53 | if (np.abs(v_s_ - v_s) < 1e-8).all():
54 | break
55 | return policy
56 |
57 |
58 | def main():
59 | state_space = [(i, j) for i in range(4) for j in range(4)]
60 | state_space.remove((0, 0))
61 | state_space.remove((3, 3))
62 | mdp = GridMDP(
63 | state_space=state_space,
64 | action_space=["n", "s", "w", "e"],
65 | reward=-1,
66 | gamma=0.9)
67 | policy = value_iterate(mdp)
68 | print("convergence policy is:")
69 | print(policy)
70 |
71 |
72 | if __name__ == '__main__':
73 | main()
74 |
--------------------------------------------------------------------------------
/3-dynamic_program/maze_game_with_dynamic_program.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/9
5 | import gym
6 | import pandas as pd
7 | import numpy as np
8 |
9 |
10 | def value_iterate(env):
11 | state_space = env.observation_space
12 | action_space = env.action_space
13 | v_s = pd.Series(data=np.zeros(shape=len(state_space)), index=state_space)
14 | policy = pd.Series(index=state_space)
15 | gamma = 0.8
16 | while True:
17 | print(v_s)
18 | v_s_ = v_s.copy()
19 | for state in state_space:
20 | v_s_a = pd.Series()
21 | for action in action_space:
22 | state_, reward, is_done, _ = env.transform(state, action)
23 | if is_done:
24 | v_s_a[action] = reward
25 | else:
26 | v_s_a[action] = reward + gamma*v_s_[state_]
27 | v_s[state] = v_s_a.max()
28 | policy[state] = np.random.choice((v_s_a == v_s[state]).index)
29 | if (np.abs(v_s_ - v_s) < 1e-8).all():
30 | break
31 | return policy
32 |
33 |
34 | ### 这个就是一个伪代码 ###
35 | def main():
36 | env = gym.make("MazeGame-v0")
37 | print(env.action_space)
38 | print(env.observation_space)
39 | policy = value_iterate(env)
40 | print("convergence policy is:")
41 | print(policy)
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/3-dynamic_program/policy_iteration_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/3-dynamic_program/policy_iteration_algorithm.png
--------------------------------------------------------------------------------
/3-dynamic_program/value_iteration_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/3-dynamic_program/value_iteration_algorithm.png
--------------------------------------------------------------------------------
/4-monte_carlo/monte_carlo_evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/5/2
5 | from collections import defaultdict
6 |
7 |
8 | def mc(gamma, env, state_sample, reward_sample):
9 | V = defaultdict(float)
10 | N = defaultdict(int)
11 | states = env.observation_space
12 | num = len(state_sample)
13 | for i in range(num):
14 | G = 0.0
15 | episode_len = len(state_sample[i])
16 | # 从后往前尝试
17 | for episode in range(episode_len-1, -1, -1):
18 | G *= gamma
19 | G += reward_sample[i][episode]
20 |
21 | # 计算每一状态的值函数累加
22 | for episode in range(episode_len):
23 | s = state_sample[i][episode]
24 | V[s] += G
25 | N[s] += 1
26 | G -= reward_sample[i][episode]
27 | G /= gamma
28 |
29 | # 经验平均
30 | for s in states:
31 | if N[s] >= 0.000001:
32 | V[s] /= N[s]
33 | return V
34 |
--------------------------------------------------------------------------------
/4-monte_carlo/monte_carlo_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/4/30
5 | import numpy as np
6 |
7 |
8 | def gen_andom(env, num):
9 | state_sample = []
10 | action_sample = []
11 | reward_sample = []
12 | # 模拟num次的采样
13 | for i in range(num):
14 | s_tmp = []
15 | a_tmp = []
16 | r_tmp = []
17 | s = env.reset()
18 | is_done = False
19 | # 每次采样的过程
20 | while not is_done:
21 | a = np.random.choice(env.action_space)
22 | s_, r, is_done, _ = env.transform(s, a)
23 | s_tmp.append(s)
24 | a_tmp.append(a)
25 | r_tmp.append(r)
26 | s = s_
27 | state_sample.append(s_tmp)
28 | action_sample.append(a_tmp)
29 | reward_sample.append(r_tmp)
30 | return state_sample, action_sample, reward_sample
31 |
--------------------------------------------------------------------------------
/5-temporal_difference/README.md:
--------------------------------------------------------------------------------
1 | ## 使用方法
2 | 1. 参考1-gym_developing里面的[使用方法](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/README.md)安装push_box_game环境
3 | 环境文件是push_box_game.py
4 |
5 |
--------------------------------------------------------------------------------
/5-temporal_difference/push_box_game.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2018/5/4
5 | import gym
6 | from gym.utils import seeding
7 | from gym.envs.classic_control import rendering
8 | import logging
9 | import numpy as np
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class PushBoxEnv(gym.Env):
15 | metadata = {
16 | 'render.modes': ['human', 'rgb_array'],
17 | 'video.frames_per_second': 1
18 | }
19 |
20 | def __init__(self):
21 | self.reward = np.array([
22 | [-100, -100, -100, -100, -100],
23 | [-100, 0, 0, 0, -100],
24 | [-100, 0, -100, 0, -100],
25 | [-100, 0, -100, 0, -100],
26 | [-100, 0, 100, 0, -100],
27 | ]).T
28 | self.walls= [(0, 1), (0, 3), (2, 2), (2, 3)]
29 | self.dest_position = (2, 4)
30 | self.work_position = (2, 0)
31 | self.box_position = (2, 1)
32 | self.action_space = ((1, 0), (-1, 0), (0, 1), (0, -1))
33 | self.viewer = None
34 | self.__state = None
35 | self.seed()
36 |
37 | def _is_terminal(self, state):
38 | box_position = state[2:]
39 | if self.reward[box_position] != 0:
40 | return True
41 | return False
42 |
43 | def _move_ok(self, position):
44 | if tuple(position) not in self.walls and \
45 | (position >= 0).all() and \
46 | (position < 5).all():
47 | return True
48 | return False
49 |
50 | def _trans_make(self, state, action):
51 | work_position = np.array(state[:2])
52 | box_position = np.array(state[2:])
53 | action = np.array(action)
54 | next_work_position = work_position + action
55 | # 判断人可不可以移动
56 | if self._move_ok(next_work_position):
57 | work_position = next_work_position
58 | # 判断箱子可以移动
59 | if (next_work_position == box_position).all():
60 |
61 | next_box_position = box_position + action
62 | if self._move_ok(next_box_position):
63 | # 说明箱子可以移动
64 | box_position = next_box_position
65 | else:
66 | # 虽然箱子在前方但是只是挡住了自己的路线,需要将已经移动的人还原
67 | work_position -= action
68 | return tuple(np.hstack((work_position, box_position)))
69 |
70 | def _reward(self, state):
71 | box_position = state[2:]
72 | return self.reward[box_position]
73 |
74 | def seed(self, seed=None):
75 | self.np_random, seed = seeding.np_random(seed)
76 | return [seed]
77 |
78 | def close(self):
79 | if self.viewer:
80 | self.viewer.close()
81 |
82 | def transform(self, state, action):
83 | # 卫语言
84 | if self._is_terminal(state):
85 | return state, self._reward(state), True, {}
86 |
87 | # 状态转移
88 | next_state = self._trans_make(state, action)
89 |
90 | # 计算回报
91 | r = self._reward(next_state)
92 |
93 | # 判断是否终止
94 | is_terminal = False
95 | if self._is_terminal(next_state):
96 | is_terminal = True
97 |
98 | return next_state, r, is_terminal, {}
99 |
100 | def step(self, action):
101 | # 系统当前状态
102 | state = self.__state
103 |
104 | # 调用transform状态转移
105 | next_state, r, is_terminal, _ = self.transform(state, action)
106 |
107 | # 状态转移
108 | self.__state = next_state
109 |
110 | return next_state, r, is_terminal, {}
111 |
112 | def reset(self):
113 | self.__state = tuple(np.hstack((self.work_position, self.box_position)))
114 | return self.__state
115 |
116 | def render(self, mode='human', close=False):
117 | if close:
118 | if self.viewer is not None:
119 | self.viewer.close()
120 | self.viewer = None
121 | return
122 | unit = 50
123 | screen_width = 5 * unit
124 | screen_height = 5 * unit
125 |
126 | if self.viewer is None:
127 |
128 | self.viewer = rendering.Viewer(screen_width, screen_height)
129 |
130 | #创建网格
131 | for c in range(5):
132 | line = rendering.Line((0, c*unit), (screen_width, c*unit))
133 | line.set_color(0, 0, 0)
134 | self.viewer.add_geom(line)
135 | for r in range(5):
136 | line = rendering.Line((r*unit, 0), (r*unit, screen_height))
137 | line.set_color(0, 0, 0)
138 | self.viewer.add_geom(line)
139 |
140 | # 创建墙壁
141 | for x, y in self.walls:
142 | r = rendering.make_polygon(
143 | v=[[x * unit, y * unit],
144 | [(x + 1) * unit, y * unit],
145 | [(x + 1) * unit, (y + 1) * unit],
146 | [x * unit, (y + 1) * unit],
147 | [x * unit, y * unit]])
148 | r.set_color(0, 0, 0)
149 | self.viewer.add_geom(r)
150 |
151 | # 创建终点
152 | d_x, d_y = self.dest_position
153 | dest = rendering.make_polygon(v=[
154 | [d_x*unit, d_y*unit],
155 | [(d_x+1)*unit, d_y*unit],
156 | [(d_x+1)*unit, (d_y+1)*unit],
157 | [d_x*unit, (d_y+1)*unit],
158 | [d_x*unit, d_y*unit]])
159 | dest_trans = rendering.Transform()
160 | dest.add_attr(dest_trans)
161 | dest.set_color(1, 0, 0)
162 | self.viewer.add_geom(dest)
163 |
164 | # 创建worker
165 | self.work = rendering.make_circle(20)
166 | self.work_trans = rendering.Transform()
167 | self.work.add_attr(self.work_trans)
168 | self.work.set_color(0, 1, 0)
169 | self.viewer.add_geom(self.work)
170 |
171 | # 创建箱子
172 | self.box = rendering.make_circle(20)
173 | self.box_trans = rendering.Transform()
174 | self.box.add_attr(self.box_trans)
175 | self.box.set_color(0, 0, 1)
176 | self.viewer.add_geom(self.box)
177 |
178 | if self.__state is None:
179 | return None
180 | w_x,w_y = self.__state[:2]
181 | b_x,b_y = self.__state[2:]
182 | self.work_trans.set_translation(w_x*unit+unit/2, w_y*unit+unit/2)
183 | self.box_trans.set_translation(b_x*unit+unit/2, b_y*unit+unit/2)
184 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
185 |
186 |
187 | if __name__ == '__main__':
188 | env = PushBoxEnv()
189 | env.reset()
190 | env.render()
191 | env.step((-1, 0))
192 | env.render()
193 |
--------------------------------------------------------------------------------
/5-temporal_difference/push_box_game/agent.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2017/10/4
5 |
6 | import numpy as np
7 | from sklearn.preprocessing import LabelBinarizer
8 | import cv2
9 | import random
10 | import pandas as pd
11 | import os
12 | import tensorflow as tf
13 | from collections import deque
14 |
15 |
16 | class Agent(object):
17 |
18 | def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
19 | self.actions = actions
20 | self.learning_rate = learning_rate
21 | self.reward_decay = reward_decay
22 | self.epsilon = e_greedy
23 | self.q_table = pd.DataFrame()
24 |
25 | def check_state_exist(self, state):
26 | if tuple(state) not in self.q_table.columns:
27 | self.q_table[tuple(state)] = [0]*len(self.actions)
28 |
29 | def choose_action(self, observation):
30 | observation = tuple(observation)
31 | self.check_state_exist(observation)
32 | if np.random.uniform() < self.epsilon:
33 | state_action = self.q_table[observation]
34 | state_action = state_action.reindex(np.random.permutation(state_action.index)) # some action_space have same value
35 | action_idx = state_action.argmax()
36 | else:
37 | action_idx = np.random.choice(range(len(self.actions)))
38 | return action_idx
39 |
40 | def save_train_parameter(self, name):
41 | self.q_table.to_pickle(name)
42 |
43 | def load_train_parameter(self, name):
44 | self.q_table = pd.read_pickle(name)
45 |
46 | def learn(self, *args, **kwargs):
47 | pass
48 |
49 | def train(self, *args, **kwargs):
50 | pass
51 |
52 |
53 | class QLearningAgent(Agent):
54 |
55 | def learn(self, s, a, r, s_, done):
56 | self.check_state_exist(s_)
57 | q_predict = self.q_table[s][a]
58 | if not done:
59 | q_target = r + self.reward_decay * self.q_table[s_].max() # next __state is not terminal
60 | else:
61 | q_target = r # next __state is terminal
62 | self.q_table[s][a] += self.learning_rate * (q_target - q_predict) # update
63 |
64 | def train(self, env, max_iterator=100):
65 | self.load_train_parameter("q_table.pkl")
66 | for episode in range(max_iterator):
67 |
68 | observation = env.reset()
69 |
70 | while True:
71 | # env.render()
72 |
73 | action_idx = self.choose_action(observation)
74 |
75 | observation_, reward, done, _ = env.step(self.actions[action_idx])
76 |
77 | print(observation, reward)
78 |
79 | self.learn(observation, action_idx, reward, observation_, done)
80 |
81 | observation = observation_
82 |
83 | if done:
84 | self.save_train_parameter("q_table.pkl")
85 | break
86 |
87 |
88 | class SarsaAgent(Agent):
89 |
90 | def learn(self, s, a, r, s_, a_, done):
91 | self.check_state_exist(s_)
92 | q_predict = self.q_table[s][a]
93 | if not done:
94 | q_target = r + self.reward_decay * self.q_table[s_][a_] # next __state is not terminal
95 | else:
96 | q_target = r # next __state is terminal
97 | self.q_table[s][a] += self.learning_rate*(q_target - q_predict) # update
98 |
99 | def train(self, env, max_iterator=100):
100 | self.load_train_parameter("q_table.pkl")
101 | for episode in range(max_iterator):
102 |
103 | observation = env.reset()
104 | action_idx = self.choose_action(observation)
105 | while True:
106 | # env.render()
107 | print(observation)
108 |
109 | observation_, reward, done, _ = env.step(self.actions[action_idx])
110 |
111 | action_idx_ = self.choose_action(observation_)
112 |
113 | self.learn(observation, action_idx, reward, observation_, action_idx_, done)
114 |
115 | observation = observation_
116 | action_idx = action_idx_
117 |
118 | if done:
119 | self.save_train_parameter("q_table.pkl")
120 | break
121 |
--------------------------------------------------------------------------------
/5-temporal_difference/push_box_game/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2017/10/4
5 | import gym
6 | from agent import QLearningAgent
7 | if __name__ == "__main__":
8 | env = gym.make("PushBoxGame-v0")
9 | RL = QLearningAgent(actions=env.action_space)
10 | RL.train(env)
--------------------------------------------------------------------------------
/5-temporal_difference/push_box_game/q_table.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/push_box_game/q_table.pkl
--------------------------------------------------------------------------------
/5-temporal_difference/q_learning_algortihm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/q_learning_algortihm.png
--------------------------------------------------------------------------------
/5-temporal_difference/sarsa_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/sarsa_algorithm.png
--------------------------------------------------------------------------------
/5-temporal_difference/sarsa_lambda_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/5-temporal_difference/sarsa_lambda_algorithm.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all pyc files.
2 | *.pyc
3 |
4 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/README.md:
--------------------------------------------------------------------------------
1 | # Using Deep Q-Network to Learn How To Play Flappy Bird
2 |
3 |
4 |
5 | 7 mins version: [DQN for flappy bird](https://www.youtube.com/watch?v=THhUXIhjkCM)
6 |
7 | ## Overview
8 | This project follows the description of the Deep Q Learning algorithm described in Playing Atari with Deep Reinforcement Learning [2] and shows that this learning algorithm can be further generalized to the notorious Flappy Bird.
9 |
10 | ## Installation Dependencies:
11 | * Python 2.7 or 3
12 | * TensorFlow 0.7
13 | * pygame
14 | * OpenCV-Python
15 |
16 | ## How to Run?
17 | ```
18 | git clone https://github.com/yenchenlin1994/DeepLearningFlappyBird.git
19 | cd DeepLearningFlappyBird
20 | python deep_q_network.py
21 | ```
22 |
23 | ## What is Deep Q-Network?
24 | It is a convolutional neural network, trained with a variant of Q-learning, whose input is raw pixels and whose output is a value function estimating future rewards.
25 |
26 | For those who are interested in deep reinforcement learning, I highly recommend to read the following post:
27 |
28 | [Demystifying Deep Reinforcement Learning](http://www.nervanasys.com/demystifying-deep-reinforcement-learning/)
29 |
30 | ## Deep Q-Network Algorithm
31 |
32 | The pseudo-code for the Deep Q Learning algorithm, as given in [1], can be found below:
33 |
34 | ```
35 | Initialize replay memory D to size N
36 | Initialize action-value function Q with random weights
37 | for episode = 1, M do
38 | Initialize state s_1
39 | for t = 1, T do
40 | With probability ϵ select random action a_t
41 | otherwise select a_t=max_a Q(s_t,a; θ_i)
42 | Execute action a_t in emulator and observe r_t and s_(t+1)
43 | Store transition (s_t,a_t,r_t,s_(t+1)) in D
44 | Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
45 | Set y_j:=
46 | r_j for terminal s_(j+1)
47 | r_j+γ*max_(a^' ) Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
48 | Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
49 | end for
50 | end for
51 | ```
52 |
53 | ## Experiments
54 |
55 | #### Environment
56 | Since deep Q-network is trained on the raw pixel values observed from the game screen at each time step, [3] finds that remove the background appeared in the original game can make it converge faster. This process can be visualized as the following figure:
57 |
58 |
59 |
60 | #### Network Architecture
61 | According to [1], I first preprocessed the game screens with following steps:
62 |
63 | 1. Convert image to grayscale
64 | 2. Resize image to 80x80
65 | 3. Stack last 4 frames to produce an 80x80x4 input array for network
66 |
67 | The architecture of the network is shown in the figure below. The first layer convolves the input image with an 8x8x4x32 kernel at a stride size of 4. The output is then put through a 2x2 max pooling layer. The second layer convolves with a 4x4x32x64 kernel at a stride of 2. We then max pool again. The third layer convolves with a 3x3x64x64 kernel at a stride of 1. We then max pool one more time. The last hidden layer consists of 256 fully connected ReLU nodes.
68 |
69 |
70 |
71 | The final output layer has the same dimensionality as the number of valid actions which can be performed in the game, where the 0th index always corresponds to doing nothing. The values at this output layer represent the Q function given the input state for each valid action. At each time step, the network performs whichever action corresponds to the highest Q value using a ϵ greedy policy.
72 |
73 |
74 | #### Training
75 | At first, I initialize all weight matrices randomly using a normal distribution with a standard deviation of 0.01, then set the replay memory with a max size of 500,00 experiences.
76 |
77 | I start training by choosing actions uniformly at random for the first 10,000 time steps, without updating the network weights. This allows the system to populate the replay memory before training begins.
78 |
79 | Note that unlike [1], which initialize ϵ = 1, I linearly anneal ϵ from 0.1 to 0.0001 over the course of the next 3000,000 frames. The reason why I set it this way is that agent can choose an action every 0.03s (FPS=30) in our game, high ϵ will make it **flap** too much and thus keeps itself at the top of the game screen and finally bump the pipe in a clumsy way. This condition will make Q function converge relatively slow since it only start to look other conditions when ϵ is low.
80 | However, in other games, initialize ϵ to 1 is more reasonable.
81 |
82 | During training time, at each time step, the network samples minibatches of size 32 from the replay memory to train on, and performs a gradient step on the loss function described above using the Adam optimization algorithm with a learning rate of 0.000001. After annealing finishes, the network continues to train indefinitely, with ϵ fixed at 0.001.
83 |
84 | ## FAQ
85 |
86 | #### Checkpoint not found
87 | Change [first line of `saved_networks/checkpoint`](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/saved_networks/checkpoint#L1) to
88 |
89 | `model_checkpoint_path: "saved_networks/bird-dqn-2920000"`
90 |
91 | #### How to reproduce?
92 | 1. Comment out [these lines](https://github.com/yenchenlin1994/DeepLearningFlappyBird/blob/master/deep_q_network.py#L108-L112)
93 |
94 | 2. Modify `deep_q_network.py`'s parameter as follow:
95 | ```python
96 | OBSERVE = 10000
97 | EXPLORE = 3000000
98 | FINAL_EPSILON = 0.0001
99 | INITIAL_EPSILON = 0.1
100 | ```
101 |
102 | ## References
103 |
104 | [1] Mnih Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves, Martin Riedmiller, Andreas K. Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, and Demis Hassabis. **Human-level Control through Deep Reinforcement Learning**. Nature, 529-33, 2015.
105 |
106 | [2] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. **Playing Atari with Deep Reinforcement Learning**. NIPS, Deep Learning workshop
107 |
108 | [3] Kevin Chen. **Deep Reinforcement Learning for Flappy Bird** [Report](http://cs229.stanford.edu/proj2015/362_report.pdf) | [Youtube result](https://youtu.be/9WKBzTUsPKc)
109 |
110 | ## Disclaimer
111 | This work is highly based on the following repos:
112 |
113 | 1. [sourabhv/FlapPyBird] (https://github.com/sourabhv/FlapPyBird)
114 | 2. [asrivat1/DeepLearningVideoGames](https://github.com/asrivat1/DeepLearningVideoGames)
115 |
116 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.ogg
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/die.wav
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.ogg
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/hit.wav
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.ogg
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/point.wav
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.ogg
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/swoosh.wav
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.ogg
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/audio/wing.wav
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/0.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/1.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/2.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/3.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/4.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/5.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/6.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/7.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/8.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/9.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/background-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/background-black.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/base.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/base.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/pipe-green.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/pipe-green.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-downflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-downflap.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-midflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-midflap.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-upflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/assets/sprites/redbird-upflap.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/deep_q_network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import print_function
3 |
4 | import tensorflow as tf
5 | import cv2
6 | import sys
7 | sys.path.append("game/")
8 | import game.wrapped_flappy_bird as game
9 | import random
10 | import numpy as np
11 | from collections import deque
12 |
13 | GAME = 'bird' # the name of the game being played for log files
14 | ACTIONS = 2 # number of valid actions
15 | GAMMA = 0.99 # decay rate of past observations
16 | OBSERVE = 100000. # timesteps to observe before training
17 | EXPLORE = 2000000. # frames over which to anneal epsilon
18 | FINAL_EPSILON = 0.0001 # final value of epsilon
19 | INITIAL_EPSILON = 0.0001 # starting value of epsilon
20 | REPLAY_MEMORY = 50000 # number of previous transitions to remember
21 | BATCH = 32 # size of minibatch
22 | FRAME_PER_ACTION = 1
23 |
24 | def weight_variable(shape):
25 | initial = tf.truncated_normal(shape, stddev = 0.01)
26 | return tf.Variable(initial)
27 |
28 | def bias_variable(shape):
29 | initial = tf.constant(0.01, shape = shape)
30 | return tf.Variable(initial)
31 |
32 | def conv2d(x, W, stride):
33 | return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME")
34 |
35 | def max_pool_2x2(x):
36 | return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")
37 |
38 | def createNetwork():
39 | # network weights
40 | W_conv1 = weight_variable([8, 8, 4, 32])
41 | b_conv1 = bias_variable([32])
42 |
43 | W_conv2 = weight_variable([4, 4, 32, 64])
44 | b_conv2 = bias_variable([64])
45 |
46 | W_conv3 = weight_variable([3, 3, 64, 64])
47 | b_conv3 = bias_variable([64])
48 |
49 | W_fc1 = weight_variable([1600, 512])
50 | b_fc1 = bias_variable([512])
51 |
52 | W_fc2 = weight_variable([512, ACTIONS])
53 | b_fc2 = bias_variable([ACTIONS])
54 |
55 | # input layer
56 | s = tf.placeholder("float", [None, 80, 80, 4])
57 |
58 | # hidden layers
59 | h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1)
60 | h_pool1 = max_pool_2x2(h_conv1)
61 |
62 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2)
63 | #h_pool2 = max_pool_2x2(h_conv2)
64 |
65 | h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3)
66 | #h_pool3 = max_pool_2x2(h_conv3)
67 |
68 | #h_pool3_flat = tf.reshape(h_pool3, [-1, 256])
69 | h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
70 |
71 | h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)
72 |
73 | # readout layer
74 | readout = tf.matmul(h_fc1, W_fc2) + b_fc2
75 |
76 | return s, readout, h_fc1
77 |
78 | def trainNetwork(s, readout, h_fc1, sess):
79 | # define the cost function
80 | a = tf.placeholder("float", [None, ACTIONS])
81 | y = tf.placeholder("float", [None])
82 | readout_action = tf.reduce_sum(tf.multiply(readout, a), reduction_indices=1)
83 | cost = tf.reduce_mean(tf.square(y - readout_action))
84 | train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)
85 |
86 | # open up a game __state to communicate with emulator
87 | game_state = game.GameState()
88 |
89 | # store the previous observations in replay memory
90 | D = deque()
91 |
92 | # printing
93 | a_file = open("logs_" + GAME + "/readout.txt", 'w')
94 | h_file = open("logs_" + GAME + "/hidden.txt", 'w')
95 |
96 | # get the first __state by doing nothing and preprocess the image to 80x80x4
97 | do_nothing = np.zeros(ACTIONS)
98 | do_nothing[0] = 1
99 | x_t, r_0, terminal = game_state.frame_step(do_nothing)
100 | x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
101 | ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
102 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
103 |
104 | # saving and loading networks
105 | saver = tf.train.Saver()
106 | sess.run(tf.initialize_all_variables())
107 | checkpoint = tf.train.get_checkpoint_state("saved_networks")
108 | if checkpoint and checkpoint.model_checkpoint_path:
109 | saver.restore(sess, checkpoint.model_checkpoint_path)
110 | print("Successfully loaded:", checkpoint.model_checkpoint_path)
111 | else:
112 | print("Could not find old network weights")
113 |
114 | # start training
115 | epsilon = INITIAL_EPSILON
116 | t = 0
117 | while "flappy bird" != "angry bird":
118 | # choose an action epsilon greedily
119 | readout_t = readout.eval(feed_dict={s : [s_t]})[0]
120 | a_t = np.zeros([ACTIONS])
121 | action_index = 0
122 |
123 | if t % FRAME_PER_ACTION == 0:
124 | if random.random() <= epsilon:
125 | print("----------Random Action----------")
126 | action_index = random.randrange(ACTIONS)
127 | a_t[random.randrange(ACTIONS)] = 1
128 | else:
129 | action_index = np.argmax(readout_t)
130 | a_t[action_index] = 1
131 | else:
132 | a_t[0] = 1 # do nothing
133 |
134 | # scale down epsilon
135 | if epsilon > FINAL_EPSILON and t > OBSERVE:
136 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
137 |
138 | # run the selected action and observe next __state and reward
139 | x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
140 | x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
141 | ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
142 | x_t1 = np.reshape(x_t1, (80, 80, 1))
143 | #s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
144 | s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)
145 |
146 | # store the transition in D
147 | D.append((s_t, a_t, r_t, s_t1, terminal))
148 | if len(D) > REPLAY_MEMORY:
149 | D.popleft()
150 |
151 | # only train if done observing
152 | if t > OBSERVE:
153 | # sample a minibatch to train on
154 | minibatch = random.sample(D, BATCH)
155 |
156 | # get the batch variables
157 | s_j_batch = [d[0] for d in minibatch]
158 | a_batch = [d[1] for d in minibatch]
159 | r_batch = [d[2] for d in minibatch]
160 | s_j1_batch = [d[3] for d in minibatch]
161 |
162 | y_batch = []
163 | readout_j1_batch = readout.eval(feed_dict = {s : s_j1_batch})
164 | for i in range(0, len(minibatch)):
165 | terminal = minibatch[i][4]
166 | # if terminal, only equals reward
167 | if terminal:
168 | y_batch.append(r_batch[i])
169 | else:
170 | y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i]))
171 |
172 | # perform gradient step
173 | train_step.run(feed_dict = {
174 | y : y_batch,
175 | a : a_batch,
176 | s : s_j_batch}
177 | )
178 |
179 | # update the old values
180 | s_t = s_t1
181 | t += 1
182 |
183 | # save progress every 10000 iterations
184 | if t % 10000 == 0:
185 | saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step = t)
186 |
187 | # print info
188 | state = ""
189 | if t <= OBSERVE:
190 | state = "observe"
191 | elif t > OBSERVE and t <= OBSERVE + EXPLORE:
192 | state = "explore"
193 | else:
194 | state = "train"
195 |
196 | print("TIMESTEP", t, "/ STATE", state, \
197 | "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
198 | "/ Q_MAX %e" % np.max(readout_t))
199 | # write info to files
200 | '''
201 | if t % 10000 <= 100:
202 | a_file.write(",".join([str(x) for x in readout_t]) + '\n')
203 | h_file.write(",".join([str(x) for x in h_fc1.eval(feed_dict={s:[s_t]})[0]]) + '\n')
204 | cv2.imwrite("logs_tetris/frame" + str(t) + ".png", x_t1)
205 | '''
206 |
207 | def playGame():
208 | sess = tf.InteractiveSession()
209 | s, readout, h_fc1 = createNetwork()
210 | trainNetwork(s, readout, h_fc1, sess)
211 |
212 | def main():
213 | playGame()
214 |
215 | if __name__ == "__main__":
216 | main()
217 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/game/flappy_bird_utils.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import sys
3 | def load():
4 | # path of player with different observation_space
5 | PLAYER_PATH = (
6 | 'assets/sprites/redbird-upflap.png',
7 | 'assets/sprites/redbird-midflap.png',
8 | 'assets/sprites/redbird-downflap.png'
9 | )
10 |
11 | # path of background
12 | BACKGROUND_PATH = 'assets/sprites/background-black.png'
13 |
14 | # path of pipe
15 | PIPE_PATH = 'assets/sprites/pipe-green.png'
16 |
17 | IMAGES, SOUNDS, HITMASKS = {}, {}, {}
18 |
19 | # numbers sprites for score display
20 | IMAGES['numbers'] = (
21 | pygame.image.load('assets/sprites/0.png').convert_alpha(),
22 | pygame.image.load('assets/sprites/1.png').convert_alpha(),
23 | pygame.image.load('assets/sprites/2.png').convert_alpha(),
24 | pygame.image.load('assets/sprites/3.png').convert_alpha(),
25 | pygame.image.load('assets/sprites/4.png').convert_alpha(),
26 | pygame.image.load('assets/sprites/5.png').convert_alpha(),
27 | pygame.image.load('assets/sprites/6.png').convert_alpha(),
28 | pygame.image.load('assets/sprites/7.png').convert_alpha(),
29 | pygame.image.load('assets/sprites/8.png').convert_alpha(),
30 | pygame.image.load('assets/sprites/9.png').convert_alpha()
31 | )
32 |
33 | # base (ground) sprite
34 | IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha()
35 |
36 | # sounds
37 | if 'win' in sys.platform:
38 | soundExt = '.wav'
39 | else:
40 | soundExt = '.ogg'
41 |
42 | SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt)
43 | SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt)
44 | SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt)
45 | SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt)
46 | SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt)
47 |
48 | # select random background sprites
49 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()
50 |
51 | # select random player sprites
52 | IMAGES['player'] = (
53 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
54 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
55 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
56 | )
57 |
58 | # select random pipe sprites
59 | IMAGES['pipe'] = (
60 | pygame.transform.rotate(
61 | pygame.image.load(PIPE_PATH).convert_alpha(), 180),
62 | pygame.image.load(PIPE_PATH).convert_alpha(),
63 | )
64 |
65 | # hismask for pipes
66 | HITMASKS['pipe'] = (
67 | getHitmask(IMAGES['pipe'][0]),
68 | getHitmask(IMAGES['pipe'][1]),
69 | )
70 |
71 | # hitmask for player
72 | HITMASKS['player'] = (
73 | getHitmask(IMAGES['player'][0]),
74 | getHitmask(IMAGES['player'][1]),
75 | getHitmask(IMAGES['player'][2]),
76 | )
77 |
78 | return IMAGES, SOUNDS, HITMASKS
79 |
80 | def getHitmask(image):
81 | """returns a hitmask using an image's alpha."""
82 | mask = []
83 | for x in range(image.get_width()):
84 | mask.append([])
85 | for y in range(image.get_height()):
86 | mask[x].append(bool(image.get_at((x,y))[3]))
87 | return mask
88 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/game/wrapped_flappy_bird.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import random
4 | import pygame
5 | import flappy_bird_utils
6 | import pygame.surfarray as surfarray
7 | from pygame.locals import *
8 | from itertools import cycle
9 |
10 | FPS = 30
11 | SCREENWIDTH = 288
12 | SCREENHEIGHT = 512
13 |
14 | pygame.init()
15 | FPSCLOCK = pygame.time.Clock()
16 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))
17 | pygame.display.set_caption('Flappy Bird')
18 |
19 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()
20 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe
21 | BASEY = SCREENHEIGHT * 0.79
22 |
23 | PLAYER_WIDTH = IMAGES['player'][0].get_width()
24 | PLAYER_HEIGHT = IMAGES['player'][0].get_height()
25 | PIPE_WIDTH = IMAGES['pipe'][0].get_width()
26 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height()
27 | BACKGROUND_WIDTH = IMAGES['background'].get_width()
28 |
29 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])
30 |
31 |
32 | class GameState:
33 | def __init__(self):
34 | self.score = self.playerIndex = self.loopIter = 0
35 | self.playerx = int(SCREENWIDTH * 0.2)
36 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)
37 | self.basex = 0
38 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH
39 |
40 | newPipe1 = getRandomPipe()
41 | newPipe2 = getRandomPipe()
42 | self.upperPipes = [
43 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
44 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
45 | ]
46 | self.lowerPipes = [
47 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
48 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
49 | ]
50 |
51 | # player velocity, max velocity, downward accleration, accleration on flap
52 | self.pipeVelX = -4
53 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped
54 | self.playerMaxVelY = 10 # max vel along Y, max descend speed
55 | self.playerMinVelY = -8 # min vel along Y, max ascend speed
56 | self.playerAccY = 1 # players downward accleration
57 | self.playerFlapAcc = -9 # players speed on flapping
58 | self.playerFlapped = False # True when player flaps
59 |
60 | def frame_step(self, input_actions):
61 | pygame.event.pump()
62 |
63 | reward = 0.1
64 | terminal = False
65 |
66 | if sum(input_actions) != 1:
67 | raise ValueError('Multiple input action_space!')
68 |
69 | # input_actions[0] == 1: do nothing
70 | # input_actions[1] == 1: flap the bird
71 | if input_actions[1] == 1:
72 | if self.playery > -2 * PLAYER_HEIGHT:
73 | self.playerVelY = self.playerFlapAcc
74 | self.playerFlapped = True
75 | #SOUNDS['wing'].play()
76 |
77 | # check for score
78 | playerMidPos = self.playerx + PLAYER_WIDTH / 2
79 | for pipe in self.upperPipes:
80 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2
81 | if pipeMidPos <= playerMidPos < pipeMidPos + 4:
82 | self.score += 1
83 | #SOUNDS['point'].play()
84 | reward = 1
85 |
86 | # playerIndex basex change
87 | if (self.loopIter + 1) % 3 == 0:
88 | self.playerIndex = next(PLAYER_INDEX_GEN)
89 | self.loopIter = (self.loopIter + 1) % 30
90 | self.basex = -((-self.basex + 100) % self.baseShift)
91 |
92 | # player's movement
93 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped:
94 | self.playerVelY += self.playerAccY
95 | if self.playerFlapped:
96 | self.playerFlapped = False
97 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT)
98 | if self.playery < 0:
99 | self.playery = 0
100 |
101 | # move pipes to left
102 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
103 | uPipe['x'] += self.pipeVelX
104 | lPipe['x'] += self.pipeVelX
105 |
106 | # add new pipe when first pipe is about to touch left of screen
107 | if 0 < self.upperPipes[0]['x'] < 5:
108 | newPipe = getRandomPipe()
109 | self.upperPipes.append(newPipe[0])
110 | self.lowerPipes.append(newPipe[1])
111 |
112 | # remove first pipe if its out of the screen
113 | if self.upperPipes[0]['x'] < -PIPE_WIDTH:
114 | self.upperPipes.pop(0)
115 | self.lowerPipes.pop(0)
116 |
117 | # check if crash here
118 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery,
119 | 'index': self.playerIndex},
120 | self.upperPipes, self.lowerPipes)
121 | if isCrash:
122 | #SOUNDS['hit'].play()
123 | #SOUNDS['die'].play()
124 | terminal = True
125 | self.__init__()
126 | reward = -1
127 |
128 | # draw sprites
129 | SCREEN.blit(IMAGES['background'], (0,0))
130 |
131 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
132 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
133 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
134 |
135 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
136 | # print score so player overlaps the score
137 | # showScore(self.score)
138 | SCREEN.blit(IMAGES['player'][self.playerIndex],
139 | (self.playerx, self.playery))
140 |
141 | image_data = pygame.surfarray.array3d(pygame.display.get_surface())
142 | pygame.display.update()
143 | FPSCLOCK.tick(FPS)
144 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2)
145 | return image_data, reward, terminal
146 |
147 | def getRandomPipe():
148 | """returns a randomly generated pipe"""
149 | # y of gap between upper and lower pipe
150 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90]
151 | index = random.randint(0, len(gapYs)-1)
152 | gapY = gapYs[index]
153 |
154 | gapY += int(BASEY * 0.2)
155 | pipeX = SCREENWIDTH + 10
156 |
157 | return [
158 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe
159 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe
160 | ]
161 |
162 |
163 | def showScore(score):
164 | """displays score in center of screen"""
165 | scoreDigits = [int(x) for x in list(str(score))]
166 | totalWidth = 0 # total width of all numbers to be printed
167 |
168 | for digit in scoreDigits:
169 | totalWidth += IMAGES['numbers'][digit].get_width()
170 |
171 | Xoffset = (SCREENWIDTH - totalWidth) / 2
172 |
173 | for digit in scoreDigits:
174 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1))
175 | Xoffset += IMAGES['numbers'][digit].get_width()
176 |
177 |
178 | def checkCrash(player, upperPipes, lowerPipes):
179 | """returns True if player collders with base or pipes."""
180 | pi = player['index']
181 | player['w'] = IMAGES['player'][0].get_width()
182 | player['h'] = IMAGES['player'][0].get_height()
183 |
184 | # if player crashes into ground
185 | if player['y'] + player['h'] >= BASEY - 1:
186 | return True
187 | else:
188 |
189 | playerRect = pygame.Rect(player['x'], player['y'],
190 | player['w'], player['h'])
191 |
192 | for uPipe, lPipe in zip(upperPipes, lowerPipes):
193 | # upper and lower pipe rects
194 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
195 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
196 |
197 | # player and upper/lower pipe hitmasks
198 | pHitMask = HITMASKS['player'][pi]
199 | uHitmask = HITMASKS['pipe'][0]
200 | lHitmask = HITMASKS['pipe'][1]
201 |
202 | # if bird collided with upipe or lpipe
203 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
204 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)
205 |
206 | if uCollide or lCollide:
207 | return True
208 |
209 | return False
210 |
211 | def pixelCollision(rect1, rect2, hitmask1, hitmask2):
212 | """Checks if two objects collide and not just their rects"""
213 | rect = rect1.clip(rect2)
214 |
215 | if rect.width == 0 or rect.height == 0:
216 | return False
217 |
218 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y
219 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y
220 |
221 | for x in range(rect.width):
222 | for y in range(rect.height):
223 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
224 | return True
225 | return False
226 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/images/flappy_bird_demp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/flappy_bird_demp.gif
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/images/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/network.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/images/preprocess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/images/preprocess.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/hidden.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/hidden.txt
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/readout.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/logs_bird/readout.txt
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2880000.meta
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2890000.meta
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2900000.meta
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2910000.meta
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/bird-dqn-2920000.meta
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "bird-dqn-2920000"
2 | all_model_checkpoint_paths: "bird-dqn-2880000"
3 | all_model_checkpoint_paths: "bird-dqn-2890000"
4 | all_model_checkpoint_paths: "bird-dqn-2900000"
5 | all_model_checkpoint_paths: "bird-dqn-2910000"
6 | all_model_checkpoint_paths: "bird-dqn-2920000"
7 |
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/pretrained_model/bird-dqn-policy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_learning_flappy_bird/saved_networks/pretrained_model/bird-dqn-policy
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_q_network_algortihm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuliquan/reinforcement_learning_basic_book/b4bbcb11d5ed74a4b1f9e071d6af7744275f2b32/6-value_function_approximate/deep_q_network_algortihm.png
--------------------------------------------------------------------------------
/6-value_function_approximate/deep_q_network_template.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # author : zlq16
4 | # date : 2017/10/31
5 | import tensorflow as tf
6 | import cv2
7 | import random
8 | import numpy as np
9 | from collections import deque
10 |
11 | #Game的定义类,此处Game是什么不重要,只要提供执行Action的方法,获取当前游戏区域像素的方法即可
12 | class Game(object):
13 | def __init__(self): #Game初始化
14 | # action是MOVE_STAY、MOVE_LEFT、MOVE_RIGHT
15 | # ai控制棒子左右移动;返回游戏界面像素数和对应的奖励。(像素->奖励->强化棒子往奖励高的方向移动)
16 | pass
17 | def step(self, action):
18 | pass
19 | # learning_rate
20 | GAMMA = 0.99
21 | # 跟新梯度
22 | INITIAL_EPSILON = 1.0
23 | FINAL_EPSILON = 0.05
24 | # 测试观测次数
25 | EXPLORE = 500000
26 | OBSERVE = 500
27 | # 记忆经验大小
28 | REPLAY_MEMORY = 500000
29 | # 每次训练取出的记录数
30 | BATCH = 100
31 | # 输出层神经元数。代表3种操作-MOVE_STAY:[1, 0, 0] MOVE_LEFT:[0, 1, 0] MOVE_RIGHT:[0, 0, 1]
32 | output = 3
33 | MOVE_STAY =[1, 0, 0]
34 | MOVE_LEFT =[0, 1, 0]
35 | MOVE_RIGHT=[0, 0, 1]
36 | input_image = tf.placeholder("float", [None, 80, 100, 4]) # 游戏像素
37 | action = tf.placeholder("float", [None, output]) # 操作
38 |
39 | #定义CNN-卷积神经网络
40 | def convolutional_neural_network(input_image):
41 | weights = {'w_conv1':tf.Variable(tf.zeros([8, 8, 4, 32])),
42 | 'w_conv2':tf.Variable(tf.zeros([4, 4, 32, 64])),
43 | 'w_conv3':tf.Variable(tf.zeros([3, 3, 64, 64])),
44 | 'w_fc4':tf.Variable(tf.zeros([3456, 784])),
45 | 'w_out':tf.Variable(tf.zeros([784, output]))}
46 |
47 | biases = {'b_conv1':tf.Variable(tf.zeros([32])),
48 | 'b_conv2':tf.Variable(tf.zeros([64])),
49 | 'b_conv3':tf.Variable(tf.zeros([64])),
50 | 'b_fc4':tf.Variable(tf.zeros([784])),
51 | 'b_out':tf.Variable(tf.zeros([output]))}
52 |
53 | conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights['w_conv1'], strides = [1, 4, 4, 1], padding = "VALID") + biases['b_conv1'])
54 | conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights['w_conv2'], strides = [1, 2, 2, 1], padding = "VALID") + biases['b_conv2'])
55 | conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights['w_conv3'], strides = [1, 1, 1, 1], padding = "VALID") + biases['b_conv3'])
56 | conv3_flat = tf.reshape(conv3, [-1, 3456])
57 | fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights['w_fc4']) + biases['b_fc4'])
58 |
59 | output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out']
60 | return output_layer
61 |
62 | #训练神经网络
63 | def train_neural_network(input_image):
64 | argmax = tf.placeholder("float", [None, output])
65 | gt = tf.placeholder("float", [None])
66 |
67 | #损失函数
68 | predict_action = convolutional_neural_network(input_image)
69 | action = tf.reduce_sum(tf.mul(predict_action, argmax), reduction_indices = 1) #max(Q(S,:))
70 | cost = tf.reduce_mean(tf.square(action - gt))
71 | optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost)
72 |
73 | #游戏开始
74 | game = Game()
75 | D = deque()
76 | _, image = game.step(MOVE_STAY)
77 | image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
78 | ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
79 | input_image_data = np.stack((image, image, image, image), axis = 2)
80 |
81 | with tf.Session() as sess:
82 | #初始化神经网络各种参数
83 | sess.run(tf.initialize_all_variables())
84 | #保存神经网络参数的模块
85 | saver = tf.train.Saver()
86 |
87 | #总的运行次数
88 | n = 0
89 | epsilon = INITIAL_EPSILON
90 | while True:
91 |
92 | #神经网络输出的是Q(S,:)值
93 | action_t = predict_action.eval(feed_dict = {input_image : [input_image_data]})[0]
94 | argmax_t = np.zeros([output], dtype=np.int)
95 |
96 | #贪心选取动作
97 | if(random.random() <= INITIAL_EPSILON):
98 | maxIndex = random.randrange(output)
99 | else:
100 | maxIndex = np.argmax(action_t)
101 |
102 | #将action对应的Q(S,a)最大值提取出来
103 | argmax_t[maxIndex] = 1
104 |
105 | #贪婪的部分开始不断的增加
106 | if epsilon > FINAL_EPSILON:
107 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
108 |
109 | #将选取的动作带入到环境,观察环境状态S'与回报reward
110 | reward, image = game.step(list(argmax_t))
111 |
112 | #将得到的图形进行变换用于神经网络的输出
113 | image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
114 | ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
115 | image = np.reshape(image, (80, 100, 1))
116 | input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis = 2)
117 |
118 | #将S,a,r,S'记录的大脑中
119 | D.append((input_image_data, argmax_t, reward, input_image_data1))
120 |
121 | #大脑的记忆是有一定的限度的
122 | if len(D) > REPLAY_MEMORY:
123 | D.popleft()
124 |
125 | #如果达到观察期就要进行神经网络训练
126 | if n > OBSERVE:
127 |
128 | #随机的选取一定记忆的数据进行训练
129 | minibatch = random.sample(D, BATCH)
130 | #将里面的每一个记忆的S提取出来
131 | input_image_data_batch = [d[0] for d in minibatch]
132 | #将里面的每一个记忆的a提取出来
133 | argmax_batch = [d[1] for d in minibatch]
134 | #将里面的每一个记忆回报提取出来
135 | reward_batch = [d[2] for d in minibatch]
136 | #将里面的每一个记忆的下一步转台提取出来
137 | input_image_data1_batch = [d[3] for d in minibatch]
138 |
139 | gt_batch = []
140 | #利用已经有的求解Q(S',:)
141 | out_batch = predict_action.eval(feed_dict = {input_image : input_image_data1_batch})
142 |
143 | #利用bellman优化得到长期的回报r + γmax(Q(s',:))
144 | for i in range(0, len(minibatch)):
145 | gt_batch.append(reward_batch[i] + GAMMA * np.max(out_batch[i]))
146 |
147 | #利用事先定义的优化函数进行优化神经网络参数
148 | print("gt_batch:", gt_batch, "argmax:", argmax_batch)
149 | optimizer.run(feed_dict = {gt : gt_batch, argmax : argmax_batch, input_image : input_image_data_batch})
150 |
151 | input_image_data = input_image_data1
152 | n = n+1
153 | print(n, "epsilon:", epsilon, " ", "action:", maxIndex, " ", "_reward:", reward)
154 |
155 | train_neural_network(input_image)
156 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 代码说明
2 | ## 描述
3 | > 这是一个我学习《深入浅出强化学习-原理入门》的学习代码仓库,主要是一些书上的例子和书后面的练习题的代码
4 | ## 目录
5 | ### 1-gym二次开发(gym develop)
6 | 1. [gym二次开发相关文件配置](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/README.md)
7 | 2. [改写gym下的core.py文件](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/core.py)
8 | 3. [利用gym二次开发的一个网格游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/grid_game.py)
9 | 4. [利用gym二次开发的一个迷宫游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/1-gym_developing/maze_game.py)
10 | ### 2-马尔科夫决策过程(Markov Decision Process)
11 | 1. [学习生活的例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/2-markov_decision_process/our_life.py)
12 | 2. [里面对于迷宫的环境模拟的课后作业](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/2-markov_decision_process/game.py)
13 | ### 3-动态规划(Dynamic Program)
14 | 1. [网格游戏在均匀策略下的策略评估例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_average_policy.py)
15 | 2. [策略迭代算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/policy_iteration_algorithm.png)
16 | 3. [网格游戏在贪婪策略下的策略迭代例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_policy_iterate.py)
17 | 4. [值迭代算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/value_iteration_algorithm.png)
18 | 5. [网格游戏在贪婪测略下的值迭代例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/grid_game_with_value_iterate.py)
19 | 6. [迷宫游戏在动态规划下的课后作业](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/3-dynamic_program/maze_game_with_dynamic_program.py)
20 | ### 4-蒙特卡洛值迭代(Monte Carlo)
21 | 1. [蒙特卡罗方法采样](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/4-monte_carlo/monte_carlo_sample.py)
22 | 2. [蒙特卡罗方法评估](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/4-monte_carlo/monte_carlo_evaluate.py)
23 | ### 5-时间差分值迭代(Temporal Difference)
24 | 1. [Q-learning算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/5-temporal_difference/q_learning_algortihm.png)
25 | 2. [Sarsa算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/sarsa_algorithm.png)
26 | 3. [Sarsa(λ)算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/sarsa_lambda_algorithm.png)
27 | 4. [利用gym二次开发的一个推箱子游戏例子](https://github.com/zhuliquan/reinforcement_learning_basic_book/master/5-temporal_difference/push_box_game.py)
28 | 5. [利用时间差分学习推箱子实例](https://github.com/zhuliquan/reinforcement_learning_basic_book/tree/master/5-temporal_difference/push_box_game)
29 | ### 6-值函数逼近(Value Function Approximate)
30 | 1. [Deep Q-learning算法流程图](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/6-value_function_approximate/deep_q_network_algortihm.png)
31 | 2. [Deep Q-learning算法模板](https://github.com/zhuliquan/reinforcement_learning_basic_book/blob/master/6-value_function_approximate/deep_q_network_template.py)
32 | 3. [利用Deep Q-learning写的flappy游戏](https://github.com/zhuliquan/reinforcement_learning_basic_book/tree/master/6-value_function_approximate/deep_learning_flappy_bird)
33 |
--------------------------------------------------------------------------------