├── README.md ├── catch.py ├── lua ├── catch.lua ├── common.lua ├── easy21.lua ├── monte_carlo.lua ├── nnlearner.lua ├── qlearner.lua ├── random.lua └── testbed.lua ├── nbandit.py └── screenshots ├── catch.png └── nbandit.png /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning: An Introduction 2 | 3 | Implementing exercises for [Reinforcement Learning: An Introduction](http://webdocs.cs.ualberta.ca/~sutton/book/the-book.html). 4 | 5 | ## Chapter 2 - Bandit Problems 6 | 7 | ![Greedy, 0.1 and 0.01 epsilon greedy agents](screenshots/nbandit.png) 8 | 9 | `nbandit.py`, implementation of a greedy and epsilon greedy agent for the n-armed bandit problem. For explanations on how it works, read the book ;) 10 | 11 | ## Playing Catch 12 | 13 | As a more interesting test, I next tried my hand at a very simple game: Catch. 14 | 15 | A ball starts at a random position at the top of a 5x5 playing field and moves down one row each round. The player controls a bat to catch the ball width, which can either move left, right or stand still. Catching the ball gives a reward of +1, missing -1. 16 | 17 | A naive table based agent learns to play perfectly after ~500 episodes, the neural network based ones (with 1 and 2 hidden layers) take quite a bit longer, about 3000 episodes: 18 | 19 | ![Table based Q-learning, 1 and 2 hidden layer neural networks, random agent](screenshots/catch.png) 20 | -------------------------------------------------------------------------------- /catch.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import collections 3 | import functools 4 | import random 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class Environment(object): 10 | __metaclass__ = abc.ABCMeta 11 | 12 | @abc.abstractmethod 13 | def available_actions(self): 14 | return [] 15 | 16 | @abc.abstractmethod 17 | def reset(self): 18 | pass 19 | 20 | @abc.abstractmethod 21 | def observe(self): 22 | return [] 23 | 24 | @abc.abstractmethod 25 | def act(self, action): 26 | """Returns (reward, is_terminal).""" 27 | return 0.0, True 28 | 29 | 30 | class Catch(Environment): 31 | 32 | def __init__(self, width): 33 | self.width = width 34 | self.height = width 35 | self.reset() 36 | 37 | def available_actions(self): 38 | return [-1, 0, 1] 39 | 40 | def reset(self): 41 | self.ball_position = [0, random.randint(0, self.width - 1)] 42 | self.bat_position = random.randint(0, self.width - 1) 43 | 44 | def observe(self): 45 | return self._render(self.ball_position[0], self.ball_position[1], 46 | self.bat_position) 47 | 48 | def _render(self, ball_h, ball_w, bat_w): 49 | screen = [] 50 | for h in range(self.height): 51 | row = [] 52 | for w in range(self.width): 53 | if h == ball_h and w == ball_w: 54 | row.append('*') 55 | elif w == self.bat_position and h == self.height - 1: 56 | row.append('_') 57 | else: 58 | row.append(' ') 59 | screen.append(row) 60 | return screen 61 | 62 | def act(self, action): 63 | self.bat_position = max(0, min(self.width - 1, self.bat_position + action)) 64 | self.ball_position[0] += 1 65 | if self.ball_position[0] == self.height - 1: 66 | if self.bat_position == self.ball_position[1]: 67 | return 1, True 68 | else: 69 | return -1, True 70 | return 0, False 71 | 72 | 73 | class QLearner(object): 74 | """Off-Policy TD Control, page 158.""" 75 | 76 | def __init__(self, available_actions, learning_rate, discount_factor): 77 | # action_value_estimates 78 | self.Q = collections.defaultdict(lambda: collections.defaultdict(int)) 79 | self.available_actions = available_actions 80 | self.learning_rate = learning_rate 81 | self.discount_factor = discount_factor 82 | 83 | def act(self, observation): 84 | q_action = self.Q[self._condense_observation(observation)] 85 | return functools.reduce(lambda a, b: a if q_action[a] > q_action[b] else b, 86 | self.available_actions) 87 | 88 | def learn(self, observation_raw, action, new_observation_raw, reward): 89 | observation = self._condense_observation(observation_raw) 90 | new_observation = self._condense_observation(new_observation_raw) 91 | max_q = max([self.Q[new_observation][a] for a in self.available_actions]) 92 | self.Q[observation][action] += self.learning_rate * \ 93 | (reward + self.discount_factor * max_q - self.Q[observation][action]) 94 | 95 | def _condense_observation(self, observation): 96 | return '\n'.join([''.join(row) for row in observation]) 97 | 98 | def __str__(self): 99 | return 'GreedyAgent' 100 | 101 | num_episodes = 350 102 | num_repetitions = 200 103 | 104 | plt.ion() 105 | x = range(num_episodes) 106 | rewards = [0 for _ in range(num_episodes)] 107 | fig = plt.figure() 108 | ax = fig.add_subplot(111) 109 | ax.set_ylim([-1.1, 1.1]) 110 | line = ax.plot(x, rewards, '-')[0] 111 | 112 | 113 | for repetition in range(num_repetitions): 114 | catch = Catch(5) 115 | agent = QLearner(catch.available_actions(), learning_rate=0.1, 116 | discount_factor=0.9) 117 | reward_ma = 0 118 | for i in range(num_episodes): 119 | terminal = False 120 | catch.reset() 121 | 122 | while not terminal: 123 | observation = catch.observe() 124 | action = agent.act(observation) 125 | reward, terminal = catch.act(action) 126 | new_observation = catch.observe() 127 | agent.learn(observation, action, new_observation, reward) 128 | 129 | reward_ma = 0.05 * reward + 0.95 * reward_ma 130 | 131 | rewards[i] += 1.0 / (repetition + 1) * (reward_ma - rewards[i]) 132 | if repetition % 3 == 0: 133 | line.set_ydata(rewards) 134 | fig.canvas.draw() 135 | 136 | raw_input() 137 | -------------------------------------------------------------------------------- /lua/catch.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local Catch = {} 4 | Catch.__index = Catch 5 | 6 | function Catch.create(width) 7 | local catch = {} 8 | setmetatable(catch, Catch) 9 | catch.width = width 10 | catch.height = width 11 | return catch 12 | end 13 | 14 | -- actions 15 | function Catch:availableActions() 16 | return {-1, 0, 1} 17 | end 18 | 19 | function Catch:reset() 20 | self.ballPosition = {0, math.floor(torch.uniform(0, self.width))} 21 | self.batPosition = math.floor(torch.uniform(0, self.width)) 22 | end 23 | 24 | -- observations 25 | function Catch:observe() 26 | return {self.ballPosition[1], self.ballPosition[2], self.batPosition} 27 | end 28 | 29 | -- reward, isTerminal 30 | function Catch:act(action) 31 | self.batPosition = math.max(0, math.min(self.width - 1, self.batPosition + action)) 32 | self.ballPosition[1] = self.ballPosition[1] + 1 33 | if self.ballPosition[1] == self.height - 1 then 34 | if self.batPosition == self.ballPosition[2] then 35 | return 1.0, true 36 | else 37 | return -1.0, true 38 | end 39 | end 40 | return 0.0, false 41 | end 42 | 43 | return Catch 44 | -------------------------------------------------------------------------------- /lua/common.lua: -------------------------------------------------------------------------------- 1 | function defaultdict(default_value_factory) 2 | local t = {} 3 | local metatable = {} 4 | metatable.__index = function(t, key) 5 | if not rawget(t, key) then 6 | rawset(t, key, default_value_factory(key)) 7 | end 8 | return rawget(t, key) 9 | end 10 | return setmetatable(t, metatable) 11 | end 12 | 13 | function reduce(f, xs) 14 | acc = xs[1] 15 | for i = 2, #xs do 16 | acc = f(acc, xs[i]) 17 | end 18 | return acc 19 | end 20 | 21 | function observationToString(observation) 22 | return table.concat(observation, "") 23 | end 24 | -------------------------------------------------------------------------------- /lua/easy21.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local Easy21 = {} 4 | Easy21.__index = Easy21 5 | 6 | function Easy21.create(width) 7 | local easy21 = {} 8 | setmetatable(easy21, Easy21) 9 | return easy21 10 | end 11 | 12 | -- actions 13 | function Easy21:availableActions() 14 | return {'stick', 'hit'} 15 | end 16 | 17 | function Easy21:reset() 18 | -- start off with one black card each 19 | self.player = math.ceil(torch.uniform(0, 10)) 20 | self.dealer = math.ceil(torch.uniform(0, 10)) 21 | end 22 | 23 | -- observations 24 | function Easy21:observe() 25 | return {self.player, self.dealer} 26 | end 27 | 28 | function drawCard(initial_score) 29 | local card_value = math.ceil(torch.uniform(0, 10)) 30 | if torch.uniform() < 1 / 3 then 31 | -- red card, subtract 32 | return -1 * card_value 33 | else 34 | -- black card, add 35 | return 1 * card_value 36 | end 37 | end 38 | 39 | function isBust(score) 40 | return score < 1 or score > 21 41 | end 42 | 43 | -- reward, isTerminal 44 | function Easy21:act(action) 45 | if action == 'hit' then 46 | self.player = self.player + drawCard() 47 | if isBust(self.player) then 48 | return -1, true 49 | else 50 | return 0, false 51 | end 52 | elseif action == 'stick' then 53 | -- play out dealer. Dealer sticks on 17 or higher, hits otherwise. 54 | while self.dealer < 17 do 55 | self.dealer = self.dealer + drawCard() 56 | end 57 | if isBust(self.dealer) or self.player > self.dealer then 58 | return 1, true 59 | elseif self.dealer > self.player then 60 | return -1, true 61 | else 62 | return 0, true 63 | end 64 | else 65 | error('unknown action: ' .. action) 66 | end 67 | end 68 | 69 | return Easy21 70 | -------------------------------------------------------------------------------- /lua/monte_carlo.lua: -------------------------------------------------------------------------------- 1 | require 'common' 2 | 3 | local MonteCarlo = {} 4 | MonteCarlo.__index = MonteCarlo 5 | 6 | function MonteCarlo.create(availableActions, args) 7 | args = args or {} 8 | local learner = {} 9 | setmetatable(learner, MonteCarlo) 10 | learner.Q = defaultdict(function(_) return defaultdict(function(_) return 0 end) end) 11 | learner.availableActions = availableActions 12 | learner.n0 = args.n0 or 100 13 | learner.n_state = defaultdict(function(_) return 0 end) 14 | learner.n_action = defaultdict(function(_) return defaultdict(function(_) return 0 end) end) 15 | learner.visited = {} 16 | learner.reward = 0 17 | return learner 18 | end 19 | 20 | function MonteCarlo:act(observation) 21 | local state = observationToString(observation) 22 | self.n_state[state] = self.n_state[state] + 1 23 | local epsilon = self.n0 / (self.n0 + self.n_state[state]) 24 | 25 | if torch.uniform() < epsilon then 26 | -- pick random action 27 | return self.availableActions[math.floor( 28 | torch.uniform(1, #self.availableActions + 1))] 29 | else 30 | local q_act = self.Q[state] 31 | -- pick highest value action 32 | return reduce(function(a, b) return q_act[a] > q_act[b] and a or b end, 33 | self.availableActions) 34 | end 35 | end 36 | 37 | function MonteCarlo:learn(observationRaw, action, newObservationRaw, reward, terminal) 38 | self.visited[#self.visited + 1] = {observationRaw, action} 39 | self.reward = self.reward + reward 40 | if not terminal then 41 | -- Monte Carlo only updates at the end of the episode, towards the actual 42 | -- return. (In contrast to TD, which updates at every step towards the 43 | -- estimated return). 44 | return 45 | end 46 | 47 | 48 | for i, visited in ipairs(self.visited) do 49 | local observationRaw, action = unpack(visited) 50 | local state = observationToString(observationRaw) 51 | self.n_action[state][action] = self.n_action[state][action] + 1 52 | local alpha = 1 / self.n_action[state][action] 53 | local q = self.Q[state][action] 54 | self.Q[state][action] = q + alpha * (self.reward - q) 55 | end 56 | 57 | self.visited = {} 58 | self.reward = 0 59 | end 60 | 61 | 62 | function MonteCarlo:visualizeQState() 63 | local q = torch.zeros(21, 21) 64 | for player_value = 1, 21 do 65 | for dealer_value = 1, 21 do 66 | local state = observationToString{player_value, dealer_value} 67 | local q_act = self.Q[state] 68 | local max_action = reduce( 69 | function(a, b) return q_act[a] > q_act[b] and a or b end, 70 | self.availableActions) 71 | q[player_value][dealer_value] = q_act[max_action] 72 | end 73 | end 74 | 75 | 76 | gnuplot.xlabel('Player') 77 | gnuplot.ylabel('Dealer') 78 | gnuplot.axis({1, 21, 1, 21}) 79 | gnuplot.figure(1) 80 | gnuplot.splot(q) 81 | end 82 | 83 | return MonteCarlo 84 | -------------------------------------------------------------------------------- /lua/nnlearner.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'torch' 3 | 4 | local NNLearner = {} 5 | NNLearner.__index = NNLearner 6 | 7 | function NNLearner.create(availableActions, args) 8 | local learner = {} 9 | setmetatable(learner, NNLearner) 10 | learner.availableActions = availableActions 11 | learner.learningRate = args.learningRate or 0.2 12 | learner.discountFactor = args.discountFactor or 0.9 13 | learner.explorationRate = args.explorationRate or 0.05 14 | 15 | learner.screen = torch.zeros(25) 16 | learner.criterion = nn.MSECriterion() 17 | learner.target = torch.zeros(3) 18 | 19 | learner.mlp = nn.Sequential(); -- make a multi-layer perceptron 20 | local inputs = 25 21 | local outputs = #availableActions 22 | local HUs = args.HUs or 20 23 | learner.mlp:add(nn.Linear(inputs, HUs)) 24 | learner.mlp:add(args.transfer or nn.ReLU()) 25 | for i = 1, args.hiddenLayers - 1 do 26 | learner.mlp:add(nn.Linear(HUs, HUs)) 27 | learner.mlp:add(args.transfer or nn.ReLU()) 28 | end 29 | learner.mlp:add(nn.Linear(HUs, outputs)) 30 | return learner 31 | end 32 | 33 | function NNLearner:q(observation) 34 | for i = 1, 25 do 35 | self.screen[i] = 0 36 | end 37 | self.screen[observation[1] * 5 + observation[2] + 1] = 1 38 | self.screen[21 + observation[3]] = -1 39 | return self.mlp:forward(self.screen) 40 | end 41 | 42 | local function maxAction(q, availableActions) 43 | local max_q = - math.huge 44 | local action_index = -1 45 | for i, a in pairs(availableActions) do 46 | if q[i] > max_q then 47 | max_q = q[i] 48 | action_index = i 49 | end 50 | end 51 | return max_q, action_index 52 | end 53 | 54 | function NNLearner:act(observation, rewardMa) 55 | if torch.uniform() < math.min(self.explorationRate, (1 - rewardMa) / 4) then 56 | return self.availableActions[math.floor( 57 | torch.uniform(1, #self.availableActions + 1))] 58 | end 59 | local _, action_index = maxAction(self:q(observation), self.availableActions) 60 | return self.availableActions[action_index] 61 | end 62 | 63 | function NNLearner:learn(observation, action, newObservation, reward, terminal) 64 | local q = self:q(newObservation) 65 | local max_q, _ = maxAction(q, self.availableActions) 66 | local action_index = -1 67 | for i, a in pairs(self.availableActions) do 68 | if a == action then 69 | action_index = i 70 | end 71 | end 72 | 73 | 74 | local pred = self:q(observation) 75 | for i = 1, pred:size(1) do 76 | self.target[i] = pred[i] + self.learningRate * (q[i] - pred[i]) 77 | end 78 | if terminal then 79 | self.target[action_index] = reward 80 | else 81 | self.target[action_index] = pred[action_index] + self.learningRate * (reward + self.discountFactor * max_q - pred[action_index]) 82 | end 83 | 84 | local err = self.criterion:forward(pred, self.target) 85 | local gradCriterion = self.criterion:backward(pred, self.target) 86 | 87 | -- train over this example in 3 steps 88 | -- (1) zero the accumulation of the gradients 89 | self.mlp:zeroGradParameters() 90 | -- (2) accumulate gradients 91 | self.mlp:backward(self.screen, gradCriterion) 92 | -- (3) update parameters with learning rate 93 | self.mlp:updateParameters(self.learningRate) 94 | end 95 | 96 | function NNLearner:visualize() 97 | local bat = 2 98 | for action = 1, 3 do 99 | local gradient = torch.zeros(5, 5) 100 | for ballH = 0, 4 do 101 | for ballW = 0, 4 do 102 | local q = self:q({ballH, ballW, bat}) 103 | gradient[ballH + 1][ballW + 1] = q[action] 104 | end 105 | end 106 | gnuplot.figure(action + 1) 107 | gnuplot.imagesc(gradient, 'color') 108 | end 109 | -- body 110 | end 111 | 112 | return NNLearner 113 | -------------------------------------------------------------------------------- /lua/qlearner.lua: -------------------------------------------------------------------------------- 1 | require 'common' 2 | 3 | local QLearner = {} 4 | QLearner.__index = QLearner 5 | 6 | function QLearner.create(availableActions, learningRate, discountFactor) 7 | local learner = {} 8 | setmetatable(learner, QLearner) 9 | learner.Q = defaultdict(function(_) return defaultdict(function(_) return 0 end) end) 10 | learner.availableActions = availableActions 11 | learner.learningRate = learningRate 12 | learner.discountFactor = discountFactor 13 | return learner 14 | end 15 | 16 | function QLearner:act(observation) 17 | q_action = self.Q[observationToString(observation)] 18 | return reduce(function(a, b) return q_action[a] > q_action[b] and a or b end, 19 | self.availableActions) 20 | end 21 | 22 | function QLearner:learn(observationRaw, action, newObservationRaw, reward, terminal) 23 | local max_q = - math.huge 24 | for i, a in pairs(self.availableActions) do 25 | local q = self.Q[observationToString(newObservationRaw)][a] 26 | if q > max_q then 27 | max_q = q 28 | end 29 | end 30 | local observation = observationToString(observationRaw) 31 | self.Q[observation][action] = (self.Q[observation][action] + self.learningRate * 32 | (reward + self.discountFactor * max_q - self.Q[observation][action])) 33 | end 34 | 35 | return QLearner 36 | -------------------------------------------------------------------------------- /lua/random.lua: -------------------------------------------------------------------------------- 1 | require 'common' 2 | 3 | local Random = {} 4 | Random.__index = Random 5 | 6 | function Random.create(availableActions) 7 | local learner = {} 8 | setmetatable(learner, Random) 9 | learner.availableActions = availableActions 10 | return learner 11 | end 12 | 13 | 14 | function Random:act(observation) 15 | return self.availableActions[math.floor( 16 | torch.uniform(1, #self.availableActions + 1))] 17 | end 18 | 19 | function Random:learn(observationRaw, action, newObservationRaw, reward, terminal) 20 | end 21 | 22 | return Random 23 | -------------------------------------------------------------------------------- /lua/testbed.lua: -------------------------------------------------------------------------------- 1 | require 'gnuplot' 2 | require 'torch' 3 | 4 | -- Environments 5 | local catch = require 'catch' 6 | local easy21 = require 'easy21' 7 | 8 | -- Agents 9 | local nnlearner = require 'nnlearner' 10 | local randomlearner = require 'random' 11 | local qlearner = require 'qlearner' 12 | local monteCarlo = require 'monte_carlo' 13 | 14 | 15 | local numRepetitions = 50 16 | local numEpisodes = 20000000 17 | local agentDefinitions = { 18 | -- { 19 | -- name = 'QLearner', 20 | -- factory = function(availableActions) 21 | -- return qlearner.create(availableActions, 0.1, 0.9) 22 | -- end, 23 | -- }, 24 | -- { 25 | -- name = 'Random', 26 | -- factory = function(availableActions) 27 | -- return randomlearner.create(availableActions) 28 | -- end, 29 | -- }, 30 | { 31 | name = 'MonteCarlo', 32 | factory = function(availableActions) 33 | return monteCarlo.create(availableActions) 34 | end, 35 | }, 36 | -- { 37 | -- name = 'NNLearner 1', 38 | -- factory = function(availableActions) 39 | -- return nnlearner.create(availableActions, {hiddenLayers=1}) 40 | -- end, 41 | -- }, 42 | -- { 43 | -- name = 'NNLearner 2', 44 | -- factory = function(availableActions) 45 | -- return nnlearner.create(availableActions, {hiddenLayers=2}) 46 | -- end, 47 | -- }, 48 | } 49 | local rewards = torch.zeros(#agentDefinitions, numEpisodes) 50 | 51 | for repetition = 1, numRepetitions do 52 | -- local environment = catch.create(5) 53 | local environment = easy21.create() 54 | local agents = {} 55 | for _, definition in pairs(agentDefinitions) do 56 | agents[#agents + 1] = definition.factory(environment:availableActions()) 57 | end 58 | 59 | print(repetition) 60 | 61 | for agent_i = 1, #agents do 62 | local timer = torch.Timer() 63 | local agent = agents[agent_i] 64 | local reward_ma = 0 65 | for i = 1, numEpisodes do 66 | environment:reset() 67 | reward = 0.0 68 | terminal = false 69 | while not terminal do 70 | local observation = environment:observe() 71 | local action = agent:act(observation, reward_ma) 72 | reward, terminal = environment:act(action) 73 | local new_observation = environment:observe() 74 | agent:learn(observation, action, new_observation, reward, terminal) 75 | end 76 | 77 | reward_ma = 0.01 * reward + 0.99 * reward_ma 78 | 79 | -- iterative mean from http://www.heikohoffmann.de/htmlthesis/node134.html 80 | rewards[agent_i][i] = (rewards[agent_i][i] + 1.0 / repetition * 81 | (reward_ma - rewards[agent_i][i])) 82 | 83 | if i % 10000 == 0 then 84 | local name = agentDefinitions[agent_i].name 85 | print(name .. ': ' .. rewards[agent_i][i]) 86 | if name == 'MonteCarlo' then 87 | agent:visualizeQState() 88 | end 89 | end 90 | 91 | end 92 | 93 | if repetition % 1 == 0 then 94 | data = {} 95 | for i = 1, #agents do 96 | local reward = rewards[i] 97 | table.insert(data, { 98 | agentDefinitions[i]['name'], 99 | torch.range(1, numEpisodes), 100 | reward, 101 | '-', 102 | }) 103 | end 104 | gnuplot.axis({0, numEpisodes, -1, 1}) 105 | gnuplot.figure(0) 106 | gnuplot.plot(data) 107 | end 108 | 109 | end 110 | end 111 | -------------------------------------------------------------------------------- /nbandit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2 2 | import abc 3 | import commands 4 | import functools 5 | import math 6 | import random 7 | import time 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | class Environment(object): 13 | __metaclass__ = abc.ABCMeta 14 | 15 | @abc.abstractmethod 16 | def available_actions(self): 17 | return [] 18 | 19 | @abc.abstractmethod 20 | def act(self, action): 21 | return 0.0 22 | 23 | 24 | class Bandit(Environment): 25 | 26 | def __init__(self, rewards, noise_std): 27 | self.rewards = rewards 28 | self.n = len(rewards) 29 | self.noise_std = noise_std 30 | 31 | def available_actions(self): 32 | return range(self.n) 33 | 34 | def act(self, action): 35 | return self.rewards[action] + random.gauss(0, self.noise_std) 36 | 37 | def __str__(self): 38 | return 'Bandit(%s, noise_std=%s)' % (self.rewards, self.noise_std) 39 | 40 | 41 | class GreedyAgent(object): 42 | 43 | def __init__(self, num_actions): 44 | # action_value_estimates 45 | self.Q = [0 for _ in range(num_actions)] 46 | self.action_count = [1 for _ in range(num_actions)] 47 | 48 | def act(self): 49 | return functools.reduce(lambda a, b: a if self.Q[a] > self.Q[b] else b, 50 | range(len(self.Q))) 51 | 52 | def learn(self, action, reward): 53 | self.Q[action] += 1.0 / self.action_count[action] * \ 54 | (reward - self.Q[action]) 55 | self.action_count[action] += 1 56 | 57 | def __str__(self): 58 | return 'GreedyAgent' 59 | 60 | 61 | class EpsilonGreedyAgent(GreedyAgent): 62 | 63 | def __init__(self, num_actions, epsilon): 64 | super(EpsilonGreedyAgent, self).__init__(num_actions) 65 | self.epsilon = epsilon 66 | self.num_actions = num_actions 67 | 68 | def act(self): 69 | if random.random() < self.epsilon: 70 | return random.randint(0, self.num_actions - 1) 71 | else: 72 | return super(EpsilonGreedyAgent, self).act() 73 | 74 | def __str__(self): 75 | return 'EpsilonGreedyAgent(epsilon=%f)' % self.epsilon 76 | 77 | 78 | class SoftmaxGreedyAgent(GreedyAgent): 79 | 80 | def __init__(self, num_actions, temperature): 81 | super(SoftmaxGreedyAgent, self).__init__(num_actions) 82 | self.temperature = temperature 83 | self.num_actions = num_actions 84 | 85 | def act(self): 86 | base = sum([math.exp(q / self.temperature) for q in self.Q]) 87 | r = random.random() 88 | for i, q in enumerate(self.Q): 89 | p = math.exp(q / self.temperature) / base 90 | if r < p: 91 | return i 92 | r -= p 93 | 94 | def __str__(self): 95 | return 'SoftmaxGreedyAgent(temperature=%f)' % self.temperature 96 | 97 | 98 | class Testbed(object): 99 | 100 | def __init__(self, environment_fn, agent_fns): 101 | self.environment_fn = environment_fn 102 | self.agent_fns = agent_fns 103 | 104 | def evaluate(self, num_episodes, episode_length, show_plot=False, 105 | save_animation=False): 106 | if show_plot or save_animation: 107 | plt.ion() 108 | x = range(episode_length) 109 | self.rewards = [ 110 | [0 for _ in range(episode_length)] for _ in self.agent_fns] 111 | self.fig = plt.figure() 112 | ax = self.fig.add_subplot(111) 113 | ax.set_ylim([0, 1.6]) 114 | 115 | # Create some fake agents for the labels. 116 | environment = self.environment_fn() 117 | agents = [f(len(environment.available_actions())) 118 | for f in self.agent_fns] 119 | self.lines = [ax.plot(x, rewards, '-', label=str(agent))[0] 120 | for rewards, agent in zip(self.rewards, agents)] 121 | ax.legend(loc=8) 122 | 123 | start = time.time() 124 | for episode in xrange(num_episodes): 125 | environment = self.environment_fn() 126 | agents = [f(len(environment.available_actions())) 127 | for f in self.agent_fns] 128 | for step in xrange(episode_length): 129 | for i, agent in enumerate(agents): 130 | action = agent.act() 131 | reward = environment.act(action) 132 | agent.learn(action, reward) 133 | 134 | if show_plot or save_animation: 135 | self.rewards[i][step] += ((1.0 / (episode + 1)) * 136 | (reward - self.rewards[i][step])) 137 | 138 | if (show_plot or save_animation) and episode % 20 == 0: 139 | for line, rewards in zip(self.lines, self.rewards): 140 | line.set_ydata(rewards) 141 | self.fig.canvas.draw() 142 | 143 | if save_animation: 144 | self.fig.savefig('episode%04d.png' % episode) 145 | 146 | if episode and episode % 100 == 0: 147 | end_rewards = ', '.join( 148 | ['%s: %0.2f' % (a, r[-1]) for a, r in zip(agents, self.rewards)]) 149 | eps = episode / (time.time() - start) 150 | print 'Episode %d: %s (%.2f e/s)' % (episode, end_rewards, eps) 151 | 152 | if save_animation: 153 | print 'Creating gif...' 154 | commands.getoutput( 155 | 'convert -delay 10 -loop 0 -colors 15 -quality 50% -resize 50% ' 156 | 'episode*.png training.gif') 157 | commands.getoutput('rm episode*.png') 158 | 159 | testbed = Testbed( 160 | environment_fn=lambda: Bandit([random.gauss(0, 1) for _ in range(10)], 1), 161 | agent_fns=[ 162 | lambda n: GreedyAgent(n), 163 | lambda n: EpsilonGreedyAgent(n, 0.1), 164 | lambda n: SoftmaxGreedyAgent(n, 0.2), 165 | ]) 166 | testbed.evaluate( 167 | num_episodes=2000, episode_length=1000, show_plot=True) 168 | 169 | print 'Done, press enter to exit.' 170 | raw_input() 171 | -------------------------------------------------------------------------------- /screenshots/catch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mononofu/reinforcement-learning/8ea3b189e8ffad2158e8fd991db8cfd5748d281d/screenshots/catch.png -------------------------------------------------------------------------------- /screenshots/nbandit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mononofu/reinforcement-learning/8ea3b189e8ffad2158e8fd991db8cfd5748d281d/screenshots/nbandit.png --------------------------------------------------------------------------------