├── LOG.md ├── README.md ├── lec01 ├── LOG.md ├── greedy_1.png ├── greedy_10.png ├── greedy_100.png ├── greedy_3.png ├── greedy_3_reorder.png ├── optimistic_1.png ├── optimistic_10.png ├── optimistic_100.png ├── optimistic_3.png ├── rewards_1000.png ├── rewards_1000_3.png ├── rewards_20000.png ├── rewards_20000_2.png ├── rewards_20000_3.png ├── t.py ├── ucb1.png └── ucb1_long.png ├── policy_gradient.py ├── sarsa.png ├── sarsa_alphas.png ├── sarsa_illegal.png ├── t.py ├── tictactoe.py └── tiger.py /LOG.md: -------------------------------------------------------------------------------- 1 | # ex1 2 | random to random 3 | Counter({-1: 5063, 1: 4757, 0: 180}) 4 | 5 | # ex2 6 | <48661191875666868481x256 sparse matrix of type '' 7 | with 97956 stored elements in Dictionary Of Keys format> 8 | 9 | # ex3 10 | 1000000 trials 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning 2 | 3 | ## Lecture Note 4 | 5 | ### Japanese 6 | 7 | - lec1: http://www.slideshare.net/nishio/1-70974083 8 | - lec2: http://www.slideshare.net/nishio/2-71708934 9 | - lec3: http://www.slideshare.net/nishio/3-71708970 10 | - lec4: http://www.slideshare.net/nishio/4-71709532 11 | - lec5: https://www.slideshare.net/nishio/5-79364176 12 | -------------------------------------------------------------------------------- /lec01/LOG.md: -------------------------------------------------------------------------------- 1 | 2 | ## ex2-2 3 | 4 | ``` 5 | greedy_1 6 | 0 7 | 3 8 | 130 9 | 467 10 | 11 | greedy_3 12 | 0 13 | 9 14 | 250 15 | 341 16 | 17 | greedy_10 18 | 0 19 | 14 20 | 338 21 | 248 22 | ``` 23 | 24 | # ex2-3 25 | 26 | ``` 27 | ucb1 28 | 27 29 | 89 30 | 385 31 | 99 32 | 33 | optimistic_1 34 | 0 35 | 9 36 | 361 37 | 230 38 | 39 | optimistic_3 40 | 0 41 | 4 42 | 563 43 | 33 44 | 45 | optimistic_10 46 | 2 47 | 8 48 | 587 49 | 3 50 | ``` 51 | 52 | # ex2-4 53 | 54 | ``` 55 | ucb1_long 56 | 6 57 | 19 58 | 556 59 | 19 60 | ``` 61 | 62 | # ex2-5 63 | 64 | ``` 65 | greedy_100 66 | 0 67 | 2 68 | 590 69 | 8 70 | optimistic_100 71 | 71 72 | 103 73 | 296 74 | 130 75 | ``` 76 | 77 | # ex2-6 78 | 79 | ``` 80 | greedy_3_reorder 81 | 385 82 | 0 83 | 14 84 | 201 85 | ``` -------------------------------------------------------------------------------- /lec01/greedy_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/greedy_1.png -------------------------------------------------------------------------------- /lec01/greedy_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/greedy_10.png -------------------------------------------------------------------------------- /lec01/greedy_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/greedy_100.png -------------------------------------------------------------------------------- /lec01/greedy_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/greedy_3.png -------------------------------------------------------------------------------- /lec01/greedy_3_reorder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/greedy_3_reorder.png -------------------------------------------------------------------------------- /lec01/optimistic_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/optimistic_1.png -------------------------------------------------------------------------------- /lec01/optimistic_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/optimistic_10.png -------------------------------------------------------------------------------- /lec01/optimistic_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/optimistic_100.png -------------------------------------------------------------------------------- /lec01/optimistic_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/optimistic_3.png -------------------------------------------------------------------------------- /lec01/rewards_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/rewards_1000.png -------------------------------------------------------------------------------- /lec01/rewards_1000_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/rewards_1000_3.png -------------------------------------------------------------------------------- /lec01/rewards_20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/rewards_20000.png -------------------------------------------------------------------------------- /lec01/rewards_20000_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/rewards_20000_2.png -------------------------------------------------------------------------------- /lec01/rewards_20000_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/rewards_20000_3.png -------------------------------------------------------------------------------- /lec01/t.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | import numpy as np 3 | 4 | def action_50(): 5 | if random() < 0.1: 6 | return 500 7 | return 0 8 | 9 | def action_100(): 10 | if random() < 0.2: 11 | return 500 12 | return 0 13 | 14 | def action_150(): 15 | if random() < 0.3: 16 | return 500 17 | return 0 18 | 19 | def action_none(): 20 | return 100 21 | 22 | actions = [action_50, action_100, action_150, action_none] 23 | num_actions = len(actions) 24 | 25 | def policy_greedy(num_use, sum_reward): 26 | for i in range(num_actions): 27 | if num_use[i] < 3: 28 | return i 29 | return np.argmax(sum_reward / num_use) 30 | 31 | 32 | def policy_builder_greedy(threshold): 33 | def policy_greedy(num_use, sum_reward): 34 | for i in range(num_actions): 35 | if num_use[i] < threshold: 36 | return i 37 | return np.argmax(sum_reward / num_use) 38 | return policy_greedy 39 | 40 | 41 | def policy_builder_optimistic(num_offset, offset_weight=500): 42 | def policy_optimistic(num_use, sum_reward): 43 | return np.argmax( 44 | (sum_reward + num_offset * offset_weight) 45 | / (num_use + num_offset)) 46 | return policy_optimistic 47 | 48 | 49 | def policy_ucb1(num_use, sum_reward): 50 | for i in range(num_actions): 51 | if num_use[i] < 1: 52 | return i 53 | mu = sum_reward / num_use 54 | ub = 500 * np.sqrt(np.log(num_use.sum()) * 2 / num_use) 55 | return np.argmax(mu + ub) 56 | 57 | 58 | def ex1(): 59 | for i in range(20): 60 | i_action = policy_greedy(num_use, sum_reward) 61 | action = actions[i_action] 62 | reward = action() 63 | print action.__name__, reward 64 | num_use[i_action] += 1 65 | sum_reward[i_action] += reward 66 | 67 | 68 | NUM_SERIES = 600 69 | NUM_ITERATION = 1000 70 | def ex2(name, policy, num_iteration=NUM_ITERATION): 71 | choices = np.zeros((NUM_SERIES, num_iteration)) 72 | rewards = np.zeros((NUM_SERIES, num_iteration)) 73 | for j in range(NUM_SERIES): 74 | num_use = np.zeros(num_actions) 75 | sum_reward = np.zeros(num_actions) 76 | for i in range(num_iteration): 77 | i_action = policy(num_use, sum_reward) 78 | action = actions[i_action] 79 | reward = action() 80 | #print action.__name__, reward 81 | num_use[i_action] += 1 82 | sum_reward[i_action] += reward 83 | choices[j, i] = i_action 84 | rewards[j, i] = reward 85 | 86 | 87 | last = np.sort(choices[:, -1]) 88 | print name 89 | print sum(last == 0) 90 | print sum(last == 1) 91 | print sum(last == 2) 92 | print sum(last == 3) 93 | 94 | # visualization 95 | import matplotlib.pyplot as plt 96 | choices.sort(axis=0) 97 | if num_iteration == 10000: 98 | plt.imshow(choices[:, ::10]) 99 | else: 100 | plt.imshow(choices) 101 | plt.savefig('{}.png'.format(name)) 102 | 103 | 104 | def run_ex2(): 105 | ex2("greedy_1", policy_builder_greedy(1)) 106 | ex2("greedy_3", policy_builder_greedy(3)) 107 | ex2("greedy_10", policy_builder_greedy(10)) 108 | 109 | ex2("ucb1", policy_ucb1) 110 | 111 | ex2("optimistic_1", policy_builder_optimistic(1)) 112 | ex2("optimistic_3", policy_builder_optimistic(3)) 113 | ex2("optimistic_10", policy_builder_optimistic(10)) 114 | 115 | ex2("ucb1_long", policy_ucb1, num_iteration=10000) 116 | 117 | 118 | def ex3(name, policy, num_iteration=NUM_ITERATION): 119 | choices = np.zeros((NUM_SERIES, num_iteration)) 120 | rewards = np.zeros((NUM_SERIES, num_iteration)) 121 | for j in range(NUM_SERIES): 122 | num_use = np.zeros(num_actions) 123 | sum_reward = np.zeros(num_actions) 124 | for i in range(num_iteration): 125 | i_action = policy(num_use, sum_reward) 126 | action = actions[i_action] 127 | reward = action() 128 | #print action.__name__, reward 129 | num_use[i_action] += 1 130 | sum_reward[i_action] += reward 131 | choices[j, i] = i_action 132 | rewards[j, i] = reward 133 | 134 | return rewards.mean(axis=0) 135 | # visualization 136 | #import matplotlib.pyplot as plt 137 | #plt.plot(rewards.sum(axis=0)) 138 | #plt.savefig('{}_reward.png'.format(name)) 139 | 140 | 141 | def ex3_vis(p=100, N=1000): 142 | r1 = ex3( 143 | "greedy_{}".format(p), 144 | policy_builder_greedy(p), N) 145 | r2 = ex3( 146 | "optimistic_{}".format(p), 147 | policy_builder_optimistic(p), N) 148 | r3 = ex3("ucb1", policy_ucb1, N) 149 | 150 | if N == 1000: 151 | w = 30 152 | elif N == 20000: 153 | w = 300 154 | else: 155 | raise NotImplemented 156 | 157 | kernel = np.ones(w) / w 158 | def smooth(x): 159 | return np.convolve(x, kernel, mode='valid') 160 | 161 | import matplotlib.pyplot as plt 162 | plt.plot(smooth(r1), label = "greedy_{}".format(p)) 163 | plt.plot(smooth(r2), label = "optimistic_{}".format(p)) 164 | plt.plot(smooth(r3), label = "ucb1") 165 | 166 | plt.xlabel("iteration") 167 | plt.ylabel("reward") 168 | plt.legend(loc = 4) 169 | 170 | plt.savefig('rewards_{}_3.png'.format(N)) 171 | 172 | #ex3_vis(N=20000) 173 | 174 | if not"ex2-5": 175 | ex2("greedy_100", policy_builder_greedy(100)) 176 | ex2("optimistic_100", policy_builder_optimistic(100)) 177 | 178 | actions = [action_none, action_50, action_100, action_150] 179 | ex2("greedy_3_reorder", policy_builder_greedy(3)) 180 | -------------------------------------------------------------------------------- /lec01/ucb1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/ucb1.png -------------------------------------------------------------------------------- /lec01/ucb1_long.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/lec01/ucb1_long.png -------------------------------------------------------------------------------- /policy_gradient.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | WORLD_WIDTH = 300 4 | WORLD_HEIGHT = 200 5 | START_X = 50.0 6 | START_Y = 100.0 7 | INITIAL_VELOCITY = 0.0 # 0.1 8 | INERTIA = 0.0 # 0.99 9 | VELOCITY_LIMIT = 1.0 # 0.1 10 | SIGMA = 1.0 11 | X_BONUS = 10.0 12 | class Environment(object): 13 | def __init__(self): 14 | self.result_log =[] 15 | self.init_state() 16 | 17 | def init_state(self): 18 | if INITIAL_VELOCITY: 19 | self.state = np.array([ 20 | START_X, START_Y, 21 | np.random.normal(0, INITIAL_VELOCITY), 22 | np.random.normal(0, INITIAL_VELOCITY), 23 | 1.0]) 24 | else: 25 | self.state = np.array([ 26 | START_X, START_Y, 0.0, 0.0, 1.0]) 27 | 28 | self.time = 300 29 | 30 | def get_state(self): 31 | return self.state.copy() 32 | 33 | def update(self, action): 34 | m = np.linalg.norm(action) 35 | if m > 1: 36 | action /= m 37 | self.state[2:4] = self.state[2:4] * INERTIA + action * VELOCITY_LIMIT 38 | #m = np.linalg.norm(self.state[2:]) 39 | #if m > 10: 40 | # self.state[2:] /= (m / 10) 41 | 42 | self.state[:2] += self.state[2:4] 43 | x, y = self.state[:2] 44 | 45 | if x < 0: 46 | self.init_state() 47 | self.result_log.append('left') 48 | return -1.0 + X_BONUS * x / WORLD_WIDTH 49 | 50 | if y < 0: 51 | self.init_state() 52 | self.result_log.append('top') 53 | return -1.0 + X_BONUS * x / WORLD_WIDTH 54 | 55 | if y > WORLD_HEIGHT: 56 | self.init_state() 57 | self.result_log.append('bottom') 58 | return -1.0 + X_BONUS * x / WORLD_WIDTH 59 | 60 | if x > WORLD_WIDTH: 61 | self.init_state() 62 | self.result_log.append('goal') 63 | return 10.0 64 | 65 | if 100 < x < 200 and 30 < y < 150: 66 | self.init_state() 67 | self.result_log.append('middle') 68 | return -1.0 + X_BONUS * x / WORLD_WIDTH 69 | 70 | self.time -= 1 71 | if self.time == 0: 72 | self.init_state() 73 | self.result_log.append('timeout') 74 | return -1.0 + X_BONUS * x / WORLD_WIDTH 75 | 76 | return 0.0 77 | 78 | def policy_random(state): 79 | return np.random.normal(size=2) 80 | 81 | 82 | class Policy(object): 83 | def __init__(self): 84 | self.theta = np.random.normal(scale=0.001, size=(5, 2)) * 0 85 | 86 | def __call__(self, state): 87 | mean = state.dot(self.theta) 88 | a = np.random.normal(mean, SIGMA) 89 | return a 90 | 91 | def grad(self, state, action): 92 | t1 = action - state.dot(self.theta) 93 | # 2 94 | t2 = -state 95 | # 4 96 | g = np.outer(t2, t1) 97 | return g 98 | 99 | 100 | def play(policy, num_plays=100, to_print=False): 101 | env = Environment() 102 | result = 0 103 | for i in range(num_plays): 104 | while True: 105 | s = env.get_state() 106 | a = policy(s) 107 | r = env.update(a) 108 | if r: 109 | break 110 | #print env.result_log[-1] 111 | return env 112 | 113 | 114 | def reinforce(policy, num_plays=100, to_print=False): 115 | env = Environment() 116 | result = 0 117 | samples = [] 118 | sum_t = 0 119 | sum_r = 0.0 120 | for i in range(num_plays): 121 | t = 0 122 | SARs = [] 123 | while True: 124 | s = env.get_state() 125 | a = policy(s) 126 | r = env.update(a) 127 | 128 | t += 1 129 | sum_r += r 130 | SARs.append((s, a, r)) 131 | if r: 132 | break 133 | samples.append((t, SARs)) 134 | sum_t += t 135 | 136 | baseline = float(sum_r) / sum_t 137 | grad = np.zeros((5, 2)) 138 | for (t, SARs) in samples: 139 | tmp_grad = np.zeros((5, 2)) 140 | for (s, a, r) in SARs: 141 | g = policy.grad(s, a) 142 | tmp_grad += g * (r - baseline) 143 | grad += tmp_grad / t 144 | grad /= num_plays 145 | #policy.theta /= np.linalg.norm(policy.theta) 146 | if np.linalg.norm(grad) > 1: 147 | grad /= np.linalg.norm(grad) 148 | print 'theta' 149 | print policy.theta 150 | print 'grad' 151 | print grad 152 | policy.theta -= 0.01 * grad 153 | print baseline, sum_t 154 | return env, samples 155 | 156 | 157 | #print Counter(play(policy_random).result_log) 158 | #print Counter(play(Policy()).result_log) 159 | policy = Policy() 160 | for i in range(10000): 161 | env, samples = reinforce(policy, 100) 162 | print Counter(env.result_log) 163 | 164 | from PIL import Image, ImageDraw 165 | im = Image.new('RGB', (300, 200), color=(255,255,255)) 166 | d = ImageDraw.Draw(im) 167 | for t, SARs in samples: 168 | points = [(START_X, START_Y)] 169 | for s, a, r in SARs: 170 | points.append(tuple(s[:2])) 171 | d.line(points, fill=0) 172 | d.rectangle((100, 30, 200, 150), fill=(128, 128, 128)) 173 | im.save('reinforce{:04d}.png'.format(i)) 174 | 175 | -------------------------------------------------------------------------------- /sarsa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/sarsa.png -------------------------------------------------------------------------------- /sarsa_alphas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/sarsa_alphas.png -------------------------------------------------------------------------------- /sarsa_illegal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishio/reinforcement_learning/c7a88036d83d1094b0052378e083d7e41a5cb648/sarsa_illegal.png -------------------------------------------------------------------------------- /t.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quarto 3 | 4 | 0: vacant 5 | 1-16: occupied 6 | """ 7 | import numpy as np 8 | from collections import Counter 9 | 10 | def board_to_int(board): 11 | s = 0L 12 | for i in range(16): 13 | s += long(board[i]) * (17 ** i) 14 | return s 15 | 16 | def board_to_possible_hands(board): 17 | return [i for i in range(16) if board[i] == 0] 18 | 19 | def init_board(): 20 | return np.zeros(16, dtype=np.int) 21 | 22 | def init_Q(): 23 | from scipy.sparse import dok_matrix 24 | return dok_matrix((17 ** 16, 16 * 16)) 25 | 26 | LINES = [ 27 | [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], 28 | [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15], 29 | [0, 5, 10, 15], [3, 6, 9, 12] 30 | ] 31 | def is_win(board): 32 | for line in LINES: 33 | xs = board[line] 34 | if any(x == 0 for x in xs): continue 35 | a, b, c, d = xs - 1 36 | if a & b & c & d != 0: 37 | return 1 38 | if a | b | c | d != 15: 39 | return 1 40 | return 0 41 | 42 | def print_board(board): 43 | """ 44 | >>> print_board(range(16)) 45 | . o x o | . o o x | . o o o | . o o o 46 | x o x o | x o o x | o x x x | o o o o 47 | x o x o | x o o x | x o o o | o x x x 48 | x o x o | x o o x | o x x x | x x x x 49 | """ 50 | m = np.zeros((16, 4), dtype=np.int) 51 | for i in range(16): 52 | if board[i] == 0: 53 | m[i, :] = 0 54 | else: 55 | v = board[i] - 1 56 | for bit in range(4): # nth bit 57 | m[i, bit] = ((v >> bit) & 1) + 1 58 | 59 | for y in range(4): 60 | print ' | '.join( 61 | ' '.join( 62 | ['.ox'[v] for v in m[y * 4 : (y + 1) * 4, bit]] 63 | ) 64 | for bit in range(4)) 65 | print 66 | 67 | 68 | def policy_random(env): 69 | from random import choice 70 | position = choice(board_to_possible_hands(env.board)) 71 | piece = choice(env.available_pieces) 72 | return (position, piece) 73 | 74 | 75 | class Environment(object): 76 | def __init__(self, policy=policy_random): 77 | self.op_policy = policy 78 | self.result_log =[] 79 | self.init_env() 80 | 81 | def init_env(self): 82 | self.board = init_board() 83 | self.available_pieces= range(2, 17) 84 | self.selected_piece = 1 85 | 86 | def _update(self, action, k=1, to_print=False): 87 | position, piece = action 88 | 89 | if self.board[position] != 0: 90 | # illegal move 91 | print 'illegal pos' 92 | self.init_env() 93 | self.result_log.append(-1 * k) 94 | return (self.board, -1 * k) 95 | 96 | if piece not in self.available_pieces: 97 | # illegal move 98 | print 'illegal piece' 99 | self.init_env() 100 | self.result_log.append(-1 * k) 101 | return (self.board, -1 * k) 102 | 103 | self.board[position] = self.selected_piece 104 | self.available_pieces.remove(piece) 105 | self.selected_piece = piece 106 | 107 | if to_print: 108 | print k, action 109 | print_board(self.board) 110 | 111 | b = is_win(self.board) 112 | if b: 113 | self.init_env() 114 | self.result_log.append(+1 * k) 115 | return (self.board, +1 * k) 116 | 117 | if not self.available_pieces: 118 | # put selected piece 119 | self.board[self.board==0] = self.selected_piece 120 | b = is_win(self.board) 121 | if to_print: 122 | print 'last move' 123 | print_board(self.board) 124 | 125 | self.init_env() 126 | if b: 127 | # opponent win 128 | self.result_log.append(-1 * k) 129 | return (self.board, -1 * k) 130 | else: 131 | # tie 132 | self.result_log.append(0) 133 | return (self.board, -1) 134 | 135 | return None 136 | 137 | def __call__(self, action, to_print=False): 138 | ret = self._update(action, k=1, to_print=to_print) 139 | if ret: return ret 140 | op_action = self.op_policy(self) 141 | ret = self._update(op_action, k=-1, to_print=to_print) 142 | if ret: return ret 143 | return (self.board, 0) 144 | 145 | 146 | def play(policy1, policy2=policy_random, to_print=False): 147 | env = Environment() 148 | result = 0 149 | 150 | for i in range(9): 151 | a = policy1(env) 152 | s, r = env(a, to_print=to_print) 153 | if r != 0: break 154 | if to_print: 155 | print env.result_log[-1] 156 | return env.result_log[-1] 157 | 158 | #play(policy_random, to_print=True) 159 | 160 | 161 | class Greedy(object): 162 | def __init__(self): 163 | self.Qtable = init_Q() 164 | 165 | def __call__(self, env): 166 | from random import choice 167 | s = board_to_int(env.board) 168 | actions = (action_to_int((pos, piece)) 169 | for pos in board_to_possible_hands(env.board) 170 | for piece in env.available_pieces 171 | ) 172 | qa = [(self.Qtable[s, a], a) for a in actions] 173 | bestQ, bestA = max(qa) 174 | bextQ, bestA = choice([(q, a) for (q, a) in qa if q == bestQ]) 175 | return int_to_action(bestA) 176 | 177 | 178 | class EpsilonGreedy(object): 179 | def __init__(self, eps=0.1): 180 | self.Qtable = init_Q() 181 | self.eps = eps 182 | 183 | def __call__(self, env): 184 | from random import choice, random 185 | s = board_to_int(env.board) 186 | if random() < self.eps: 187 | pos = choice(board_to_possible_hands(env.board)) 188 | piece = choice(env.available_pieces) 189 | return (pos, piece) 190 | 191 | actions = (action_to_int((pos, piece)) 192 | for pos in board_to_possible_hands(env.board) 193 | for piece in env.available_pieces 194 | ) 195 | qa = [(self.Qtable[s, a], a) for a in actions] 196 | bestQ, bestA = max(qa) 197 | bextQ, bestA = choice([(q, a) for (q, a) in qa if q == bestQ]) 198 | return int_to_action(bestA) 199 | 200 | 201 | def board_to_state(board): 202 | return board_to_int(board) 203 | 204 | def action_to_int(action): 205 | pos, piece = action 206 | return pos * 16 + (piece - 1) 207 | 208 | def int_to_action(i): 209 | assert 0 <= i < 16 * 16 210 | return (i / 16, i % 16 + 1) 211 | 212 | 213 | from kagura.utils import Digest 214 | digest = Digest(1) 215 | battle_per_seconds = [] 216 | 217 | def sarsa(alpha, policyClass=Greedy): 218 | global environment, policy 219 | gamma = 0.9 220 | num_result = batch_width * num_batch 221 | environment = Environment() 222 | policy = policyClass() 223 | 224 | action = policy(environment) 225 | state = board_to_state(environment.board) 226 | while True: 227 | next_board, reward = environment(action) 228 | next_state = board_to_state(next_board) 229 | 230 | # determine a' 231 | next_action = policy(environment) 232 | nextQ = policy.Qtable[next_state, action_to_int(next_action)] 233 | 234 | # update Q(s, a) 235 | s_a = (state, action_to_int(action)) 236 | Qsa = policy.Qtable[s_a] 237 | estimated_reward = reward + gamma * nextQ 238 | diff = estimated_reward - Qsa 239 | policy.Qtable[s_a] += alpha * diff 240 | 241 | state = next_state 242 | action = next_action 243 | if len(environment.result_log) == num_result: 244 | break 245 | t = digest.digest(len(environment.result_log)) 246 | if t: 247 | battle_per_seconds.append(t) 248 | 249 | vs = [] 250 | for i in range(num_batch): 251 | c = Counter(environment.result_log[batch_width * i : batch_width * (i + 1)]) 252 | print c 253 | vs.append(float(c[1]) / batch_width) 254 | return vs 255 | 256 | 257 | def qlearn(alpha, policyClass=Greedy): 258 | global environment, policy 259 | gamma = 0.9 260 | num_result = batch_width * num_batch 261 | environment = Environment() 262 | policy = policyClass() 263 | 264 | state = board_to_state(environment.board) 265 | while True: 266 | action = policy(environment) 267 | next_board, reward = environment(action) 268 | next_state = board_to_state(next_board) 269 | 270 | # update Q(s, a) 271 | maxQ = max(policy.Qtable[next_state, a] for a in board_to_possible_hands(next_board)) 272 | s_a = (state, action_to_int(action)) 273 | 274 | Qsa = policy.Qtable[s_a] 275 | estimated_reward = reward + gamma * maxQ 276 | diff = estimated_reward - Qsa 277 | policy.Qtable[s_a] += alpha * diff 278 | 279 | state = next_state 280 | 281 | if len(environment.result_log) == num_result: 282 | break 283 | t = digest.digest(len(environment.result_log)) 284 | if t: 285 | battle_per_seconds.append(t) 286 | 287 | vs = [] 288 | for i in range(num_batch): 289 | c = Counter(environment.result_log[batch_width * i : batch_width * (i + 1)]) 290 | print c 291 | vs.append(float(c[1]) / batch_width) 292 | return vs 293 | 294 | 295 | 296 | def plot_log(): 297 | from kagura import load 298 | result_log = load("sarsa_0.05_result_log") 299 | batch_width = 1000 300 | num_batch = 1000 301 | vs = [] 302 | for i in range(num_batch): 303 | c = Counter(result_log[batch_width * i : batch_width * (i + 1)]) 304 | print c 305 | vs.append(float(c[1]) / batch_width) 306 | 307 | label = 'Sarsa(0.05)' 308 | imgname = 'sarsa_0.05.png' 309 | plot() 310 | 311 | def plot(): 312 | import matplotlib.pyplot as plt 313 | plt.clf() 314 | plt.plot([0.475] * len(vs), label = "baseline") 315 | plt.plot(vs, label=label) 316 | plt.xlabel("iteration") 317 | plt.ylabel("Prob. of win") 318 | plt.legend(loc = 4) 319 | plt.savefig(imgname) 320 | 321 | 322 | def f(n, m): 323 | if m == 1: return n + 1 324 | return n * f(n - 1, m - 1) + f(n, m - 1) 325 | 326 | 327 | if not'ex1': 328 | from collections import Counter 329 | print Counter( 330 | play(policy_random) for i in range(10000)) 331 | elif not'ex2': 332 | batch_width = 1000 333 | num_batch = 100 334 | vs = sarsa(0.5) 335 | elif not'ex3': 336 | batch_width = 1000 337 | num_batch = 1000 338 | vs = sarsa(0.5) 339 | 340 | if 0: 341 | batch_width = 1000 342 | num_batch = 1000 343 | vs = qlearn(0.5) 344 | label = 'Qlearn(0.5)' 345 | imgname = 'qlearn.png' 346 | elif 0: 347 | batch_width = 1000 348 | num_batch = 1000 349 | vs = qlearn(0.05) 350 | label = 'Qlearn(0.05)' 351 | imgname = 'qlearn_0.05.png' 352 | 353 | 354 | from kagura import dump 355 | if 0: 356 | batch_width = 1000 357 | num_batch = 1000 358 | vs = sarsa(0.5, policyClass=EpsilonGreedy) 359 | label = 'Sarsa(0.5, eps=0.1)' 360 | imgname = 'sarsa_0.5_eps0.1.png' 361 | dump(environment.result_log, imgname.replace('.png', '_result_log')) 362 | elif 0: 363 | batch_width = 1000 364 | num_batch = 1000 365 | vs = sarsa(0.05, policyClass=EpsilonGreedy) 366 | label = 'Sarsa(0.05, eps=0.1)' 367 | imgname = 'sarsa_0.05_eps0.1.png' 368 | dump(environment.result_log, imgname.replace('.png', '_result_log')) 369 | 370 | if 0: 371 | batch_width = 100 372 | num_batch = 1000 373 | vs = sarsa(0.05, policyClass=Greedy) 374 | label = 'Sarsa(0.05)' 375 | imgname = 'sarsa_0.05_2.png' 376 | dump(environment.result_log, imgname.replace('.png', '_result_log')) 377 | 378 | 379 | batch_width = 1000 380 | num_batch = 100 381 | vs = sarsa(0.5) 382 | label = 'Sarsa(0.5)' 383 | imgname = 'sarsa_0.5_2.png' 384 | 385 | plot() 386 | 387 | -------------------------------------------------------------------------------- /tictactoe.py: -------------------------------------------------------------------------------- 1 | """ 2 | tic tac toe 3 | 4 | 0: vacant 5 | 1: first player 6 | 2: second player 7 | """ 8 | import numpy as np 9 | 10 | def _board_to_int(board): 11 | s = 0 12 | for i in range(9): 13 | s += board[i] * (3 ** i) 14 | return s 15 | 16 | def board_to_possible_hands(board): 17 | return [i for i in range(9) if board[i] == 0] 18 | 19 | def init_board(): 20 | return np.zeros(9, dtype=np.int) 21 | 22 | class QRaw(object): 23 | def __init__(self): 24 | #self.value = [0] * (3 ** 9 * 9) 25 | self.value = {} 26 | def get(self, info, a): 27 | return self.value.get(self.get_key(info, a), 0) 28 | def set(self, info, a, v): 29 | self.value[self.get_key(info, a)] = v 30 | def get_key(self, info, a): 31 | s = _board_to_int(info) 32 | return s * 9 + a 33 | 34 | def array_to_int(xs, m, n): 35 | "convert array to [0, m ** n) integer" 36 | assert len(xs) == n 37 | s = 0 38 | for i in range(n): 39 | assert 0 <= xs[i] < m 40 | s += xs[i] * (m ** i) 41 | return s 42 | 43 | 44 | class QSym(QRaw): 45 | "Compress states using symmetry" 46 | def get_key(self, info, a): 47 | board = info.copy() 48 | board[a] += 3 # where to place 49 | buf = [] 50 | for i in range(4): 51 | buf.append(array_to_int(board, 6, 9)) 52 | # rotate 90 ccw 53 | board = board[[2, 5, 8, 1, 4, 7, 0, 3, 6]] 54 | # mirror 55 | board = board[[2, 1, 0, 5, 4, 3, 8, 7, 6]] 56 | for i in range(4): 57 | buf.append(array_to_int(board, 6, 9)) 58 | # rotate 90 ccw 59 | board = board[[2, 5, 8, 1, 4, 7, 0, 3, 6]] 60 | 61 | return min(buf) 62 | 63 | 64 | class QLine1(QRaw): 65 | def to_lines(self, info, a): 66 | board = info.copy() 67 | board[a] += 3 # where to place 68 | lines = np.array([board[line] for line in LINES]) 69 | lines.sort(axis=1) 70 | lines.sort(axis=0) 71 | return lines 72 | 73 | def get_key(self, info, a): 74 | lines = self.to_lines(info, a) 75 | # list(sorted(set(tuple(sorted((a, b, c))) for a in range(6) for b in range(6) for c in range(6) if one(x > 2 for x in (a, b, c))))) 76 | MAP = [(0, 0, 3), (0, 0, 4), (0, 0, 5), (0, 1, 3), 77 | (0, 1, 4), (0, 1, 5), (0, 2, 3), (0, 2, 4), 78 | (0, 2, 5), (1, 1, 3), (1, 1, 4), (1, 1, 5), 79 | (1, 2, 3), (1, 2, 4), (1, 2, 5), (2, 2, 3), 80 | (2, 2, 4), (2, 2, 5)] 81 | MAP = dict((MAP[i], i + 1) for i in range(len(MAP))) 82 | lines = [MAP.get(tuple(line), 0) for line in lines] 83 | lines.sort() 84 | return tuple(lines) 85 | 86 | 87 | class QLine2(QLine1): 88 | def get_key(self, info, a): 89 | lines = self.to_lines(info, a) 90 | MAP = {(2, 2, 3): 1, (1, 1, 3): 2, 91 | (0, 2, 3): 3, (0, 1, 3): 4, 92 | (0, 0, 3): 5 93 | } 94 | lines = [MAP.get(tuple(line), 0) for line in lines] 95 | lines.sort() 96 | print lines 97 | return tuple(lines) 98 | 99 | 100 | def policy_random(env): 101 | from random import choice 102 | actions = board_to_possible_hands(env.board) 103 | return choice(actions) 104 | 105 | LINES = [ 106 | [0, 1, 2], [3, 4, 5], [6, 7, 8], 107 | [0, 3, 6], [1, 4, 7], [2, 5, 8], 108 | [0, 4, 8], [2, 4, 6] 109 | ] 110 | def is_win(board): 111 | for line in LINES: 112 | a, b, c = board[line] 113 | if a != 0 and a == b and a == c: 114 | return a 115 | return 0 116 | 117 | def print_board(board): 118 | s = ['.ox'[x] for x in board] 119 | print ' '.join(s[0:3]) 120 | print ' '.join(s[3:6]) 121 | print ' '.join(s[6:9]) 122 | print 123 | 124 | 125 | class Environment(object): 126 | def __init__(self, policy=policy_random): 127 | self.board = init_board() 128 | self.op_policy = policy 129 | self.result_log =[] 130 | 131 | def __call__(self, action): 132 | if self.board[action] != 0: 133 | # illegal move 134 | self.board = init_board() 135 | self.result_log.append(2) 136 | return (self, -1) 137 | 138 | self.board[action] = 1 139 | b = is_win(self.board) 140 | if b: 141 | self.board = init_board() 142 | self.result_log.append(1) 143 | return (self, +1) 144 | 145 | if not board_to_possible_hands(self.board): 146 | self.board = init_board() 147 | self.result_log.append(0) 148 | return (self, -1) 149 | 150 | op_action = self.op_policy(self) 151 | self.board[op_action] = 2 152 | b = is_win(self.board) 153 | if b: 154 | self.result_log.append(2) 155 | self.board = init_board() 156 | return (self, -1) 157 | 158 | return (self, 0) 159 | 160 | def get_info(self): 161 | return self.board.copy() 162 | 163 | 164 | def play(policy1, policy2=policy_random, to_print=False): 165 | env = Environment() 166 | result = 0 167 | for i in range(9): 168 | if i % 2 == 0: 169 | a = policy1(env) 170 | env.board[a] = 1 171 | else: 172 | a = policy2(env) 173 | env.board[a] = 2 174 | 175 | if to_print: 176 | print_board(env.board) 177 | 178 | b = is_win(env.board) 179 | if b: 180 | result = b 181 | break 182 | return result 183 | 184 | 185 | from collections import Counter 186 | if not"ex1": 187 | print Counter( 188 | play(policy_random) for i in range(10000)) 189 | 190 | 191 | class Greedy(object): 192 | def __init__(self, QClass=QRaw): 193 | self.Qtable = QClass() 194 | 195 | def __call__(self, env): 196 | from random import choice 197 | actions = get_available_actions(env) 198 | qa = [(self.Qtable.get(env.get_info(), a), a) for a in actions] 199 | bestQ, bestA = max(qa) 200 | bextQ, bestA = choice([(q, a) for (q, a) in qa if q == bestQ]) 201 | return bestA 202 | 203 | 204 | def get_available_actions(env): 205 | return range(9) 206 | # return board_to_possible_hands(board) 207 | 208 | 209 | class EpsilonGreedy(object): 210 | def __init__(self, eps=0.1): 211 | self.Qtable = init_Q() 212 | self.eps = eps 213 | 214 | def __call__(self, env): 215 | from random import choice, random 216 | s = board_to_int(env.board) 217 | if random() < self.eps: 218 | actions = get_available_actions(env) 219 | return choice(actions) 220 | 221 | actions = get_available_actions(env) 222 | qa = [(self.Qtable[get_Qkey(s, a)], a) for a in actions] 223 | bestQ, bestA = max(qa) 224 | bextQ, bestA = choice([(q, a) for (q, a) in qa if q == bestQ]) 225 | return bestA 226 | 227 | 228 | def sarsa(alpha, policyClass=Greedy, extra_play=False): 229 | global environment, policy 230 | gamma = 0.9 231 | num_result = batch_width * num_batch 232 | environment = Environment() 233 | policy = policyClass() 234 | action = policy(environment) 235 | prev_num_battle = 0 236 | win_ratios = [] 237 | while True: 238 | old = environment.get_info() 239 | next_env, reward = environment(action) 240 | 241 | # determine a' 242 | next_action = policy(next_env) 243 | nextQ = policy.Qtable.get(next_env.get_info(), next_action) 244 | 245 | # update Q(s, a) 246 | Qsa = policy.Qtable.get(old, action) 247 | estimated_reward = reward + gamma * nextQ 248 | diff = estimated_reward - Qsa 249 | policy.Qtable.set(old, action, Qsa + alpha * diff) 250 | 251 | action = next_action 252 | 253 | num_battle = len(environment.result_log) 254 | if extra_play and num_battle > prev_num_battle: 255 | if num_battle % batch_width == 0: 256 | v = calc_win_ratio_from_plays(policy, policy_random) 257 | win_ratios.append(v) 258 | if num_battle == num_result: 259 | break 260 | prev_num_battle = num_battle 261 | 262 | if not extra_play: 263 | win_ratios = calc_win_ratios_from_result_log(environment) 264 | return win_ratios 265 | 266 | def calc_win_ratio_from_plays(policy, op_policy=policy_random, N=1000): 267 | v = calc_win_ratio( 268 | [play(policy, op_policy) for i in range(1000)]) 269 | return v 270 | 271 | def calc_win_ratio(xs): 272 | c = Counter(xs) 273 | return float(c[1]) / len(xs) 274 | 275 | 276 | def calc_win_ratios_from_result_log(env): 277 | vs = [] 278 | for i in range(num_batch): 279 | v = calc_win_ratio( 280 | environment.result_log[ 281 | batch_width * i : 282 | batch_width * (i + 1)]) 283 | vs.append(v) 284 | return vs 285 | 286 | 287 | def qlearn(alpha, policyClass=Greedy, extra_play=False): 288 | global environment, policy 289 | gamma = 0.9 290 | num_result = batch_width * num_batch 291 | environment = Environment() 292 | policy = policyClass() 293 | 294 | state = policy.Qtable.board_to_state(environment.board) 295 | prev_num_battle = 0 296 | win_ratios = [] 297 | while True: 298 | action = policy(environment) 299 | next_env, reward = environment(action) 300 | next_state = policy.Qtable.board_to_state(next_env.board) 301 | 302 | # update Q(s, a) 303 | maxQ = max(policy.Qtable.get(next_state, a) 304 | for a in get_available_actions(next_env)) 305 | 306 | Qsa = policy.Qtable.get(state, action) 307 | estimated_reward = reward + gamma * maxQ 308 | diff = estimated_reward - Qsa 309 | policy.Qtable.set(state, action, Qsa + alpha * diff) 310 | 311 | state = next_state 312 | 313 | num_battle = len(environment.result_log) 314 | if extra_play and num_battle > prev_num_battle: 315 | if num_battle % batch_width == 0: 316 | v = calc_win_ratio_from_plays(policy, policy_random) 317 | win_ratios.append(v) 318 | if num_battle == num_result: 319 | break 320 | prev_num_battle = num_battle 321 | 322 | if not extra_play: 323 | win_ratios = calc_win_ratios_from_result_log(environment) 324 | return win_ratios 325 | 326 | 327 | batch_width = 100 328 | num_batch = 100 329 | 330 | import matplotlib.pyplot as plt 331 | 332 | def plot(seq, name, baseline=None): 333 | """seq: [(values, label)]""" 334 | plt.clf() 335 | if baseline != None: 336 | plt.plot([0.58] * len(seq[0][0]), label = "baseline") 337 | for (vs, label) in seq: 338 | plt.plot(vs, label=label) 339 | plt.xlabel("iteration") 340 | plt.ylabel("Prob. of win") 341 | plt.legend(loc = 4) 342 | plt.savefig(name) 343 | 344 | 345 | #batch_width = 10 346 | #num_batch = 100 347 | #result = [ 348 | # (sarsa(0.05), "Sarsa(0.05)"), 349 | # (sarsa(0.05, lambda: Greedy(QSym)), "Sarsa(0.05)+Sym"), 350 | # (sarsa(0.05, lambda: Greedy(QLine1)), "Sarsa(0.05)+Line1"), 351 | # (sarsa(0.05, lambda: Greedy(QLine2)), "Sarsa(0.05)+Line2"), 352 | # (qlearn(0.05), "Qlearn(0.05)"), 353 | # (qlearn(0.05, lambda: Greedy(QSym)), "Qlearn(0.05)+Sym"), 354 | #] 355 | #plot(result, 'out.png', baseline=0.58) 356 | 357 | 358 | batch_width = 300 359 | num_batch = 1 360 | def foo(f): 361 | win_ratios = [] 362 | for i in range(100): 363 | f() 364 | v = calc_win_ratio_from_plays(policy, policy_random) 365 | win_ratios.append(v) 366 | print i, v 367 | s = np.array(win_ratios) 368 | print len(win_ratios) 369 | print "{:.2f}+-{:.2f}".format(s.mean(), s.std() * 2) 370 | 371 | #foo(lambda: sarsa(0.05)) 372 | #foo(lambda: sarsa(0.05, lambda: Greedy(QSym))) 373 | #foo(lambda: sarsa(0.05, lambda: Greedy(QLine1))) 374 | foo(lambda: sarsa(0.05, lambda: Greedy(QLine2))) 375 | -------------------------------------------------------------------------------- /tiger.py: -------------------------------------------------------------------------------- 1 | """ 2 | POMDP 3 | """ 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | 7 | actions = ['Aleft', 'Aright', 'Alisten'] 8 | states = ['Sleft', 'Sright'] 9 | observations = ['Oleft', 'Oright'] 10 | 11 | # R(s, a) 12 | reward = np.array([ 13 | [+10.0, -100.0, -1.0], 14 | [-100.0, +10.0, -1.0], 15 | ]) 16 | 17 | # T(s, a, s') 18 | transition = np.array([ 19 | [[0.5, 0.5], [0.5, 0.5],[1.0, 0.0]], 20 | [[0.5, 0.5], [0.5, 0.5],[0.0, 1.0]] 21 | ]) 22 | 23 | # O(s', a, o) 24 | observation = np.array([ 25 | [[0.5, 0.5], [0.5, 0.5],[0.85, 0.15]], 26 | [[0.5, 0.5], [0.5, 0.5],[0.15, 0.85]] 27 | ]) 28 | 29 | # initial beleif 30 | b0 = np.array([0.5, 0.5]) 31 | 32 | V0 = np.array([[0.0, 0.0]]) 33 | 34 | gamma = 1.0 35 | 36 | def _common_part(V0): 37 | g_a_star = reward 38 | g_a_o = np.zeros((len(actions), len(observations), len(V0), len(states))) 39 | for i in range(len(V0)): 40 | for a in range(len(actions)): 41 | for o in range(len(observations)): 42 | buf = np.zeros(len(states)) 43 | for s in range(len(states)): 44 | sum_s2 = 0.0 45 | for s2 in range(len(states)): 46 | sum_s2 += transition[s, a, s2] * observation[s2, a, o] * V0[i][s2] 47 | buf[s] = gamma * sum_s2 48 | g_a_o[a, o, i, :] = buf 49 | return g_a_star, g_a_o 50 | 51 | 52 | def exact_backup(V0): 53 | g_a_star, g_a_o = _common_part(V0) 54 | #print 'g_a_o' 55 | #print g_a_o 56 | # exact 57 | tmp = g_a_star.copy() 58 | g_a = np.zeros((len(actions), len(V0) ** len(observations), len(states))) 59 | for a in range(len(actions)): 60 | #print tmp[:, a] 61 | lhs = [tmp[:, a]] 62 | 63 | for o in range(len(observations)): 64 | #print g_a_o[a, o] 65 | rhs= g_a_o[a, o] 66 | buf = [] 67 | #print "lhs: {}, rhs: {}".format(lhs, rhs) 68 | for x in lhs: 69 | for y in rhs: 70 | buf.append(x + y) 71 | lhs = buf 72 | #print lhs 73 | g_a[a,:] = lhs 74 | #print g_a 75 | V1 = g_a.reshape(len(actions) * len(V0) ** len(observations), len(states)) 76 | return V1 77 | 78 | 79 | def pbvi_backup(V0, B): 80 | g_a_star, g_a_o = _common_part(V0) 81 | g_a_b = np.zeros((len(actions), len(B), len(states))) 82 | for a in range(len(actions)): 83 | for b in range(len(B)): 84 | tmp = g_a_star[:, a].copy() 85 | for o in range(len(observations)): 86 | i = np.argmax([alpha.dot(B[b]) for alpha in g_a_o[a, o]]) 87 | tmp += g_a_o[a, o, i] 88 | g_a_b[a, b] = tmp 89 | 90 | V1 = np.array([ 91 | g_a_b[np.argmax([g_a_b[a, b].dot(B[b]) for a in range(len(actions))]), b] 92 | for b in range(len(B)) 93 | ]) 94 | return V1 95 | 96 | 97 | def plot(V): 98 | for alpha in V: 99 | plt.plot(alpha) 100 | plt.show() 101 | 102 | 103 | def prune(V): 104 | ret = [] 105 | N = 10000 106 | for i in range(N + 1): 107 | s = float(i) / N 108 | b = [s, 1 - s] 109 | ret.append(np.argmax([x.dot(b) for x in V])) 110 | #print ret 111 | return V[list(set(ret))] 112 | 113 | 114 | def get_max(V): 115 | vs = [] 116 | N = 10000 117 | for i in range(N + 1): 118 | s = float(i) / N 119 | b = [s, 1 - s] 120 | vs.append(np.max([x.dot(b) for x in V])) 121 | return vs 122 | 123 | 124 | 125 | def update_belief(b, a, o): 126 | ret = np.zeros(len(states)) 127 | for s2 in range(len(states)): 128 | tmp = np.sum([b[s] * transition[s, a, s2] for s in range(len(states))]) 129 | ret[s2] = tmp * observation[s2, a, o] 130 | z = ret.sum() 131 | return ret / z 132 | 133 | 134 | from random import random 135 | def belief_point_set_expansion(B0): 136 | print 137 | ret = list(B0) 138 | for b in B0: 139 | buf = [] 140 | s = 1 if random() < b[1] else 0 141 | for a in range(len(actions)): 142 | p = transition[s, a][1] 143 | s2 = 1 if random() < p else 0 144 | p = observation[s2, a][1] 145 | o = 1 if random() < p else 0 146 | 147 | b2 = update_belief(b, a, o) 148 | buf.append(b2) 149 | dists = [np.min([np.linalg.norm(x - y) for y in B0]) for x in buf] 150 | print b 151 | print dists 152 | i = np.argmax(dists) 153 | if dists[i] > 0.0: 154 | ret.append(buf[i]) 155 | 156 | return ret 157 | 158 | 159 | def get_B(): 160 | b = [np.array([0.5, 0.5])] 161 | while True: 162 | b = belief_point_set_expansion(b) 163 | print b 164 | if len(b) > 10: break 165 | return b 166 | 167 | 168 | if 0: 169 | N = 10 170 | B = [] 171 | for i in range(N + 1): 172 | s = float(i) / N 173 | b = [s, 1 - s] 174 | B.append(b) 175 | B = np.array(B) 176 | else: 177 | B = get_B() 178 | 179 | if 'pbvi': 180 | backup = lambda V: pbvi_backup(V, B) 181 | else: 182 | backup = exact_backup 183 | 184 | if 0: 185 | from kagura import stopwatch 186 | s = stopwatch.Stopwatch(quiet=True) 187 | Vprev = V0 188 | for i in range(10): 189 | s.start() 190 | V = backup(Vprev) 191 | beforePrune = len(V) 192 | V = prune(V) 193 | print "V{}, {}->{}, {}".format(i + 1, beforePrune, len(V), s.get()) 194 | s.end() 195 | Vprev = V 196 | 197 | 198 | import kagura 199 | if 1: 200 | N = 10000 201 | xs = [float(i) / N for i in range(N + 1)] 202 | plt.plot(xs, get_max(kagura.load('exactVI')), label='exact') 203 | plt.plot(xs, get_max(kagura.load('PBVI')), label='PBVI') 204 | plt.plot(xs, get_max(kagura.load('PBVI2')), label='PBVI2') 205 | V = kagura.load('PBVI2') 206 | plt.scatter(np.array(B)[:, 1], [np.max([b.dot(v) for v in V]) for b in B]) 207 | plt.legend(loc = 4) 208 | plt.xlim(0, 1) 209 | plt.show() 210 | --------------------------------------------------------------------------------