├── .gitignore
├── AlphaZero
├── current_policy_0.model
├── current_policy_1.model
├── entropy.npy
├── game.py
├── loss.npy
├── main.py
├── mcts_alphaZero.py
├── play.py
├── policy_value_net_pytorch.py
└── train.py
├── README.md
├── ch_0
├── sec_1
│ ├── kb_game.py
│ └── main.py
└── sec_2
│ ├── images
│ ├── background.jpg
│ ├── bird_female.png
│ ├── bird_male.png
│ └── obstacle.png
│ ├── load_images.py
│ └── yuan_yang_env.py
├── ch_1
├── sec_3
│ ├── dp_policy_iter.py
│ ├── dp_value_iter.py
│ ├── load_images.py
│ └── yuan_yang_env.py
├── sec_4
│ ├── load_images.py
│ ├── mc_rl.py
│ └── yuan_yang_env_mc.py
├── sec_5
│ ├── TD_RL.py
│ ├── load_images.py
│ └── yuan_yang_env_td.py
└── sec_6
│ ├── flappy_bird
│ ├── assets
│ │ ├── audio
│ │ │ ├── die.ogg
│ │ │ ├── die.wav
│ │ │ ├── hit.ogg
│ │ │ ├── hit.wav
│ │ │ ├── point.ogg
│ │ │ ├── point.wav
│ │ │ ├── swoosh.ogg
│ │ │ ├── swoosh.wav
│ │ │ ├── wing.ogg
│ │ │ └── wing.wav
│ │ └── sprites
│ │ │ ├── 0.png
│ │ │ ├── 1.png
│ │ │ ├── 2.png
│ │ │ ├── 3.png
│ │ │ ├── 4.png
│ │ │ ├── 5.png
│ │ │ ├── 6.png
│ │ │ ├── 7.png
│ │ │ ├── 8.png
│ │ │ ├── 9.png
│ │ │ ├── background-black.png
│ │ │ ├── base.png
│ │ │ ├── pipe-green.png
│ │ │ ├── redbird-downflap.png
│ │ │ ├── redbird-midflap.png
│ │ │ └── redbird-upflap.png
│ ├── dqn_agent.py
│ ├── game
│ │ ├── flappy_bird_utils.py
│ │ ├── keyboard_agent.py
│ │ └── wrapped_flappy_bird.py
│ ├── images
│ │ ├── flappy_bird_demp.gif
│ │ ├── network.png
│ │ └── preprocess.png
│ ├── logs_bird
│ │ ├── hidden.txt
│ │ └── readout.txt
│ ├── saved_networks0
│ │ ├── saved_model.pb
│ │ └── variables
│ │ │ ├── variables.data-00000-of-00002
│ │ │ ├── variables.data-00001-of-00002
│ │ │ └── variables.index
│ └── saved_networks1
│ │ ├── saved_model.pb
│ │ └── variables
│ │ ├── variables.data-00000-of-00002
│ │ ├── variables.data-00001-of-00002
│ │ └── variables.index
│ ├── lfa_rl.py
│ ├── load_images.py
│ └── yuan_yang_env_fa.py
└── pyTorch_learn
├── a_1.py
└── a_2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
3 | /pyTorch_learn/data/
4 |
5 | *.psd
--------------------------------------------------------------------------------
/AlphaZero/current_policy_0.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/current_policy_0.model
--------------------------------------------------------------------------------
/AlphaZero/current_policy_1.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/current_policy_1.model
--------------------------------------------------------------------------------
/AlphaZero/entropy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/entropy.npy
--------------------------------------------------------------------------------
/AlphaZero/game.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Board:
4 | def __init__(self, **kwargs):
5 | self.width = int(kwargs.get('width', 8))
6 | self.height = int(kwargs.get('height', 8))
7 | self.n_in_row = int(kwargs.get('n_in_row', 5))
8 | self.players = (1, 2)
9 | self.states = {}
10 |
11 | def init_board(self, start_player=0):
12 | if self.width < self.n_in_row or self.height < self.n_in_row:
13 | raise Exception('board width and can not be \
14 | less than {}'.format(self.n_in_row))
15 | # 当前 player 编号
16 | self.current_player = self.players[start_player]
17 | self.available = list(range(self.width * self.height))
18 | self.states = {}
19 | self.last_move = -1
20 |
21 | def move_to_location(self, move):
22 | """
23 | 3*3 board's moves like:
24 | 6 7 8
25 | 3 4 5
26 | 0 1 2
27 | and move 5's location is (1, 2)
28 | """
29 | h = move // self.width
30 | w = move % self.width
31 | return [h, w]
32 |
33 | def location_to_move(self, location):
34 | if len(location) != 2:
35 | return -1
36 | h = location[0]
37 | w = location[1]
38 | move = h * self.width + w
39 | if move not in range(self.width * self.height):
40 | return -1
41 | return move
42 |
43 | def do_move(self, move):
44 | self.states[move] = self.current_player
45 | self.available.remove(move)
46 | self.current_player = (
47 | self.players[0] if self.current_player == self.players[1]
48 | else self.players[1]
49 | )
50 | self.last_move = move
51 |
52 | def get_current_player(self):
53 | return self.current_player
54 |
55 | def current_state(self):
56 | square_state = np.zeros((4, self.width, self.height))
57 | if self.states:
58 | moves, players = np.array(list(zip(*self.states.items())))
59 | move_curr = moves[players == self.current_player]
60 | move_oppo = moves[players != self.current_player]
61 | square_state[0][move_curr // self.width, move_curr % self.height] = 1.0
62 | square_state[1][move_oppo // self.width, move_oppo % self.height] = 1.0
63 | # 第 3 个平面,整个棋盘只有一个 1 ,表达上一手下的哪个
64 | square_state[2][self.last_move // self.width, self.last_move % self.height] = 1.0
65 | if len(self.states) % 2 == 0:
66 | # 第 4 个平面,全 1 或者全 0 ,描述现在是哪一方
67 | square_state[3][:, :] = 1.0
68 | # 为什么这里要反转?
69 | return square_state[:, ::-1, :]
70 |
71 | def has_a_winner(self):
72 | width = self.width
73 | height = self.height
74 | states = self.states
75 | n = self.n_in_row
76 |
77 | # 是否满足双方至少下了 5 子(粗略地)
78 | moved = list(set(range(width * height)) - set(self.available))
79 | if len(moved) < self.n_in_row + 2:
80 | return False, -1
81 |
82 | for m in moved:
83 | h = m // width
84 | w = m % width
85 | player = states[m]
86 |
87 | if w in range(width - n + 1) \
88 | and len(set(states.get(i, -1) for i in range(m, m+n))) == 1:
89 | return True, player
90 |
91 | if h in range(height - n + 1) \
92 | and len(set(states.get(i, -1) for i in range(m, m+n*width, width))) == 1:
93 | return True, player
94 |
95 | if w in range(width - n + 1) and h in range(height - n + 1) \
96 | and len(set(states.get(i, -1) for i in range(m, m+n*(width+1), width+1))) == 1:
97 | return True, player
98 |
99 | if w in range(n - 1, width) and h in range(height - n + 1) \
100 | and len(set(states.get(i, -1) for i in range(m, m+n*(width-1), width-1))) == 1:
101 | return True, player
102 |
103 | return False, -1
104 |
105 | def game_end(self):
106 | win, winner = self.has_a_winner()
107 | if win:
108 | return True, winner
109 | elif not len(self.available):
110 | return True, -1
111 | return False, -1
112 |
113 | class Game:
114 | def __init__(self, board: Board):
115 | self.board = board
116 |
117 | def start_self_play(self, player, is_shown=False, temp=1e-3):
118 | self.board.init_board()
119 | p1, p2 = self.board.players
120 | states, mcts_probs, current_players = [], [], []
121 | while True:
122 | move, move_probs = player.get_action(self.board, temp=temp, return_prob=1)
123 | # 保存 self-play 数据
124 | states.append(self.board.current_state())
125 | mcts_probs.append(move_probs)
126 | current_players.append(self.board.current_player)
127 | # 执行一步落子
128 | self.board.do_move(move)
129 | if is_shown:
130 | self.graphic(self.board, p1, p2)
131 | end, winner = self.board.game_end()
132 | if end:
133 | # 从每一个 state 对应的 player 的视角保存胜负信息
134 | winner_z = np.zeros(len(current_players))
135 | if winner != -1:
136 | winner_z[np.array(current_players) == winner] = 1.0
137 | winner_z[np.array(current_players) != winner] = - 1.0
138 | player.reset_player()
139 | if is_shown:
140 | if winner != -1:
141 | print("Game end. Winner is player: ", winner)
142 | else:
143 | print("Game end. Tie")
144 | return winner, zip(states, mcts_probs, winner_z)
145 |
146 | def start_play(self, player1, player2, start_player=0, is_shown=1):
147 | if start_player not in (0, 1):
148 | raise Exception('start_player should be either 0 or 1')
149 | self.board.init_board(start_player)
150 | p1, p2 = self.board.players
151 | player1.set_player_ind(p1)
152 | player2.set_player_ind(p2)
153 | players = {p1: player1, p2: player2}
154 | if is_shown:
155 | self.graphic(self.board, player1.player, player2.player)
156 | while True:
157 | current_player = self.board.get_current_player()
158 | player_in_turn = players[current_player]
159 | move = player_in_turn.get_action(self.board)
160 | self.board.do_move(move)
161 | if is_shown:
162 | self.graphic(self.board, player1.player, player2.player)
163 | end, winner = self.board.game_end()
164 | if end:
165 | if is_shown:
166 | if winner != -1:
167 | print("Game end. Winner is ", players[winner])
168 | else:
169 | print("Game end. Tie")
170 | return winner
171 |
172 | def graphic(self, board, player1, player2):
173 | width = board.width
174 | height = board.height
175 |
176 | print("Player", player1, "with X".rjust(3))
177 | print("Player", player2, "with O".rjust(3))
178 | print()
179 | for x in range(width):
180 | print("{0:8}".format(x), end='')
181 | print('\r\n')
182 | for i in range(height - 1, -1, -1):
183 | print("{0:4d}".format(i), end='')
184 | for j in range(width):
185 | loc = i * width + j
186 | p = board.states.get(loc, -1)
187 | if p == player1:
188 | print('x'.center(8), end='')
189 | elif p == player2:
190 | print('O'.center(8), end='')
191 | else:
192 | print('_'.center(8), end='')
193 | print('\r\n\r\n')
194 |
--------------------------------------------------------------------------------
/AlphaZero/loss.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/loss.npy
--------------------------------------------------------------------------------
/AlphaZero/main.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 | DIRNAME = osp.dirname(__file__)
4 | sys.path.append(DIRNAME + '/..')
5 |
6 | from AlphaZero.train import TrainPipeline
7 |
8 | if __name__ == "__main__":
9 | training_pipeline = TrainPipeline(DIRNAME + '/current_policy.model')
10 | training_pipeline.run()
11 |
12 |
--------------------------------------------------------------------------------
/AlphaZero/mcts_alphaZero.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 |
4 | def Softmax(x):
5 | probs = np.exp(x - np.max(x))
6 | probs /= np.sum(probs)
7 | return probs
8 |
9 | class TreeNode:
10 | """ A node in the MCTS tree. """
11 | def __init__(self, parent, prior_p):
12 | self._parent = parent
13 | self._children = {}
14 | self._n_visits = 0
15 | self._Q = 0
16 | self._u = 0
17 | self._P = prior_p
18 |
19 | def select(self, c_puct):
20 | """ Return: A tuple of (action, next_node) """
21 | return max(
22 | self._children.items(),
23 | key=lambda act_node: act_node[1].get_value(c_puct)
24 | )
25 |
26 | def get_value(self, c_puct):
27 | self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
28 | return self._Q + self._u
29 |
30 | def expand(self, action_priors):
31 | for action, prob in action_priors:
32 | if action not in self._children:
33 | self._children[action] = TreeNode(self, prob)
34 |
35 | def update(self, leaf_value):
36 | self._n_visits += 1
37 | self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
38 |
39 | def update_recursive(self, leaf_value):
40 | if self._parent:
41 | self._parent.update_recursive(-leaf_value)
42 | self.update(leaf_value)
43 |
44 | def is_leaf(self):
45 | return self._children == {}
46 |
47 | def is_root(self):
48 | return self._parent is None
49 |
50 |
51 | class MCTS(object):
52 | """ An implementation of Monte Carlo Tree Search. """
53 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
54 | self._root = TreeNode(None, 1.0)
55 | self._policy = policy_value_fn
56 | self._c_puct = c_puct
57 | self._n_playout = n_playout
58 |
59 | def _playout(self, state):
60 | """ 完整的执行选择、扩展评估和回传更新等步骤 """
61 | node = self._root
62 | # 选择
63 | while True:
64 | if node.is_leaf():
65 | break
66 | action, node = node.select(self._c_puct)
67 | state.do_move(action)
68 | # 扩展及评估
69 | action_probs, leaf_value = self._policy(state)
70 | end, winner = state.game_end()
71 | if not end:
72 | node.expand(action_probs)
73 | else:
74 | if winner == -1: # 平局
75 | leaf_value = 0.0
76 | else:
77 | leaf_value = (
78 | 1.0 if winner == state.get_current_player() else -1.0
79 | )
80 | # 回传更新
81 | node.update_recursive(-leaf_value)
82 |
83 | def get_move_probs(self, state, temp=1e-3):
84 | for n in range(self._n_playout):
85 | state_copy = copy.deepcopy(state)
86 | self._playout(state_copy)
87 |
88 | act_visits = [
89 | (act, node._n_visits) for act, node in self._root._children.items()
90 | ]
91 | acts, visits = zip(*act_visits)
92 | # 注意,这里是根据 visits 算出的动作选择概率
93 | act_probs = Softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
94 |
95 | return acts, act_probs
96 |
97 | def update_with_move(self, last_move):
98 | if last_move in self._root._children:
99 | self._root = self._root._children[last_move]
100 | self._root._parent = None
101 | else:
102 | self._root = TreeNode(None, 1.0)
103 |
104 | class MCTSPlayer:
105 | """ AI player based on MCTS """
106 | def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0):
107 | self.mcts = MCTS(policy_value_function, c_puct, n_playout)
108 | self._is_selfplay = is_selfplay
109 |
110 | def get_action(self, board, temp=1e-3, return_prob=0):
111 | sensible_moves = board.available
112 | move_probs = np.zeros(board.width * board.height)
113 | if len(sensible_moves) > 0:
114 | acts, probs = self.mcts.get_move_probs(board, temp)
115 | move_probs[list(acts)] = probs
116 | if self._is_selfplay:
117 | move = np.random.choice(
118 | acts, p = 0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs)))
119 | )
120 | # 更新根节点,复用搜索子树
121 | self.mcts.update_with_move(move)
122 | else:
123 | move = np.random.choice(acts, p=probs)
124 | # 重置根节点
125 | self.mcts.update_with_move(-1)
126 | location = board.move_to_location(move)
127 | print("AI move: %d, %d\n".format(location[0], location[1]))
128 | if return_prob:
129 | return move, move_probs
130 | else:
131 | return move
132 | else:
133 | print("WARNING: the board is full")
134 |
135 | def set_player_ind(self, p):
136 | self.player = p
137 |
138 | def reset_player(self):
139 | self.mcts.update_with_move(-1)
140 |
141 | def __str__(self):
142 | return "MCTS {}".format(self.player)
143 |
--------------------------------------------------------------------------------
/AlphaZero/play.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 | DIRNAME = osp.dirname(__file__)
4 | sys.path.append(DIRNAME + '/..')
5 |
6 | from AlphaZero.game import Board, Game
7 | from AlphaZero.mcts_alphaZero import MCTSPlayer
8 | from AlphaZero.policy_value_net_pytorch import PolicyValueNet
9 |
10 | import argparse
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('-t', '--trained', action='store_true')
13 | args = parser.parse_args()
14 |
15 | """
16 | input location as '3,3' to play
17 | """
18 |
19 | class Human:
20 | """ human player """
21 | def __init__(self):
22 | self.player = None
23 |
24 | def set_player_ind(self, p):
25 | self.player = p
26 |
27 | def get_action(self, board):
28 | try:
29 | location = input("Your move: ")
30 | if isinstance(location, str): # for Python3
31 | location = [int(n, 10) for n in location.split(",")]
32 | move = board.location_to_move(location)
33 | except Exception as e:
34 | move = -1
35 | if move == -1 or move not in board.available:
36 | print("invalid move")
37 | move = self.get_action(board)
38 | return move
39 |
40 | def __str__(self):
41 | return "Human {}".format(self.player)
42 |
43 | def run():
44 | n = 5
45 | width, height = 8, 8
46 | model_file = DIRNAME + '/current_policy' + ('_1' if args.trained else '_0') + '.model'
47 | try:
48 | board = Board(width=width, height=height, n_in_row=n)
49 | game = Game(board)
50 | # 创建 AI player
51 | best_policy = PolicyValueNet(width, height, model_file=model_file)
52 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400)
53 | # 创建 Human player ,输入样例: 2,3
54 | human = Human()
55 | # 设置 start_player = 0 可以让人类先手
56 | game.start_play(human, mcts_player, start_player=1, is_shown=1)
57 | except KeyboardInterrupt:
58 | print('\n\rquit')
59 |
60 | if __name__ == "__main__":
61 | run()
--------------------------------------------------------------------------------
/AlphaZero/policy_value_net_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 | def set_learning_rate(optimizer, lr):
9 | """ Set the learning rate to the given value """
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = lr
12 |
13 | class Net(nn.Module):
14 | """ 定义策略价值网络结构 """
15 | def __init__(self, board_width, board_height):
16 | super().__init__()
17 | self.board_width = board_width
18 | self.board_height = board_height
19 | # common layers
20 | self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
21 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
22 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
23 | # action policy layers
24 | self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
25 | self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height)
26 | # state value layers
27 | self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
28 | self.val_fc1 = nn.Linear(2 * board_width * board_height, 64)
29 | self.val_fc2 = nn.Linear(64, 1)
30 |
31 | def forward(self, state_input):
32 | # common layers
33 | x = F.relu(self.conv1(state_input))
34 | x = F.relu(self.conv2(x))
35 | x = F.relu(self.conv3(x))
36 | # action policy layers
37 | x_act = F.relu(self.act_conv1(x))
38 | x_act = x_act.view(-1, 4 * self.board_width * self.board_height)
39 | x_act = F.log_softmax(self.act_fc1(x_act), dim=1)
40 | # state value layers
41 | x_val = F.relu(self.val_conv1(x))
42 | x_val = x_val.view(-1, 2 * self.board_width * self.board_height)
43 | x_val = F.relu(self.val_fc1(x_val))
44 | x_val = torch.tanh(self.val_fc2(x_val))
45 | return x_act, x_val
46 |
47 | class PolicyValueNet:
48 | def __init__(self, board_width, board_height, model_file=None):
49 | self.board_width = board_width
50 | self.board_height = board_height
51 | self.l2_const = 1e-4
52 | self.policy_value_net = Net(board_width, board_height)
53 | self.optimizer = optim.Adam(self.policy_value_net.parameters(), weight_decay=self.l2_const)
54 | if model_file:
55 | net_params = torch.load(model_file)
56 | self.policy_value_net.load_state_dict(net_params)
57 |
58 | def policy_value_fn(self, board):
59 | legal_positions = board.available
60 | current_state = np.ascontiguousarray(
61 | board.current_state().reshape(
62 | -1, 4, self.board_width, self.board_height
63 | )
64 | )
65 | log_act_probs, value = self.policy_value_net(
66 | Variable(torch.from_numpy(current_state)).float()
67 | )
68 | act_probs = np.exp(log_act_probs.data.numpy().flatten())
69 | act_probs = zip(legal_positions, act_probs[legal_positions])
70 | value = value.data[0][0]
71 | return act_probs, value
72 |
73 | def train_step(self, state_batch, mcts_probs, winner_batch, lr):
74 | """ perform a training step """
75 | state_batch = Variable(torch.FloatTensor(state_batch))
76 | mcts_probs = Variable(torch.FloatTensor(mcts_probs))
77 | winner_batch = Variable(torch.FloatTensor(winner_batch))
78 | # zero the parameter gradients
79 | self.optimizer.zero_grad()
80 | # set learning rate
81 | set_learning_rate(self.optimizer, lr)
82 | # forward
83 | log_act_probs, value = self.policy_value_net(state_batch)
84 | # define the loss
85 | value_loss = F.mse_loss(value.view(-1), winner_batch)
86 | policy_loss = - torch.mean(
87 | torch.sum(mcts_probs * log_act_probs, 1)
88 | )
89 | loss = value_loss + policy_loss
90 | # backward and optimize
91 | loss.backward()
92 | self.optimizer.step()
93 | # policy entropy, for monitoring only
94 | entropy = - torch.mean(
95 | torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
96 | )
97 | return loss.item(), entropy.item()
98 |
99 | def get_policy_param(self):
100 | net_params = self.policy_value_net.state_dict()
101 | return net_params
102 |
103 | def save_model(self, model_file):
104 | """ save model params to file """
105 | net_params = self.get_policy_param()
106 | torch.save(net_params, model_file)
107 |
108 |
--------------------------------------------------------------------------------
/AlphaZero/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from collections import deque
4 | import os.path as osp
5 | from .game import Board, Game
6 | from .mcts_alphaZero import MCTSPlayer
7 | from .policy_value_net_pytorch import PolicyValueNet
8 |
9 | DIRNAME = osp.dirname(__file__)
10 |
11 | class TrainPipeline:
12 | def __init__(self, init_model=None):
13 | # 棋盘相关参数
14 | self.board_width = 8
15 | self.board_height = 8
16 | self.n_in_row = 5
17 | self.board = Board(
18 | width=self.board_width,
19 | height=self.board_height,
20 | n_in_row=self.n_in_row
21 | )
22 | self.game = Game(self.board)
23 | # 自我对弈相关参数
24 | self.temp = 1.0
25 | self.c_puct = 5
26 | self.n_playout = 400
27 | # 训练更新相关参数
28 | self.learn_rate = 2e-3
29 | self.buffer_size = 10000
30 | self.batch_size = 512
31 | self.data_buffer = deque(maxlen=self.buffer_size)
32 | self.check_freq = 2 # 保存模型的概率
33 | self.game_batch_num = 3000 # 训练更新的次数
34 | if init_model:
35 | # 如果提供了初始模型,则加载其用于初始化策略价值网络
36 | self.policy_value_net = PolicyValueNet(
37 | self.board_width,
38 | self.board_height,
39 | model_file=init_model
40 | )
41 | else:
42 | # 随机初始化策略价值网络
43 | self.policy_value_net = PolicyValueNet(
44 | self.board_width,
45 | self.board_height
46 | )
47 | self.mcts_player = MCTSPlayer(
48 | self.policy_value_net.policy_value_fn,
49 | c_puct=self.c_puct,
50 | n_playout=self.n_playout,
51 | is_selfplay=1
52 | )
53 |
54 | def run(self):
55 | """ 执行完整的训练流程 """
56 | for i in range(self.game_batch_num):
57 | episode_len = self.collect_selfplay_data()
58 | if len(self.data_buffer) > self.batch_size:
59 | loss, entropy = self.policy_update()
60 | print((
61 | "batch i:{}, "
62 | "episode_len:{}, "
63 | "loss:{:.4f}, "
64 | "entropy:{:.4f}"
65 | ).format(i+1, episode_len, loss, entropy))
66 | # save performance per update
67 | loss_array = np.load(DIRNAME + '/loss.npy')
68 | entropy_array = np.load(DIRNAME + '/entropy.npy')
69 | loss_array = np.append(loss_array, loss)
70 | entropy_array = np.append(entropy_array, entropy)
71 | np.save(DIRNAME + '/loss.npy', loss_array)
72 | np.save(DIRNAME + '/entropy.npy', entropy_array)
73 | del loss_array
74 | del entropy_array
75 | else:
76 | print("batch i:{}, episode_len:{}".format(i+1, episode_len))
77 | # 定期保存模型
78 | if (i+1) % self.check_freq == 0:
79 | self.policy_value_net.save_model(
80 | DIRNAME + '/current_policy.model'
81 | )
82 |
83 | def collect_selfplay_data(self):
84 | """ collect self-play data for training """
85 | winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp)
86 | play_data = list(play_data)[:]
87 | episode_len = len(play_data)
88 | # augment the data
89 | play_data = self.get_equi_data(play_data)
90 | self.data_buffer.extend(play_data)
91 | return episode_len
92 |
93 | def get_equi_data(self, play_data):
94 | """ play_data: [(state, mcts_prob, winner_z), ...] """
95 | extend_data = []
96 | for state, mcts_prob, winner in play_data:
97 | for i in [1, 2, 3, 4]:
98 | # 逆时针旋转
99 | equi_state = np.array([np.rot90(s, i) for s in state])
100 | equi_mcts_prob = np.rot90(np.flipud(
101 | mcts_prob.reshape(self.board_height, self.board_width)
102 | ), i)
103 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
104 | # 水平翻转
105 | equi_state = np.array([np.fliplr(s) for s in equi_state])
106 | equi_mcts_prob = np.fliplr(equi_mcts_prob)
107 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
108 | return extend_data
109 |
110 | def policy_update(self):
111 | """ update the policy-value net """
112 | mini_batch = random.sample(self.data_buffer, self.batch_size)
113 | state_batch = [data[0] for data in mini_batch]
114 | mcts_probs_batch = [data[1] for data in mini_batch]
115 | winner_batch = [data[2] for data in mini_batch]
116 | loss, entropy = self.policy_value_net.train_step(
117 | state_batch, mcts_probs_batch, winner_batch, self.learn_rate
118 | )
119 | return loss, entropy
120 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 强化学习练手场地
2 | 这个仓库本来用于复现某本国内强化学习教材的案例,奈何这本书写得实在太差了,因此弃坑此书。我的书评在这里:[豆瓣书评](https://book.douban.com/review/12673161/)。
3 |
4 | 现在,这个仓库用于存储一些强化学习练手小项目与算法实验。具体来讲,就是不至于单独成一个 repo 的项目,但是又值得拿出来讨论的代码。
5 |
6 | ### 我的笔记分布
7 | - 🥊 入门学习 / 读书笔记 [GitHub链接:PiperLiu/Reinforcement-Learning-practice-zh](https://github.com/PiperLiu/Reinforcement-Learning-practice-zh)
8 | - 💻 阅读论文 / 视频课程的笔记 [GitHub链接:PiperLiu/introRL](https://github.com/PiperLiu/introRL)
9 | - ✨ 大小算法 / 练手操场 [GitHub链接:PiperLiu/Approachable-Reinforcement-Learning](https://github.com/PiperLiu/Approachable-Reinforcement-Learning)
10 |
11 | ### 仓库目录
12 | - 强化学习基础复现案例 [转到摘要与索引](#强化学习基础复现案例)
13 |
14 |
15 | ****
16 |
17 | ## 强化学习基础复现案例
18 |
19 | 参考《深入浅出强化学习:编程实战》内容。奈何实在写得不行,因此照着复现了第0、1章与AlphaZero后,弃坑。
20 |
21 | - [附录 A PyTorch 入门](#A)
22 | - 第 0 篇 先导篇
23 | - - [1 一个及其简单的强化学习实例](#sec_1)
24 | - - [2 马尔可夫决策过程](#sec_2)
25 | - 第 1 篇 基于值函数的方法
26 | - - [3 基于动态规划的方法](#sec_3)
27 | - - [4 基于蒙特卡洛的方法](#sec_4)
28 | - - [5 基于时间差分的方法](#sec_5)
29 | - - [6 基于函数逼近的方法](#sec_6)
30 | - AlphaZero 实战:五子棋
31 | - - [链接](#sec_11)
32 |
33 | ### 附录 A PyTorch 入门
34 |
35 |
36 | [./pyTorch_learn/](./pyTorch_learn/)
37 |
38 | 介绍了 PyTorch 的基本使用,主要实例为:构建了一个极其单薄简单的卷积神经网络,数据集为 `CIFAR10` 。学习完附录 A ,我主要收获:
39 | - 输入 PyTorch 的 `nn.Module` 应该是 `mini-batch` ,即比正常数据多一个维度;
40 | - 输入 `nn.Module` 应该是 `Variable` 包裹的;
41 | - 在网络类中, `__init__()` 并没有真正定义网络结构的关系,网络结构的输入输出关系在 `forward()` 中定义。
42 |
43 | 此外,让我们梳理一下神经网络的“学习”过程:
44 | ```python
45 | import torch
46 | import torch.nn as nn
47 | import torch.nn.functional as F
48 | from torch.autograd import Variabl
49 | import torch.optim as optim
50 |
51 | # 神经网络对象
52 | net = Net()
53 | # 损失函数:交叉熵
54 | criterion = nn.CrossEntropyLoss()
55 | # 优化方式
56 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
57 | # 遍历所有数据 5 次
58 | for epoch in range(5):
59 | # data 是一个 mini-batch , batch-size 由 trainloader 决定
60 | for i, data in enumerate(trainloader, 0):
61 | # 特征值与标签:注意用 Variable 进行包裹
62 | inputs, labels = data
63 | inputs, labels = Variable(inputs), Variable(labels)
64 | # pytorch 中 backward() 函数执行时,梯度是累积计算,
65 | # 而不是被替换;
66 | # 但在处理每一个 batch 时并不需要与其他 batch 的梯度
67 | # 混合起来累积计算,
68 | # 因此需要对每个 batch 调用一遍 zero_grad()
69 | # 将参数梯度置0。
70 | optimizer.zero_grad()
71 | # 现在的输出值
72 | outputs = net(inputs)
73 | # 求误差
74 | loss = criterion(outputs, labels)
75 | # 在初始化 optimizer 时,我们明确告诉它应该更新模型的
76 | # 哪些参数;一旦调用了 loss.backward() ,梯度就会被
77 | # torch对象“存储”(它们具有grad和requires_grad属性);
78 | # 在计算模型中所有张量的梯度后,调用 optimizer.step()
79 | # 会使优化器迭代它应该更新的所有参数(张量),
80 | # 并使用它们内部存储的 grad 来更新它们的值。
81 | loss.backward()
82 | optimizer.step()
83 |
84 | # 模型的保存与加载
85 | torch.save(net.state_dict(), './pyTorch_learn/data/' + 'model.pt')
86 | net.load_state_dict(torch.load('./pyTorch_learn/data/' + 'model.pt'))
87 | ```
88 |
89 | ### 第 0 篇 先导篇
90 |
91 | #### 1 一个及其简单的强化学习实例
92 |
93 |
94 | [./ch_0/sec_1/](./ch_0/sec_1/)
95 |
96 | 很简答的一个实例,探讨“探索”与“利用”间的博弈平衡。
97 |
98 | 我对原书的代码进行了一些改进与 typo 。
99 |
100 | #### 2 马尔可夫决策过程
101 |
102 |
103 | [./ch_0/sec_2/](./ch_0/sec_2/)
104 |
105 | 优势函数(Advantage Function):$A(s, a) = q_\pi (s, a) - v_\pi (s)$
106 |
107 | 造了一个交互环境,以后测试可以用到。典型的“网格世界”。
108 |
109 | ### 第 1 篇 基于值函数的方法
110 |
111 | #### 3 基于动态规划的方法
112 |
113 |
114 | [./ch_1/sec_3/](./ch_1/sec_3/)
115 |
116 | - 修改了“鸳鸯环境”:[yuan_yang_env.py](./ch_1/sec_3/yuan_yang_env.py)
117 | - 策略迭代:[dp_policy_iter.py](./ch_1/sec_3/dp_policy_iter.py)
118 | - 价值迭代:[dp_value_iter.py](./ch_1/sec_3/dp_value_iter.py)
119 |
120 | #### 4 基于蒙特卡洛的方法
121 |
122 |
123 | [./ch_1/sec_4/](./ch_1/sec_4/)
124 |
125 | 基于蒙特卡洛方法,分别实现了:
126 | - 纯贪心策略;
127 | - 同轨策略下的 Epsilon-贪心策略。
128 |
129 | #### 5 基于时间差分的方法
130 |
131 |
132 | [./ch_1/sec_5/](./ch_1/sec_5/)
133 |
134 | 本书让我对“离轨策略”的概念更加清晰:
135 | - 离轨策略可以使用重复的数据、不同策略下产生的数据;
136 | - 因为,离轨策略更新时,只用到了(s, a, r, s'),并不需要使用 a' 这个数据;
137 | - 换句话说,并不需要数据由本策略产生。
138 |
139 | **发现了一个神奇的现象:关于奖励机制的设置。**
140 |
141 | 书上说,本节使用的 env 可以与蒙特卡洛同,这是不对的。
142 | 先说上节蒙特卡洛的奖励机制:
143 | - 小鸟撞到墙:-10;
144 | - 小鸟到达目的地:+10;
145 | - 小鸟走一步,什么都没有发生:-2。
146 |
147 | 如果你运行试验,你会发现,TD(0) 方法下,小鸟会 **畏惧前行** :
148 | - 在蒙特卡洛方法下,如此的奖励机制有效的,因为训练是在整整一幕结束之后;
149 | - 在 TD(0) 方法下,小鸟状态墙得到的奖励是 -10 ,而它行走五步的奖励也是 -10 (不考虑折扣);
150 | - 但在这个环境中,小鸟要抵达目的地,至少要走 20 步;
151 | - 因此,与其“披荆斩棘走到目的地”,还不如“在一开始就一头撞在墙上撞死”来的奖励多呢。
152 |
153 | 可以用如下几个例子印证,对于 [./ch_1/sec_5/yuan_yang_env_td.py](./ch_1/sec_5/yuan_yang_env_td.py) 的第 151-159 行:
154 | ```python
155 | flag_collide = self.collide(next_position)
156 | if flag_collide == 1:
157 | return self.position_to_state(current_position), -10, True
158 |
159 | flag_find = self.find(next_position)
160 | if flag_find == 1:
161 | return self.position_to_state(next_position), 10, True
162 |
163 | return self.position_to_state(next_position), -2, False
164 | ```
165 |
166 | 现在 (撞墙奖励, 到达重点奖励, 走路奖励) 分别为 (-10, 10, -2) 。现在的实验结果是:小鸟宁可撞死,也不出门。
167 |
168 | 我们把出门走路的痛感与抵达目的地的快乐进行更改,:
169 | - (-10, 10, -1),出门走路没有那么疼了,小鸟倾向于抵达目的地;
170 | - (-10, 10, -1.2),出门走路的痛感上升,小鸟倾向于宁可开局就撞死自己;
171 | - (-10, 100, -1.2),出门走路虽然疼,但是到达目的地的快乐是很大很大的,小鸟多次尝试,掐指一算,还是出门合适,毕竟它是一只深谋远虑的鸟。
172 |
173 | 运行试验,我们发现,上述试验并不稳定,这是因为每次训练小鸟做出的随机决策不同,且“走路痛感”与“抵达快乐”很不悬殊,小鸟左右为难。
174 |
175 | 因此,一个好的解决方案是:**别管那么多!我们不是希望小鸟抵达目的地吗?那就不要让它走路感到疼痛!**
176 | - 设置奖励为(-10, 10, 0),我保证小鸟会“愿意出门”,并抵达目的地!
177 |
178 | 这也提醒了我:**作为一个强化学习算法实践者,不应该过多干涉智能体决策!告诉智能体几个简单的规则,剩下的交给其自己学习。过于复杂的奖励机制,会出现意想不到的状况!**
179 |
180 | 对于此问题,其实还可以另一个思路:调高撞墙的痛感。
181 | - 设置奖励为(-1000, 10, -2),如次,小鸟便不敢撞墙:因为撞墙太疼了!!!
182 | - 并且,小鸟也不会“多走路”,因为多走路也有痛感。你会发现,如此得到的结果,小鸟总是能找到最优方案(最短的路径)。
183 |
184 | #### 6 基于函数逼近的方法
185 |
186 |
187 | [./ch_1/sec_6/](./ch_1/sec_6/)a
188 |
189 | 本节前半部分最后一次使用“鸳鸯系统”,我发现:
190 | - 无论是正常向量状态表示,还是固定稀疏状态表示,书中都将 `epsilon = epsilon * 0.99` 在迭代中去掉;
191 | - 事实证明,不应该将其去掉,尤其是第一组情况。第一组情况其实就是上节的表格型 q-learning ;
192 | - 固定稀疏表示中,可以不加探索欲望的收敛(在这个环境中)。
193 |
194 | 此外,还发现:
195 | - 固定稀疏中,鸟倾向于走直线;
196 | - 我认为这是因为固定稀疏矩阵中,抽取了特征,同一个 x 或同一个 y 对应的状态,其价值更趋同。
197 |
198 | 本节后半部分:非线性函数逼近。
199 |
200 | 书中没有给代码地址,我 Google 到作者应该是借鉴了这个:[https://github.com/yenchenlin/DeepLearningFlappyBird](https://github.com/yenchenlin/DeepLearningFlappyBird)
201 | - 我将这个项目写在了:[./ch_1/sec_6/flappy_bird/](./ch_1/sec_6/flappy_bird/)
202 | - - 我添加了手动操作体验游戏的部分,按 H 键可以煽动翅膀:[./ch_1/sec_6/flappy_bird/game/keyboard_agent.py](./ch_1/sec_6/flappy_bird/game/keyboard_agent.py)
203 | - - 书上是 tf 1 的代码,我使用 tf 2 重写,这个过程中参考了:[https://github.com/tomjur/TF2.0DQN](https://github.com/tomjur/TF2.0DQN)
204 | - `python -u "d:\GitHub\rl\Approachable-Reinforcement-Learning\ch_1\sec_6\flappy_bird\dqn_agent.py"`以训练
205 |
206 | ### AlphaZero 实战:五子棋
207 |
208 |
209 | 代码在 [./AlphaZero](./AlphaZero) 。
210 |
--------------------------------------------------------------------------------
/ch_0/sec_1/kb_game.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | class KB_Game:
5 | def __init__(self):
6 | self.q = np.array([0.0, 0.0, 0.0])
7 | self.action_counts = np.array([0, 0, 0])
8 | self.current_cumulative_rewards = 0.0
9 | self.actions = [1, 2, 3]
10 | self.counts = 0
11 | self.counts_history = []
12 | self.cumulative_rewards_history = []
13 | self.average_rewards_history = []
14 | self.a = 1
15 | self.reward = 0
16 |
17 | def step(self, a):
18 | r = 0
19 | if a == 1:
20 | r = np.random.normal(1, 1)
21 | elif a == 2:
22 | r = np.random.normal(2, 1)
23 | elif a == 3:
24 | r = np.random.normal(1.5, 1)
25 | return r
26 |
27 | def choose_action(self, policy, **kwargs):
28 | action = 0
29 | if policy == 'e_greedy':
30 | if np.random.random() < kwargs['epsilon']:
31 | action = np.random.randint(1, 4)
32 | else:
33 | action = np.argmax(self.q) + 1
34 | if policy == 'ucb':
35 | c_ratio = kwargs['c_ratio']
36 | if 0 in self.action_counts:
37 | action = np.where(self.action_counts==0)[0][0] + 1
38 | else:
39 | value = self.q + c_ratio * np.sqrt(np.log(self.counts) /\
40 | self.action_counts)
41 | action = np.argmax(value) + 1
42 | if policy == 'boltzmann':
43 | tau = kwargs['temperature']
44 | p = np.exp(self.q / tau) / (np.sum(np.exp(self.q / tau)))
45 | action = np.random.choice([1, 2, 3], p=p.ravel())
46 | return action
47 |
48 | def train(self, play_total, policy, **kwargs):
49 | reward_1 = []
50 | reward_2 = []
51 | reward_3 = []
52 | for i in range(play_total):
53 | action = 0
54 | if policy == 'e_greedy':
55 | action = \
56 | self.choose_action(policy, epsilon=kwargs['epsilon'])
57 | if policy == 'ucb':
58 | action = \
59 | self.choose_action(policy, c_ratio=kwargs['c_ratio'])
60 | if policy == 'boltzmann':
61 | action = \
62 | self.choose_action(policy, temperature=kwargs['temperature'])
63 | self.a = action
64 |
65 | self.r = self.step(self.a)
66 | self.counts +=1
67 |
68 | self.q[self.a-1] = (self.q[self.a-1] * self.action_counts[self.a-1] + self.r) /\
69 | (self.action_counts[self.a-1] + 1)
70 | self.action_counts[self.a-1] += 1
71 |
72 | reward_1.append(self.q[0])
73 | reward_2.append(self.q[1])
74 | reward_3.append(self.q[2])
75 | self.current_cumulative_rewards += self.r
76 | self.cumulative_rewards_history.append(self.current_cumulative_rewards)
77 | self.average_rewards_history.append(self.current_cumulative_rewards /\
78 | (len(self.average_rewards_history) + 1))
79 | self.counts_history.append(i)
80 |
81 | def reset(self):
82 | self.q = np.array([0.0, 0.0, 0.0])
83 | self.action_counts = np.array([0, 0, 0])
84 | self.current_cumulative_rewards = 0.0
85 | self.actions = [1, 2, 3]
86 | self.counts = 0
87 | self.counts_history = []
88 | self.cumulative_rewards_history = []
89 | self.average_rewards_history = []
90 | self.a = 1
91 | self.reward = 0
92 |
93 | def plot(self, colors, policy, style):
94 | plt.figure(1)
95 | plt.plot(self.counts_history, self.average_rewards_history, colors+style, label=policy)
96 | plt.legend()
97 | plt.xlabel('n', fontsize=18)
98 | plt.ylabel('average rewards', fontsize=18)
99 |
--------------------------------------------------------------------------------
/ch_0/sec_1/main.py:
--------------------------------------------------------------------------------
1 | from kb_game import KB_Game
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 | np.random.seed(0)
7 | k_gamble = KB_Game()
8 | total = 2000
9 |
10 | k_gamble.train(play_total=total, policy='e_greedy', epsilon=0.05)
11 | k_gamble.plot(colors='r', policy='e_greedy', style='-.')
12 | k_gamble.reset()
13 |
14 | k_gamble.train(play_total=total, policy='boltzmann', temperature=1)
15 | k_gamble.plot(colors='b', policy='boltzmann', style='--')
16 | k_gamble.reset()
17 |
18 | k_gamble.train(play_total=total, policy='ucb', c_ratio=0.5)
19 | k_gamble.plot(colors='g', policy='ucb', style='-')
20 | k_gamble.reset()
21 |
22 | plt.show()
--------------------------------------------------------------------------------
/ch_0/sec_2/images/background.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/background.jpg
--------------------------------------------------------------------------------
/ch_0/sec_2/images/bird_female.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/bird_female.png
--------------------------------------------------------------------------------
/ch_0/sec_2/images/bird_male.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/bird_male.png
--------------------------------------------------------------------------------
/ch_0/sec_2/images/obstacle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/obstacle.png
--------------------------------------------------------------------------------
/ch_0/sec_2/load_images.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import os.path as osp
3 |
4 | path_file = osp.abspath(__file__)
5 | path_images = osp.join(path_file, '..', 'images')
6 |
7 | def load_bird_male():
8 | obj = 'bird_male.png'
9 | obj_path = osp.join(path_images, obj)
10 | return pygame.image.load(obj_path)
11 |
12 | def load_bird_female():
13 | obj = 'bird_female.png'
14 | obj_path = osp.join(path_images, obj)
15 | return pygame.image.load(obj_path)
16 |
17 | def load_background():
18 | obj = 'background.jpg'
19 | obj_path = osp.join(path_images, obj)
20 | return pygame.image.load(obj_path)
21 |
22 | def load_obstacle():
23 | obj = 'obstacle.png'
24 | obj_path = osp.join(path_images, obj)
25 | return pygame.image.load(obj_path)
26 |
--------------------------------------------------------------------------------
/ch_0/sec_2/yuan_yang_env.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import random
3 | from load_images import *
4 | import numpy as np
5 |
6 | class YuanYangEnv:
7 | def __init__(self):
8 | self.states = []
9 | for i in range(0, 100):
10 | self.states.append(i)
11 | self.actions = ['e', 's', 'w', 'n']
12 | self.gamma = 0.8
13 | self.value = np.zeros((10, 10))
14 |
15 | self.viewer = None
16 | self.FPSCLOCK = pygame.time.Clock()
17 |
18 | self.screen_size = (1200, 900)
19 | self.bird_position = (0, 0)
20 | self.limit_distance_x = 120
21 | self.limit_distance_y = 90
22 | self.obstacle_size = [120, 90]
23 | self.obstacle1_x = []
24 | self.obstacle1_y = []
25 | self.obstacle2_x = []
26 | self.obstacle2_y = []
27 |
28 | for i in range(8):
29 | # obstacle 1
30 | self.obstacle1_x.append(360)
31 | if i <= 3:
32 | self.obstacle1_y.append(90 * i)
33 | else:
34 | self.obstacle1_y.append(90 * (i + 2))
35 | # obstacle 2
36 | self.obstacle2_x.append(720)
37 | if i <= 4:
38 | self.obstacle2_y.append(90 * i)
39 | else:
40 | self.obstacle2_y.append(90 * (i + 2))
41 |
42 | self.bird_male_init_position = [0.0, 0.0]
43 | self.bird_male_position = [0, 0]
44 | self.bird_female_init_position = [1080, 0]
45 |
46 | def collide(self, state_position):
47 | flag = 1
48 | flag1 = 1
49 | flag2 = 1
50 |
51 | # obstacle 1
52 | dx = []
53 | dy = []
54 | for i in range(8):
55 | dx1 = abs(self.obstacle1_x[i] - state_position[0])
56 | dx.append(dx1)
57 | dy1 = abs(self.obstacle1_y[i] - state_position[1])
58 | dy.append(dy1)
59 | mindx = min(dx)
60 | mindy = min(dy)
61 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
62 | flag1 = 0
63 |
64 | # obstacle 2
65 | dx_second = []
66 | dy_second = []
67 | for i in range(8):
68 | dx1 = abs(self.obstacle2_x[i] - state_position[0])
69 | dx_second.append(dx1)
70 | dy1 = abs(self.obstacle2_y[i] - state_position[1])
71 | dy_second.append(dy1)
72 | mindx = min(dx_second)
73 | mindy = min(dy_second)
74 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
75 | flag2 = 0
76 |
77 | if flag1 == 0 and flag2 == 0:
78 | flag = 0
79 |
80 | # collide edge
81 | if state_position[0] > 1080 or \
82 | state_position[0] < 0 or \
83 | state_position[1] > 810 or \
84 | state_position[1] < 0:
85 | flag = 1
86 |
87 | return flag
88 |
89 | def find(self, state_position):
90 | flag = 0
91 | if abs(state_position[0] - self.bird_female_init_position[0]) < \
92 | self.limit_distance_x and \
93 | abs(state_position[1] - self.bird_female_init_position[1]) < \
94 | self.limit_distance_y:
95 | flag = 1
96 | return flag
97 |
98 | def state_to_position(self, state):
99 | i = int(state / 10)
100 | j = state % 10
101 | position = [0, 0]
102 | postion[0] = 120 * j
103 | postion[1] = 90 * i
104 | return position
105 |
106 | def position_to_state(self, position):
107 | i = position[0] / 120
108 | j = position[1] / 90
109 | return int(i + 10 * j)
110 |
111 | def reset(self):
112 | # 随机产生一个初始位置
113 | flag1 = 1
114 | flag2 = 1
115 | while flag1 or flag2 == 1:
116 | state = self.states[int(random.random() * len(self.states))]
117 | state_position = self.state_to_position(state)
118 | flag1 = self.collide(state_position)
119 | flag2 = self.find(state_position)
120 | return state
121 |
122 | def transform(self, state, action):
123 | current_position = self.state_to_position(state)
124 | next_position = [0, 0]
125 | flag_collide = 0
126 | flag_find = 0
127 |
128 | flag_collide = self.collide(current_position)
129 | flag_find = self.find(current_position)
130 | if flag_collide == 1:
131 | return state, -1, True
132 | if flag_find == 1:
133 | return state, 1, True
134 |
135 | if action == 'e':
136 | next_position[0] = current_position[0] + 120
137 | next_position[1] = current_position[1]
138 | if action == 's':
139 | next_position[0] = current_position[0]
140 | next_position[1] = current_position[1] + 90
141 | if action == 'w':
142 | next_position[0] = current_position[0] - 120
143 | next_position[1] = current_position[1]
144 | if action == 'n':
145 | next_position[0] = current_position[0]
146 | next_position[1] = current_position[1] - 90
147 |
148 | flag_collide = self.collide(next_position)
149 | if flag_collide == 1:
150 | return self.position_to_state(current_position), -1, True
151 |
152 | flag_find = self.find(next_position)
153 | if flag_find == 1:
154 | return self.position_to_state(next_position), 1, True
155 |
156 | return self.position_to_state(next_position), 0, False
157 |
158 | def gameover(self):
159 | for event in pygame.event.get():
160 | if event.type == pygame.QUIT:
161 | exit()
162 |
163 | def render(self):
164 | if self.viewer is None:
165 | pygame.init()
166 |
167 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32)
168 | pygame.display.set_caption("yuanyang")
169 | # load pic
170 | self.bird_male = load_bird_male()
171 | self.bird_female = load_bird_female()
172 | self.background = load_background()
173 | self.obstacle = load_obstacle()
174 |
175 | # self.viewer.blit(self.bird_female, self.bird_female_init_position)
176 | # self.viewer.blit(self.bird_male, self.bird_male_init_position)
177 |
178 | self.viewer.blit(self.background, (0, 0))
179 | self.font = pygame.font.SysFont('times', 15)
180 |
181 | self.viewer.blit(self.background, (0, 0))
182 | for i in range(11):
183 | pygame.draw.lines(self.viewer,
184 | (255, 255, 255),
185 | True,
186 | ((120 * i, 0), (120 * i, 900)),
187 | 1
188 | )
189 | pygame.draw.lines(self.viewer,
190 | (255, 255, 255),
191 | True,
192 | ((0, 90 * i), (1200, 90 * i)),
193 | 1
194 | )
195 |
196 | for i in range(8):
197 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
198 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
199 |
200 | self.viewer.blit(self.bird_female, self.bird_female_init_position)
201 | self.viewer.blit(self.bird_male, self.bird_male_init_position)
202 |
203 | for i in range(10):
204 | for j in range(10):
205 | surface = self.font.render(str(
206 | round(float(self.value[i, j]), 3)), True, (0, 0, 0)
207 | )
208 | self.viewer.blit(surface, (120 * i + 5, 90 * j + 70))
209 |
210 | pygame.display.update()
211 | self.gameover()
212 | self.FPSCLOCK.tick(30)
213 |
214 |
215 | if __name__ == "__main__":
216 | yy = YuanYangEnv()
217 | yy.render()
218 | while True:
219 | for event in pygame.event.get():
220 | if event.type == pygame.QUIT:
221 | exit()
222 |
223 |
--------------------------------------------------------------------------------
/ch_1/sec_3/dp_policy_iter.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from yuan_yang_env import YuanYangEnv
4 |
5 |
6 | class DP_Policy_Iter:
7 | def __init__(self, yuanyang):
8 | self.states = yuanyang.states
9 | self.actions = yuanyang.actions
10 | self.v = [0.0 for i in range(len(self.states) + 1)]
11 | self.pi = dict()
12 | self.yuanyang = yuanyang
13 | self.gamma = yuanyang.gamma
14 |
15 | # 初始化策略
16 | for state in self.states:
17 | flag1 = 0
18 | flag2 = 0
19 | flag1 = yuanyang.collide(yuanyang.state_to_position(state))
20 | flag2 = yuanyang.find(yuanyang.state_to_position(state))
21 | if flag1 == 1 or flag2 == 1:
22 | continue
23 | self.pi[state] = self.actions[int(random.random() * len(self.actions))]
24 |
25 | def policy_evaluate(self):
26 | # 策略评估在计算值函数
27 | for i in range(100):
28 | delta = 0.0
29 | for state in self.states:
30 | flag1 = 0
31 | flag2 = 0
32 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state))
33 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state))
34 | if flag1 == 1 or flag2 == 1:
35 | continue
36 | action = self.pi[state]
37 | s, r, t = self.yuanyang.transform(state, action)
38 | # 更新值
39 | new_v = r + self.gamma * self.v[s]
40 | delta += abs(self.v[state] - new_v)
41 | # 更新值替换原来的值函数
42 | self.v[state] = new_v
43 | if delta < 1e-6:
44 | print('策略评估迭代次数:{}'.format(i))
45 | break
46 |
47 | def policy_improve(self):
48 | # 利用更新后的值函数进行策略改善
49 | for state in self.states:
50 | flag1 = 0
51 | flag2 = 0
52 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state))
53 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state))
54 | if flag1 == 1 or flag2 == 1:
55 | continue
56 | a1 = self.actions[0]
57 | s, r, t = self.yuanyang.transform(state, a1)
58 | v1 = r + self.gamma * self.v[s]
59 | # 找状态 s 时,采用哪种动作,值函数最大
60 | for action in self.actions:
61 | s, r, t = self.yuanyang.transform(state, action)
62 | if v1 < r + self.gamma * self.v[s]:
63 | a1 = action
64 | v1 = r + self.gamma * self.v[s]
65 | # 贪婪策略,进行更新
66 | self.pi[state] = a1
67 |
68 | def policy_iterate(self):
69 | for i in range(100):
70 | # 策略评估,变的是 v
71 | self.policy_evaluate()
72 | # 策略改善
73 | pi_old = self.pi.copy()
74 | # 改变 pi
75 | self.policy_improve()
76 | if (self.pi == pi_old):
77 | print('策略改善次数:{}'.format(i))
78 | break
79 |
80 |
81 | if __name__ == "__main__":
82 | yuanyang = YuanYangEnv()
83 | policy_value = DP_Policy_Iter(yuanyang)
84 | policy_value.policy_iterate()
85 |
86 | # 打印
87 | flag = 1
88 | s = 0
89 | path = []
90 | # 将 v 值打印出来
91 | for state in range(100):
92 | i = int(state / 10)
93 | j = state % 10
94 | yuanyang.value[j, i] = policy_value.v[state]
95 |
96 | step_num = 0
97 |
98 | # 将最优路径打印出来
99 | while flag:
100 | path.append(s)
101 | yuanyang.path = path
102 | a = policy_value.pi[s]
103 | print('%d->%s\t'%(s, a))
104 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
105 | yuanyang.render()
106 | time.sleep(0.2)
107 | step_num += 1
108 | s_, r, t = yuanyang.transform(s, a)
109 | if t == True or step_num > 200:
110 | flag = 0
111 | s = s_
112 |
113 | # 渲染最后的路径点
114 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
115 | path.append(s)
116 | yuanyang.render()
117 | while True:
118 | yuanyang.render()
119 |
--------------------------------------------------------------------------------
/ch_1/sec_3/dp_value_iter.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from yuan_yang_env import YuanYangEnv
4 |
5 | class DP_Value_Iter:
6 | def __init__(self, yuanyang):
7 | self.states = yuanyang.states
8 | self.actions = yuanyang.actions
9 | self.v = [0.0 for i in range(len(self.states) + 1)]
10 | self.pi = dict()
11 | self.yuanyang = yuanyang
12 |
13 | self.gamma = yuanyang.gamma
14 | for state in self.states:
15 | flag1 = 0
16 | flag2 = 0
17 | flag1 = yuanyang.collide(yuanyang.state_to_position(state))
18 | flag2 = yuanyang.find(yuanyang.state_to_position(state))
19 | if flag1 == 1 or flag2 == 1:
20 | continue
21 | self.pi[state] = self.actions[int(random.random() * len(self.actions))]
22 |
23 | def value_iteration(self):
24 | for i in range(1000):
25 | delta = 0.0
26 | for state in self.states:
27 | flag1 = 0
28 | flag2 = 0
29 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state))
30 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state))
31 | if flag1 == 1 or flag2 == 2:
32 | continue
33 | a1 = self.actions[int(random.random() * 4)]
34 | s, r, t = self.yuanyang.transform(state, a1)
35 | # 策略评估
36 | v1 = r + self.gamma * self.v[s]
37 | # 策略改善
38 | for action in self.actions:
39 | s, r, t = self.yuanyang.transform(state, action)
40 | if v1 < r + self.gamma * self.v[s]:
41 | a1 = action
42 | v1 = r + self.gamma * self.v[s]
43 | delta += abs(v1 - self.v[state])
44 | self.pi[state] = a1
45 | self.v[state] = v1
46 | if delta < 1e-6:
47 | print('迭代次数为:{}'.format(i))
48 | break
49 |
50 |
51 | if __name__ == "__main__":
52 | yuanyang = YuanYangEnv()
53 | policy_value = DP_Value_Iter(yuanyang)
54 | policy_value.value_iteration()
55 |
56 | # 打印
57 | flag = 1
58 | s = 0
59 | path = []
60 | # 将 v 值打印出来
61 | for state in range(100):
62 | i = int(state / 10)
63 | j = state % 10
64 | yuanyang.value[j, i] = policy_value.v[state]
65 |
66 | step_num = 0
67 |
68 | # 将最优路径打印出来
69 | while flag:
70 | path.append(s)
71 | yuanyang.path = path
72 | a = policy_value.pi[s]
73 | print('%d->%s\t'%(s, a))
74 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
75 | yuanyang.render()
76 | time.sleep(0.2)
77 | step_num += 1
78 | s_, r, t = yuanyang.transform(s, a)
79 | if t == True or step_num > 200:
80 | flag = 0
81 | s = s_
82 |
83 | # 渲染最后的路径点
84 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
85 | path.append(s)
86 | yuanyang.render()
87 | while True:
88 | yuanyang.render()
89 |
--------------------------------------------------------------------------------
/ch_1/sec_3/load_images.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import os.path as osp
3 |
4 | path_file = osp.abspath(__file__)
5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images')
6 |
7 | def load_bird_male():
8 | obj = 'bird_male.png'
9 | obj_path = osp.join(path_images, obj)
10 | return pygame.image.load(obj_path)
11 |
12 | def load_bird_female():
13 | obj = 'bird_female.png'
14 | obj_path = osp.join(path_images, obj)
15 | return pygame.image.load(obj_path)
16 |
17 | def load_background():
18 | obj = 'background.jpg'
19 | obj_path = osp.join(path_images, obj)
20 | return pygame.image.load(obj_path)
21 |
22 | def load_obstacle():
23 | obj = 'obstacle.png'
24 | obj_path = osp.join(path_images, obj)
25 | return pygame.image.load(obj_path)
26 |
--------------------------------------------------------------------------------
/ch_1/sec_3/yuan_yang_env.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import random
3 | from load_images import *
4 | import numpy as np
5 |
6 | class YuanYangEnv:
7 | def __init__(self):
8 | self.states = []
9 | for i in range(0, 100):
10 | self.states.append(i)
11 | self.actions = ['e', 's', 'w', 'n']
12 | self.gamma = 0.8
13 | self.value = np.zeros((10, 10))
14 |
15 | self.viewer = None
16 | self.FPSCLOCK = pygame.time.Clock()
17 |
18 | self.screen_size = (1200, 900)
19 | self.bird_position = (0, 0)
20 | self.limit_distance_x = 120
21 | self.limit_distance_y = 90
22 | self.obstacle_size = [120, 90]
23 | self.obstacle1_x = []
24 | self.obstacle1_y = []
25 | self.obstacle2_x = []
26 | self.obstacle2_y = []
27 |
28 | for i in range(8):
29 | # obstacle 1
30 | self.obstacle1_x.append(360)
31 | if i <= 3:
32 | self.obstacle1_y.append(90 * i)
33 | else:
34 | self.obstacle1_y.append(90 * (i + 2))
35 | # obstacle 2
36 | self.obstacle2_x.append(720)
37 | if i <= 4:
38 | self.obstacle2_y.append(90 * i)
39 | else:
40 | self.obstacle2_y.append(90 * (i + 2))
41 |
42 | self.bird_male_init_position = [0.0, 0.0]
43 | self.bird_male_position = [0, 0]
44 | self.bird_female_init_position = [1080, 0]
45 |
46 | self.path = []
47 |
48 | def collide(self, state_position):
49 | flag = 1
50 | flag1 = 1
51 | flag2 = 1
52 |
53 | # obstacle 1
54 | dx = []
55 | dy = []
56 | for i in range(8):
57 | dx1 = abs(self.obstacle1_x[i] - state_position[0])
58 | dx.append(dx1)
59 | dy1 = abs(self.obstacle1_y[i] - state_position[1])
60 | dy.append(dy1)
61 | mindx = min(dx)
62 | mindy = min(dy)
63 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
64 | flag1 = 0
65 |
66 | # obstacle 2
67 | dx_second = []
68 | dy_second = []
69 | for i in range(8):
70 | dx1 = abs(self.obstacle2_x[i] - state_position[0])
71 | dx_second.append(dx1)
72 | dy1 = abs(self.obstacle2_y[i] - state_position[1])
73 | dy_second.append(dy1)
74 | mindx = min(dx_second)
75 | mindy = min(dy_second)
76 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
77 | flag2 = 0
78 |
79 | if flag1 == 0 and flag2 == 0:
80 | flag = 0
81 |
82 | # collide edge
83 | if state_position[0] > 1080 or \
84 | state_position[0] < 0 or \
85 | state_position[1] > 810 or \
86 | state_position[1] < 0:
87 | flag = 1
88 |
89 | return flag
90 |
91 | def find(self, state_position):
92 | flag = 0
93 | if abs(state_position[0] - self.bird_female_init_position[0]) < \
94 | self.limit_distance_x and \
95 | abs(state_position[1] - self.bird_female_init_position[1]) < \
96 | self.limit_distance_y:
97 | flag = 1
98 | return flag
99 |
100 | def state_to_position(self, state):
101 | i = int(state / 10)
102 | j = state % 10
103 | position = [0, 0]
104 | position[0] = 120 * j
105 | position[1] = 90 * i
106 | return position
107 |
108 | def position_to_state(self, position):
109 | i = position[0] / 120
110 | j = position[1] / 90
111 | return int(i + 10 * j)
112 |
113 | def reset(self):
114 | # 随机产生一个初始位置
115 | flag1 = 1
116 | flag2 = 1
117 | while flag1 or flag2 == 1:
118 | state = self.states[int(random.random() * len(self.states))]
119 | state_position = self.state_to_position(state)
120 | flag1 = self.collide(state_position)
121 | flag2 = self.find(state_position)
122 | return state
123 |
124 | def transform(self, state, action):
125 | current_position = self.state_to_position(state)
126 | next_position = [0, 0]
127 | flag_collide = 0
128 | flag_find = 0
129 |
130 | flag_collide = self.collide(current_position)
131 | flag_find = self.find(current_position)
132 | if flag_collide == 1:
133 | return state, -1, True
134 | if flag_find == 1:
135 | return state, 1, True
136 |
137 | if action == 'e':
138 | next_position[0] = current_position[0] + 120
139 | next_position[1] = current_position[1]
140 | if action == 's':
141 | next_position[0] = current_position[0]
142 | next_position[1] = current_position[1] + 90
143 | if action == 'w':
144 | next_position[0] = current_position[0] - 120
145 | next_position[1] = current_position[1]
146 | if action == 'n':
147 | next_position[0] = current_position[0]
148 | next_position[1] = current_position[1] - 90
149 |
150 | flag_collide = self.collide(next_position)
151 | if flag_collide == 1:
152 | return self.position_to_state(current_position), -1, True
153 |
154 | flag_find = self.find(next_position)
155 | if flag_find == 1:
156 | return self.position_to_state(next_position), 1, True
157 |
158 | return self.position_to_state(next_position), 0, False
159 |
160 | def gameover(self):
161 | for event in pygame.event.get():
162 | if event.type == pygame.QUIT:
163 | exit()
164 |
165 | def render(self):
166 | if self.viewer is None:
167 | pygame.init()
168 |
169 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32)
170 | pygame.display.set_caption("yuanyang")
171 | # load pic
172 | self.bird_male = load_bird_male()
173 | self.bird_female = load_bird_female()
174 | self.background = load_background()
175 | self.obstacle = load_obstacle()
176 |
177 | # self.viewer.blit(self.bird_female, self.bird_female_init_position)
178 | # self.viewer.blit(self.bird_male, self.bird_male_init_position)
179 |
180 | self.viewer.blit(self.background, (0, 0))
181 | self.font = pygame.font.SysFont('times', 15)
182 |
183 | self.viewer.blit(self.background, (0, 0))
184 | for i in range(11):
185 | pygame.draw.lines(self.viewer,
186 | (255, 255, 255),
187 | True,
188 | ((120 * i, 0), (120 * i, 900)),
189 | 1
190 | )
191 | pygame.draw.lines(self.viewer,
192 | (255, 255, 255),
193 | True,
194 | ((0, 90 * i), (1200, 90 * i)),
195 | 1
196 | )
197 |
198 | for i in range(8):
199 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
200 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
201 |
202 | self.viewer.blit(self.bird_female, self.bird_female_init_position)
203 | self.viewer.blit(self.bird_male, self.bird_male_init_position)
204 |
205 | for i in range(10):
206 | for j in range(10):
207 | surface = self.font.render(str(
208 | round(float(self.value[i, j]), 3)), True, (0, 0, 0)
209 | )
210 | self.viewer.blit(surface, (120 * i + 5, 90 * j + 70))
211 |
212 | # 画路径点
213 | for i in range(len(self.path)):
214 | rec_position = self.state_to_position(self.path[i])
215 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
216 | surface = self.font.render(str(i), True, (255, 0, 0))
217 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
218 |
219 | pygame.display.update()
220 | self.gameover()
221 | self.FPSCLOCK.tick(30)
222 |
223 |
224 | if __name__ == "__main__":
225 | yy = YuanYangEnv()
226 | yy.render()
227 | while True:
228 | for event in pygame.event.get():
229 | if event.type == pygame.QUIT:
230 | exit()
231 |
232 |
--------------------------------------------------------------------------------
/ch_1/sec_4/load_images.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import os.path as osp
3 |
4 | path_file = osp.abspath(__file__)
5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images')
6 |
7 | def load_bird_male():
8 | obj = 'bird_male.png'
9 | obj_path = osp.join(path_images, obj)
10 | return pygame.image.load(obj_path)
11 |
12 | def load_bird_female():
13 | obj = 'bird_female.png'
14 | obj_path = osp.join(path_images, obj)
15 | return pygame.image.load(obj_path)
16 |
17 | def load_background():
18 | obj = 'background.jpg'
19 | obj_path = osp.join(path_images, obj)
20 | return pygame.image.load(obj_path)
21 |
22 | def load_obstacle():
23 | obj = 'obstacle.png'
24 | obj_path = osp.join(path_images, obj)
25 | return pygame.image.load(obj_path)
26 |
--------------------------------------------------------------------------------
/ch_1/sec_4/mc_rl.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import time
3 | import random
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from yuan_yang_env_mc import YuanYangEnv
7 |
8 | class MC_RL:
9 | def __init__(self, yuanyang):
10 | # 行为值函数的初始化
11 | self.qvalue = np.zeros((len(yuanyang.states), len(yuanyang.actions))) * 0.1
12 | # 次数初始化
13 | self.n = 0.001 * np.ones(
14 | (len(yuanyang.states), len(yuanyang.actions))
15 | )
16 | self.actions = yuanyang.actions
17 | self.yuanyang = yuanyang
18 | self.gamma = yuanyang.gamma
19 |
20 | # 定义贪婪策略
21 | def greedy_policy(self, qfun, state):
22 | amax = qfun[state, :].argmax()
23 | return self.actions[amax]
24 |
25 | def epsilon_greedy_policy(self, qfun, state, epsilon):
26 | amax = qfun[state, :].argmax()
27 | # 概率部分
28 | if np.random.uniform() < 1 - epsilon:
29 | # 最优动作
30 | return self.actions[amax]
31 | else:
32 | return self.actions[int(random.random() * len(self.actions))]
33 |
34 | # 找到动作所对应的序号
35 | def find_anum(self, a):
36 | for i in range(len(self.actions)):
37 | if a == self.actions[i]:
38 | return i
39 |
40 | def mc_learning_ei(self, num_iter):
41 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
42 | self.n = 0.001 * np.ones((len(self.yuanyang.states), len(self.yuanyang.actions)))
43 | # 学习 num_iter 次
44 | for iter1 in range(num_iter):
45 | # 采集状态样本
46 | s_sample = []
47 | # 采集动作样本
48 | a_sample = []
49 | # 采集回报样本
50 | r_sample = []
51 | # 随机初始化状态
52 | s = self.yuanyang.reset()
53 | a = self.actions[int(random.random() * len(self.actions))]
54 | done = False
55 | step_num = 0
56 |
57 | if self.mc_test() == 1:
58 | print("探索初始化第1次完成任务需要的次数:{}".format(iter1))
59 | break
60 |
61 | # 采集数据 s0-a1-s1-s1-a2-s2...terminate state
62 | while done == False and step_num < 30:
63 | # 与环境交互
64 | s_next, r, done = self.yuanyang.transform(s, a)
65 | a_num = self.find_anum(a)
66 | # 往回走给予惩罚
67 | if s_next in s_sample:
68 | r = -2
69 | # 存储数据,采样数据
70 | s_sample.append(s)
71 | r_sample.append(r)
72 | a_sample.append(a_num)
73 | step_num += 1
74 | # 转移到下一状态,继续试验,s0-s1-s2
75 | s = s_next
76 | a = self.greedy_policy(self.qvalue, s)
77 |
78 | # 从样本中计算折扣累计回报,g(s_0) = r_0 + gamma * r_1 + ... + v(sT)
79 | a = self.greedy_policy(self.qvalue, s)
80 | g = self.qvalue[s, self.find_anum(a)]
81 | for i in range(len(s_sample) - 1, -1, -1):
82 | g *= self.gamma
83 | g += r_sample[i]
84 | # g = G(s1, a),开始对其他状态累计回报
85 | for i in range(len(s_sample)):
86 | # 计算状态-行为对(s, a)的次数,s, a1...s, a2
87 | self.n[s_sample[i], a_sample[i]] += 1.0
88 | # 利用增量式方法更新值函数
89 | self.qvalue[s_sample[i], a_sample[i]] = (self.qvalue[s_sample[i], a_sample[i]] * (self.n[s_sample[i], a_sample[i]] - 1) + g) / self.n[s_sample[i], a_sample[i]]
90 | g -= r_sample[i]
91 | g /= self.gamma
92 |
93 | return self.qvalue
94 |
95 | def mc_test(self):
96 | s = 0
97 | s_sample = []
98 | done = False
99 | flag = 0
100 | step_num = 0
101 | while done == False and step_num < 30:
102 | a = self.greedy_policy(self.qvalue, s)
103 | # 与环境交互
104 | s_next, r, done = self.yuanyang.transform(s, a)
105 | s_sample.append(s)
106 | s = s_next
107 | step_num += 1
108 | if s == 9:
109 | flag = 1
110 | return flag
111 |
112 | def mc_learning_on_policy(self, num_iter, epsilon):
113 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
114 | self.n = 0.001 * np.ones((len(self.yuanyang.states), len(self.yuanyang.actions)))
115 | # 学习 num_iter 次
116 | for iter1 in range(num_iter):
117 | # 采集状态样本
118 | s_sample = []
119 | # 采集动作样本
120 | a_sample = []
121 | # 采集回报样本
122 | r_sample = []
123 | # 固定初始状态
124 | s = 0
125 | done = False
126 | step_num = 0
127 | epsilon = epsilon * np.exp(-iter1 / 1000)
128 |
129 | # 采集数据 s0-a1-s1-s1-a2-s2...terminate state
130 | while done == False and step_num < 30:
131 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
132 | # 与环境交互
133 | s_next, r, done = self.yuanyang.transform(s, a)
134 | a_num = self.find_anum(a)
135 | # 往回走给予惩罚
136 | if s_next in s_sample:
137 | r = -2
138 | # 存储数据,采样数据
139 | s_sample.append(s)
140 | r_sample.append(r)
141 | a_sample.append(a_num)
142 | step_num += 1
143 | # 转移到下一状态,继续试验,s0-s1-s2
144 | s = s_next
145 |
146 | if s == 9:
147 | print('同轨策略第1次完成任务需要次数:{}'.format(iter1))
148 | break
149 |
150 | # 从样本中计算折扣累计回报 g(s_0) = r_0 + gamma * r_1 + gamma ^ 3 * r3 + v(sT)
151 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
152 | g = self.qvalue[s, self.find_anum(a)]
153 | # 计算该序列第1状态的折扣累计回报
154 | for i in range(len(s_sample) - 1, -1, -1):
155 | g *= self.gamma
156 | g += r_sample[i]
157 | # g = G(s1, a),开始计算其他状态的折扣累计回报
158 | for i in range(len(s_sample)):
159 | # 计算状态-行为对 (s, a) 的次数,s, a1...s, a2
160 | self.n[s_sample[i], a_sample[i]] += 1.0
161 | # 利用增量式方法更新值函数
162 | self.qvalue[s_sample[i], a_sample[i]] = (
163 | self.qvalue[s_sample[i], a_sample[i]] * (self.n[s_sample[i], a_sample[i]] - 1) + g
164 | ) / self.n[s_sample[i], a_sample[i]]
165 | g -= r_sample[i]
166 | g /= self.gamma
167 |
168 | return self.qvalue
169 |
170 |
171 | def mc_learning_ei():
172 | yuanyang = YuanYangEnv()
173 | brain = MC_RL(yuanyang)
174 | # 探索初始化方法
175 | qvalue1 = brain.mc_learning_ei(num_iter=10000)
176 |
177 | # 打印
178 | flag = 1
179 | s = 0
180 | path = []
181 | # 将 v 值打印出来
182 | yuanyang.action_value = qvalue1
183 | step_num = 0
184 |
185 | # 将最优路径打印出来
186 | while flag:
187 | path.append(s)
188 | yuanyang.path = path
189 | a = brain.greedy_policy(qvalue1, s)
190 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3])
191 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
192 | yuanyang.render()
193 | time.sleep(0.25)
194 | step_num += 1
195 | s_, r, t = yuanyang.transform(s, a)
196 | if t == True or step_num > 200:
197 | flag = 0
198 | s = s_
199 |
200 | # 渲染最后的路径点
201 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
202 | path.append(s)
203 | yuanyang.render()
204 | while True:
205 | yuanyang.render()
206 |
207 | def mc_learning_epsilon():
208 | yuanyang = YuanYangEnv()
209 | brain = MC_RL(yuanyang)
210 | # 探索初始化方法
211 | qvalue2 = brain.mc_learning_on_policy(num_iter=10000, epsilon=0.2)
212 |
213 | # 打印
214 | flag = 1
215 | s = 0
216 | path = []
217 | # 将 v 值打印出来
218 | yuanyang.action_value = qvalue2
219 | step_num = 0
220 |
221 | # 将最优路径打印出来
222 | while flag:
223 | path.append(s)
224 | yuanyang.path = path
225 | a = brain.greedy_policy(qvalue2, s)
226 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3])
227 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
228 | yuanyang.render()
229 | time.sleep(0.25)
230 | step_num += 1
231 | s_, r, t = yuanyang.transform(s, a)
232 | if t == True or step_num > 200:
233 | flag = 0
234 | s = s_
235 |
236 | # 渲染最后的路径点
237 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
238 | path.append(s)
239 | yuanyang.render()
240 | while True:
241 | yuanyang.render()
242 |
243 | if __name__ == "__main__":
244 | # mc_learning_ei()
245 |
246 | mc_learning_epsilon()
--------------------------------------------------------------------------------
/ch_1/sec_4/yuan_yang_env_mc.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import random
3 | from load_images import *
4 | import numpy as np
5 |
6 | class YuanYangEnv:
7 | def __init__(self):
8 | self.states = []
9 | for i in range(0, 100):
10 | self.states.append(i)
11 | self.actions = ['e', 's', 'w', 'n']
12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减
13 | self.gamma = 0.95
14 | self.action_value = np.zeros((100, 4))
15 |
16 | self.viewer = None
17 | self.FPSCLOCK = pygame.time.Clock()
18 |
19 | self.screen_size = (1200, 900)
20 | self.bird_position = (0, 0)
21 | self.limit_distance_x = 120
22 | self.limit_distance_y = 90
23 | self.obstacle_size = [120, 90]
24 | self.obstacle1_x = []
25 | self.obstacle1_y = []
26 | self.obstacle2_x = []
27 | self.obstacle2_y = []
28 |
29 | for i in range(8):
30 | # obstacle 1
31 | self.obstacle1_x.append(360)
32 | if i <= 3:
33 | self.obstacle1_y.append(90 * i)
34 | else:
35 | self.obstacle1_y.append(90 * (i + 2))
36 | # obstacle 2
37 | self.obstacle2_x.append(720)
38 | if i <= 4:
39 | self.obstacle2_y.append(90 * i)
40 | else:
41 | self.obstacle2_y.append(90 * (i + 2))
42 |
43 | self.bird_male_init_position = [0.0, 0.0]
44 | self.bird_male_position = [0, 0]
45 | self.bird_female_init_position = [1080, 0]
46 |
47 | self.path = []
48 |
49 | def collide(self, state_position):
50 | flag = 1
51 | flag1 = 1
52 | flag2 = 1
53 |
54 | # obstacle 1
55 | dx = []
56 | dy = []
57 | for i in range(8):
58 | dx1 = abs(self.obstacle1_x[i] - state_position[0])
59 | dx.append(dx1)
60 | dy1 = abs(self.obstacle1_y[i] - state_position[1])
61 | dy.append(dy1)
62 | mindx = min(dx)
63 | mindy = min(dy)
64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
65 | flag1 = 0
66 |
67 | # obstacle 2
68 | dx_second = []
69 | dy_second = []
70 | for i in range(8):
71 | dx1 = abs(self.obstacle2_x[i] - state_position[0])
72 | dx_second.append(dx1)
73 | dy1 = abs(self.obstacle2_y[i] - state_position[1])
74 | dy_second.append(dy1)
75 | mindx = min(dx_second)
76 | mindy = min(dy_second)
77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
78 | flag2 = 0
79 |
80 | if flag1 == 0 and flag2 == 0:
81 | flag = 0
82 |
83 | # collide edge
84 | if state_position[0] > 1080 or \
85 | state_position[0] < 0 or \
86 | state_position[1] > 810 or \
87 | state_position[1] < 0:
88 | flag = 1
89 |
90 | return flag
91 |
92 | def find(self, state_position):
93 | flag = 0
94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \
95 | self.limit_distance_x and \
96 | abs(state_position[1] - self.bird_female_init_position[1]) < \
97 | self.limit_distance_y:
98 | flag = 1
99 | return flag
100 |
101 | def state_to_position(self, state):
102 | i = int(state / 10)
103 | j = state % 10
104 | position = [0, 0]
105 | position[0] = 120 * j
106 | position[1] = 90 * i
107 | return position
108 |
109 | def position_to_state(self, position):
110 | i = position[0] / 120
111 | j = position[1] / 90
112 | return int(i + 10 * j)
113 |
114 | def reset(self):
115 | # 随机产生一个初始位置
116 | flag1 = 1
117 | flag2 = 1
118 | while flag1 or flag2 == 1:
119 | state = self.states[int(random.random() * len(self.states))]
120 | state_position = self.state_to_position(state)
121 | flag1 = self.collide(state_position)
122 | flag2 = self.find(state_position)
123 | return state
124 |
125 | def transform(self, state, action):
126 | current_position = self.state_to_position(state)
127 | next_position = [0, 0]
128 | flag_collide = 0
129 | flag_find = 0
130 |
131 | flag_collide = self.collide(current_position)
132 | flag_find = self.find(current_position)
133 | if flag_collide == 1:
134 | return state, -10, True
135 | if flag_find == 1:
136 | return state, 10, True
137 |
138 | if action == 'e':
139 | next_position[0] = current_position[0] + 120
140 | next_position[1] = current_position[1]
141 | if action == 's':
142 | next_position[0] = current_position[0]
143 | next_position[1] = current_position[1] + 90
144 | if action == 'w':
145 | next_position[0] = current_position[0] - 120
146 | next_position[1] = current_position[1]
147 | if action == 'n':
148 | next_position[0] = current_position[0]
149 | next_position[1] = current_position[1] - 90
150 |
151 | flag_collide = self.collide(next_position)
152 | if flag_collide == 1:
153 | return self.position_to_state(current_position), -10, True
154 |
155 | flag_find = self.find(next_position)
156 | if flag_find == 1:
157 | return self.position_to_state(next_position), 10, True
158 |
159 | return self.position_to_state(next_position), -2, False
160 |
161 | def gameover(self):
162 | for event in pygame.event.get():
163 | if event.type == pygame.QUIT:
164 | exit()
165 |
166 | def render(self):
167 | if self.viewer is None:
168 | pygame.init()
169 |
170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32)
171 | pygame.display.set_caption("yuanyang")
172 | # load pic
173 | self.bird_male = load_bird_male()
174 | self.bird_female = load_bird_female()
175 | self.background = load_background()
176 | self.obstacle = load_obstacle()
177 |
178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position)
179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position)
180 |
181 | self.viewer.blit(self.background, (0, 0))
182 | self.font = pygame.font.SysFont('times', 15)
183 |
184 | self.viewer.blit(self.background, (0, 0))
185 | for i in range(11):
186 | pygame.draw.lines(self.viewer,
187 | (255, 255, 255),
188 | True,
189 | ((120 * i, 0), (120 * i, 900)),
190 | 1
191 | )
192 | pygame.draw.lines(self.viewer,
193 | (255, 255, 255),
194 | True,
195 | ((0, 90 * i), (1200, 90 * i)),
196 | 1
197 | )
198 |
199 | for i in range(8):
200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
202 |
203 | self.viewer.blit(self.bird_female, self.bird_female_init_position)
204 | self.viewer.blit(self.bird_male, self.bird_male_init_position)
205 |
206 | # 画动作-值函数
207 | for i in range(100):
208 | y = int(i / 10)
209 | x = i % 10
210 | # 往东的值函数
211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0))
212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45))
213 | # 往南的值函数
214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0))
215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70))
216 | # 往西的值函数
217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0))
218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45))
219 | # 往北的值函数
220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0))
221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10))
222 |
223 | # 画路径点
224 | for i in range(len(self.path)):
225 | rec_position = self.state_to_position(self.path[i])
226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
227 | surface = self.font.render(str(i), True, (255, 0, 0))
228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
229 |
230 | pygame.display.update()
231 | self.gameover()
232 | self.FPSCLOCK.tick(30)
233 |
234 |
235 | if __name__ == "__main__":
236 | yy = YuanYangEnv()
237 | yy.render()
238 | while True:
239 | for event in pygame.event.get():
240 | if event.type == pygame.QUIT:
241 | exit()
242 |
243 |
--------------------------------------------------------------------------------
/ch_1/sec_5/TD_RL.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import os
4 | import pygame
5 | import time
6 | import matplotlib.pyplot as plt
7 | from yuan_yang_env_td import YuanYangEnv
8 |
9 |
10 | class TD_RL:
11 | def __init__(self, yuanyang):
12 | self.gamma = yuanyang.gamma
13 | self.yuanyang = yuanyang
14 | # 值函数的初始值
15 | self.qvalue = np.zeros(
16 | (len(self.yuanyang.states), len(self.yuanyang.actions))
17 | )
18 |
19 | # 定义贪婪策略
20 | def greedy_policy(self, qfun, state):
21 | amax = qfun[state, :].argmax()
22 | return self.yuanyang.actions[amax]
23 |
24 | # 定义 epsilon 贪婪策略
25 | def epsilon_greedy_policy(self, qfun, state, epsilon):
26 | amax = qfun[state, :].argmax()
27 | # 概率部分
28 | if np.random.uniform() < 1 - epsilon:
29 | # 最优动作
30 | return self.yuanyang.actions[amax]
31 | else:
32 | return self.yuanyang.actions[
33 | int(random.random() * len(self.yuanyang.actions))
34 | ]
35 |
36 | # 找到动作所对应的序号
37 | def find_anum(self, a):
38 | for i in range(len(self.yuanyang.actions)):
39 | if a == self.yuanyang.actions[i]:
40 | return i
41 |
42 | def sarsa(self, num_iter, alpha, epsilon):
43 | iter_num = []
44 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
45 | # 第1个大循环,产生了多少实验
46 | for iter in range(num_iter):
47 | # 随机初始化状态
48 | epsilon = epsilon * 0.99
49 | s_sample = []
50 | # 初始状态
51 | s = 0
52 | flag = self.greedy_test()
53 | if flag == 1:
54 | iter_num.append(iter)
55 | if len(iter_num) < 2:
56 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0]))
57 | if flag == 2:
58 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter))
59 | break
60 | # 利用 epsilon-greedy 策略选初始动作
61 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
62 | t = False
63 | count = 0
64 |
65 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate
66 | while t == False and count < 30:
67 | # 与环境交互得到下一状态
68 | s_next, r, t = self.yuanyang.transform(s, a)
69 | a_num = self.find_anum(a)
70 | # 本轨迹中已有,给出负回报
71 | if s_next in s_sample:
72 | r = -2
73 | s_sample.append(s)
74 | # 判断是否是终止状态
75 | if t == True:
76 | q_target = r
77 | else:
78 | # 下一状态处的最大动作,体现同轨策略
79 | a1 = self.epsilon_greedy_policy(self.qvalue, s_next, epsilon)
80 | a1_num = self.find_anum(a1)
81 | # Q-learning 的更新公式(SARSA)
82 | q_target = r + self.gamma * self.qvalue[s_next, a1_num]
83 | # 利用 td 方法更新动作值函数 alpha
84 | self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num])
85 | # 转到下一状态
86 | s = s_next
87 | # 行为策略
88 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
89 | count += 1
90 |
91 | return self.qvalue
92 |
93 | def greedy_test(self):
94 | s = 0
95 | s_sample = []
96 | done = False
97 | flag = 0
98 | step_num = 0
99 | while done == False and step_num < 30:
100 | a = self.greedy_policy(self.qvalue, s)
101 | # 与环境交互
102 | s_next, r, done = self.yuanyang.transform(s, a)
103 | s_sample.append(s)
104 | s = s_next
105 | step_num += 1
106 |
107 | if s == 9:
108 | flag = 1
109 | if s == 9 and step_num < 21:
110 | flag = 2
111 |
112 | return flag
113 |
114 | def qlearning(self, num_iter, alpha, epsilon):
115 | iter_num = []
116 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
117 | # 第1个大循环,产生了多少实验
118 | for iter in range(num_iter):
119 | # 随机初始化状态
120 | epsilon = epsilon * 0.99
121 | s_sample = []
122 | # 初始状态
123 | s = 0
124 | flag = self.greedy_test()
125 | if flag == 1:
126 | iter_num.append(iter)
127 | if len(iter_num) < 2:
128 | print('q-learning 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0]))
129 | if flag == 2:
130 | print('q-learning 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter))
131 | break
132 | # 随机选取初始动作
133 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
134 | t = False
135 | count = 0
136 |
137 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate
138 | while t == False and count < 30:
139 | # 与环境交互得到下一状态
140 | s_next, r, t = self.yuanyang.transform(s, a)
141 | a_num = self.find_anum(a)
142 | # 本轨迹中已有,给出负回报
143 | if s_next in s_sample:
144 | r = -2
145 | s_sample.append(s)
146 | # 判断是否是终止状态
147 | if t == True:
148 | q_target = r
149 | else:
150 | # 下一状态处的最大动作,体现同轨策略
151 | a1 = self.greedy_policy(self.qvalue, s_next)
152 | a1_num = self.find_anum(a1)
153 | # Q-learning 的更新公式 TD(0)
154 | q_target = r + self.gamma * self.qvalue[s_next, a1_num]
155 | # 利用 td 方法更新动作值函数 alpha
156 | self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num])
157 | # 转到下一状态
158 | s = s_next
159 | # 行为策略
160 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
161 | count += 1
162 |
163 | return self.qvalue
164 |
165 | def sarsa():
166 | yuanyang = YuanYangEnv()
167 | brain = TD_RL(yuanyang)
168 | qvalue1 = brain.sarsa(num_iter=5000, alpha=0.1, epsilon=0.8)
169 |
170 | # 打印
171 | flag = 1
172 | s = 0
173 | path = []
174 | # 将 v 值打印出来
175 | yuanyang.action_value = qvalue1
176 | step_num = 0
177 |
178 | # 将最优路径打印出来
179 | while flag:
180 | path.append(s)
181 | yuanyang.path = path
182 | a = brain.greedy_policy(qvalue1, s)
183 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3])
184 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
185 | yuanyang.render()
186 | time.sleep(0.1)
187 | step_num += 1
188 | s_, r, t = yuanyang.transform(s, a)
189 | if t == True or step_num > 200:
190 | flag = 0
191 | s = s_
192 |
193 | # 渲染最后的路径点
194 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
195 | path.append(s)
196 | yuanyang.render()
197 | while True:
198 | yuanyang.render()
199 |
200 | def qlearning():
201 | yuanyang = YuanYangEnv()
202 | brain = TD_RL(yuanyang)
203 | qvalue2 = brain.qlearning(num_iter=5000, alpha=0.1, epsilon=0.1)
204 |
205 | # 打印
206 | flag = 1
207 | s = 0
208 | path = []
209 | # 将 v 值打印出来
210 | yuanyang.action_value = qvalue2
211 | step_num = 0
212 |
213 | # 将最优路径打印出来
214 | while flag:
215 | path.append(s)
216 | yuanyang.path = path
217 | a = brain.greedy_policy(qvalue2, s)
218 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3])
219 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
220 | yuanyang.render()
221 | time.sleep(0.1)
222 | step_num += 1
223 | s_, r, t = yuanyang.transform(s, a)
224 | if t == True or step_num > 200:
225 | flag = 0
226 | s = s_
227 |
228 | # 渲染最后的路径点
229 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
230 | path.append(s)
231 | yuanyang.render()
232 | while True:
233 | yuanyang.render()
234 |
235 |
236 | if __name__ == "__main__":
237 | # sarsa()
238 | qlearning()
--------------------------------------------------------------------------------
/ch_1/sec_5/load_images.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import os.path as osp
3 |
4 | path_file = osp.abspath(__file__)
5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images')
6 |
7 | def load_bird_male():
8 | obj = 'bird_male.png'
9 | obj_path = osp.join(path_images, obj)
10 | return pygame.image.load(obj_path)
11 |
12 | def load_bird_female():
13 | obj = 'bird_female.png'
14 | obj_path = osp.join(path_images, obj)
15 | return pygame.image.load(obj_path)
16 |
17 | def load_background():
18 | obj = 'background.jpg'
19 | obj_path = osp.join(path_images, obj)
20 | return pygame.image.load(obj_path)
21 |
22 | def load_obstacle():
23 | obj = 'obstacle.png'
24 | obj_path = osp.join(path_images, obj)
25 | return pygame.image.load(obj_path)
26 |
--------------------------------------------------------------------------------
/ch_1/sec_5/yuan_yang_env_td.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import random
3 | from load_images import *
4 | import numpy as np
5 |
6 | class YuanYangEnv:
7 | def __init__(self):
8 | self.states = []
9 | for i in range(0, 100):
10 | self.states.append(i)
11 | self.actions = ['e', 's', 'w', 'n']
12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减
13 | self.gamma = 0.95
14 | self.action_value = np.zeros((100, 4))
15 |
16 | self.viewer = None
17 | self.FPSCLOCK = pygame.time.Clock()
18 |
19 | self.screen_size = (1200, 900)
20 | self.bird_position = (0, 0)
21 | self.limit_distance_x = 120
22 | self.limit_distance_y = 90
23 | self.obstacle_size = [120, 90]
24 | self.obstacle1_x = []
25 | self.obstacle1_y = []
26 | self.obstacle2_x = []
27 | self.obstacle2_y = []
28 |
29 | for i in range(8):
30 | # obstacle 1
31 | self.obstacle1_x.append(360)
32 | if i <= 3:
33 | self.obstacle1_y.append(90 * i)
34 | else:
35 | self.obstacle1_y.append(90 * (i + 2))
36 | # obstacle 2
37 | self.obstacle2_x.append(720)
38 | if i <= 4:
39 | self.obstacle2_y.append(90 * i)
40 | else:
41 | self.obstacle2_y.append(90 * (i + 2))
42 |
43 | self.bird_male_init_position = [0.0, 0.0]
44 | self.bird_male_position = [0, 0]
45 | self.bird_female_init_position = [1080, 0]
46 |
47 | self.path = []
48 |
49 | def collide(self, state_position):
50 | flag = 1
51 | flag1 = 1
52 | flag2 = 1
53 |
54 | # obstacle 1
55 | dx = []
56 | dy = []
57 | for i in range(8):
58 | dx1 = abs(self.obstacle1_x[i] - state_position[0])
59 | dx.append(dx1)
60 | dy1 = abs(self.obstacle1_y[i] - state_position[1])
61 | dy.append(dy1)
62 | mindx = min(dx)
63 | mindy = min(dy)
64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
65 | flag1 = 0
66 |
67 | # obstacle 2
68 | dx_second = []
69 | dy_second = []
70 | for i in range(8):
71 | dx1 = abs(self.obstacle2_x[i] - state_position[0])
72 | dx_second.append(dx1)
73 | dy1 = abs(self.obstacle2_y[i] - state_position[1])
74 | dy_second.append(dy1)
75 | mindx = min(dx_second)
76 | mindy = min(dy_second)
77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
78 | flag2 = 0
79 |
80 | if flag1 == 0 and flag2 == 0:
81 | flag = 0
82 |
83 | # collide edge
84 | if state_position[0] > 1080 or \
85 | state_position[0] < 0 or \
86 | state_position[1] > 810 or \
87 | state_position[1] < 0:
88 | flag = 1
89 |
90 | return flag
91 |
92 | def find(self, state_position):
93 | flag = 0
94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \
95 | self.limit_distance_x and \
96 | abs(state_position[1] - self.bird_female_init_position[1]) < \
97 | self.limit_distance_y:
98 | flag = 1
99 | return flag
100 |
101 | def state_to_position(self, state):
102 | i = int(state / 10)
103 | j = state % 10
104 | position = [0, 0]
105 | position[0] = 120 * j
106 | position[1] = 90 * i
107 | return position
108 |
109 | def position_to_state(self, position):
110 | i = position[0] / 120
111 | j = position[1] / 90
112 | return int(i + 10 * j)
113 |
114 | def reset(self):
115 | # 随机产生一个初始位置
116 | flag1 = 1
117 | flag2 = 1
118 | while flag1 or flag2 == 1:
119 | state = self.states[int(random.random() * len(self.states))]
120 | state_position = self.state_to_position(state)
121 | flag1 = self.collide(state_position)
122 | flag2 = self.find(state_position)
123 | return state
124 |
125 | def transform(self, state, action):
126 | current_position = self.state_to_position(state)
127 | next_position = [0, 0]
128 | flag_collide = 0
129 | flag_find = 0
130 |
131 | flag_collide = self.collide(current_position)
132 | flag_find = self.find(current_position)
133 | if flag_collide == 1:
134 | return state, -10, True
135 | if flag_find == 1:
136 | return state, 10, True
137 |
138 | if action == 'e':
139 | next_position[0] = current_position[0] + 120
140 | next_position[1] = current_position[1]
141 | if action == 's':
142 | next_position[0] = current_position[0]
143 | next_position[1] = current_position[1] + 90
144 | if action == 'w':
145 | next_position[0] = current_position[0] - 120
146 | next_position[1] = current_position[1]
147 | if action == 'n':
148 | next_position[0] = current_position[0]
149 | next_position[1] = current_position[1] - 90
150 |
151 | flag_collide = self.collide(next_position)
152 | if flag_collide == 1:
153 | return self.position_to_state(current_position), -1000, True
154 |
155 | flag_find = self.find(next_position)
156 | if flag_find == 1:
157 | return self.position_to_state(next_position), 10, True
158 |
159 | return self.position_to_state(next_position), -2, False
160 |
161 | def gameover(self):
162 | for event in pygame.event.get():
163 | if event.type == pygame.QUIT:
164 | exit()
165 |
166 | def render(self):
167 | if self.viewer is None:
168 | pygame.init()
169 |
170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32)
171 | pygame.display.set_caption("yuanyang")
172 | # load pic
173 | self.bird_male = load_bird_male()
174 | self.bird_female = load_bird_female()
175 | self.background = load_background()
176 | self.obstacle = load_obstacle()
177 |
178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position)
179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position)
180 |
181 | self.viewer.blit(self.background, (0, 0))
182 | self.font = pygame.font.SysFont('times', 15)
183 |
184 | self.viewer.blit(self.background, (0, 0))
185 | for i in range(11):
186 | pygame.draw.lines(self.viewer,
187 | (255, 255, 255),
188 | True,
189 | ((120 * i, 0), (120 * i, 900)),
190 | 1
191 | )
192 | pygame.draw.lines(self.viewer,
193 | (255, 255, 255),
194 | True,
195 | ((0, 90 * i), (1200, 90 * i)),
196 | 1
197 | )
198 |
199 | for i in range(8):
200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
202 |
203 | self.viewer.blit(self.bird_female, self.bird_female_init_position)
204 | self.viewer.blit(self.bird_male, self.bird_male_init_position)
205 |
206 | # 画动作-值函数
207 | for i in range(100):
208 | y = int(i / 10)
209 | x = i % 10
210 | # 往东的值函数
211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0))
212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45))
213 | # 往南的值函数
214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0))
215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70))
216 | # 往西的值函数
217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0))
218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45))
219 | # 往北的值函数
220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0))
221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10))
222 |
223 | # 画路径点
224 | for i in range(len(self.path)):
225 | rec_position = self.state_to_position(self.path[i])
226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
227 | surface = self.font.render(str(i), True, (255, 0, 0))
228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
229 |
230 | pygame.display.update()
231 | self.gameover()
232 | self.FPSCLOCK.tick(30)
233 |
234 |
235 | if __name__ == "__main__":
236 | yy = YuanYangEnv()
237 | yy.render()
238 | while True:
239 | for event in pygame.event.get():
240 | if event.type == pygame.QUIT:
241 | exit()
242 |
243 |
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/die.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/die.ogg
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/die.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/die.wav
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/hit.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/hit.ogg
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/hit.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/hit.wav
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/point.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/point.ogg
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/point.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/point.wav
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/swoosh.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/swoosh.ogg
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/swoosh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/swoosh.wav
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/wing.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/wing.ogg
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/audio/wing.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/wing.wav
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/0.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/1.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/2.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/3.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/4.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/5.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/6.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/7.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/8.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/9.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/background-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/background-black.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/base.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/base.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/pipe-green.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/pipe-green.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/redbird-downflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-downflap.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/redbird-midflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-midflap.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/assets/sprites/redbird-upflap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-upflap.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/dqn_agent.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow import keras
3 | import tensorflow.keras.layers as kl
4 | import tensorflow.keras.losses as kls
5 | import tensorflow.keras.optimizers as ko
6 | import numpy as np
7 | import cv2
8 | import sys
9 | import os.path as osp
10 | dirname = osp.dirname(__file__)
11 | sys.path.append(dirname)
12 | from game.wrapped_flappy_bird import GameState
13 | import random
14 |
15 | GAME = 'flappy bird'
16 | ACTIONS = 2
17 | GAMMA = 0.99
18 | OBSERVE = 250. # 训练前观察的步长
19 | EXPLORE = 3.0e6 # 随机探索的时间
20 | FINAL_EPSILON = 1.0e-4 # 最终探索率
21 | INITIAL_EPSILON = 0.1 # 初始探索率
22 | REPLAY_MEMORY = 50000 # 经验池的大小
23 | BATCH = 32 # mini-batch 的大小
24 | FRAME_PER_ACTION = 1 # 跳帧
25 |
26 | class Experience_Buffer:
27 | def __init__(self, buffer_size=REPLAY_MEMORY):
28 | super().__init__()
29 | self.buffer = []
30 | self.buffer_size = buffer_size
31 |
32 | def add_experience(self, experience):
33 | if len(self.buffer) + len(experience) >= self.buffer_size:
34 | self.buffer[0:len(self.buffer) + len(experience) - self.buffer_size] = []
35 | self.buffer.extend(experience)
36 |
37 | def sample(self, samples_num):
38 | samples_data = random.sample(self.buffer, samples_num)
39 | train_s = [d[0] for d in samples_data]
40 | train_s = np.asarray(train_s)
41 | train_a = [d[1] for d in samples_data]
42 | train_a = np.asarray(train_a)
43 | train_r = [d[2] for d in samples_data]
44 | train_r = np.asarray(train_r)
45 | train_s_ = [d[3] for d in samples_data]
46 | train_s_ = np.asarray(train_s_)
47 | train_terminal = [d[4] for d in samples_data]
48 | train_terminal = np.asarray(train_terminal)
49 | return train_s, train_a, train_r, train_s_, train_terminal
50 |
51 | class Deep_Q_N:
52 | def __init__(self, lr=1.0e-6, model_file=None):
53 | self.gamma = GAMMA
54 | self.tau = 0.01
55 | # tf
56 | self.learning_rate = lr
57 | self.q_model = self.build_q_net()
58 | self.q_target_model = self.build_q_net()
59 | if model_file is not None:
60 | self.restore_model(model_file)
61 |
62 | def save_model(self, model_path):
63 | self.q_model.save(model_path + '0')
64 | self.q_target_model.save(model_path + '1')
65 |
66 | def restore_model(self, model_path):
67 | self.q_model.load_weights(model_path + '0')
68 | self.q_target_model.load_weights(model_path + '1')
69 |
70 | def build_q_net(self):
71 | model = keras.Sequential()
72 | h_conv1 = kl.Conv2D(
73 | input_shape=(80, 80, 4),
74 | filters=32, kernel_size=8,
75 | data_format='channels_last',
76 | strides=4, padding='same',
77 | activation='relu'
78 | )
79 | h_pool1 = kl.MaxPool2D(
80 | pool_size=2, strides=2, padding='same'
81 | )
82 | h_conv2 = kl.Conv2D(
83 | filters=64, kernel_size=4,
84 | strides=2, padding='same',
85 | activation='relu'
86 | )
87 | h_conv3 = kl.Conv2D(
88 | filters=64, kernel_size=3,
89 | strides=1, padding='same',
90 | activation='relu'
91 | )
92 | h_conv3_flat = kl.Flatten()
93 | h_fc1 = kl.Dense(512, activation='relu')
94 | qout = kl.Dense(ACTIONS)
95 | model.add(h_conv1)
96 | model.add(h_pool1)
97 | model.add(h_conv2)
98 | model.add(h_conv3)
99 | model.add(h_conv3_flat)
100 | model.add(h_fc1)
101 | model.add(qout)
102 |
103 | model.compile(
104 | optimizer=ko.Adam(lr=self.learning_rate),
105 | loss=[self._get_mse_for_action]
106 | )
107 |
108 | return model
109 |
110 | def _get_mse_for_action(self, target_and_action, current_prediction):
111 | targets, one_hot_action = tf.split(target_and_action, [1, 2], axis=1)
112 | active_q_value = tf.expand_dims(tf.reduce_sum(current_prediction * one_hot_action, axis=1), axis=-1)
113 | return kls.mean_squared_error(targets, active_q_value)
114 |
115 | def _update_target(self):
116 | q_weights = self.q_model.get_weights()
117 | q_target_weights = self.q_target_model.get_weights()
118 |
119 | q_weights = [self.tau * w for w in q_weights]
120 | q_target_weights = [(1. - self.tau) * w for w in q_target_weights]
121 | new_weights = [
122 | q_weights[i] + q_target_weights[i]
123 | for i in range(len(q_weights))
124 | ]
125 | self.q_target_model.set_weights(new_weights)
126 |
127 | def _one_hot_action(self, actions):
128 | action_index = np.array(actions)
129 | batch_size = len(actions)
130 | result = np.zeros((batch_size, 2))
131 | result[np.arange(batch_size), action_index] = 1.
132 | return result
133 |
134 | def epsilon_greedy(self, s_t, epsilon):
135 | s_t = s_t.reshape(-1, 80, 80, 4)
136 | amax = np.argmax(self.q_model.predict(s_t)[0])
137 | # 概率部分
138 | if np.random.uniform() < 1 - epsilon:
139 | # 最优动作
140 | a_t = amax
141 | else:
142 | a_t = random.randint(0, 1)
143 | return a_t
144 |
145 | def train_Network(self, experience_buffer):
146 | # 打开游戏状态与模拟器进行通信
147 | game_state = GameState(fps=100)
148 | # 获得第1个状态并将图形进行预处理
149 | do_nothing = 0
150 | # 与游戏交互1次
151 | x_t, r_0, terminal = game_state.frame_step(do_nothing)
152 | x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
153 | ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)
154 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
155 | # 开始训练
156 | epsilon = INITIAL_EPSILON
157 | t = 0
158 | while "flappy bird" != "angry bird":
159 | a_t = self.epsilon_greedy(s_t, epsilon)
160 | # epsilon 递减
161 | if epsilon > FINAL_EPSILON and t > OBSERVE:
162 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
163 | # 运用动作与环境交互1次
164 | x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
165 | x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
166 | ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
167 | x_t1 = np.reshape(x_t1, (80, 80, 1))
168 | s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)
169 | # 将数据存储到经验池中
170 | experience = np.reshape(np.array([s_t, a_t, r_t, s_t1, terminal]), [1, 5])
171 | experience_buffer.add_experience(experience)
172 | # 在观察结束后进行训练
173 | if t > OBSERVE:
174 | # 采集样本
175 | train_s, train_a, train_r, train_s_, train_terminal = experience_buffer.sample(BATCH)
176 | target_q = []
177 | read_target_Q = self.q_target_model.predict(train_s_)
178 | for i in range(len(train_r)):
179 | if train_terminal[i]:
180 | target_q.append(train_r[i])
181 | else:
182 | target_q.append(train_r[i] * GAMMA * np.max(read_target_Q[i]))
183 | # 训练 1 步
184 | one_hot_actions = self._one_hot_action(train_a)
185 | target_q = np.asarray(target_q)
186 | target_and_actions = np.concatenate((target_q[:, None], one_hot_actions), axis=1)
187 | loss = self.q_model.train_on_batch(train_s, target_and_actions)
188 | # 更新旧的网络
189 | self._update_target()
190 | # 往前推进1步
191 | s_t = s_t1
192 | t += 1
193 | # 每 10000 次迭代保存1次
194 | if t % 10000 == 0:
195 | dirname = osp.dirname(__file__)
196 | self.save_model(dirname + '\\saved_networks')
197 | if t <= OBSERVE:
198 | print("OBSERVER", t)
199 | else:
200 | if t % 1 == 0:
201 | print("train, steps", t, "/epsion ", epsilon, "action_index", a_t, "/reward", r_t)
202 |
203 |
204 | if __name__=="__main__":
205 | buffer = Experience_Buffer()
206 | brain = Deep_Q_N()
207 | brain.train_Network(buffer)
208 |
209 |
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/game/flappy_bird_utils.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import sys
3 | import os.path as osp
4 |
5 | path = osp.dirname(__file__)
6 | path = osp.join(path, '..')
7 |
8 | def load():
9 | # path of player with different states
10 | PLAYER_PATH = (
11 | path + '/assets/sprites/redbird-upflap.png',
12 | path + '/assets/sprites/redbird-midflap.png',
13 | path + '/assets/sprites/redbird-downflap.png'
14 | )
15 |
16 | # path of background
17 | BACKGROUND_PATH = path + '/assets/sprites/background-black.png'
18 |
19 | # path of pipe
20 | PIPE_PATH = path + '/assets/sprites/pipe-green.png'
21 |
22 | IMAGES, SOUNDS, HITMASKS = {}, {}, {}
23 |
24 | # numbers sprites for score display
25 | IMAGES['numbers'] = (
26 | pygame.image.load(path + '/assets/sprites/0.png').convert_alpha(),
27 | pygame.image.load(path + '/assets/sprites/1.png').convert_alpha(),
28 | pygame.image.load(path + '/assets/sprites/2.png').convert_alpha(),
29 | pygame.image.load(path + '/assets/sprites/3.png').convert_alpha(),
30 | pygame.image.load(path + '/assets/sprites/4.png').convert_alpha(),
31 | pygame.image.load(path + '/assets/sprites/5.png').convert_alpha(),
32 | pygame.image.load(path + '/assets/sprites/6.png').convert_alpha(),
33 | pygame.image.load(path + '/assets/sprites/7.png').convert_alpha(),
34 | pygame.image.load(path + '/assets/sprites/8.png').convert_alpha(),
35 | pygame.image.load(path + '/assets/sprites/9.png').convert_alpha()
36 | )
37 |
38 | # base (ground) sprite
39 | IMAGES['base'] = pygame.image.load(path + '/assets/sprites/base.png').convert_alpha()
40 |
41 | # sounds
42 | if 'win' in sys.platform:
43 | soundExt = '.wav'
44 | else:
45 | soundExt = '.ogg'
46 |
47 | SOUNDS['die'] = pygame.mixer.Sound(path + '/assets/audio/die' + soundExt)
48 | SOUNDS['hit'] = pygame.mixer.Sound(path + '/assets/audio/hit' + soundExt)
49 | SOUNDS['point'] = pygame.mixer.Sound(path + '/assets/audio/point' + soundExt)
50 | SOUNDS['swoosh'] = pygame.mixer.Sound(path + '/assets/audio/swoosh' + soundExt)
51 | SOUNDS['wing'] = pygame.mixer.Sound(path + '/assets/audio/wing' + soundExt)
52 |
53 | # select random background sprites
54 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()
55 |
56 | # select random player sprites
57 | IMAGES['player'] = (
58 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
59 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
60 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
61 | )
62 |
63 | # select random pipe sprites
64 | IMAGES['pipe'] = (
65 | pygame.transform.rotate(
66 | pygame.image.load(PIPE_PATH).convert_alpha(), 180),
67 | pygame.image.load(PIPE_PATH).convert_alpha(),
68 | )
69 |
70 | # hismask for pipes
71 | HITMASKS['pipe'] = (
72 | getHitmask(IMAGES['pipe'][0]),
73 | getHitmask(IMAGES['pipe'][1]),
74 | )
75 |
76 | # hitmask for player
77 | HITMASKS['player'] = (
78 | getHitmask(IMAGES['player'][0]),
79 | getHitmask(IMAGES['player'][1]),
80 | getHitmask(IMAGES['player'][2]),
81 | )
82 |
83 | return IMAGES, SOUNDS, HITMASKS
84 |
85 | def getHitmask(image):
86 | """returns a hitmask using an image's alpha."""
87 | mask = []
88 | for x in range(image.get_width()):
89 | mask.append([])
90 | for y in range(image.get_height()):
91 | mask[x].append(bool(image.get_at((x,y))[3]))
92 | return mask
93 |
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/game/keyboard_agent.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import sys
3 | from wrapped_flappy_bird import GameState
4 |
5 | game_state = GameState(sound=True)
6 |
7 | ACTIONS = [0, 1]
8 |
9 | while True:
10 | action = ACTIONS[0]
11 | for event in pygame.event.get():
12 | if event.type == pygame.QUIT:
13 | sys.exit()
14 | if event.type == pygame.KEYDOWN:
15 | if event.key == pygame.K_h:
16 | action = ACTIONS[1]
17 | game_state.frame_step(action)
18 | pygame.quit()
19 |
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/game/wrapped_flappy_bird.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import os.path as osp
4 | import random
5 | import pygame
6 | dirname = osp.dirname(__file__)
7 | sys.path.append(dirname)
8 | import flappy_bird_utils
9 | import pygame.surfarray as surfarray
10 | from pygame.locals import *
11 | from itertools import cycle
12 |
13 | FPS = 30
14 | SCREENWIDTH = 288
15 | SCREENHEIGHT = 512
16 |
17 | pygame.init()
18 | FPSCLOCK = pygame.time.Clock()
19 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))
20 | pygame.display.set_caption('Flappy Bird')
21 |
22 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()
23 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe
24 | BASEY = SCREENHEIGHT * 0.79
25 |
26 | PLAYER_WIDTH = IMAGES['player'][0].get_width()
27 | PLAYER_HEIGHT = IMAGES['player'][0].get_height()
28 | PIPE_WIDTH = IMAGES['pipe'][0].get_width()
29 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height()
30 | BACKGROUND_WIDTH = IMAGES['background'].get_width()
31 |
32 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])
33 |
34 |
35 | class GameState:
36 | def __init__(self, sound=False, fps=FPS):
37 | self.score = self.playerIndex = self.loopIter = 0
38 | self.playerx = int(SCREENWIDTH * 0.2)
39 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)
40 | self.basex = 0
41 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH
42 | self.sound = sound
43 | self.fps = fps
44 |
45 | newPipe1 = getRandomPipe()
46 | newPipe2 = getRandomPipe()
47 | self.upperPipes = [
48 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
49 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
50 | ]
51 | self.lowerPipes = [
52 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
53 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
54 | ]
55 |
56 | # player velocity, max velocity, downward accleration, accleration on flap
57 | self.pipeVelX = -4
58 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped
59 | self.playerMaxVelY = 10 # max vel along Y, max descend speed
60 | self.playerMinVelY = -8 # min vel along Y, max ascend speed
61 | self.playerAccY = 1 # players downward accleration
62 | self.playerFlapAcc = -9 # players speed on flapping
63 | self.playerFlapped = False # True when player flaps
64 |
65 | def frame_step(self, action):
66 | pygame.event.pump()
67 |
68 | reward = 0.1
69 | terminal = False
70 |
71 | # if sum(input_actions) != 1:
72 | # raise ValueError('Multiple input actions!')
73 |
74 | # input_actions[0] == 1: do nothing
75 | # input_actions[1] == 1: flap the bird
76 | if action == 1:
77 | if self.playery > -2 * PLAYER_HEIGHT:
78 | self.playerVelY = self.playerFlapAcc
79 | self.playerFlapped = True
80 | if self.sound:
81 | SOUNDS['wing'].play()
82 |
83 | # check for score
84 | playerMidPos = self.playerx + PLAYER_WIDTH / 2
85 | for pipe in self.upperPipes:
86 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2
87 | if pipeMidPos <= playerMidPos < pipeMidPos + 4:
88 | self.score += 1
89 | if self.sound:
90 | SOUNDS['point'].play()
91 | reward = 1
92 |
93 | # playerIndex basex change
94 | if (self.loopIter + 1) % 3 == 0:
95 | self.playerIndex = next(PLAYER_INDEX_GEN)
96 | self.loopIter = (self.loopIter + 1) % 30
97 | self.basex = -((-self.basex + 100) % self.baseShift)
98 |
99 | # player's movement
100 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped:
101 | self.playerVelY += self.playerAccY
102 | if self.playerFlapped:
103 | self.playerFlapped = False
104 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT)
105 | if self.playery < 0:
106 | self.playery = 0
107 |
108 | # move pipes to left
109 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
110 | uPipe['x'] += self.pipeVelX
111 | lPipe['x'] += self.pipeVelX
112 |
113 | # add new pipe when first pipe is about to touch left of screen
114 | if 0 < self.upperPipes[0]['x'] < 5:
115 | newPipe = getRandomPipe()
116 | self.upperPipes.append(newPipe[0])
117 | self.lowerPipes.append(newPipe[1])
118 |
119 | # remove first pipe if its out of the screen
120 | if self.upperPipes[0]['x'] < -PIPE_WIDTH:
121 | self.upperPipes.pop(0)
122 | self.lowerPipes.pop(0)
123 |
124 | # check if crash here
125 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery,
126 | 'index': self.playerIndex},
127 | self.upperPipes, self.lowerPipes)
128 | if isCrash:
129 | if self.sound:
130 | SOUNDS['hit'].play()
131 | SOUNDS['die'].play()
132 | terminal = True
133 | self.__init__(sound=self.sound)
134 | reward = -1
135 |
136 | # draw sprites
137 | SCREEN.blit(IMAGES['background'], (0,0))
138 |
139 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
140 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
141 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
142 |
143 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
144 | # print score so player overlaps the score
145 | showScore(self.score)
146 | SCREEN.blit(IMAGES['player'][self.playerIndex],
147 | (self.playerx, self.playery))
148 |
149 | image_data = pygame.surfarray.array3d(pygame.display.get_surface())
150 | pygame.display.update()
151 | FPSCLOCK.tick(self.fps)
152 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2)
153 | return image_data, reward, terminal
154 |
155 | def getRandomPipe():
156 | """returns a randomly generated pipe"""
157 | # y of gap between upper and lower pipe
158 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90]
159 | index = random.randint(0, len(gapYs)-1)
160 | gapY = gapYs[index]
161 |
162 | gapY += int(BASEY * 0.2)
163 | pipeX = SCREENWIDTH + 10
164 |
165 | return [
166 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe
167 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe
168 | ]
169 |
170 |
171 | def showScore(score):
172 | """displays score in center of screen"""
173 | scoreDigits = [int(x) for x in list(str(score))]
174 | totalWidth = 0 # total width of all numbers to be printed
175 |
176 | for digit in scoreDigits:
177 | totalWidth += IMAGES['numbers'][digit].get_width()
178 |
179 | Xoffset = (SCREENWIDTH - totalWidth) / 2
180 |
181 | for digit in scoreDigits:
182 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1))
183 | Xoffset += IMAGES['numbers'][digit].get_width()
184 |
185 |
186 | def checkCrash(player, upperPipes, lowerPipes):
187 | """returns True if player collders with base or pipes."""
188 | pi = player['index']
189 | player['w'] = IMAGES['player'][0].get_width()
190 | player['h'] = IMAGES['player'][0].get_height()
191 |
192 | # if player crashes into ground
193 | if player['y'] + player['h'] >= BASEY - 1:
194 | return True
195 | else:
196 |
197 | playerRect = pygame.Rect(player['x'], player['y'],
198 | player['w'], player['h'])
199 |
200 | for uPipe, lPipe in zip(upperPipes, lowerPipes):
201 | # upper and lower pipe rects
202 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
203 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
204 |
205 | # player and upper/lower pipe hitmasks
206 | pHitMask = HITMASKS['player'][pi]
207 | uHitmask = HITMASKS['pipe'][0]
208 | lHitmask = HITMASKS['pipe'][1]
209 |
210 | # if bird collided with upipe or lpipe
211 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
212 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)
213 |
214 | if uCollide or lCollide:
215 | return True
216 |
217 | return False
218 |
219 | def pixelCollision(rect1, rect2, hitmask1, hitmask2):
220 | """Checks if two objects collide and not just their rects"""
221 | rect = rect1.clip(rect2)
222 |
223 | if rect.width == 0 or rect.height == 0:
224 | return False
225 |
226 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y
227 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y
228 |
229 | for x in range(rect.width):
230 | for y in range(rect.height):
231 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
232 | return True
233 | return False
234 |
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/images/flappy_bird_demp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/flappy_bird_demp.gif
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/images/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/network.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/images/preprocess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/preprocess.png
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/logs_bird/hidden.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/logs_bird/hidden.txt
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/logs_bird/readout.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/logs_bird/readout.txt
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks0/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/saved_model.pb
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00000-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00000-of-00002
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00001-of-00002
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.index
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks1/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/saved_model.pb
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00000-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00000-of-00002
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00001-of-00002
--------------------------------------------------------------------------------
/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.index
--------------------------------------------------------------------------------
/ch_1/sec_6/lfa_rl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import os
4 | import pygame
5 | import time
6 | import matplotlib.pyplot as plt
7 | from yuan_yang_env_fa import YuanYangEnv
8 |
9 |
10 | class LFA_RL:
11 | def __init__(self, yuanyang):
12 | self.gamma = yuanyang.gamma
13 | self.yuanyang = yuanyang
14 | # 基于特征表示的参数
15 | self.theta_tr = np.zeros((400, 1)) * 0.1
16 | # 基于固定稀疏所对应的参数
17 | self.theta_fsr = np.zeros((80, 1)) * 0.1
18 |
19 | def feature_tr(self, s, a):
20 | phi_s_a = np.zeros((1, 400))
21 | phi_s_a[0, 100 * a + s] = 1
22 | return phi_s_a
23 |
24 | # 定义贪婪策略
25 | def greedy_policy_tr(self, state):
26 | qfun = np.array([0, 0, 0, 0]) * 0.1
27 | # 计算行为值函数
28 | for i in range(4):
29 | qfun[i] = np.dot(self.feature_tr(state, i), self.theta_tr)
30 | amax = qfun.argmax()
31 | return self.yuanyang.actions[amax]
32 |
33 | # 定义 epsilon 贪婪策略
34 | def epsilon_greedy_policy_tr(self, state, epsilon):
35 | qfun = np.array([0, 0, 0, 0]) * 0.1
36 | # 计算行为值函数
37 | for i in range(4):
38 | qfun[i] = np.dot(self.feature_tr(state, i), self.theta_tr)
39 | amax = qfun.argmax()
40 | # 概率部分
41 | if np.random.uniform() < 1 - epsilon:
42 | # 最优动作
43 | return self.yuanyang.actions[amax]
44 | else:
45 | return self.yuanyang.actions[
46 | int(random.random() * len(self.yuanyang.actions))
47 | ]
48 |
49 | # 找到动作所对应的序号
50 | def find_anum(self, a):
51 | for i in range(len(self.yuanyang.actions)):
52 | if a == self.yuanyang.actions[i]:
53 | return i
54 |
55 | def qlearning_lfa_tr(self, num_iter, alpha, epsilon):
56 | iter_num = []
57 | self.theta_tr = np.zeros((400, 1)) * 0.1
58 | # 第1个大循环,产生了多少实验
59 | for iter in range(num_iter):
60 | # 随机初始化状态
61 | epsilon = epsilon * 0.99
62 | s_sample = []
63 | # 初始状态
64 | s = 0
65 | flag = self.greedy_test_tr()
66 | if flag == 1:
67 | iter_num.append(iter)
68 | if len(iter_num) < 2:
69 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0]))
70 | if flag == 2:
71 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter))
72 | break
73 | # 利用 epsilon-greedy 策略选初始动作
74 | a = self.epsilon_greedy_policy_tr(s, epsilon)
75 | t = False
76 | count = 0
77 |
78 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate
79 | while t == False and count < 30:
80 | # 与环境交互得到下一状态
81 | s_next, r, t = self.yuanyang.transform(s, a)
82 | a_num = self.find_anum(a)
83 | # 本轨迹中已有,给出负回报
84 | if s_next in s_sample:
85 | r = -2
86 | s_sample.append(s)
87 | # 判断是否是终止状态
88 | if t == True:
89 | q_target = r
90 | else:
91 | # 下一状态处的最大动作,体现同轨策略
92 | a1 = self.greedy_policy_tr(s_next)
93 | a1_num = self.find_anum(a1)
94 | # Q-learning 得到时间差分目标
95 | q_target = r + self.gamma * np.dot(self.feature_tr(s_next, a1_num), self.theta_tr)
96 | # 利用梯度下降的方法对参数进行学习
97 | self.theta_tr = self.theta_tr + alpha * (q_target - np.dot(self.feature_tr(s, a_num), self.theta_tr))[0, 0] * np.transpose(self.feature_tr(s, a_num))
98 | # 转到下一状态
99 | s = s_next
100 | # 行为策略
101 | a = self.epsilon_greedy_policy_tr(s, epsilon)
102 | count += 1
103 |
104 | return self.theta_tr
105 |
106 | def greedy_test_tr(self):
107 | s = 0
108 | s_sample = []
109 | done = False
110 | flag = 0
111 | step_num = 0
112 | while done == False and step_num < 30:
113 | a = self.greedy_policy_tr(s)
114 | # 与环境交互
115 | s_next, r, done = self.yuanyang.transform(s, a)
116 | s_sample.append(s)
117 | s = s_next
118 | step_num += 1
119 |
120 | if s == 9:
121 | flag = 1
122 | if s == 9 and step_num < 21:
123 | flag = 2
124 |
125 | return flag
126 |
127 | def feature_fsr(self, s, a):
128 | phi_s_a = np.zeros((1, 80))
129 | y = int(s / 10)
130 | x = s - 10 * y
131 | phi_s_a[0, 20 * a + x] = 1
132 | phi_s_a[0, 20 * a + 10 + y] = 1
133 | return phi_s_a
134 |
135 | def greedy_policy_fsr(self, state):
136 | qfun = np.array([0, 0, 0, 0]) * 0.1
137 | # 计算行为值函数
138 | for i in range(4):
139 | qfun[i] = np.dot(self.feature_fsr(state, i), self.theta_fsr)
140 | amax = qfun.argmax()
141 | return self.yuanyang.actions[amax]
142 |
143 | # 定义 epsilon 贪婪策略
144 | def epsilon_greedy_policy_fsr(self, state, epsilon):
145 | qfun = np.array([0, 0, 0, 0]) * 0.1
146 | # 计算行为值函数
147 | for i in range(4):
148 | qfun[i] = np.dot(self.feature_fsr(state, i), self.theta_fsr)
149 | amax = qfun.argmax()
150 | # 概率部分
151 | if np.random.uniform() < 1 - epsilon:
152 | # 最优动作
153 | return self.yuanyang.actions[amax]
154 | else:
155 | return self.yuanyang.actions[
156 | int(random.random() * len(self.yuanyang.actions))
157 | ]
158 |
159 | def greedy_test_fsr(self):
160 | s = 0
161 | s_sample = []
162 | done = False
163 | flag = 0
164 | step_num = 0
165 | while done == False and step_num < 30:
166 | a = self.greedy_policy_fsr(s)
167 | # 与环境交互
168 | s_next, r, done = self.yuanyang.transform(s, a)
169 | s_sample.append(s)
170 | s = s_next
171 | step_num += 1
172 |
173 | if s == 9:
174 | flag = 1
175 | if s == 9 and step_num < 21:
176 | flag = 2
177 |
178 | return flag
179 |
180 | def qlearning_lfa_fsr(self, num_iter, alpha, epsilon):
181 | iter_num = []
182 | self.theta_tr = np.zeros((80, 1)) * 0.1
183 | # 第1个大循环,产生了多少实验
184 | for iter in range(num_iter):
185 | # 随机初始化状态
186 | epsilon = epsilon * 0.99
187 | s_sample = []
188 | # 初始状态
189 | s = 0
190 | flag = self.greedy_test_fsr()
191 | if flag == 1:
192 | iter_num.append(iter)
193 | if len(iter_num) < 2:
194 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0]))
195 | if flag == 2:
196 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter))
197 | break
198 | # 利用 epsilon-greedy 策略选初始动作
199 | a = self.epsilon_greedy_policy_fsr(s, epsilon)
200 | t = False
201 | count = 0
202 |
203 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate
204 | while t == False and count < 30:
205 | # 与环境交互得到下一状态
206 | s_next, r, t = self.yuanyang.transform(s, a)
207 | a_num = self.find_anum(a)
208 | # 本轨迹中已有,给出负回报
209 | if s_next in s_sample:
210 | r = -2
211 | s_sample.append(s)
212 | # 判断是否是终止状态
213 | if t == True:
214 | q_target = r
215 | else:
216 | # 下一状态处的最大动作,体现同轨策略
217 | a1 = self.greedy_policy_fsr(s_next)
218 | a1_num = self.find_anum(a1)
219 | # Q-learning 得到时间差分目标
220 | q_target = r + self.gamma * np.dot(self.feature_fsr(s_next, a1_num), self.theta_fsr)
221 | # 利用梯度下降的方法对参数进行学习
222 | self.theta_fsr = self.theta_fsr + alpha * (q_target - np.dot(self.feature_fsr(s, a_num), self.theta_fsr))[0, 0] * np.transpose(self.feature_fsr(s, a_num))
223 | # 转到下一状态
224 | s = s_next
225 | # 行为策略
226 | a = self.epsilon_greedy_policy_fsr(s, epsilon)
227 | count += 1
228 |
229 | return self.theta_fsr
230 |
231 |
232 | def qlearning_lfa_tr():
233 | yuanyang = YuanYangEnv()
234 | brain = LFA_RL(yuanyang)
235 | brain.qlearning_lfa_tr(num_iter=5000, alpha=0.1, epsilon=0.8)
236 |
237 | # 打印
238 | flag = 1
239 | s = 0
240 | path = []
241 | # 将 v 值打印出来
242 | qvalue1 = np.zeros((100, 4))
243 | for i in range(400):
244 | y = int(i / 100)
245 | x = i - 100 * y
246 | qvalue1[x, y] = np.dot(brain.feature_tr(x, y), brain.theta_tr)
247 | yuanyang.action_value = qvalue1
248 | step_num = 0
249 |
250 | # 将最优路径打印出来
251 | while flag:
252 | path.append(s)
253 | yuanyang.path = path
254 | a = brain.greedy_policy_tr(s)
255 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3])
256 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
257 | yuanyang.render()
258 | time.sleep(0.1)
259 | step_num += 1
260 | s_, r, t = yuanyang.transform(s, a)
261 | if t == True or step_num > 200:
262 | flag = 0
263 | s = s_
264 |
265 | # 渲染最后的路径点
266 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
267 | path.append(s)
268 | yuanyang.render()
269 | while True:
270 | yuanyang.render()
271 |
272 | def qlearning_lfa_fsr():
273 | yuanyang = YuanYangEnv()
274 | brain = LFA_RL(yuanyang)
275 | qvalue2 = brain.qlearning_lfa_fsr(num_iter=5000, alpha=0.1, epsilon=0.1)
276 |
277 | # 打印
278 | flag = 1
279 | s = 0
280 | path = []
281 | # 将 v 值打印出来
282 | # 打印
283 | flag = 1
284 | s = 0
285 | path = []
286 | # 将 v 值打印出来
287 | qvalue2 = np.zeros((100, 4))
288 | for i in range(400):
289 | y = int(i / 100)
290 | x = i - 100 * y
291 | qvalue2[x, y] = np.dot(brain.feature_fsr(x, y), brain.theta_fsr)
292 | yuanyang.action_value = qvalue2
293 | step_num = 0
294 |
295 | # 将最优路径打印出来
296 | while flag:
297 | path.append(s)
298 | yuanyang.path = path
299 | a = brain.greedy_policy_fsr(s)
300 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3])
301 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
302 | yuanyang.render()
303 | time.sleep(0.1)
304 | step_num += 1
305 | s_, r, t = yuanyang.transform(s, a)
306 | if t == True or step_num > 200:
307 | flag = 0
308 | s = s_
309 |
310 | # 渲染最后的路径点
311 | yuanyang.bird_male_position = yuanyang.state_to_position(s)
312 | path.append(s)
313 | yuanyang.render()
314 | while True:
315 | yuanyang.render()
316 |
317 |
318 | if __name__ == "__main__":
319 | # qlearning_lfa_tr()
320 | qlearning_lfa_fsr()
--------------------------------------------------------------------------------
/ch_1/sec_6/load_images.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import os.path as osp
3 |
4 | path_file = osp.abspath(__file__)
5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images')
6 |
7 | def load_bird_male():
8 | obj = 'bird_male.png'
9 | obj_path = osp.join(path_images, obj)
10 | return pygame.image.load(obj_path)
11 |
12 | def load_bird_female():
13 | obj = 'bird_female.png'
14 | obj_path = osp.join(path_images, obj)
15 | return pygame.image.load(obj_path)
16 |
17 | def load_background():
18 | obj = 'background.jpg'
19 | obj_path = osp.join(path_images, obj)
20 | return pygame.image.load(obj_path)
21 |
22 | def load_obstacle():
23 | obj = 'obstacle.png'
24 | obj_path = osp.join(path_images, obj)
25 | return pygame.image.load(obj_path)
26 |
--------------------------------------------------------------------------------
/ch_1/sec_6/yuan_yang_env_fa.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import random
3 | from load_images import *
4 | import numpy as np
5 |
6 | class YuanYangEnv:
7 | def __init__(self):
8 | self.states = []
9 | for i in range(0, 100):
10 | self.states.append(i)
11 | self.actions = ['e', 's', 'w', 'n']
12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减
13 | self.gamma = 0.95
14 | self.action_value = np.zeros((100, 4))
15 |
16 | self.viewer = None
17 | self.FPSCLOCK = pygame.time.Clock()
18 |
19 | self.screen_size = (1200, 900)
20 | self.bird_position = (0, 0)
21 | self.limit_distance_x = 120
22 | self.limit_distance_y = 90
23 | self.obstacle_size = [120, 90]
24 | self.obstacle1_x = []
25 | self.obstacle1_y = []
26 | self.obstacle2_x = []
27 | self.obstacle2_y = []
28 |
29 | for i in range(8):
30 | # obstacle 1
31 | self.obstacle1_x.append(360)
32 | if i <= 3:
33 | self.obstacle1_y.append(90 * i)
34 | else:
35 | self.obstacle1_y.append(90 * (i + 2))
36 | # obstacle 2
37 | self.obstacle2_x.append(720)
38 | if i <= 4:
39 | self.obstacle2_y.append(90 * i)
40 | else:
41 | self.obstacle2_y.append(90 * (i + 2))
42 |
43 | self.bird_male_init_position = [0.0, 0.0]
44 | self.bird_male_position = [0, 0]
45 | self.bird_female_init_position = [1080, 0]
46 |
47 | self.path = []
48 |
49 | def collide(self, state_position):
50 | flag = 1
51 | flag1 = 1
52 | flag2 = 1
53 |
54 | # obstacle 1
55 | dx = []
56 | dy = []
57 | for i in range(8):
58 | dx1 = abs(self.obstacle1_x[i] - state_position[0])
59 | dx.append(dx1)
60 | dy1 = abs(self.obstacle1_y[i] - state_position[1])
61 | dy.append(dy1)
62 | mindx = min(dx)
63 | mindy = min(dy)
64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
65 | flag1 = 0
66 |
67 | # obstacle 2
68 | dx_second = []
69 | dy_second = []
70 | for i in range(8):
71 | dx1 = abs(self.obstacle2_x[i] - state_position[0])
72 | dx_second.append(dx1)
73 | dy1 = abs(self.obstacle2_y[i] - state_position[1])
74 | dy_second.append(dy1)
75 | mindx = min(dx_second)
76 | mindy = min(dy_second)
77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
78 | flag2 = 0
79 |
80 | if flag1 == 0 and flag2 == 0:
81 | flag = 0
82 |
83 | # collide edge
84 | if state_position[0] > 1080 or \
85 | state_position[0] < 0 or \
86 | state_position[1] > 810 or \
87 | state_position[1] < 0:
88 | flag = 1
89 |
90 | return flag
91 |
92 | def find(self, state_position):
93 | flag = 0
94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \
95 | self.limit_distance_x and \
96 | abs(state_position[1] - self.bird_female_init_position[1]) < \
97 | self.limit_distance_y:
98 | flag = 1
99 | return flag
100 |
101 | def state_to_position(self, state):
102 | i = int(state / 10)
103 | j = state % 10
104 | position = [0, 0]
105 | position[0] = 120 * j
106 | position[1] = 90 * i
107 | return position
108 |
109 | def position_to_state(self, position):
110 | i = position[0] / 120
111 | j = position[1] / 90
112 | return int(i + 10 * j)
113 |
114 | def reset(self):
115 | # 随机产生一个初始位置
116 | flag1 = 1
117 | flag2 = 1
118 | while flag1 or flag2 == 1:
119 | state = self.states[int(random.random() * len(self.states))]
120 | state_position = self.state_to_position(state)
121 | flag1 = self.collide(state_position)
122 | flag2 = self.find(state_position)
123 | return state
124 |
125 | def transform(self, state, action):
126 | current_position = self.state_to_position(state)
127 | next_position = [0, 0]
128 | flag_collide = 0
129 | flag_find = 0
130 |
131 | flag_collide = self.collide(current_position)
132 | flag_find = self.find(current_position)
133 | if flag_collide == 1:
134 | return state, -10, True
135 | if flag_find == 1:
136 | return state, 10, True
137 |
138 | if action == 'e':
139 | next_position[0] = current_position[0] + 120
140 | next_position[1] = current_position[1]
141 | if action == 's':
142 | next_position[0] = current_position[0]
143 | next_position[1] = current_position[1] + 90
144 | if action == 'w':
145 | next_position[0] = current_position[0] - 120
146 | next_position[1] = current_position[1]
147 | if action == 'n':
148 | next_position[0] = current_position[0]
149 | next_position[1] = current_position[1] - 90
150 |
151 | flag_collide = self.collide(next_position)
152 | if flag_collide == 1:
153 | return self.position_to_state(current_position), -10, True
154 |
155 | flag_find = self.find(next_position)
156 | if flag_find == 1:
157 | return self.position_to_state(next_position), 10, True
158 |
159 | return self.position_to_state(next_position), 0, False
160 |
161 | def gameover(self):
162 | for event in pygame.event.get():
163 | if event.type == pygame.QUIT:
164 | exit()
165 |
166 | def render(self):
167 | if self.viewer is None:
168 | pygame.init()
169 |
170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32)
171 | pygame.display.set_caption("yuanyang")
172 | # load pic
173 | self.bird_male = load_bird_male()
174 | self.bird_female = load_bird_female()
175 | self.background = load_background()
176 | self.obstacle = load_obstacle()
177 |
178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position)
179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position)
180 |
181 | self.viewer.blit(self.background, (0, 0))
182 | self.font = pygame.font.SysFont('times', 15)
183 |
184 | self.viewer.blit(self.background, (0, 0))
185 | for i in range(11):
186 | pygame.draw.lines(self.viewer,
187 | (255, 255, 255),
188 | True,
189 | ((120 * i, 0), (120 * i, 900)),
190 | 1
191 | )
192 | pygame.draw.lines(self.viewer,
193 | (255, 255, 255),
194 | True,
195 | ((0, 90 * i), (1200, 90 * i)),
196 | 1
197 | )
198 |
199 | for i in range(8):
200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
202 |
203 | self.viewer.blit(self.bird_female, self.bird_female_init_position)
204 | self.viewer.blit(self.bird_male, self.bird_male_init_position)
205 |
206 | # 画动作-值函数
207 | for i in range(100):
208 | y = int(i / 10)
209 | x = i % 10
210 | # 往东的值函数
211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0))
212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45))
213 | # 往南的值函数
214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0))
215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70))
216 | # 往西的值函数
217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0))
218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45))
219 | # 往北的值函数
220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0))
221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10))
222 |
223 | # 画路径点
224 | for i in range(len(self.path)):
225 | rec_position = self.state_to_position(self.path[i])
226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
227 | surface = self.font.render(str(i), True, (255, 0, 0))
228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
229 |
230 | pygame.display.update()
231 | self.gameover()
232 | self.FPSCLOCK.tick(30)
233 |
234 |
235 | if __name__ == "__main__":
236 | yy = YuanYangEnv()
237 | yy.render()
238 | while True:
239 | for event in pygame.event.get():
240 | if event.type == pygame.QUIT:
241 | exit()
242 |
243 |
--------------------------------------------------------------------------------
/pyTorch_learn/a_1.py:
--------------------------------------------------------------------------------
1 | # a.1.1
2 | import torch
3 | x = torch.Tensor(2, 3)
4 | print(x)
5 |
6 | x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
7 | print(x)
8 |
9 | x = torch.rand(2, 3)
10 | print(x)
11 |
12 | x = torch.zeros(2, 3)
13 | print(x)
14 |
15 | x = torch.ones(2, 3)
16 | print(x)
17 |
18 | print("")
19 | print("x.size: {}".format(x.size()))
20 | print("x.size[0]: {}".format(x.size()[0]))
21 |
22 | # a.1.2
23 | x = torch.ones(2, 3)
24 | y = torch.ones(2, 3) * 2
25 | print("")
26 | print("a.1.2")
27 | print("x + y = {}".format(x + y))
28 |
29 | print(torch.add(x, y))
30 |
31 | x.add_(y)
32 | print(x)
33 | print("注意:PyTroch中修改 Tensor 内容的操作\
34 | 都会在方法名后加一个下划线,如copy_()、t_()等")
35 |
36 | print("x.zero_(): {}".format(x.zero_()))
37 | print("x: {}".format(x))
38 |
39 | print("")
40 | print("Tensor 也支持 NumPy 中的各种切片操作")
41 |
42 | x[:, 1] = x[:, 1] + 2
43 | print(x)
44 |
45 | print("torch.view()相当于numpy中的reshape()")
46 | print("x.view(1, 6): {}".format(x.view(1, 6)))
47 |
48 | # a.1.3
49 | print("")
50 | print("a.1.3")
51 |
52 | print("Tensor 与 NumPy 的 array 可以转化,但是共享地址")
53 | import numpy as np
54 | x = torch.ones(2, 3)
55 | print(x)
56 |
57 | y = x.numpy()
58 | print("y = x.numpy(): {}".format(y))
59 |
60 | print("x.add_(2): {}".format(x.add_(2)))
61 |
62 | print("y: {}".format(y))
63 |
64 | z = torch.from_numpy(y)
65 | print("z = torch.from_numpy(y): {}".format(z))
66 |
67 | # a.1.4
68 | print("")
69 | print("a.1.4")
70 |
71 | print("Autograd 实现自动梯度")
72 |
73 | from torch.autograd import Variable
74 |
75 | x = Variable(torch.ones(2, 2)*2, requires_grad=True)
76 |
77 | print(x)
78 |
79 | print("x.data {}, \n x's type: {}\n".format(x.data, type(x)))
80 |
81 | y = 2 * (x * x) + 5 * x
82 | y = y.sum()
83 | print("y: {}, \n y's type: {}\n".format(y, type(y)))
84 |
85 | print("y 可视为关于 x 的函数")
86 | print("y 应该是一个标量,y.backward()自动计算梯度")
87 |
88 | y.backward()
89 | print("x.grad: {}".format(x.grad))
90 | print("x.grad 中自动保存梯度")
91 |
--------------------------------------------------------------------------------
/pyTorch_learn/a_2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | class Net(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.conv1 = nn.Conv2d(3, 6, 5)
10 | self.conv2 = nn.Conv2d(6, 16, 5)
11 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
12 | self.fc2 = nn.Linear(120, 84)
13 | self.fc3 = nn.Linear(84, 10)
14 |
15 | def forward(self, x):
16 | '''
17 | 在 __init__() 中并没有真正定义网络结构的关系
18 | 输入输出关系在 forward() 中定义
19 | '''
20 | x = F.max_pool2d(F.relu(self.conv1(x)), 2)
21 | x = F.max_pool2d(F.relu(self.conv2(x)), 2)
22 | '''
23 | 注意, torch.nn 中要求输入的数据是一个 mini-batch ,
24 | 由于数据图像是 3 维的,因此输入数据 x 是 4 维的
25 | 因此进入全连接层前 x.view(-1, ...) 方法化为 2 维
26 | '''
27 | x = x.view(-1, 16 * 5 * 5)
28 | x = F.relu(self.fc1(x))
29 | x = F.relu(self.fc2(x))
30 | x = self.fc3(x)
31 | return x
32 |
33 |
34 | def seeNet():
35 | net = Net()
36 | print(net)
37 | params = list(net.parameters())
38 | print(len(params))
39 | print("第一层的 weight :")
40 | print(params[0].size())
41 | print(net.conv1.weight.size())
42 | print("第一层的 bias :")
43 | print(params[1].size())
44 | print(net.conv1.bias.size())
45 | print("")
46 | print(net.conv1.weight.requires_grad)
47 |
48 | '''
49 | 神经网络的输入输出应该是 Variable
50 | '''
51 | inputV = Variable(torch.rand(1, 3, 32, 32))
52 | # net.__call__()
53 | output = net(inputV)
54 | print(output)
55 |
56 | # seeNet()
57 |
58 | # train
59 | def trainNet():
60 | net = Net()
61 | inputV = Variable(torch.rand(1, 3, 32, 32))
62 | output = net(inputV)
63 |
64 | criterion = nn.CrossEntropyLoss()
65 | label = Variable(torch.LongTensor([4]))
66 | loss = criterion(output, label)
67 | # all data type should be 'Variable'
68 | print(loss)
69 |
70 | import torch.optim as optim
71 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
72 | print("\nbefore optim: {}".format(net.conv1.bias))
73 |
74 | optimizer.zero_grad() # zeros the gradient buffers
75 | loss.backward()
76 | optimizer.step() # Does the update
77 | '''
78 | 参数有变化,但可能很小
79 | '''
80 | print("\nafter optim: {}".format(net.conv1.bias))
81 |
82 | # trainNet()
83 |
84 | # 实战:CIFAR-10
85 | import torchvision
86 | import torchvision.transforms as transforms
87 | import torch.optim as optim
88 |
89 | transform = transforms.Compose(
90 | [transforms.ToTensor(),
91 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
92 | )
93 | trainset = torchvision.datasets.CIFAR10(
94 | root='./pyTorch_learn/data',
95 | train=True,
96 | download=True,
97 | transform=transform
98 | )
99 | trainloader = torch.utils.data.DataLoader(
100 | trainset,
101 | batch_size=4,
102 | shuffle=True,
103 | num_workers=0 # windows 下线程参数设为 0 安全
104 | )
105 |
106 | testset = torchvision.datasets.CIFAR10(
107 | root='./pyTorch_learn/data',
108 | train=False,
109 | download=True,
110 | transform=transform
111 | )
112 | testloader = torch.utils.data.DataLoader(
113 | testset,
114 | batch_size=4,
115 | shuffle=False,
116 | num_workers=0 # windows 下线程参数设为 0 安全
117 | )
118 |
119 | classes = ('plane', 'car', 'bird', 'cat', 'deer',
120 | 'dog', 'frog', 'horse', 'ship', 'truck'
121 | )
122 |
123 | def cifar_10():
124 | net = Net()
125 | criterion = nn.CrossEntropyLoss()
126 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
127 |
128 | for epoch in range(5):
129 |
130 | running_loss = 0.0
131 | for i, data in enumerate(trainloader, 0):
132 | inputs, labels = data
133 | inputs, labels = Variable(inputs), Variable(labels)
134 | optimizer.zero_grad()
135 | outputs = net(inputs)
136 | loss = criterion(outputs, labels)
137 | loss.backward()
138 | optimizer.step()
139 | running_loss += loss.item()
140 | if i % 6000 == 5999:
141 | print('[%d, %5d] loss: %.3f' %
142 | (epoch + 1, i + 1, running_loss / 6000))
143 | running_loss = 0.0
144 |
145 | print('Finished Training')
146 | torch.save(net.state_dict(), './pyTorch_learn/data/' + 'model.pt')
147 | net.load_state_dict(torch.load('./pyTorch_learn/data/' + 'model.pt'))
148 |
149 | correct = 0
150 | total = 0
151 | for data in testloader:
152 | images, labels = data
153 | outputs = net(Variable(images))
154 | # 返回可能性最大的索引 -> 输出标签
155 | _, predicted = torch.max(outputs, 1)
156 | total += labels.size(0)
157 | correct += (predicted == labels).sum()
158 |
159 | print('Accuracy of the network on the 10000 test images: %d %%' % (
160 | 100 * correct / total
161 | ))
162 |
163 | class_correct = list(0. for i in range(10))
164 | class_total = list(0. for i in range(10))
165 | for data in testloader:
166 | images, labels = data
167 | outputs = net(Variable(images))
168 | _, predicted = torch.max(outputs.data, 1)
169 | c = (predicted == labels).squeeze()
170 | for i in range(4): # mini-batch's size = 4
171 | label = labels[i]
172 | class_correct[label] += c[i]
173 | class_total[label] += 1
174 |
175 | for i in range(10):
176 | print('Accuracy of %5s : %2d %%' % (
177 | classes[i], 100 * class_correct[i] / class_total[i]
178 | ))
179 |
180 | # save net
181 | print(net.state_dict().keys())
182 | print(net.state_dict()['conv1.bias'])
183 |
184 | # torch.save(net.state_dict(), 'model.pt')
185 | # net.load_state_dict(torch.load('model.pt'))
186 |
187 | cifar_10()
--------------------------------------------------------------------------------