├── .gitignore ├── images ├── nn_sarsa_rms_vs_lambda.png ├── nn_sarsa_rms_vs_iteration.jpg ├── nn_sarsa_rms_vs_iteration.png ├── table_sarsa_rms_vs_lambda.png ├── table_sarsa_rms_vs_iteration.png ├── q3a.eps ├── q4a.eps └── q5a.eps ├── util ├── io_util.lua ├── tensorutil.lua ├── mdputil.lua ├── util_for_unittests.lua └── util.lua ├── rl_constants.lua ├── Policy.lua ├── TableSarsaFactory.lua ├── Explorer.lua ├── TestPolicy.lua ├── SarsaFactory.lua ├── AllActionsEqualPolicy.lua ├── test ├── unittest_TestPolicy.lua ├── unittest_io_util.lua ├── unittest_TableSarsaFactory.lua ├── unittest_TestMdp.lua ├── unittest_TestSAFE.lua ├── unittest_LinSarsaFactory.lua ├── unittest_AllActionsEqualPolicy.lua ├── unittest_VHash.lua ├── unittest_NNSarsaFactory.lua ├── unittest_tensorutil.lua ├── unittest_GreedyPolicy.lua ├── unittest_QHash.lua ├── unittest_TestMdpQVAnalyzer.lua ├── unittest_LinSarsa.lua ├── unittest_EpisodeBuilder.lua ├── unittest_MonteCarloControl.lua ├── unittest_NNSarsa.lua ├── unittest_QLin.lua ├── unittest_QNN.lua ├── unittest_MdpSampler.lua ├── unittest_util.lua └── unittest_TableSarsa.lua ├── ControlFactory.lua ├── ConstExplorer.lua ├── TestSAFE.lua ├── VFunc.lua ├── DecayTableExplorer.lua ├── CMakeLists.txt ├── ValueIteration.lua ├── MdpConfig.lua ├── doc ├── montecarlo.md ├── valuefunctions.md ├── policy.md ├── mdp.md ├── sarsa.md └── index.md ├── VHash.lua ├── GreedyPolicy.lua ├── Control.lua ├── SAFeatureExtractor.lua ├── QFunc.lua ├── Evaluator.lua ├── QHash.lua ├── run_tests.lua ├── QLin.lua ├── NNSarsaFactory.lua ├── LinSarsaFactory.lua ├── TestMdp.lua ├── QVAnalyzer.lua ├── MdpSampler.lua ├── Mdp.lua ├── rl.lua ├── MonteCarloControl.lua ├── QApprox.lua ├── LinSarsa.lua ├── EpisodeBuilder.lua ├── NNSarsa.lua ├── TestMdpQVAnalyzer.lua ├── Sarsa.lua ├── TableSarsa.lua ├── rocks └── rl-0.1-1.rockspec ├── rl-0.2-5.rockspec ├── QNN.lua ├── README.md └── SarsaAnalyzer.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.rock 3 | -------------------------------------------------------------------------------- /images/nn_sarsa_rms_vs_lambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_lambda.png -------------------------------------------------------------------------------- /images/nn_sarsa_rms_vs_iteration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_iteration.jpg -------------------------------------------------------------------------------- /images/nn_sarsa_rms_vs_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_iteration.png -------------------------------------------------------------------------------- /images/table_sarsa_rms_vs_lambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/table_sarsa_rms_vs_lambda.png -------------------------------------------------------------------------------- /images/table_sarsa_rms_vs_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/table_sarsa_rms_vs_iteration.png -------------------------------------------------------------------------------- /util/io_util.lua: -------------------------------------------------------------------------------- 1 | function rl.util.save_q(fname, q) 2 | torch.save(fname, q) 3 | end 4 | 5 | function rl.util.load_q(fname) 6 | return torch.load(fname) 7 | end 8 | -------------------------------------------------------------------------------- /rl_constants.lua: -------------------------------------------------------------------------------- 1 | rl.EVALUATOR_DEFAULT_N_ITERS = 1000 2 | rl.MONTECARLOCONTROL_DEFAULT_N0 = 100 3 | 4 | -- Tolerance for floating point comparison 5 | rl.FLOAT_EPS = 1e-10 6 | -------------------------------------------------------------------------------- /Policy.lua: -------------------------------------------------------------------------------- 1 | local Policy = torch.class('rl.Policy') 2 | 3 | function Policy:__init() 4 | end 5 | 6 | --- Return an action given a state 7 | function Policy:get_action(s) 8 | error('Policy must implement get_action') 9 | end 10 | -------------------------------------------------------------------------------- /TableSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | local TableSarsaFactory, parent = 2 | torch.class('rl.TableSarsaFactory', 'rl.SarsaFactory') 3 | 4 | function TableSarsaFactory:get_control() 5 | return rl.TableSarsa( 6 | self.mdp_config, 7 | self.lambda) 8 | end 9 | 10 | -------------------------------------------------------------------------------- /Explorer.lua: -------------------------------------------------------------------------------- 1 | local Explorer = torch.class('rl.Explorer') 2 | 3 | function Explorer:__init() 4 | end 5 | 6 | --- Return epsilon, the probability of exploring randomly, given a state 7 | function Explorer:get_eps(s) 8 | error('Explorer must implement get_eps') 9 | end 10 | -------------------------------------------------------------------------------- /TestPolicy.lua: -------------------------------------------------------------------------------- 1 | local TestPolicy, parent = torch.class('rl.TestPolicy', 'rl.Policy') 2 | 3 | function TestPolicy:__init(action) 4 | parent.__init(self) 5 | self.action = action 6 | end 7 | 8 | function TestPolicy:get_action(s) 9 | return self.action 10 | end 11 | -------------------------------------------------------------------------------- /SarsaFactory.lua: -------------------------------------------------------------------------------- 1 | local SarsaFactory, parent = torch.class('rl.SarsaFactory', 'rl.ControlFactory') 2 | 3 | function SarsaFactory:__init(mdp_config, lambda) 4 | parent.__init(self, mdp_config) 5 | self.lambda = lambda 6 | end 7 | 8 | function SarsaFactory:set_lambda(lambda) 9 | self.lambda = lambda 10 | end 11 | -------------------------------------------------------------------------------- /AllActionsEqualPolicy.lua: -------------------------------------------------------------------------------- 1 | local AAEP, parent = torch.class('rl.AllActionsEqualPolicy', 'rl.Policy') 2 | 3 | function AAEP:__init(mdp) 4 | parent:__init(self) 5 | self.mdp = mdp 6 | end 7 | 8 | function AAEP:get_action(s) 9 | actions = self.mdp:get_all_actions() 10 | return actions[math.random(1, #actions)] 11 | end 12 | 13 | -------------------------------------------------------------------------------- /test/unittest_TestPolicy.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestTestPolicy = {} 5 | function TestTestPolicy.test_correct_action() 6 | local a = 2 7 | local policy = rl.TestPolicy(a) 8 | tester:asserteq(a, policy:get_action(3)) 9 | end 10 | 11 | tester:add(TestTestPolicy) 12 | 13 | tester:run() 14 | 15 | -------------------------------------------------------------------------------- /test/unittest_io_util.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | 5 | local TestIOUtil = {} 6 | function TestIOUtil.test_save_load() 7 | local q = rl.rl.QHash(rl.TestMdp()) 8 | rl.util.save_q('/tmp/q', q) 9 | local q2 = rl.util.load_q('/tmp/q') 10 | tester:asserteq(q, q2) 11 | end 12 | tester:add(TestIOUtil) 13 | 14 | tester:run() 15 | -------------------------------------------------------------------------------- /ControlFactory.lua: -------------------------------------------------------------------------------- 1 | local ControlFactory = torch.class('rl.ControlFactory') 2 | 3 | function ControlFactory:__init(mdp_config) 4 | self.mdp_config = mdp_config 5 | end 6 | 7 | function ControlFactory:set_mdp_config(mdp_config) 8 | self.mdp_config = mdp_config 9 | end 10 | 11 | function ControlFactory:get_control() 12 | error('Must implement get_control') 13 | end 14 | -------------------------------------------------------------------------------- /util/tensorutil.lua: -------------------------------------------------------------------------------- 1 | -- Thanks to http://stackoverflow.com/questions/34123291/torch-apply-function-over-dimension 2 | function rl.util.apply_to_slices(tensor, dimension, func, ...) 3 | for i, slice in ipairs(tensor:split(1, dimension)) do 4 | func(slice, i, ...) 5 | end 6 | return tensor 7 | end 8 | 9 | function rl.util.fill_range(tensor, i, offset) 10 | tensor:fill(i + offset) 11 | end 12 | 13 | -------------------------------------------------------------------------------- /ConstExplorer.lua: -------------------------------------------------------------------------------- 1 | --- Explore with a fixed probability 2 | local ConstExplorer = torch.class('rl.ConstExplorer', 'rl.Explorer') 3 | 4 | function ConstExplorer:__init(p) 5 | self.p = p 6 | end 7 | 8 | function ConstExplorer:get_eps(s) 9 | return self.p 10 | end 11 | 12 | function ConstExplorer:__eq(other) 13 | return torch.typename(self) == torch.typename(other) 14 | and self.p == other.p 15 | end 16 | -------------------------------------------------------------------------------- /TestSAFE.lua: -------------------------------------------------------------------------------- 1 | local TestSAFE, parent = torch.class('rl.TestSAFE', 'rl.SAFeatureExtractor') 2 | 3 | -- simple feature extractor that returns the sum and difference of s and a for 4 | -- TestMdp 5 | function TestSAFE:get_sa_features(s, a) 6 | return torch.Tensor{s+a, s-a} 7 | end 8 | 9 | function TestSAFE:get_sa_features_dim() 10 | return 2 11 | end 12 | 13 | function TestSAFE:get_sa_num_features() 14 | return 2 15 | end 16 | -------------------------------------------------------------------------------- /VFunc.lua: -------------------------------------------------------------------------------- 1 | local VFunc = torch.class('rl.VFunc') 2 | 3 | function VFunc:__init(mdp) 4 | self.mdp = mdp 5 | end 6 | 7 | function VFunc:get_value(s) 8 | error('Must implement get_Value method') 9 | end 10 | 11 | function VFunc:__eq(other) 12 | for _, s in pairs(self.mdp:get_all_states()) do 13 | if self:get_value(s) ~= other:get_value(s) then 14 | return false 15 | end 16 | end 17 | return true 18 | end 19 | -------------------------------------------------------------------------------- /DecayTableExplorer.lua: -------------------------------------------------------------------------------- 1 | --- Choose epsilon to be N0 / N0 + (# times visited a state) 2 | local DecayTableExplorer, parent = 3 | torch.class('rl.DecayTableExplorer', 'rl.Explorer') 4 | 5 | function DecayTableExplorer:__init(N0, state_table) 6 | parent.__init(self) 7 | self.N0 = N0 8 | self.state_table = state_table 9 | end 10 | 11 | function DecayTableExplorer:get_eps(s) 12 | return self.N0 / (self.N0 + self.state_table:get_value(s)) 13 | end 14 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.6) 3 | IF(LUAROCKS_PREFIX) 4 | MESSAGE(STATUS "Installing Torch through Luarocks") 5 | STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}") 6 | MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}") 7 | ENDIF() 8 | FIND_PACKAGE(Torch REQUIRED) 9 | 10 | FILE(GLOB luasrc *.lua) 11 | ADD_TORCH_PACKAGE(rl "" "${luasrc}" "Reinforcement Learning Package") 12 | 13 | -------------------------------------------------------------------------------- /ValueIteration.lua: -------------------------------------------------------------------------------- 1 | local ValueIteration = torch.class('rl.ValueIteration', 'rl.Control') 2 | 3 | -- ValueIteration captures an algorithm that optimizes a policy by alternating 4 | -- between policy evaluation for one step and policy iteration for one step 5 | function ValueIteration:improve_policy() 6 | self:optimize_policy() 7 | self:evaluate_policy() 8 | end 9 | 10 | function ValueIteration:optimize_policy() 11 | error('Must implement optimize_policy.') 12 | end 13 | 14 | function ValueIteration:evaluate_policy() 15 | error('Must implement evaluate_policy.') 16 | end 17 | -------------------------------------------------------------------------------- /test/unittest_TableSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local mdp = rl.TestMdp() 5 | local discount_factor = 0.12 6 | local TestTableSarsaFactory = {} 7 | function TestTableSarsaFactory.test_get_control() 8 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 9 | local lambda = 0.126 10 | 11 | local table_sarsa = rl.TableSarsa(mdp_config, lambda) 12 | local factory = rl.TableSarsaFactory(mdp_config, lambda) 13 | tester:assert(factory:get_control() == table_sarsa) 14 | end 15 | 16 | tester:add(TestTableSarsaFactory) 17 | 18 | tester:run() 19 | 20 | -------------------------------------------------------------------------------- /MdpConfig.lua: -------------------------------------------------------------------------------- 1 | local MdpConfig = torch.class('rl.MdpConfig') 2 | 3 | function MdpConfig:__init(mdp, discount_factor) 4 | self.mdp = mdp 5 | if discount_factor < 0 or discount_factor > 1 then 6 | error('Gamma must be between 0 and 1, inclusive.') 7 | end 8 | self.discount_factor = discount_factor 9 | end 10 | 11 | function MdpConfig:get_mdp() 12 | return self.mdp 13 | end 14 | 15 | function MdpConfig:get_discount_factor() 16 | return self.discount_factor 17 | end 18 | 19 | function MdpConfig:get_description() 20 | return self.mdp:get_description() .. "-Gamma="..self.discount_factor 21 | end 22 | -------------------------------------------------------------------------------- /doc/montecarlo.md: -------------------------------------------------------------------------------- 1 | ## Monte Carlo Control 2 | Monte Carlo (MC) estimates the value of a state-action pair under a given 3 | policy by sampling and taking the average. MC Control alternates between this 4 | Q-function estimation and epsilon-greedy policy improvement. 5 | 6 | Example use: 7 | 8 | ```lua 9 | local mdp = TestMdp() 10 | local discount_factor = 0.9 11 | local n_iters = 1000 12 | 13 | local mdp_config = MdpConfig(mdp, discount_factor) 14 | local mc = MonteCarloControl(mdp_config) 15 | mc:improve_policy_for_n_iters(n_iters) 16 | 17 | local policy = mc:get_policy() 18 | local learned_action = policy.get_action(state) 19 | ``` 20 | 21 | -------------------------------------------------------------------------------- /test/unittest_TestMdp.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestTestMdp = {} 5 | function TestTestMdp.test_terminates() 6 | local mdp = rl.TestMdp() 7 | local s = mdp:get_start_state() 8 | s = mdp:step(s, 1) 9 | s = mdp:step(s, 1) 10 | tester:assert(mdp:is_terminal(s)) 11 | end 12 | 13 | function TestTestMdp.test_reward() 14 | local mdp = rl.TestMdp() 15 | local old_s = 1 16 | local a = 1 17 | local new_s, r = mdp:step(old_s, a) 18 | tester:asserteq(r, -1) 19 | 20 | a = 3 21 | local _, r = mdp:step(new_s, a) 22 | tester:asserteq(r, 1) 23 | end 24 | 25 | tester:add(TestTestMdp) 26 | 27 | tester:run() 28 | 29 | -------------------------------------------------------------------------------- /VHash.lua: -------------------------------------------------------------------------------- 1 | -- A slow implementation of a state value function using hashes and tables 2 | 3 | local VHash, parent = torch.class('rl.VHash', 'rl.VFunc') 4 | 5 | function VHash:__init(mdp) 6 | parent.__init(self, mdp) 7 | self.v_table = rl.util.get_all_states_map(mdp) 8 | self.hs = function (s) return mdp:hash_s(s) end 9 | end 10 | 11 | function VHash:get_value(s) 12 | return self.v_table[self.hs(s)] 13 | end 14 | 15 | function VHash:mult(s, value) 16 | self.v_table[self.hs(s)] = self.v_table[self.hs(s)] * value 17 | end 18 | 19 | function VHash:add(s, delta) 20 | self.v_table[self.hs(s)] = self.v_table[self.hs(s)] + delta 21 | end 22 | 23 | VHash.__eq = parent.__eq 24 | -------------------------------------------------------------------------------- /util/mdputil.lua: -------------------------------------------------------------------------------- 1 | -- Get a table, where all the state keys have been initialized to 0 2 | function rl.util.get_all_states_map(mdp) 3 | all_states_map = {} 4 | for k, s in pairs(mdp:get_all_states()) do 5 | all_states_map[mdp:hash_s(s)] = 0 6 | end 7 | return all_states_map 8 | end 9 | 10 | function rl.util.get_all_states_action_map(mdp) 11 | all_states_actions_map = {} 12 | for k, s in pairs(mdp:get_all_states()) do 13 | local actions_map = {} 14 | for _, a in pairs(mdp:get_all_actions()) do 15 | actions_map[mdp:hash_a(a)] = 0 16 | end 17 | all_states_actions_map[mdp:hash_s(s)] = actions_map 18 | end 19 | return all_states_actions_map 20 | end 21 | -------------------------------------------------------------------------------- /GreedyPolicy.lua: -------------------------------------------------------------------------------- 1 | local GreedyPolicy, parent = torch.class('rl.GreedyPolicy', 'rl.Policy') 2 | 3 | function GreedyPolicy:__init(q, exploration_strat, actions) 4 | parent.__init(self) 5 | self.q = q 6 | self.exploration_strat = exploration_strat 7 | self.actions = actions 8 | self.n_actions = #actions 9 | end 10 | 11 | function GreedyPolicy:get_action(s) 12 | local eps = self.exploration_strat:get_eps(s) 13 | 14 | pi = {} 15 | for k, a in pairs(self.actions) do 16 | pi[a] = eps / self.n_actions 17 | end 18 | 19 | -- Add 1-eps to best action 20 | best_a = self.q:get_best_action(s) 21 | pi[best_a] = pi[best_a] + 1 - eps 22 | 23 | return rl.util.weighted_random_choice(pi) 24 | end 25 | -------------------------------------------------------------------------------- /Control.lua: -------------------------------------------------------------------------------- 1 | local Control = torch.class("rl.Control") 2 | 3 | -- Control captures an algorithm that optimizes a policy for a given MDP. 4 | function Control:__init(mdp_config) 5 | self.mdp = mdp_config:get_mdp() 6 | self.policy = rl.AllActionsEqualPolicy(self.mdp) 7 | self.sampler = rl.MdpSampler(mdp_config) 8 | end 9 | 10 | function Control:improve_policy_for_n_iters(n_iters) 11 | for i = 1, n_iters do 12 | self:improve_policy() 13 | end 14 | end 15 | 16 | function Control:improve_policy() 17 | error('Must implement improve_policy') 18 | end 19 | 20 | function Control:set_policy(policy) 21 | self.policy = policy 22 | end 23 | 24 | function Control:get_policy() 25 | return self.policy 26 | end 27 | 28 | -------------------------------------------------------------------------------- /SAFeatureExtractor.lua: -------------------------------------------------------------------------------- 1 | -- Feature extractor for a state-action pair. 2 | local SAFeatureExtractor = torch.class('rl.SAFeatureExtractor') 3 | 4 | function SAFeatureExtractor:__init() 5 | end 6 | 7 | -- returns a Tensor that represents the features of a given state-action pair 8 | function SAFeatureExtractor:get_sa_features(s, a) 9 | error('Must implement get_sa_features') 10 | end 11 | 12 | -- returns the dimension of the output of get_sa_features 13 | function SAFeatureExtractor:get_sa_features_dim() 14 | error('Must implement get_sa_features_dim') 15 | end 16 | 17 | -- returns the num of elements in the tensor returned by get_sa_features 18 | function SAFeatureExtractor:get_sa_num_features() 19 | error('Must implement get_sa_num_features') 20 | end 21 | -------------------------------------------------------------------------------- /QFunc.lua: -------------------------------------------------------------------------------- 1 | local QFunc = torch.class('rl.QFunc') 2 | 3 | function QFunc:__init(mdp) 4 | self.mdp = mdp 5 | end 6 | 7 | function QFunc:get_value(s, a) 8 | error('Must implement get_Value method') 9 | end 10 | 11 | function QFunc:get_best_action(s) 12 | local actions = self.mdp:get_all_actions() 13 | local Qs = self.q_table[self.hs(s)] 14 | local best_a, best_i = rl.util.max( 15 | actions, 16 | function (a) return Qs[self.ha(a)] end) 17 | return best_a 18 | end 19 | 20 | function QFunc:__eq(other) 21 | for _, s in pairs(self.mdp:get_all_states()) do 22 | for _, a in pairs(self.mdp:get_all_actions()) do 23 | if self:get_value(s, a) ~= other:get_value(s, a) then 24 | return false 25 | end 26 | end 27 | end 28 | return true 29 | end 30 | -------------------------------------------------------------------------------- /Evaluator.lua: -------------------------------------------------------------------------------- 1 | local Evaluator = torch.class('rl.Evaluator') 2 | 3 | function Evaluator.__init(self, mdp_config) 4 | self.sampler = rl.MdpSampler(mdp_config) 5 | self.mdp_description = mdp_config:get_description() 6 | end 7 | 8 | function Evaluator:get_policy_avg_return(policy, n_iters) 9 | local total_r = 0 10 | for i = 1, n_iters do 11 | total_r = total_r + self.sampler:sample_total_reward(policy) 12 | end 13 | return total_r 14 | end 15 | 16 | function Evaluator:display_metrics(policy, description, n_iters) 17 | n_iters = n_iters or rl.EVALUATOR_DEFAULT_N_ITERS 18 | local total_r = self:get_policy_avg_return(policy, n_iters) 19 | print('Avg Reward for <' .. description .. '> policy for ' .. 20 | self.mdp_description .. ': ' .. 21 | total_r .. '/' .. n_iters .. ' = ' .. total_r/n_iters) 22 | end 23 | -------------------------------------------------------------------------------- /QHash.lua: -------------------------------------------------------------------------------- 1 | -- A simple implementation of a state-action value function using hashes and 2 | -- tables 3 | local QHash, parent = torch.class('rl.QHash', 'rl.QFunc') 4 | 5 | function QHash:__init(mdp) 6 | parent.__init(self, mdp) 7 | self.hs = function (s) return mdp:hash_s(s) end 8 | self.ha = function (a) return mdp:hash_a(a) end 9 | self.q_table = rl.util.get_all_states_action_map(mdp) 10 | end 11 | 12 | function QHash:get_value(s, a) 13 | return self.q_table[self.hs(s)][self.ha(a)] 14 | end 15 | 16 | function QHash:mult(s, a, value) 17 | self.q_table[self.hs(s)][self.ha(a)] = self.q_table[self.hs(s)][self.ha(a)] * value 18 | end 19 | 20 | function QHash:add(s, a, delta) 21 | self.q_table[self.hs(s)][self.ha(a)] = self.q_table[self.hs(s)][self.ha(a)] + delta 22 | end 23 | 24 | -- Weird. For some reason this is needed 25 | QHash.__eq = parent.__eq 26 | -------------------------------------------------------------------------------- /run_tests.lua: -------------------------------------------------------------------------------- 1 | require 'test.unittest_util' 2 | require 'test.unittest_tensorutil' 3 | require 'test.unittest_io_util' 4 | 5 | require 'test.unittest_MdpSampler' 6 | require 'test.unittest_EpisodeBuilder' 7 | 8 | require 'test.unittest_AllActionsEqualPolicy' 9 | require 'test.unittest_GreedyPolicy' 10 | 11 | require 'test.unittest_QHash' 12 | require 'test.unittest_VHash' 13 | require 'test.unittest_MonteCarloControl' 14 | require 'test.unittest_TableSarsa' 15 | require 'test.unittest_TableSarsaFactory' 16 | 17 | require 'test.unittest_QLin' 18 | require 'test.unittest_LinSarsa' 19 | require 'test.unittest_LinSarsaFactory' 20 | 21 | require 'test.unittest_QNN' 22 | require 'test.unittest_NNSarsa' 23 | require 'test.unittest_NNSarsaFactory' 24 | 25 | require 'test.unittest_TestMdp' 26 | require 'test.unittest_TestPolicy' 27 | require 'test.unittest_TestSAFE' 28 | require 'test.unittest_TestMdpQVAnalyzer' 29 | -------------------------------------------------------------------------------- /test/unittest_TestSAFE.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestTestSAFE = {} 5 | local fe = rl.TestSAFE() 6 | function TestTestSAFE.test_get_sa_features() 7 | local s = 5 8 | local a = 2 9 | local expected = torch.Tensor{7, 3} 10 | tester:assertTensorEq(fe:get_sa_features(s, a), expected, 0) 11 | end 12 | 13 | function TestTestSAFE.test_get_sa_features_dim() 14 | local s = 5 15 | local a = 2 16 | local blank = torch.Tensor(fe:get_sa_features_dim()) 17 | local features = fe:get_sa_features(s, a) 18 | tester:assert(rl.util.are_tensors_same_shape(blank, features)) 19 | end 20 | 21 | function TestTestSAFE.test_get_sa_num_features() 22 | local s = 5 23 | local a = 2 24 | local features = fe:get_sa_features(s, a) 25 | tester:asserteq(fe:get_sa_num_features(), features:numel()) 26 | end 27 | 28 | tester:add(TestTestSAFE) 29 | 30 | tester:run() 31 | 32 | -------------------------------------------------------------------------------- /QLin.lua: -------------------------------------------------------------------------------- 1 | -- Implementation of a state-action value function approx using linear function 2 | -- of features 3 | local QLin, parent = torch.class('rl.QLin', 'rl.QApprox') 4 | 5 | function QLin:__init(mdp, feature_extractor) 6 | parent.__init(self, mdp, feature_extractor) 7 | self.weights = torch.zeros(feature_extractor:get_sa_features_dim()) 8 | end 9 | 10 | function QLin:clear() 11 | self.weights = torch.zeros(self.feature_extractor:get_sa_features_dim()) 12 | end 13 | 14 | function QLin:get_value(s, a) 15 | return self.weights:dot(self.feature_extractor:get_sa_features(s, a)) 16 | end 17 | 18 | function QLin:add(d_weights) 19 | self.weights = self.weights + d_weights 20 | end 21 | 22 | function QLin:mult(factor) 23 | self.weights = self.weights * factor 24 | end 25 | 26 | function QLin:get_weight_vector() 27 | return self.weights 28 | end 29 | 30 | QLin.__eq = parent.__eq -- force inheritance of this 31 | -------------------------------------------------------------------------------- /NNSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | local NNSarsaFactory, parent = 2 | torch.class('rl.NNSarsaFactory', 'rl.SarsaFactory') 3 | 4 | function NNSarsaFactory:__init( 5 | mdp_config, 6 | lambda, 7 | explorer, 8 | feature_extractor, 9 | step_size) 10 | parent.__init(self, mdp_config, lambda) 11 | self.explorer = explorer 12 | self.feature_extractor = feature_extractor 13 | self.step_size = step_size 14 | end 15 | 16 | function NNSarsaFactory:set_explorer(explorer) 17 | self.explorer = explorer 18 | end 19 | 20 | function NNSarsaFactory:set_feature_extractor(feature_extractor) 21 | self.feature_extractor = feature_extractor 22 | end 23 | 24 | function NNSarsaFactory:get_control() 25 | return rl.NNSarsa(self.mdp_config, 26 | self.lambda, 27 | self.explorer, 28 | self.feature_extractor, 29 | self.step_size) 30 | end 31 | 32 | -------------------------------------------------------------------------------- /test/unittest_LinSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local mdp = rl.TestMdp() 5 | local discount_factor = 0.631 6 | local TestLinSarsaFactory = {} 7 | function TestLinSarsaFactory.test_get_control() 8 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 9 | local lambda = 0.65 10 | local eps = 0.2437 11 | local explorer = rl.ConstExplorer(eps) 12 | local feature_extractor = rl.TestSAFE() 13 | local step_size = 0.92 14 | 15 | local lin_sarsa = rl.LinSarsa( 16 | mdp_config, 17 | lambda, 18 | explorer, 19 | feature_extractor, 20 | step_size) 21 | local factory = rl.LinSarsaFactory( 22 | mdp_config, 23 | lambda, 24 | explorer, 25 | feature_extractor, 26 | step_size) 27 | tester:assert(factory:get_control() == lin_sarsa) 28 | end 29 | 30 | tester:add(TestLinSarsaFactory) 31 | 32 | tester:run() 33 | 34 | -------------------------------------------------------------------------------- /LinSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | local LinSarsaFactory, parent = 2 | torch.class('rl.LinSarsaFactory', 'rl.SarsaFactory') 3 | 4 | function LinSarsaFactory:__init( 5 | mdp_config, 6 | lambda, 7 | explorer, 8 | feature_extractor, 9 | step_size) 10 | parent.__init(self, mdp_config, lambda) 11 | self.explorer = explorer 12 | self.feature_extractor = feature_extractor 13 | self.step_size = step_size 14 | end 15 | 16 | function LinSarsaFactory:set_explorer(explorer) 17 | self.explorer = explorer 18 | end 19 | 20 | function LinSarsaFactory:set_feature_extractor(feature_extractor) 21 | self.feature_extractor = feature_extractor 22 | end 23 | 24 | function LinSarsaFactory:get_control() 25 | return rl.LinSarsa(self.mdp_config, 26 | self.lambda, 27 | self.explorer, 28 | self.feature_extractor, 29 | self.step_size) 30 | end 31 | 32 | -------------------------------------------------------------------------------- /TestMdp.lua: -------------------------------------------------------------------------------- 1 | -- Dummy MDP for testing 2 | -- state = either 1, 2, or 3 3 | -- action = either 1, 2, or 3 4 | local TestMdp, parent = torch.class('rl.TestMdp', 'rl.Mdp') 5 | 6 | local TERMINAL = 3 7 | 8 | function TestMdp:step(state, action) 9 | if TestMdp:is_terminal(state) then 10 | error('MDP is done.') 11 | end 12 | local reward = -1 13 | if state + action >= 4 then 14 | reward = 1 15 | end 16 | return state + 1, reward 17 | end 18 | 19 | function TestMdp:get_start_state() 20 | return 1 21 | end 22 | 23 | function TestMdp:is_terminal(s) 24 | return s == TERMINAL 25 | end 26 | 27 | function TestMdp:get_all_states() 28 | return {1, 2, 3} 29 | end 30 | 31 | function TestMdp:get_all_actions() 32 | return {1, 2, 3} 33 | end 34 | 35 | function TestMdp:hash_s(state) 36 | return state 37 | end 38 | 39 | function TestMdp:hash_a(action) 40 | return action 41 | end 42 | 43 | function TestMdp:get_description() 44 | return 'Test MDP' 45 | end 46 | -------------------------------------------------------------------------------- /test/unittest_AllActionsEqualPolicy.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestAllActionsEqualPolicy = {} 5 | 6 | local function all_actions_have_good_freq( 7 | action_history, 8 | all_actions) 9 | local expected_p = 1. / #all_actions 10 | for _, action in pairs(all_actions) do 11 | if not rl.util.elem_has_good_freq(action, action_history, expected_p) then 12 | return false 13 | end 14 | end 15 | return true 16 | end 17 | 18 | function TestAllActionsEqualPolicy.test_policy() 19 | local mdp = rl.TestMdp() 20 | local policy = rl.AllActionsEqualPolicy(mdp) 21 | local N = 100000 22 | local action_history = {} 23 | for i = 1, N do 24 | action_history[i] = policy:get_action(nil) 25 | end 26 | 27 | local all_actions = mdp:get_all_actions() 28 | tester:assert( 29 | all_actions_have_good_freq(action_history, all_actions)) 30 | end 31 | 32 | 33 | tester:add(TestAllActionsEqualPolicy) 34 | 35 | tester:run() 36 | -------------------------------------------------------------------------------- /test/unittest_VHash.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | local TestVHash = {} 5 | 6 | function TestVHash.test_add_once() 7 | local v = rl.VHash(rl.TestMdp()) 8 | local s = 1 9 | local val = 2 10 | v:add(s, val) 11 | 12 | tester:asserteq(v:get_value(s), val) 13 | end 14 | 15 | function TestVHash.test_mult() 16 | local v = rl.VHash(rl.TestMdp()) 17 | local s = 2 18 | v:add(s, 1) 19 | v:mult(s, 3) 20 | v:mult(s, 3) 21 | v:mult(s, 3) 22 | 23 | tester:asserteq(v:get_value(s), 27) 24 | end 25 | 26 | function TestVHash.test_equality() 27 | local v1 = rl.VHash(rl.TestMdp()) 28 | local s = 2 29 | v1:add(s, 1) 30 | v1:mult(s, 3) 31 | v1:mult(s, 3) 32 | v1:mult(s, 3) 33 | 34 | local v2 = rl.VHash(rl.TestMdp()) 35 | local s = 2 36 | v2:add(s, 5) 37 | v2:mult(s, 0) 38 | v2:add(s, 2) 39 | v2:add(s, -1) 40 | v2:mult(s, 3) 41 | v2:mult(s, 3) 42 | v2:mult(s, 3) 43 | 44 | tester:assert(v1 == v2) 45 | end 46 | 47 | tester:add(TestVHash) 48 | 49 | tester:run() 50 | -------------------------------------------------------------------------------- /QVAnalyzer.lua: -------------------------------------------------------------------------------- 1 | local gnuplot = require 'gnuplot' 2 | 3 | local QVAnalyzer = torch.class('rl.QVAnalyzer') 4 | 5 | function QVAnalyzer:__init(mdp) 6 | self.mdp = mdp 7 | end 8 | function QVAnalyzer:get_v_tensor(v) 9 | error('Must implement get_v_tensor.') 10 | end 11 | 12 | -- Plot the state value function 13 | function QVAnalyzer:plot_v(v) 14 | error('Must implement plot_v.') 15 | end 16 | 17 | function QVAnalyzer:plot_best_action(q, method_description) 18 | error('Must implement plot_best_action.') 19 | end 20 | 21 | function QVAnalyzer:v_from_q(q) 22 | error('Must implement v_from_q') 23 | end 24 | 25 | function QVAnalyzer:get_q_tensor(q) 26 | error('Must implement get_q_tensor.') 27 | end 28 | 29 | function QVAnalyzer:q_rms(q1, q2) 30 | local t1 = self:get_q_tensor(q1) 31 | local t2 = self:get_q_tensor(q2) 32 | return torch.sum(torch.pow(t1 - t2, 2)) 33 | end 34 | 35 | function QVAnalyzer:v_rms(v1, v2) 36 | local t1 = self:get_v_tensor(v1) 37 | local t2 = self:get_v_tensor(v2) 38 | return torch.sum(torch.pow(t1 - t2, 2)) 39 | end 40 | -------------------------------------------------------------------------------- /doc/valuefunctions.md: -------------------------------------------------------------------------------- 1 | ## Value Functions 2 | All Q value functions implement: 3 | `get_value(s, a)` 4 | 5 | `get_best_action(s)` 6 | 7 | All V value functions implement: 8 | `get_value(s)` 9 | 10 | ### (Hash)Tables 11 | These are the simplest types of data structures. SHash and SAHash implement hash 12 | tables over the state and state-action states space, respectively. Only use 13 | these hash tables for small state/action spaces. 14 | 15 | ### Function Approximation 16 | For large state/action spaces, using a look-up table becomes intractable. An 17 | alternative is to approximate the value of a state or state-action pair by using 18 | a function approximator. See below for how features are extracted. 19 | 20 | ## State-Action Feature Extractors (SAFE) 21 | SAFeatureExtractor defines an interface for classes that extract features out of 22 | a given state-action pair. SAFEs have to implement: 23 | 24 | `[Tensor] get_sa_features(s, a)` 25 | `[number of tuple of numbers] get_sa_features_dim()` 26 | which returns the dimensions of the tensor returned by `get_sa_features`. 27 | -------------------------------------------------------------------------------- /test/unittest_NNSarsaFactory.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | 5 | local mdp = rl.TestMdp() 6 | local discount_factor = 0.631 7 | local TestNNSarsaFactory = {} 8 | function TestNNSarsaFactory.test_get_control() 9 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 10 | local lambda = 0.65 11 | local eps = 0.2437 12 | local explorer = rl.ConstExplorer(eps) 13 | local feature_extractor = rl.TestSAFE() 14 | local step_size = 0.92 15 | 16 | local nn_sarsa = rl.NNSarsa( 17 | mdp_config, 18 | lambda, 19 | explorer, 20 | feature_extractor, 21 | step_size) 22 | local factory = rl.NNSarsaFactory( 23 | mdp_config, 24 | lambda, 25 | explorer, 26 | feature_extractor, 27 | step_size) 28 | local nn_sarsa2 = factory:get_control() 29 | -- Not a clean way to prevent the q's from being initialized to different 30 | -- (random) parameters since they're neural network weights 31 | nn_sarsa2.q = nn_sarsa.q 32 | tester:assert(nn_sarsa2 == nn_sarsa) 33 | end 34 | 35 | tester:add(TestNNSarsaFactory) 36 | 37 | tester:run() 38 | 39 | -------------------------------------------------------------------------------- /test/unittest_tensorutil.lua: -------------------------------------------------------------------------------- 1 | 2 | local tester = torch.Tester() 3 | 4 | local TestTensorUtil = {} 5 | function TestTensorUtil.test_apply_to_slices() 6 | local function power_fill(tensor, i, power) 7 | power = power or 1 8 | tensor:fill(i ^ power) 9 | end 10 | local A = torch.Tensor(2, 2) 11 | local B = torch.Tensor(2, 2) 12 | B[1][1] = 1 13 | B[1][2] = 1 14 | B[2][1] = 2 15 | B[2][2] = 2 16 | tester:assertTensorEq( 17 | rl.util.apply_to_slices(A, 1, power_fill), 18 | B, 19 | 0) 20 | end 21 | 22 | function TestTensorUtil.fill_range() 23 | local A = torch.Tensor(2, 2) 24 | local B = torch.Tensor(2, 2) 25 | B[1][1] = -1 26 | B[1][2] = -1 27 | B[2][1] = 0 28 | B[2][2] = 0 29 | tester:assertTensorEq( 30 | rl.util.apply_to_slices(A, 1, rl.util.fill_range, -2), 31 | B, 32 | 0) 33 | 34 | B[1][1] = 123 35 | B[2][1] = 123 36 | B[1][2] = 124 37 | B[2][2] = 124 38 | tester:assertTensorEq( 39 | rl.util.apply_to_slices(A, 2, rl.util.fill_range, 122), 40 | B, 41 | 0) 42 | end 43 | 44 | tester:add(TestTensorUtil) 45 | 46 | tester:run() 47 | -------------------------------------------------------------------------------- /test/unittest_GreedyPolicy.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestGreedyPolicy = {} 5 | local mdp = rl.TestMdp() 6 | local q = rl.QFunc() 7 | 8 | local function get_policy(best_action, eps) 9 | local explorer = rl.ConstExplorer(eps) 10 | q.get_best_action = function (s) 11 | return best_action 12 | end 13 | return rl.GreedyPolicy(q, explorer, mdp:get_all_actions()) 14 | end 15 | 16 | function TestGreedyPolicy.test_greedy() 17 | local eps = 0.7 18 | local policy = get_policy(2, eps) 19 | local expected_probabilities = { 20 | eps/3, 21 | 1 - 2*eps/3, 22 | eps/3 23 | } 24 | tester:assert(rl.util.are_testmdp_policy_probabilities_good( 25 | policy, 26 | expected_probabilities)) 27 | end 28 | 29 | function TestGreedyPolicy.test_greedy2() 30 | local eps = 0.05 31 | local policy = get_policy(1, eps) 32 | local expected_probabilities = { 33 | 1 - 2*eps/3, 34 | eps/3, 35 | eps/3 36 | } 37 | tester:assert(rl.util.are_testmdp_policy_probabilities_good( 38 | policy, 39 | expected_probabilities)) 40 | end 41 | tester:add(TestGreedyPolicy) 42 | 43 | tester:run() 44 | -------------------------------------------------------------------------------- /MdpSampler.lua: -------------------------------------------------------------------------------- 1 | local MdpSampler = torch.class('rl.MdpSampler') 2 | 3 | function MdpSampler:__init(mdp_config) 4 | self.mdp = mdp_config:get_mdp() 5 | self.discount_factor = mdp_config:get_discount_factor() 6 | end 7 | 8 | function MdpSampler:sample_total_reward(policy) 9 | local s = self.mdp:get_start_state() 10 | local total_r, r = 0, 0 11 | while not self.mdp:is_terminal(s) do 12 | s, r = self.mdp:step(s, policy:get_action(s)) 13 | total_r = total_r + r 14 | end 15 | return total_r 16 | end 17 | 18 | -- Episode: list of {state, action, discounted return, reward}. Indexed by time, 19 | -- starting at time = 1. 20 | function MdpSampler:get_episode(policy) 21 | local s = self.mdp:get_start_state() 22 | local r = 0 23 | local a = nil 24 | local next_s = nil 25 | local episode_builder = rl.EpisodeBuilder(self.discount_factor) 26 | 27 | while not self.mdp:is_terminal(s) do 28 | a = policy:get_action(s) 29 | next_s, r = self.mdp:step(s, a) 30 | episode_builder:add_state_action_reward_step(s, a, r) 31 | s = next_s 32 | end 33 | 34 | return episode_builder:get_episode() 35 | end 36 | function MdpSampler:get_mdp() 37 | return self.mdp 38 | end 39 | -------------------------------------------------------------------------------- /test/unittest_QHash.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | local TestQHash = {} 5 | 6 | local mdp = rl.TestMdp() 7 | function TestQHash.test_add_once() 8 | local q = rl.QHash(mdp) 9 | local s = 1 10 | local a = 1 11 | local val = 2 12 | q:add(s, a, val) 13 | 14 | tester:asserteq(q:get_value(s, a), val) 15 | tester:asserteq(q:get_best_action(s), a) 16 | end 17 | 18 | function TestQHash.test_mult() 19 | local q = rl.QHash(rl.TestMdp()) 20 | local s = 2 21 | local a = 3 22 | q:add(s, a, 1) 23 | q:mult(s, a, 3) 24 | q:mult(s, a, 3) 25 | q:mult(s, a, 3) 26 | 27 | tester:asserteq(q:get_value(s, a), 27) 28 | end 29 | 30 | function TestQHash.test_equality() 31 | local q1 = rl.QHash(rl.TestMdp()) 32 | local s = 2 33 | local a = 3 34 | q1:add(s, a, 1) 35 | q1:mult(s, a, 3) 36 | q1:mult(s, a, 3) 37 | q1:mult(s, a, 3) 38 | 39 | local q2 = rl.QHash(rl.TestMdp()) 40 | local s = 2 41 | local a = 3 42 | q2:add(s, a, 1) 43 | q2:mult(s, a, 0) 44 | q2:add(s, a, 1) 45 | q2:mult(s, a, 3) 46 | q2:mult(s, a, 3) 47 | q2:mult(s, a, 3) 48 | 49 | tester:assert(q1 == q2) 50 | end 51 | 52 | tester:add(TestQHash) 53 | 54 | tester:run() 55 | -------------------------------------------------------------------------------- /Mdp.lua: -------------------------------------------------------------------------------- 1 | local Mdp = torch.class('rl.Mdp') 2 | 3 | function Mdp:__init() 4 | end 5 | 6 | function Mdp:step(state, action) 7 | error('Must implement step') 8 | end 9 | 10 | function Mdp:get_start_state() 11 | error('Must implement get_start_state') 12 | end 13 | 14 | function Mdp:is_terminal(state) 15 | error('Must implement is_terminal') 16 | end 17 | 18 | -- The next two functions shoul return all states/actions in a list (i.e. a 19 | -- Table with numbers as keys). These two methods might be really expensive to 20 | -- compute. It's the responsibility of the caller to take that into 21 | -- consideration. 22 | function Mdp:get_all_states() 23 | error('Must implement get_all_states') 24 | end 25 | 26 | function Mdp:get_all_actions() 27 | error('Must implement get_all_actions') 28 | end 29 | 30 | -- These hash functions for the state and action are used if you plan on using 31 | -- TableSarsa. Otherwise, use feature extractors. 32 | -- 33 | -- TODO: Move this to its own class, like SAFeatureExtractor. 34 | function Mdp:hash_s(state) 35 | error('Must implement hash_s') 36 | end 37 | 38 | function Mdp:hash_a(action) 39 | error('Must implement hash_a') 40 | end 41 | 42 | function Mdp:get_description() 43 | return 'Base MDP' 44 | end 45 | -------------------------------------------------------------------------------- /test/unittest_TestMdpQVAnalyzer.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestTestMdpQVAnalyzer = {} 5 | local qva = rl.TestMdpQVAnalyzer() 6 | 7 | local v = {} -- mock VFunc 8 | function v:get_value(state) 9 | return state 10 | end 11 | 12 | local q = {} -- mock QFunc 13 | function q:get_value(state, action) 14 | return state + action 15 | end 16 | 17 | function q:get_best_action(state) 18 | return 3 19 | end 20 | 21 | function TestTestMdpQVAnalyzer.test_get_v_tensor() 22 | local expected = torch.Tensor{1, 2, 3} 23 | tester:assertTensorEq(expected, qva:get_v_tensor(v), 0) 24 | end 25 | 26 | function TestTestMdpQVAnalyzer.test_get_q_tensor() 27 | local expected_table = { -- row = state, col = action 28 | {2, 3, 4}, 29 | {3, 4, 5}, 30 | {4, 5, 6}, 31 | } 32 | local expected = torch.Tensor(expected_table) 33 | tester:assertTensorEq(expected, qva:get_q_tensor(q), 0) 34 | end 35 | 36 | function TestTestMdpQVAnalyzer.test_v_from_q() 37 | local expected_v = {} 38 | function expected_v:get_value(state) 39 | return state + 3 40 | end 41 | 42 | local result_v = qva:v_from_q(q) 43 | 44 | tester:assert(result_v:__eq(expected_v)) 45 | end 46 | 47 | tester:add(TestTestMdpQVAnalyzer) 48 | 49 | tester:run() 50 | 51 | -------------------------------------------------------------------------------- /doc/policy.md: -------------------------------------------------------------------------------- 1 | ## Policy 2 | A Policy implements one method: 3 | 4 | `[action] get_action(state)` 5 | 6 | ### EpsilonGreedyPolicy 7 | The EpsilonGreedy policy is a simply policy that balances exploration and 8 | exploitation. The idea of epsilon greedy policies is to choose a random action 9 | with some small probability (epsilon). This encourages exploration. Otherwise, 10 | choose the best action, to exploit our knowledge so far. 11 | 12 | This is currently the only non-trivial policy implemented. 13 | 14 | ### Explorer 15 | This is used to choose how to balance exploration vs. exploitation. 16 | Specifically, it gives the probablity of exploring. So, it implements 17 | 18 | `[number from 0 to 1] get_eps(s)` 19 | 20 | which returns epsilon for the epsilon greedy policy. 21 | 22 | ### DecayTableExplorer 23 | This type of explorer chooses epsilon to be 24 | 25 | `N0 / (N0 + N(s))` 26 | 27 | where `N0` is some constant, and `N(s)` is the number of times state `s` has 28 | been visited. This type of exploration strategy with EpsilonGreedyPolicy is 29 | guaranteed to converge to the optimal policy since each state is explored, but 30 | eventually the best action is exploited. This is because as the number of times 31 | states has been visited explored (i.e. the more exploration we've done) the 32 | smaller epsilon because (i.e. don't bother exploring as much). 33 | -------------------------------------------------------------------------------- /rl.lua: -------------------------------------------------------------------------------- 1 | module('rl', package.seeall) 2 | 3 | module('rl.util', package.seeall) 4 | require('rl.util.io_util') 5 | require('rl.util.util') 6 | require('rl.util.mdputil') 7 | require('rl.util.tensorutil') 8 | require('rl.util.util_for_unittests') 9 | 10 | require('rl.rl_constants') 11 | 12 | require('rl.Policy') 13 | require('rl.GreedyPolicy') 14 | require('rl.AllActionsEqualPolicy') 15 | require('rl.TestPolicy') 16 | 17 | require('rl.Control') 18 | require('rl.ValueIteration') 19 | require('rl.MonteCarloControl') 20 | require('rl.Sarsa') 21 | require('rl.TableSarsa') 22 | require('rl.LinSarsa') 23 | require('rl.NNSarsa') 24 | 25 | require('rl.ControlFactory') 26 | require('rl.SarsaFactory') 27 | require('rl.TableSarsaFactory') 28 | require('rl.LinSarsaFactory') 29 | require('rl.NNSarsaFactory') 30 | 31 | require('rl.Mdp') 32 | require('rl.TestMdp') 33 | 34 | require('rl.SAFeatureExtractor') 35 | require('rl.TestSAFE') 36 | 37 | require('rl.QFunc') 38 | require('rl.QHash') 39 | require('rl.QApprox') 40 | require('rl.QLin') 41 | require('rl.QNN') 42 | 43 | require('rl.VFunc') 44 | require('rl.VHash') 45 | 46 | require('rl.MdpConfig') 47 | require('rl.MdpSampler') 48 | require('rl.EpisodeBuilder') 49 | 50 | require('rl.Explorer') 51 | require('rl.ConstExplorer') 52 | require('rl.DecayTableExplorer') 53 | 54 | require('rl.Evaluator') 55 | require('rl.QVAnalyzer') 56 | require('rl.SarsaAnalyzer') 57 | require('rl.TestMdpQVAnalyzer') 58 | -------------------------------------------------------------------------------- /MonteCarloControl.lua: -------------------------------------------------------------------------------- 1 | local MonteCarloControl, parent = 2 | torch.class('rl.MonteCarloControl', 'rl.ValueIteration') 3 | 4 | function MonteCarloControl:__init(mdp_config, N0) 5 | parent.__init(self, mdp_config) 6 | self.q = rl.QHash(self.mdp) 7 | self.Ns = rl.VHash(self.mdp) 8 | self.Nsa = rl.QHash(self.mdp) 9 | self.N0 = N0 or rl.MONTECARLOCONTROL_DEFAULT_N0 10 | self.actions = self.mdp.get_all_actions() 11 | end 12 | 13 | function MonteCarloControl:optimize_policy() 14 | self.policy = rl.GreedyPolicy( 15 | self.q, 16 | rl.DecayTableExplorer(self.N0, self.Ns), 17 | self.actions 18 | ) 19 | end 20 | 21 | function MonteCarloControl:evaluate_policy() 22 | local episode = self.sampler:get_episode(self.policy) 23 | for t, data in pairs(episode) do 24 | local s = data.state 25 | local a = data.action 26 | local Gt = data.discounted_return 27 | 28 | self.Ns:add(s, 1) 29 | self.Nsa:add(s, a, 1) 30 | 31 | local alpha = 1. / self.Nsa:get_value(s, a) 32 | self.q:add(s, a, alpha * (Gt - self.q:get_value(s, a))) 33 | end 34 | end 35 | 36 | function MonteCarloControl:get_q() 37 | return self.q 38 | end 39 | 40 | function MonteCarloControl:__eq(other) 41 | return torch.typename(self) == torch.typename(other) 42 | and self.q == other.q 43 | and self.Ns == other.Ns 44 | and self.Nsa == other.Nsa 45 | and self.N0 == other.N0 46 | end 47 | -------------------------------------------------------------------------------- /QApprox.lua: -------------------------------------------------------------------------------- 1 | -- Abstract class for a Q function approximation class 2 | local QApprox, parent = torch.class('rl.QApprox', 'rl.QFunc') 3 | 4 | function QApprox:__init(mdp, feature_extractor) 5 | parent.__init(self, mdp) 6 | self.feature_extractor = feature_extractor 7 | end 8 | 9 | function QApprox:clear() 10 | error('Must implement clear method') 11 | end 12 | 13 | function QApprox:get_value(s, a) 14 | error('Must implement get_Value method') 15 | end 16 | 17 | function QApprox:add(d_weights) 18 | error('Must implement add method') 19 | end 20 | 21 | function QApprox:mult_all(factor) 22 | error('Must implement mult_all method') 23 | end 24 | 25 | function QApprox:get_weight_vector() 26 | error('Must implement get_weight_vector method') 27 | end 28 | 29 | function QApprox:get_q_tensor() 30 | local value = torch.zeros(N_DEALER_STATES, N_PLAYER_STATES, N_ACTIONS) 31 | for dealer = 1, N_DEALER_STATES do 32 | for player = 1, N_PLAYER_STATES do 33 | for a = 1, N_ACTIONS do 34 | s = {dealer, player} 35 | value[s][a] = self:get_value(s, a) 36 | end 37 | end 38 | end 39 | return value 40 | end 41 | 42 | function QApprox:get_best_action(s) 43 | local actions = self.mdp:get_all_actions() 44 | local best_a, best_i = rl.util.max( 45 | actions, 46 | function (a) return self:get_value(s, a) end) 47 | return best_a 48 | end 49 | 50 | QApprox.__eq = parent.__eq -- force inheritance of this 51 | -------------------------------------------------------------------------------- /LinSarsa.lua: -------------------------------------------------------------------------------- 1 | -- Implement SARSA algorithm using a linear function approximator for on-line 2 | -- policy control 3 | local LinSarsa, parent = torch.class('rl.LinSarsa', 'rl.Sarsa') 4 | 5 | function LinSarsa:__init(mdp_config, lambda, explorer, feature_extractor, step_size) 6 | parent.__init(self, mdp_config, lambda) 7 | self.explorer = explorer 8 | self.feature_extractor = feature_extractor 9 | self.step_size = step_size 10 | self.q = rl.QLin(self.mdp, self.feature_extractor) 11 | self.eligibility = rl.QLin(self.mdp, self.feature_extractor) 12 | end 13 | 14 | function LinSarsa:get_new_q() 15 | return rl.QLin(self.mdp, self.feature_extractor) 16 | end 17 | 18 | function LinSarsa:reset_eligibility() 19 | self.eligibility = rl.QLin(self.mdp, self.feature_extractor) 20 | end 21 | 22 | function LinSarsa:update_eligibility(s, a) 23 | local features = self.feature_extractor:get_sa_features(s, a) 24 | self.eligibility:mult(self.discount_factor*self.lambda) 25 | self.eligibility:add(features) 26 | end 27 | 28 | function LinSarsa:td_update(td_error) 29 | self.q:add(self.eligibility:get_weight_vector() * self.step_size * td_error) 30 | end 31 | 32 | function LinSarsa:update_policy() 33 | self.policy = rl.GreedyPolicy( 34 | self.q, 35 | self.explorer, 36 | self.actions 37 | ) 38 | end 39 | 40 | function LinSarsa:__eq(other) 41 | return torch.typename(self) == torch.typename(other) 42 | and self.explorer == other.explorer 43 | and self.feature_extractor == other.feature_extractor 44 | and self.step_size == other.step_size 45 | and self.q == other.q 46 | end 47 | -------------------------------------------------------------------------------- /util/util_for_unittests.lua: -------------------------------------------------------------------------------- 1 | function rl.util.are_testmdp_policy_probabilities_good( 2 | policy, 3 | expected_probabilities) 4 | local n_times_action = {0, 0, 0} 5 | local n_iters = 10000 6 | for i = 1, n_iters do 7 | local state = math.random(1, 2) 8 | local a = policy:get_action(state) 9 | n_times_action[a] = n_times_action[a] + 1 10 | end 11 | 12 | for action = 1, 3 do 13 | if not rl.util.is_prob_good( 14 | n_times_action[action], 15 | expected_probabilities[action], 16 | n_iters) then 17 | return false 18 | end 19 | end 20 | return true 21 | end 22 | 23 | function rl.util.are_tensors_same_shape(t1, t2) 24 | if t1:dim() ~= t2:dim() then 25 | return false 26 | end 27 | for d = 1, t1:dim() do 28 | if (#t1)[d] ~= (#t2)[d] then 29 | return false 30 | end 31 | end 32 | return true 33 | end 34 | 35 | function rl.util.do_qtable_qfunc_match(mdp, q_table, qfunc) 36 | for _, state in pairs(mdp:get_all_states()) do 37 | for _, action in pairs(mdp:get_all_actions()) do 38 | local sa_value = qfunc:get_value(state, action) 39 | if math.abs(q_table[state][action] - sa_value) > rl.FLOAT_EPS then 40 | return false 41 | end 42 | end 43 | end 44 | return true 45 | end 46 | 47 | function rl.util.do_vtable_vfunc_match(mdp, v_table, vfunc) 48 | for _, state in pairs(mdp:get_all_states()) do 49 | local state_value = vfunc:get_value(state) 50 | if v_table[state] ~= state_value then 51 | return false 52 | end 53 | end 54 | return true 55 | end 56 | -------------------------------------------------------------------------------- /EpisodeBuilder.lua: -------------------------------------------------------------------------------- 1 | -- Episode: list of {state, action, discounted return, reward}, indexed by time. 2 | -- Time starts at 1 (going along with Lua conventions). 3 | -- This class builds this list intelligentally based on discount_factor, the discount 4 | -- factor. 5 | local EpisodeBuilder = torch.class('rl.EpisodeBuilder') 6 | 7 | function EpisodeBuilder:__init(discount_factor) 8 | if discount_factor < 0 or discount_factor > 1 then 9 | error('Gamma must be between 0 and 1, inclusive.') 10 | end 11 | self.t = 1 12 | self.episode = {} 13 | self.discount_factor = discount_factor 14 | self.built = false 15 | end 16 | 17 | function EpisodeBuilder:add_state_action_reward_step(state, action, reward) 18 | self.episode[self.t] = { 19 | state = state, 20 | action = action, 21 | discounted_return = reward, 22 | reward = reward 23 | } 24 | self.t = self.t + 1 25 | self.built = false 26 | end 27 | 28 | local function is_built(self) 29 | return self.discount_factor == 0 or self.built 30 | end 31 | 32 | function EpisodeBuilder:get_episode() 33 | if is_built(self) then 34 | return self.episode 35 | end 36 | 37 | local t = self.t - 1 38 | local discounted_future_return = self.discount_factor * self.episode[t].reward 39 | t = t - 1 40 | for i = 1, #self.episode - 1 do 41 | self.episode[t].discounted_return = self.episode[t].discounted_return + 42 | discounted_future_return 43 | discounted_future_return = self.discount_factor * (discounted_future_return + 44 | self.episode[t].reward) 45 | t = t - 1 46 | end 47 | self.built = true 48 | return self.episode 49 | end 50 | -------------------------------------------------------------------------------- /NNSarsa.lua: -------------------------------------------------------------------------------- 1 | -- Implement SARSA algorithm using a neural network function approximator for 2 | -- on-line policy control 3 | local NNSarsa, parent = torch.class('rl.NNSarsa', 'rl.Sarsa') 4 | 5 | function NNSarsa:__init(mdp_config, lambda, explorer, feature_extractor, step_size) 6 | parent.__init(self, mdp_config, lambda) 7 | self.explorer = explorer 8 | self.feature_extractor = feature_extractor 9 | self.step_size = step_size 10 | self.q = rl.QNN(mdp_config:get_mdp(), feature_extractor) 11 | self.last_state = nil 12 | self.last_action = nil 13 | self.momentum = self.lambda * self.discount_factor 14 | end 15 | 16 | function NNSarsa:get_new_q() 17 | return q.QNN:new() 18 | end 19 | 20 | function NNSarsa:reset_eligibility() 21 | self.last_state = nil 22 | self.last_action = nil 23 | self.q:reset_momentum() 24 | end 25 | 26 | function NNSarsa:update_eligibility(s, a) 27 | self.last_state = s 28 | self.last_action = a 29 | end 30 | 31 | function NNSarsa:td_update(td_error) 32 | local learning_rate = td_error * self.step_size 33 | self.q:backward( 34 | self.last_state, 35 | self.last_action, 36 | learning_rate, 37 | self.momentum) 38 | end 39 | 40 | function NNSarsa:update_policy() 41 | self.policy = rl.GreedyPolicy( 42 | self.q, 43 | self.explorer, 44 | self.actions 45 | ) 46 | end 47 | 48 | function NNSarsa:__eq(other) 49 | return torch.typename(self) == torch.typename(other) 50 | and self.explorer == other.explorer 51 | and self.feature_extractor == other.feature_extractor 52 | and self.step_size == other.step_size 53 | and self.q == other.q 54 | and self.last_state == other.last_state 55 | and self.last_action == other.last_action 56 | end 57 | -------------------------------------------------------------------------------- /TestMdpQVAnalyzer.lua: -------------------------------------------------------------------------------- 1 | local TestMdpQVAnalyzer, parent = 2 | torch.class('rl.TestMdpQVAnalyzer', 'rl.QVAnalyzer') 3 | 4 | function TestMdpQVAnalyzer:__init() 5 | parent.__init(self, rl.TestMdp()) 6 | self.n_states = #self.mdp.get_all_states() 7 | self.n_actions = #self.mdp.get_all_actions() 8 | end 9 | 10 | function TestMdpQVAnalyzer:get_v_tensor(v) 11 | local tensor = torch.zeros(self.n_states) 12 | for s = 1, self.n_states do 13 | tensor[s] = v:get_value(s) 14 | end 15 | return tensor 16 | end 17 | 18 | function TestMdpQVAnalyzer:plot_v(v) 19 | local tensor = self:get_v_tensor(v) 20 | local x = torch.Tensor(self.n_states) 21 | x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0) 22 | gnuplot.plot(x, tensor) 23 | gnuplot.xlabel('Dealer Showing') 24 | gnuplot.ylabel('State Value') 25 | gnuplot.title('Monte-Carlo State Value Function') 26 | end 27 | 28 | function TestMdpQVAnalyzer:get_q_tensor(q) 29 | local tensor = torch.zeros(self.n_states, self.n_actions) 30 | for s = 1, self.n_states do 31 | for a = 1, self.n_actions do 32 | tensor[s][a] = q:get_value(s, a) 33 | end 34 | end 35 | return tensor 36 | end 37 | 38 | function TestMdpQVAnalyzer:plot_best_action(q) 39 | local best_action_at_state = torch.Tensor(self.n_states) 40 | for s = 1, self.n_states do 41 | best_action_at_state[s] = q:get_best_action(s) 42 | end 43 | local x = torch.Tensor(self.n_actions) 44 | x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0) 45 | gnuplot.plot(x, best_action_at_state) 46 | gnuplot.xlabel('State') 47 | gnuplot.zlabel('Best Action') 48 | gnuplot.title('Learned Best Action Based on q') 49 | end 50 | 51 | function TestMdpQVAnalyzer:v_from_q(q) 52 | local v = rl.VHash(self.mdp) 53 | for s = 1, self.n_states do 54 | local a = q:get_best_action(s) 55 | v:add(s, q:get_value(s, a)) 56 | end 57 | return v 58 | end 59 | -------------------------------------------------------------------------------- /doc/mdp.md: -------------------------------------------------------------------------------- 1 | ## Markov Decision Proccess (MDP) 2 | Markov Decision Proccesses (MDPs) are at the heard of the RL algorithms 3 | implemented. Here, they are represented as a class. The definition of the MDP 4 | class will depend on the particular problem. 5 | 6 | The biggest idea of MDPs is that they are memoryless. The state of an MDP should 7 | be enough to determine what happens next. 8 | 9 | ### MDP Config 10 | Most other classes will require a `MdpConfig` instance instead of a `Mdp` 11 | instance. `MdpConfig` is a wrapper data structure that contains an MDP and 12 | configuration, such as the discount factor. A common pattern is the following: 13 | 14 | ```lua 15 | local mdp = TestMdp() 16 | local discount_factor = 0.9 17 | 18 | local mdp_config = MdpConfig(mdp, discount_factor) 19 | --- use mdp_config for future calls 20 | ``` 21 | 22 | ### MdpSampler 23 | The MdpSampler is a wrapper around an Mdp that out provides some convenience 24 | methods for sampling the MDP, namely: 25 | 26 | * `[number] sample_reward(policy)` 27 | * `[episode] get_episode(policy)` 28 | 29 | An episode is a table of {state, action, discounted return, reward}, indexed by 30 | time. Time starts at 1 (going along with Lua conventions). 31 | 32 | 33 | ### Creating Your Own MDP 34 | To create a MDP, extend the base MDP class using torch: 35 | 36 | ```lua 37 | require 'Mdp' 38 | local MyMdp, parent = torch.class('MyMdp', 'Mdp') 39 | 40 | function MyMdp:__init(arg1) 41 | parent.__init(self) 42 | end 43 | ``` 44 | 45 | The main functions that an MDP needs to be implemented are 46 | 47 | * `[next_state, reward] step(state, action)` Note that state should capture 48 | everything needed to compute the next state and reward, given an action. 49 | * `[state] get_start_state()` 50 | * `[boolean] is_terminal(state)` 51 | 52 | Check out [Mdp.lua](../Mdp.lua) for detail on other the functions that you may 53 | want to implement. See [Blackjack.lua](../BlackJack.lua) for an example. 54 | 55 | -------------------------------------------------------------------------------- /test/unittest_LinSarsa.lua: -------------------------------------------------------------------------------- 1 | local tester = torch.Tester() 2 | 3 | local discount_factor = 0.95 4 | local mdp = rl.TestMdp() 5 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 6 | local fe = rl.TestSAFE() 7 | 8 | local TestLinSarsa = {} 9 | 10 | function TestLinSarsa.test_update_eligibility_one_step() 11 | local lambda = 1 12 | local eps = 0.032 13 | local explorer = rl.ConstExplorer(eps) 14 | local step_size = 0.05 15 | local sarsa = rl.LinSarsa(mdp_config, lambda, explorer, fe, step_size) 16 | 17 | local s = 2 18 | local a = 1 19 | sarsa:update_eligibility(s, a) 20 | 21 | local eligibility_expected = rl.QLin(mdp, fe) 22 | eligibility_expected.weights[1] = s + a 23 | eligibility_expected.weights[2] = s - a 24 | 25 | tester:assert(sarsa.eligibility == eligibility_expected) 26 | end 27 | 28 | function TestLinSarsa.test_update_eligibility_many_steps() 29 | local lambda = 0.5 30 | local eps = 0.032 31 | local explorer = rl.ConstExplorer(eps) 32 | local step_size = 0.05 33 | local sarsa = rl.LinSarsa(mdp_config, lambda, explorer, fe, step_size) 34 | 35 | local s = 2 36 | local a = 1 37 | sarsa:update_eligibility(s, a) 38 | 39 | local eligibility_expected = rl.QLin(mdp, fe) 40 | eligibility_expected.weights = eligibility_expected.weights 41 | + torch.Tensor{s+a, s-a} 42 | 43 | local decay_factor = discount_factor * lambda 44 | eligibility_expected.weights = eligibility_expected.weights * decay_factor 45 | 46 | s = 2 47 | a = 2 48 | sarsa:update_eligibility(s, a) 49 | eligibility_expected.weights = 50 | eligibility_expected.weights + torch.Tensor{s+a, s-a} 51 | 52 | eligibility_expected.weights = eligibility_expected.weights * decay_factor 53 | 54 | s = 2 55 | a = 1 56 | sarsa:update_eligibility(s, a) 57 | eligibility_expected.weights = 58 | eligibility_expected.weights + torch.Tensor{s+a, s-a} 59 | 60 | tester:assert(sarsa.eligibility == eligibility_expected) 61 | end 62 | 63 | tester:add(TestLinSarsa) 64 | 65 | tester:run() 66 | 67 | -------------------------------------------------------------------------------- /Sarsa.lua: -------------------------------------------------------------------------------- 1 | -- Abstract class for implementing SARSA. 2 | -- See end of file for functions that must be implemented. 3 | local Sarsa, parent = torch.class('rl.Sarsa', 'rl.Control') 4 | 5 | function Sarsa:__init(mdp_config, lambda) 6 | parent.__init(self, mdp_config) 7 | self.lambda = lambda 8 | self.q = nil 9 | self.actions = self.mdp:get_all_actions() 10 | self.discount_factor = mdp_config:get_discount_factor() 11 | end 12 | 13 | function Sarsa:run_episode(s, a) 14 | self:reset_eligibility() 15 | -- Can't use sampler because we're updating policy at each step 16 | while not self.mdp:is_terminal(s) do 17 | local s_new, r = self.mdp:step(s, a) 18 | local td_error, a_new = nil, nil 19 | if s_new == nil then 20 | td_error = r - self.q:get_value(s, a) 21 | else 22 | a_new = self.policy:get_action(s_new) 23 | td_error = r + self.discount_factor*self.q:get_value(s_new, a_new) 24 | - self.q:get_value(s, a) 25 | end 26 | self:update_eligibility(s, a) 27 | self:td_update(td_error) 28 | self:update_policy() 29 | 30 | s = s_new 31 | a = a_new 32 | end 33 | end 34 | 35 | function Sarsa:improve_policy() 36 | local s = self.mdp:get_start_state() 37 | local a = self.policy:get_action(s) 38 | self:run_episode(s, a) 39 | end 40 | 41 | function Sarsa:get_q() 42 | return self.q 43 | end 44 | 45 | -- Return an instance of a Q class. 46 | function Sarsa:get_new_q() 47 | error('Must implement get_new_q') 48 | end 49 | 50 | -- Clear self.eligibility for a new episode 51 | function Sarsa:reset_eligibility() 52 | error('Must implement reset_eligibility') 53 | end 54 | 55 | -- Update self.eligibility after visiting state s and taking action a 56 | function Sarsa:update_eligibility(s, a) 57 | error('Must implement update_eligibility') 58 | end 59 | 60 | -- Implement the TD update rule, given a TD error. 61 | function Sarsa:td_update(td_error) 62 | error('Must implement td_error') 63 | end 64 | 65 | -- Update self.policy. 66 | function Sarsa:update_policy() 67 | error('Must implement update_policy') 68 | end 69 | -------------------------------------------------------------------------------- /test/unittest_EpisodeBuilder.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestEpisodeBuilder = {} 5 | function are_discounted_return_good( 6 | discount_factor, 7 | rewards, 8 | expected_discounted_returns) 9 | local builder = rl.EpisodeBuilder(discount_factor) 10 | local state = 5 11 | local action = 9 12 | 13 | for _, r in pairs(rewards) do 14 | builder:add_state_action_reward_step(state, action, r) 15 | end 16 | 17 | local expected = {} 18 | for i, expected_discounted_return in pairs(expected_discounted_returns) do 19 | expected[i] = { 20 | state = state, 21 | action = action, 22 | discounted_return = expected_discounted_return, 23 | reward = rewards[i] 24 | } 25 | end 26 | 27 | return rl.util.deepcompare(expected, builder:get_episode()) 28 | end 29 | function TestEpisodeBuilder.test_gamma_one() 30 | local discount_factor = 1 31 | local rewards = {1, 1, 1, 1} 32 | local expected_discounted_returns = {4, 3, 2, 1} 33 | tester:assert(are_discounted_return_good(discount_factor, 34 | rewards, 35 | expected_discounted_returns)) 36 | end 37 | 38 | function TestEpisodeBuilder.test_gamma_zero() 39 | local discount_factor = 0 40 | local rewards = {1, 1, 1, 1} 41 | local expected_discounted_returns = {1, 1, 1, 1} 42 | tester:assert(are_discounted_return_good(discount_factor, 43 | rewards, 44 | expected_discounted_returns)) 45 | end 46 | 47 | function TestEpisodeBuilder.test_gamma_fraction() 48 | local discount_factor = 0.5 49 | local rewards = {1, 1, 1, 1} 50 | local expected_discounted_returns = { 51 | 1 + 0.5 * (1 + 0.5 * (1+0.5)), 52 | 1 + 0.5 * (1 + 0.5), 53 | 1 + 0.5, 54 | 1} 55 | tester:assert(are_discounted_return_good(discount_factor, 56 | rewards, 57 | expected_discounted_returns)) 58 | end 59 | 60 | tester:add(TestEpisodeBuilder) 61 | 62 | tester:run() 63 | -------------------------------------------------------------------------------- /doc/sarsa.md: -------------------------------------------------------------------------------- 1 | ## Sarsa-lambda 2 | See [Sutton and 3 | Barto](https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node77.html) for 4 | explanation of algorithm. One major difference is that we used discounted 5 | reward. 6 | 7 | The main step that is abstracted away is how Q(s, a) is updated given the TD 8 | error. The underlying structure of Q (e.g. hash table vs. function approximator) 9 | will determine how it is handled. The eligibility update will also change base 10 | on the structure. 11 | 12 | Three versions are implemented: 13 | * TableSarsa - Use lookup tables to store Q 14 | * LinSarsa - Use a linear dot product of features to approximate Q 15 | * NnSarsa - Use a neural network to approximate Q 16 | 17 | ### Linear Approximation 18 | Documentation is a TODO - see LinSarsa.lua for now. 19 | 20 | ### Neural Network Approximation 21 | Documentation is a TODO - see NNSarsa.lua for now. 22 | 23 | ### Sarsa-lambda Analysis Scripts 24 | Below are more intesting scripts that compare Sarsa-lambda algorithms perform 25 | relative to Monte Carlo (MC) Control. 26 | * `analyze_table_sarsa.lua` 27 | * `analyze_lin_sarsa.lua` 28 | * `analyze_nn_sarsa.lua` 29 | 30 | MC Control is used as a baseline because it gives an unbiased estimate of the 31 | true Q (state-action value) function. In each of the scripts, two plots get 32 | generated: (1) root mean square (RMS) error of the estimated Q function vs 33 | lambda. (2) RMS error of the estimated Q function vs # iterations for lambda = 0 34 | and lambda = 1. 35 | 36 | To save time, you can generated the Q function from MC Control, save 37 | it, and then load it back up in the above scripts. Generate a MC Q file with 38 | 39 | `$ th generate_q_mc.lua -saveqto .dat` 40 | 41 | and use this file when running the above scripts with the following. 42 | 43 | `$ th analyze_table_sarsa.lua -loadqfrom .dat` 44 | 45 | Run these scripts with the -h option for more help. 46 | 47 | ### Example Plots 48 | 49 | For `analyze_table_sarsa` with the number of iterations set to 10^5, you get the 50 | following plots 51 | 52 | ![RMS vs Lambda for TableSarsa](../images/table_sarsa_rms_vs_lambda.png "RMS vs 53 | Lambda for TableSarsa") 54 | 55 | ![RMS vs Episode for TableSarsa, lambda = 0 and 56 | 1](../images/table_sarsa_rms_vs_iteration.png "RMS vs Iteration for TableSarsa, 57 | lambda = 1 and 1") 58 | -------------------------------------------------------------------------------- /TableSarsa.lua: -------------------------------------------------------------------------------- 1 | -- Implement SARSA algorithm using a linear function approximator for on-line 2 | -- policy control 3 | local TableSarsa, parent = torch.class('rl.TableSarsa', 'rl.Sarsa') 4 | function TableSarsa:__init(mdp_config, lambda) 5 | parent.__init(self, mdp_config, lambda) 6 | self.Ns = rl.VHash(self.mdp) 7 | self.Nsa = rl.QHash(self.mdp) 8 | self.q = rl.QHash(self.mdp) 9 | self.eligibility = rl.QHash(self.mdp) 10 | end 11 | 12 | function TableSarsa:get_new_q() 13 | return rl.QHash(self.mdp) 14 | end 15 | 16 | function TableSarsa:reset_eligibility() 17 | self.eligibility = rl.QHash(self.mdp) 18 | end 19 | 20 | function TableSarsa:update_eligibility(s, a) 21 | for _, state in pairs(self.mdp:get_all_states()) do 22 | for _, action in pairs(self.mdp:get_all_actions()) do 23 | self.eligibility:mult( 24 | state, 25 | action, 26 | self.discount_factor*self.lambda) 27 | end 28 | end 29 | self.eligibility:add(s, a, 1) 30 | self.Ns:add(s, 1) 31 | self.Nsa:add(s, a, 1) 32 | end 33 | 34 | local function get_step_size(self, state, action) 35 | local value = self.Nsa:get_value(state, action) 36 | if value == 0 then 37 | return value 38 | end 39 | return 1. / value 40 | end 41 | 42 | function TableSarsa:td_update(td_error) 43 | for _, state in pairs(self.mdp:get_all_states()) do 44 | for _, action in pairs(self.mdp:get_all_actions()) do 45 | local step_size = get_step_size(self, state, action) 46 | local eligibility = self.eligibility:get_value(state, action) 47 | self.q:add( 48 | state, 49 | action, 50 | step_size * td_error * eligibility) 51 | end 52 | end 53 | end 54 | 55 | function TableSarsa:update_policy() 56 | self.explorer = rl.DecayTableExplorer( 57 | rl.MONTECARLOCONTROL_DEFAULT_N0, 58 | self.Ns) 59 | self.policy = rl.GreedyPolicy( 60 | self.q, 61 | self.explorer, 62 | self.actions 63 | ) 64 | end 65 | 66 | function TableSarsa:__eq(other) 67 | return torch.typename(self) == torch.typename(other) 68 | and self.Ns == other.Ns 69 | and self.Nsa == other.Nsa 70 | and self.q == other.q 71 | and self.eligibility == other.eligibility 72 | end 73 | -------------------------------------------------------------------------------- /test/unittest_MonteCarloControl.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | math.randomseed(os.time()) 4 | local tester = torch.Tester() 5 | 6 | local TestMonteCarloControl = {} 7 | function TestMonteCarloControl.test_evalute_policy() 8 | local mdp = rl.TestMdp() 9 | local discount_factor = 1 10 | -- With this policy, the episode will be: 11 | -- Step 1 12 | -- state: 1 13 | -- action: 1 14 | -- reward: -1 15 | -- Gt: -2 16 | -- 17 | -- Step 2 18 | -- state 2: 2 19 | -- action 1: 1 20 | -- reward 1: -1 21 | -- Gt 1: -1 22 | local policy = rl.TestPolicy(1) 23 | local config = rl.MdpConfig(mdp, discount_factor) 24 | local mcc = rl.MonteCarloControl(config) 25 | mcc:set_policy(policy) 26 | mcc:evaluate_policy() 27 | 28 | local expected = rl.MonteCarloControl(config) 29 | local q = rl.QHash(mdp) 30 | q:add(1, 1, -2) 31 | q:add(2, 1, -1) 32 | local Ns = rl.VHash(mdp) 33 | Ns:add(1, 1) 34 | Ns:add(2, 1) 35 | local Nsa = rl.QHash(mdp) 36 | Nsa:add(1, 1, 1) 37 | Nsa:add(2, 1, 1) 38 | local N0 = rl.MONTECARLOCONTROL_DEFAULT_N0 39 | expected.q = q 40 | expected.Ns = Ns 41 | expected.Nsa = Nsa 42 | expected.N0 = N0 43 | 44 | tester:assert(mcc == expected) 45 | end 46 | 47 | function TestMonteCarloControl.test_optimize_policy() 48 | local mdp = rl.TestMdp() 49 | local discount_factor = 1 50 | local config = rl.MdpConfig(mdp, discount_factor) 51 | local mcc = rl.MonteCarloControl(config) 52 | local q = rl.QHash(mdp) 53 | q:add(1, 1, 100) -- make action "1" be the best action 54 | q:add(2, 1, 100) 55 | local Ns = rl.VHash(mdp) 56 | local n_times_states_visited = 10 -- make other actions seem really explored 57 | Ns:add(1, n_times_states_visited) 58 | Ns:add(2, n_times_states_visited) 59 | local N0 = 1 60 | 61 | mcc.q = q 62 | mcc.Ns = Ns 63 | mcc.N0 = N0 64 | mcc:optimize_policy() 65 | 66 | local policy = mcc:get_policy() 67 | 68 | local expected_eps = N0 / (N0 + n_times_states_visited) 69 | local expected_probabilities = { 70 | 1 - 2* expected_eps / 3, 71 | expected_eps / 3, 72 | expected_eps / 3 73 | } 74 | 75 | tester:assert(rl.util.are_testmdp_policy_probabilities_good( 76 | policy, 77 | expected_probabilities)) 78 | end 79 | 80 | tester:add(TestMonteCarloControl) 81 | 82 | tester:run() 83 | -------------------------------------------------------------------------------- /rocks/rl-0.1-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "rl" 2 | version = "0.1-1" 3 | 4 | source = { 5 | url = "git://github.com/vpong/torch-rl.git", 6 | tag = "v0.1" 7 | } 8 | 9 | description = { 10 | summary = "A package for basic reinforcement learning algorithms.", 11 | detailed = [[ 12 | A package for basic reinforcement learning algorithms 13 | ]], 14 | homepage = "https://github.com/vpong/torch-rl" 15 | } 16 | 17 | dependencies = { 18 | "lua ~> 5.1", 19 | "torch >= 7.0" 20 | } 21 | 22 | build = { 23 | type = "builtin", 24 | modules = { 25 | AllActionsEqualPolicy = "AllActionsEqualPolicy.lua", 26 | constants = "constants.lua", 27 | ConstExplorer = "ConstExplorer.lua", 28 | ControlFactory = "ControlFactory.lua", 29 | Control = "Control.lua", 30 | DecayTableExplorer = "DecayTableExplorer.lua", 31 | EpisodeBuilder = "EpisodeBuilder.lua", 32 | Evaluator = "Evaluator.lua", 33 | Explorer = "Explorer.lua", 34 | GreedyPolicy = "GreedyPolicy.lua", 35 | LinSarsaFactory = "LinSarsaFactory.lua", 36 | LinSarsa = "LinSarsa.lua", 37 | MdpConfig = "MdpConfig.lua", 38 | Mdp = "Mdp.lua", 39 | MdpSampler = "MdpSampler.lua", 40 | MonteCarloControl = "MonteCarloControl.lua", 41 | NNSarsaFactory = "NNSarsaFactory.lua", 42 | NNSarsa = "NNSarsa.lua", 43 | Policy = "Policy.lua", 44 | QApprox = "QApprox.lua", 45 | QFunc = "QFunc.lua", 46 | QHash = "QHash.lua", 47 | QLin = "QLin.lua", 48 | QNN = "QNN.lua", 49 | QVAnalyzer = "QVAnalyzer.lua", 50 | rl = "rl.lua", 51 | SAFeatureExtractor = "SAFeatureExtractor.lua", 52 | SarsaAnalyzer = "SarsaAnalyzer.lua", 53 | SarsaFactory = "SarsaFactory.lua", 54 | Sarsa = "Sarsa.lua", 55 | TableSarsaFactory = "TableSarsaFactory.lua", 56 | TableSarsa = "TableSarsa.lua", 57 | TestMdp = "TestMdp.lua", 58 | TestMdpQVAnalyzer = "TestMdpQVAnalyzer.lua", 59 | TestPolicy = "TestPolicy.lua", 60 | TestSAFE = "TestSAFE.lua", 61 | ValueIteration = "ValueIteration.lua", 62 | VFunc = "VFunc.lua", 63 | VHash = "VHash.lua", 64 | ["util.io_util"] = "util/io_util.lua", 65 | ["util.mdputil"] = "util/mdputil.lua", 66 | ["util.tensorutil"] = "util/tensorutil.lua", 67 | ["util.util_for_unittests"] = "util/util_for_unittests.lua", 68 | ["util.util"] = "util/util.lua" 69 | }, 70 | copy_directories = { "doc" , "test"} 71 | } 72 | -------------------------------------------------------------------------------- /rl-0.2-5.rockspec: -------------------------------------------------------------------------------- 1 | package = "rl" 2 | version = "0.2-5" 3 | 4 | source = { 5 | url = "git://github.com/vpong/torch-rl.git", 6 | tag = "v0.2.3" 7 | } 8 | 9 | description = { 10 | summary = "A package for basic reinforcement learning algorithms.", 11 | detailed = [[ 12 | A package for basic reinforcement learning algorithms 13 | ]], 14 | homepage = "https://github.com/vpong/torch-rl" 15 | } 16 | 17 | dependencies = { 18 | "lua ~> 5.1", 19 | "torch >= 7.0" 20 | } 21 | 22 | build = { 23 | type = "builtin", 24 | modules = { 25 | ["rl.AllActionsEqualPolicy"] = "AllActionsEqualPolicy.lua", 26 | ["rl.ConstExplorer"] = "ConstExplorer.lua", 27 | ["rl.ControlFactory"] = "ControlFactory.lua", 28 | ["rl.Control"] = "Control.lua", 29 | ["rl.DecayTableExplorer"] = "DecayTableExplorer.lua", 30 | ["rl.EpisodeBuilder"] = "EpisodeBuilder.lua", 31 | ["rl.Evaluator"] = "Evaluator.lua", 32 | ["rl.Explorer"] = "Explorer.lua", 33 | ["rl.GreedyPolicy"] = "GreedyPolicy.lua", 34 | ["rl.LinSarsaFactory"] = "LinSarsaFactory.lua", 35 | ["rl.LinSarsa"] = "LinSarsa.lua", 36 | ["rl.MdpConfig"] = "MdpConfig.lua", 37 | ["rl.Mdp"] = "Mdp.lua", 38 | ["rl.MdpSampler"] = "MdpSampler.lua", 39 | ["rl.MonteCarloControl"] = "MonteCarloControl.lua", 40 | ["rl.NNSarsaFactory"] = "NNSarsaFactory.lua", 41 | ["rl.NNSarsa"] = "NNSarsa.lua", 42 | ["rl.Policy"] = "Policy.lua", 43 | ["rl.QApprox"] = "QApprox.lua", 44 | ["rl.QFunc"] = "QFunc.lua", 45 | ["rl.QHash"] = "QHash.lua", 46 | ["rl.QLin"] = "QLin.lua", 47 | ["rl.QNN"] = "QNN.lua", 48 | ["rl.QVAnalyzer"] = "QVAnalyzer.lua", 49 | ["rl"] = "rl.lua", 50 | ["rl.rl_constants"] = "rl_constants.lua", 51 | ["rl.SAFeatureExtractor"] = "SAFeatureExtractor.lua", 52 | ["rl.SarsaAnalyzer"] = "SarsaAnalyzer.lua", 53 | ["rl.SarsaFactory"] = "SarsaFactory.lua", 54 | ["rl.Sarsa"] = "Sarsa.lua", 55 | ["rl.TableSarsaFactory"] = "TableSarsaFactory.lua", 56 | ["rl.TableSarsa"] = "TableSarsa.lua", 57 | ["rl.TestMdp"] = "TestMdp.lua", 58 | ["rl.TestMdpQVAnalyzer"] = "TestMdpQVAnalyzer.lua", 59 | ["rl.TestPolicy"] = "TestPolicy.lua", 60 | ["rl.TestSAFE"] = "TestSAFE.lua", 61 | ["rl.ValueIteration"] = "ValueIteration.lua", 62 | ["rl.VFunc"] = "VFunc.lua", 63 | ["rl.VHash"] = "VHash.lua", 64 | ["rl.util.io_util"] = "util/io_util.lua", 65 | ["rl.util.mdputil"] = "util/mdputil.lua", 66 | ["rl.util.tensorutil"] = "util/tensorutil.lua", 67 | ["rl.util.util_for_unittests"] = "util/util_for_unittests.lua", 68 | ["rl.util.util"] = "util/util.lua" 69 | }, 70 | copy_directories = { "doc" , "test"} 71 | } 72 | -------------------------------------------------------------------------------- /QNN.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | local nngraph = require 'nngraph' 3 | local dpnn = require 'dpnn' 4 | 5 | -- Implementation of a state-action value function approx using a neural network 6 | local QNN, parent = torch.class('rl.QNN', 'rl.QApprox') 7 | 8 | local function get_module(self) 9 | local x = nn.Identity()() 10 | local l1 = nn.Linear(self.n_features, 1)(x) 11 | return nn.gModule({x}, {l1}) 12 | end 13 | 14 | function QNN:is_linear() 15 | return true 16 | end 17 | 18 | function QNN:__init(mdp, feature_extractor) 19 | parent.__init(self, mdp, feature_extractor) 20 | self.n_features = feature_extractor:get_sa_num_features() 21 | self.module = get_module(self) 22 | self.is_first_update = true 23 | end 24 | 25 | -- This took forever to figure out. See 26 | -- https://github.com/Element-Research/dpnn/blob/165ce5ff37d0bb77c207e82f5423ade08593d020/Module.lua#L488 27 | -- for detail. 28 | local function reset_momentum(net) 29 | net.momGradParams = nil 30 | if net.modules then 31 | for _, child in pairs(net.modules) do 32 | reset_momentum(child) 33 | end 34 | end 35 | end 36 | 37 | function QNN:reset_momentum() 38 | self.is_first_update = true 39 | self.module:zeroGradParameters() 40 | reset_momentum(self.module) 41 | end 42 | 43 | function QNN:get_value(s, a) 44 | local input = self.feature_extractor:get_sa_features(s, a) 45 | return self.module:forward(input)[1] 46 | end 47 | 48 | -- For now, we're ignoring eligibility. This means that the update rule to the 49 | -- parameters W of the network is: 50 | -- 51 | -- W <- W + step_size * td_error * dQ(s,a)/dW 52 | -- 53 | -- We can force the network to update this way by recognizing that the "loss" 54 | -- is literally just the output of the network. This makes it so that 55 | -- 56 | -- dLoss/dW = dQ(s, a)/dW 57 | -- 58 | -- So, the network will update correctly if we just tell it that the output is 59 | -- the loss, i.e. set grad_out = 1. 60 | -- 61 | -- For more detail where the update rule, see 62 | -- https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node89.html 63 | -- 64 | -- For more detail on how nn works, see 65 | -- https://github.com/torch/nn/blob/master/doc/module.md 66 | function QNN:backward(s, a, learning_rate, momentum) 67 | -- forward to make sure input is set correctly 68 | local input = self.feature_extractor:get_sa_features(s, a) 69 | local output = self.module:forward(input) 70 | -- backward 71 | local grad_out = torch.ones(#output) 72 | self.module:zeroGradParameters() 73 | self.module:backward(input, grad_out) 74 | -- update 75 | -- This check is necessary because of the way updateGradParameters is 76 | -- implemented this. This makes sure that the first update doesn't give 77 | -- momentum itself. However, future updates should rely on the momentum of 78 | -- previous calls. 79 | -- 80 | -- Also, we can't just put the call to updateGradParameters() before the 81 | -- call to backward() because the zeroGradParameters() call messes it up. 82 | if self.is_first_update then 83 | self.is_first_update = false 84 | else 85 | self.module:updateGradParameters(momentum, 0, false) -- momentum (dpnn) 86 | end 87 | self.module:updateParameters(-learning_rate) -- W = W - rate * dL/dW 88 | end 89 | 90 | QNN.__eq = parent.__eq -- force inheritance of this 91 | -------------------------------------------------------------------------------- /doc/index.md: -------------------------------------------------------------------------------- 1 | # torch-rl 2 | This is a Torch 7 package that implements a few reinforcement learning 3 | algorithms. So far, we've only implemented Q-learning. 4 | 5 | This documentation is intended to mostly to give a high-level idea of what each 6 | abstract class does. We start with a summary of the important files, and then 7 | give a slightly more detailed description afterwards. For more detail, see the 8 | source code. For examples on how to use the functions, see the unit tests. 9 | 10 | ## Summary of files 11 | Files that start with upper cases are classes. Every other file is a script, 12 | except for the constants and util file. This gives a summary of the most 13 | important files. 14 | 15 | ### Interfaces/abstract classes 16 | * `Control.lua` - Represents an algorithm that improves a policy. 17 | * `Mdp.lua` - A Markov Decision Proccess that represents an environemtn. 18 | * `Policy.lua` - A way of deciding what action to do given a state. 19 | * `Sarsa.lua` - A specific Control algorithm. Technically, it's Sarsa-lambda 20 | * `SAFeatureExtractor.lua` - Represents a way of extracting features from a given 21 | [S]tate-[A]ction pair. 22 | * `ControlFactory.lua` - Used to create new Control instances 23 | * `Explorer.lua` - Used to get the epsilon value for EpsilonGreedy 24 | 25 | ### Concrete classes 26 | * `Evaluator.lua` - Used to measure the performance of a policy 27 | * `MdpConfig.lua` - A way of configuring an Mdp. 28 | * `MdpSampler.lua` - A useful wrapper around an Mdp. 29 | * `QVAnalyzer.lua` - Used to get measurements out of Control algorithms 30 | 31 | ### Specific implementations 32 | * `EpsilonGreedyPolicy.lua` - Implements epsilon greedy policy. 33 | * `DecayTableExplorer.lua` - A way of decaying epsilon for epsilon greedy policy 34 | to ensure convergence. 35 | * `NNSarsa.lua` - Implements Sarsa-lambda using neural networks as a function 36 | approximator 37 | * `LinSarsa.lua` - Implements Sarsa-lambda using linear weighting as a function 38 | approximator 39 | * `TableSarsa.lua` - Implements Sarsa-lambda using a lookup table. 40 | * `MonteCarloControl.lua` - Implements Monte Carlo control. 41 | 42 | ### Test Files 43 | * `unittest_*.lua` - Unit tests. Can be run directly with `th unittest_*.lua`. 44 | * `run_rl_unittests.lua` - Run all unit tests related to this package. 45 | * `run_BlackJack_unittests.lua` - Run all unit tests related to Black Jack. 46 | * `run_all_unittests.lua` - Run all unit tests in this package. 47 | * `TestMdp.lua` - An MDP used for testing. 48 | * `TestPolicy.lua` - A policy for TestMdp used for testing. 49 | * `TestSAFE.lua` - A feature extractor used for testing. 50 | 51 | ## Read More 52 | * [MDP](doc/mdp.md) - Markov Decision Processes are the foundation of how 53 | reinforcement learning models the world. 54 | * [Policy](doc/policy.md) - Policies are mappings from state to action. 55 | * [Sarsa](doc/sarsa.md) - Read about the Sarsa-lambda algorithm and scripts that 56 | test them. 57 | * [Monte Carlo Control](doc/montecarlo.md) - Read about Monte Carlo Control and 58 | how to use it. 59 | * [Value Functions](doc/valuefunctions.md) - Value functions represent how 60 | valuable certains states and/or actions are. 61 | * [Black Jack](doc/blackjack.md) - An example MDP that is a simplified version 62 | of black jack. 63 | 64 | ## A note on Abstract Classes and private methods 65 | Torch doesn't implement interfaces nor abstract classes natively, but this 66 | packages tries to implement them by defining functions and raising an error if 67 | you try to implement it. (We'll call everything an abstract class 68 | just for simplicity.) 69 | 70 | Also, Torch classes don't provide a way of making private methods, so this is 71 | faked with the following: 72 | 73 | ```lua 74 | 75 | local function private_method(self, arg1, ...) 76 | 77 | end 78 | 79 | -- ... 80 | 81 | function Foo:bar() 82 | self:public_method(arg1, ...) -- Call public methods normally 83 | private_method(self, arg1, ...) -- Use this to call private methods 84 | end 85 | ``` 86 | -------------------------------------------------------------------------------- /test/unittest_NNSarsa.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | 5 | local discount_factor = 0.95 6 | local mdp = rl.TestMdp() 7 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 8 | local fe = rl.TestSAFE() 9 | local lambda = 1 10 | local eps = 0.032 11 | local explorer = rl.ConstExplorer(eps) 12 | local step_size = 0.05 13 | 14 | local function get_sarsa() 15 | return rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size) 16 | end 17 | 18 | local TestNNSarsa = {} 19 | function TestNNSarsa.test_update_eligibility_one_step() 20 | local sarsa = get_sarsa() 21 | 22 | local s = 2 23 | local a = 1 24 | sarsa:update_eligibility(s, a) 25 | 26 | tester:assert(sarsa.last_state == s and sarsa.last_action == a) 27 | end 28 | 29 | function TestNNSarsa.test_td_update_once() 30 | local sarsa = get_sarsa() 31 | local expected_module = sarsa.q.module:clone() 32 | 33 | local s = 2 34 | local a = 1 35 | local td_error = -0.4 36 | sarsa:update_eligibility(s, a) 37 | sarsa:td_update(td_error) 38 | 39 | local input = fe:get_sa_features(s, a) 40 | expected_module:forward(input) 41 | local grad_out = torch.Tensor{1} 42 | expected_module:backward(input, grad_out) 43 | local momentum = lambda * discount_factor 44 | local learning_rate = step_size * td_error 45 | expected_module:updateParameters(-learning_rate) 46 | 47 | local expected_sarsa = rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size) 48 | expected_sarsa.q.module = expected_module 49 | expected_sarsa.last_state = s 50 | expected_sarsa.last_action = a 51 | 52 | tester:assert(sarsa == expected_sarsa) 53 | end 54 | 55 | function TestNNSarsa.test_td_update_many_times() 56 | local sarsa = get_sarsa() 57 | local expected_module = sarsa.q.module:clone() 58 | 59 | local s = 2 60 | local a = 1 61 | local td_error = -0.4 62 | sarsa:update_eligibility(s, a) 63 | sarsa:td_update(td_error) 64 | 65 | local input = fe:get_sa_features(s, a) 66 | expected_module:forward(input) 67 | local grad_out = torch.Tensor{1} 68 | expected_module:backward(input, grad_out) 69 | local momentum = lambda * discount_factor 70 | local learning_rate = step_size * td_error 71 | expected_module:updateParameters(-learning_rate) 72 | 73 | s = 2 74 | a = 2 75 | td_error = -0.6 76 | sarsa:update_eligibility(s, a) 77 | sarsa:td_update(td_error) 78 | 79 | input = fe:get_sa_features(s, a) 80 | expected_module:forward(input) 81 | grad_out = torch.Tensor{1} 82 | expected_module:zeroGradParameters() 83 | expected_module:backward(input, grad_out) 84 | momentum = lambda * discount_factor 85 | expected_module:updateGradParameters(momentum, 0, false) 86 | local learning_rate = step_size * td_error 87 | expected_module:updateParameters(-learning_rate) 88 | 89 | local expected_sarsa = rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size) 90 | expected_sarsa.q.module = expected_module 91 | expected_sarsa.last_state = s 92 | expected_sarsa.last_action = a 93 | 94 | tester:assert(sarsa == expected_sarsa) 95 | end 96 | 97 | function TestNNSarsa.test_reset_eligibility() 98 | local sarsa = get_sarsa() 99 | local expected_module = sarsa.q.module:clone() 100 | 101 | if not sarsa.q:is_linear() then 102 | return 103 | end 104 | 105 | local s = 2 106 | local a = 1 107 | local td_error = -0.4 108 | local old_value = sarsa.q:get_value(s, a) 109 | sarsa:update_eligibility(s, a) 110 | sarsa:td_update(td_error) 111 | 112 | local new_value1 = sarsa.q:get_value(s, a) 113 | local d_value_1 = new_value1 - old_value 114 | 115 | sarsa:reset_eligibility(s, a) 116 | sarsa:update_eligibility(s, a) 117 | sarsa:td_update(td_error) 118 | local new_value2 = sarsa.q:get_value(s, a) 119 | local d_value_2 = new_value2 - new_value1 120 | 121 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS) 122 | end 123 | 124 | tester:add(TestNNSarsa) 125 | 126 | tester:run() 127 | 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-rl 2 | This is a Torch 7 package that implements a few reinforcement learning (RL) 3 | algorithms. So far, only Q-learning is implemented. 4 | 5 | ## Installation 6 | #### Dependencies 7 | 8 | 0. (Optional) 9 | [LuaRocks](https://github.com/keplerproject/luarocks/wiki/Download) 10 | - Highly recommended for anyone who wants to use Lua. Also, install this before 11 | install Torch, as Torch has weird configurations for LuaRocks. 12 | 1. [Torch](http://torch.ch/docs/getting-started.html) 13 | 2. [Lua 5.1](http://www.lua.org/download.html) - Installing torch automatically 14 | installs Lua 15 | 16 | #### LuaRocks - Automatic Installation (Recommended) 17 | 18 | 1. Install [LuaRocks](https://github.com/keplerproject/luarocks/wiki/Download). 19 | 2. From terminal, run 20 | ``` 21 | $ luarocks install rl 22 | ``` 23 | 24 | Error finding the module? Try 25 | ``` 26 | $ luarocks install --server=https://luarocks.org/ rl 27 | ``` 28 | 29 | #### LuaRocks - Manual Installation 30 | 1. `$ git clone git@github.com:vpong/torch-rl.git` 31 | 2. `$ luarocks make` 32 | 33 | #### Totally Manually 34 | ``` 35 | $ git clone git@github.com:vpong/torch-rl.git 36 | ``` 37 | Note that you'll basically have to the files around to any project that you want 38 | to use. 39 | 40 | ## Reinforcement Learning Topics 41 | * [MDP](doc/mdp.md) - A Markov Decision Process (MDP) models the world. Read 42 | about useful MDP functions and [how to create your own 43 | MDP](doc/mdp.md#create_mdp). 44 | * [Policy](doc/policy.md) - Policies are mappings from state to action. 45 | * [Sarsa](doc/sarsa.md) - Read about the Sarsa-lambda algorithm and scripts that 46 | test them. 47 | * [Monte Carlo Control](doc/montecarlo.md) - Read about Monte Carlo Control and 48 | how to use it. 49 | * [Value Functions](doc/valuefunctions.md) - Value functions represent how 50 | valuable certains states and/or actions are. 51 | * [Black Jack](https://github.com/vpong/rl-example) - An example repository that 52 | shows how to use RL algorithms to learn to play (a simplified version of) 53 | black jack. 54 | 55 | ## Summary of files 56 | This gives a summary of the most important files. Files that start with upper 57 | cases are classes. For more detail, see the source code. For examples on how to 58 | use the functions, see the unit tests. 59 | 60 | #### Interfaces/abstract classes 61 | * `Control.lua` - Represents an algorithm that improves a policy. 62 | * `Mdp.lua` - A Markov Decision Proccess that represents an environemtn. 63 | * `Policy.lua` - A way of deciding what action to do given a state. 64 | * `Sarsa.lua` - A specific Control algorithm. Technically, it's Sarsa-lambda 65 | * `SAFeatureExtractor.lua` - Represents a way of extracting features from a given 66 | [S]tate-[A]ction pair. 67 | * `ControlFactory.lua` - Used to create new Control instances 68 | * `Explorer.lua` - Used to get the epsilon value for EpsilonGreedy 69 | 70 | #### Concrete classes 71 | * `Evaluator.lua` - Used to measure the performance of a policy 72 | * `MdpConfig.lua` - A way of configuring an Mdp. 73 | * `MdpSampler.lua` - A useful wrapper around an Mdp. 74 | * `QVAnalyzer.lua` - Used to get measurements out of Control algorithms 75 | 76 | #### Specific implementations 77 | * `EpsilonGreedyPolicy.lua` - Implements epsilon greedy policy. 78 | * `DecayTableExplorer.lua` - A way of decaying epsilon for epsilon greedy policy 79 | to ensure convergence. 80 | * `NNSarsa.lua` - Implements Sarsa-lambda using neural networks as a function 81 | approximator 82 | * `LinSarsa.lua` - Implements Sarsa-lambda using linear weighting as a function 83 | approximator 84 | * `TableSarsa.lua` - Implements Sarsa-lambda using a lookup table. 85 | * `MonteCarloControl.lua` - Implements Monte Carlo control. 86 | 87 | #### Other Files 88 | * `util/*.lua` - Utility functions used throughout the project. 89 | * `test/unittest_*.lua` - Unit tests. Can be run individual tests with `th 90 | unittest_*.lua`. 91 | * `run_tests.lua` - Run all unit tests in `test/`directory. 92 | * `TestMdp.lua` - An MDP used for testing. 93 | * `TestPolicy.lua` - A policy for TestMdp used for testing. 94 | * `TestSAFE.lua` - A feature extractor used for testing. 95 | 96 | ## A note on Abstract Classes and private methods 97 | Torch doesn't implement interfaces nor abstract classes natively, but this 98 | packages tries to implement them by defining functions and raising an error if 99 | you try to implement it. (We'll call everything an abstract class 100 | just for simplicity.) 101 | 102 | Also, Torch classes don't provide a way of making private methods, so this is 103 | faked with the following: 104 | 105 | ```lua 106 | 107 | local function private_method(self, arg1, ...) 108 | ... 109 | end 110 | 111 | function Foo:bar() 112 | self:public_method(arg1, ...) -- Call public methods normally 113 | private_method(self, arg1, ...) -- Use this to call private methods 114 | end 115 | ``` 116 | -------------------------------------------------------------------------------- /test/unittest_QLin.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestQLin = {} 5 | 6 | local mdp = rl.TestMdp() 7 | function TestQLin.test_add_once() 8 | local fe = rl.TestSAFE() 9 | 10 | local q = rl.QLin(mdp, fe) 11 | local feature_dim = fe:get_sa_features_dim() 12 | q.weights = torch.zeros(feature_dim) 13 | local d_weights = torch.zeros(feature_dim) 14 | d_weights[1] = 1 15 | q:add(d_weights) 16 | 17 | local s = 1 18 | local a = 1 19 | tester:asserteq(q:get_value(s, a), 2, "Wrong state-action value.") 20 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.") 21 | 22 | local expected_q_table = { -- row = state, colm = action 23 | [1] = {1+1, 1+2, 1+3}, 24 | [2] = {2+1, 2+2, 2+3}, 25 | [3] = {3+1, 3+2, 3+3} 26 | } 27 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q)) 28 | end 29 | 30 | function TestQLin.test_add_complex() 31 | local fe = rl.TestSAFE() 32 | 33 | local q = rl.QLin(mdp, fe) 34 | local feature_dim = fe:get_sa_features_dim() 35 | q.weights = torch.zeros(feature_dim) 36 | local d_weights = torch.zeros(feature_dim) 37 | local weight_1 = 3 38 | local weight_2 = 0.5 39 | d_weights[1] = weight_1 40 | d_weights[2] = weight_2 41 | q:add(d_weights) 42 | 43 | local s = 3 44 | local a = 1 45 | tester:asserteq(q:get_value(s, a), 12+1, "Wrong state-action value.") 46 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.") 47 | 48 | local expected_q_table = { -- row = state, colm = action 49 | [1] = { 50 | weight_1*(1+1) + weight_2*(1-1), 51 | weight_1*(1+2) + weight_2*(1-2), 52 | weight_1*(1+3) + weight_2*(1-3) 53 | }, 54 | [2] = { 55 | weight_1*(2+1) + weight_2*(2-1), 56 | weight_1*(2+2) + weight_2*(2-2), 57 | weight_1*(2+3) + weight_2*(2-3) 58 | }, 59 | [3] = { 60 | weight_1*(3+1) + weight_2*(3-1), 61 | weight_1*(3+2) + weight_2*(3-2), 62 | weight_1*(3+3) + weight_2*(3-3) 63 | } 64 | } 65 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q)) 66 | end 67 | 68 | function TestQLin.test_add_and_multiply() 69 | local fe = rl.TestSAFE() 70 | 71 | local q = rl.QLin(mdp, fe) 72 | local feature_dim = fe:get_sa_features_dim() 73 | q.weights = torch.zeros(feature_dim) 74 | local d_weights = torch.zeros(feature_dim) 75 | local weight_1 = 3 76 | local weight_2 = 0.5 77 | d_weights[1] = weight_1 78 | d_weights[2] = weight_2 79 | q:add(d_weights) 80 | 81 | local factor = 0.9 82 | q:mult(factor) 83 | 84 | local s = 3 85 | local a = 1 86 | tester:asserteq( 87 | q:get_value(s, a), 88 | factor*(12+1), 89 | "Wrong state-action value.") 90 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.") 91 | 92 | local expected_q_table = { -- row = state, colm = action 93 | [1] = { 94 | factor * (weight_1*(1+1) + weight_2*(1-1)), 95 | factor * (weight_1*(1+2) + weight_2*(1-2)), 96 | factor * (weight_1*(1+3) + weight_2*(1-3)) 97 | }, 98 | [2] = { 99 | factor * (weight_1*(2+1) + weight_2*(2-1)), 100 | factor * (weight_1*(2+2) + weight_2*(2-2)), 101 | factor * (weight_1*(2+3) + weight_2*(2-3)) 102 | }, 103 | [3] = { 104 | factor * (weight_1*(3+1) + weight_2*(3-1)), 105 | factor * (weight_1*(3+2) + weight_2*(3-2)), 106 | factor * (weight_1*(3+3) + weight_2*(3-3)) 107 | } 108 | } 109 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q)) 110 | end 111 | 112 | function TestQLin.test_clear() 113 | local fe = rl.TestSAFE() 114 | 115 | local q = rl.QLin(mdp, fe) 116 | local feature_dim = fe:get_sa_features_dim() 117 | q.weights = torch.zeros(feature_dim) 118 | local d_weights = torch.zeros(feature_dim) 119 | local weight_1 = 3 120 | local weight_2 = 0.5 121 | d_weights[1] = weight_1 122 | d_weights[2] = weight_2 123 | q:add(d_weights) 124 | 125 | local factor = 0.9 126 | q:mult(factor) 127 | 128 | -- Now it should be the same as test_add_once 129 | q:clear() 130 | 131 | d_weights[1] = 1 132 | d_weights[2] = 0 133 | q:add(d_weights) 134 | 135 | local s = 1 136 | local a = 1 137 | tester:asserteq(q:get_value(s, a), 2, "Wrong state-action value.") 138 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.") 139 | 140 | local expected_q_table = { -- row = state, colm = action 141 | [1] = {1+1, 1+2, 1+3}, 142 | [2] = {2+1, 2+2, 2+3}, 143 | [3] = {3+1, 3+2, 3+3} 144 | } 145 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q)) 146 | end 147 | tester:add(TestQLin) 148 | 149 | tester:run() 150 | -------------------------------------------------------------------------------- /test/unittest_QNN.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | local tester = torch.Tester() 3 | 4 | local TestQNN = {} 5 | 6 | local mdp = rl.TestMdp() 7 | local fe = rl.TestSAFE() 8 | 9 | function TestQNN.test_backward() 10 | local q = rl.QNN(mdp, fe) 11 | local module = q.module:clone() 12 | 13 | local s = 2 14 | local a = 4 15 | 16 | local step_size = 0.5 17 | local lambda = 0 18 | local discount_factor = 0.5 19 | local td_error = 1 20 | local learning_rate = step_size * td_error 21 | local momentum = lambda * discount_factor 22 | q:backward(s, a, learning_rate, momentum) 23 | local new_params = q.module:parameters() 24 | 25 | -- This is kinda cheating since this is basically the same code as the 26 | -- function, but I also don't see a better way to do this. 27 | local input = fe:get_sa_features(s, a) 28 | local grad_out = torch.Tensor{1} 29 | local _ = module:forward(input) 30 | module:backward(input, grad_out) 31 | module:updateGradParameters(momentum, 0, false) 32 | module:updateParameters(-step_size*td_error) 33 | local expected_params = module:parameters() 34 | 35 | tester:assertTensorEq(expected_params[1], new_params[1], 0) 36 | tester:assertTensorEq(expected_params[2], new_params[2], 0) 37 | end 38 | 39 | function TestQNN.test_backward_no_momentum() 40 | local q = rl.QNN(mdp, fe) 41 | if not q:is_linear() then 42 | return 43 | end 44 | 45 | local module = q.module:clone() 46 | 47 | local s = 2 48 | local a = 4 49 | local old_value = q:get_value(s, a) 50 | 51 | local step_size = 0.5 52 | local td_error = 1 53 | local learning_rate = step_size * td_error 54 | local momentum = 0 55 | 56 | q:backward(s, a, learning_rate, momentum) 57 | local new_value1 = q:get_value(s, a) 58 | local d_value_1 = new_value1 - old_value 59 | 60 | q:backward(s, a, learning_rate, momentum) 61 | local new_value2 = q:get_value(s, a) 62 | local d_value_2 = new_value2 - new_value1 63 | 64 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS) 65 | end 66 | 67 | function TestQNN.test_backward_with_momentum() 68 | local q = rl.QNN(mdp, fe) 69 | if not q:is_linear() then 70 | return 71 | end 72 | local module = q.module:clone() 73 | 74 | local s = 2 75 | local a = 4 76 | local old_value = q:get_value(s, a) 77 | 78 | local step_size = 0.5 79 | local lambda = 1 80 | local discount_factor = 1 81 | local td_error = 1 82 | local learning_rate = step_size * td_error 83 | local momentum = lambda * discount_factor 84 | 85 | q:backward(s, a, learning_rate, momentum) 86 | local new_value1 = q:get_value(s, a) 87 | local d_value_1 = new_value1 - old_value 88 | 89 | q:backward(s, a, learning_rate, momentum) 90 | local new_value2 = q:get_value(s, a) 91 | local d_value_2 = new_value2 - new_value1 92 | 93 | tester:assert(math.abs((1+momentum)*d_value_1 - d_value_2) < rl.FLOAT_EPS) 94 | end 95 | 96 | function TestQNN.test_momentum_exists() 97 | local q = rl.QNN(mdp, fe) 98 | -- If q is non-linear, then all bets are off on whether or not the momentum 99 | -- will change things. 100 | if not q:is_linear() then 101 | return 102 | end 103 | local module = q.module:clone() 104 | 105 | local s = 2 106 | local a = 4 107 | local old_value = q:get_value(s, a) 108 | 109 | local step_size = 0.5 110 | local lambda = 1 111 | local discount_factor = 1 112 | local td_error = 1 113 | local learning_rate = step_size * td_error 114 | local momentum = lambda * discount_factor 115 | 116 | q:backward(s, a, learning_rate, momentum) 117 | local new_value1 = q:get_value(s, a) 118 | local d_value_1 = new_value1 - old_value 119 | 120 | q:backward(s, a, learning_rate, momentum) 121 | local new_value2 = q:get_value(s, a) 122 | local d_value_2 = new_value2 - new_value1 123 | 124 | tester:assert(math.abs(d_value_1 - d_value_2) > rl.FLOAT_EPS) 125 | end 126 | 127 | function TestQNN.test_backward_reset_momentum() 128 | local q = rl.QNN(mdp, fe) 129 | if not q:is_linear() then 130 | return 131 | end 132 | local module = q.module:clone() 133 | 134 | local s = 2 135 | local a = 4 136 | local old_value = q:get_value(s, a) 137 | 138 | local step_size = 0.5 139 | local td_error = 1 140 | local learning_rate = step_size * td_error 141 | local momentum = 1 142 | 143 | q:backward(s, a, learning_rate, momentum) 144 | local new_value1 = q:get_value(s, a) 145 | local d_value_1 = new_value1 - old_value 146 | 147 | q:reset_momentum() 148 | 149 | q:backward(s, a, learning_rate, momentum) 150 | local new_value2 = q:get_value(s, a) 151 | local d_value_2 = new_value2 - new_value1 152 | 153 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS) 154 | end 155 | 156 | tester:add(TestQNN) 157 | 158 | tester:run() 159 | -------------------------------------------------------------------------------- /util/util.lua: -------------------------------------------------------------------------------- 1 | -- Thanks to 2 | -- https://scriptinghelpers.org/questions/11242/how-to-make-a-weighted-selection 3 | function rl.util.weighted_random_choice(items) 4 | -- Sum all weights 5 | local total_weight = 0 6 | for item, weight in pairs(items) do 7 | total_weight = total_weight + weight 8 | end 9 | 10 | -- Pick random value 11 | rand = math.random() * total_weight 12 | choice = nil 13 | 14 | -- Search for the interval [0, w1] [w1, w1 + w2] [w1 + w2, w1 + w2 + w3] ... 15 | -- that `rand` belongs to 16 | -- and select the corresponding choice 17 | for item, weight in pairs(items) do 18 | if rand < weight then 19 | choice = item 20 | break 21 | else 22 | rand = rand - weight 23 | end 24 | end 25 | 26 | return choice 27 | end 28 | 29 | -- Not called fold left/right because order of pairs is not guaranteed. 30 | function rl.util.fold(fn) 31 | return function (acc) 32 | return function (list) 33 | for k, v in pairs(list) do 34 | acc = fn(acc, v) 35 | end 36 | return acc 37 | end 38 | end 39 | end 40 | 41 | -- Do this to avoid accumulator being reused 42 | local sum_from = rl.util.fold(function (a, b) return a + b end) 43 | function rl.util.sum(lst) 44 | return sum_from(0)(lst) 45 | end 46 | 47 | function rl.util.fold_with_key(fn) 48 | return function (acc) 49 | return function (list) 50 | for k, v in pairs(list) do 51 | acc = fn(acc, k, v) 52 | end 53 | return acc 54 | end 55 | end 56 | end 57 | 58 | -- Return the (element with max value, key of that element) of a table 59 | -- value_of is a function that given an element in the list, returns its value. 60 | function rl.util.max(tab, value_of) 61 | max_elem = nil 62 | maxK = nil 63 | maxV = 0 64 | for k, elem in pairs(tab) do 65 | curr_v = value_of(elem) 66 | if max_elem == nil or curr_v > maxV then 67 | max_elem = elem 68 | maxK = k 69 | maxV = curr_v 70 | end 71 | end 72 | return max_elem, maxK 73 | end 74 | 75 | -- Thanks to https://gist.github.com/tylerneylon/81333721109155b2d244 76 | function rl.util.copy_simply(obj) 77 | if type(obj) ~= 'table' then return obj end 78 | local res = {} 79 | for k, v in pairs(obj) do res[rl.util.copy_simply(k)] = rl.util.copy_simply(v) end 80 | return res 81 | end 82 | 83 | -- cache the results of a function call 84 | function rl.util.memoize(f) 85 | local cache = nil 86 | return ( 87 | function () 88 | if cache == nil then 89 | cache = f() 90 | end 91 | return cache 92 | end 93 | ) 94 | end 95 | 96 | -- Compare two tables, optionally ignore meta table (default FALSE) 97 | -- source: https://web.archive.org/web/20131225070434/http://snippets.luacode.org/snippets/Deep_Comparison_of_Two_Values_3 98 | local function deepcompare(t1, t2, ignore_mt) 99 | local ty1 = type(t1) 100 | local ty2 = type(t2) 101 | if ty1 ~= ty2 then 102 | return false 103 | end 104 | 105 | -- non-table types can be directly compared 106 | if ty1 ~= 'table' and ty2 ~= 'table' then 107 | return t1 == t2 108 | end 109 | 110 | -- as well as tables which have the metamethod __eq 111 | local mt = getmetatable(t1) 112 | if not ignore_mt and mt and mt.__eq then 113 | return t1 == t2 114 | end 115 | 116 | for k1,v1 in pairs(t1) do 117 | local v2 = t2[k1] 118 | if v2 == nil or not rl.util.deepcompare(v1,v2) then 119 | return false 120 | end 121 | end 122 | for k2,v2 in pairs(t2) do 123 | local v1 = t1[k2] 124 | if v1 == nil or not rl.util.deepcompare(v1,v2) then 125 | return false 126 | end 127 | end 128 | return true 129 | end 130 | 131 | function rl.util.deepcompare(t1, t2) 132 | return deepcompare(t1, t2, true) 133 | end 134 | 135 | function rl.util.deepcompare_with_meta(t1, t2) 136 | return deepcompare(t1, t2, false) 137 | end 138 | 139 | -- Get # of times elem is in list 140 | function rl.util.get_count(elem, list) 141 | local count = 0 142 | for _, e in pairs(list) do 143 | if e == elem then 144 | count = count + 1 145 | end 146 | end 147 | return count 148 | end 149 | 150 | -- check if Bernounilli trial results is reasonable 151 | function rl.util.is_prob_good(n, p, N) 152 | if p == 0 then 153 | return n == 0 154 | end 155 | if p < 0 or p > 1 then 156 | error('Invalid probability: ' .. p) 157 | end 158 | local std = math.sqrt(N * p * (1-p)) 159 | local mean = N * p 160 | return (mean - 3*std < n and n < mean + 3*std) 161 | end 162 | 163 | -- Check if # times elem is in list is reasonable, assuming it had a fixed 164 | -- probability of being in that list 165 | function rl.util.elem_has_good_freq(elem, list, expected_p) 166 | local n = rl.util.get_count(elem, list) 167 | return rl.util.is_prob_good(n, expected_p, #list) 168 | end 169 | -------------------------------------------------------------------------------- /test/unittest_MdpSampler.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | 5 | local TestMdpSampler = {} 6 | 7 | local function get_sampler(discount_factor) 8 | local config = rl.MdpConfig(rl.TestMdp(), discount_factor) 9 | return rl.MdpSampler(config) 10 | end 11 | 12 | local function get_policy_episode(policy, discount_factor) 13 | local sampler = get_sampler(discount_factor) 14 | return sampler:get_episode(policy) 15 | end 16 | 17 | local function is_discount_good(episode, discount_factor) 18 | local data = episode[1] 19 | local last_Gt = data.discounted_return 20 | local last_r = data.reward 21 | for t = 2, #episode do 22 | data = episode[t] 23 | 24 | local s = data.state 25 | local a = data.action 26 | local Gt = data.discounted_return 27 | local r = data.reward 28 | if last_Gt ~= last_r + discount_factor * Gt then 29 | return false 30 | end 31 | 32 | last_Gt = Gt 33 | last_r = r 34 | end 35 | return true 36 | end 37 | 38 | function TestMdpSampler.test_get_episode_discounted_reward() 39 | local discount_factor = 1 40 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor) 41 | tester:assert(is_discount_good(episode, discount_factor)) 42 | 43 | local discount_factor = 0.5 44 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor) 45 | tester:assert(is_discount_good(episode, discount_factor)) 46 | 47 | local discount_factor = 0 48 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor) 49 | tester:assert(is_discount_good(episode, discount_factor)) 50 | end 51 | 52 | function TestMdpSampler.test_discounted_reward_error() 53 | local discount_factor = 2 54 | local get_config = function () 55 | return rl.MdpConfig(TestMdp, discount_factor) 56 | end 57 | tester:assertError(get_config) 58 | 59 | discount_factor = -1 60 | get_config = function () 61 | return rl.MdpConfig(TestMdp, discount_factor) 62 | end 63 | tester:assertError(get_config) 64 | end 65 | 66 | function TestMdpSampler.test_sample_return_always_one() 67 | local policy = rl.TestPolicy(1) 68 | 69 | local discount_factor = 1 70 | local sampler = get_sampler(discount_factor) 71 | tester:asserteq(sampler:sample_total_reward(policy), -2) 72 | 73 | local discount_factor = 0 74 | local sampler = get_sampler(discount_factor) 75 | tester:asserteq(sampler:sample_total_reward(policy), -2) 76 | 77 | local discount_factor = 0.5 78 | local sampler = get_sampler(discount_factor) 79 | tester:asserteq(sampler:sample_total_reward(policy), -2) 80 | end 81 | 82 | function TestMdpSampler.test_sample_return_always_two() 83 | local policy = rl.TestPolicy(2) 84 | 85 | local discount_factor = 1 86 | local sampler = get_sampler(discount_factor) 87 | tester:asserteq(sampler:sample_total_reward(policy), 0) 88 | 89 | local discount_factor = 0 90 | local sampler = get_sampler(discount_factor) 91 | tester:asserteq(sampler:sample_total_reward(policy), 0) 92 | 93 | local discount_factor = 0.5 94 | local sampler = get_sampler(discount_factor) 95 | tester:asserteq(sampler:sample_total_reward(policy), 0) 96 | end 97 | 98 | function TestMdpSampler.test_sample_return_always_three() 99 | local policy = rl.TestPolicy(3) 100 | 101 | local discount_factor = 1 102 | local sampler = get_sampler(discount_factor) 103 | tester:asserteq(sampler:sample_total_reward(policy), 2) 104 | 105 | local discount_factor = 0 106 | local sampler = get_sampler(discount_factor) 107 | tester:asserteq(sampler:sample_total_reward(policy), 2) 108 | 109 | local discount_factor = 0.5 110 | local sampler = get_sampler(discount_factor) 111 | tester:asserteq(sampler:sample_total_reward(policy), 2) 112 | end 113 | 114 | local function is_action_good(episode, expected) 115 | local policy = rl.TestPolicy(expected) 116 | local discount_factor = 1 117 | 118 | local episode = get_policy_episode(policy, discount_factor) 119 | for t, data in pairs(episode) do 120 | if data.action ~= expected then 121 | return false 122 | end 123 | end 124 | return true 125 | end 126 | 127 | function TestMdpSampler.test_action_is_good() 128 | tester:assert(is_action_good(episode, 1)) 129 | end 130 | 131 | function TestMdpSampler.test_action_is_good2() 132 | tester:assert(is_action_good(episode, 2)) 133 | end 134 | 135 | function TestMdpSampler.test_action_is_good3() 136 | tester:assert(is_action_good(episode, 3)) 137 | end 138 | 139 | 140 | function TestMdpSampler.test_episode() 141 | local policy = rl.TestPolicy(1) 142 | local discount_factor = 1 143 | local episode = get_policy_episode(policy, discount_factor) 144 | 145 | tester:assert(#episode == 2) 146 | tester:assert(episode[1].state == 1) 147 | tester:assert(episode[1].action == 1) 148 | tester:assert(episode[1].discounted_return == -2) 149 | tester:assert(episode[1].reward == -1) 150 | tester:assert(episode[2].state == 2) 151 | tester:assert(episode[2].action == 1) 152 | tester:assert(episode[2].discounted_return == -1) 153 | tester:assert(episode[2].reward == -1) 154 | end 155 | 156 | tester:add(TestMdpSampler) 157 | 158 | tester:run() 159 | -------------------------------------------------------------------------------- /test/unittest_util.lua: -------------------------------------------------------------------------------- 1 | require 'rl' 2 | 3 | local tester = torch.Tester() 4 | 5 | local TestUtil = {} 6 | function TestUtil.test_get_count() 7 | local list = {1, 1, 1, 2} 8 | tester:asserteq(rl.util.get_count(1, list), 3) 9 | end 10 | 11 | function TestUtil.test_is_prob_good_good() 12 | local n = 5 13 | local p = 0.5 14 | local N = 10 15 | tester:assert(rl.util.is_prob_good(n, p, N)) 16 | tester:assert(rl.util.is_prob_good(0, 0, N)) 17 | end 18 | 19 | function TestUtil.test_is_prob_good_bad() 20 | local n = 1 21 | local p = 0.5 22 | local N = 100 23 | tester:assert(not rl.util.is_prob_good(n, p, N)) 24 | tester:assert(not rl.util.is_prob_good(1, 0, N)) 25 | end 26 | 27 | function TestUtil.test_elem_has_good_freq() 28 | local list = {1, 1, 1, 2} 29 | tester:assert(rl.util.elem_has_good_freq(1, list, 0.75)) 30 | tester:assert(rl.util.elem_has_good_freq(2, list, 0.25)) 31 | tester:assert(rl.util.elem_has_good_freq(3, list, 0)) 32 | tester:assert(not rl.util.elem_has_good_freq(2, list, 0)) 33 | end 34 | 35 | function TestUtil.test_fold() 36 | local t = { 37 | 1, 38 | 1, 39 | 1, 40 | 3, 41 | 1, 42 | 1, 43 | 1, 44 | 3 45 | } 46 | 47 | tester:asserteq(rl.util.sum(t), 12) 48 | tester:asserteq(rl.util.fold(function(a, b) return a - b end)(0)(t), -12) 49 | tester:asserteq(rl.util.fold(function(a, b) return a - b end)(-8)(t), -20) 50 | end 51 | 52 | function TestUtil.test_fold_with_key() 53 | local t = { 54 | 1, 55 | 1, 56 | 1, 57 | 3, 58 | } 59 | 60 | tester:asserteq( 61 | rl.util.fold_with_key(function(a, k, b) return a + k + b end)(0)(t), 62 | 16) 63 | tester:asserteq( 64 | rl.util.fold_with_key(function(a, k, b) return a + k + b end)(-6)(t), 65 | 10) 66 | end 67 | 68 | function TestUtil.test_weighted_random_choice_only_one() 69 | local t = { 70 | a = 0, 71 | b = 0, 72 | c = 0, 73 | d = 1 74 | } 75 | local function choice_is_always_good() 76 | for i = 1, 100 do 77 | if rl.util.weighted_random_choice(t) ~= 'd' then 78 | return false 79 | end 80 | end 81 | return true 82 | end 83 | tester:assert(choice_is_always_good()) 84 | end 85 | 86 | function TestUtil.test_weighted_random_choice() 87 | t = { 88 | 1, 89 | 1, 90 | 1, 91 | 3, 92 | 1, 93 | 1, 94 | 1, 95 | 3 96 | } 97 | local N = 100000 98 | local denom = rl.util.sum(t) 99 | local nums = torch.zeros(8) 100 | for i = 1, N do 101 | result = rl.util.weighted_random_choice(t) 102 | nums[result] = nums[result] + 1 103 | end 104 | local function prob_is_good() 105 | for i = 1, nums:numel() do 106 | if not rl.util.is_prob_good(nums[i], t[i] / denom, N) then 107 | return false 108 | end 109 | end 110 | return true 111 | end 112 | tester:assert(prob_is_good()) 113 | end 114 | 115 | function TestUtil.test_max() 116 | local t = { 117 | a = 1, 118 | b = 2, 119 | c = 3, 120 | d = -6 121 | } 122 | local maxElem, maxK = rl.util.max(t, function (v) return v end) 123 | tester:asserteq(maxElem , 3) 124 | tester:asserteq(maxK , 'c') 125 | 126 | maxElem, maxK = rl.util.max(t, function (v) return -v end) 127 | tester:asserteq(maxElem , -6) 128 | tester:asserteq(maxK , 'd') 129 | end 130 | 131 | function TestUtil.test_deepcompare() 132 | local t1 = { 133 | a = 1, 134 | b = {x=4, y=3, z=1}, 135 | c = {1, 2, 3}, 136 | d = -6 137 | } 138 | local t2 = { 139 | b = {x=4, y=3, z=1}, 140 | a = 1, 141 | d = -6, 142 | c = {1, 2, 3} 143 | } 144 | tester:assert(rl.util.deepcompare(t1, t2)) 145 | end 146 | 147 | function TestUtil.test_deepcompare_int_keys() 148 | local t1 = { 149 | [4] = 1, 150 | [-3] = 2, 151 | [0] = {1, 2, 3}, 152 | [1] = -6 153 | } 154 | local t2 = { 155 | [-3] = 2, 156 | [4] = 1, 157 | [1] = -6, 158 | [0] = {1, 2, 3} 159 | } 160 | tester:assert(rl.util.deepcompare(t1, t2)) 161 | end 162 | 163 | function TestUtil.test_deepcompare_with_meta() 164 | local t1 = { 165 | a = 1, 166 | b = {x=4, y=3, z=1}, 167 | c = {1, 2, 3}, 168 | d = -6 169 | } 170 | local t2 = { 171 | b = {x=4, y=3, z=1}, 172 | a = 1, 173 | d = -6, 174 | c = {1, 2, 3} 175 | } 176 | local mt = { 177 | __eq = function (lhs, rhs) 178 | return true 179 | end 180 | } 181 | setmetatable(t1, mt) 182 | setmetatable(t2, mt) 183 | tester:assert(rl.util.deepcompare_with_meta(t1, t2)) 184 | end 185 | 186 | function TestUtil.test_deepcompare_with_meta_false() 187 | local t1 = { 188 | a = 1, 189 | b = {x=4, y=3, z=1}, 190 | c = {1, 2, 3}, 191 | d = -6 192 | } 193 | local t2 = { 194 | b = {x=4, y=3, z=1}, 195 | a = 1, 196 | d = -6, 197 | c = {1, 2, 3} 198 | } 199 | local mt = { 200 | __eq = function (lhs, rhs) 201 | return false 202 | end 203 | } 204 | setmetatable(t1, mt) 205 | setmetatable(t2, mt) 206 | 207 | tester:assert(not rl.util.deepcompare_with_meta(t1, t2)) 208 | end 209 | 210 | tester:add(TestUtil) 211 | 212 | tester:run() 213 | -------------------------------------------------------------------------------- /SarsaAnalyzer.lua: -------------------------------------------------------------------------------- 1 | -- Analyze different control algorithms. 2 | local gnuplot = require 'gnuplot' 3 | 4 | local SarsaAnalyzer = torch.class('rl.SarsaAnalyzer') 5 | 6 | function SarsaAnalyzer:__init(opt, mdp_config, qvanalyzer, sarsa_factory) 7 | self.loadqfrom = opt.loadqfrom 8 | self.save = opt.save 9 | self.show = opt.show 10 | self.rms_num_points = opt.rms_num_points 11 | self.n_iters = opt.n_iters or N_ITERS 12 | 13 | self.mdp_config = mdp_config 14 | self.qvanalyzer = qvanalyzer 15 | self.sarsa_factory = sarsa_factory 16 | 17 | self.q_mc = nil 18 | end 19 | 20 | function SarsaAnalyzer:get_true_q(n_iters) 21 | if self.loadqfrom ~= nil and self.loadqfrom ~= '' then 22 | print('Loading q_mc from ' .. self.loadqfrom) 23 | return rl.util.load_q(self.loadqfrom) 24 | end 25 | 26 | self.n_iters = n_iters or self.n_iters 27 | local mc = rl.MonteCarloControl(self.mdp_config) 28 | print('Computing Q from Monte Carlo. # iters = ' .. self.n_iters) 29 | mc:improve_policy_for_n_iters(self.n_iters) 30 | 31 | return mc:get_q() 32 | end 33 | 34 | local function plot_rms_lambda_data(self, data) 35 | gnuplot.plot(data) 36 | gnuplot.grid(true) 37 | gnuplot.xlabel('lambda') 38 | gnuplot.ylabel('RMS between Q-MC and Q-SARSA') 39 | gnuplot.title('Q RMS episodes vs lambda, after ' 40 | .. self.n_iters .. ' iterations') 41 | end 42 | 43 | local function get_sarsa(self, lambda) 44 | self.sarsa_factory:set_lambda(lambda) 45 | return self.sarsa_factory:get_control() 46 | end 47 | 48 | local function plot_results(self, plot_function, image_fname) 49 | if self.show then 50 | gnuplot.figure() 51 | plot_function() 52 | end 53 | if self.save then 54 | gnuplot.epsfigure(image_fname) 55 | print('Saving plot to: ' .. image_fname) 56 | plot_function() 57 | gnuplot.plotflush() 58 | end 59 | end 60 | 61 | local function get_lambda_data(self) 62 | local rms_lambda_data = torch.Tensor(11, 2) 63 | local i = 1 64 | print('Generating data/plot for varying lambdas.') 65 | for lambda = 0, 1, 0.1 do 66 | print('Processing SARSA for lambda = ' .. lambda) 67 | local sarsa = get_sarsa(self, lambda) 68 | sarsa:improve_policy(self.n_iters) 69 | local q = sarsa:get_q() 70 | rms_lambda_data[i][1] = lambda 71 | rms_lambda_data[i][2] = self.qvanalyzer:q_rms(q, self.q_mc) 72 | i = i + 1 73 | end 74 | return rms_lambda_data 75 | end 76 | 77 | -- For a given control algorithm, see how the RMS changes with lambda. 78 | -- Sweeps and plots the performance for lambda = 0, 0.1, 0.2, ..., 1.0 79 | function SarsaAnalyzer:eval_lambdas( 80 | image_fname, 81 | n_iters) 82 | self.q_mc = self.q_mc or self:get_true_q() 83 | self.n_iters = n_iters or self.n_iters 84 | local rms_lambda_data = torch.Tensor(11, 2) 85 | local i = 1 86 | print('Generating data/plot for varying lambdas.') 87 | for lambda = 0, 1, 0.1 do 88 | print('Processing SARSA for lambda = ' .. lambda) 89 | local sarsa = get_sarsa(self, lambda) 90 | sarsa:improve_policy_for_n_iters(self.n_iters) 91 | local q = sarsa:get_q() 92 | rms_lambda_data[i][1] = lambda 93 | rms_lambda_data[i][2] = self.qvanalyzer:q_rms(q, self.q_mc) 94 | i = i + 1 95 | end 96 | 97 | plot_results(self, 98 | function () 99 | plot_rms_lambda_data(self, rms_lambda_data) 100 | end, 101 | image_fname) 102 | end 103 | 104 | local function plot_rms_episode_data(self, data_table) 105 | for lambda, data in pairs(data_table) do 106 | gnuplot.plot({tostring(lambda), data}) 107 | end 108 | 109 | gnuplot.plot({'0', data[0]}, 110 | {'1', data[1]}) 111 | gnuplot.grid(true) 112 | gnuplot.xlabel('Episode') 113 | gnuplot.ylabel('RMS between Q-MC and Q-SARSA') 114 | gnuplot.title('Q RMS vs Episode, lambda = 0 and 1, after ' 115 | .. self.n_iters .. ' iterations') 116 | end 117 | 118 | -- hack to get around that torch doesn't seem to allow private class methods 119 | local function get_rms_episode_data(self, lambda) 120 | local rms_episode_data = torch.Tensor(self.rms_num_points, 2) 121 | local sarsa = get_sarsa(self, lambda) 122 | sarsa:improve_policy() 123 | local q = sarsa:get_q() 124 | rms_episode_data[1][1] = 1 125 | rms_episode_data[1][2] = self.qvanalyzer:q_rms(q, self.q_mc) 126 | local n_iters_per_data_point = self.n_iters / self.rms_num_points 127 | local i = n_iters_per_data_point 128 | for j = 2, self.rms_num_points do 129 | sarsa:improve_policy_for_n_iters(n_iters_per_data_point) 130 | q = sarsa:get_q() 131 | rms_episode_data[j][1] = i 132 | rms_episode_data[j][2] = self.qvanalyzer:q_rms(q, self.q_mc) 133 | i = i + n_iters_per_data_point 134 | end 135 | return rms_episode_data 136 | end 137 | 138 | -- For a given control algorithm, see how the RMS changes with # of episodes for 139 | -- lambda = 0 and lambda = 1. 140 | function SarsaAnalyzer:eval_l0_l1_rms( 141 | image_fname, 142 | n_iters) 143 | self.q_mc = self.q_mc or self:get_true_q() 144 | n_iters = n_iters or self.n_iters 145 | 146 | print('Generating data for RMS vs episode') 147 | local l0_data = get_rms_episode_data(self, 0) 148 | local l1_data = get_rms_episode_data(self, 1) 149 | data = {} 150 | data[0] = l0_data 151 | data[1] = l1_data 152 | 153 | print('Generating plots for RMS vs episode') 154 | plot_results(self, 155 | function () 156 | plot_rms_episode_data(self, data) 157 | end, 158 | image_fname) 159 | end 160 | -------------------------------------------------------------------------------- /test/unittest_TableSarsa.lua: -------------------------------------------------------------------------------- 1 | 2 | local tester = torch.Tester() 3 | 4 | local TestTableSarsa = {} 5 | local discount_factor = 0.9 6 | local mdp = rl.TestMdp() 7 | local mdp_config = rl.MdpConfig(mdp, discount_factor) 8 | 9 | local function non_q_params_match( 10 | sarsa, 11 | Ns_expected, 12 | Nsa_expected, 13 | eligibility_expected) 14 | local Ns = sarsa.Ns 15 | local Nsa = sarsa.Nsa 16 | local elig = sarsa.eligibility 17 | 18 | return rl.util.do_vtable_vfunc_match(mdp, Ns_expected, Ns) 19 | and rl.util.do_qtable_qfunc_match(mdp, Nsa_expected, Nsa) 20 | and rl.util.do_qtable_qfunc_match(mdp, eligibility_expected, elig) 21 | end 22 | 23 | function TestTableSarsa.test_update_eligibility_one_step() 24 | local lambda = 1 25 | local sarsa = rl.TableSarsa(mdp_config, lambda) 26 | 27 | local s = 2 28 | local a = 1 29 | sarsa:update_eligibility(s, a) 30 | 31 | local Ns_expected = {0, 1, 0} 32 | local Nsa_expected = { -- row = state, colm = action 33 | [1] = {0, 0, 0}, 34 | [2] = {1, 0, 0}, 35 | [3] = {0, 0, 0} 36 | } 37 | local eligibility_expected = Nsa_expected 38 | local correct = non_q_params_match( 39 | sarsa, 40 | Ns_expected, 41 | Nsa_expected, 42 | eligibility_expected) 43 | tester:assert(correct) 44 | end 45 | 46 | function TestTableSarsa.test_update_eligibility_lambda1() 47 | local lambda = 1 48 | local sarsa = rl.TableSarsa(mdp_config, lambda) 49 | 50 | local s = 2 51 | local a = 1 52 | sarsa:update_eligibility(s, a) 53 | 54 | local Ns_expected = { 55 | [1] = 0, 56 | [2] = 1, 57 | [3] = 0, 58 | } 59 | local Nsa_expected = { -- row = state, colm = action 60 | [1] = {0, 0, 0}, 61 | [2] = {1, 0, 0}, 62 | [3] = {0, 0, 0} 63 | } 64 | local eligibility_expected = Nsa_expected 65 | local correct = non_q_params_match( 66 | sarsa, 67 | Ns_expected, 68 | Nsa_expected, 69 | eligibility_expected) 70 | tester:assert(correct) 71 | end 72 | 73 | function TestTableSarsa:test_update_eligibility_lambda_frac() 74 | local lambda = 0.5 75 | local sarsa = rl.TableSarsa(mdp_config, lambda) 76 | 77 | local s = 2 78 | local a = 1 79 | sarsa:update_eligibility(s, a) 80 | sarsa:update_eligibility(s, a) 81 | sarsa:update_eligibility(s, a) 82 | 83 | local decay_factor = lambda * discount_factor 84 | 85 | local Ns_expected = {0, 3, 0} 86 | local Nsa_expected = { -- row = state, colm = action 87 | [1] = {0, 0, 0}, 88 | [2] = {3, 0, 0}, 89 | [3] = {0, 0, 0} 90 | } 91 | local eligibility_expected = { 92 | [1] = {0, 0, 0}, 93 | [2] = {1 + decay_factor * (1 + decay_factor*1), 0, 0}, 94 | [3] = {0, 0, 0} 95 | } 96 | local correct = non_q_params_match( 97 | sarsa, 98 | Ns_expected, 99 | Nsa_expected, 100 | eligibility_expected) 101 | tester:assert(correct) 102 | end 103 | 104 | function TestTableSarsa:test_update_eligibility_lambda0() 105 | local lambda = 0 106 | local sarsa = rl.TableSarsa(mdp_config, lambda) 107 | 108 | local s = 2 109 | local a = 1 110 | sarsa:update_eligibility(s, a) 111 | sarsa:update_eligibility(s, a) 112 | sarsa:update_eligibility(s, a) 113 | 114 | local decay_factor = lambda * discount_factor 115 | 116 | local Ns_expected = {0, 3, 0} 117 | local Nsa_expected = { -- row = state, colm = action 118 | [1] = {0, 0, 0}, 119 | [2] = {3, 0, 0}, 120 | [3] = {0, 0, 0} 121 | } 122 | local eligibility_expected = { 123 | [1] = {0, 0, 0}, 124 | [2] = {1 + decay_factor * (1 + decay_factor*1), 0, 0}, 125 | [3] = {0, 0, 0} 126 | } 127 | local correct = non_q_params_match( 128 | sarsa, 129 | Ns_expected, 130 | Nsa_expected, 131 | eligibility_expected) 132 | tester:assert(correct) 133 | end 134 | 135 | function TestTableSarsa:test_update_eligibility_mixed_updates() 136 | local lambda = 0.4 137 | local sarsa = rl.TableSarsa(mdp_config, lambda) 138 | 139 | local s = 2 140 | local a = 1 141 | sarsa:update_eligibility(s, a) 142 | s = 1 143 | a = 2 144 | sarsa:update_eligibility(s, a) 145 | s = 2 146 | a = 3 147 | sarsa:update_eligibility(s, a) 148 | 149 | local decay_factor = lambda * discount_factor 150 | 151 | local Ns_expected = {1, 2, 0} 152 | local Nsa_expected = { -- row = state, colm = action 153 | [1] = {0, 1, 0}, 154 | [2] = {1, 0, 1}, 155 | [3] = {0, 0, 0} 156 | } 157 | local eligibility_expected = { 158 | [1] = {0, decay_factor, 0}, 159 | [2] = {decay_factor^2, 0, 1}, 160 | [3] = {0, 0, 0} 161 | } 162 | local correct = non_q_params_match( 163 | sarsa, 164 | Ns_expected, 165 | Nsa_expected, 166 | eligibility_expected) 167 | tester:assert(correct) 168 | end 169 | 170 | function TestTableSarsa:test_td_update_one_update() 171 | local lambda = 0.4 172 | local sarsa = rl.TableSarsa(mdp_config, lambda) 173 | 174 | local s = 2 175 | local a = 1 176 | sarsa:update_eligibility(s, a) 177 | local td_error = 5 178 | sarsa:td_update(td_error) 179 | local q_expected = { -- row = state, colm = action 180 | [1] = {0, 0, 0}, 181 | [2] = {5, 0, 0}, 182 | [3] = {0, 0, 0} 183 | } 184 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, q_expected, sarsa.q)) 185 | end 186 | 187 | function TestTableSarsa:test_td_update_many_updates() 188 | local lambda = 0.4 189 | local sarsa = rl.TableSarsa(mdp_config, lambda) 190 | 191 | local s = 2 192 | local a = 1 193 | sarsa:update_eligibility(s, a) 194 | local td_error = 5 195 | sarsa:td_update(td_error) 196 | sarsa:update_eligibility(s, a) 197 | sarsa:td_update(td_error) 198 | 199 | s = 3 200 | a = 3 201 | sarsa:update_eligibility(s, a) 202 | td_error = -10 203 | sarsa:td_update(td_error) 204 | 205 | local decay_factor = lambda * discount_factor 206 | local q_expected = { -- row = state, colm = action 207 | [1] = {0, 0, 0}, 208 | [2] = {5+5*(1+decay_factor)/2-10*decay_factor*(1+decay_factor)/2, 0, 0}, 209 | [3] = {0, 0, -10} 210 | } 211 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, q_expected, sarsa.q)) 212 | end 213 | 214 | tester:add(TestTableSarsa) 215 | 216 | tester:run() 217 | 218 | -------------------------------------------------------------------------------- /images/q3a.eps: -------------------------------------------------------------------------------- 1 | %!PS-Adobe-2.0 EPSF-2.0 2 | %%Title: q3a.eps 3 | %%Creator: gnuplot 4.6 patchlevel 4 4 | %%CreationDate: Wed Jan 20 18:03:19 2016 5 | %%DocumentFonts: (atend) 6 | %%BoundingBox: 50 50 410 302 7 | %%EndComments 8 | %%BeginProlog 9 | /gnudict 256 dict def 10 | gnudict begin 11 | % 12 | % The following true/false flags may be edited by hand if desired. 13 | % The unit line width and grayscale image gamma correction may also be changed. 14 | % 15 | /Color true def 16 | /Blacktext false def 17 | /Solid false def 18 | /Dashlength 1 def 19 | /Landscape false def 20 | /Level1 false def 21 | /Rounded false def 22 | /ClipToBoundingBox false def 23 | /SuppressPDFMark false def 24 | /TransparentPatterns false def 25 | /gnulinewidth 5.000 def 26 | /userlinewidth gnulinewidth def 27 | /Gamma 1.0 def 28 | /BackgroundColor {-1.000 -1.000 -1.000} def 29 | % 30 | /vshift -46 def 31 | /dl1 { 32 | 10.0 Dashlength mul mul 33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if 34 | } def 35 | /dl2 { 36 | 10.0 Dashlength mul mul 37 | Rounded { currentlinewidth 0.75 mul add } if 38 | } def 39 | /hpt_ 31.5 def 40 | /vpt_ 31.5 def 41 | /hpt hpt_ def 42 | /vpt vpt_ def 43 | /doclip { 44 | ClipToBoundingBox { 45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath 46 | clip 47 | } if 48 | } def 49 | % 50 | % Gnuplot Prolog Version 4.6 (September 2012) 51 | % 52 | %/SuppressPDFMark true def 53 | % 54 | /M {moveto} bind def 55 | /L {lineto} bind def 56 | /R {rmoveto} bind def 57 | /V {rlineto} bind def 58 | /N {newpath moveto} bind def 59 | /Z {closepath} bind def 60 | /C {setrgbcolor} bind def 61 | /f {rlineto fill} bind def 62 | /g {setgray} bind def 63 | /Gshow {show} def % May be redefined later in the file to support UTF-8 64 | /vpt2 vpt 2 mul def 65 | /hpt2 hpt 2 mul def 66 | /Lshow {currentpoint stroke M 0 vshift R 67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R 69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R 71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def 73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def 74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash} 75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def 76 | /BL {stroke userlinewidth 2 mul setlinewidth 77 | Rounded {1 setlinejoin 1 setlinecap} if} def 78 | /AL {stroke userlinewidth 2 div setlinewidth 79 | Rounded {1 setlinejoin 1 setlinecap} if} def 80 | /UL {dup gnulinewidth mul /userlinewidth exch def 81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def 82 | /PL {stroke userlinewidth setlinewidth 83 | Rounded {1 setlinejoin 1 setlinecap} if} def 84 | 3.8 setmiterlimit 85 | % Default Line colors 86 | /LCw {1 1 1} def 87 | /LCb {0 0 0} def 88 | /LCa {0 0 0} def 89 | /LC0 {1 0 0} def 90 | /LC1 {0 1 0} def 91 | /LC2 {0 0 1} def 92 | /LC3 {1 0 1} def 93 | /LC4 {0 1 1} def 94 | /LC5 {1 1 0} def 95 | /LC6 {0 0 0} def 96 | /LC7 {1 0.3 0} def 97 | /LC8 {0.5 0.5 0.5} def 98 | % Default Line Types 99 | /LTw {PL [] 1 setgray} def 100 | /LTb {BL [] LCb DL} def 101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def 102 | /LT0 {PL [] LC0 DL} def 103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def 104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def 105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def 106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def 107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def 108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def 109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def 110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def 111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def 112 | /Dia {stroke [] 0 setdash 2 copy vpt add M 113 | hpt neg vpt neg V hpt vpt neg V 114 | hpt vpt V hpt neg vpt V closepath stroke 115 | Pnt} def 116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V 117 | currentpoint stroke M 118 | hpt neg vpt neg R hpt2 0 V stroke 119 | } def 120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M 121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 122 | hpt2 neg 0 V closepath stroke 123 | Pnt} def 124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M 125 | hpt2 vpt2 neg V currentpoint stroke M 126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def 127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M 128 | hpt neg vpt -1.62 mul V 129 | hpt 2 mul 0 V 130 | hpt neg vpt 1.62 mul V closepath stroke 131 | Pnt} def 132 | /Star {2 copy Pls Crs} def 133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M 134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 135 | hpt2 neg 0 V closepath fill} def 136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M 137 | hpt neg vpt -1.62 mul V 138 | hpt 2 mul 0 V 139 | hpt neg vpt 1.62 mul V closepath fill} def 140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M 141 | hpt neg vpt 1.62 mul V 142 | hpt 2 mul 0 V 143 | hpt neg vpt -1.62 mul V closepath stroke 144 | Pnt} def 145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M 146 | hpt neg vpt 1.62 mul V 147 | hpt 2 mul 0 V 148 | hpt neg vpt -1.62 mul V closepath fill} def 149 | /DiaF {stroke [] 0 setdash vpt add M 150 | hpt neg vpt neg V hpt vpt neg V 151 | hpt vpt V hpt neg vpt V closepath fill} def 152 | /Pent {stroke [] 0 setdash 2 copy gsave 153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 154 | closepath stroke grestore Pnt} def 155 | /PentF {stroke [] 0 setdash gsave 156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 157 | closepath fill grestore} def 158 | /Circle {stroke [] 0 setdash 2 copy 159 | hpt 0 360 arc stroke Pnt} def 160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def 161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def 162 | /C1 {BL [] 0 setdash 2 copy moveto 163 | 2 copy vpt 0 90 arc closepath fill 164 | vpt 0 360 arc closepath} bind def 165 | /C2 {BL [] 0 setdash 2 copy moveto 166 | 2 copy vpt 90 180 arc closepath fill 167 | vpt 0 360 arc closepath} bind def 168 | /C3 {BL [] 0 setdash 2 copy moveto 169 | 2 copy vpt 0 180 arc closepath fill 170 | vpt 0 360 arc closepath} bind def 171 | /C4 {BL [] 0 setdash 2 copy moveto 172 | 2 copy vpt 180 270 arc closepath fill 173 | vpt 0 360 arc closepath} bind def 174 | /C5 {BL [] 0 setdash 2 copy moveto 175 | 2 copy vpt 0 90 arc 176 | 2 copy moveto 177 | 2 copy vpt 180 270 arc closepath fill 178 | vpt 0 360 arc} bind def 179 | /C6 {BL [] 0 setdash 2 copy moveto 180 | 2 copy vpt 90 270 arc closepath fill 181 | vpt 0 360 arc closepath} bind def 182 | /C7 {BL [] 0 setdash 2 copy moveto 183 | 2 copy vpt 0 270 arc closepath fill 184 | vpt 0 360 arc closepath} bind def 185 | /C8 {BL [] 0 setdash 2 copy moveto 186 | 2 copy vpt 270 360 arc closepath fill 187 | vpt 0 360 arc closepath} bind def 188 | /C9 {BL [] 0 setdash 2 copy moveto 189 | 2 copy vpt 270 450 arc closepath fill 190 | vpt 0 360 arc closepath} bind def 191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill 192 | 2 copy moveto 193 | 2 copy vpt 90 180 arc closepath fill 194 | vpt 0 360 arc closepath} bind def 195 | /C11 {BL [] 0 setdash 2 copy moveto 196 | 2 copy vpt 0 180 arc closepath fill 197 | 2 copy moveto 198 | 2 copy vpt 270 360 arc closepath fill 199 | vpt 0 360 arc closepath} bind def 200 | /C12 {BL [] 0 setdash 2 copy moveto 201 | 2 copy vpt 180 360 arc closepath fill 202 | vpt 0 360 arc closepath} bind def 203 | /C13 {BL [] 0 setdash 2 copy moveto 204 | 2 copy vpt 0 90 arc closepath fill 205 | 2 copy moveto 206 | 2 copy vpt 180 360 arc closepath fill 207 | vpt 0 360 arc closepath} bind def 208 | /C14 {BL [] 0 setdash 2 copy moveto 209 | 2 copy vpt 90 360 arc closepath fill 210 | vpt 0 360 arc} bind def 211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill 212 | vpt 0 360 arc closepath} bind def 213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto 214 | neg 0 rlineto closepath} bind def 215 | /Square {dup Rec} bind def 216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def 217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def 218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def 219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def 221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill 223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def 225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill 226 | 2 copy vpt Square fill Bsquare} bind def 227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def 228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def 229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill 230 | Bsquare} bind def 231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill 232 | Bsquare} bind def 233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def 234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 235 | 2 copy vpt Square fill Bsquare} bind def 236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def 239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def 240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def 241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def 242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def 243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def 244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def 245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def 246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def 247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def 248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def 249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def 250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def 251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def 252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def 253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def 254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def 255 | /DiaE {stroke [] 0 setdash vpt add M 256 | hpt neg vpt neg V hpt vpt neg V 257 | hpt vpt V hpt neg vpt V closepath stroke} def 258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M 259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 260 | hpt2 neg 0 V closepath stroke} def 261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M 262 | hpt neg vpt -1.62 mul V 263 | hpt 2 mul 0 V 264 | hpt neg vpt 1.62 mul V closepath stroke} def 265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M 266 | hpt neg vpt 1.62 mul V 267 | hpt 2 mul 0 V 268 | hpt neg vpt -1.62 mul V closepath stroke} def 269 | /PentE {stroke [] 0 setdash gsave 270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 271 | closepath stroke grestore} def 272 | /CircE {stroke [] 0 setdash 273 | hpt 0 360 arc stroke} def 274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def 275 | /DiaW {stroke [] 0 setdash vpt add M 276 | hpt neg vpt neg V hpt vpt neg V 277 | hpt vpt V hpt neg vpt V Opaque stroke} def 278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M 279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 280 | hpt2 neg 0 V Opaque stroke} def 281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M 282 | hpt neg vpt -1.62 mul V 283 | hpt 2 mul 0 V 284 | hpt neg vpt 1.62 mul V Opaque stroke} def 285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M 286 | hpt neg vpt 1.62 mul V 287 | hpt 2 mul 0 V 288 | hpt neg vpt -1.62 mul V Opaque stroke} def 289 | /PentW {stroke [] 0 setdash gsave 290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 291 | Opaque stroke grestore} def 292 | /CircW {stroke [] 0 setdash 293 | hpt 0 360 arc Opaque stroke} def 294 | /BoxFill {gsave Rec 1 setgray fill grestore} def 295 | /Density { 296 | /Fillden exch def 297 | currentrgbcolor 298 | /ColB exch def /ColG exch def /ColR exch def 299 | /ColR ColR Fillden mul Fillden sub 1 add def 300 | /ColG ColG Fillden mul Fillden sub 1 add def 301 | /ColB ColB Fillden mul Fillden sub 1 add def 302 | ColR ColG ColB setrgbcolor} def 303 | /BoxColFill {gsave Rec PolyFill} def 304 | /PolyFill {gsave Density fill grestore grestore} def 305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def 306 | % 307 | % PostScript Level 1 Pattern Fill routine for rectangles 308 | % Usage: x y w h s a XX PatternFill 309 | % x,y = lower left corner of box to be filled 310 | % w,h = width and height of box 311 | % a = angle in degrees between lines and x-axis 312 | % XX = 0/1 for no/yes cross-hatch 313 | % 314 | /PatternFill {gsave /PFa [ 9 2 roll ] def 315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate 316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec 317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 318 | clip 319 | currentlinewidth 0.5 mul setlinewidth 320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def 321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate 322 | 0 1 PFs PFa 4 get div 1 add floor cvi 323 | {PFa 4 get mul 0 M 0 PFs V} for 324 | 0 PFa 6 get ne { 325 | 0 1 PFs PFa 4 get div 1 add floor cvi 326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for 327 | } if 328 | stroke grestore} def 329 | % 330 | /languagelevel where 331 | {pop languagelevel} {1} ifelse 332 | 2 lt 333 | {/InterpretLevel1 true def} 334 | {/InterpretLevel1 Level1 def} 335 | ifelse 336 | % 337 | % PostScript level 2 pattern fill definitions 338 | % 339 | /Level2PatternFill { 340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8} 341 | bind def 342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def 343 | << Tile8x8 344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke} 345 | >> matrix makepattern 346 | /Pat1 exch def 347 | << Tile8x8 348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke 349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke} 350 | >> matrix makepattern 351 | /Pat2 exch def 352 | << Tile8x8 353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L 354 | 8 8 L 8 0 L 0 0 L fill} 355 | >> matrix makepattern 356 | /Pat3 exch def 357 | << Tile8x8 358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L 359 | 0 12 M 12 0 L stroke} 360 | >> matrix makepattern 361 | /Pat4 exch def 362 | << Tile8x8 363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L 364 | 0 -4 M 12 8 L stroke} 365 | >> matrix makepattern 366 | /Pat5 exch def 367 | << Tile8x8 368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L 369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke} 370 | >> matrix makepattern 371 | /Pat6 exch def 372 | << Tile8x8 373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L 374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke} 375 | >> matrix makepattern 376 | /Pat7 exch def 377 | << Tile8x8 378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L 379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke} 380 | >> matrix makepattern 381 | /Pat8 exch def 382 | << Tile8x8 383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L 384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke} 385 | >> matrix makepattern 386 | /Pat9 exch def 387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def 388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def 389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def 390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def 391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def 392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def 393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def 394 | } def 395 | % 396 | % 397 | %End of PostScript Level 2 code 398 | % 399 | /PatternBgnd { 400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 401 | } def 402 | % 403 | % Substitute for Level 2 pattern fill codes with 404 | % grayscale if Level 2 support is not selected. 405 | % 406 | /Level1PatternFill { 407 | /Pattern1 {0.250 Density} bind def 408 | /Pattern2 {0.500 Density} bind def 409 | /Pattern3 {0.750 Density} bind def 410 | /Pattern4 {0.125 Density} bind def 411 | /Pattern5 {0.375 Density} bind def 412 | /Pattern6 {0.625 Density} bind def 413 | /Pattern7 {0.875 Density} bind def 414 | } def 415 | % 416 | % Now test for support of Level 2 code 417 | % 418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse 419 | % 420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont 421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall 422 | currentdict end definefont pop 423 | /MFshow { 424 | { dup 5 get 3 ge 425 | { 5 get 3 eq {gsave} {grestore} ifelse } 426 | {dup dup 0 get findfont exch 1 get scalefont setfont 427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6 428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq 429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5 430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div 431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get 432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop 433 | pop aload pop M} ifelse }ifelse }ifelse } 434 | ifelse } 435 | forall} def 436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def 437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse } 438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont 439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def 440 | /MLshow { currentpoint stroke M 441 | 0 exch R 442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 443 | /MRshow { currentpoint stroke M 444 | exch dup MFwidth neg 3 -1 roll R 445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 446 | /MCshow { currentpoint stroke M 447 | exch dup MFwidth -2 div 3 -1 roll R 448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def 450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def 451 | Level1 SuppressPDFMark or 452 | {} { 453 | /SDict 10 dict def 454 | systemdict /pdfmark known not { 455 | userdict /pdfmark systemdict /cleartomark get put 456 | } if 457 | SDict begin [ 458 | /Title (q3a.eps) 459 | /Subject (gnuplot plot) 460 | /Creator (gnuplot 4.6 patchlevel 4) 461 | /Author (vitchyr) 462 | % /Producer (gnuplot) 463 | % /Keywords () 464 | /CreationDate (Wed Jan 20 18:03:19 2016) 465 | /DOCINFO pdfmark 466 | end 467 | } ifelse 468 | end 469 | %%EndProlog 470 | %%Page: 1 1 471 | gnudict begin 472 | gsave 473 | doclip 474 | 50 50 translate 475 | 0.050 0.050 scale 476 | 0 setgray 477 | newpath 478 | (Helvetica) findfont 140 scalefont setfont 479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if 480 | 1.000 UL 481 | LTb 482 | 0.13 0.13 0.13 C 1.000 UL 483 | LTa 484 | LCa setrgbcolor 485 | 602 448 M 486 | 6345 0 V 487 | stroke 488 | LTb 489 | 0.13 0.13 0.13 C 602 448 M 490 | 63 0 V 491 | 6282 0 R 492 | -63 0 V 493 | stroke 494 | 518 448 M 495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 30)] 496 | ] -46.7 MRshow 497 | 1.000 UL 498 | LTb 499 | 0.13 0.13 0.13 C 1.000 UL 500 | LTa 501 | LCa setrgbcolor 502 | 602 1143 M 503 | 6345 0 V 504 | stroke 505 | LTb 506 | 0.13 0.13 0.13 C 602 1143 M 507 | 63 0 V 508 | 6282 0 R 509 | -63 0 V 510 | stroke 511 | 518 1143 M 512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 35)] 513 | ] -46.7 MRshow 514 | 1.000 UL 515 | LTb 516 | 0.13 0.13 0.13 C 1.000 UL 517 | LTa 518 | LCa setrgbcolor 519 | 602 1838 M 520 | 6345 0 V 521 | stroke 522 | LTb 523 | 0.13 0.13 0.13 C 602 1838 M 524 | 63 0 V 525 | 6282 0 R 526 | -63 0 V 527 | stroke 528 | 518 1838 M 529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 40)] 530 | ] -46.7 MRshow 531 | 1.000 UL 532 | LTb 533 | 0.13 0.13 0.13 C 1.000 UL 534 | LTa 535 | LCa setrgbcolor 536 | 602 2534 M 537 | 6345 0 V 538 | stroke 539 | LTb 540 | 0.13 0.13 0.13 C 602 2534 M 541 | 63 0 V 542 | 6282 0 R 543 | -63 0 V 544 | stroke 545 | 518 2534 M 546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 45)] 547 | ] -46.7 MRshow 548 | 1.000 UL 549 | LTb 550 | 0.13 0.13 0.13 C 1.000 UL 551 | LTa 552 | LCa setrgbcolor 553 | 602 3229 M 554 | 6345 0 V 555 | stroke 556 | LTb 557 | 0.13 0.13 0.13 C 602 3229 M 558 | 63 0 V 559 | 6282 0 R 560 | -63 0 V 561 | stroke 562 | 518 3229 M 563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 50)] 564 | ] -46.7 MRshow 565 | 1.000 UL 566 | LTb 567 | 0.13 0.13 0.13 C 1.000 UL 568 | LTa 569 | LCa setrgbcolor 570 | 602 3924 M 571 | 6345 0 V 572 | stroke 573 | LTb 574 | 0.13 0.13 0.13 C 602 3924 M 575 | 63 0 V 576 | 6282 0 R 577 | -63 0 V 578 | stroke 579 | 518 3924 M 580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 55)] 581 | ] -46.7 MRshow 582 | 1.000 UL 583 | LTb 584 | 0.13 0.13 0.13 C 1.000 UL 585 | LTa 586 | LCa setrgbcolor 587 | 602 4619 M 588 | 6345 0 V 589 | stroke 590 | LTb 591 | 0.13 0.13 0.13 C 602 4619 M 592 | 63 0 V 593 | 6282 0 R 594 | -63 0 V 595 | stroke 596 | 518 4619 M 597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 60)] 598 | ] -46.7 MRshow 599 | 1.000 UL 600 | LTb 601 | 0.13 0.13 0.13 C 1.000 UL 602 | LTa 603 | LCa setrgbcolor 604 | 602 448 M 605 | 0 4171 V 606 | stroke 607 | LTb 608 | 0.13 0.13 0.13 C 602 448 M 609 | 0 63 V 610 | 0 4108 R 611 | 0 -63 V 612 | stroke 613 | 602 308 M 614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)] 615 | ] -46.7 MCshow 616 | 1.000 UL 617 | LTb 618 | 0.13 0.13 0.13 C 1.000 UL 619 | LTa 620 | LCa setrgbcolor 621 | 1871 448 M 622 | 0 4171 V 623 | stroke 624 | LTb 625 | 0.13 0.13 0.13 C 1871 448 M 626 | 0 63 V 627 | 0 4108 R 628 | 0 -63 V 629 | stroke 630 | 1871 308 M 631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)] 632 | ] -46.7 MCshow 633 | 1.000 UL 634 | LTb 635 | 0.13 0.13 0.13 C 1.000 UL 636 | LTa 637 | LCa setrgbcolor 638 | 3140 448 M 639 | 0 4171 V 640 | stroke 641 | LTb 642 | 0.13 0.13 0.13 C 3140 448 M 643 | 0 63 V 644 | 0 4108 R 645 | 0 -63 V 646 | stroke 647 | 3140 308 M 648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)] 649 | ] -46.7 MCshow 650 | 1.000 UL 651 | LTb 652 | 0.13 0.13 0.13 C 1.000 UL 653 | LTa 654 | LCa setrgbcolor 655 | 4409 448 M 656 | 0 4171 V 657 | stroke 658 | LTb 659 | 0.13 0.13 0.13 C 4409 448 M 660 | 0 63 V 661 | 0 4108 R 662 | 0 -63 V 663 | stroke 664 | 4409 308 M 665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)] 666 | ] -46.7 MCshow 667 | 1.000 UL 668 | LTb 669 | 0.13 0.13 0.13 C 1.000 UL 670 | LTa 671 | LCa setrgbcolor 672 | 5678 448 M 673 | 0 4171 V 674 | stroke 675 | LTb 676 | 0.13 0.13 0.13 C 5678 448 M 677 | 0 63 V 678 | 0 4108 R 679 | 0 -63 V 680 | stroke 681 | 5678 308 M 682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)] 683 | ] -46.7 MCshow 684 | 1.000 UL 685 | LTb 686 | 0.13 0.13 0.13 C 1.000 UL 687 | LTa 688 | LCa setrgbcolor 689 | 6947 448 M 690 | 0 4171 V 691 | stroke 692 | LTb 693 | 0.13 0.13 0.13 C 6947 448 M 694 | 0 63 V 695 | 0 4108 R 696 | 0 -63 V 697 | stroke 698 | 6947 308 M 699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)] 700 | ] -46.7 MCshow 701 | 1.000 UL 702 | LTb 703 | 0.13 0.13 0.13 C 1.000 UL 704 | LTb 705 | 0.13 0.13 0.13 C 602 4619 N 706 | 602 448 L 707 | 6345 0 V 708 | 0 4171 V 709 | -6345 0 V 710 | Z stroke 711 | LCb setrgbcolor 712 | 112 2533 M 713 | currentpoint gsave translate -270 rotate 0 0 moveto 714 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)] 715 | ] -46.7 MCshow 716 | grestore 717 | LTb 718 | LCb setrgbcolor 719 | 3774 98 M 720 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)] 721 | ] -46.7 MCshow 722 | LTb 723 | 3774 4829 M 724 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)] 725 | ] -46.7 MCshow 726 | 1.000 UP 727 | 1.000 UL 728 | LTb 729 | 0.13 0.13 0.13 C % Begin plot #1 730 | 1.000 UP 731 | 2.000 UL 732 | LT0 733 | 0.11 0.27 0.60 C 602 1604 M 734 | 635 -594 V 735 | 634 116 V 736 | 635 722 V 737 | 634 -474 V 738 | 635 -238 V 739 | 634 1132 V 740 | 635 -215 V 741 | 634 952 V 742 | 635 -471 V 743 | 634 1835 V 744 | 602 1604 CircleF 745 | 1237 1010 CircleF 746 | 1871 1126 CircleF 747 | 2506 1848 CircleF 748 | 3140 1374 CircleF 749 | 3775 1136 CircleF 750 | 4409 2268 CircleF 751 | 5044 2053 CircleF 752 | 5678 3005 CircleF 753 | 6313 2534 CircleF 754 | 6947 4369 CircleF 755 | % End plot #1 756 | 1.000 UL 757 | LTb 758 | 0.13 0.13 0.13 C 602 4619 N 759 | 602 448 L 760 | 6345 0 V 761 | 0 4171 V 762 | -6345 0 V 763 | Z stroke 764 | 1.000 UP 765 | 1.000 UL 766 | LTb 767 | 0.13 0.13 0.13 C stroke 768 | grestore 769 | end 770 | showpage 771 | %%Trailer 772 | %%DocumentFonts: Helvetica 773 | -------------------------------------------------------------------------------- /images/q4a.eps: -------------------------------------------------------------------------------- 1 | %!PS-Adobe-2.0 EPSF-2.0 2 | %%Title: q4a.eps 3 | %%Creator: gnuplot 4.6 patchlevel 4 4 | %%CreationDate: Wed Jan 20 18:25:52 2016 5 | %%DocumentFonts: (atend) 6 | %%BoundingBox: 50 50 410 302 7 | %%EndComments 8 | %%BeginProlog 9 | /gnudict 256 dict def 10 | gnudict begin 11 | % 12 | % The following true/false flags may be edited by hand if desired. 13 | % The unit line width and grayscale image gamma correction may also be changed. 14 | % 15 | /Color true def 16 | /Blacktext false def 17 | /Solid false def 18 | /Dashlength 1 def 19 | /Landscape false def 20 | /Level1 false def 21 | /Rounded false def 22 | /ClipToBoundingBox false def 23 | /SuppressPDFMark false def 24 | /TransparentPatterns false def 25 | /gnulinewidth 5.000 def 26 | /userlinewidth gnulinewidth def 27 | /Gamma 1.0 def 28 | /BackgroundColor {-1.000 -1.000 -1.000} def 29 | % 30 | /vshift -46 def 31 | /dl1 { 32 | 10.0 Dashlength mul mul 33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if 34 | } def 35 | /dl2 { 36 | 10.0 Dashlength mul mul 37 | Rounded { currentlinewidth 0.75 mul add } if 38 | } def 39 | /hpt_ 31.5 def 40 | /vpt_ 31.5 def 41 | /hpt hpt_ def 42 | /vpt vpt_ def 43 | /doclip { 44 | ClipToBoundingBox { 45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath 46 | clip 47 | } if 48 | } def 49 | % 50 | % Gnuplot Prolog Version 4.6 (September 2012) 51 | % 52 | %/SuppressPDFMark true def 53 | % 54 | /M {moveto} bind def 55 | /L {lineto} bind def 56 | /R {rmoveto} bind def 57 | /V {rlineto} bind def 58 | /N {newpath moveto} bind def 59 | /Z {closepath} bind def 60 | /C {setrgbcolor} bind def 61 | /f {rlineto fill} bind def 62 | /g {setgray} bind def 63 | /Gshow {show} def % May be redefined later in the file to support UTF-8 64 | /vpt2 vpt 2 mul def 65 | /hpt2 hpt 2 mul def 66 | /Lshow {currentpoint stroke M 0 vshift R 67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R 69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R 71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def 73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def 74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash} 75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def 76 | /BL {stroke userlinewidth 2 mul setlinewidth 77 | Rounded {1 setlinejoin 1 setlinecap} if} def 78 | /AL {stroke userlinewidth 2 div setlinewidth 79 | Rounded {1 setlinejoin 1 setlinecap} if} def 80 | /UL {dup gnulinewidth mul /userlinewidth exch def 81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def 82 | /PL {stroke userlinewidth setlinewidth 83 | Rounded {1 setlinejoin 1 setlinecap} if} def 84 | 3.8 setmiterlimit 85 | % Default Line colors 86 | /LCw {1 1 1} def 87 | /LCb {0 0 0} def 88 | /LCa {0 0 0} def 89 | /LC0 {1 0 0} def 90 | /LC1 {0 1 0} def 91 | /LC2 {0 0 1} def 92 | /LC3 {1 0 1} def 93 | /LC4 {0 1 1} def 94 | /LC5 {1 1 0} def 95 | /LC6 {0 0 0} def 96 | /LC7 {1 0.3 0} def 97 | /LC8 {0.5 0.5 0.5} def 98 | % Default Line Types 99 | /LTw {PL [] 1 setgray} def 100 | /LTb {BL [] LCb DL} def 101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def 102 | /LT0 {PL [] LC0 DL} def 103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def 104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def 105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def 106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def 107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def 108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def 109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def 110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def 111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def 112 | /Dia {stroke [] 0 setdash 2 copy vpt add M 113 | hpt neg vpt neg V hpt vpt neg V 114 | hpt vpt V hpt neg vpt V closepath stroke 115 | Pnt} def 116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V 117 | currentpoint stroke M 118 | hpt neg vpt neg R hpt2 0 V stroke 119 | } def 120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M 121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 122 | hpt2 neg 0 V closepath stroke 123 | Pnt} def 124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M 125 | hpt2 vpt2 neg V currentpoint stroke M 126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def 127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M 128 | hpt neg vpt -1.62 mul V 129 | hpt 2 mul 0 V 130 | hpt neg vpt 1.62 mul V closepath stroke 131 | Pnt} def 132 | /Star {2 copy Pls Crs} def 133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M 134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 135 | hpt2 neg 0 V closepath fill} def 136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M 137 | hpt neg vpt -1.62 mul V 138 | hpt 2 mul 0 V 139 | hpt neg vpt 1.62 mul V closepath fill} def 140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M 141 | hpt neg vpt 1.62 mul V 142 | hpt 2 mul 0 V 143 | hpt neg vpt -1.62 mul V closepath stroke 144 | Pnt} def 145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M 146 | hpt neg vpt 1.62 mul V 147 | hpt 2 mul 0 V 148 | hpt neg vpt -1.62 mul V closepath fill} def 149 | /DiaF {stroke [] 0 setdash vpt add M 150 | hpt neg vpt neg V hpt vpt neg V 151 | hpt vpt V hpt neg vpt V closepath fill} def 152 | /Pent {stroke [] 0 setdash 2 copy gsave 153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 154 | closepath stroke grestore Pnt} def 155 | /PentF {stroke [] 0 setdash gsave 156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 157 | closepath fill grestore} def 158 | /Circle {stroke [] 0 setdash 2 copy 159 | hpt 0 360 arc stroke Pnt} def 160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def 161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def 162 | /C1 {BL [] 0 setdash 2 copy moveto 163 | 2 copy vpt 0 90 arc closepath fill 164 | vpt 0 360 arc closepath} bind def 165 | /C2 {BL [] 0 setdash 2 copy moveto 166 | 2 copy vpt 90 180 arc closepath fill 167 | vpt 0 360 arc closepath} bind def 168 | /C3 {BL [] 0 setdash 2 copy moveto 169 | 2 copy vpt 0 180 arc closepath fill 170 | vpt 0 360 arc closepath} bind def 171 | /C4 {BL [] 0 setdash 2 copy moveto 172 | 2 copy vpt 180 270 arc closepath fill 173 | vpt 0 360 arc closepath} bind def 174 | /C5 {BL [] 0 setdash 2 copy moveto 175 | 2 copy vpt 0 90 arc 176 | 2 copy moveto 177 | 2 copy vpt 180 270 arc closepath fill 178 | vpt 0 360 arc} bind def 179 | /C6 {BL [] 0 setdash 2 copy moveto 180 | 2 copy vpt 90 270 arc closepath fill 181 | vpt 0 360 arc closepath} bind def 182 | /C7 {BL [] 0 setdash 2 copy moveto 183 | 2 copy vpt 0 270 arc closepath fill 184 | vpt 0 360 arc closepath} bind def 185 | /C8 {BL [] 0 setdash 2 copy moveto 186 | 2 copy vpt 270 360 arc closepath fill 187 | vpt 0 360 arc closepath} bind def 188 | /C9 {BL [] 0 setdash 2 copy moveto 189 | 2 copy vpt 270 450 arc closepath fill 190 | vpt 0 360 arc closepath} bind def 191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill 192 | 2 copy moveto 193 | 2 copy vpt 90 180 arc closepath fill 194 | vpt 0 360 arc closepath} bind def 195 | /C11 {BL [] 0 setdash 2 copy moveto 196 | 2 copy vpt 0 180 arc closepath fill 197 | 2 copy moveto 198 | 2 copy vpt 270 360 arc closepath fill 199 | vpt 0 360 arc closepath} bind def 200 | /C12 {BL [] 0 setdash 2 copy moveto 201 | 2 copy vpt 180 360 arc closepath fill 202 | vpt 0 360 arc closepath} bind def 203 | /C13 {BL [] 0 setdash 2 copy moveto 204 | 2 copy vpt 0 90 arc closepath fill 205 | 2 copy moveto 206 | 2 copy vpt 180 360 arc closepath fill 207 | vpt 0 360 arc closepath} bind def 208 | /C14 {BL [] 0 setdash 2 copy moveto 209 | 2 copy vpt 90 360 arc closepath fill 210 | vpt 0 360 arc} bind def 211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill 212 | vpt 0 360 arc closepath} bind def 213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto 214 | neg 0 rlineto closepath} bind def 215 | /Square {dup Rec} bind def 216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def 217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def 218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def 219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def 221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill 223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def 225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill 226 | 2 copy vpt Square fill Bsquare} bind def 227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def 228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def 229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill 230 | Bsquare} bind def 231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill 232 | Bsquare} bind def 233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def 234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 235 | 2 copy vpt Square fill Bsquare} bind def 236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def 239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def 240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def 241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def 242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def 243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def 244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def 245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def 246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def 247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def 248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def 249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def 250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def 251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def 252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def 253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def 254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def 255 | /DiaE {stroke [] 0 setdash vpt add M 256 | hpt neg vpt neg V hpt vpt neg V 257 | hpt vpt V hpt neg vpt V closepath stroke} def 258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M 259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 260 | hpt2 neg 0 V closepath stroke} def 261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M 262 | hpt neg vpt -1.62 mul V 263 | hpt 2 mul 0 V 264 | hpt neg vpt 1.62 mul V closepath stroke} def 265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M 266 | hpt neg vpt 1.62 mul V 267 | hpt 2 mul 0 V 268 | hpt neg vpt -1.62 mul V closepath stroke} def 269 | /PentE {stroke [] 0 setdash gsave 270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 271 | closepath stroke grestore} def 272 | /CircE {stroke [] 0 setdash 273 | hpt 0 360 arc stroke} def 274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def 275 | /DiaW {stroke [] 0 setdash vpt add M 276 | hpt neg vpt neg V hpt vpt neg V 277 | hpt vpt V hpt neg vpt V Opaque stroke} def 278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M 279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 280 | hpt2 neg 0 V Opaque stroke} def 281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M 282 | hpt neg vpt -1.62 mul V 283 | hpt 2 mul 0 V 284 | hpt neg vpt 1.62 mul V Opaque stroke} def 285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M 286 | hpt neg vpt 1.62 mul V 287 | hpt 2 mul 0 V 288 | hpt neg vpt -1.62 mul V Opaque stroke} def 289 | /PentW {stroke [] 0 setdash gsave 290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 291 | Opaque stroke grestore} def 292 | /CircW {stroke [] 0 setdash 293 | hpt 0 360 arc Opaque stroke} def 294 | /BoxFill {gsave Rec 1 setgray fill grestore} def 295 | /Density { 296 | /Fillden exch def 297 | currentrgbcolor 298 | /ColB exch def /ColG exch def /ColR exch def 299 | /ColR ColR Fillden mul Fillden sub 1 add def 300 | /ColG ColG Fillden mul Fillden sub 1 add def 301 | /ColB ColB Fillden mul Fillden sub 1 add def 302 | ColR ColG ColB setrgbcolor} def 303 | /BoxColFill {gsave Rec PolyFill} def 304 | /PolyFill {gsave Density fill grestore grestore} def 305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def 306 | % 307 | % PostScript Level 1 Pattern Fill routine for rectangles 308 | % Usage: x y w h s a XX PatternFill 309 | % x,y = lower left corner of box to be filled 310 | % w,h = width and height of box 311 | % a = angle in degrees between lines and x-axis 312 | % XX = 0/1 for no/yes cross-hatch 313 | % 314 | /PatternFill {gsave /PFa [ 9 2 roll ] def 315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate 316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec 317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 318 | clip 319 | currentlinewidth 0.5 mul setlinewidth 320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def 321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate 322 | 0 1 PFs PFa 4 get div 1 add floor cvi 323 | {PFa 4 get mul 0 M 0 PFs V} for 324 | 0 PFa 6 get ne { 325 | 0 1 PFs PFa 4 get div 1 add floor cvi 326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for 327 | } if 328 | stroke grestore} def 329 | % 330 | /languagelevel where 331 | {pop languagelevel} {1} ifelse 332 | 2 lt 333 | {/InterpretLevel1 true def} 334 | {/InterpretLevel1 Level1 def} 335 | ifelse 336 | % 337 | % PostScript level 2 pattern fill definitions 338 | % 339 | /Level2PatternFill { 340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8} 341 | bind def 342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def 343 | << Tile8x8 344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke} 345 | >> matrix makepattern 346 | /Pat1 exch def 347 | << Tile8x8 348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke 349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke} 350 | >> matrix makepattern 351 | /Pat2 exch def 352 | << Tile8x8 353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L 354 | 8 8 L 8 0 L 0 0 L fill} 355 | >> matrix makepattern 356 | /Pat3 exch def 357 | << Tile8x8 358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L 359 | 0 12 M 12 0 L stroke} 360 | >> matrix makepattern 361 | /Pat4 exch def 362 | << Tile8x8 363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L 364 | 0 -4 M 12 8 L stroke} 365 | >> matrix makepattern 366 | /Pat5 exch def 367 | << Tile8x8 368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L 369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke} 370 | >> matrix makepattern 371 | /Pat6 exch def 372 | << Tile8x8 373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L 374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke} 375 | >> matrix makepattern 376 | /Pat7 exch def 377 | << Tile8x8 378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L 379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke} 380 | >> matrix makepattern 381 | /Pat8 exch def 382 | << Tile8x8 383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L 384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke} 385 | >> matrix makepattern 386 | /Pat9 exch def 387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def 388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def 389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def 390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def 391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def 392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def 393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def 394 | } def 395 | % 396 | % 397 | %End of PostScript Level 2 code 398 | % 399 | /PatternBgnd { 400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 401 | } def 402 | % 403 | % Substitute for Level 2 pattern fill codes with 404 | % grayscale if Level 2 support is not selected. 405 | % 406 | /Level1PatternFill { 407 | /Pattern1 {0.250 Density} bind def 408 | /Pattern2 {0.500 Density} bind def 409 | /Pattern3 {0.750 Density} bind def 410 | /Pattern4 {0.125 Density} bind def 411 | /Pattern5 {0.375 Density} bind def 412 | /Pattern6 {0.625 Density} bind def 413 | /Pattern7 {0.875 Density} bind def 414 | } def 415 | % 416 | % Now test for support of Level 2 code 417 | % 418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse 419 | % 420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont 421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall 422 | currentdict end definefont pop 423 | /MFshow { 424 | { dup 5 get 3 ge 425 | { 5 get 3 eq {gsave} {grestore} ifelse } 426 | {dup dup 0 get findfont exch 1 get scalefont setfont 427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6 428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq 429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5 430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div 431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get 432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop 433 | pop aload pop M} ifelse }ifelse }ifelse } 434 | ifelse } 435 | forall} def 436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def 437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse } 438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont 439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def 440 | /MLshow { currentpoint stroke M 441 | 0 exch R 442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 443 | /MRshow { currentpoint stroke M 444 | exch dup MFwidth neg 3 -1 roll R 445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 446 | /MCshow { currentpoint stroke M 447 | exch dup MFwidth -2 div 3 -1 roll R 448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def 450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def 451 | Level1 SuppressPDFMark or 452 | {} { 453 | /SDict 10 dict def 454 | systemdict /pdfmark known not { 455 | userdict /pdfmark systemdict /cleartomark get put 456 | } if 457 | SDict begin [ 458 | /Title (q4a.eps) 459 | /Subject (gnuplot plot) 460 | /Creator (gnuplot 4.6 patchlevel 4) 461 | /Author (vitchyr) 462 | % /Producer (gnuplot) 463 | % /Keywords () 464 | /CreationDate (Wed Jan 20 18:25:52 2016) 465 | /DOCINFO pdfmark 466 | end 467 | } ifelse 468 | end 469 | %%EndProlog 470 | %%Page: 1 1 471 | gnudict begin 472 | gsave 473 | doclip 474 | 50 50 translate 475 | 0.050 0.050 scale 476 | 0 setgray 477 | newpath 478 | (Helvetica) findfont 140 scalefont setfont 479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if 480 | 1.000 UL 481 | LTb 482 | 0.13 0.13 0.13 C 1.000 UL 483 | LTa 484 | LCa setrgbcolor 485 | 602 448 M 486 | 6345 0 V 487 | stroke 488 | LTb 489 | 0.13 0.13 0.13 C 602 448 M 490 | 63 0 V 491 | 6282 0 R 492 | -63 0 V 493 | stroke 494 | 518 448 M 495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 44)] 496 | ] -46.7 MRshow 497 | 1.000 UL 498 | LTb 499 | 0.13 0.13 0.13 C 1.000 UL 500 | LTa 501 | LCa setrgbcolor 502 | 602 1044 M 503 | 6345 0 V 504 | stroke 505 | LTb 506 | 0.13 0.13 0.13 C 602 1044 M 507 | 63 0 V 508 | 6282 0 R 509 | -63 0 V 510 | stroke 511 | 518 1044 M 512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 46)] 513 | ] -46.7 MRshow 514 | 1.000 UL 515 | LTb 516 | 0.13 0.13 0.13 C 1.000 UL 517 | LTa 518 | LCa setrgbcolor 519 | 602 1640 M 520 | 6345 0 V 521 | stroke 522 | LTb 523 | 0.13 0.13 0.13 C 602 1640 M 524 | 63 0 V 525 | 6282 0 R 526 | -63 0 V 527 | stroke 528 | 518 1640 M 529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 48)] 530 | ] -46.7 MRshow 531 | 1.000 UL 532 | LTb 533 | 0.13 0.13 0.13 C 1.000 UL 534 | LTa 535 | LCa setrgbcolor 536 | 602 2236 M 537 | 6345 0 V 538 | stroke 539 | LTb 540 | 0.13 0.13 0.13 C 602 2236 M 541 | 63 0 V 542 | 6282 0 R 543 | -63 0 V 544 | stroke 545 | 518 2236 M 546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 50)] 547 | ] -46.7 MRshow 548 | 1.000 UL 549 | LTb 550 | 0.13 0.13 0.13 C 1.000 UL 551 | LTa 552 | LCa setrgbcolor 553 | 602 2831 M 554 | 6345 0 V 555 | stroke 556 | LTb 557 | 0.13 0.13 0.13 C 602 2831 M 558 | 63 0 V 559 | 6282 0 R 560 | -63 0 V 561 | stroke 562 | 518 2831 M 563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 52)] 564 | ] -46.7 MRshow 565 | 1.000 UL 566 | LTb 567 | 0.13 0.13 0.13 C 1.000 UL 568 | LTa 569 | LCa setrgbcolor 570 | 602 3427 M 571 | 6345 0 V 572 | stroke 573 | LTb 574 | 0.13 0.13 0.13 C 602 3427 M 575 | 63 0 V 576 | 6282 0 R 577 | -63 0 V 578 | stroke 579 | 518 3427 M 580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 54)] 581 | ] -46.7 MRshow 582 | 1.000 UL 583 | LTb 584 | 0.13 0.13 0.13 C 1.000 UL 585 | LTa 586 | LCa setrgbcolor 587 | 602 4023 M 588 | 6345 0 V 589 | stroke 590 | LTb 591 | 0.13 0.13 0.13 C 602 4023 M 592 | 63 0 V 593 | 6282 0 R 594 | -63 0 V 595 | stroke 596 | 518 4023 M 597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 56)] 598 | ] -46.7 MRshow 599 | 1.000 UL 600 | LTb 601 | 0.13 0.13 0.13 C 1.000 UL 602 | LTa 603 | LCa setrgbcolor 604 | 602 4619 M 605 | 6345 0 V 606 | stroke 607 | LTb 608 | 0.13 0.13 0.13 C 602 4619 M 609 | 63 0 V 610 | 6282 0 R 611 | -63 0 V 612 | stroke 613 | 518 4619 M 614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 58)] 615 | ] -46.7 MRshow 616 | 1.000 UL 617 | LTb 618 | 0.13 0.13 0.13 C 1.000 UL 619 | LTa 620 | LCa setrgbcolor 621 | 602 448 M 622 | 0 4171 V 623 | stroke 624 | LTb 625 | 0.13 0.13 0.13 C 602 448 M 626 | 0 63 V 627 | 0 4108 R 628 | 0 -63 V 629 | stroke 630 | 602 308 M 631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)] 632 | ] -46.7 MCshow 633 | 1.000 UL 634 | LTb 635 | 0.13 0.13 0.13 C 1.000 UL 636 | LTa 637 | LCa setrgbcolor 638 | 1871 448 M 639 | 0 4171 V 640 | stroke 641 | LTb 642 | 0.13 0.13 0.13 C 1871 448 M 643 | 0 63 V 644 | 0 4108 R 645 | 0 -63 V 646 | stroke 647 | 1871 308 M 648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)] 649 | ] -46.7 MCshow 650 | 1.000 UL 651 | LTb 652 | 0.13 0.13 0.13 C 1.000 UL 653 | LTa 654 | LCa setrgbcolor 655 | 3140 448 M 656 | 0 4171 V 657 | stroke 658 | LTb 659 | 0.13 0.13 0.13 C 3140 448 M 660 | 0 63 V 661 | 0 4108 R 662 | 0 -63 V 663 | stroke 664 | 3140 308 M 665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)] 666 | ] -46.7 MCshow 667 | 1.000 UL 668 | LTb 669 | 0.13 0.13 0.13 C 1.000 UL 670 | LTa 671 | LCa setrgbcolor 672 | 4409 448 M 673 | 0 4171 V 674 | stroke 675 | LTb 676 | 0.13 0.13 0.13 C 4409 448 M 677 | 0 63 V 678 | 0 4108 R 679 | 0 -63 V 680 | stroke 681 | 4409 308 M 682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)] 683 | ] -46.7 MCshow 684 | 1.000 UL 685 | LTb 686 | 0.13 0.13 0.13 C 1.000 UL 687 | LTa 688 | LCa setrgbcolor 689 | 5678 448 M 690 | 0 4171 V 691 | stroke 692 | LTb 693 | 0.13 0.13 0.13 C 5678 448 M 694 | 0 63 V 695 | 0 4108 R 696 | 0 -63 V 697 | stroke 698 | 5678 308 M 699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)] 700 | ] -46.7 MCshow 701 | 1.000 UL 702 | LTb 703 | 0.13 0.13 0.13 C 1.000 UL 704 | LTa 705 | LCa setrgbcolor 706 | 6947 448 M 707 | 0 4171 V 708 | stroke 709 | LTb 710 | 0.13 0.13 0.13 C 6947 448 M 711 | 0 63 V 712 | 0 4108 R 713 | 0 -63 V 714 | stroke 715 | 6947 308 M 716 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)] 717 | ] -46.7 MCshow 718 | 1.000 UL 719 | LTb 720 | 0.13 0.13 0.13 C 1.000 UL 721 | LTb 722 | 0.13 0.13 0.13 C 602 4619 N 723 | 602 448 L 724 | 6345 0 V 725 | 0 4171 V 726 | -6345 0 V 727 | Z stroke 728 | LCb setrgbcolor 729 | 112 2533 M 730 | currentpoint gsave translate -270 rotate 0 0 moveto 731 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)] 732 | ] -46.7 MCshow 733 | grestore 734 | LTb 735 | LCb setrgbcolor 736 | 3774 98 M 737 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)] 738 | ] -46.7 MCshow 739 | LTb 740 | 3774 4829 M 741 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)] 742 | ] -46.7 MCshow 743 | 1.000 UP 744 | 1.000 UL 745 | LTb 746 | 0.13 0.13 0.13 C % Begin plot #1 747 | 1.000 UP 748 | 2.000 UL 749 | LT0 750 | 0.11 0.27 0.60 C 602 1646 M 751 | 635 1139 V 752 | 634 -696 V 753 | 2506 1038 L 754 | 634 1373 V 755 | 635 -82 V 756 | 634 1844 V 757 | 635 -626 V 758 | 634 95 V 759 | 635 -518 V 760 | 634 1402 V 761 | 602 1646 CircleF 762 | 1237 2785 CircleF 763 | 1871 2089 CircleF 764 | 2506 1038 CircleF 765 | 3140 2411 CircleF 766 | 3775 2329 CircleF 767 | 4409 4173 CircleF 768 | 5044 3547 CircleF 769 | 5678 3642 CircleF 770 | 6313 3124 CircleF 771 | 6947 4526 CircleF 772 | % End plot #1 773 | 1.000 UL 774 | LTb 775 | 0.13 0.13 0.13 C 602 4619 N 776 | 602 448 L 777 | 6345 0 V 778 | 0 4171 V 779 | -6345 0 V 780 | Z stroke 781 | 1.000 UP 782 | 1.000 UL 783 | LTb 784 | 0.13 0.13 0.13 C stroke 785 | grestore 786 | end 787 | showpage 788 | %%Trailer 789 | %%DocumentFonts: Helvetica 790 | -------------------------------------------------------------------------------- /images/q5a.eps: -------------------------------------------------------------------------------- 1 | %!PS-Adobe-2.0 EPSF-2.0 2 | %%Title: q5a.eps 3 | %%Creator: gnuplot 4.6 patchlevel 4 4 | %%CreationDate: Thu Jan 21 00:24:42 2016 5 | %%DocumentFonts: (atend) 6 | %%BoundingBox: 50 50 410 302 7 | %%EndComments 8 | %%BeginProlog 9 | /gnudict 256 dict def 10 | gnudict begin 11 | % 12 | % The following true/false flags may be edited by hand if desired. 13 | % The unit line width and grayscale image gamma correction may also be changed. 14 | % 15 | /Color true def 16 | /Blacktext false def 17 | /Solid false def 18 | /Dashlength 1 def 19 | /Landscape false def 20 | /Level1 false def 21 | /Rounded false def 22 | /ClipToBoundingBox false def 23 | /SuppressPDFMark false def 24 | /TransparentPatterns false def 25 | /gnulinewidth 5.000 def 26 | /userlinewidth gnulinewidth def 27 | /Gamma 1.0 def 28 | /BackgroundColor {-1.000 -1.000 -1.000} def 29 | % 30 | /vshift -46 def 31 | /dl1 { 32 | 10.0 Dashlength mul mul 33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if 34 | } def 35 | /dl2 { 36 | 10.0 Dashlength mul mul 37 | Rounded { currentlinewidth 0.75 mul add } if 38 | } def 39 | /hpt_ 31.5 def 40 | /vpt_ 31.5 def 41 | /hpt hpt_ def 42 | /vpt vpt_ def 43 | /doclip { 44 | ClipToBoundingBox { 45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath 46 | clip 47 | } if 48 | } def 49 | % 50 | % Gnuplot Prolog Version 4.6 (September 2012) 51 | % 52 | %/SuppressPDFMark true def 53 | % 54 | /M {moveto} bind def 55 | /L {lineto} bind def 56 | /R {rmoveto} bind def 57 | /V {rlineto} bind def 58 | /N {newpath moveto} bind def 59 | /Z {closepath} bind def 60 | /C {setrgbcolor} bind def 61 | /f {rlineto fill} bind def 62 | /g {setgray} bind def 63 | /Gshow {show} def % May be redefined later in the file to support UTF-8 64 | /vpt2 vpt 2 mul def 65 | /hpt2 hpt 2 mul def 66 | /Lshow {currentpoint stroke M 0 vshift R 67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R 69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R 71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def 72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def 73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def 74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash} 75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def 76 | /BL {stroke userlinewidth 2 mul setlinewidth 77 | Rounded {1 setlinejoin 1 setlinecap} if} def 78 | /AL {stroke userlinewidth 2 div setlinewidth 79 | Rounded {1 setlinejoin 1 setlinecap} if} def 80 | /UL {dup gnulinewidth mul /userlinewidth exch def 81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def 82 | /PL {stroke userlinewidth setlinewidth 83 | Rounded {1 setlinejoin 1 setlinecap} if} def 84 | 3.8 setmiterlimit 85 | % Default Line colors 86 | /LCw {1 1 1} def 87 | /LCb {0 0 0} def 88 | /LCa {0 0 0} def 89 | /LC0 {1 0 0} def 90 | /LC1 {0 1 0} def 91 | /LC2 {0 0 1} def 92 | /LC3 {1 0 1} def 93 | /LC4 {0 1 1} def 94 | /LC5 {1 1 0} def 95 | /LC6 {0 0 0} def 96 | /LC7 {1 0.3 0} def 97 | /LC8 {0.5 0.5 0.5} def 98 | % Default Line Types 99 | /LTw {PL [] 1 setgray} def 100 | /LTb {BL [] LCb DL} def 101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def 102 | /LT0 {PL [] LC0 DL} def 103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def 104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def 105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def 106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def 107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def 108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def 109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def 110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def 111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def 112 | /Dia {stroke [] 0 setdash 2 copy vpt add M 113 | hpt neg vpt neg V hpt vpt neg V 114 | hpt vpt V hpt neg vpt V closepath stroke 115 | Pnt} def 116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V 117 | currentpoint stroke M 118 | hpt neg vpt neg R hpt2 0 V stroke 119 | } def 120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M 121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 122 | hpt2 neg 0 V closepath stroke 123 | Pnt} def 124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M 125 | hpt2 vpt2 neg V currentpoint stroke M 126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def 127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M 128 | hpt neg vpt -1.62 mul V 129 | hpt 2 mul 0 V 130 | hpt neg vpt 1.62 mul V closepath stroke 131 | Pnt} def 132 | /Star {2 copy Pls Crs} def 133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M 134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 135 | hpt2 neg 0 V closepath fill} def 136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M 137 | hpt neg vpt -1.62 mul V 138 | hpt 2 mul 0 V 139 | hpt neg vpt 1.62 mul V closepath fill} def 140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M 141 | hpt neg vpt 1.62 mul V 142 | hpt 2 mul 0 V 143 | hpt neg vpt -1.62 mul V closepath stroke 144 | Pnt} def 145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M 146 | hpt neg vpt 1.62 mul V 147 | hpt 2 mul 0 V 148 | hpt neg vpt -1.62 mul V closepath fill} def 149 | /DiaF {stroke [] 0 setdash vpt add M 150 | hpt neg vpt neg V hpt vpt neg V 151 | hpt vpt V hpt neg vpt V closepath fill} def 152 | /Pent {stroke [] 0 setdash 2 copy gsave 153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 154 | closepath stroke grestore Pnt} def 155 | /PentF {stroke [] 0 setdash gsave 156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 157 | closepath fill grestore} def 158 | /Circle {stroke [] 0 setdash 2 copy 159 | hpt 0 360 arc stroke Pnt} def 160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def 161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def 162 | /C1 {BL [] 0 setdash 2 copy moveto 163 | 2 copy vpt 0 90 arc closepath fill 164 | vpt 0 360 arc closepath} bind def 165 | /C2 {BL [] 0 setdash 2 copy moveto 166 | 2 copy vpt 90 180 arc closepath fill 167 | vpt 0 360 arc closepath} bind def 168 | /C3 {BL [] 0 setdash 2 copy moveto 169 | 2 copy vpt 0 180 arc closepath fill 170 | vpt 0 360 arc closepath} bind def 171 | /C4 {BL [] 0 setdash 2 copy moveto 172 | 2 copy vpt 180 270 arc closepath fill 173 | vpt 0 360 arc closepath} bind def 174 | /C5 {BL [] 0 setdash 2 copy moveto 175 | 2 copy vpt 0 90 arc 176 | 2 copy moveto 177 | 2 copy vpt 180 270 arc closepath fill 178 | vpt 0 360 arc} bind def 179 | /C6 {BL [] 0 setdash 2 copy moveto 180 | 2 copy vpt 90 270 arc closepath fill 181 | vpt 0 360 arc closepath} bind def 182 | /C7 {BL [] 0 setdash 2 copy moveto 183 | 2 copy vpt 0 270 arc closepath fill 184 | vpt 0 360 arc closepath} bind def 185 | /C8 {BL [] 0 setdash 2 copy moveto 186 | 2 copy vpt 270 360 arc closepath fill 187 | vpt 0 360 arc closepath} bind def 188 | /C9 {BL [] 0 setdash 2 copy moveto 189 | 2 copy vpt 270 450 arc closepath fill 190 | vpt 0 360 arc closepath} bind def 191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill 192 | 2 copy moveto 193 | 2 copy vpt 90 180 arc closepath fill 194 | vpt 0 360 arc closepath} bind def 195 | /C11 {BL [] 0 setdash 2 copy moveto 196 | 2 copy vpt 0 180 arc closepath fill 197 | 2 copy moveto 198 | 2 copy vpt 270 360 arc closepath fill 199 | vpt 0 360 arc closepath} bind def 200 | /C12 {BL [] 0 setdash 2 copy moveto 201 | 2 copy vpt 180 360 arc closepath fill 202 | vpt 0 360 arc closepath} bind def 203 | /C13 {BL [] 0 setdash 2 copy moveto 204 | 2 copy vpt 0 90 arc closepath fill 205 | 2 copy moveto 206 | 2 copy vpt 180 360 arc closepath fill 207 | vpt 0 360 arc closepath} bind def 208 | /C14 {BL [] 0 setdash 2 copy moveto 209 | 2 copy vpt 90 360 arc closepath fill 210 | vpt 0 360 arc} bind def 211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill 212 | vpt 0 360 arc closepath} bind def 213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto 214 | neg 0 rlineto closepath} bind def 215 | /Square {dup Rec} bind def 216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def 217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def 218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def 219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def 221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill 223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def 224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def 225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill 226 | 2 copy vpt Square fill Bsquare} bind def 227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def 228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def 229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill 230 | Bsquare} bind def 231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill 232 | Bsquare} bind def 233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def 234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 235 | 2 copy vpt Square fill Bsquare} bind def 236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill 237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def 238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def 239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def 240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def 241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def 242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def 243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def 244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def 245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def 246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def 247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def 248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def 249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def 250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def 251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def 252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def 253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def 254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def 255 | /DiaE {stroke [] 0 setdash vpt add M 256 | hpt neg vpt neg V hpt vpt neg V 257 | hpt vpt V hpt neg vpt V closepath stroke} def 258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M 259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 260 | hpt2 neg 0 V closepath stroke} def 261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M 262 | hpt neg vpt -1.62 mul V 263 | hpt 2 mul 0 V 264 | hpt neg vpt 1.62 mul V closepath stroke} def 265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M 266 | hpt neg vpt 1.62 mul V 267 | hpt 2 mul 0 V 268 | hpt neg vpt -1.62 mul V closepath stroke} def 269 | /PentE {stroke [] 0 setdash gsave 270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 271 | closepath stroke grestore} def 272 | /CircE {stroke [] 0 setdash 273 | hpt 0 360 arc stroke} def 274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def 275 | /DiaW {stroke [] 0 setdash vpt add M 276 | hpt neg vpt neg V hpt vpt neg V 277 | hpt vpt V hpt neg vpt V Opaque stroke} def 278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M 279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V 280 | hpt2 neg 0 V Opaque stroke} def 281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M 282 | hpt neg vpt -1.62 mul V 283 | hpt 2 mul 0 V 284 | hpt neg vpt 1.62 mul V Opaque stroke} def 285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M 286 | hpt neg vpt 1.62 mul V 287 | hpt 2 mul 0 V 288 | hpt neg vpt -1.62 mul V Opaque stroke} def 289 | /PentW {stroke [] 0 setdash gsave 290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat 291 | Opaque stroke grestore} def 292 | /CircW {stroke [] 0 setdash 293 | hpt 0 360 arc Opaque stroke} def 294 | /BoxFill {gsave Rec 1 setgray fill grestore} def 295 | /Density { 296 | /Fillden exch def 297 | currentrgbcolor 298 | /ColB exch def /ColG exch def /ColR exch def 299 | /ColR ColR Fillden mul Fillden sub 1 add def 300 | /ColG ColG Fillden mul Fillden sub 1 add def 301 | /ColB ColB Fillden mul Fillden sub 1 add def 302 | ColR ColG ColB setrgbcolor} def 303 | /BoxColFill {gsave Rec PolyFill} def 304 | /PolyFill {gsave Density fill grestore grestore} def 305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def 306 | % 307 | % PostScript Level 1 Pattern Fill routine for rectangles 308 | % Usage: x y w h s a XX PatternFill 309 | % x,y = lower left corner of box to be filled 310 | % w,h = width and height of box 311 | % a = angle in degrees between lines and x-axis 312 | % XX = 0/1 for no/yes cross-hatch 313 | % 314 | /PatternFill {gsave /PFa [ 9 2 roll ] def 315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate 316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec 317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 318 | clip 319 | currentlinewidth 0.5 mul setlinewidth 320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def 321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate 322 | 0 1 PFs PFa 4 get div 1 add floor cvi 323 | {PFa 4 get mul 0 M 0 PFs V} for 324 | 0 PFa 6 get ne { 325 | 0 1 PFs PFa 4 get div 1 add floor cvi 326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for 327 | } if 328 | stroke grestore} def 329 | % 330 | /languagelevel where 331 | {pop languagelevel} {1} ifelse 332 | 2 lt 333 | {/InterpretLevel1 true def} 334 | {/InterpretLevel1 Level1 def} 335 | ifelse 336 | % 337 | % PostScript level 2 pattern fill definitions 338 | % 339 | /Level2PatternFill { 340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8} 341 | bind def 342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def 343 | << Tile8x8 344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke} 345 | >> matrix makepattern 346 | /Pat1 exch def 347 | << Tile8x8 348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke 349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke} 350 | >> matrix makepattern 351 | /Pat2 exch def 352 | << Tile8x8 353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L 354 | 8 8 L 8 0 L 0 0 L fill} 355 | >> matrix makepattern 356 | /Pat3 exch def 357 | << Tile8x8 358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L 359 | 0 12 M 12 0 L stroke} 360 | >> matrix makepattern 361 | /Pat4 exch def 362 | << Tile8x8 363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L 364 | 0 -4 M 12 8 L stroke} 365 | >> matrix makepattern 366 | /Pat5 exch def 367 | << Tile8x8 368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L 369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke} 370 | >> matrix makepattern 371 | /Pat6 exch def 372 | << Tile8x8 373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L 374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke} 375 | >> matrix makepattern 376 | /Pat7 exch def 377 | << Tile8x8 378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L 379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke} 380 | >> matrix makepattern 381 | /Pat8 exch def 382 | << Tile8x8 383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L 384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke} 385 | >> matrix makepattern 386 | /Pat9 exch def 387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def 388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def 389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def 390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def 391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def 392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def 393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def 394 | } def 395 | % 396 | % 397 | %End of PostScript Level 2 code 398 | % 399 | /PatternBgnd { 400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse 401 | } def 402 | % 403 | % Substitute for Level 2 pattern fill codes with 404 | % grayscale if Level 2 support is not selected. 405 | % 406 | /Level1PatternFill { 407 | /Pattern1 {0.250 Density} bind def 408 | /Pattern2 {0.500 Density} bind def 409 | /Pattern3 {0.750 Density} bind def 410 | /Pattern4 {0.125 Density} bind def 411 | /Pattern5 {0.375 Density} bind def 412 | /Pattern6 {0.625 Density} bind def 413 | /Pattern7 {0.875 Density} bind def 414 | } def 415 | % 416 | % Now test for support of Level 2 code 417 | % 418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse 419 | % 420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont 421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall 422 | currentdict end definefont pop 423 | /MFshow { 424 | { dup 5 get 3 ge 425 | { 5 get 3 eq {gsave} {grestore} ifelse } 426 | {dup dup 0 get findfont exch 1 get scalefont setfont 427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6 428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq 429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5 430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div 431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get 432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop 433 | pop aload pop M} ifelse }ifelse }ifelse } 434 | ifelse } 435 | forall} def 436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def 437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse } 438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont 439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def 440 | /MLshow { currentpoint stroke M 441 | 0 exch R 442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 443 | /MRshow { currentpoint stroke M 444 | exch dup MFwidth neg 3 -1 roll R 445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 446 | /MCshow { currentpoint stroke M 447 | exch dup MFwidth -2 div 3 -1 roll R 448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def 449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def 450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def 451 | Level1 SuppressPDFMark or 452 | {} { 453 | /SDict 10 dict def 454 | systemdict /pdfmark known not { 455 | userdict /pdfmark systemdict /cleartomark get put 456 | } if 457 | SDict begin [ 458 | /Title (q5a.eps) 459 | /Subject (gnuplot plot) 460 | /Creator (gnuplot 4.6 patchlevel 4) 461 | /Author (vitchyr) 462 | % /Producer (gnuplot) 463 | % /Keywords () 464 | /CreationDate (Thu Jan 21 00:24:42 2016) 465 | /DOCINFO pdfmark 466 | end 467 | } ifelse 468 | end 469 | %%EndProlog 470 | %%Page: 1 1 471 | gnudict begin 472 | gsave 473 | doclip 474 | 50 50 translate 475 | 0.050 0.050 scale 476 | 0 setgray 477 | newpath 478 | (Helvetica) findfont 140 scalefont setfont 479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if 480 | 1.000 UL 481 | LTb 482 | 0.13 0.13 0.13 C 1.000 UL 483 | LTa 484 | LCa setrgbcolor 485 | 602 448 M 486 | 6345 0 V 487 | stroke 488 | LTb 489 | 0.13 0.13 0.13 C 602 448 M 490 | 63 0 V 491 | 6282 0 R 492 | -63 0 V 493 | stroke 494 | 518 448 M 495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 28)] 496 | ] -46.7 MRshow 497 | 1.000 UL 498 | LTb 499 | 0.13 0.13 0.13 C 1.000 UL 500 | LTa 501 | LCa setrgbcolor 502 | 602 969 M 503 | 6345 0 V 504 | stroke 505 | LTb 506 | 0.13 0.13 0.13 C 602 969 M 507 | 63 0 V 508 | 6282 0 R 509 | -63 0 V 510 | stroke 511 | 518 969 M 512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 30)] 513 | ] -46.7 MRshow 514 | 1.000 UL 515 | LTb 516 | 0.13 0.13 0.13 C 1.000 UL 517 | LTa 518 | LCa setrgbcolor 519 | 602 1491 M 520 | 6345 0 V 521 | stroke 522 | LTb 523 | 0.13 0.13 0.13 C 602 1491 M 524 | 63 0 V 525 | 6282 0 R 526 | -63 0 V 527 | stroke 528 | 518 1491 M 529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 32)] 530 | ] -46.7 MRshow 531 | 1.000 UL 532 | LTb 533 | 0.13 0.13 0.13 C 1.000 UL 534 | LTa 535 | LCa setrgbcolor 536 | 602 2012 M 537 | 6345 0 V 538 | stroke 539 | LTb 540 | 0.13 0.13 0.13 C 602 2012 M 541 | 63 0 V 542 | 6282 0 R 543 | -63 0 V 544 | stroke 545 | 518 2012 M 546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 34)] 547 | ] -46.7 MRshow 548 | 1.000 UL 549 | LTb 550 | 0.13 0.13 0.13 C 1.000 UL 551 | LTa 552 | LCa setrgbcolor 553 | 602 2534 M 554 | 6345 0 V 555 | stroke 556 | LTb 557 | 0.13 0.13 0.13 C 602 2534 M 558 | 63 0 V 559 | 6282 0 R 560 | -63 0 V 561 | stroke 562 | 518 2534 M 563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 36)] 564 | ] -46.7 MRshow 565 | 1.000 UL 566 | LTb 567 | 0.13 0.13 0.13 C 1.000 UL 568 | LTa 569 | LCa setrgbcolor 570 | 602 3055 M 571 | 6345 0 V 572 | stroke 573 | LTb 574 | 0.13 0.13 0.13 C 602 3055 M 575 | 63 0 V 576 | 6282 0 R 577 | -63 0 V 578 | stroke 579 | 518 3055 M 580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 38)] 581 | ] -46.7 MRshow 582 | 1.000 UL 583 | LTb 584 | 0.13 0.13 0.13 C 1.000 UL 585 | LTa 586 | LCa setrgbcolor 587 | 602 3576 M 588 | 6345 0 V 589 | stroke 590 | LTb 591 | 0.13 0.13 0.13 C 602 3576 M 592 | 63 0 V 593 | 6282 0 R 594 | -63 0 V 595 | stroke 596 | 518 3576 M 597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 40)] 598 | ] -46.7 MRshow 599 | 1.000 UL 600 | LTb 601 | 0.13 0.13 0.13 C 1.000 UL 602 | LTa 603 | LCa setrgbcolor 604 | 602 4098 M 605 | 6345 0 V 606 | stroke 607 | LTb 608 | 0.13 0.13 0.13 C 602 4098 M 609 | 63 0 V 610 | 6282 0 R 611 | -63 0 V 612 | stroke 613 | 518 4098 M 614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 42)] 615 | ] -46.7 MRshow 616 | 1.000 UL 617 | LTb 618 | 0.13 0.13 0.13 C 1.000 UL 619 | LTa 620 | LCa setrgbcolor 621 | 602 4619 M 622 | 6345 0 V 623 | stroke 624 | LTb 625 | 0.13 0.13 0.13 C 602 4619 M 626 | 63 0 V 627 | 6282 0 R 628 | -63 0 V 629 | stroke 630 | 518 4619 M 631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 44)] 632 | ] -46.7 MRshow 633 | 1.000 UL 634 | LTb 635 | 0.13 0.13 0.13 C 1.000 UL 636 | LTa 637 | LCa setrgbcolor 638 | 602 448 M 639 | 0 4171 V 640 | stroke 641 | LTb 642 | 0.13 0.13 0.13 C 602 448 M 643 | 0 63 V 644 | 0 4108 R 645 | 0 -63 V 646 | stroke 647 | 602 308 M 648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)] 649 | ] -46.7 MCshow 650 | 1.000 UL 651 | LTb 652 | 0.13 0.13 0.13 C 1.000 UL 653 | LTa 654 | LCa setrgbcolor 655 | 1871 448 M 656 | 0 4171 V 657 | stroke 658 | LTb 659 | 0.13 0.13 0.13 C 1871 448 M 660 | 0 63 V 661 | 0 4108 R 662 | 0 -63 V 663 | stroke 664 | 1871 308 M 665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)] 666 | ] -46.7 MCshow 667 | 1.000 UL 668 | LTb 669 | 0.13 0.13 0.13 C 1.000 UL 670 | LTa 671 | LCa setrgbcolor 672 | 3140 448 M 673 | 0 4171 V 674 | stroke 675 | LTb 676 | 0.13 0.13 0.13 C 3140 448 M 677 | 0 63 V 678 | 0 4108 R 679 | 0 -63 V 680 | stroke 681 | 3140 308 M 682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)] 683 | ] -46.7 MCshow 684 | 1.000 UL 685 | LTb 686 | 0.13 0.13 0.13 C 1.000 UL 687 | LTa 688 | LCa setrgbcolor 689 | 4409 448 M 690 | 0 4171 V 691 | stroke 692 | LTb 693 | 0.13 0.13 0.13 C 4409 448 M 694 | 0 63 V 695 | 0 4108 R 696 | 0 -63 V 697 | stroke 698 | 4409 308 M 699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)] 700 | ] -46.7 MCshow 701 | 1.000 UL 702 | LTb 703 | 0.13 0.13 0.13 C 1.000 UL 704 | LTa 705 | LCa setrgbcolor 706 | 5678 448 M 707 | 0 4171 V 708 | stroke 709 | LTb 710 | 0.13 0.13 0.13 C 5678 448 M 711 | 0 63 V 712 | 0 4108 R 713 | 0 -63 V 714 | stroke 715 | 5678 308 M 716 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)] 717 | ] -46.7 MCshow 718 | 1.000 UL 719 | LTb 720 | 0.13 0.13 0.13 C 1.000 UL 721 | LTa 722 | LCa setrgbcolor 723 | 6947 448 M 724 | 0 4171 V 725 | stroke 726 | LTb 727 | 0.13 0.13 0.13 C 6947 448 M 728 | 0 63 V 729 | 0 4108 R 730 | 0 -63 V 731 | stroke 732 | 6947 308 M 733 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)] 734 | ] -46.7 MCshow 735 | 1.000 UL 736 | LTb 737 | 0.13 0.13 0.13 C 1.000 UL 738 | LTb 739 | 0.13 0.13 0.13 C 602 4619 N 740 | 602 448 L 741 | 6345 0 V 742 | 0 4171 V 743 | -6345 0 V 744 | Z stroke 745 | LCb setrgbcolor 746 | 112 2533 M 747 | currentpoint gsave translate -270 rotate 0 0 moveto 748 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)] 749 | ] -46.7 MCshow 750 | grestore 751 | LTb 752 | LCb setrgbcolor 753 | 3774 98 M 754 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)] 755 | ] -46.7 MCshow 756 | LTb 757 | 3774 4829 M 758 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)] 759 | ] -46.7 MCshow 760 | 1.000 UP 761 | 1.000 UL 762 | LTb 763 | 0.13 0.13 0.13 C % Begin plot #1 764 | 1.000 UP 765 | 2.000 UL 766 | LT0 767 | 0.11 0.27 0.60 C 602 910 M 768 | 635 2723 V 769 | 1871 1339 L 770 | 2506 517 L 771 | 634 3912 V 772 | 3775 1021 L 773 | 4409 813 L 774 | 635 526 V 775 | 634 780 V 776 | 6313 934 L 777 | 634 35 V 778 | 602 910 CircleF 779 | 1237 3633 CircleF 780 | 1871 1339 CircleF 781 | 2506 517 CircleF 782 | 3140 4429 CircleF 783 | 3775 1021 CircleF 784 | 4409 813 CircleF 785 | 5044 1339 CircleF 786 | 5678 2119 CircleF 787 | 6313 934 CircleF 788 | 6947 969 CircleF 789 | % End plot #1 790 | 1.000 UL 791 | LTb 792 | 0.13 0.13 0.13 C 602 4619 N 793 | 602 448 L 794 | 6345 0 V 795 | 0 4171 V 796 | -6345 0 V 797 | Z stroke 798 | 1.000 UP 799 | 1.000 UL 800 | LTb 801 | 0.13 0.13 0.13 C stroke 802 | grestore 803 | end 804 | showpage 805 | %%Trailer 806 | %%DocumentFonts: Helvetica 807 | --------------------------------------------------------------------------------