├── .gitignore ├── AlphaZero ├── current_policy_0.model ├── current_policy_1.model ├── entropy.npy ├── game.py ├── loss.npy ├── main.py ├── mcts_alphaZero.py ├── play.py ├── policy_value_net_pytorch.py └── train.py ├── README.md ├── ch_0 ├── sec_1 │ ├── kb_game.py │ └── main.py └── sec_2 │ ├── images │ ├── background.jpg │ ├── bird_female.png │ ├── bird_male.png │ └── obstacle.png │ ├── load_images.py │ └── yuan_yang_env.py ├── ch_1 ├── sec_3 │ ├── dp_policy_iter.py │ ├── dp_value_iter.py │ ├── load_images.py │ └── yuan_yang_env.py ├── sec_4 │ ├── load_images.py │ ├── mc_rl.py │ └── yuan_yang_env_mc.py ├── sec_5 │ ├── TD_RL.py │ ├── load_images.py │ └── yuan_yang_env_td.py └── sec_6 │ ├── flappy_bird │ ├── assets │ │ ├── audio │ │ │ ├── die.ogg │ │ │ ├── die.wav │ │ │ ├── hit.ogg │ │ │ ├── hit.wav │ │ │ ├── point.ogg │ │ │ ├── point.wav │ │ │ ├── swoosh.ogg │ │ │ ├── swoosh.wav │ │ │ ├── wing.ogg │ │ │ └── wing.wav │ │ └── sprites │ │ │ ├── 0.png │ │ │ ├── 1.png │ │ │ ├── 2.png │ │ │ ├── 3.png │ │ │ ├── 4.png │ │ │ ├── 5.png │ │ │ ├── 6.png │ │ │ ├── 7.png │ │ │ ├── 8.png │ │ │ ├── 9.png │ │ │ ├── background-black.png │ │ │ ├── base.png │ │ │ ├── pipe-green.png │ │ │ ├── redbird-downflap.png │ │ │ ├── redbird-midflap.png │ │ │ └── redbird-upflap.png │ ├── dqn_agent.py │ ├── game │ │ ├── flappy_bird_utils.py │ │ ├── keyboard_agent.py │ │ └── wrapped_flappy_bird.py │ ├── images │ │ ├── flappy_bird_demp.gif │ │ ├── network.png │ │ └── preprocess.png │ ├── logs_bird │ │ ├── hidden.txt │ │ └── readout.txt │ ├── saved_networks0 │ │ ├── saved_model.pb │ │ └── variables │ │ │ ├── variables.data-00000-of-00002 │ │ │ ├── variables.data-00001-of-00002 │ │ │ └── variables.index │ └── saved_networks1 │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00002 │ │ ├── variables.data-00001-of-00002 │ │ └── variables.index │ ├── lfa_rl.py │ ├── load_images.py │ └── yuan_yang_env_fa.py └── pyTorch_learn ├── a_1.py └── a_2.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | /pyTorch_learn/data/ 4 | 5 | *.psd -------------------------------------------------------------------------------- /AlphaZero/current_policy_0.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/current_policy_0.model -------------------------------------------------------------------------------- /AlphaZero/current_policy_1.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/current_policy_1.model -------------------------------------------------------------------------------- /AlphaZero/entropy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/entropy.npy -------------------------------------------------------------------------------- /AlphaZero/game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Board: 4 | def __init__(self, **kwargs): 5 | self.width = int(kwargs.get('width', 8)) 6 | self.height = int(kwargs.get('height', 8)) 7 | self.n_in_row = int(kwargs.get('n_in_row', 5)) 8 | self.players = (1, 2) 9 | self.states = {} 10 | 11 | def init_board(self, start_player=0): 12 | if self.width < self.n_in_row or self.height < self.n_in_row: 13 | raise Exception('board width and can not be \ 14 | less than {}'.format(self.n_in_row)) 15 | # 当前 player 编号 16 | self.current_player = self.players[start_player] 17 | self.available = list(range(self.width * self.height)) 18 | self.states = {} 19 | self.last_move = -1 20 | 21 | def move_to_location(self, move): 22 | """ 23 | 3*3 board's moves like: 24 | 6 7 8 25 | 3 4 5 26 | 0 1 2 27 | and move 5's location is (1, 2) 28 | """ 29 | h = move // self.width 30 | w = move % self.width 31 | return [h, w] 32 | 33 | def location_to_move(self, location): 34 | if len(location) != 2: 35 | return -1 36 | h = location[0] 37 | w = location[1] 38 | move = h * self.width + w 39 | if move not in range(self.width * self.height): 40 | return -1 41 | return move 42 | 43 | def do_move(self, move): 44 | self.states[move] = self.current_player 45 | self.available.remove(move) 46 | self.current_player = ( 47 | self.players[0] if self.current_player == self.players[1] 48 | else self.players[1] 49 | ) 50 | self.last_move = move 51 | 52 | def get_current_player(self): 53 | return self.current_player 54 | 55 | def current_state(self): 56 | square_state = np.zeros((4, self.width, self.height)) 57 | if self.states: 58 | moves, players = np.array(list(zip(*self.states.items()))) 59 | move_curr = moves[players == self.current_player] 60 | move_oppo = moves[players != self.current_player] 61 | square_state[0][move_curr // self.width, move_curr % self.height] = 1.0 62 | square_state[1][move_oppo // self.width, move_oppo % self.height] = 1.0 63 | # 第 3 个平面,整个棋盘只有一个 1 ,表达上一手下的哪个 64 | square_state[2][self.last_move // self.width, self.last_move % self.height] = 1.0 65 | if len(self.states) % 2 == 0: 66 | # 第 4 个平面,全 1 或者全 0 ,描述现在是哪一方 67 | square_state[3][:, :] = 1.0 68 | # 为什么这里要反转? 69 | return square_state[:, ::-1, :] 70 | 71 | def has_a_winner(self): 72 | width = self.width 73 | height = self.height 74 | states = self.states 75 | n = self.n_in_row 76 | 77 | # 是否满足双方至少下了 5 子(粗略地) 78 | moved = list(set(range(width * height)) - set(self.available)) 79 | if len(moved) < self.n_in_row + 2: 80 | return False, -1 81 | 82 | for m in moved: 83 | h = m // width 84 | w = m % width 85 | player = states[m] 86 | 87 | if w in range(width - n + 1) \ 88 | and len(set(states.get(i, -1) for i in range(m, m+n))) == 1: 89 | return True, player 90 | 91 | if h in range(height - n + 1) \ 92 | and len(set(states.get(i, -1) for i in range(m, m+n*width, width))) == 1: 93 | return True, player 94 | 95 | if w in range(width - n + 1) and h in range(height - n + 1) \ 96 | and len(set(states.get(i, -1) for i in range(m, m+n*(width+1), width+1))) == 1: 97 | return True, player 98 | 99 | if w in range(n - 1, width) and h in range(height - n + 1) \ 100 | and len(set(states.get(i, -1) for i in range(m, m+n*(width-1), width-1))) == 1: 101 | return True, player 102 | 103 | return False, -1 104 | 105 | def game_end(self): 106 | win, winner = self.has_a_winner() 107 | if win: 108 | return True, winner 109 | elif not len(self.available): 110 | return True, -1 111 | return False, -1 112 | 113 | class Game: 114 | def __init__(self, board: Board): 115 | self.board = board 116 | 117 | def start_self_play(self, player, is_shown=False, temp=1e-3): 118 | self.board.init_board() 119 | p1, p2 = self.board.players 120 | states, mcts_probs, current_players = [], [], [] 121 | while True: 122 | move, move_probs = player.get_action(self.board, temp=temp, return_prob=1) 123 | # 保存 self-play 数据 124 | states.append(self.board.current_state()) 125 | mcts_probs.append(move_probs) 126 | current_players.append(self.board.current_player) 127 | # 执行一步落子 128 | self.board.do_move(move) 129 | if is_shown: 130 | self.graphic(self.board, p1, p2) 131 | end, winner = self.board.game_end() 132 | if end: 133 | # 从每一个 state 对应的 player 的视角保存胜负信息 134 | winner_z = np.zeros(len(current_players)) 135 | if winner != -1: 136 | winner_z[np.array(current_players) == winner] = 1.0 137 | winner_z[np.array(current_players) != winner] = - 1.0 138 | player.reset_player() 139 | if is_shown: 140 | if winner != -1: 141 | print("Game end. Winner is player: ", winner) 142 | else: 143 | print("Game end. Tie") 144 | return winner, zip(states, mcts_probs, winner_z) 145 | 146 | def start_play(self, player1, player2, start_player=0, is_shown=1): 147 | if start_player not in (0, 1): 148 | raise Exception('start_player should be either 0 or 1') 149 | self.board.init_board(start_player) 150 | p1, p2 = self.board.players 151 | player1.set_player_ind(p1) 152 | player2.set_player_ind(p2) 153 | players = {p1: player1, p2: player2} 154 | if is_shown: 155 | self.graphic(self.board, player1.player, player2.player) 156 | while True: 157 | current_player = self.board.get_current_player() 158 | player_in_turn = players[current_player] 159 | move = player_in_turn.get_action(self.board) 160 | self.board.do_move(move) 161 | if is_shown: 162 | self.graphic(self.board, player1.player, player2.player) 163 | end, winner = self.board.game_end() 164 | if end: 165 | if is_shown: 166 | if winner != -1: 167 | print("Game end. Winner is ", players[winner]) 168 | else: 169 | print("Game end. Tie") 170 | return winner 171 | 172 | def graphic(self, board, player1, player2): 173 | width = board.width 174 | height = board.height 175 | 176 | print("Player", player1, "with X".rjust(3)) 177 | print("Player", player2, "with O".rjust(3)) 178 | print() 179 | for x in range(width): 180 | print("{0:8}".format(x), end='') 181 | print('\r\n') 182 | for i in range(height - 1, -1, -1): 183 | print("{0:4d}".format(i), end='') 184 | for j in range(width): 185 | loc = i * width + j 186 | p = board.states.get(loc, -1) 187 | if p == player1: 188 | print('x'.center(8), end='') 189 | elif p == player2: 190 | print('O'.center(8), end='') 191 | else: 192 | print('_'.center(8), end='') 193 | print('\r\n\r\n') 194 | -------------------------------------------------------------------------------- /AlphaZero/loss.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/AlphaZero/loss.npy -------------------------------------------------------------------------------- /AlphaZero/main.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | DIRNAME = osp.dirname(__file__) 4 | sys.path.append(DIRNAME + '/..') 5 | 6 | from AlphaZero.train import TrainPipeline 7 | 8 | if __name__ == "__main__": 9 | training_pipeline = TrainPipeline(DIRNAME + '/current_policy.model') 10 | training_pipeline.run() 11 | 12 | -------------------------------------------------------------------------------- /AlphaZero/mcts_alphaZero.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | def Softmax(x): 5 | probs = np.exp(x - np.max(x)) 6 | probs /= np.sum(probs) 7 | return probs 8 | 9 | class TreeNode: 10 | """ A node in the MCTS tree. """ 11 | def __init__(self, parent, prior_p): 12 | self._parent = parent 13 | self._children = {} 14 | self._n_visits = 0 15 | self._Q = 0 16 | self._u = 0 17 | self._P = prior_p 18 | 19 | def select(self, c_puct): 20 | """ Return: A tuple of (action, next_node) """ 21 | return max( 22 | self._children.items(), 23 | key=lambda act_node: act_node[1].get_value(c_puct) 24 | ) 25 | 26 | def get_value(self, c_puct): 27 | self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) 28 | return self._Q + self._u 29 | 30 | def expand(self, action_priors): 31 | for action, prob in action_priors: 32 | if action not in self._children: 33 | self._children[action] = TreeNode(self, prob) 34 | 35 | def update(self, leaf_value): 36 | self._n_visits += 1 37 | self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits 38 | 39 | def update_recursive(self, leaf_value): 40 | if self._parent: 41 | self._parent.update_recursive(-leaf_value) 42 | self.update(leaf_value) 43 | 44 | def is_leaf(self): 45 | return self._children == {} 46 | 47 | def is_root(self): 48 | return self._parent is None 49 | 50 | 51 | class MCTS(object): 52 | """ An implementation of Monte Carlo Tree Search. """ 53 | def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): 54 | self._root = TreeNode(None, 1.0) 55 | self._policy = policy_value_fn 56 | self._c_puct = c_puct 57 | self._n_playout = n_playout 58 | 59 | def _playout(self, state): 60 | """ 完整的执行选择、扩展评估和回传更新等步骤 """ 61 | node = self._root 62 | # 选择 63 | while True: 64 | if node.is_leaf(): 65 | break 66 | action, node = node.select(self._c_puct) 67 | state.do_move(action) 68 | # 扩展及评估 69 | action_probs, leaf_value = self._policy(state) 70 | end, winner = state.game_end() 71 | if not end: 72 | node.expand(action_probs) 73 | else: 74 | if winner == -1: # 平局 75 | leaf_value = 0.0 76 | else: 77 | leaf_value = ( 78 | 1.0 if winner == state.get_current_player() else -1.0 79 | ) 80 | # 回传更新 81 | node.update_recursive(-leaf_value) 82 | 83 | def get_move_probs(self, state, temp=1e-3): 84 | for n in range(self._n_playout): 85 | state_copy = copy.deepcopy(state) 86 | self._playout(state_copy) 87 | 88 | act_visits = [ 89 | (act, node._n_visits) for act, node in self._root._children.items() 90 | ] 91 | acts, visits = zip(*act_visits) 92 | # 注意,这里是根据 visits 算出的动作选择概率 93 | act_probs = Softmax(1.0 / temp * np.log(np.array(visits) + 1e-10)) 94 | 95 | return acts, act_probs 96 | 97 | def update_with_move(self, last_move): 98 | if last_move in self._root._children: 99 | self._root = self._root._children[last_move] 100 | self._root._parent = None 101 | else: 102 | self._root = TreeNode(None, 1.0) 103 | 104 | class MCTSPlayer: 105 | """ AI player based on MCTS """ 106 | def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): 107 | self.mcts = MCTS(policy_value_function, c_puct, n_playout) 108 | self._is_selfplay = is_selfplay 109 | 110 | def get_action(self, board, temp=1e-3, return_prob=0): 111 | sensible_moves = board.available 112 | move_probs = np.zeros(board.width * board.height) 113 | if len(sensible_moves) > 0: 114 | acts, probs = self.mcts.get_move_probs(board, temp) 115 | move_probs[list(acts)] = probs 116 | if self._is_selfplay: 117 | move = np.random.choice( 118 | acts, p = 0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs))) 119 | ) 120 | # 更新根节点,复用搜索子树 121 | self.mcts.update_with_move(move) 122 | else: 123 | move = np.random.choice(acts, p=probs) 124 | # 重置根节点 125 | self.mcts.update_with_move(-1) 126 | location = board.move_to_location(move) 127 | print("AI move: %d, %d\n".format(location[0], location[1])) 128 | if return_prob: 129 | return move, move_probs 130 | else: 131 | return move 132 | else: 133 | print("WARNING: the board is full") 134 | 135 | def set_player_ind(self, p): 136 | self.player = p 137 | 138 | def reset_player(self): 139 | self.mcts.update_with_move(-1) 140 | 141 | def __str__(self): 142 | return "MCTS {}".format(self.player) 143 | -------------------------------------------------------------------------------- /AlphaZero/play.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | DIRNAME = osp.dirname(__file__) 4 | sys.path.append(DIRNAME + '/..') 5 | 6 | from AlphaZero.game import Board, Game 7 | from AlphaZero.mcts_alphaZero import MCTSPlayer 8 | from AlphaZero.policy_value_net_pytorch import PolicyValueNet 9 | 10 | import argparse 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-t', '--trained', action='store_true') 13 | args = parser.parse_args() 14 | 15 | """ 16 | input location as '3,3' to play 17 | """ 18 | 19 | class Human: 20 | """ human player """ 21 | def __init__(self): 22 | self.player = None 23 | 24 | def set_player_ind(self, p): 25 | self.player = p 26 | 27 | def get_action(self, board): 28 | try: 29 | location = input("Your move: ") 30 | if isinstance(location, str): # for Python3 31 | location = [int(n, 10) for n in location.split(",")] 32 | move = board.location_to_move(location) 33 | except Exception as e: 34 | move = -1 35 | if move == -1 or move not in board.available: 36 | print("invalid move") 37 | move = self.get_action(board) 38 | return move 39 | 40 | def __str__(self): 41 | return "Human {}".format(self.player) 42 | 43 | def run(): 44 | n = 5 45 | width, height = 8, 8 46 | model_file = DIRNAME + '/current_policy' + ('_1' if args.trained else '_0') + '.model' 47 | try: 48 | board = Board(width=width, height=height, n_in_row=n) 49 | game = Game(board) 50 | # 创建 AI player 51 | best_policy = PolicyValueNet(width, height, model_file=model_file) 52 | mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) 53 | # 创建 Human player ,输入样例: 2,3 54 | human = Human() 55 | # 设置 start_player = 0 可以让人类先手 56 | game.start_play(human, mcts_player, start_player=1, is_shown=1) 57 | except KeyboardInterrupt: 58 | print('\n\rquit') 59 | 60 | if __name__ == "__main__": 61 | run() -------------------------------------------------------------------------------- /AlphaZero/policy_value_net_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | def set_learning_rate(optimizer, lr): 9 | """ Set the learning rate to the given value """ 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = lr 12 | 13 | class Net(nn.Module): 14 | """ 定义策略价值网络结构 """ 15 | def __init__(self, board_width, board_height): 16 | super().__init__() 17 | self.board_width = board_width 18 | self.board_height = board_height 19 | # common layers 20 | self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) 21 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 22 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 23 | # action policy layers 24 | self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1) 25 | self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height) 26 | # state value layers 27 | self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) 28 | self.val_fc1 = nn.Linear(2 * board_width * board_height, 64) 29 | self.val_fc2 = nn.Linear(64, 1) 30 | 31 | def forward(self, state_input): 32 | # common layers 33 | x = F.relu(self.conv1(state_input)) 34 | x = F.relu(self.conv2(x)) 35 | x = F.relu(self.conv3(x)) 36 | # action policy layers 37 | x_act = F.relu(self.act_conv1(x)) 38 | x_act = x_act.view(-1, 4 * self.board_width * self.board_height) 39 | x_act = F.log_softmax(self.act_fc1(x_act), dim=1) 40 | # state value layers 41 | x_val = F.relu(self.val_conv1(x)) 42 | x_val = x_val.view(-1, 2 * self.board_width * self.board_height) 43 | x_val = F.relu(self.val_fc1(x_val)) 44 | x_val = torch.tanh(self.val_fc2(x_val)) 45 | return x_act, x_val 46 | 47 | class PolicyValueNet: 48 | def __init__(self, board_width, board_height, model_file=None): 49 | self.board_width = board_width 50 | self.board_height = board_height 51 | self.l2_const = 1e-4 52 | self.policy_value_net = Net(board_width, board_height) 53 | self.optimizer = optim.Adam(self.policy_value_net.parameters(), weight_decay=self.l2_const) 54 | if model_file: 55 | net_params = torch.load(model_file) 56 | self.policy_value_net.load_state_dict(net_params) 57 | 58 | def policy_value_fn(self, board): 59 | legal_positions = board.available 60 | current_state = np.ascontiguousarray( 61 | board.current_state().reshape( 62 | -1, 4, self.board_width, self.board_height 63 | ) 64 | ) 65 | log_act_probs, value = self.policy_value_net( 66 | Variable(torch.from_numpy(current_state)).float() 67 | ) 68 | act_probs = np.exp(log_act_probs.data.numpy().flatten()) 69 | act_probs = zip(legal_positions, act_probs[legal_positions]) 70 | value = value.data[0][0] 71 | return act_probs, value 72 | 73 | def train_step(self, state_batch, mcts_probs, winner_batch, lr): 74 | """ perform a training step """ 75 | state_batch = Variable(torch.FloatTensor(state_batch)) 76 | mcts_probs = Variable(torch.FloatTensor(mcts_probs)) 77 | winner_batch = Variable(torch.FloatTensor(winner_batch)) 78 | # zero the parameter gradients 79 | self.optimizer.zero_grad() 80 | # set learning rate 81 | set_learning_rate(self.optimizer, lr) 82 | # forward 83 | log_act_probs, value = self.policy_value_net(state_batch) 84 | # define the loss 85 | value_loss = F.mse_loss(value.view(-1), winner_batch) 86 | policy_loss = - torch.mean( 87 | torch.sum(mcts_probs * log_act_probs, 1) 88 | ) 89 | loss = value_loss + policy_loss 90 | # backward and optimize 91 | loss.backward() 92 | self.optimizer.step() 93 | # policy entropy, for monitoring only 94 | entropy = - torch.mean( 95 | torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) 96 | ) 97 | return loss.item(), entropy.item() 98 | 99 | def get_policy_param(self): 100 | net_params = self.policy_value_net.state_dict() 101 | return net_params 102 | 103 | def save_model(self, model_file): 104 | """ save model params to file """ 105 | net_params = self.get_policy_param() 106 | torch.save(net_params, model_file) 107 | 108 | -------------------------------------------------------------------------------- /AlphaZero/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from collections import deque 4 | import os.path as osp 5 | from .game import Board, Game 6 | from .mcts_alphaZero import MCTSPlayer 7 | from .policy_value_net_pytorch import PolicyValueNet 8 | 9 | DIRNAME = osp.dirname(__file__) 10 | 11 | class TrainPipeline: 12 | def __init__(self, init_model=None): 13 | # 棋盘相关参数 14 | self.board_width = 8 15 | self.board_height = 8 16 | self.n_in_row = 5 17 | self.board = Board( 18 | width=self.board_width, 19 | height=self.board_height, 20 | n_in_row=self.n_in_row 21 | ) 22 | self.game = Game(self.board) 23 | # 自我对弈相关参数 24 | self.temp = 1.0 25 | self.c_puct = 5 26 | self.n_playout = 400 27 | # 训练更新相关参数 28 | self.learn_rate = 2e-3 29 | self.buffer_size = 10000 30 | self.batch_size = 512 31 | self.data_buffer = deque(maxlen=self.buffer_size) 32 | self.check_freq = 2 # 保存模型的概率 33 | self.game_batch_num = 3000 # 训练更新的次数 34 | if init_model: 35 | # 如果提供了初始模型,则加载其用于初始化策略价值网络 36 | self.policy_value_net = PolicyValueNet( 37 | self.board_width, 38 | self.board_height, 39 | model_file=init_model 40 | ) 41 | else: 42 | # 随机初始化策略价值网络 43 | self.policy_value_net = PolicyValueNet( 44 | self.board_width, 45 | self.board_height 46 | ) 47 | self.mcts_player = MCTSPlayer( 48 | self.policy_value_net.policy_value_fn, 49 | c_puct=self.c_puct, 50 | n_playout=self.n_playout, 51 | is_selfplay=1 52 | ) 53 | 54 | def run(self): 55 | """ 执行完整的训练流程 """ 56 | for i in range(self.game_batch_num): 57 | episode_len = self.collect_selfplay_data() 58 | if len(self.data_buffer) > self.batch_size: 59 | loss, entropy = self.policy_update() 60 | print(( 61 | "batch i:{}, " 62 | "episode_len:{}, " 63 | "loss:{:.4f}, " 64 | "entropy:{:.4f}" 65 | ).format(i+1, episode_len, loss, entropy)) 66 | # save performance per update 67 | loss_array = np.load(DIRNAME + '/loss.npy') 68 | entropy_array = np.load(DIRNAME + '/entropy.npy') 69 | loss_array = np.append(loss_array, loss) 70 | entropy_array = np.append(entropy_array, entropy) 71 | np.save(DIRNAME + '/loss.npy', loss_array) 72 | np.save(DIRNAME + '/entropy.npy', entropy_array) 73 | del loss_array 74 | del entropy_array 75 | else: 76 | print("batch i:{}, episode_len:{}".format(i+1, episode_len)) 77 | # 定期保存模型 78 | if (i+1) % self.check_freq == 0: 79 | self.policy_value_net.save_model( 80 | DIRNAME + '/current_policy.model' 81 | ) 82 | 83 | def collect_selfplay_data(self): 84 | """ collect self-play data for training """ 85 | winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) 86 | play_data = list(play_data)[:] 87 | episode_len = len(play_data) 88 | # augment the data 89 | play_data = self.get_equi_data(play_data) 90 | self.data_buffer.extend(play_data) 91 | return episode_len 92 | 93 | def get_equi_data(self, play_data): 94 | """ play_data: [(state, mcts_prob, winner_z), ...] """ 95 | extend_data = [] 96 | for state, mcts_prob, winner in play_data: 97 | for i in [1, 2, 3, 4]: 98 | # 逆时针旋转 99 | equi_state = np.array([np.rot90(s, i) for s in state]) 100 | equi_mcts_prob = np.rot90(np.flipud( 101 | mcts_prob.reshape(self.board_height, self.board_width) 102 | ), i) 103 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 104 | # 水平翻转 105 | equi_state = np.array([np.fliplr(s) for s in equi_state]) 106 | equi_mcts_prob = np.fliplr(equi_mcts_prob) 107 | extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) 108 | return extend_data 109 | 110 | def policy_update(self): 111 | """ update the policy-value net """ 112 | mini_batch = random.sample(self.data_buffer, self.batch_size) 113 | state_batch = [data[0] for data in mini_batch] 114 | mcts_probs_batch = [data[1] for data in mini_batch] 115 | winner_batch = [data[2] for data in mini_batch] 116 | loss, entropy = self.policy_value_net.train_step( 117 | state_batch, mcts_probs_batch, winner_batch, self.learn_rate 118 | ) 119 | return loss, entropy 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 强化学习练手场地 2 | 这个仓库本来用于复现某本国内强化学习教材的案例,奈何这本书写得实在太差了,因此弃坑此书。我的书评在这里:[豆瓣书评](https://book.douban.com/review/12673161/)。 3 | 4 | 现在,这个仓库用于存储一些强化学习练手小项目与算法实验。具体来讲,就是不至于单独成一个 repo 的项目,但是又值得拿出来讨论的代码。 5 | 6 | ### 我的笔记分布 7 | - 🥊 入门学习 / 读书笔记 [GitHub链接:PiperLiu/Reinforcement-Learning-practice-zh](https://github.com/PiperLiu/Reinforcement-Learning-practice-zh) 8 | - 💻 阅读论文 / 视频课程的笔记 [GitHub链接:PiperLiu/introRL](https://github.com/PiperLiu/introRL) 9 | - ✨ 大小算法 / 练手操场 [GitHub链接:PiperLiu/Approachable-Reinforcement-Learning](https://github.com/PiperLiu/Approachable-Reinforcement-Learning) 10 | 11 | ### 仓库目录 12 | - 强化学习基础复现案例 [转到摘要与索引](#强化学习基础复现案例) 13 | 14 | 15 | **** 16 | 17 | ## 强化学习基础复现案例 18 | 19 | 参考《深入浅出强化学习:编程实战》内容。奈何实在写得不行,因此照着复现了第0、1章与AlphaZero后,弃坑。 20 | 21 | - [附录 A PyTorch 入门](#A) 22 | - 第 0 篇 先导篇 23 | - - [1 一个及其简单的强化学习实例](#sec_1) 24 | - - [2 马尔可夫决策过程](#sec_2) 25 | - 第 1 篇 基于值函数的方法 26 | - - [3 基于动态规划的方法](#sec_3) 27 | - - [4 基于蒙特卡洛的方法](#sec_4) 28 | - - [5 基于时间差分的方法](#sec_5) 29 | - - [6 基于函数逼近的方法](#sec_6) 30 | - AlphaZero 实战:五子棋 31 | - - [链接](#sec_11) 32 | 33 | ### 附录 A PyTorch 入门 34 | 35 | 36 | [./pyTorch_learn/](./pyTorch_learn/) 37 | 38 | 介绍了 PyTorch 的基本使用,主要实例为:构建了一个极其单薄简单的卷积神经网络,数据集为 `CIFAR10` 。学习完附录 A ,我主要收获: 39 | - 输入 PyTorch 的 `nn.Module` 应该是 `mini-batch` ,即比正常数据多一个维度; 40 | - 输入 `nn.Module` 应该是 `Variable` 包裹的; 41 | - 在网络类中, `__init__()` 并没有真正定义网络结构的关系,网络结构的输入输出关系在 `forward()` 中定义。 42 | 43 | 此外,让我们梳理一下神经网络的“学习”过程: 44 | ```python 45 | import torch 46 | import torch.nn as nn 47 | import torch.nn.functional as F 48 | from torch.autograd import Variabl 49 | import torch.optim as optim 50 | 51 | # 神经网络对象 52 | net = Net() 53 | # 损失函数:交叉熵 54 | criterion = nn.CrossEntropyLoss() 55 | # 优化方式 56 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 57 | # 遍历所有数据 5 次 58 | for epoch in range(5): 59 | # data 是一个 mini-batch , batch-size 由 trainloader 决定 60 | for i, data in enumerate(trainloader, 0): 61 | # 特征值与标签:注意用 Variable 进行包裹 62 | inputs, labels = data 63 | inputs, labels = Variable(inputs), Variable(labels) 64 | # pytorch 中 backward() 函数执行时,梯度是累积计算, 65 | # 而不是被替换; 66 | # 但在处理每一个 batch 时并不需要与其他 batch 的梯度 67 | # 混合起来累积计算, 68 | # 因此需要对每个 batch 调用一遍 zero_grad() 69 | # 将参数梯度置0。 70 | optimizer.zero_grad() 71 | # 现在的输出值 72 | outputs = net(inputs) 73 | # 求误差 74 | loss = criterion(outputs, labels) 75 | # 在初始化 optimizer 时,我们明确告诉它应该更新模型的 76 | # 哪些参数;一旦调用了 loss.backward() ,梯度就会被 77 | # torch对象“存储”(它们具有grad和requires_grad属性); 78 | # 在计算模型中所有张量的梯度后,调用 optimizer.step() 79 | # 会使优化器迭代它应该更新的所有参数(张量), 80 | # 并使用它们内部存储的 grad 来更新它们的值。 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # 模型的保存与加载 85 | torch.save(net.state_dict(), './pyTorch_learn/data/' + 'model.pt') 86 | net.load_state_dict(torch.load('./pyTorch_learn/data/' + 'model.pt')) 87 | ``` 88 | 89 | ### 第 0 篇 先导篇 90 | 91 | #### 1 一个及其简单的强化学习实例 92 | 93 | 94 | [./ch_0/sec_1/](./ch_0/sec_1/) 95 | 96 | 很简答的一个实例,探讨“探索”与“利用”间的博弈平衡。 97 | 98 | 我对原书的代码进行了一些改进与 typo 。 99 | 100 | #### 2 马尔可夫决策过程 101 | 102 | 103 | [./ch_0/sec_2/](./ch_0/sec_2/) 104 | 105 | 优势函数(Advantage Function):$A(s, a) = q_\pi (s, a) - v_\pi (s)$ 106 | 107 | 造了一个交互环境,以后测试可以用到。典型的“网格世界”。 108 | 109 | ### 第 1 篇 基于值函数的方法 110 | 111 | #### 3 基于动态规划的方法 112 | 113 | 114 | [./ch_1/sec_3/](./ch_1/sec_3/) 115 | 116 | - 修改了“鸳鸯环境”:[yuan_yang_env.py](./ch_1/sec_3/yuan_yang_env.py) 117 | - 策略迭代:[dp_policy_iter.py](./ch_1/sec_3/dp_policy_iter.py) 118 | - 价值迭代:[dp_value_iter.py](./ch_1/sec_3/dp_value_iter.py) 119 | 120 | #### 4 基于蒙特卡洛的方法 121 | 122 | 123 | [./ch_1/sec_4/](./ch_1/sec_4/) 124 | 125 | 基于蒙特卡洛方法,分别实现了: 126 | - 纯贪心策略; 127 | - 同轨策略下的 Epsilon-贪心策略。 128 | 129 | #### 5 基于时间差分的方法 130 | 131 | 132 | [./ch_1/sec_5/](./ch_1/sec_5/) 133 | 134 | 本书让我对“离轨策略”的概念更加清晰: 135 | - 离轨策略可以使用重复的数据、不同策略下产生的数据; 136 | - 因为,离轨策略更新时,只用到了(s, a, r, s'),并不需要使用 a' 这个数据; 137 | - 换句话说,并不需要数据由本策略产生。 138 | 139 | **发现了一个神奇的现象:关于奖励机制的设置。** 140 | 141 | 书上说,本节使用的 env 可以与蒙特卡洛同,这是不对的。 142 | 先说上节蒙特卡洛的奖励机制: 143 | - 小鸟撞到墙:-10; 144 | - 小鸟到达目的地:+10; 145 | - 小鸟走一步,什么都没有发生:-2。 146 | 147 | 如果你运行试验,你会发现,TD(0) 方法下,小鸟会 **畏惧前行** : 148 | - 在蒙特卡洛方法下,如此的奖励机制有效的,因为训练是在整整一幕结束之后; 149 | - 在 TD(0) 方法下,小鸟状态墙得到的奖励是 -10 ,而它行走五步的奖励也是 -10 (不考虑折扣); 150 | - 但在这个环境中,小鸟要抵达目的地,至少要走 20 步; 151 | - 因此,与其“披荆斩棘走到目的地”,还不如“在一开始就一头撞在墙上撞死”来的奖励多呢。 152 | 153 | 可以用如下几个例子印证,对于 [./ch_1/sec_5/yuan_yang_env_td.py](./ch_1/sec_5/yuan_yang_env_td.py) 的第 151-159 行: 154 | ```python 155 | flag_collide = self.collide(next_position) 156 | if flag_collide == 1: 157 | return self.position_to_state(current_position), -10, True 158 | 159 | flag_find = self.find(next_position) 160 | if flag_find == 1: 161 | return self.position_to_state(next_position), 10, True 162 | 163 | return self.position_to_state(next_position), -2, False 164 | ``` 165 | 166 | 现在 (撞墙奖励, 到达重点奖励, 走路奖励) 分别为 (-10, 10, -2) 。现在的实验结果是:小鸟宁可撞死,也不出门。 167 | 168 | 我们把出门走路的痛感与抵达目的地的快乐进行更改,: 169 | - (-10, 10, -1),出门走路没有那么疼了,小鸟倾向于抵达目的地; 170 | - (-10, 10, -1.2),出门走路的痛感上升,小鸟倾向于宁可开局就撞死自己; 171 | - (-10, 100, -1.2),出门走路虽然疼,但是到达目的地的快乐是很大很大的,小鸟多次尝试,掐指一算,还是出门合适,毕竟它是一只深谋远虑的鸟。 172 | 173 | 运行试验,我们发现,上述试验并不稳定,这是因为每次训练小鸟做出的随机决策不同,且“走路痛感”与“抵达快乐”很不悬殊,小鸟左右为难。 174 | 175 | 因此,一个好的解决方案是:**别管那么多!我们不是希望小鸟抵达目的地吗?那就不要让它走路感到疼痛!** 176 | - 设置奖励为(-10, 10, 0),我保证小鸟会“愿意出门”,并抵达目的地! 177 | 178 | 这也提醒了我:**作为一个强化学习算法实践者,不应该过多干涉智能体决策!告诉智能体几个简单的规则,剩下的交给其自己学习。过于复杂的奖励机制,会出现意想不到的状况!** 179 | 180 | 对于此问题,其实还可以另一个思路:调高撞墙的痛感。 181 | - 设置奖励为(-1000, 10, -2),如次,小鸟便不敢撞墙:因为撞墙太疼了!!! 182 | - 并且,小鸟也不会“多走路”,因为多走路也有痛感。你会发现,如此得到的结果,小鸟总是能找到最优方案(最短的路径)。 183 | 184 | #### 6 基于函数逼近的方法 185 | 186 | 187 | [./ch_1/sec_6/](./ch_1/sec_6/)a 188 | 189 | 本节前半部分最后一次使用“鸳鸯系统”,我发现: 190 | - 无论是正常向量状态表示,还是固定稀疏状态表示,书中都将 `epsilon = epsilon * 0.99` 在迭代中去掉; 191 | - 事实证明,不应该将其去掉,尤其是第一组情况。第一组情况其实就是上节的表格型 q-learning ; 192 | - 固定稀疏表示中,可以不加探索欲望的收敛(在这个环境中)。 193 | 194 | 此外,还发现: 195 | - 固定稀疏中,鸟倾向于走直线; 196 | - 我认为这是因为固定稀疏矩阵中,抽取了特征,同一个 x 或同一个 y 对应的状态,其价值更趋同。 197 | 198 | 本节后半部分:非线性函数逼近。 199 | 200 | 书中没有给代码地址,我 Google 到作者应该是借鉴了这个:[https://github.com/yenchenlin/DeepLearningFlappyBird](https://github.com/yenchenlin/DeepLearningFlappyBird) 201 | - 我将这个项目写在了:[./ch_1/sec_6/flappy_bird/](./ch_1/sec_6/flappy_bird/) 202 | - - 我添加了手动操作体验游戏的部分,按 H 键可以煽动翅膀:[./ch_1/sec_6/flappy_bird/game/keyboard_agent.py](./ch_1/sec_6/flappy_bird/game/keyboard_agent.py) 203 | - - 书上是 tf 1 的代码,我使用 tf 2 重写,这个过程中参考了:[https://github.com/tomjur/TF2.0DQN](https://github.com/tomjur/TF2.0DQN) 204 | - `python -u "d:\GitHub\rl\Approachable-Reinforcement-Learning\ch_1\sec_6\flappy_bird\dqn_agent.py"`以训练 205 | 206 | ### AlphaZero 实战:五子棋 207 | 208 | 209 | 代码在 [./AlphaZero](./AlphaZero) 。 210 | -------------------------------------------------------------------------------- /ch_0/sec_1/kb_game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class KB_Game: 5 | def __init__(self): 6 | self.q = np.array([0.0, 0.0, 0.0]) 7 | self.action_counts = np.array([0, 0, 0]) 8 | self.current_cumulative_rewards = 0.0 9 | self.actions = [1, 2, 3] 10 | self.counts = 0 11 | self.counts_history = [] 12 | self.cumulative_rewards_history = [] 13 | self.average_rewards_history = [] 14 | self.a = 1 15 | self.reward = 0 16 | 17 | def step(self, a): 18 | r = 0 19 | if a == 1: 20 | r = np.random.normal(1, 1) 21 | elif a == 2: 22 | r = np.random.normal(2, 1) 23 | elif a == 3: 24 | r = np.random.normal(1.5, 1) 25 | return r 26 | 27 | def choose_action(self, policy, **kwargs): 28 | action = 0 29 | if policy == 'e_greedy': 30 | if np.random.random() < kwargs['epsilon']: 31 | action = np.random.randint(1, 4) 32 | else: 33 | action = np.argmax(self.q) + 1 34 | if policy == 'ucb': 35 | c_ratio = kwargs['c_ratio'] 36 | if 0 in self.action_counts: 37 | action = np.where(self.action_counts==0)[0][0] + 1 38 | else: 39 | value = self.q + c_ratio * np.sqrt(np.log(self.counts) /\ 40 | self.action_counts) 41 | action = np.argmax(value) + 1 42 | if policy == 'boltzmann': 43 | tau = kwargs['temperature'] 44 | p = np.exp(self.q / tau) / (np.sum(np.exp(self.q / tau))) 45 | action = np.random.choice([1, 2, 3], p=p.ravel()) 46 | return action 47 | 48 | def train(self, play_total, policy, **kwargs): 49 | reward_1 = [] 50 | reward_2 = [] 51 | reward_3 = [] 52 | for i in range(play_total): 53 | action = 0 54 | if policy == 'e_greedy': 55 | action = \ 56 | self.choose_action(policy, epsilon=kwargs['epsilon']) 57 | if policy == 'ucb': 58 | action = \ 59 | self.choose_action(policy, c_ratio=kwargs['c_ratio']) 60 | if policy == 'boltzmann': 61 | action = \ 62 | self.choose_action(policy, temperature=kwargs['temperature']) 63 | self.a = action 64 | 65 | self.r = self.step(self.a) 66 | self.counts +=1 67 | 68 | self.q[self.a-1] = (self.q[self.a-1] * self.action_counts[self.a-1] + self.r) /\ 69 | (self.action_counts[self.a-1] + 1) 70 | self.action_counts[self.a-1] += 1 71 | 72 | reward_1.append(self.q[0]) 73 | reward_2.append(self.q[1]) 74 | reward_3.append(self.q[2]) 75 | self.current_cumulative_rewards += self.r 76 | self.cumulative_rewards_history.append(self.current_cumulative_rewards) 77 | self.average_rewards_history.append(self.current_cumulative_rewards /\ 78 | (len(self.average_rewards_history) + 1)) 79 | self.counts_history.append(i) 80 | 81 | def reset(self): 82 | self.q = np.array([0.0, 0.0, 0.0]) 83 | self.action_counts = np.array([0, 0, 0]) 84 | self.current_cumulative_rewards = 0.0 85 | self.actions = [1, 2, 3] 86 | self.counts = 0 87 | self.counts_history = [] 88 | self.cumulative_rewards_history = [] 89 | self.average_rewards_history = [] 90 | self.a = 1 91 | self.reward = 0 92 | 93 | def plot(self, colors, policy, style): 94 | plt.figure(1) 95 | plt.plot(self.counts_history, self.average_rewards_history, colors+style, label=policy) 96 | plt.legend() 97 | plt.xlabel('n', fontsize=18) 98 | plt.ylabel('average rewards', fontsize=18) 99 | -------------------------------------------------------------------------------- /ch_0/sec_1/main.py: -------------------------------------------------------------------------------- 1 | from kb_game import KB_Game 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | np.random.seed(0) 7 | k_gamble = KB_Game() 8 | total = 2000 9 | 10 | k_gamble.train(play_total=total, policy='e_greedy', epsilon=0.05) 11 | k_gamble.plot(colors='r', policy='e_greedy', style='-.') 12 | k_gamble.reset() 13 | 14 | k_gamble.train(play_total=total, policy='boltzmann', temperature=1) 15 | k_gamble.plot(colors='b', policy='boltzmann', style='--') 16 | k_gamble.reset() 17 | 18 | k_gamble.train(play_total=total, policy='ucb', c_ratio=0.5) 19 | k_gamble.plot(colors='g', policy='ucb', style='-') 20 | k_gamble.reset() 21 | 22 | plt.show() -------------------------------------------------------------------------------- /ch_0/sec_2/images/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/background.jpg -------------------------------------------------------------------------------- /ch_0/sec_2/images/bird_female.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/bird_female.png -------------------------------------------------------------------------------- /ch_0/sec_2/images/bird_male.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/bird_male.png -------------------------------------------------------------------------------- /ch_0/sec_2/images/obstacle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_0/sec_2/images/obstacle.png -------------------------------------------------------------------------------- /ch_0/sec_2/load_images.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import os.path as osp 3 | 4 | path_file = osp.abspath(__file__) 5 | path_images = osp.join(path_file, '..', 'images') 6 | 7 | def load_bird_male(): 8 | obj = 'bird_male.png' 9 | obj_path = osp.join(path_images, obj) 10 | return pygame.image.load(obj_path) 11 | 12 | def load_bird_female(): 13 | obj = 'bird_female.png' 14 | obj_path = osp.join(path_images, obj) 15 | return pygame.image.load(obj_path) 16 | 17 | def load_background(): 18 | obj = 'background.jpg' 19 | obj_path = osp.join(path_images, obj) 20 | return pygame.image.load(obj_path) 21 | 22 | def load_obstacle(): 23 | obj = 'obstacle.png' 24 | obj_path = osp.join(path_images, obj) 25 | return pygame.image.load(obj_path) 26 | -------------------------------------------------------------------------------- /ch_0/sec_2/yuan_yang_env.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | from load_images import * 4 | import numpy as np 5 | 6 | class YuanYangEnv: 7 | def __init__(self): 8 | self.states = [] 9 | for i in range(0, 100): 10 | self.states.append(i) 11 | self.actions = ['e', 's', 'w', 'n'] 12 | self.gamma = 0.8 13 | self.value = np.zeros((10, 10)) 14 | 15 | self.viewer = None 16 | self.FPSCLOCK = pygame.time.Clock() 17 | 18 | self.screen_size = (1200, 900) 19 | self.bird_position = (0, 0) 20 | self.limit_distance_x = 120 21 | self.limit_distance_y = 90 22 | self.obstacle_size = [120, 90] 23 | self.obstacle1_x = [] 24 | self.obstacle1_y = [] 25 | self.obstacle2_x = [] 26 | self.obstacle2_y = [] 27 | 28 | for i in range(8): 29 | # obstacle 1 30 | self.obstacle1_x.append(360) 31 | if i <= 3: 32 | self.obstacle1_y.append(90 * i) 33 | else: 34 | self.obstacle1_y.append(90 * (i + 2)) 35 | # obstacle 2 36 | self.obstacle2_x.append(720) 37 | if i <= 4: 38 | self.obstacle2_y.append(90 * i) 39 | else: 40 | self.obstacle2_y.append(90 * (i + 2)) 41 | 42 | self.bird_male_init_position = [0.0, 0.0] 43 | self.bird_male_position = [0, 0] 44 | self.bird_female_init_position = [1080, 0] 45 | 46 | def collide(self, state_position): 47 | flag = 1 48 | flag1 = 1 49 | flag2 = 1 50 | 51 | # obstacle 1 52 | dx = [] 53 | dy = [] 54 | for i in range(8): 55 | dx1 = abs(self.obstacle1_x[i] - state_position[0]) 56 | dx.append(dx1) 57 | dy1 = abs(self.obstacle1_y[i] - state_position[1]) 58 | dy.append(dy1) 59 | mindx = min(dx) 60 | mindy = min(dy) 61 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 62 | flag1 = 0 63 | 64 | # obstacle 2 65 | dx_second = [] 66 | dy_second = [] 67 | for i in range(8): 68 | dx1 = abs(self.obstacle2_x[i] - state_position[0]) 69 | dx_second.append(dx1) 70 | dy1 = abs(self.obstacle2_y[i] - state_position[1]) 71 | dy_second.append(dy1) 72 | mindx = min(dx_second) 73 | mindy = min(dy_second) 74 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 75 | flag2 = 0 76 | 77 | if flag1 == 0 and flag2 == 0: 78 | flag = 0 79 | 80 | # collide edge 81 | if state_position[0] > 1080 or \ 82 | state_position[0] < 0 or \ 83 | state_position[1] > 810 or \ 84 | state_position[1] < 0: 85 | flag = 1 86 | 87 | return flag 88 | 89 | def find(self, state_position): 90 | flag = 0 91 | if abs(state_position[0] - self.bird_female_init_position[0]) < \ 92 | self.limit_distance_x and \ 93 | abs(state_position[1] - self.bird_female_init_position[1]) < \ 94 | self.limit_distance_y: 95 | flag = 1 96 | return flag 97 | 98 | def state_to_position(self, state): 99 | i = int(state / 10) 100 | j = state % 10 101 | position = [0, 0] 102 | postion[0] = 120 * j 103 | postion[1] = 90 * i 104 | return position 105 | 106 | def position_to_state(self, position): 107 | i = position[0] / 120 108 | j = position[1] / 90 109 | return int(i + 10 * j) 110 | 111 | def reset(self): 112 | # 随机产生一个初始位置 113 | flag1 = 1 114 | flag2 = 1 115 | while flag1 or flag2 == 1: 116 | state = self.states[int(random.random() * len(self.states))] 117 | state_position = self.state_to_position(state) 118 | flag1 = self.collide(state_position) 119 | flag2 = self.find(state_position) 120 | return state 121 | 122 | def transform(self, state, action): 123 | current_position = self.state_to_position(state) 124 | next_position = [0, 0] 125 | flag_collide = 0 126 | flag_find = 0 127 | 128 | flag_collide = self.collide(current_position) 129 | flag_find = self.find(current_position) 130 | if flag_collide == 1: 131 | return state, -1, True 132 | if flag_find == 1: 133 | return state, 1, True 134 | 135 | if action == 'e': 136 | next_position[0] = current_position[0] + 120 137 | next_position[1] = current_position[1] 138 | if action == 's': 139 | next_position[0] = current_position[0] 140 | next_position[1] = current_position[1] + 90 141 | if action == 'w': 142 | next_position[0] = current_position[0] - 120 143 | next_position[1] = current_position[1] 144 | if action == 'n': 145 | next_position[0] = current_position[0] 146 | next_position[1] = current_position[1] - 90 147 | 148 | flag_collide = self.collide(next_position) 149 | if flag_collide == 1: 150 | return self.position_to_state(current_position), -1, True 151 | 152 | flag_find = self.find(next_position) 153 | if flag_find == 1: 154 | return self.position_to_state(next_position), 1, True 155 | 156 | return self.position_to_state(next_position), 0, False 157 | 158 | def gameover(self): 159 | for event in pygame.event.get(): 160 | if event.type == pygame.QUIT: 161 | exit() 162 | 163 | def render(self): 164 | if self.viewer is None: 165 | pygame.init() 166 | 167 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32) 168 | pygame.display.set_caption("yuanyang") 169 | # load pic 170 | self.bird_male = load_bird_male() 171 | self.bird_female = load_bird_female() 172 | self.background = load_background() 173 | self.obstacle = load_obstacle() 174 | 175 | # self.viewer.blit(self.bird_female, self.bird_female_init_position) 176 | # self.viewer.blit(self.bird_male, self.bird_male_init_position) 177 | 178 | self.viewer.blit(self.background, (0, 0)) 179 | self.font = pygame.font.SysFont('times', 15) 180 | 181 | self.viewer.blit(self.background, (0, 0)) 182 | for i in range(11): 183 | pygame.draw.lines(self.viewer, 184 | (255, 255, 255), 185 | True, 186 | ((120 * i, 0), (120 * i, 900)), 187 | 1 188 | ) 189 | pygame.draw.lines(self.viewer, 190 | (255, 255, 255), 191 | True, 192 | ((0, 90 * i), (1200, 90 * i)), 193 | 1 194 | ) 195 | 196 | for i in range(8): 197 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i])) 198 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i])) 199 | 200 | self.viewer.blit(self.bird_female, self.bird_female_init_position) 201 | self.viewer.blit(self.bird_male, self.bird_male_init_position) 202 | 203 | for i in range(10): 204 | for j in range(10): 205 | surface = self.font.render(str( 206 | round(float(self.value[i, j]), 3)), True, (0, 0, 0) 207 | ) 208 | self.viewer.blit(surface, (120 * i + 5, 90 * j + 70)) 209 | 210 | pygame.display.update() 211 | self.gameover() 212 | self.FPSCLOCK.tick(30) 213 | 214 | 215 | if __name__ == "__main__": 216 | yy = YuanYangEnv() 217 | yy.render() 218 | while True: 219 | for event in pygame.event.get(): 220 | if event.type == pygame.QUIT: 221 | exit() 222 | 223 | -------------------------------------------------------------------------------- /ch_1/sec_3/dp_policy_iter.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from yuan_yang_env import YuanYangEnv 4 | 5 | 6 | class DP_Policy_Iter: 7 | def __init__(self, yuanyang): 8 | self.states = yuanyang.states 9 | self.actions = yuanyang.actions 10 | self.v = [0.0 for i in range(len(self.states) + 1)] 11 | self.pi = dict() 12 | self.yuanyang = yuanyang 13 | self.gamma = yuanyang.gamma 14 | 15 | # 初始化策略 16 | for state in self.states: 17 | flag1 = 0 18 | flag2 = 0 19 | flag1 = yuanyang.collide(yuanyang.state_to_position(state)) 20 | flag2 = yuanyang.find(yuanyang.state_to_position(state)) 21 | if flag1 == 1 or flag2 == 1: 22 | continue 23 | self.pi[state] = self.actions[int(random.random() * len(self.actions))] 24 | 25 | def policy_evaluate(self): 26 | # 策略评估在计算值函数 27 | for i in range(100): 28 | delta = 0.0 29 | for state in self.states: 30 | flag1 = 0 31 | flag2 = 0 32 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state)) 33 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state)) 34 | if flag1 == 1 or flag2 == 1: 35 | continue 36 | action = self.pi[state] 37 | s, r, t = self.yuanyang.transform(state, action) 38 | # 更新值 39 | new_v = r + self.gamma * self.v[s] 40 | delta += abs(self.v[state] - new_v) 41 | # 更新值替换原来的值函数 42 | self.v[state] = new_v 43 | if delta < 1e-6: 44 | print('策略评估迭代次数:{}'.format(i)) 45 | break 46 | 47 | def policy_improve(self): 48 | # 利用更新后的值函数进行策略改善 49 | for state in self.states: 50 | flag1 = 0 51 | flag2 = 0 52 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state)) 53 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state)) 54 | if flag1 == 1 or flag2 == 1: 55 | continue 56 | a1 = self.actions[0] 57 | s, r, t = self.yuanyang.transform(state, a1) 58 | v1 = r + self.gamma * self.v[s] 59 | # 找状态 s 时,采用哪种动作,值函数最大 60 | for action in self.actions: 61 | s, r, t = self.yuanyang.transform(state, action) 62 | if v1 < r + self.gamma * self.v[s]: 63 | a1 = action 64 | v1 = r + self.gamma * self.v[s] 65 | # 贪婪策略,进行更新 66 | self.pi[state] = a1 67 | 68 | def policy_iterate(self): 69 | for i in range(100): 70 | # 策略评估,变的是 v 71 | self.policy_evaluate() 72 | # 策略改善 73 | pi_old = self.pi.copy() 74 | # 改变 pi 75 | self.policy_improve() 76 | if (self.pi == pi_old): 77 | print('策略改善次数:{}'.format(i)) 78 | break 79 | 80 | 81 | if __name__ == "__main__": 82 | yuanyang = YuanYangEnv() 83 | policy_value = DP_Policy_Iter(yuanyang) 84 | policy_value.policy_iterate() 85 | 86 | # 打印 87 | flag = 1 88 | s = 0 89 | path = [] 90 | # 将 v 值打印出来 91 | for state in range(100): 92 | i = int(state / 10) 93 | j = state % 10 94 | yuanyang.value[j, i] = policy_value.v[state] 95 | 96 | step_num = 0 97 | 98 | # 将最优路径打印出来 99 | while flag: 100 | path.append(s) 101 | yuanyang.path = path 102 | a = policy_value.pi[s] 103 | print('%d->%s\t'%(s, a)) 104 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 105 | yuanyang.render() 106 | time.sleep(0.2) 107 | step_num += 1 108 | s_, r, t = yuanyang.transform(s, a) 109 | if t == True or step_num > 200: 110 | flag = 0 111 | s = s_ 112 | 113 | # 渲染最后的路径点 114 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 115 | path.append(s) 116 | yuanyang.render() 117 | while True: 118 | yuanyang.render() 119 | -------------------------------------------------------------------------------- /ch_1/sec_3/dp_value_iter.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from yuan_yang_env import YuanYangEnv 4 | 5 | class DP_Value_Iter: 6 | def __init__(self, yuanyang): 7 | self.states = yuanyang.states 8 | self.actions = yuanyang.actions 9 | self.v = [0.0 for i in range(len(self.states) + 1)] 10 | self.pi = dict() 11 | self.yuanyang = yuanyang 12 | 13 | self.gamma = yuanyang.gamma 14 | for state in self.states: 15 | flag1 = 0 16 | flag2 = 0 17 | flag1 = yuanyang.collide(yuanyang.state_to_position(state)) 18 | flag2 = yuanyang.find(yuanyang.state_to_position(state)) 19 | if flag1 == 1 or flag2 == 1: 20 | continue 21 | self.pi[state] = self.actions[int(random.random() * len(self.actions))] 22 | 23 | def value_iteration(self): 24 | for i in range(1000): 25 | delta = 0.0 26 | for state in self.states: 27 | flag1 = 0 28 | flag2 = 0 29 | flag1 = self.yuanyang.collide(self.yuanyang.state_to_position(state)) 30 | flag2 = self.yuanyang.find(self.yuanyang.state_to_position(state)) 31 | if flag1 == 1 or flag2 == 2: 32 | continue 33 | a1 = self.actions[int(random.random() * 4)] 34 | s, r, t = self.yuanyang.transform(state, a1) 35 | # 策略评估 36 | v1 = r + self.gamma * self.v[s] 37 | # 策略改善 38 | for action in self.actions: 39 | s, r, t = self.yuanyang.transform(state, action) 40 | if v1 < r + self.gamma * self.v[s]: 41 | a1 = action 42 | v1 = r + self.gamma * self.v[s] 43 | delta += abs(v1 - self.v[state]) 44 | self.pi[state] = a1 45 | self.v[state] = v1 46 | if delta < 1e-6: 47 | print('迭代次数为:{}'.format(i)) 48 | break 49 | 50 | 51 | if __name__ == "__main__": 52 | yuanyang = YuanYangEnv() 53 | policy_value = DP_Value_Iter(yuanyang) 54 | policy_value.value_iteration() 55 | 56 | # 打印 57 | flag = 1 58 | s = 0 59 | path = [] 60 | # 将 v 值打印出来 61 | for state in range(100): 62 | i = int(state / 10) 63 | j = state % 10 64 | yuanyang.value[j, i] = policy_value.v[state] 65 | 66 | step_num = 0 67 | 68 | # 将最优路径打印出来 69 | while flag: 70 | path.append(s) 71 | yuanyang.path = path 72 | a = policy_value.pi[s] 73 | print('%d->%s\t'%(s, a)) 74 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 75 | yuanyang.render() 76 | time.sleep(0.2) 77 | step_num += 1 78 | s_, r, t = yuanyang.transform(s, a) 79 | if t == True or step_num > 200: 80 | flag = 0 81 | s = s_ 82 | 83 | # 渲染最后的路径点 84 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 85 | path.append(s) 86 | yuanyang.render() 87 | while True: 88 | yuanyang.render() 89 | -------------------------------------------------------------------------------- /ch_1/sec_3/load_images.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import os.path as osp 3 | 4 | path_file = osp.abspath(__file__) 5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images') 6 | 7 | def load_bird_male(): 8 | obj = 'bird_male.png' 9 | obj_path = osp.join(path_images, obj) 10 | return pygame.image.load(obj_path) 11 | 12 | def load_bird_female(): 13 | obj = 'bird_female.png' 14 | obj_path = osp.join(path_images, obj) 15 | return pygame.image.load(obj_path) 16 | 17 | def load_background(): 18 | obj = 'background.jpg' 19 | obj_path = osp.join(path_images, obj) 20 | return pygame.image.load(obj_path) 21 | 22 | def load_obstacle(): 23 | obj = 'obstacle.png' 24 | obj_path = osp.join(path_images, obj) 25 | return pygame.image.load(obj_path) 26 | -------------------------------------------------------------------------------- /ch_1/sec_3/yuan_yang_env.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | from load_images import * 4 | import numpy as np 5 | 6 | class YuanYangEnv: 7 | def __init__(self): 8 | self.states = [] 9 | for i in range(0, 100): 10 | self.states.append(i) 11 | self.actions = ['e', 's', 'w', 'n'] 12 | self.gamma = 0.8 13 | self.value = np.zeros((10, 10)) 14 | 15 | self.viewer = None 16 | self.FPSCLOCK = pygame.time.Clock() 17 | 18 | self.screen_size = (1200, 900) 19 | self.bird_position = (0, 0) 20 | self.limit_distance_x = 120 21 | self.limit_distance_y = 90 22 | self.obstacle_size = [120, 90] 23 | self.obstacle1_x = [] 24 | self.obstacle1_y = [] 25 | self.obstacle2_x = [] 26 | self.obstacle2_y = [] 27 | 28 | for i in range(8): 29 | # obstacle 1 30 | self.obstacle1_x.append(360) 31 | if i <= 3: 32 | self.obstacle1_y.append(90 * i) 33 | else: 34 | self.obstacle1_y.append(90 * (i + 2)) 35 | # obstacle 2 36 | self.obstacle2_x.append(720) 37 | if i <= 4: 38 | self.obstacle2_y.append(90 * i) 39 | else: 40 | self.obstacle2_y.append(90 * (i + 2)) 41 | 42 | self.bird_male_init_position = [0.0, 0.0] 43 | self.bird_male_position = [0, 0] 44 | self.bird_female_init_position = [1080, 0] 45 | 46 | self.path = [] 47 | 48 | def collide(self, state_position): 49 | flag = 1 50 | flag1 = 1 51 | flag2 = 1 52 | 53 | # obstacle 1 54 | dx = [] 55 | dy = [] 56 | for i in range(8): 57 | dx1 = abs(self.obstacle1_x[i] - state_position[0]) 58 | dx.append(dx1) 59 | dy1 = abs(self.obstacle1_y[i] - state_position[1]) 60 | dy.append(dy1) 61 | mindx = min(dx) 62 | mindy = min(dy) 63 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 64 | flag1 = 0 65 | 66 | # obstacle 2 67 | dx_second = [] 68 | dy_second = [] 69 | for i in range(8): 70 | dx1 = abs(self.obstacle2_x[i] - state_position[0]) 71 | dx_second.append(dx1) 72 | dy1 = abs(self.obstacle2_y[i] - state_position[1]) 73 | dy_second.append(dy1) 74 | mindx = min(dx_second) 75 | mindy = min(dy_second) 76 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 77 | flag2 = 0 78 | 79 | if flag1 == 0 and flag2 == 0: 80 | flag = 0 81 | 82 | # collide edge 83 | if state_position[0] > 1080 or \ 84 | state_position[0] < 0 or \ 85 | state_position[1] > 810 or \ 86 | state_position[1] < 0: 87 | flag = 1 88 | 89 | return flag 90 | 91 | def find(self, state_position): 92 | flag = 0 93 | if abs(state_position[0] - self.bird_female_init_position[0]) < \ 94 | self.limit_distance_x and \ 95 | abs(state_position[1] - self.bird_female_init_position[1]) < \ 96 | self.limit_distance_y: 97 | flag = 1 98 | return flag 99 | 100 | def state_to_position(self, state): 101 | i = int(state / 10) 102 | j = state % 10 103 | position = [0, 0] 104 | position[0] = 120 * j 105 | position[1] = 90 * i 106 | return position 107 | 108 | def position_to_state(self, position): 109 | i = position[0] / 120 110 | j = position[1] / 90 111 | return int(i + 10 * j) 112 | 113 | def reset(self): 114 | # 随机产生一个初始位置 115 | flag1 = 1 116 | flag2 = 1 117 | while flag1 or flag2 == 1: 118 | state = self.states[int(random.random() * len(self.states))] 119 | state_position = self.state_to_position(state) 120 | flag1 = self.collide(state_position) 121 | flag2 = self.find(state_position) 122 | return state 123 | 124 | def transform(self, state, action): 125 | current_position = self.state_to_position(state) 126 | next_position = [0, 0] 127 | flag_collide = 0 128 | flag_find = 0 129 | 130 | flag_collide = self.collide(current_position) 131 | flag_find = self.find(current_position) 132 | if flag_collide == 1: 133 | return state, -1, True 134 | if flag_find == 1: 135 | return state, 1, True 136 | 137 | if action == 'e': 138 | next_position[0] = current_position[0] + 120 139 | next_position[1] = current_position[1] 140 | if action == 's': 141 | next_position[0] = current_position[0] 142 | next_position[1] = current_position[1] + 90 143 | if action == 'w': 144 | next_position[0] = current_position[0] - 120 145 | next_position[1] = current_position[1] 146 | if action == 'n': 147 | next_position[0] = current_position[0] 148 | next_position[1] = current_position[1] - 90 149 | 150 | flag_collide = self.collide(next_position) 151 | if flag_collide == 1: 152 | return self.position_to_state(current_position), -1, True 153 | 154 | flag_find = self.find(next_position) 155 | if flag_find == 1: 156 | return self.position_to_state(next_position), 1, True 157 | 158 | return self.position_to_state(next_position), 0, False 159 | 160 | def gameover(self): 161 | for event in pygame.event.get(): 162 | if event.type == pygame.QUIT: 163 | exit() 164 | 165 | def render(self): 166 | if self.viewer is None: 167 | pygame.init() 168 | 169 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32) 170 | pygame.display.set_caption("yuanyang") 171 | # load pic 172 | self.bird_male = load_bird_male() 173 | self.bird_female = load_bird_female() 174 | self.background = load_background() 175 | self.obstacle = load_obstacle() 176 | 177 | # self.viewer.blit(self.bird_female, self.bird_female_init_position) 178 | # self.viewer.blit(self.bird_male, self.bird_male_init_position) 179 | 180 | self.viewer.blit(self.background, (0, 0)) 181 | self.font = pygame.font.SysFont('times', 15) 182 | 183 | self.viewer.blit(self.background, (0, 0)) 184 | for i in range(11): 185 | pygame.draw.lines(self.viewer, 186 | (255, 255, 255), 187 | True, 188 | ((120 * i, 0), (120 * i, 900)), 189 | 1 190 | ) 191 | pygame.draw.lines(self.viewer, 192 | (255, 255, 255), 193 | True, 194 | ((0, 90 * i), (1200, 90 * i)), 195 | 1 196 | ) 197 | 198 | for i in range(8): 199 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i])) 200 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i])) 201 | 202 | self.viewer.blit(self.bird_female, self.bird_female_init_position) 203 | self.viewer.blit(self.bird_male, self.bird_male_init_position) 204 | 205 | for i in range(10): 206 | for j in range(10): 207 | surface = self.font.render(str( 208 | round(float(self.value[i, j]), 3)), True, (0, 0, 0) 209 | ) 210 | self.viewer.blit(surface, (120 * i + 5, 90 * j + 70)) 211 | 212 | # 画路径点 213 | for i in range(len(self.path)): 214 | rec_position = self.state_to_position(self.path[i]) 215 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3) 216 | surface = self.font.render(str(i), True, (255, 0, 0)) 217 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5)) 218 | 219 | pygame.display.update() 220 | self.gameover() 221 | self.FPSCLOCK.tick(30) 222 | 223 | 224 | if __name__ == "__main__": 225 | yy = YuanYangEnv() 226 | yy.render() 227 | while True: 228 | for event in pygame.event.get(): 229 | if event.type == pygame.QUIT: 230 | exit() 231 | 232 | -------------------------------------------------------------------------------- /ch_1/sec_4/load_images.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import os.path as osp 3 | 4 | path_file = osp.abspath(__file__) 5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images') 6 | 7 | def load_bird_male(): 8 | obj = 'bird_male.png' 9 | obj_path = osp.join(path_images, obj) 10 | return pygame.image.load(obj_path) 11 | 12 | def load_bird_female(): 13 | obj = 'bird_female.png' 14 | obj_path = osp.join(path_images, obj) 15 | return pygame.image.load(obj_path) 16 | 17 | def load_background(): 18 | obj = 'background.jpg' 19 | obj_path = osp.join(path_images, obj) 20 | return pygame.image.load(obj_path) 21 | 22 | def load_obstacle(): 23 | obj = 'obstacle.png' 24 | obj_path = osp.join(path_images, obj) 25 | return pygame.image.load(obj_path) 26 | -------------------------------------------------------------------------------- /ch_1/sec_4/mc_rl.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import time 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from yuan_yang_env_mc import YuanYangEnv 7 | 8 | class MC_RL: 9 | def __init__(self, yuanyang): 10 | # 行为值函数的初始化 11 | self.qvalue = np.zeros((len(yuanyang.states), len(yuanyang.actions))) * 0.1 12 | # 次数初始化 13 | self.n = 0.001 * np.ones( 14 | (len(yuanyang.states), len(yuanyang.actions)) 15 | ) 16 | self.actions = yuanyang.actions 17 | self.yuanyang = yuanyang 18 | self.gamma = yuanyang.gamma 19 | 20 | # 定义贪婪策略 21 | def greedy_policy(self, qfun, state): 22 | amax = qfun[state, :].argmax() 23 | return self.actions[amax] 24 | 25 | def epsilon_greedy_policy(self, qfun, state, epsilon): 26 | amax = qfun[state, :].argmax() 27 | # 概率部分 28 | if np.random.uniform() < 1 - epsilon: 29 | # 最优动作 30 | return self.actions[amax] 31 | else: 32 | return self.actions[int(random.random() * len(self.actions))] 33 | 34 | # 找到动作所对应的序号 35 | def find_anum(self, a): 36 | for i in range(len(self.actions)): 37 | if a == self.actions[i]: 38 | return i 39 | 40 | def mc_learning_ei(self, num_iter): 41 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions))) 42 | self.n = 0.001 * np.ones((len(self.yuanyang.states), len(self.yuanyang.actions))) 43 | # 学习 num_iter 次 44 | for iter1 in range(num_iter): 45 | # 采集状态样本 46 | s_sample = [] 47 | # 采集动作样本 48 | a_sample = [] 49 | # 采集回报样本 50 | r_sample = [] 51 | # 随机初始化状态 52 | s = self.yuanyang.reset() 53 | a = self.actions[int(random.random() * len(self.actions))] 54 | done = False 55 | step_num = 0 56 | 57 | if self.mc_test() == 1: 58 | print("探索初始化第1次完成任务需要的次数:{}".format(iter1)) 59 | break 60 | 61 | # 采集数据 s0-a1-s1-s1-a2-s2...terminate state 62 | while done == False and step_num < 30: 63 | # 与环境交互 64 | s_next, r, done = self.yuanyang.transform(s, a) 65 | a_num = self.find_anum(a) 66 | # 往回走给予惩罚 67 | if s_next in s_sample: 68 | r = -2 69 | # 存储数据,采样数据 70 | s_sample.append(s) 71 | r_sample.append(r) 72 | a_sample.append(a_num) 73 | step_num += 1 74 | # 转移到下一状态,继续试验,s0-s1-s2 75 | s = s_next 76 | a = self.greedy_policy(self.qvalue, s) 77 | 78 | # 从样本中计算折扣累计回报,g(s_0) = r_0 + gamma * r_1 + ... + v(sT) 79 | a = self.greedy_policy(self.qvalue, s) 80 | g = self.qvalue[s, self.find_anum(a)] 81 | for i in range(len(s_sample) - 1, -1, -1): 82 | g *= self.gamma 83 | g += r_sample[i] 84 | # g = G(s1, a),开始对其他状态累计回报 85 | for i in range(len(s_sample)): 86 | # 计算状态-行为对(s, a)的次数,s, a1...s, a2 87 | self.n[s_sample[i], a_sample[i]] += 1.0 88 | # 利用增量式方法更新值函数 89 | self.qvalue[s_sample[i], a_sample[i]] = (self.qvalue[s_sample[i], a_sample[i]] * (self.n[s_sample[i], a_sample[i]] - 1) + g) / self.n[s_sample[i], a_sample[i]] 90 | g -= r_sample[i] 91 | g /= self.gamma 92 | 93 | return self.qvalue 94 | 95 | def mc_test(self): 96 | s = 0 97 | s_sample = [] 98 | done = False 99 | flag = 0 100 | step_num = 0 101 | while done == False and step_num < 30: 102 | a = self.greedy_policy(self.qvalue, s) 103 | # 与环境交互 104 | s_next, r, done = self.yuanyang.transform(s, a) 105 | s_sample.append(s) 106 | s = s_next 107 | step_num += 1 108 | if s == 9: 109 | flag = 1 110 | return flag 111 | 112 | def mc_learning_on_policy(self, num_iter, epsilon): 113 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions))) 114 | self.n = 0.001 * np.ones((len(self.yuanyang.states), len(self.yuanyang.actions))) 115 | # 学习 num_iter 次 116 | for iter1 in range(num_iter): 117 | # 采集状态样本 118 | s_sample = [] 119 | # 采集动作样本 120 | a_sample = [] 121 | # 采集回报样本 122 | r_sample = [] 123 | # 固定初始状态 124 | s = 0 125 | done = False 126 | step_num = 0 127 | epsilon = epsilon * np.exp(-iter1 / 1000) 128 | 129 | # 采集数据 s0-a1-s1-s1-a2-s2...terminate state 130 | while done == False and step_num < 30: 131 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 132 | # 与环境交互 133 | s_next, r, done = self.yuanyang.transform(s, a) 134 | a_num = self.find_anum(a) 135 | # 往回走给予惩罚 136 | if s_next in s_sample: 137 | r = -2 138 | # 存储数据,采样数据 139 | s_sample.append(s) 140 | r_sample.append(r) 141 | a_sample.append(a_num) 142 | step_num += 1 143 | # 转移到下一状态,继续试验,s0-s1-s2 144 | s = s_next 145 | 146 | if s == 9: 147 | print('同轨策略第1次完成任务需要次数:{}'.format(iter1)) 148 | break 149 | 150 | # 从样本中计算折扣累计回报 g(s_0) = r_0 + gamma * r_1 + gamma ^ 3 * r3 + v(sT) 151 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 152 | g = self.qvalue[s, self.find_anum(a)] 153 | # 计算该序列第1状态的折扣累计回报 154 | for i in range(len(s_sample) - 1, -1, -1): 155 | g *= self.gamma 156 | g += r_sample[i] 157 | # g = G(s1, a),开始计算其他状态的折扣累计回报 158 | for i in range(len(s_sample)): 159 | # 计算状态-行为对 (s, a) 的次数,s, a1...s, a2 160 | self.n[s_sample[i], a_sample[i]] += 1.0 161 | # 利用增量式方法更新值函数 162 | self.qvalue[s_sample[i], a_sample[i]] = ( 163 | self.qvalue[s_sample[i], a_sample[i]] * (self.n[s_sample[i], a_sample[i]] - 1) + g 164 | ) / self.n[s_sample[i], a_sample[i]] 165 | g -= r_sample[i] 166 | g /= self.gamma 167 | 168 | return self.qvalue 169 | 170 | 171 | def mc_learning_ei(): 172 | yuanyang = YuanYangEnv() 173 | brain = MC_RL(yuanyang) 174 | # 探索初始化方法 175 | qvalue1 = brain.mc_learning_ei(num_iter=10000) 176 | 177 | # 打印 178 | flag = 1 179 | s = 0 180 | path = [] 181 | # 将 v 值打印出来 182 | yuanyang.action_value = qvalue1 183 | step_num = 0 184 | 185 | # 将最优路径打印出来 186 | while flag: 187 | path.append(s) 188 | yuanyang.path = path 189 | a = brain.greedy_policy(qvalue1, s) 190 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3]) 191 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 192 | yuanyang.render() 193 | time.sleep(0.25) 194 | step_num += 1 195 | s_, r, t = yuanyang.transform(s, a) 196 | if t == True or step_num > 200: 197 | flag = 0 198 | s = s_ 199 | 200 | # 渲染最后的路径点 201 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 202 | path.append(s) 203 | yuanyang.render() 204 | while True: 205 | yuanyang.render() 206 | 207 | def mc_learning_epsilon(): 208 | yuanyang = YuanYangEnv() 209 | brain = MC_RL(yuanyang) 210 | # 探索初始化方法 211 | qvalue2 = brain.mc_learning_on_policy(num_iter=10000, epsilon=0.2) 212 | 213 | # 打印 214 | flag = 1 215 | s = 0 216 | path = [] 217 | # 将 v 值打印出来 218 | yuanyang.action_value = qvalue2 219 | step_num = 0 220 | 221 | # 将最优路径打印出来 222 | while flag: 223 | path.append(s) 224 | yuanyang.path = path 225 | a = brain.greedy_policy(qvalue2, s) 226 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3]) 227 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 228 | yuanyang.render() 229 | time.sleep(0.25) 230 | step_num += 1 231 | s_, r, t = yuanyang.transform(s, a) 232 | if t == True or step_num > 200: 233 | flag = 0 234 | s = s_ 235 | 236 | # 渲染最后的路径点 237 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 238 | path.append(s) 239 | yuanyang.render() 240 | while True: 241 | yuanyang.render() 242 | 243 | if __name__ == "__main__": 244 | # mc_learning_ei() 245 | 246 | mc_learning_epsilon() -------------------------------------------------------------------------------- /ch_1/sec_4/yuan_yang_env_mc.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | from load_images import * 4 | import numpy as np 5 | 6 | class YuanYangEnv: 7 | def __init__(self): 8 | self.states = [] 9 | for i in range(0, 100): 10 | self.states.append(i) 11 | self.actions = ['e', 's', 'w', 'n'] 12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减 13 | self.gamma = 0.95 14 | self.action_value = np.zeros((100, 4)) 15 | 16 | self.viewer = None 17 | self.FPSCLOCK = pygame.time.Clock() 18 | 19 | self.screen_size = (1200, 900) 20 | self.bird_position = (0, 0) 21 | self.limit_distance_x = 120 22 | self.limit_distance_y = 90 23 | self.obstacle_size = [120, 90] 24 | self.obstacle1_x = [] 25 | self.obstacle1_y = [] 26 | self.obstacle2_x = [] 27 | self.obstacle2_y = [] 28 | 29 | for i in range(8): 30 | # obstacle 1 31 | self.obstacle1_x.append(360) 32 | if i <= 3: 33 | self.obstacle1_y.append(90 * i) 34 | else: 35 | self.obstacle1_y.append(90 * (i + 2)) 36 | # obstacle 2 37 | self.obstacle2_x.append(720) 38 | if i <= 4: 39 | self.obstacle2_y.append(90 * i) 40 | else: 41 | self.obstacle2_y.append(90 * (i + 2)) 42 | 43 | self.bird_male_init_position = [0.0, 0.0] 44 | self.bird_male_position = [0, 0] 45 | self.bird_female_init_position = [1080, 0] 46 | 47 | self.path = [] 48 | 49 | def collide(self, state_position): 50 | flag = 1 51 | flag1 = 1 52 | flag2 = 1 53 | 54 | # obstacle 1 55 | dx = [] 56 | dy = [] 57 | for i in range(8): 58 | dx1 = abs(self.obstacle1_x[i] - state_position[0]) 59 | dx.append(dx1) 60 | dy1 = abs(self.obstacle1_y[i] - state_position[1]) 61 | dy.append(dy1) 62 | mindx = min(dx) 63 | mindy = min(dy) 64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 65 | flag1 = 0 66 | 67 | # obstacle 2 68 | dx_second = [] 69 | dy_second = [] 70 | for i in range(8): 71 | dx1 = abs(self.obstacle2_x[i] - state_position[0]) 72 | dx_second.append(dx1) 73 | dy1 = abs(self.obstacle2_y[i] - state_position[1]) 74 | dy_second.append(dy1) 75 | mindx = min(dx_second) 76 | mindy = min(dy_second) 77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 78 | flag2 = 0 79 | 80 | if flag1 == 0 and flag2 == 0: 81 | flag = 0 82 | 83 | # collide edge 84 | if state_position[0] > 1080 or \ 85 | state_position[0] < 0 or \ 86 | state_position[1] > 810 or \ 87 | state_position[1] < 0: 88 | flag = 1 89 | 90 | return flag 91 | 92 | def find(self, state_position): 93 | flag = 0 94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \ 95 | self.limit_distance_x and \ 96 | abs(state_position[1] - self.bird_female_init_position[1]) < \ 97 | self.limit_distance_y: 98 | flag = 1 99 | return flag 100 | 101 | def state_to_position(self, state): 102 | i = int(state / 10) 103 | j = state % 10 104 | position = [0, 0] 105 | position[0] = 120 * j 106 | position[1] = 90 * i 107 | return position 108 | 109 | def position_to_state(self, position): 110 | i = position[0] / 120 111 | j = position[1] / 90 112 | return int(i + 10 * j) 113 | 114 | def reset(self): 115 | # 随机产生一个初始位置 116 | flag1 = 1 117 | flag2 = 1 118 | while flag1 or flag2 == 1: 119 | state = self.states[int(random.random() * len(self.states))] 120 | state_position = self.state_to_position(state) 121 | flag1 = self.collide(state_position) 122 | flag2 = self.find(state_position) 123 | return state 124 | 125 | def transform(self, state, action): 126 | current_position = self.state_to_position(state) 127 | next_position = [0, 0] 128 | flag_collide = 0 129 | flag_find = 0 130 | 131 | flag_collide = self.collide(current_position) 132 | flag_find = self.find(current_position) 133 | if flag_collide == 1: 134 | return state, -10, True 135 | if flag_find == 1: 136 | return state, 10, True 137 | 138 | if action == 'e': 139 | next_position[0] = current_position[0] + 120 140 | next_position[1] = current_position[1] 141 | if action == 's': 142 | next_position[0] = current_position[0] 143 | next_position[1] = current_position[1] + 90 144 | if action == 'w': 145 | next_position[0] = current_position[0] - 120 146 | next_position[1] = current_position[1] 147 | if action == 'n': 148 | next_position[0] = current_position[0] 149 | next_position[1] = current_position[1] - 90 150 | 151 | flag_collide = self.collide(next_position) 152 | if flag_collide == 1: 153 | return self.position_to_state(current_position), -10, True 154 | 155 | flag_find = self.find(next_position) 156 | if flag_find == 1: 157 | return self.position_to_state(next_position), 10, True 158 | 159 | return self.position_to_state(next_position), -2, False 160 | 161 | def gameover(self): 162 | for event in pygame.event.get(): 163 | if event.type == pygame.QUIT: 164 | exit() 165 | 166 | def render(self): 167 | if self.viewer is None: 168 | pygame.init() 169 | 170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32) 171 | pygame.display.set_caption("yuanyang") 172 | # load pic 173 | self.bird_male = load_bird_male() 174 | self.bird_female = load_bird_female() 175 | self.background = load_background() 176 | self.obstacle = load_obstacle() 177 | 178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position) 179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position) 180 | 181 | self.viewer.blit(self.background, (0, 0)) 182 | self.font = pygame.font.SysFont('times', 15) 183 | 184 | self.viewer.blit(self.background, (0, 0)) 185 | for i in range(11): 186 | pygame.draw.lines(self.viewer, 187 | (255, 255, 255), 188 | True, 189 | ((120 * i, 0), (120 * i, 900)), 190 | 1 191 | ) 192 | pygame.draw.lines(self.viewer, 193 | (255, 255, 255), 194 | True, 195 | ((0, 90 * i), (1200, 90 * i)), 196 | 1 197 | ) 198 | 199 | for i in range(8): 200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i])) 201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i])) 202 | 203 | self.viewer.blit(self.bird_female, self.bird_female_init_position) 204 | self.viewer.blit(self.bird_male, self.bird_male_init_position) 205 | 206 | # 画动作-值函数 207 | for i in range(100): 208 | y = int(i / 10) 209 | x = i % 10 210 | # 往东的值函数 211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0)) 212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45)) 213 | # 往南的值函数 214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0)) 215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70)) 216 | # 往西的值函数 217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0)) 218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45)) 219 | # 往北的值函数 220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0)) 221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10)) 222 | 223 | # 画路径点 224 | for i in range(len(self.path)): 225 | rec_position = self.state_to_position(self.path[i]) 226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3) 227 | surface = self.font.render(str(i), True, (255, 0, 0)) 228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5)) 229 | 230 | pygame.display.update() 231 | self.gameover() 232 | self.FPSCLOCK.tick(30) 233 | 234 | 235 | if __name__ == "__main__": 236 | yy = YuanYangEnv() 237 | yy.render() 238 | while True: 239 | for event in pygame.event.get(): 240 | if event.type == pygame.QUIT: 241 | exit() 242 | 243 | -------------------------------------------------------------------------------- /ch_1/sec_5/TD_RL.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | import pygame 5 | import time 6 | import matplotlib.pyplot as plt 7 | from yuan_yang_env_td import YuanYangEnv 8 | 9 | 10 | class TD_RL: 11 | def __init__(self, yuanyang): 12 | self.gamma = yuanyang.gamma 13 | self.yuanyang = yuanyang 14 | # 值函数的初始值 15 | self.qvalue = np.zeros( 16 | (len(self.yuanyang.states), len(self.yuanyang.actions)) 17 | ) 18 | 19 | # 定义贪婪策略 20 | def greedy_policy(self, qfun, state): 21 | amax = qfun[state, :].argmax() 22 | return self.yuanyang.actions[amax] 23 | 24 | # 定义 epsilon 贪婪策略 25 | def epsilon_greedy_policy(self, qfun, state, epsilon): 26 | amax = qfun[state, :].argmax() 27 | # 概率部分 28 | if np.random.uniform() < 1 - epsilon: 29 | # 最优动作 30 | return self.yuanyang.actions[amax] 31 | else: 32 | return self.yuanyang.actions[ 33 | int(random.random() * len(self.yuanyang.actions)) 34 | ] 35 | 36 | # 找到动作所对应的序号 37 | def find_anum(self, a): 38 | for i in range(len(self.yuanyang.actions)): 39 | if a == self.yuanyang.actions[i]: 40 | return i 41 | 42 | def sarsa(self, num_iter, alpha, epsilon): 43 | iter_num = [] 44 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions))) 45 | # 第1个大循环,产生了多少实验 46 | for iter in range(num_iter): 47 | # 随机初始化状态 48 | epsilon = epsilon * 0.99 49 | s_sample = [] 50 | # 初始状态 51 | s = 0 52 | flag = self.greedy_test() 53 | if flag == 1: 54 | iter_num.append(iter) 55 | if len(iter_num) < 2: 56 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0])) 57 | if flag == 2: 58 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter)) 59 | break 60 | # 利用 epsilon-greedy 策略选初始动作 61 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 62 | t = False 63 | count = 0 64 | 65 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate 66 | while t == False and count < 30: 67 | # 与环境交互得到下一状态 68 | s_next, r, t = self.yuanyang.transform(s, a) 69 | a_num = self.find_anum(a) 70 | # 本轨迹中已有,给出负回报 71 | if s_next in s_sample: 72 | r = -2 73 | s_sample.append(s) 74 | # 判断是否是终止状态 75 | if t == True: 76 | q_target = r 77 | else: 78 | # 下一状态处的最大动作,体现同轨策略 79 | a1 = self.epsilon_greedy_policy(self.qvalue, s_next, epsilon) 80 | a1_num = self.find_anum(a1) 81 | # Q-learning 的更新公式(SARSA) 82 | q_target = r + self.gamma * self.qvalue[s_next, a1_num] 83 | # 利用 td 方法更新动作值函数 alpha 84 | self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num]) 85 | # 转到下一状态 86 | s = s_next 87 | # 行为策略 88 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 89 | count += 1 90 | 91 | return self.qvalue 92 | 93 | def greedy_test(self): 94 | s = 0 95 | s_sample = [] 96 | done = False 97 | flag = 0 98 | step_num = 0 99 | while done == False and step_num < 30: 100 | a = self.greedy_policy(self.qvalue, s) 101 | # 与环境交互 102 | s_next, r, done = self.yuanyang.transform(s, a) 103 | s_sample.append(s) 104 | s = s_next 105 | step_num += 1 106 | 107 | if s == 9: 108 | flag = 1 109 | if s == 9 and step_num < 21: 110 | flag = 2 111 | 112 | return flag 113 | 114 | def qlearning(self, num_iter, alpha, epsilon): 115 | iter_num = [] 116 | self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions))) 117 | # 第1个大循环,产生了多少实验 118 | for iter in range(num_iter): 119 | # 随机初始化状态 120 | epsilon = epsilon * 0.99 121 | s_sample = [] 122 | # 初始状态 123 | s = 0 124 | flag = self.greedy_test() 125 | if flag == 1: 126 | iter_num.append(iter) 127 | if len(iter_num) < 2: 128 | print('q-learning 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0])) 129 | if flag == 2: 130 | print('q-learning 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter)) 131 | break 132 | # 随机选取初始动作 133 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 134 | t = False 135 | count = 0 136 | 137 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate 138 | while t == False and count < 30: 139 | # 与环境交互得到下一状态 140 | s_next, r, t = self.yuanyang.transform(s, a) 141 | a_num = self.find_anum(a) 142 | # 本轨迹中已有,给出负回报 143 | if s_next in s_sample: 144 | r = -2 145 | s_sample.append(s) 146 | # 判断是否是终止状态 147 | if t == True: 148 | q_target = r 149 | else: 150 | # 下一状态处的最大动作,体现同轨策略 151 | a1 = self.greedy_policy(self.qvalue, s_next) 152 | a1_num = self.find_anum(a1) 153 | # Q-learning 的更新公式 TD(0) 154 | q_target = r + self.gamma * self.qvalue[s_next, a1_num] 155 | # 利用 td 方法更新动作值函数 alpha 156 | self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num]) 157 | # 转到下一状态 158 | s = s_next 159 | # 行为策略 160 | a = self.epsilon_greedy_policy(self.qvalue, s, epsilon) 161 | count += 1 162 | 163 | return self.qvalue 164 | 165 | def sarsa(): 166 | yuanyang = YuanYangEnv() 167 | brain = TD_RL(yuanyang) 168 | qvalue1 = brain.sarsa(num_iter=5000, alpha=0.1, epsilon=0.8) 169 | 170 | # 打印 171 | flag = 1 172 | s = 0 173 | path = [] 174 | # 将 v 值打印出来 175 | yuanyang.action_value = qvalue1 176 | step_num = 0 177 | 178 | # 将最优路径打印出来 179 | while flag: 180 | path.append(s) 181 | yuanyang.path = path 182 | a = brain.greedy_policy(qvalue1, s) 183 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3]) 184 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 185 | yuanyang.render() 186 | time.sleep(0.1) 187 | step_num += 1 188 | s_, r, t = yuanyang.transform(s, a) 189 | if t == True or step_num > 200: 190 | flag = 0 191 | s = s_ 192 | 193 | # 渲染最后的路径点 194 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 195 | path.append(s) 196 | yuanyang.render() 197 | while True: 198 | yuanyang.render() 199 | 200 | def qlearning(): 201 | yuanyang = YuanYangEnv() 202 | brain = TD_RL(yuanyang) 203 | qvalue2 = brain.qlearning(num_iter=5000, alpha=0.1, epsilon=0.1) 204 | 205 | # 打印 206 | flag = 1 207 | s = 0 208 | path = [] 209 | # 将 v 值打印出来 210 | yuanyang.action_value = qvalue2 211 | step_num = 0 212 | 213 | # 将最优路径打印出来 214 | while flag: 215 | path.append(s) 216 | yuanyang.path = path 217 | a = brain.greedy_policy(qvalue2, s) 218 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3]) 219 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 220 | yuanyang.render() 221 | time.sleep(0.1) 222 | step_num += 1 223 | s_, r, t = yuanyang.transform(s, a) 224 | if t == True or step_num > 200: 225 | flag = 0 226 | s = s_ 227 | 228 | # 渲染最后的路径点 229 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 230 | path.append(s) 231 | yuanyang.render() 232 | while True: 233 | yuanyang.render() 234 | 235 | 236 | if __name__ == "__main__": 237 | # sarsa() 238 | qlearning() -------------------------------------------------------------------------------- /ch_1/sec_5/load_images.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import os.path as osp 3 | 4 | path_file = osp.abspath(__file__) 5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images') 6 | 7 | def load_bird_male(): 8 | obj = 'bird_male.png' 9 | obj_path = osp.join(path_images, obj) 10 | return pygame.image.load(obj_path) 11 | 12 | def load_bird_female(): 13 | obj = 'bird_female.png' 14 | obj_path = osp.join(path_images, obj) 15 | return pygame.image.load(obj_path) 16 | 17 | def load_background(): 18 | obj = 'background.jpg' 19 | obj_path = osp.join(path_images, obj) 20 | return pygame.image.load(obj_path) 21 | 22 | def load_obstacle(): 23 | obj = 'obstacle.png' 24 | obj_path = osp.join(path_images, obj) 25 | return pygame.image.load(obj_path) 26 | -------------------------------------------------------------------------------- /ch_1/sec_5/yuan_yang_env_td.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | from load_images import * 4 | import numpy as np 5 | 6 | class YuanYangEnv: 7 | def __init__(self): 8 | self.states = [] 9 | for i in range(0, 100): 10 | self.states.append(i) 11 | self.actions = ['e', 's', 'w', 'n'] 12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减 13 | self.gamma = 0.95 14 | self.action_value = np.zeros((100, 4)) 15 | 16 | self.viewer = None 17 | self.FPSCLOCK = pygame.time.Clock() 18 | 19 | self.screen_size = (1200, 900) 20 | self.bird_position = (0, 0) 21 | self.limit_distance_x = 120 22 | self.limit_distance_y = 90 23 | self.obstacle_size = [120, 90] 24 | self.obstacle1_x = [] 25 | self.obstacle1_y = [] 26 | self.obstacle2_x = [] 27 | self.obstacle2_y = [] 28 | 29 | for i in range(8): 30 | # obstacle 1 31 | self.obstacle1_x.append(360) 32 | if i <= 3: 33 | self.obstacle1_y.append(90 * i) 34 | else: 35 | self.obstacle1_y.append(90 * (i + 2)) 36 | # obstacle 2 37 | self.obstacle2_x.append(720) 38 | if i <= 4: 39 | self.obstacle2_y.append(90 * i) 40 | else: 41 | self.obstacle2_y.append(90 * (i + 2)) 42 | 43 | self.bird_male_init_position = [0.0, 0.0] 44 | self.bird_male_position = [0, 0] 45 | self.bird_female_init_position = [1080, 0] 46 | 47 | self.path = [] 48 | 49 | def collide(self, state_position): 50 | flag = 1 51 | flag1 = 1 52 | flag2 = 1 53 | 54 | # obstacle 1 55 | dx = [] 56 | dy = [] 57 | for i in range(8): 58 | dx1 = abs(self.obstacle1_x[i] - state_position[0]) 59 | dx.append(dx1) 60 | dy1 = abs(self.obstacle1_y[i] - state_position[1]) 61 | dy.append(dy1) 62 | mindx = min(dx) 63 | mindy = min(dy) 64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 65 | flag1 = 0 66 | 67 | # obstacle 2 68 | dx_second = [] 69 | dy_second = [] 70 | for i in range(8): 71 | dx1 = abs(self.obstacle2_x[i] - state_position[0]) 72 | dx_second.append(dx1) 73 | dy1 = abs(self.obstacle2_y[i] - state_position[1]) 74 | dy_second.append(dy1) 75 | mindx = min(dx_second) 76 | mindy = min(dy_second) 77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 78 | flag2 = 0 79 | 80 | if flag1 == 0 and flag2 == 0: 81 | flag = 0 82 | 83 | # collide edge 84 | if state_position[0] > 1080 or \ 85 | state_position[0] < 0 or \ 86 | state_position[1] > 810 or \ 87 | state_position[1] < 0: 88 | flag = 1 89 | 90 | return flag 91 | 92 | def find(self, state_position): 93 | flag = 0 94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \ 95 | self.limit_distance_x and \ 96 | abs(state_position[1] - self.bird_female_init_position[1]) < \ 97 | self.limit_distance_y: 98 | flag = 1 99 | return flag 100 | 101 | def state_to_position(self, state): 102 | i = int(state / 10) 103 | j = state % 10 104 | position = [0, 0] 105 | position[0] = 120 * j 106 | position[1] = 90 * i 107 | return position 108 | 109 | def position_to_state(self, position): 110 | i = position[0] / 120 111 | j = position[1] / 90 112 | return int(i + 10 * j) 113 | 114 | def reset(self): 115 | # 随机产生一个初始位置 116 | flag1 = 1 117 | flag2 = 1 118 | while flag1 or flag2 == 1: 119 | state = self.states[int(random.random() * len(self.states))] 120 | state_position = self.state_to_position(state) 121 | flag1 = self.collide(state_position) 122 | flag2 = self.find(state_position) 123 | return state 124 | 125 | def transform(self, state, action): 126 | current_position = self.state_to_position(state) 127 | next_position = [0, 0] 128 | flag_collide = 0 129 | flag_find = 0 130 | 131 | flag_collide = self.collide(current_position) 132 | flag_find = self.find(current_position) 133 | if flag_collide == 1: 134 | return state, -10, True 135 | if flag_find == 1: 136 | return state, 10, True 137 | 138 | if action == 'e': 139 | next_position[0] = current_position[0] + 120 140 | next_position[1] = current_position[1] 141 | if action == 's': 142 | next_position[0] = current_position[0] 143 | next_position[1] = current_position[1] + 90 144 | if action == 'w': 145 | next_position[0] = current_position[0] - 120 146 | next_position[1] = current_position[1] 147 | if action == 'n': 148 | next_position[0] = current_position[0] 149 | next_position[1] = current_position[1] - 90 150 | 151 | flag_collide = self.collide(next_position) 152 | if flag_collide == 1: 153 | return self.position_to_state(current_position), -1000, True 154 | 155 | flag_find = self.find(next_position) 156 | if flag_find == 1: 157 | return self.position_to_state(next_position), 10, True 158 | 159 | return self.position_to_state(next_position), -2, False 160 | 161 | def gameover(self): 162 | for event in pygame.event.get(): 163 | if event.type == pygame.QUIT: 164 | exit() 165 | 166 | def render(self): 167 | if self.viewer is None: 168 | pygame.init() 169 | 170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32) 171 | pygame.display.set_caption("yuanyang") 172 | # load pic 173 | self.bird_male = load_bird_male() 174 | self.bird_female = load_bird_female() 175 | self.background = load_background() 176 | self.obstacle = load_obstacle() 177 | 178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position) 179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position) 180 | 181 | self.viewer.blit(self.background, (0, 0)) 182 | self.font = pygame.font.SysFont('times', 15) 183 | 184 | self.viewer.blit(self.background, (0, 0)) 185 | for i in range(11): 186 | pygame.draw.lines(self.viewer, 187 | (255, 255, 255), 188 | True, 189 | ((120 * i, 0), (120 * i, 900)), 190 | 1 191 | ) 192 | pygame.draw.lines(self.viewer, 193 | (255, 255, 255), 194 | True, 195 | ((0, 90 * i), (1200, 90 * i)), 196 | 1 197 | ) 198 | 199 | for i in range(8): 200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i])) 201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i])) 202 | 203 | self.viewer.blit(self.bird_female, self.bird_female_init_position) 204 | self.viewer.blit(self.bird_male, self.bird_male_init_position) 205 | 206 | # 画动作-值函数 207 | for i in range(100): 208 | y = int(i / 10) 209 | x = i % 10 210 | # 往东的值函数 211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0)) 212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45)) 213 | # 往南的值函数 214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0)) 215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70)) 216 | # 往西的值函数 217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0)) 218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45)) 219 | # 往北的值函数 220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0)) 221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10)) 222 | 223 | # 画路径点 224 | for i in range(len(self.path)): 225 | rec_position = self.state_to_position(self.path[i]) 226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3) 227 | surface = self.font.render(str(i), True, (255, 0, 0)) 228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5)) 229 | 230 | pygame.display.update() 231 | self.gameover() 232 | self.FPSCLOCK.tick(30) 233 | 234 | 235 | if __name__ == "__main__": 236 | yy = YuanYangEnv() 237 | yy.render() 238 | while True: 239 | for event in pygame.event.get(): 240 | if event.type == pygame.QUIT: 241 | exit() 242 | 243 | -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/die.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/die.ogg -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/die.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/die.wav -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/hit.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/hit.ogg -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/hit.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/hit.wav -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/point.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/point.ogg -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/point.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/point.wav -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/swoosh.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/swoosh.ogg -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/swoosh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/swoosh.wav -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/wing.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/wing.ogg -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/audio/wing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/audio/wing.wav -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/0.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/1.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/2.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/3.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/4.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/5.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/6.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/7.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/8.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/9.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/background-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/background-black.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/base.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/pipe-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/pipe-green.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/redbird-downflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-downflap.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/redbird-midflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-midflap.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/assets/sprites/redbird-upflap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/assets/sprites/redbird-upflap.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow.keras.layers as kl 4 | import tensorflow.keras.losses as kls 5 | import tensorflow.keras.optimizers as ko 6 | import numpy as np 7 | import cv2 8 | import sys 9 | import os.path as osp 10 | dirname = osp.dirname(__file__) 11 | sys.path.append(dirname) 12 | from game.wrapped_flappy_bird import GameState 13 | import random 14 | 15 | GAME = 'flappy bird' 16 | ACTIONS = 2 17 | GAMMA = 0.99 18 | OBSERVE = 250. # 训练前观察的步长 19 | EXPLORE = 3.0e6 # 随机探索的时间 20 | FINAL_EPSILON = 1.0e-4 # 最终探索率 21 | INITIAL_EPSILON = 0.1 # 初始探索率 22 | REPLAY_MEMORY = 50000 # 经验池的大小 23 | BATCH = 32 # mini-batch 的大小 24 | FRAME_PER_ACTION = 1 # 跳帧 25 | 26 | class Experience_Buffer: 27 | def __init__(self, buffer_size=REPLAY_MEMORY): 28 | super().__init__() 29 | self.buffer = [] 30 | self.buffer_size = buffer_size 31 | 32 | def add_experience(self, experience): 33 | if len(self.buffer) + len(experience) >= self.buffer_size: 34 | self.buffer[0:len(self.buffer) + len(experience) - self.buffer_size] = [] 35 | self.buffer.extend(experience) 36 | 37 | def sample(self, samples_num): 38 | samples_data = random.sample(self.buffer, samples_num) 39 | train_s = [d[0] for d in samples_data] 40 | train_s = np.asarray(train_s) 41 | train_a = [d[1] for d in samples_data] 42 | train_a = np.asarray(train_a) 43 | train_r = [d[2] for d in samples_data] 44 | train_r = np.asarray(train_r) 45 | train_s_ = [d[3] for d in samples_data] 46 | train_s_ = np.asarray(train_s_) 47 | train_terminal = [d[4] for d in samples_data] 48 | train_terminal = np.asarray(train_terminal) 49 | return train_s, train_a, train_r, train_s_, train_terminal 50 | 51 | class Deep_Q_N: 52 | def __init__(self, lr=1.0e-6, model_file=None): 53 | self.gamma = GAMMA 54 | self.tau = 0.01 55 | # tf 56 | self.learning_rate = lr 57 | self.q_model = self.build_q_net() 58 | self.q_target_model = self.build_q_net() 59 | if model_file is not None: 60 | self.restore_model(model_file) 61 | 62 | def save_model(self, model_path): 63 | self.q_model.save(model_path + '0') 64 | self.q_target_model.save(model_path + '1') 65 | 66 | def restore_model(self, model_path): 67 | self.q_model.load_weights(model_path + '0') 68 | self.q_target_model.load_weights(model_path + '1') 69 | 70 | def build_q_net(self): 71 | model = keras.Sequential() 72 | h_conv1 = kl.Conv2D( 73 | input_shape=(80, 80, 4), 74 | filters=32, kernel_size=8, 75 | data_format='channels_last', 76 | strides=4, padding='same', 77 | activation='relu' 78 | ) 79 | h_pool1 = kl.MaxPool2D( 80 | pool_size=2, strides=2, padding='same' 81 | ) 82 | h_conv2 = kl.Conv2D( 83 | filters=64, kernel_size=4, 84 | strides=2, padding='same', 85 | activation='relu' 86 | ) 87 | h_conv3 = kl.Conv2D( 88 | filters=64, kernel_size=3, 89 | strides=1, padding='same', 90 | activation='relu' 91 | ) 92 | h_conv3_flat = kl.Flatten() 93 | h_fc1 = kl.Dense(512, activation='relu') 94 | qout = kl.Dense(ACTIONS) 95 | model.add(h_conv1) 96 | model.add(h_pool1) 97 | model.add(h_conv2) 98 | model.add(h_conv3) 99 | model.add(h_conv3_flat) 100 | model.add(h_fc1) 101 | model.add(qout) 102 | 103 | model.compile( 104 | optimizer=ko.Adam(lr=self.learning_rate), 105 | loss=[self._get_mse_for_action] 106 | ) 107 | 108 | return model 109 | 110 | def _get_mse_for_action(self, target_and_action, current_prediction): 111 | targets, one_hot_action = tf.split(target_and_action, [1, 2], axis=1) 112 | active_q_value = tf.expand_dims(tf.reduce_sum(current_prediction * one_hot_action, axis=1), axis=-1) 113 | return kls.mean_squared_error(targets, active_q_value) 114 | 115 | def _update_target(self): 116 | q_weights = self.q_model.get_weights() 117 | q_target_weights = self.q_target_model.get_weights() 118 | 119 | q_weights = [self.tau * w for w in q_weights] 120 | q_target_weights = [(1. - self.tau) * w for w in q_target_weights] 121 | new_weights = [ 122 | q_weights[i] + q_target_weights[i] 123 | for i in range(len(q_weights)) 124 | ] 125 | self.q_target_model.set_weights(new_weights) 126 | 127 | def _one_hot_action(self, actions): 128 | action_index = np.array(actions) 129 | batch_size = len(actions) 130 | result = np.zeros((batch_size, 2)) 131 | result[np.arange(batch_size), action_index] = 1. 132 | return result 133 | 134 | def epsilon_greedy(self, s_t, epsilon): 135 | s_t = s_t.reshape(-1, 80, 80, 4) 136 | amax = np.argmax(self.q_model.predict(s_t)[0]) 137 | # 概率部分 138 | if np.random.uniform() < 1 - epsilon: 139 | # 最优动作 140 | a_t = amax 141 | else: 142 | a_t = random.randint(0, 1) 143 | return a_t 144 | 145 | def train_Network(self, experience_buffer): 146 | # 打开游戏状态与模拟器进行通信 147 | game_state = GameState(fps=100) 148 | # 获得第1个状态并将图形进行预处理 149 | do_nothing = 0 150 | # 与游戏交互1次 151 | x_t, r_0, terminal = game_state.frame_step(do_nothing) 152 | x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY) 153 | ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY) 154 | s_t = np.stack((x_t, x_t, x_t, x_t), axis=2) 155 | # 开始训练 156 | epsilon = INITIAL_EPSILON 157 | t = 0 158 | while "flappy bird" != "angry bird": 159 | a_t = self.epsilon_greedy(s_t, epsilon) 160 | # epsilon 递减 161 | if epsilon > FINAL_EPSILON and t > OBSERVE: 162 | epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE 163 | # 运用动作与环境交互1次 164 | x_t1_colored, r_t, terminal = game_state.frame_step(a_t) 165 | x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY) 166 | ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY) 167 | x_t1 = np.reshape(x_t1, (80, 80, 1)) 168 | s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2) 169 | # 将数据存储到经验池中 170 | experience = np.reshape(np.array([s_t, a_t, r_t, s_t1, terminal]), [1, 5]) 171 | experience_buffer.add_experience(experience) 172 | # 在观察结束后进行训练 173 | if t > OBSERVE: 174 | # 采集样本 175 | train_s, train_a, train_r, train_s_, train_terminal = experience_buffer.sample(BATCH) 176 | target_q = [] 177 | read_target_Q = self.q_target_model.predict(train_s_) 178 | for i in range(len(train_r)): 179 | if train_terminal[i]: 180 | target_q.append(train_r[i]) 181 | else: 182 | target_q.append(train_r[i] * GAMMA * np.max(read_target_Q[i])) 183 | # 训练 1 步 184 | one_hot_actions = self._one_hot_action(train_a) 185 | target_q = np.asarray(target_q) 186 | target_and_actions = np.concatenate((target_q[:, None], one_hot_actions), axis=1) 187 | loss = self.q_model.train_on_batch(train_s, target_and_actions) 188 | # 更新旧的网络 189 | self._update_target() 190 | # 往前推进1步 191 | s_t = s_t1 192 | t += 1 193 | # 每 10000 次迭代保存1次 194 | if t % 10000 == 0: 195 | dirname = osp.dirname(__file__) 196 | self.save_model(dirname + '\\saved_networks') 197 | if t <= OBSERVE: 198 | print("OBSERVER", t) 199 | else: 200 | if t % 1 == 0: 201 | print("train, steps", t, "/epsion ", epsilon, "action_index", a_t, "/reward", r_t) 202 | 203 | 204 | if __name__=="__main__": 205 | buffer = Experience_Buffer() 206 | brain = Deep_Q_N() 207 | brain.train_Network(buffer) 208 | 209 | -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/game/flappy_bird_utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | import os.path as osp 4 | 5 | path = osp.dirname(__file__) 6 | path = osp.join(path, '..') 7 | 8 | def load(): 9 | # path of player with different states 10 | PLAYER_PATH = ( 11 | path + '/assets/sprites/redbird-upflap.png', 12 | path + '/assets/sprites/redbird-midflap.png', 13 | path + '/assets/sprites/redbird-downflap.png' 14 | ) 15 | 16 | # path of background 17 | BACKGROUND_PATH = path + '/assets/sprites/background-black.png' 18 | 19 | # path of pipe 20 | PIPE_PATH = path + '/assets/sprites/pipe-green.png' 21 | 22 | IMAGES, SOUNDS, HITMASKS = {}, {}, {} 23 | 24 | # numbers sprites for score display 25 | IMAGES['numbers'] = ( 26 | pygame.image.load(path + '/assets/sprites/0.png').convert_alpha(), 27 | pygame.image.load(path + '/assets/sprites/1.png').convert_alpha(), 28 | pygame.image.load(path + '/assets/sprites/2.png').convert_alpha(), 29 | pygame.image.load(path + '/assets/sprites/3.png').convert_alpha(), 30 | pygame.image.load(path + '/assets/sprites/4.png').convert_alpha(), 31 | pygame.image.load(path + '/assets/sprites/5.png').convert_alpha(), 32 | pygame.image.load(path + '/assets/sprites/6.png').convert_alpha(), 33 | pygame.image.load(path + '/assets/sprites/7.png').convert_alpha(), 34 | pygame.image.load(path + '/assets/sprites/8.png').convert_alpha(), 35 | pygame.image.load(path + '/assets/sprites/9.png').convert_alpha() 36 | ) 37 | 38 | # base (ground) sprite 39 | IMAGES['base'] = pygame.image.load(path + '/assets/sprites/base.png').convert_alpha() 40 | 41 | # sounds 42 | if 'win' in sys.platform: 43 | soundExt = '.wav' 44 | else: 45 | soundExt = '.ogg' 46 | 47 | SOUNDS['die'] = pygame.mixer.Sound(path + '/assets/audio/die' + soundExt) 48 | SOUNDS['hit'] = pygame.mixer.Sound(path + '/assets/audio/hit' + soundExt) 49 | SOUNDS['point'] = pygame.mixer.Sound(path + '/assets/audio/point' + soundExt) 50 | SOUNDS['swoosh'] = pygame.mixer.Sound(path + '/assets/audio/swoosh' + soundExt) 51 | SOUNDS['wing'] = pygame.mixer.Sound(path + '/assets/audio/wing' + soundExt) 52 | 53 | # select random background sprites 54 | IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert() 55 | 56 | # select random player sprites 57 | IMAGES['player'] = ( 58 | pygame.image.load(PLAYER_PATH[0]).convert_alpha(), 59 | pygame.image.load(PLAYER_PATH[1]).convert_alpha(), 60 | pygame.image.load(PLAYER_PATH[2]).convert_alpha(), 61 | ) 62 | 63 | # select random pipe sprites 64 | IMAGES['pipe'] = ( 65 | pygame.transform.rotate( 66 | pygame.image.load(PIPE_PATH).convert_alpha(), 180), 67 | pygame.image.load(PIPE_PATH).convert_alpha(), 68 | ) 69 | 70 | # hismask for pipes 71 | HITMASKS['pipe'] = ( 72 | getHitmask(IMAGES['pipe'][0]), 73 | getHitmask(IMAGES['pipe'][1]), 74 | ) 75 | 76 | # hitmask for player 77 | HITMASKS['player'] = ( 78 | getHitmask(IMAGES['player'][0]), 79 | getHitmask(IMAGES['player'][1]), 80 | getHitmask(IMAGES['player'][2]), 81 | ) 82 | 83 | return IMAGES, SOUNDS, HITMASKS 84 | 85 | def getHitmask(image): 86 | """returns a hitmask using an image's alpha.""" 87 | mask = [] 88 | for x in range(image.get_width()): 89 | mask.append([]) 90 | for y in range(image.get_height()): 91 | mask[x].append(bool(image.get_at((x,y))[3])) 92 | return mask 93 | -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/game/keyboard_agent.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import sys 3 | from wrapped_flappy_bird import GameState 4 | 5 | game_state = GameState(sound=True) 6 | 7 | ACTIONS = [0, 1] 8 | 9 | while True: 10 | action = ACTIONS[0] 11 | for event in pygame.event.get(): 12 | if event.type == pygame.QUIT: 13 | sys.exit() 14 | if event.type == pygame.KEYDOWN: 15 | if event.key == pygame.K_h: 16 | action = ACTIONS[1] 17 | game_state.frame_step(action) 18 | pygame.quit() 19 | -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/game/wrapped_flappy_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import os.path as osp 4 | import random 5 | import pygame 6 | dirname = osp.dirname(__file__) 7 | sys.path.append(dirname) 8 | import flappy_bird_utils 9 | import pygame.surfarray as surfarray 10 | from pygame.locals import * 11 | from itertools import cycle 12 | 13 | FPS = 30 14 | SCREENWIDTH = 288 15 | SCREENHEIGHT = 512 16 | 17 | pygame.init() 18 | FPSCLOCK = pygame.time.Clock() 19 | SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT)) 20 | pygame.display.set_caption('Flappy Bird') 21 | 22 | IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load() 23 | PIPEGAPSIZE = 100 # gap between upper and lower part of pipe 24 | BASEY = SCREENHEIGHT * 0.79 25 | 26 | PLAYER_WIDTH = IMAGES['player'][0].get_width() 27 | PLAYER_HEIGHT = IMAGES['player'][0].get_height() 28 | PIPE_WIDTH = IMAGES['pipe'][0].get_width() 29 | PIPE_HEIGHT = IMAGES['pipe'][0].get_height() 30 | BACKGROUND_WIDTH = IMAGES['background'].get_width() 31 | 32 | PLAYER_INDEX_GEN = cycle([0, 1, 2, 1]) 33 | 34 | 35 | class GameState: 36 | def __init__(self, sound=False, fps=FPS): 37 | self.score = self.playerIndex = self.loopIter = 0 38 | self.playerx = int(SCREENWIDTH * 0.2) 39 | self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2) 40 | self.basex = 0 41 | self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH 42 | self.sound = sound 43 | self.fps = fps 44 | 45 | newPipe1 = getRandomPipe() 46 | newPipe2 = getRandomPipe() 47 | self.upperPipes = [ 48 | {'x': SCREENWIDTH, 'y': newPipe1[0]['y']}, 49 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']}, 50 | ] 51 | self.lowerPipes = [ 52 | {'x': SCREENWIDTH, 'y': newPipe1[1]['y']}, 53 | {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']}, 54 | ] 55 | 56 | # player velocity, max velocity, downward accleration, accleration on flap 57 | self.pipeVelX = -4 58 | self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped 59 | self.playerMaxVelY = 10 # max vel along Y, max descend speed 60 | self.playerMinVelY = -8 # min vel along Y, max ascend speed 61 | self.playerAccY = 1 # players downward accleration 62 | self.playerFlapAcc = -9 # players speed on flapping 63 | self.playerFlapped = False # True when player flaps 64 | 65 | def frame_step(self, action): 66 | pygame.event.pump() 67 | 68 | reward = 0.1 69 | terminal = False 70 | 71 | # if sum(input_actions) != 1: 72 | # raise ValueError('Multiple input actions!') 73 | 74 | # input_actions[0] == 1: do nothing 75 | # input_actions[1] == 1: flap the bird 76 | if action == 1: 77 | if self.playery > -2 * PLAYER_HEIGHT: 78 | self.playerVelY = self.playerFlapAcc 79 | self.playerFlapped = True 80 | if self.sound: 81 | SOUNDS['wing'].play() 82 | 83 | # check for score 84 | playerMidPos = self.playerx + PLAYER_WIDTH / 2 85 | for pipe in self.upperPipes: 86 | pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 87 | if pipeMidPos <= playerMidPos < pipeMidPos + 4: 88 | self.score += 1 89 | if self.sound: 90 | SOUNDS['point'].play() 91 | reward = 1 92 | 93 | # playerIndex basex change 94 | if (self.loopIter + 1) % 3 == 0: 95 | self.playerIndex = next(PLAYER_INDEX_GEN) 96 | self.loopIter = (self.loopIter + 1) % 30 97 | self.basex = -((-self.basex + 100) % self.baseShift) 98 | 99 | # player's movement 100 | if self.playerVelY < self.playerMaxVelY and not self.playerFlapped: 101 | self.playerVelY += self.playerAccY 102 | if self.playerFlapped: 103 | self.playerFlapped = False 104 | self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT) 105 | if self.playery < 0: 106 | self.playery = 0 107 | 108 | # move pipes to left 109 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 110 | uPipe['x'] += self.pipeVelX 111 | lPipe['x'] += self.pipeVelX 112 | 113 | # add new pipe when first pipe is about to touch left of screen 114 | if 0 < self.upperPipes[0]['x'] < 5: 115 | newPipe = getRandomPipe() 116 | self.upperPipes.append(newPipe[0]) 117 | self.lowerPipes.append(newPipe[1]) 118 | 119 | # remove first pipe if its out of the screen 120 | if self.upperPipes[0]['x'] < -PIPE_WIDTH: 121 | self.upperPipes.pop(0) 122 | self.lowerPipes.pop(0) 123 | 124 | # check if crash here 125 | isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 126 | 'index': self.playerIndex}, 127 | self.upperPipes, self.lowerPipes) 128 | if isCrash: 129 | if self.sound: 130 | SOUNDS['hit'].play() 131 | SOUNDS['die'].play() 132 | terminal = True 133 | self.__init__(sound=self.sound) 134 | reward = -1 135 | 136 | # draw sprites 137 | SCREEN.blit(IMAGES['background'], (0,0)) 138 | 139 | for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes): 140 | SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y'])) 141 | SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y'])) 142 | 143 | SCREEN.blit(IMAGES['base'], (self.basex, BASEY)) 144 | # print score so player overlaps the score 145 | showScore(self.score) 146 | SCREEN.blit(IMAGES['player'][self.playerIndex], 147 | (self.playerx, self.playery)) 148 | 149 | image_data = pygame.surfarray.array3d(pygame.display.get_surface()) 150 | pygame.display.update() 151 | FPSCLOCK.tick(self.fps) 152 | #print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2) 153 | return image_data, reward, terminal 154 | 155 | def getRandomPipe(): 156 | """returns a randomly generated pipe""" 157 | # y of gap between upper and lower pipe 158 | gapYs = [20, 30, 40, 50, 60, 70, 80, 90] 159 | index = random.randint(0, len(gapYs)-1) 160 | gapY = gapYs[index] 161 | 162 | gapY += int(BASEY * 0.2) 163 | pipeX = SCREENWIDTH + 10 164 | 165 | return [ 166 | {'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe 167 | {'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe 168 | ] 169 | 170 | 171 | def showScore(score): 172 | """displays score in center of screen""" 173 | scoreDigits = [int(x) for x in list(str(score))] 174 | totalWidth = 0 # total width of all numbers to be printed 175 | 176 | for digit in scoreDigits: 177 | totalWidth += IMAGES['numbers'][digit].get_width() 178 | 179 | Xoffset = (SCREENWIDTH - totalWidth) / 2 180 | 181 | for digit in scoreDigits: 182 | SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1)) 183 | Xoffset += IMAGES['numbers'][digit].get_width() 184 | 185 | 186 | def checkCrash(player, upperPipes, lowerPipes): 187 | """returns True if player collders with base or pipes.""" 188 | pi = player['index'] 189 | player['w'] = IMAGES['player'][0].get_width() 190 | player['h'] = IMAGES['player'][0].get_height() 191 | 192 | # if player crashes into ground 193 | if player['y'] + player['h'] >= BASEY - 1: 194 | return True 195 | else: 196 | 197 | playerRect = pygame.Rect(player['x'], player['y'], 198 | player['w'], player['h']) 199 | 200 | for uPipe, lPipe in zip(upperPipes, lowerPipes): 201 | # upper and lower pipe rects 202 | uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 203 | lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT) 204 | 205 | # player and upper/lower pipe hitmasks 206 | pHitMask = HITMASKS['player'][pi] 207 | uHitmask = HITMASKS['pipe'][0] 208 | lHitmask = HITMASKS['pipe'][1] 209 | 210 | # if bird collided with upipe or lpipe 211 | uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) 212 | lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) 213 | 214 | if uCollide or lCollide: 215 | return True 216 | 217 | return False 218 | 219 | def pixelCollision(rect1, rect2, hitmask1, hitmask2): 220 | """Checks if two objects collide and not just their rects""" 221 | rect = rect1.clip(rect2) 222 | 223 | if rect.width == 0 or rect.height == 0: 224 | return False 225 | 226 | x1, y1 = rect.x - rect1.x, rect.y - rect1.y 227 | x2, y2 = rect.x - rect2.x, rect.y - rect2.y 228 | 229 | for x in range(rect.width): 230 | for y in range(rect.height): 231 | if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]: 232 | return True 233 | return False 234 | -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/images/flappy_bird_demp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/flappy_bird_demp.gif -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/images/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/network.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/images/preprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/images/preprocess.png -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/logs_bird/hidden.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/logs_bird/hidden.txt -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/logs_bird/readout.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/logs_bird/readout.txt -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks0/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/saved_model.pb -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00000-of-00002 -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.data-00001-of-00002 -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks0/variables/variables.index -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks1/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/saved_model.pb -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00000-of-00002 -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.data-00001-of-00002 -------------------------------------------------------------------------------- /ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiperLiu/Approachable-Reinforcement-Learning/14d56297bacd1704bb63294c13fab7ffbb3fc58b/ch_1/sec_6/flappy_bird/saved_networks1/variables/variables.index -------------------------------------------------------------------------------- /ch_1/sec_6/lfa_rl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | import pygame 5 | import time 6 | import matplotlib.pyplot as plt 7 | from yuan_yang_env_fa import YuanYangEnv 8 | 9 | 10 | class LFA_RL: 11 | def __init__(self, yuanyang): 12 | self.gamma = yuanyang.gamma 13 | self.yuanyang = yuanyang 14 | # 基于特征表示的参数 15 | self.theta_tr = np.zeros((400, 1)) * 0.1 16 | # 基于固定稀疏所对应的参数 17 | self.theta_fsr = np.zeros((80, 1)) * 0.1 18 | 19 | def feature_tr(self, s, a): 20 | phi_s_a = np.zeros((1, 400)) 21 | phi_s_a[0, 100 * a + s] = 1 22 | return phi_s_a 23 | 24 | # 定义贪婪策略 25 | def greedy_policy_tr(self, state): 26 | qfun = np.array([0, 0, 0, 0]) * 0.1 27 | # 计算行为值函数 28 | for i in range(4): 29 | qfun[i] = np.dot(self.feature_tr(state, i), self.theta_tr) 30 | amax = qfun.argmax() 31 | return self.yuanyang.actions[amax] 32 | 33 | # 定义 epsilon 贪婪策略 34 | def epsilon_greedy_policy_tr(self, state, epsilon): 35 | qfun = np.array([0, 0, 0, 0]) * 0.1 36 | # 计算行为值函数 37 | for i in range(4): 38 | qfun[i] = np.dot(self.feature_tr(state, i), self.theta_tr) 39 | amax = qfun.argmax() 40 | # 概率部分 41 | if np.random.uniform() < 1 - epsilon: 42 | # 最优动作 43 | return self.yuanyang.actions[amax] 44 | else: 45 | return self.yuanyang.actions[ 46 | int(random.random() * len(self.yuanyang.actions)) 47 | ] 48 | 49 | # 找到动作所对应的序号 50 | def find_anum(self, a): 51 | for i in range(len(self.yuanyang.actions)): 52 | if a == self.yuanyang.actions[i]: 53 | return i 54 | 55 | def qlearning_lfa_tr(self, num_iter, alpha, epsilon): 56 | iter_num = [] 57 | self.theta_tr = np.zeros((400, 1)) * 0.1 58 | # 第1个大循环,产生了多少实验 59 | for iter in range(num_iter): 60 | # 随机初始化状态 61 | epsilon = epsilon * 0.99 62 | s_sample = [] 63 | # 初始状态 64 | s = 0 65 | flag = self.greedy_test_tr() 66 | if flag == 1: 67 | iter_num.append(iter) 68 | if len(iter_num) < 2: 69 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0])) 70 | if flag == 2: 71 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter)) 72 | break 73 | # 利用 epsilon-greedy 策略选初始动作 74 | a = self.epsilon_greedy_policy_tr(s, epsilon) 75 | t = False 76 | count = 0 77 | 78 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate 79 | while t == False and count < 30: 80 | # 与环境交互得到下一状态 81 | s_next, r, t = self.yuanyang.transform(s, a) 82 | a_num = self.find_anum(a) 83 | # 本轨迹中已有,给出负回报 84 | if s_next in s_sample: 85 | r = -2 86 | s_sample.append(s) 87 | # 判断是否是终止状态 88 | if t == True: 89 | q_target = r 90 | else: 91 | # 下一状态处的最大动作,体现同轨策略 92 | a1 = self.greedy_policy_tr(s_next) 93 | a1_num = self.find_anum(a1) 94 | # Q-learning 得到时间差分目标 95 | q_target = r + self.gamma * np.dot(self.feature_tr(s_next, a1_num), self.theta_tr) 96 | # 利用梯度下降的方法对参数进行学习 97 | self.theta_tr = self.theta_tr + alpha * (q_target - np.dot(self.feature_tr(s, a_num), self.theta_tr))[0, 0] * np.transpose(self.feature_tr(s, a_num)) 98 | # 转到下一状态 99 | s = s_next 100 | # 行为策略 101 | a = self.epsilon_greedy_policy_tr(s, epsilon) 102 | count += 1 103 | 104 | return self.theta_tr 105 | 106 | def greedy_test_tr(self): 107 | s = 0 108 | s_sample = [] 109 | done = False 110 | flag = 0 111 | step_num = 0 112 | while done == False and step_num < 30: 113 | a = self.greedy_policy_tr(s) 114 | # 与环境交互 115 | s_next, r, done = self.yuanyang.transform(s, a) 116 | s_sample.append(s) 117 | s = s_next 118 | step_num += 1 119 | 120 | if s == 9: 121 | flag = 1 122 | if s == 9 and step_num < 21: 123 | flag = 2 124 | 125 | return flag 126 | 127 | def feature_fsr(self, s, a): 128 | phi_s_a = np.zeros((1, 80)) 129 | y = int(s / 10) 130 | x = s - 10 * y 131 | phi_s_a[0, 20 * a + x] = 1 132 | phi_s_a[0, 20 * a + 10 + y] = 1 133 | return phi_s_a 134 | 135 | def greedy_policy_fsr(self, state): 136 | qfun = np.array([0, 0, 0, 0]) * 0.1 137 | # 计算行为值函数 138 | for i in range(4): 139 | qfun[i] = np.dot(self.feature_fsr(state, i), self.theta_fsr) 140 | amax = qfun.argmax() 141 | return self.yuanyang.actions[amax] 142 | 143 | # 定义 epsilon 贪婪策略 144 | def epsilon_greedy_policy_fsr(self, state, epsilon): 145 | qfun = np.array([0, 0, 0, 0]) * 0.1 146 | # 计算行为值函数 147 | for i in range(4): 148 | qfun[i] = np.dot(self.feature_fsr(state, i), self.theta_fsr) 149 | amax = qfun.argmax() 150 | # 概率部分 151 | if np.random.uniform() < 1 - epsilon: 152 | # 最优动作 153 | return self.yuanyang.actions[amax] 154 | else: 155 | return self.yuanyang.actions[ 156 | int(random.random() * len(self.yuanyang.actions)) 157 | ] 158 | 159 | def greedy_test_fsr(self): 160 | s = 0 161 | s_sample = [] 162 | done = False 163 | flag = 0 164 | step_num = 0 165 | while done == False and step_num < 30: 166 | a = self.greedy_policy_fsr(s) 167 | # 与环境交互 168 | s_next, r, done = self.yuanyang.transform(s, a) 169 | s_sample.append(s) 170 | s = s_next 171 | step_num += 1 172 | 173 | if s == 9: 174 | flag = 1 175 | if s == 9 and step_num < 21: 176 | flag = 2 177 | 178 | return flag 179 | 180 | def qlearning_lfa_fsr(self, num_iter, alpha, epsilon): 181 | iter_num = [] 182 | self.theta_tr = np.zeros((80, 1)) * 0.1 183 | # 第1个大循环,产生了多少实验 184 | for iter in range(num_iter): 185 | # 随机初始化状态 186 | epsilon = epsilon * 0.99 187 | s_sample = [] 188 | # 初始状态 189 | s = 0 190 | flag = self.greedy_test_fsr() 191 | if flag == 1: 192 | iter_num.append(iter) 193 | if len(iter_num) < 2: 194 | print('sarsa 第 1 次完成任务需要的迭代次数为 {}'.format(iter_num[0])) 195 | if flag == 2: 196 | print('sarsa 第 1 次实现最短路径需要的迭代次数为 {}'.format(iter)) 197 | break 198 | # 利用 epsilon-greedy 策略选初始动作 199 | a = self.epsilon_greedy_policy_fsr(s, epsilon) 200 | t = False 201 | count = 0 202 | 203 | # 第 2 个循环, 1 个实验, s0-s1-s2-s1-s2-s_terminate 204 | while t == False and count < 30: 205 | # 与环境交互得到下一状态 206 | s_next, r, t = self.yuanyang.transform(s, a) 207 | a_num = self.find_anum(a) 208 | # 本轨迹中已有,给出负回报 209 | if s_next in s_sample: 210 | r = -2 211 | s_sample.append(s) 212 | # 判断是否是终止状态 213 | if t == True: 214 | q_target = r 215 | else: 216 | # 下一状态处的最大动作,体现同轨策略 217 | a1 = self.greedy_policy_fsr(s_next) 218 | a1_num = self.find_anum(a1) 219 | # Q-learning 得到时间差分目标 220 | q_target = r + self.gamma * np.dot(self.feature_fsr(s_next, a1_num), self.theta_fsr) 221 | # 利用梯度下降的方法对参数进行学习 222 | self.theta_fsr = self.theta_fsr + alpha * (q_target - np.dot(self.feature_fsr(s, a_num), self.theta_fsr))[0, 0] * np.transpose(self.feature_fsr(s, a_num)) 223 | # 转到下一状态 224 | s = s_next 225 | # 行为策略 226 | a = self.epsilon_greedy_policy_fsr(s, epsilon) 227 | count += 1 228 | 229 | return self.theta_fsr 230 | 231 | 232 | def qlearning_lfa_tr(): 233 | yuanyang = YuanYangEnv() 234 | brain = LFA_RL(yuanyang) 235 | brain.qlearning_lfa_tr(num_iter=5000, alpha=0.1, epsilon=0.8) 236 | 237 | # 打印 238 | flag = 1 239 | s = 0 240 | path = [] 241 | # 将 v 值打印出来 242 | qvalue1 = np.zeros((100, 4)) 243 | for i in range(400): 244 | y = int(i / 100) 245 | x = i - 100 * y 246 | qvalue1[x, y] = np.dot(brain.feature_tr(x, y), brain.theta_tr) 247 | yuanyang.action_value = qvalue1 248 | step_num = 0 249 | 250 | # 将最优路径打印出来 251 | while flag: 252 | path.append(s) 253 | yuanyang.path = path 254 | a = brain.greedy_policy_tr(s) 255 | print('%d->%s\t'%(s, a), qvalue1[s, 0], qvalue1[s, 1], qvalue1[s, 2], qvalue1[s, 3]) 256 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 257 | yuanyang.render() 258 | time.sleep(0.1) 259 | step_num += 1 260 | s_, r, t = yuanyang.transform(s, a) 261 | if t == True or step_num > 200: 262 | flag = 0 263 | s = s_ 264 | 265 | # 渲染最后的路径点 266 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 267 | path.append(s) 268 | yuanyang.render() 269 | while True: 270 | yuanyang.render() 271 | 272 | def qlearning_lfa_fsr(): 273 | yuanyang = YuanYangEnv() 274 | brain = LFA_RL(yuanyang) 275 | qvalue2 = brain.qlearning_lfa_fsr(num_iter=5000, alpha=0.1, epsilon=0.1) 276 | 277 | # 打印 278 | flag = 1 279 | s = 0 280 | path = [] 281 | # 将 v 值打印出来 282 | # 打印 283 | flag = 1 284 | s = 0 285 | path = [] 286 | # 将 v 值打印出来 287 | qvalue2 = np.zeros((100, 4)) 288 | for i in range(400): 289 | y = int(i / 100) 290 | x = i - 100 * y 291 | qvalue2[x, y] = np.dot(brain.feature_fsr(x, y), brain.theta_fsr) 292 | yuanyang.action_value = qvalue2 293 | step_num = 0 294 | 295 | # 将最优路径打印出来 296 | while flag: 297 | path.append(s) 298 | yuanyang.path = path 299 | a = brain.greedy_policy_fsr(s) 300 | print('%d->%s\t'%(s, a), qvalue2[s, 0], qvalue2[s, 1], qvalue2[s, 2], qvalue2[s, 3]) 301 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 302 | yuanyang.render() 303 | time.sleep(0.1) 304 | step_num += 1 305 | s_, r, t = yuanyang.transform(s, a) 306 | if t == True or step_num > 200: 307 | flag = 0 308 | s = s_ 309 | 310 | # 渲染最后的路径点 311 | yuanyang.bird_male_position = yuanyang.state_to_position(s) 312 | path.append(s) 313 | yuanyang.render() 314 | while True: 315 | yuanyang.render() 316 | 317 | 318 | if __name__ == "__main__": 319 | # qlearning_lfa_tr() 320 | qlearning_lfa_fsr() -------------------------------------------------------------------------------- /ch_1/sec_6/load_images.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import os.path as osp 3 | 4 | path_file = osp.abspath(__file__) 5 | path_images = osp.join(path_file, '../../..', 'ch_0/sec_2/images') 6 | 7 | def load_bird_male(): 8 | obj = 'bird_male.png' 9 | obj_path = osp.join(path_images, obj) 10 | return pygame.image.load(obj_path) 11 | 12 | def load_bird_female(): 13 | obj = 'bird_female.png' 14 | obj_path = osp.join(path_images, obj) 15 | return pygame.image.load(obj_path) 16 | 17 | def load_background(): 18 | obj = 'background.jpg' 19 | obj_path = osp.join(path_images, obj) 20 | return pygame.image.load(obj_path) 21 | 22 | def load_obstacle(): 23 | obj = 'obstacle.png' 24 | obj_path = osp.join(path_images, obj) 25 | return pygame.image.load(obj_path) 26 | -------------------------------------------------------------------------------- /ch_1/sec_6/yuan_yang_env_fa.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | from load_images import * 4 | import numpy as np 5 | 6 | class YuanYangEnv: 7 | def __init__(self): 8 | self.states = [] 9 | for i in range(0, 100): 10 | self.states.append(i) 11 | self.actions = ['e', 's', 'w', 'n'] 12 | # 蒙特卡洛需要修改 gamma ,防止长远回报过快衰减 13 | self.gamma = 0.95 14 | self.action_value = np.zeros((100, 4)) 15 | 16 | self.viewer = None 17 | self.FPSCLOCK = pygame.time.Clock() 18 | 19 | self.screen_size = (1200, 900) 20 | self.bird_position = (0, 0) 21 | self.limit_distance_x = 120 22 | self.limit_distance_y = 90 23 | self.obstacle_size = [120, 90] 24 | self.obstacle1_x = [] 25 | self.obstacle1_y = [] 26 | self.obstacle2_x = [] 27 | self.obstacle2_y = [] 28 | 29 | for i in range(8): 30 | # obstacle 1 31 | self.obstacle1_x.append(360) 32 | if i <= 3: 33 | self.obstacle1_y.append(90 * i) 34 | else: 35 | self.obstacle1_y.append(90 * (i + 2)) 36 | # obstacle 2 37 | self.obstacle2_x.append(720) 38 | if i <= 4: 39 | self.obstacle2_y.append(90 * i) 40 | else: 41 | self.obstacle2_y.append(90 * (i + 2)) 42 | 43 | self.bird_male_init_position = [0.0, 0.0] 44 | self.bird_male_position = [0, 0] 45 | self.bird_female_init_position = [1080, 0] 46 | 47 | self.path = [] 48 | 49 | def collide(self, state_position): 50 | flag = 1 51 | flag1 = 1 52 | flag2 = 1 53 | 54 | # obstacle 1 55 | dx = [] 56 | dy = [] 57 | for i in range(8): 58 | dx1 = abs(self.obstacle1_x[i] - state_position[0]) 59 | dx.append(dx1) 60 | dy1 = abs(self.obstacle1_y[i] - state_position[1]) 61 | dy.append(dy1) 62 | mindx = min(dx) 63 | mindy = min(dy) 64 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 65 | flag1 = 0 66 | 67 | # obstacle 2 68 | dx_second = [] 69 | dy_second = [] 70 | for i in range(8): 71 | dx1 = abs(self.obstacle2_x[i] - state_position[0]) 72 | dx_second.append(dx1) 73 | dy1 = abs(self.obstacle2_y[i] - state_position[1]) 74 | dy_second.append(dy1) 75 | mindx = min(dx_second) 76 | mindy = min(dy_second) 77 | if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y: 78 | flag2 = 0 79 | 80 | if flag1 == 0 and flag2 == 0: 81 | flag = 0 82 | 83 | # collide edge 84 | if state_position[0] > 1080 or \ 85 | state_position[0] < 0 or \ 86 | state_position[1] > 810 or \ 87 | state_position[1] < 0: 88 | flag = 1 89 | 90 | return flag 91 | 92 | def find(self, state_position): 93 | flag = 0 94 | if abs(state_position[0] - self.bird_female_init_position[0]) < \ 95 | self.limit_distance_x and \ 96 | abs(state_position[1] - self.bird_female_init_position[1]) < \ 97 | self.limit_distance_y: 98 | flag = 1 99 | return flag 100 | 101 | def state_to_position(self, state): 102 | i = int(state / 10) 103 | j = state % 10 104 | position = [0, 0] 105 | position[0] = 120 * j 106 | position[1] = 90 * i 107 | return position 108 | 109 | def position_to_state(self, position): 110 | i = position[0] / 120 111 | j = position[1] / 90 112 | return int(i + 10 * j) 113 | 114 | def reset(self): 115 | # 随机产生一个初始位置 116 | flag1 = 1 117 | flag2 = 1 118 | while flag1 or flag2 == 1: 119 | state = self.states[int(random.random() * len(self.states))] 120 | state_position = self.state_to_position(state) 121 | flag1 = self.collide(state_position) 122 | flag2 = self.find(state_position) 123 | return state 124 | 125 | def transform(self, state, action): 126 | current_position = self.state_to_position(state) 127 | next_position = [0, 0] 128 | flag_collide = 0 129 | flag_find = 0 130 | 131 | flag_collide = self.collide(current_position) 132 | flag_find = self.find(current_position) 133 | if flag_collide == 1: 134 | return state, -10, True 135 | if flag_find == 1: 136 | return state, 10, True 137 | 138 | if action == 'e': 139 | next_position[0] = current_position[0] + 120 140 | next_position[1] = current_position[1] 141 | if action == 's': 142 | next_position[0] = current_position[0] 143 | next_position[1] = current_position[1] + 90 144 | if action == 'w': 145 | next_position[0] = current_position[0] - 120 146 | next_position[1] = current_position[1] 147 | if action == 'n': 148 | next_position[0] = current_position[0] 149 | next_position[1] = current_position[1] - 90 150 | 151 | flag_collide = self.collide(next_position) 152 | if flag_collide == 1: 153 | return self.position_to_state(current_position), -10, True 154 | 155 | flag_find = self.find(next_position) 156 | if flag_find == 1: 157 | return self.position_to_state(next_position), 10, True 158 | 159 | return self.position_to_state(next_position), 0, False 160 | 161 | def gameover(self): 162 | for event in pygame.event.get(): 163 | if event.type == pygame.QUIT: 164 | exit() 165 | 166 | def render(self): 167 | if self.viewer is None: 168 | pygame.init() 169 | 170 | self.viewer = pygame.display.set_mode(self.screen_size, 0, 32) 171 | pygame.display.set_caption("yuanyang") 172 | # load pic 173 | self.bird_male = load_bird_male() 174 | self.bird_female = load_bird_female() 175 | self.background = load_background() 176 | self.obstacle = load_obstacle() 177 | 178 | # self.viewer.blit(self.bird_female, self.bird_female_init_position) 179 | # self.viewer.blit(self.bird_male, self.bird_male_init_position) 180 | 181 | self.viewer.blit(self.background, (0, 0)) 182 | self.font = pygame.font.SysFont('times', 15) 183 | 184 | self.viewer.blit(self.background, (0, 0)) 185 | for i in range(11): 186 | pygame.draw.lines(self.viewer, 187 | (255, 255, 255), 188 | True, 189 | ((120 * i, 0), (120 * i, 900)), 190 | 1 191 | ) 192 | pygame.draw.lines(self.viewer, 193 | (255, 255, 255), 194 | True, 195 | ((0, 90 * i), (1200, 90 * i)), 196 | 1 197 | ) 198 | 199 | for i in range(8): 200 | self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i])) 201 | self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i])) 202 | 203 | self.viewer.blit(self.bird_female, self.bird_female_init_position) 204 | self.viewer.blit(self.bird_male, self.bird_male_init_position) 205 | 206 | # 画动作-值函数 207 | for i in range(100): 208 | y = int(i / 10) 209 | x = i % 10 210 | # 往东的值函数 211 | surface = self.font.render(str(round(float(self.action_value[i, 0]), 2)), True, (0, 0, 0)) 212 | self.viewer.blit(surface, (120 * x + 80, 90 * y + 45)) 213 | # 往南的值函数 214 | surface = self.font.render(str(round(float(self.action_value[i, 1]), 2)), True, (0, 0, 0)) 215 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 70)) 216 | # 往西的值函数 217 | surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0)) 218 | self.viewer.blit(surface, (120 * x + 10, 90 * y + 45)) 219 | # 往北的值函数 220 | surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0)) 221 | self.viewer.blit(surface, (120 * x + 50, 90 * y + 10)) 222 | 223 | # 画路径点 224 | for i in range(len(self.path)): 225 | rec_position = self.state_to_position(self.path[i]) 226 | pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3) 227 | surface = self.font.render(str(i), True, (255, 0, 0)) 228 | self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5)) 229 | 230 | pygame.display.update() 231 | self.gameover() 232 | self.FPSCLOCK.tick(30) 233 | 234 | 235 | if __name__ == "__main__": 236 | yy = YuanYangEnv() 237 | yy.render() 238 | while True: 239 | for event in pygame.event.get(): 240 | if event.type == pygame.QUIT: 241 | exit() 242 | 243 | -------------------------------------------------------------------------------- /pyTorch_learn/a_1.py: -------------------------------------------------------------------------------- 1 | # a.1.1 2 | import torch 3 | x = torch.Tensor(2, 3) 4 | print(x) 5 | 6 | x = torch.Tensor([[1, 2, 3], [4, 5, 6]]) 7 | print(x) 8 | 9 | x = torch.rand(2, 3) 10 | print(x) 11 | 12 | x = torch.zeros(2, 3) 13 | print(x) 14 | 15 | x = torch.ones(2, 3) 16 | print(x) 17 | 18 | print("") 19 | print("x.size: {}".format(x.size())) 20 | print("x.size[0]: {}".format(x.size()[0])) 21 | 22 | # a.1.2 23 | x = torch.ones(2, 3) 24 | y = torch.ones(2, 3) * 2 25 | print("") 26 | print("a.1.2") 27 | print("x + y = {}".format(x + y)) 28 | 29 | print(torch.add(x, y)) 30 | 31 | x.add_(y) 32 | print(x) 33 | print("注意:PyTroch中修改 Tensor 内容的操作\ 34 | 都会在方法名后加一个下划线,如copy_()、t_()等") 35 | 36 | print("x.zero_(): {}".format(x.zero_())) 37 | print("x: {}".format(x)) 38 | 39 | print("") 40 | print("Tensor 也支持 NumPy 中的各种切片操作") 41 | 42 | x[:, 1] = x[:, 1] + 2 43 | print(x) 44 | 45 | print("torch.view()相当于numpy中的reshape()") 46 | print("x.view(1, 6): {}".format(x.view(1, 6))) 47 | 48 | # a.1.3 49 | print("") 50 | print("a.1.3") 51 | 52 | print("Tensor 与 NumPy 的 array 可以转化,但是共享地址") 53 | import numpy as np 54 | x = torch.ones(2, 3) 55 | print(x) 56 | 57 | y = x.numpy() 58 | print("y = x.numpy(): {}".format(y)) 59 | 60 | print("x.add_(2): {}".format(x.add_(2))) 61 | 62 | print("y: {}".format(y)) 63 | 64 | z = torch.from_numpy(y) 65 | print("z = torch.from_numpy(y): {}".format(z)) 66 | 67 | # a.1.4 68 | print("") 69 | print("a.1.4") 70 | 71 | print("Autograd 实现自动梯度") 72 | 73 | from torch.autograd import Variable 74 | 75 | x = Variable(torch.ones(2, 2)*2, requires_grad=True) 76 | 77 | print(x) 78 | 79 | print("x.data {}, \n x's type: {}\n".format(x.data, type(x))) 80 | 81 | y = 2 * (x * x) + 5 * x 82 | y = y.sum() 83 | print("y: {}, \n y's type: {}\n".format(y, type(y))) 84 | 85 | print("y 可视为关于 x 的函数") 86 | print("y 应该是一个标量,y.backward()自动计算梯度") 87 | 88 | y.backward() 89 | print("x.grad: {}".format(x.grad)) 90 | print("x.grad 中自动保存梯度") 91 | -------------------------------------------------------------------------------- /pyTorch_learn/a_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(3, 6, 5) 10 | self.conv2 = nn.Conv2d(6, 16, 5) 11 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 12 | self.fc2 = nn.Linear(120, 84) 13 | self.fc3 = nn.Linear(84, 10) 14 | 15 | def forward(self, x): 16 | ''' 17 | 在 __init__() 中并没有真正定义网络结构的关系 18 | 输入输出关系在 forward() 中定义 19 | ''' 20 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 21 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 22 | ''' 23 | 注意, torch.nn 中要求输入的数据是一个 mini-batch , 24 | 由于数据图像是 3 维的,因此输入数据 x 是 4 维的 25 | 因此进入全连接层前 x.view(-1, ...) 方法化为 2 维 26 | ''' 27 | x = x.view(-1, 16 * 5 * 5) 28 | x = F.relu(self.fc1(x)) 29 | x = F.relu(self.fc2(x)) 30 | x = self.fc3(x) 31 | return x 32 | 33 | 34 | def seeNet(): 35 | net = Net() 36 | print(net) 37 | params = list(net.parameters()) 38 | print(len(params)) 39 | print("第一层的 weight :") 40 | print(params[0].size()) 41 | print(net.conv1.weight.size()) 42 | print("第一层的 bias :") 43 | print(params[1].size()) 44 | print(net.conv1.bias.size()) 45 | print("") 46 | print(net.conv1.weight.requires_grad) 47 | 48 | ''' 49 | 神经网络的输入输出应该是 Variable 50 | ''' 51 | inputV = Variable(torch.rand(1, 3, 32, 32)) 52 | # net.__call__() 53 | output = net(inputV) 54 | print(output) 55 | 56 | # seeNet() 57 | 58 | # train 59 | def trainNet(): 60 | net = Net() 61 | inputV = Variable(torch.rand(1, 3, 32, 32)) 62 | output = net(inputV) 63 | 64 | criterion = nn.CrossEntropyLoss() 65 | label = Variable(torch.LongTensor([4])) 66 | loss = criterion(output, label) 67 | # all data type should be 'Variable' 68 | print(loss) 69 | 70 | import torch.optim as optim 71 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 72 | print("\nbefore optim: {}".format(net.conv1.bias)) 73 | 74 | optimizer.zero_grad() # zeros the gradient buffers 75 | loss.backward() 76 | optimizer.step() # Does the update 77 | ''' 78 | 参数有变化,但可能很小 79 | ''' 80 | print("\nafter optim: {}".format(net.conv1.bias)) 81 | 82 | # trainNet() 83 | 84 | # 实战:CIFAR-10 85 | import torchvision 86 | import torchvision.transforms as transforms 87 | import torch.optim as optim 88 | 89 | transform = transforms.Compose( 90 | [transforms.ToTensor(), 91 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 92 | ) 93 | trainset = torchvision.datasets.CIFAR10( 94 | root='./pyTorch_learn/data', 95 | train=True, 96 | download=True, 97 | transform=transform 98 | ) 99 | trainloader = torch.utils.data.DataLoader( 100 | trainset, 101 | batch_size=4, 102 | shuffle=True, 103 | num_workers=0 # windows 下线程参数设为 0 安全 104 | ) 105 | 106 | testset = torchvision.datasets.CIFAR10( 107 | root='./pyTorch_learn/data', 108 | train=False, 109 | download=True, 110 | transform=transform 111 | ) 112 | testloader = torch.utils.data.DataLoader( 113 | testset, 114 | batch_size=4, 115 | shuffle=False, 116 | num_workers=0 # windows 下线程参数设为 0 安全 117 | ) 118 | 119 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 120 | 'dog', 'frog', 'horse', 'ship', 'truck' 121 | ) 122 | 123 | def cifar_10(): 124 | net = Net() 125 | criterion = nn.CrossEntropyLoss() 126 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 127 | 128 | for epoch in range(5): 129 | 130 | running_loss = 0.0 131 | for i, data in enumerate(trainloader, 0): 132 | inputs, labels = data 133 | inputs, labels = Variable(inputs), Variable(labels) 134 | optimizer.zero_grad() 135 | outputs = net(inputs) 136 | loss = criterion(outputs, labels) 137 | loss.backward() 138 | optimizer.step() 139 | running_loss += loss.item() 140 | if i % 6000 == 5999: 141 | print('[%d, %5d] loss: %.3f' % 142 | (epoch + 1, i + 1, running_loss / 6000)) 143 | running_loss = 0.0 144 | 145 | print('Finished Training') 146 | torch.save(net.state_dict(), './pyTorch_learn/data/' + 'model.pt') 147 | net.load_state_dict(torch.load('./pyTorch_learn/data/' + 'model.pt')) 148 | 149 | correct = 0 150 | total = 0 151 | for data in testloader: 152 | images, labels = data 153 | outputs = net(Variable(images)) 154 | # 返回可能性最大的索引 -> 输出标签 155 | _, predicted = torch.max(outputs, 1) 156 | total += labels.size(0) 157 | correct += (predicted == labels).sum() 158 | 159 | print('Accuracy of the network on the 10000 test images: %d %%' % ( 160 | 100 * correct / total 161 | )) 162 | 163 | class_correct = list(0. for i in range(10)) 164 | class_total = list(0. for i in range(10)) 165 | for data in testloader: 166 | images, labels = data 167 | outputs = net(Variable(images)) 168 | _, predicted = torch.max(outputs.data, 1) 169 | c = (predicted == labels).squeeze() 170 | for i in range(4): # mini-batch's size = 4 171 | label = labels[i] 172 | class_correct[label] += c[i] 173 | class_total[label] += 1 174 | 175 | for i in range(10): 176 | print('Accuracy of %5s : %2d %%' % ( 177 | classes[i], 100 * class_correct[i] / class_total[i] 178 | )) 179 | 180 | # save net 181 | print(net.state_dict().keys()) 182 | print(net.state_dict()['conv1.bias']) 183 | 184 | # torch.save(net.state_dict(), 'model.pt') 185 | # net.load_state_dict(torch.load('model.pt')) 186 | 187 | cifar_10() --------------------------------------------------------------------------------