├── .gitignore
├── images
├── nn_sarsa_rms_vs_lambda.png
├── nn_sarsa_rms_vs_iteration.jpg
├── nn_sarsa_rms_vs_iteration.png
├── table_sarsa_rms_vs_lambda.png
├── table_sarsa_rms_vs_iteration.png
├── q3a.eps
├── q4a.eps
└── q5a.eps
├── util
├── io_util.lua
├── tensorutil.lua
├── mdputil.lua
├── util_for_unittests.lua
└── util.lua
├── rl_constants.lua
├── Policy.lua
├── TableSarsaFactory.lua
├── Explorer.lua
├── TestPolicy.lua
├── SarsaFactory.lua
├── AllActionsEqualPolicy.lua
├── test
├── unittest_TestPolicy.lua
├── unittest_io_util.lua
├── unittest_TableSarsaFactory.lua
├── unittest_TestMdp.lua
├── unittest_TestSAFE.lua
├── unittest_LinSarsaFactory.lua
├── unittest_AllActionsEqualPolicy.lua
├── unittest_VHash.lua
├── unittest_NNSarsaFactory.lua
├── unittest_tensorutil.lua
├── unittest_GreedyPolicy.lua
├── unittest_QHash.lua
├── unittest_TestMdpQVAnalyzer.lua
├── unittest_LinSarsa.lua
├── unittest_EpisodeBuilder.lua
├── unittest_MonteCarloControl.lua
├── unittest_NNSarsa.lua
├── unittest_QLin.lua
├── unittest_QNN.lua
├── unittest_MdpSampler.lua
├── unittest_util.lua
└── unittest_TableSarsa.lua
├── ControlFactory.lua
├── ConstExplorer.lua
├── TestSAFE.lua
├── VFunc.lua
├── DecayTableExplorer.lua
├── CMakeLists.txt
├── ValueIteration.lua
├── MdpConfig.lua
├── doc
├── montecarlo.md
├── valuefunctions.md
├── policy.md
├── mdp.md
├── sarsa.md
└── index.md
├── VHash.lua
├── GreedyPolicy.lua
├── Control.lua
├── SAFeatureExtractor.lua
├── QFunc.lua
├── Evaluator.lua
├── QHash.lua
├── run_tests.lua
├── QLin.lua
├── NNSarsaFactory.lua
├── LinSarsaFactory.lua
├── TestMdp.lua
├── QVAnalyzer.lua
├── MdpSampler.lua
├── Mdp.lua
├── rl.lua
├── MonteCarloControl.lua
├── QApprox.lua
├── LinSarsa.lua
├── EpisodeBuilder.lua
├── NNSarsa.lua
├── TestMdpQVAnalyzer.lua
├── Sarsa.lua
├── TableSarsa.lua
├── rocks
└── rl-0.1-1.rockspec
├── rl-0.2-5.rockspec
├── QNN.lua
├── README.md
└── SarsaAnalyzer.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.rock
3 |
--------------------------------------------------------------------------------
/images/nn_sarsa_rms_vs_lambda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_lambda.png
--------------------------------------------------------------------------------
/images/nn_sarsa_rms_vs_iteration.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_iteration.jpg
--------------------------------------------------------------------------------
/images/nn_sarsa_rms_vs_iteration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/nn_sarsa_rms_vs_iteration.png
--------------------------------------------------------------------------------
/images/table_sarsa_rms_vs_lambda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/table_sarsa_rms_vs_lambda.png
--------------------------------------------------------------------------------
/images/table_sarsa_rms_vs_iteration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitchyr/torch-rl/HEAD/images/table_sarsa_rms_vs_iteration.png
--------------------------------------------------------------------------------
/util/io_util.lua:
--------------------------------------------------------------------------------
1 | function rl.util.save_q(fname, q)
2 | torch.save(fname, q)
3 | end
4 |
5 | function rl.util.load_q(fname)
6 | return torch.load(fname)
7 | end
8 |
--------------------------------------------------------------------------------
/rl_constants.lua:
--------------------------------------------------------------------------------
1 | rl.EVALUATOR_DEFAULT_N_ITERS = 1000
2 | rl.MONTECARLOCONTROL_DEFAULT_N0 = 100
3 |
4 | -- Tolerance for floating point comparison
5 | rl.FLOAT_EPS = 1e-10
6 |
--------------------------------------------------------------------------------
/Policy.lua:
--------------------------------------------------------------------------------
1 | local Policy = torch.class('rl.Policy')
2 |
3 | function Policy:__init()
4 | end
5 |
6 | --- Return an action given a state
7 | function Policy:get_action(s)
8 | error('Policy must implement get_action')
9 | end
10 |
--------------------------------------------------------------------------------
/TableSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | local TableSarsaFactory, parent =
2 | torch.class('rl.TableSarsaFactory', 'rl.SarsaFactory')
3 |
4 | function TableSarsaFactory:get_control()
5 | return rl.TableSarsa(
6 | self.mdp_config,
7 | self.lambda)
8 | end
9 |
10 |
--------------------------------------------------------------------------------
/Explorer.lua:
--------------------------------------------------------------------------------
1 | local Explorer = torch.class('rl.Explorer')
2 |
3 | function Explorer:__init()
4 | end
5 |
6 | --- Return epsilon, the probability of exploring randomly, given a state
7 | function Explorer:get_eps(s)
8 | error('Explorer must implement get_eps')
9 | end
10 |
--------------------------------------------------------------------------------
/TestPolicy.lua:
--------------------------------------------------------------------------------
1 | local TestPolicy, parent = torch.class('rl.TestPolicy', 'rl.Policy')
2 |
3 | function TestPolicy:__init(action)
4 | parent.__init(self)
5 | self.action = action
6 | end
7 |
8 | function TestPolicy:get_action(s)
9 | return self.action
10 | end
11 |
--------------------------------------------------------------------------------
/SarsaFactory.lua:
--------------------------------------------------------------------------------
1 | local SarsaFactory, parent = torch.class('rl.SarsaFactory', 'rl.ControlFactory')
2 |
3 | function SarsaFactory:__init(mdp_config, lambda)
4 | parent.__init(self, mdp_config)
5 | self.lambda = lambda
6 | end
7 |
8 | function SarsaFactory:set_lambda(lambda)
9 | self.lambda = lambda
10 | end
11 |
--------------------------------------------------------------------------------
/AllActionsEqualPolicy.lua:
--------------------------------------------------------------------------------
1 | local AAEP, parent = torch.class('rl.AllActionsEqualPolicy', 'rl.Policy')
2 |
3 | function AAEP:__init(mdp)
4 | parent:__init(self)
5 | self.mdp = mdp
6 | end
7 |
8 | function AAEP:get_action(s)
9 | actions = self.mdp:get_all_actions()
10 | return actions[math.random(1, #actions)]
11 | end
12 |
13 |
--------------------------------------------------------------------------------
/test/unittest_TestPolicy.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestTestPolicy = {}
5 | function TestTestPolicy.test_correct_action()
6 | local a = 2
7 | local policy = rl.TestPolicy(a)
8 | tester:asserteq(a, policy:get_action(3))
9 | end
10 |
11 | tester:add(TestTestPolicy)
12 |
13 | tester:run()
14 |
15 |
--------------------------------------------------------------------------------
/test/unittest_io_util.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 |
5 | local TestIOUtil = {}
6 | function TestIOUtil.test_save_load()
7 | local q = rl.rl.QHash(rl.TestMdp())
8 | rl.util.save_q('/tmp/q', q)
9 | local q2 = rl.util.load_q('/tmp/q')
10 | tester:asserteq(q, q2)
11 | end
12 | tester:add(TestIOUtil)
13 |
14 | tester:run()
15 |
--------------------------------------------------------------------------------
/ControlFactory.lua:
--------------------------------------------------------------------------------
1 | local ControlFactory = torch.class('rl.ControlFactory')
2 |
3 | function ControlFactory:__init(mdp_config)
4 | self.mdp_config = mdp_config
5 | end
6 |
7 | function ControlFactory:set_mdp_config(mdp_config)
8 | self.mdp_config = mdp_config
9 | end
10 |
11 | function ControlFactory:get_control()
12 | error('Must implement get_control')
13 | end
14 |
--------------------------------------------------------------------------------
/util/tensorutil.lua:
--------------------------------------------------------------------------------
1 | -- Thanks to http://stackoverflow.com/questions/34123291/torch-apply-function-over-dimension
2 | function rl.util.apply_to_slices(tensor, dimension, func, ...)
3 | for i, slice in ipairs(tensor:split(1, dimension)) do
4 | func(slice, i, ...)
5 | end
6 | return tensor
7 | end
8 |
9 | function rl.util.fill_range(tensor, i, offset)
10 | tensor:fill(i + offset)
11 | end
12 |
13 |
--------------------------------------------------------------------------------
/ConstExplorer.lua:
--------------------------------------------------------------------------------
1 | --- Explore with a fixed probability
2 | local ConstExplorer = torch.class('rl.ConstExplorer', 'rl.Explorer')
3 |
4 | function ConstExplorer:__init(p)
5 | self.p = p
6 | end
7 |
8 | function ConstExplorer:get_eps(s)
9 | return self.p
10 | end
11 |
12 | function ConstExplorer:__eq(other)
13 | return torch.typename(self) == torch.typename(other)
14 | and self.p == other.p
15 | end
16 |
--------------------------------------------------------------------------------
/TestSAFE.lua:
--------------------------------------------------------------------------------
1 | local TestSAFE, parent = torch.class('rl.TestSAFE', 'rl.SAFeatureExtractor')
2 |
3 | -- simple feature extractor that returns the sum and difference of s and a for
4 | -- TestMdp
5 | function TestSAFE:get_sa_features(s, a)
6 | return torch.Tensor{s+a, s-a}
7 | end
8 |
9 | function TestSAFE:get_sa_features_dim()
10 | return 2
11 | end
12 |
13 | function TestSAFE:get_sa_num_features()
14 | return 2
15 | end
16 |
--------------------------------------------------------------------------------
/VFunc.lua:
--------------------------------------------------------------------------------
1 | local VFunc = torch.class('rl.VFunc')
2 |
3 | function VFunc:__init(mdp)
4 | self.mdp = mdp
5 | end
6 |
7 | function VFunc:get_value(s)
8 | error('Must implement get_Value method')
9 | end
10 |
11 | function VFunc:__eq(other)
12 | for _, s in pairs(self.mdp:get_all_states()) do
13 | if self:get_value(s) ~= other:get_value(s) then
14 | return false
15 | end
16 | end
17 | return true
18 | end
19 |
--------------------------------------------------------------------------------
/DecayTableExplorer.lua:
--------------------------------------------------------------------------------
1 | --- Choose epsilon to be N0 / N0 + (# times visited a state)
2 | local DecayTableExplorer, parent =
3 | torch.class('rl.DecayTableExplorer', 'rl.Explorer')
4 |
5 | function DecayTableExplorer:__init(N0, state_table)
6 | parent.__init(self)
7 | self.N0 = N0
8 | self.state_table = state_table
9 | end
10 |
11 | function DecayTableExplorer:get_eps(s)
12 | return self.N0 / (self.N0 + self.state_table:get_value(s))
13 | end
14 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
2 | CMAKE_POLICY(VERSION 2.6)
3 | IF(LUAROCKS_PREFIX)
4 | MESSAGE(STATUS "Installing Torch through Luarocks")
5 | STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}")
6 | MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}")
7 | ENDIF()
8 | FIND_PACKAGE(Torch REQUIRED)
9 |
10 | FILE(GLOB luasrc *.lua)
11 | ADD_TORCH_PACKAGE(rl "" "${luasrc}" "Reinforcement Learning Package")
12 |
13 |
--------------------------------------------------------------------------------
/ValueIteration.lua:
--------------------------------------------------------------------------------
1 | local ValueIteration = torch.class('rl.ValueIteration', 'rl.Control')
2 |
3 | -- ValueIteration captures an algorithm that optimizes a policy by alternating
4 | -- between policy evaluation for one step and policy iteration for one step
5 | function ValueIteration:improve_policy()
6 | self:optimize_policy()
7 | self:evaluate_policy()
8 | end
9 |
10 | function ValueIteration:optimize_policy()
11 | error('Must implement optimize_policy.')
12 | end
13 |
14 | function ValueIteration:evaluate_policy()
15 | error('Must implement evaluate_policy.')
16 | end
17 |
--------------------------------------------------------------------------------
/test/unittest_TableSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local mdp = rl.TestMdp()
5 | local discount_factor = 0.12
6 | local TestTableSarsaFactory = {}
7 | function TestTableSarsaFactory.test_get_control()
8 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
9 | local lambda = 0.126
10 |
11 | local table_sarsa = rl.TableSarsa(mdp_config, lambda)
12 | local factory = rl.TableSarsaFactory(mdp_config, lambda)
13 | tester:assert(factory:get_control() == table_sarsa)
14 | end
15 |
16 | tester:add(TestTableSarsaFactory)
17 |
18 | tester:run()
19 |
20 |
--------------------------------------------------------------------------------
/MdpConfig.lua:
--------------------------------------------------------------------------------
1 | local MdpConfig = torch.class('rl.MdpConfig')
2 |
3 | function MdpConfig:__init(mdp, discount_factor)
4 | self.mdp = mdp
5 | if discount_factor < 0 or discount_factor > 1 then
6 | error('Gamma must be between 0 and 1, inclusive.')
7 | end
8 | self.discount_factor = discount_factor
9 | end
10 |
11 | function MdpConfig:get_mdp()
12 | return self.mdp
13 | end
14 |
15 | function MdpConfig:get_discount_factor()
16 | return self.discount_factor
17 | end
18 |
19 | function MdpConfig:get_description()
20 | return self.mdp:get_description() .. "-Gamma="..self.discount_factor
21 | end
22 |
--------------------------------------------------------------------------------
/doc/montecarlo.md:
--------------------------------------------------------------------------------
1 | ## Monte Carlo Control
2 | Monte Carlo (MC) estimates the value of a state-action pair under a given
3 | policy by sampling and taking the average. MC Control alternates between this
4 | Q-function estimation and epsilon-greedy policy improvement.
5 |
6 | Example use:
7 |
8 | ```lua
9 | local mdp = TestMdp()
10 | local discount_factor = 0.9
11 | local n_iters = 1000
12 |
13 | local mdp_config = MdpConfig(mdp, discount_factor)
14 | local mc = MonteCarloControl(mdp_config)
15 | mc:improve_policy_for_n_iters(n_iters)
16 |
17 | local policy = mc:get_policy()
18 | local learned_action = policy.get_action(state)
19 | ```
20 |
21 |
--------------------------------------------------------------------------------
/test/unittest_TestMdp.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestTestMdp = {}
5 | function TestTestMdp.test_terminates()
6 | local mdp = rl.TestMdp()
7 | local s = mdp:get_start_state()
8 | s = mdp:step(s, 1)
9 | s = mdp:step(s, 1)
10 | tester:assert(mdp:is_terminal(s))
11 | end
12 |
13 | function TestTestMdp.test_reward()
14 | local mdp = rl.TestMdp()
15 | local old_s = 1
16 | local a = 1
17 | local new_s, r = mdp:step(old_s, a)
18 | tester:asserteq(r, -1)
19 |
20 | a = 3
21 | local _, r = mdp:step(new_s, a)
22 | tester:asserteq(r, 1)
23 | end
24 |
25 | tester:add(TestTestMdp)
26 |
27 | tester:run()
28 |
29 |
--------------------------------------------------------------------------------
/VHash.lua:
--------------------------------------------------------------------------------
1 | -- A slow implementation of a state value function using hashes and tables
2 |
3 | local VHash, parent = torch.class('rl.VHash', 'rl.VFunc')
4 |
5 | function VHash:__init(mdp)
6 | parent.__init(self, mdp)
7 | self.v_table = rl.util.get_all_states_map(mdp)
8 | self.hs = function (s) return mdp:hash_s(s) end
9 | end
10 |
11 | function VHash:get_value(s)
12 | return self.v_table[self.hs(s)]
13 | end
14 |
15 | function VHash:mult(s, value)
16 | self.v_table[self.hs(s)] = self.v_table[self.hs(s)] * value
17 | end
18 |
19 | function VHash:add(s, delta)
20 | self.v_table[self.hs(s)] = self.v_table[self.hs(s)] + delta
21 | end
22 |
23 | VHash.__eq = parent.__eq
24 |
--------------------------------------------------------------------------------
/util/mdputil.lua:
--------------------------------------------------------------------------------
1 | -- Get a table, where all the state keys have been initialized to 0
2 | function rl.util.get_all_states_map(mdp)
3 | all_states_map = {}
4 | for k, s in pairs(mdp:get_all_states()) do
5 | all_states_map[mdp:hash_s(s)] = 0
6 | end
7 | return all_states_map
8 | end
9 |
10 | function rl.util.get_all_states_action_map(mdp)
11 | all_states_actions_map = {}
12 | for k, s in pairs(mdp:get_all_states()) do
13 | local actions_map = {}
14 | for _, a in pairs(mdp:get_all_actions()) do
15 | actions_map[mdp:hash_a(a)] = 0
16 | end
17 | all_states_actions_map[mdp:hash_s(s)] = actions_map
18 | end
19 | return all_states_actions_map
20 | end
21 |
--------------------------------------------------------------------------------
/GreedyPolicy.lua:
--------------------------------------------------------------------------------
1 | local GreedyPolicy, parent = torch.class('rl.GreedyPolicy', 'rl.Policy')
2 |
3 | function GreedyPolicy:__init(q, exploration_strat, actions)
4 | parent.__init(self)
5 | self.q = q
6 | self.exploration_strat = exploration_strat
7 | self.actions = actions
8 | self.n_actions = #actions
9 | end
10 |
11 | function GreedyPolicy:get_action(s)
12 | local eps = self.exploration_strat:get_eps(s)
13 |
14 | pi = {}
15 | for k, a in pairs(self.actions) do
16 | pi[a] = eps / self.n_actions
17 | end
18 |
19 | -- Add 1-eps to best action
20 | best_a = self.q:get_best_action(s)
21 | pi[best_a] = pi[best_a] + 1 - eps
22 |
23 | return rl.util.weighted_random_choice(pi)
24 | end
25 |
--------------------------------------------------------------------------------
/Control.lua:
--------------------------------------------------------------------------------
1 | local Control = torch.class("rl.Control")
2 |
3 | -- Control captures an algorithm that optimizes a policy for a given MDP.
4 | function Control:__init(mdp_config)
5 | self.mdp = mdp_config:get_mdp()
6 | self.policy = rl.AllActionsEqualPolicy(self.mdp)
7 | self.sampler = rl.MdpSampler(mdp_config)
8 | end
9 |
10 | function Control:improve_policy_for_n_iters(n_iters)
11 | for i = 1, n_iters do
12 | self:improve_policy()
13 | end
14 | end
15 |
16 | function Control:improve_policy()
17 | error('Must implement improve_policy')
18 | end
19 |
20 | function Control:set_policy(policy)
21 | self.policy = policy
22 | end
23 |
24 | function Control:get_policy()
25 | return self.policy
26 | end
27 |
28 |
--------------------------------------------------------------------------------
/SAFeatureExtractor.lua:
--------------------------------------------------------------------------------
1 | -- Feature extractor for a state-action pair.
2 | local SAFeatureExtractor = torch.class('rl.SAFeatureExtractor')
3 |
4 | function SAFeatureExtractor:__init()
5 | end
6 |
7 | -- returns a Tensor that represents the features of a given state-action pair
8 | function SAFeatureExtractor:get_sa_features(s, a)
9 | error('Must implement get_sa_features')
10 | end
11 |
12 | -- returns the dimension of the output of get_sa_features
13 | function SAFeatureExtractor:get_sa_features_dim()
14 | error('Must implement get_sa_features_dim')
15 | end
16 |
17 | -- returns the num of elements in the tensor returned by get_sa_features
18 | function SAFeatureExtractor:get_sa_num_features()
19 | error('Must implement get_sa_num_features')
20 | end
21 |
--------------------------------------------------------------------------------
/QFunc.lua:
--------------------------------------------------------------------------------
1 | local QFunc = torch.class('rl.QFunc')
2 |
3 | function QFunc:__init(mdp)
4 | self.mdp = mdp
5 | end
6 |
7 | function QFunc:get_value(s, a)
8 | error('Must implement get_Value method')
9 | end
10 |
11 | function QFunc:get_best_action(s)
12 | local actions = self.mdp:get_all_actions()
13 | local Qs = self.q_table[self.hs(s)]
14 | local best_a, best_i = rl.util.max(
15 | actions,
16 | function (a) return Qs[self.ha(a)] end)
17 | return best_a
18 | end
19 |
20 | function QFunc:__eq(other)
21 | for _, s in pairs(self.mdp:get_all_states()) do
22 | for _, a in pairs(self.mdp:get_all_actions()) do
23 | if self:get_value(s, a) ~= other:get_value(s, a) then
24 | return false
25 | end
26 | end
27 | end
28 | return true
29 | end
30 |
--------------------------------------------------------------------------------
/Evaluator.lua:
--------------------------------------------------------------------------------
1 | local Evaluator = torch.class('rl.Evaluator')
2 |
3 | function Evaluator.__init(self, mdp_config)
4 | self.sampler = rl.MdpSampler(mdp_config)
5 | self.mdp_description = mdp_config:get_description()
6 | end
7 |
8 | function Evaluator:get_policy_avg_return(policy, n_iters)
9 | local total_r = 0
10 | for i = 1, n_iters do
11 | total_r = total_r + self.sampler:sample_total_reward(policy)
12 | end
13 | return total_r
14 | end
15 |
16 | function Evaluator:display_metrics(policy, description, n_iters)
17 | n_iters = n_iters or rl.EVALUATOR_DEFAULT_N_ITERS
18 | local total_r = self:get_policy_avg_return(policy, n_iters)
19 | print('Avg Reward for <' .. description .. '> policy for ' ..
20 | self.mdp_description .. ': ' ..
21 | total_r .. '/' .. n_iters .. ' = ' .. total_r/n_iters)
22 | end
23 |
--------------------------------------------------------------------------------
/QHash.lua:
--------------------------------------------------------------------------------
1 | -- A simple implementation of a state-action value function using hashes and
2 | -- tables
3 | local QHash, parent = torch.class('rl.QHash', 'rl.QFunc')
4 |
5 | function QHash:__init(mdp)
6 | parent.__init(self, mdp)
7 | self.hs = function (s) return mdp:hash_s(s) end
8 | self.ha = function (a) return mdp:hash_a(a) end
9 | self.q_table = rl.util.get_all_states_action_map(mdp)
10 | end
11 |
12 | function QHash:get_value(s, a)
13 | return self.q_table[self.hs(s)][self.ha(a)]
14 | end
15 |
16 | function QHash:mult(s, a, value)
17 | self.q_table[self.hs(s)][self.ha(a)] = self.q_table[self.hs(s)][self.ha(a)] * value
18 | end
19 |
20 | function QHash:add(s, a, delta)
21 | self.q_table[self.hs(s)][self.ha(a)] = self.q_table[self.hs(s)][self.ha(a)] + delta
22 | end
23 |
24 | -- Weird. For some reason this is needed
25 | QHash.__eq = parent.__eq
26 |
--------------------------------------------------------------------------------
/run_tests.lua:
--------------------------------------------------------------------------------
1 | require 'test.unittest_util'
2 | require 'test.unittest_tensorutil'
3 | require 'test.unittest_io_util'
4 |
5 | require 'test.unittest_MdpSampler'
6 | require 'test.unittest_EpisodeBuilder'
7 |
8 | require 'test.unittest_AllActionsEqualPolicy'
9 | require 'test.unittest_GreedyPolicy'
10 |
11 | require 'test.unittest_QHash'
12 | require 'test.unittest_VHash'
13 | require 'test.unittest_MonteCarloControl'
14 | require 'test.unittest_TableSarsa'
15 | require 'test.unittest_TableSarsaFactory'
16 |
17 | require 'test.unittest_QLin'
18 | require 'test.unittest_LinSarsa'
19 | require 'test.unittest_LinSarsaFactory'
20 |
21 | require 'test.unittest_QNN'
22 | require 'test.unittest_NNSarsa'
23 | require 'test.unittest_NNSarsaFactory'
24 |
25 | require 'test.unittest_TestMdp'
26 | require 'test.unittest_TestPolicy'
27 | require 'test.unittest_TestSAFE'
28 | require 'test.unittest_TestMdpQVAnalyzer'
29 |
--------------------------------------------------------------------------------
/test/unittest_TestSAFE.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestTestSAFE = {}
5 | local fe = rl.TestSAFE()
6 | function TestTestSAFE.test_get_sa_features()
7 | local s = 5
8 | local a = 2
9 | local expected = torch.Tensor{7, 3}
10 | tester:assertTensorEq(fe:get_sa_features(s, a), expected, 0)
11 | end
12 |
13 | function TestTestSAFE.test_get_sa_features_dim()
14 | local s = 5
15 | local a = 2
16 | local blank = torch.Tensor(fe:get_sa_features_dim())
17 | local features = fe:get_sa_features(s, a)
18 | tester:assert(rl.util.are_tensors_same_shape(blank, features))
19 | end
20 |
21 | function TestTestSAFE.test_get_sa_num_features()
22 | local s = 5
23 | local a = 2
24 | local features = fe:get_sa_features(s, a)
25 | tester:asserteq(fe:get_sa_num_features(), features:numel())
26 | end
27 |
28 | tester:add(TestTestSAFE)
29 |
30 | tester:run()
31 |
32 |
--------------------------------------------------------------------------------
/QLin.lua:
--------------------------------------------------------------------------------
1 | -- Implementation of a state-action value function approx using linear function
2 | -- of features
3 | local QLin, parent = torch.class('rl.QLin', 'rl.QApprox')
4 |
5 | function QLin:__init(mdp, feature_extractor)
6 | parent.__init(self, mdp, feature_extractor)
7 | self.weights = torch.zeros(feature_extractor:get_sa_features_dim())
8 | end
9 |
10 | function QLin:clear()
11 | self.weights = torch.zeros(self.feature_extractor:get_sa_features_dim())
12 | end
13 |
14 | function QLin:get_value(s, a)
15 | return self.weights:dot(self.feature_extractor:get_sa_features(s, a))
16 | end
17 |
18 | function QLin:add(d_weights)
19 | self.weights = self.weights + d_weights
20 | end
21 |
22 | function QLin:mult(factor)
23 | self.weights = self.weights * factor
24 | end
25 |
26 | function QLin:get_weight_vector()
27 | return self.weights
28 | end
29 |
30 | QLin.__eq = parent.__eq -- force inheritance of this
31 |
--------------------------------------------------------------------------------
/NNSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | local NNSarsaFactory, parent =
2 | torch.class('rl.NNSarsaFactory', 'rl.SarsaFactory')
3 |
4 | function NNSarsaFactory:__init(
5 | mdp_config,
6 | lambda,
7 | explorer,
8 | feature_extractor,
9 | step_size)
10 | parent.__init(self, mdp_config, lambda)
11 | self.explorer = explorer
12 | self.feature_extractor = feature_extractor
13 | self.step_size = step_size
14 | end
15 |
16 | function NNSarsaFactory:set_explorer(explorer)
17 | self.explorer = explorer
18 | end
19 |
20 | function NNSarsaFactory:set_feature_extractor(feature_extractor)
21 | self.feature_extractor = feature_extractor
22 | end
23 |
24 | function NNSarsaFactory:get_control()
25 | return rl.NNSarsa(self.mdp_config,
26 | self.lambda,
27 | self.explorer,
28 | self.feature_extractor,
29 | self.step_size)
30 | end
31 |
32 |
--------------------------------------------------------------------------------
/test/unittest_LinSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local mdp = rl.TestMdp()
5 | local discount_factor = 0.631
6 | local TestLinSarsaFactory = {}
7 | function TestLinSarsaFactory.test_get_control()
8 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
9 | local lambda = 0.65
10 | local eps = 0.2437
11 | local explorer = rl.ConstExplorer(eps)
12 | local feature_extractor = rl.TestSAFE()
13 | local step_size = 0.92
14 |
15 | local lin_sarsa = rl.LinSarsa(
16 | mdp_config,
17 | lambda,
18 | explorer,
19 | feature_extractor,
20 | step_size)
21 | local factory = rl.LinSarsaFactory(
22 | mdp_config,
23 | lambda,
24 | explorer,
25 | feature_extractor,
26 | step_size)
27 | tester:assert(factory:get_control() == lin_sarsa)
28 | end
29 |
30 | tester:add(TestLinSarsaFactory)
31 |
32 | tester:run()
33 |
34 |
--------------------------------------------------------------------------------
/LinSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | local LinSarsaFactory, parent =
2 | torch.class('rl.LinSarsaFactory', 'rl.SarsaFactory')
3 |
4 | function LinSarsaFactory:__init(
5 | mdp_config,
6 | lambda,
7 | explorer,
8 | feature_extractor,
9 | step_size)
10 | parent.__init(self, mdp_config, lambda)
11 | self.explorer = explorer
12 | self.feature_extractor = feature_extractor
13 | self.step_size = step_size
14 | end
15 |
16 | function LinSarsaFactory:set_explorer(explorer)
17 | self.explorer = explorer
18 | end
19 |
20 | function LinSarsaFactory:set_feature_extractor(feature_extractor)
21 | self.feature_extractor = feature_extractor
22 | end
23 |
24 | function LinSarsaFactory:get_control()
25 | return rl.LinSarsa(self.mdp_config,
26 | self.lambda,
27 | self.explorer,
28 | self.feature_extractor,
29 | self.step_size)
30 | end
31 |
32 |
--------------------------------------------------------------------------------
/TestMdp.lua:
--------------------------------------------------------------------------------
1 | -- Dummy MDP for testing
2 | -- state = either 1, 2, or 3
3 | -- action = either 1, 2, or 3
4 | local TestMdp, parent = torch.class('rl.TestMdp', 'rl.Mdp')
5 |
6 | local TERMINAL = 3
7 |
8 | function TestMdp:step(state, action)
9 | if TestMdp:is_terminal(state) then
10 | error('MDP is done.')
11 | end
12 | local reward = -1
13 | if state + action >= 4 then
14 | reward = 1
15 | end
16 | return state + 1, reward
17 | end
18 |
19 | function TestMdp:get_start_state()
20 | return 1
21 | end
22 |
23 | function TestMdp:is_terminal(s)
24 | return s == TERMINAL
25 | end
26 |
27 | function TestMdp:get_all_states()
28 | return {1, 2, 3}
29 | end
30 |
31 | function TestMdp:get_all_actions()
32 | return {1, 2, 3}
33 | end
34 |
35 | function TestMdp:hash_s(state)
36 | return state
37 | end
38 |
39 | function TestMdp:hash_a(action)
40 | return action
41 | end
42 |
43 | function TestMdp:get_description()
44 | return 'Test MDP'
45 | end
46 |
--------------------------------------------------------------------------------
/test/unittest_AllActionsEqualPolicy.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestAllActionsEqualPolicy = {}
5 |
6 | local function all_actions_have_good_freq(
7 | action_history,
8 | all_actions)
9 | local expected_p = 1. / #all_actions
10 | for _, action in pairs(all_actions) do
11 | if not rl.util.elem_has_good_freq(action, action_history, expected_p) then
12 | return false
13 | end
14 | end
15 | return true
16 | end
17 |
18 | function TestAllActionsEqualPolicy.test_policy()
19 | local mdp = rl.TestMdp()
20 | local policy = rl.AllActionsEqualPolicy(mdp)
21 | local N = 100000
22 | local action_history = {}
23 | for i = 1, N do
24 | action_history[i] = policy:get_action(nil)
25 | end
26 |
27 | local all_actions = mdp:get_all_actions()
28 | tester:assert(
29 | all_actions_have_good_freq(action_history, all_actions))
30 | end
31 |
32 |
33 | tester:add(TestAllActionsEqualPolicy)
34 |
35 | tester:run()
36 |
--------------------------------------------------------------------------------
/test/unittest_VHash.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 | local TestVHash = {}
5 |
6 | function TestVHash.test_add_once()
7 | local v = rl.VHash(rl.TestMdp())
8 | local s = 1
9 | local val = 2
10 | v:add(s, val)
11 |
12 | tester:asserteq(v:get_value(s), val)
13 | end
14 |
15 | function TestVHash.test_mult()
16 | local v = rl.VHash(rl.TestMdp())
17 | local s = 2
18 | v:add(s, 1)
19 | v:mult(s, 3)
20 | v:mult(s, 3)
21 | v:mult(s, 3)
22 |
23 | tester:asserteq(v:get_value(s), 27)
24 | end
25 |
26 | function TestVHash.test_equality()
27 | local v1 = rl.VHash(rl.TestMdp())
28 | local s = 2
29 | v1:add(s, 1)
30 | v1:mult(s, 3)
31 | v1:mult(s, 3)
32 | v1:mult(s, 3)
33 |
34 | local v2 = rl.VHash(rl.TestMdp())
35 | local s = 2
36 | v2:add(s, 5)
37 | v2:mult(s, 0)
38 | v2:add(s, 2)
39 | v2:add(s, -1)
40 | v2:mult(s, 3)
41 | v2:mult(s, 3)
42 | v2:mult(s, 3)
43 |
44 | tester:assert(v1 == v2)
45 | end
46 |
47 | tester:add(TestVHash)
48 |
49 | tester:run()
50 |
--------------------------------------------------------------------------------
/QVAnalyzer.lua:
--------------------------------------------------------------------------------
1 | local gnuplot = require 'gnuplot'
2 |
3 | local QVAnalyzer = torch.class('rl.QVAnalyzer')
4 |
5 | function QVAnalyzer:__init(mdp)
6 | self.mdp = mdp
7 | end
8 | function QVAnalyzer:get_v_tensor(v)
9 | error('Must implement get_v_tensor.')
10 | end
11 |
12 | -- Plot the state value function
13 | function QVAnalyzer:plot_v(v)
14 | error('Must implement plot_v.')
15 | end
16 |
17 | function QVAnalyzer:plot_best_action(q, method_description)
18 | error('Must implement plot_best_action.')
19 | end
20 |
21 | function QVAnalyzer:v_from_q(q)
22 | error('Must implement v_from_q')
23 | end
24 |
25 | function QVAnalyzer:get_q_tensor(q)
26 | error('Must implement get_q_tensor.')
27 | end
28 |
29 | function QVAnalyzer:q_rms(q1, q2)
30 | local t1 = self:get_q_tensor(q1)
31 | local t2 = self:get_q_tensor(q2)
32 | return torch.sum(torch.pow(t1 - t2, 2))
33 | end
34 |
35 | function QVAnalyzer:v_rms(v1, v2)
36 | local t1 = self:get_v_tensor(v1)
37 | local t2 = self:get_v_tensor(v2)
38 | return torch.sum(torch.pow(t1 - t2, 2))
39 | end
40 |
--------------------------------------------------------------------------------
/doc/valuefunctions.md:
--------------------------------------------------------------------------------
1 | ## Value Functions
2 | All Q value functions implement:
3 | `get_value(s, a)`
4 |
5 | `get_best_action(s)`
6 |
7 | All V value functions implement:
8 | `get_value(s)`
9 |
10 | ### (Hash)Tables
11 | These are the simplest types of data structures. SHash and SAHash implement hash
12 | tables over the state and state-action states space, respectively. Only use
13 | these hash tables for small state/action spaces.
14 |
15 | ### Function Approximation
16 | For large state/action spaces, using a look-up table becomes intractable. An
17 | alternative is to approximate the value of a state or state-action pair by using
18 | a function approximator. See below for how features are extracted.
19 |
20 | ## State-Action Feature Extractors (SAFE)
21 | SAFeatureExtractor defines an interface for classes that extract features out of
22 | a given state-action pair. SAFEs have to implement:
23 |
24 | `[Tensor] get_sa_features(s, a)`
25 | `[number of tuple of numbers] get_sa_features_dim()`
26 | which returns the dimensions of the tensor returned by `get_sa_features`.
27 |
--------------------------------------------------------------------------------
/test/unittest_NNSarsaFactory.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 |
5 | local mdp = rl.TestMdp()
6 | local discount_factor = 0.631
7 | local TestNNSarsaFactory = {}
8 | function TestNNSarsaFactory.test_get_control()
9 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
10 | local lambda = 0.65
11 | local eps = 0.2437
12 | local explorer = rl.ConstExplorer(eps)
13 | local feature_extractor = rl.TestSAFE()
14 | local step_size = 0.92
15 |
16 | local nn_sarsa = rl.NNSarsa(
17 | mdp_config,
18 | lambda,
19 | explorer,
20 | feature_extractor,
21 | step_size)
22 | local factory = rl.NNSarsaFactory(
23 | mdp_config,
24 | lambda,
25 | explorer,
26 | feature_extractor,
27 | step_size)
28 | local nn_sarsa2 = factory:get_control()
29 | -- Not a clean way to prevent the q's from being initialized to different
30 | -- (random) parameters since they're neural network weights
31 | nn_sarsa2.q = nn_sarsa.q
32 | tester:assert(nn_sarsa2 == nn_sarsa)
33 | end
34 |
35 | tester:add(TestNNSarsaFactory)
36 |
37 | tester:run()
38 |
39 |
--------------------------------------------------------------------------------
/test/unittest_tensorutil.lua:
--------------------------------------------------------------------------------
1 |
2 | local tester = torch.Tester()
3 |
4 | local TestTensorUtil = {}
5 | function TestTensorUtil.test_apply_to_slices()
6 | local function power_fill(tensor, i, power)
7 | power = power or 1
8 | tensor:fill(i ^ power)
9 | end
10 | local A = torch.Tensor(2, 2)
11 | local B = torch.Tensor(2, 2)
12 | B[1][1] = 1
13 | B[1][2] = 1
14 | B[2][1] = 2
15 | B[2][2] = 2
16 | tester:assertTensorEq(
17 | rl.util.apply_to_slices(A, 1, power_fill),
18 | B,
19 | 0)
20 | end
21 |
22 | function TestTensorUtil.fill_range()
23 | local A = torch.Tensor(2, 2)
24 | local B = torch.Tensor(2, 2)
25 | B[1][1] = -1
26 | B[1][2] = -1
27 | B[2][1] = 0
28 | B[2][2] = 0
29 | tester:assertTensorEq(
30 | rl.util.apply_to_slices(A, 1, rl.util.fill_range, -2),
31 | B,
32 | 0)
33 |
34 | B[1][1] = 123
35 | B[2][1] = 123
36 | B[1][2] = 124
37 | B[2][2] = 124
38 | tester:assertTensorEq(
39 | rl.util.apply_to_slices(A, 2, rl.util.fill_range, 122),
40 | B,
41 | 0)
42 | end
43 |
44 | tester:add(TestTensorUtil)
45 |
46 | tester:run()
47 |
--------------------------------------------------------------------------------
/test/unittest_GreedyPolicy.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestGreedyPolicy = {}
5 | local mdp = rl.TestMdp()
6 | local q = rl.QFunc()
7 |
8 | local function get_policy(best_action, eps)
9 | local explorer = rl.ConstExplorer(eps)
10 | q.get_best_action = function (s)
11 | return best_action
12 | end
13 | return rl.GreedyPolicy(q, explorer, mdp:get_all_actions())
14 | end
15 |
16 | function TestGreedyPolicy.test_greedy()
17 | local eps = 0.7
18 | local policy = get_policy(2, eps)
19 | local expected_probabilities = {
20 | eps/3,
21 | 1 - 2*eps/3,
22 | eps/3
23 | }
24 | tester:assert(rl.util.are_testmdp_policy_probabilities_good(
25 | policy,
26 | expected_probabilities))
27 | end
28 |
29 | function TestGreedyPolicy.test_greedy2()
30 | local eps = 0.05
31 | local policy = get_policy(1, eps)
32 | local expected_probabilities = {
33 | 1 - 2*eps/3,
34 | eps/3,
35 | eps/3
36 | }
37 | tester:assert(rl.util.are_testmdp_policy_probabilities_good(
38 | policy,
39 | expected_probabilities))
40 | end
41 | tester:add(TestGreedyPolicy)
42 |
43 | tester:run()
44 |
--------------------------------------------------------------------------------
/MdpSampler.lua:
--------------------------------------------------------------------------------
1 | local MdpSampler = torch.class('rl.MdpSampler')
2 |
3 | function MdpSampler:__init(mdp_config)
4 | self.mdp = mdp_config:get_mdp()
5 | self.discount_factor = mdp_config:get_discount_factor()
6 | end
7 |
8 | function MdpSampler:sample_total_reward(policy)
9 | local s = self.mdp:get_start_state()
10 | local total_r, r = 0, 0
11 | while not self.mdp:is_terminal(s) do
12 | s, r = self.mdp:step(s, policy:get_action(s))
13 | total_r = total_r + r
14 | end
15 | return total_r
16 | end
17 |
18 | -- Episode: list of {state, action, discounted return, reward}. Indexed by time,
19 | -- starting at time = 1.
20 | function MdpSampler:get_episode(policy)
21 | local s = self.mdp:get_start_state()
22 | local r = 0
23 | local a = nil
24 | local next_s = nil
25 | local episode_builder = rl.EpisodeBuilder(self.discount_factor)
26 |
27 | while not self.mdp:is_terminal(s) do
28 | a = policy:get_action(s)
29 | next_s, r = self.mdp:step(s, a)
30 | episode_builder:add_state_action_reward_step(s, a, r)
31 | s = next_s
32 | end
33 |
34 | return episode_builder:get_episode()
35 | end
36 | function MdpSampler:get_mdp()
37 | return self.mdp
38 | end
39 |
--------------------------------------------------------------------------------
/test/unittest_QHash.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 | local TestQHash = {}
5 |
6 | local mdp = rl.TestMdp()
7 | function TestQHash.test_add_once()
8 | local q = rl.QHash(mdp)
9 | local s = 1
10 | local a = 1
11 | local val = 2
12 | q:add(s, a, val)
13 |
14 | tester:asserteq(q:get_value(s, a), val)
15 | tester:asserteq(q:get_best_action(s), a)
16 | end
17 |
18 | function TestQHash.test_mult()
19 | local q = rl.QHash(rl.TestMdp())
20 | local s = 2
21 | local a = 3
22 | q:add(s, a, 1)
23 | q:mult(s, a, 3)
24 | q:mult(s, a, 3)
25 | q:mult(s, a, 3)
26 |
27 | tester:asserteq(q:get_value(s, a), 27)
28 | end
29 |
30 | function TestQHash.test_equality()
31 | local q1 = rl.QHash(rl.TestMdp())
32 | local s = 2
33 | local a = 3
34 | q1:add(s, a, 1)
35 | q1:mult(s, a, 3)
36 | q1:mult(s, a, 3)
37 | q1:mult(s, a, 3)
38 |
39 | local q2 = rl.QHash(rl.TestMdp())
40 | local s = 2
41 | local a = 3
42 | q2:add(s, a, 1)
43 | q2:mult(s, a, 0)
44 | q2:add(s, a, 1)
45 | q2:mult(s, a, 3)
46 | q2:mult(s, a, 3)
47 | q2:mult(s, a, 3)
48 |
49 | tester:assert(q1 == q2)
50 | end
51 |
52 | tester:add(TestQHash)
53 |
54 | tester:run()
55 |
--------------------------------------------------------------------------------
/Mdp.lua:
--------------------------------------------------------------------------------
1 | local Mdp = torch.class('rl.Mdp')
2 |
3 | function Mdp:__init()
4 | end
5 |
6 | function Mdp:step(state, action)
7 | error('Must implement step')
8 | end
9 |
10 | function Mdp:get_start_state()
11 | error('Must implement get_start_state')
12 | end
13 |
14 | function Mdp:is_terminal(state)
15 | error('Must implement is_terminal')
16 | end
17 |
18 | -- The next two functions shoul return all states/actions in a list (i.e. a
19 | -- Table with numbers as keys). These two methods might be really expensive to
20 | -- compute. It's the responsibility of the caller to take that into
21 | -- consideration.
22 | function Mdp:get_all_states()
23 | error('Must implement get_all_states')
24 | end
25 |
26 | function Mdp:get_all_actions()
27 | error('Must implement get_all_actions')
28 | end
29 |
30 | -- These hash functions for the state and action are used if you plan on using
31 | -- TableSarsa. Otherwise, use feature extractors.
32 | --
33 | -- TODO: Move this to its own class, like SAFeatureExtractor.
34 | function Mdp:hash_s(state)
35 | error('Must implement hash_s')
36 | end
37 |
38 | function Mdp:hash_a(action)
39 | error('Must implement hash_a')
40 | end
41 |
42 | function Mdp:get_description()
43 | return 'Base MDP'
44 | end
45 |
--------------------------------------------------------------------------------
/test/unittest_TestMdpQVAnalyzer.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestTestMdpQVAnalyzer = {}
5 | local qva = rl.TestMdpQVAnalyzer()
6 |
7 | local v = {} -- mock VFunc
8 | function v:get_value(state)
9 | return state
10 | end
11 |
12 | local q = {} -- mock QFunc
13 | function q:get_value(state, action)
14 | return state + action
15 | end
16 |
17 | function q:get_best_action(state)
18 | return 3
19 | end
20 |
21 | function TestTestMdpQVAnalyzer.test_get_v_tensor()
22 | local expected = torch.Tensor{1, 2, 3}
23 | tester:assertTensorEq(expected, qva:get_v_tensor(v), 0)
24 | end
25 |
26 | function TestTestMdpQVAnalyzer.test_get_q_tensor()
27 | local expected_table = { -- row = state, col = action
28 | {2, 3, 4},
29 | {3, 4, 5},
30 | {4, 5, 6},
31 | }
32 | local expected = torch.Tensor(expected_table)
33 | tester:assertTensorEq(expected, qva:get_q_tensor(q), 0)
34 | end
35 |
36 | function TestTestMdpQVAnalyzer.test_v_from_q()
37 | local expected_v = {}
38 | function expected_v:get_value(state)
39 | return state + 3
40 | end
41 |
42 | local result_v = qva:v_from_q(q)
43 |
44 | tester:assert(result_v:__eq(expected_v))
45 | end
46 |
47 | tester:add(TestTestMdpQVAnalyzer)
48 |
49 | tester:run()
50 |
51 |
--------------------------------------------------------------------------------
/doc/policy.md:
--------------------------------------------------------------------------------
1 | ## Policy
2 | A Policy implements one method:
3 |
4 | `[action] get_action(state)`
5 |
6 | ### EpsilonGreedyPolicy
7 | The EpsilonGreedy policy is a simply policy that balances exploration and
8 | exploitation. The idea of epsilon greedy policies is to choose a random action
9 | with some small probability (epsilon). This encourages exploration. Otherwise,
10 | choose the best action, to exploit our knowledge so far.
11 |
12 | This is currently the only non-trivial policy implemented.
13 |
14 | ### Explorer
15 | This is used to choose how to balance exploration vs. exploitation.
16 | Specifically, it gives the probablity of exploring. So, it implements
17 |
18 | `[number from 0 to 1] get_eps(s)`
19 |
20 | which returns epsilon for the epsilon greedy policy.
21 |
22 | ### DecayTableExplorer
23 | This type of explorer chooses epsilon to be
24 |
25 | `N0 / (N0 + N(s))`
26 |
27 | where `N0` is some constant, and `N(s)` is the number of times state `s` has
28 | been visited. This type of exploration strategy with EpsilonGreedyPolicy is
29 | guaranteed to converge to the optimal policy since each state is explored, but
30 | eventually the best action is exploited. This is because as the number of times
31 | states has been visited explored (i.e. the more exploration we've done) the
32 | smaller epsilon because (i.e. don't bother exploring as much).
33 |
--------------------------------------------------------------------------------
/rl.lua:
--------------------------------------------------------------------------------
1 | module('rl', package.seeall)
2 |
3 | module('rl.util', package.seeall)
4 | require('rl.util.io_util')
5 | require('rl.util.util')
6 | require('rl.util.mdputil')
7 | require('rl.util.tensorutil')
8 | require('rl.util.util_for_unittests')
9 |
10 | require('rl.rl_constants')
11 |
12 | require('rl.Policy')
13 | require('rl.GreedyPolicy')
14 | require('rl.AllActionsEqualPolicy')
15 | require('rl.TestPolicy')
16 |
17 | require('rl.Control')
18 | require('rl.ValueIteration')
19 | require('rl.MonteCarloControl')
20 | require('rl.Sarsa')
21 | require('rl.TableSarsa')
22 | require('rl.LinSarsa')
23 | require('rl.NNSarsa')
24 |
25 | require('rl.ControlFactory')
26 | require('rl.SarsaFactory')
27 | require('rl.TableSarsaFactory')
28 | require('rl.LinSarsaFactory')
29 | require('rl.NNSarsaFactory')
30 |
31 | require('rl.Mdp')
32 | require('rl.TestMdp')
33 |
34 | require('rl.SAFeatureExtractor')
35 | require('rl.TestSAFE')
36 |
37 | require('rl.QFunc')
38 | require('rl.QHash')
39 | require('rl.QApprox')
40 | require('rl.QLin')
41 | require('rl.QNN')
42 |
43 | require('rl.VFunc')
44 | require('rl.VHash')
45 |
46 | require('rl.MdpConfig')
47 | require('rl.MdpSampler')
48 | require('rl.EpisodeBuilder')
49 |
50 | require('rl.Explorer')
51 | require('rl.ConstExplorer')
52 | require('rl.DecayTableExplorer')
53 |
54 | require('rl.Evaluator')
55 | require('rl.QVAnalyzer')
56 | require('rl.SarsaAnalyzer')
57 | require('rl.TestMdpQVAnalyzer')
58 |
--------------------------------------------------------------------------------
/MonteCarloControl.lua:
--------------------------------------------------------------------------------
1 | local MonteCarloControl, parent =
2 | torch.class('rl.MonteCarloControl', 'rl.ValueIteration')
3 |
4 | function MonteCarloControl:__init(mdp_config, N0)
5 | parent.__init(self, mdp_config)
6 | self.q = rl.QHash(self.mdp)
7 | self.Ns = rl.VHash(self.mdp)
8 | self.Nsa = rl.QHash(self.mdp)
9 | self.N0 = N0 or rl.MONTECARLOCONTROL_DEFAULT_N0
10 | self.actions = self.mdp.get_all_actions()
11 | end
12 |
13 | function MonteCarloControl:optimize_policy()
14 | self.policy = rl.GreedyPolicy(
15 | self.q,
16 | rl.DecayTableExplorer(self.N0, self.Ns),
17 | self.actions
18 | )
19 | end
20 |
21 | function MonteCarloControl:evaluate_policy()
22 | local episode = self.sampler:get_episode(self.policy)
23 | for t, data in pairs(episode) do
24 | local s = data.state
25 | local a = data.action
26 | local Gt = data.discounted_return
27 |
28 | self.Ns:add(s, 1)
29 | self.Nsa:add(s, a, 1)
30 |
31 | local alpha = 1. / self.Nsa:get_value(s, a)
32 | self.q:add(s, a, alpha * (Gt - self.q:get_value(s, a)))
33 | end
34 | end
35 |
36 | function MonteCarloControl:get_q()
37 | return self.q
38 | end
39 |
40 | function MonteCarloControl:__eq(other)
41 | return torch.typename(self) == torch.typename(other)
42 | and self.q == other.q
43 | and self.Ns == other.Ns
44 | and self.Nsa == other.Nsa
45 | and self.N0 == other.N0
46 | end
47 |
--------------------------------------------------------------------------------
/QApprox.lua:
--------------------------------------------------------------------------------
1 | -- Abstract class for a Q function approximation class
2 | local QApprox, parent = torch.class('rl.QApprox', 'rl.QFunc')
3 |
4 | function QApprox:__init(mdp, feature_extractor)
5 | parent.__init(self, mdp)
6 | self.feature_extractor = feature_extractor
7 | end
8 |
9 | function QApprox:clear()
10 | error('Must implement clear method')
11 | end
12 |
13 | function QApprox:get_value(s, a)
14 | error('Must implement get_Value method')
15 | end
16 |
17 | function QApprox:add(d_weights)
18 | error('Must implement add method')
19 | end
20 |
21 | function QApprox:mult_all(factor)
22 | error('Must implement mult_all method')
23 | end
24 |
25 | function QApprox:get_weight_vector()
26 | error('Must implement get_weight_vector method')
27 | end
28 |
29 | function QApprox:get_q_tensor()
30 | local value = torch.zeros(N_DEALER_STATES, N_PLAYER_STATES, N_ACTIONS)
31 | for dealer = 1, N_DEALER_STATES do
32 | for player = 1, N_PLAYER_STATES do
33 | for a = 1, N_ACTIONS do
34 | s = {dealer, player}
35 | value[s][a] = self:get_value(s, a)
36 | end
37 | end
38 | end
39 | return value
40 | end
41 |
42 | function QApprox:get_best_action(s)
43 | local actions = self.mdp:get_all_actions()
44 | local best_a, best_i = rl.util.max(
45 | actions,
46 | function (a) return self:get_value(s, a) end)
47 | return best_a
48 | end
49 |
50 | QApprox.__eq = parent.__eq -- force inheritance of this
51 |
--------------------------------------------------------------------------------
/LinSarsa.lua:
--------------------------------------------------------------------------------
1 | -- Implement SARSA algorithm using a linear function approximator for on-line
2 | -- policy control
3 | local LinSarsa, parent = torch.class('rl.LinSarsa', 'rl.Sarsa')
4 |
5 | function LinSarsa:__init(mdp_config, lambda, explorer, feature_extractor, step_size)
6 | parent.__init(self, mdp_config, lambda)
7 | self.explorer = explorer
8 | self.feature_extractor = feature_extractor
9 | self.step_size = step_size
10 | self.q = rl.QLin(self.mdp, self.feature_extractor)
11 | self.eligibility = rl.QLin(self.mdp, self.feature_extractor)
12 | end
13 |
14 | function LinSarsa:get_new_q()
15 | return rl.QLin(self.mdp, self.feature_extractor)
16 | end
17 |
18 | function LinSarsa:reset_eligibility()
19 | self.eligibility = rl.QLin(self.mdp, self.feature_extractor)
20 | end
21 |
22 | function LinSarsa:update_eligibility(s, a)
23 | local features = self.feature_extractor:get_sa_features(s, a)
24 | self.eligibility:mult(self.discount_factor*self.lambda)
25 | self.eligibility:add(features)
26 | end
27 |
28 | function LinSarsa:td_update(td_error)
29 | self.q:add(self.eligibility:get_weight_vector() * self.step_size * td_error)
30 | end
31 |
32 | function LinSarsa:update_policy()
33 | self.policy = rl.GreedyPolicy(
34 | self.q,
35 | self.explorer,
36 | self.actions
37 | )
38 | end
39 |
40 | function LinSarsa:__eq(other)
41 | return torch.typename(self) == torch.typename(other)
42 | and self.explorer == other.explorer
43 | and self.feature_extractor == other.feature_extractor
44 | and self.step_size == other.step_size
45 | and self.q == other.q
46 | end
47 |
--------------------------------------------------------------------------------
/util/util_for_unittests.lua:
--------------------------------------------------------------------------------
1 | function rl.util.are_testmdp_policy_probabilities_good(
2 | policy,
3 | expected_probabilities)
4 | local n_times_action = {0, 0, 0}
5 | local n_iters = 10000
6 | for i = 1, n_iters do
7 | local state = math.random(1, 2)
8 | local a = policy:get_action(state)
9 | n_times_action[a] = n_times_action[a] + 1
10 | end
11 |
12 | for action = 1, 3 do
13 | if not rl.util.is_prob_good(
14 | n_times_action[action],
15 | expected_probabilities[action],
16 | n_iters) then
17 | return false
18 | end
19 | end
20 | return true
21 | end
22 |
23 | function rl.util.are_tensors_same_shape(t1, t2)
24 | if t1:dim() ~= t2:dim() then
25 | return false
26 | end
27 | for d = 1, t1:dim() do
28 | if (#t1)[d] ~= (#t2)[d] then
29 | return false
30 | end
31 | end
32 | return true
33 | end
34 |
35 | function rl.util.do_qtable_qfunc_match(mdp, q_table, qfunc)
36 | for _, state in pairs(mdp:get_all_states()) do
37 | for _, action in pairs(mdp:get_all_actions()) do
38 | local sa_value = qfunc:get_value(state, action)
39 | if math.abs(q_table[state][action] - sa_value) > rl.FLOAT_EPS then
40 | return false
41 | end
42 | end
43 | end
44 | return true
45 | end
46 |
47 | function rl.util.do_vtable_vfunc_match(mdp, v_table, vfunc)
48 | for _, state in pairs(mdp:get_all_states()) do
49 | local state_value = vfunc:get_value(state)
50 | if v_table[state] ~= state_value then
51 | return false
52 | end
53 | end
54 | return true
55 | end
56 |
--------------------------------------------------------------------------------
/EpisodeBuilder.lua:
--------------------------------------------------------------------------------
1 | -- Episode: list of {state, action, discounted return, reward}, indexed by time.
2 | -- Time starts at 1 (going along with Lua conventions).
3 | -- This class builds this list intelligentally based on discount_factor, the discount
4 | -- factor.
5 | local EpisodeBuilder = torch.class('rl.EpisodeBuilder')
6 |
7 | function EpisodeBuilder:__init(discount_factor)
8 | if discount_factor < 0 or discount_factor > 1 then
9 | error('Gamma must be between 0 and 1, inclusive.')
10 | end
11 | self.t = 1
12 | self.episode = {}
13 | self.discount_factor = discount_factor
14 | self.built = false
15 | end
16 |
17 | function EpisodeBuilder:add_state_action_reward_step(state, action, reward)
18 | self.episode[self.t] = {
19 | state = state,
20 | action = action,
21 | discounted_return = reward,
22 | reward = reward
23 | }
24 | self.t = self.t + 1
25 | self.built = false
26 | end
27 |
28 | local function is_built(self)
29 | return self.discount_factor == 0 or self.built
30 | end
31 |
32 | function EpisodeBuilder:get_episode()
33 | if is_built(self) then
34 | return self.episode
35 | end
36 |
37 | local t = self.t - 1
38 | local discounted_future_return = self.discount_factor * self.episode[t].reward
39 | t = t - 1
40 | for i = 1, #self.episode - 1 do
41 | self.episode[t].discounted_return = self.episode[t].discounted_return +
42 | discounted_future_return
43 | discounted_future_return = self.discount_factor * (discounted_future_return +
44 | self.episode[t].reward)
45 | t = t - 1
46 | end
47 | self.built = true
48 | return self.episode
49 | end
50 |
--------------------------------------------------------------------------------
/NNSarsa.lua:
--------------------------------------------------------------------------------
1 | -- Implement SARSA algorithm using a neural network function approximator for
2 | -- on-line policy control
3 | local NNSarsa, parent = torch.class('rl.NNSarsa', 'rl.Sarsa')
4 |
5 | function NNSarsa:__init(mdp_config, lambda, explorer, feature_extractor, step_size)
6 | parent.__init(self, mdp_config, lambda)
7 | self.explorer = explorer
8 | self.feature_extractor = feature_extractor
9 | self.step_size = step_size
10 | self.q = rl.QNN(mdp_config:get_mdp(), feature_extractor)
11 | self.last_state = nil
12 | self.last_action = nil
13 | self.momentum = self.lambda * self.discount_factor
14 | end
15 |
16 | function NNSarsa:get_new_q()
17 | return q.QNN:new()
18 | end
19 |
20 | function NNSarsa:reset_eligibility()
21 | self.last_state = nil
22 | self.last_action = nil
23 | self.q:reset_momentum()
24 | end
25 |
26 | function NNSarsa:update_eligibility(s, a)
27 | self.last_state = s
28 | self.last_action = a
29 | end
30 |
31 | function NNSarsa:td_update(td_error)
32 | local learning_rate = td_error * self.step_size
33 | self.q:backward(
34 | self.last_state,
35 | self.last_action,
36 | learning_rate,
37 | self.momentum)
38 | end
39 |
40 | function NNSarsa:update_policy()
41 | self.policy = rl.GreedyPolicy(
42 | self.q,
43 | self.explorer,
44 | self.actions
45 | )
46 | end
47 |
48 | function NNSarsa:__eq(other)
49 | return torch.typename(self) == torch.typename(other)
50 | and self.explorer == other.explorer
51 | and self.feature_extractor == other.feature_extractor
52 | and self.step_size == other.step_size
53 | and self.q == other.q
54 | and self.last_state == other.last_state
55 | and self.last_action == other.last_action
56 | end
57 |
--------------------------------------------------------------------------------
/TestMdpQVAnalyzer.lua:
--------------------------------------------------------------------------------
1 | local TestMdpQVAnalyzer, parent =
2 | torch.class('rl.TestMdpQVAnalyzer', 'rl.QVAnalyzer')
3 |
4 | function TestMdpQVAnalyzer:__init()
5 | parent.__init(self, rl.TestMdp())
6 | self.n_states = #self.mdp.get_all_states()
7 | self.n_actions = #self.mdp.get_all_actions()
8 | end
9 |
10 | function TestMdpQVAnalyzer:get_v_tensor(v)
11 | local tensor = torch.zeros(self.n_states)
12 | for s = 1, self.n_states do
13 | tensor[s] = v:get_value(s)
14 | end
15 | return tensor
16 | end
17 |
18 | function TestMdpQVAnalyzer:plot_v(v)
19 | local tensor = self:get_v_tensor(v)
20 | local x = torch.Tensor(self.n_states)
21 | x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0)
22 | gnuplot.plot(x, tensor)
23 | gnuplot.xlabel('Dealer Showing')
24 | gnuplot.ylabel('State Value')
25 | gnuplot.title('Monte-Carlo State Value Function')
26 | end
27 |
28 | function TestMdpQVAnalyzer:get_q_tensor(q)
29 | local tensor = torch.zeros(self.n_states, self.n_actions)
30 | for s = 1, self.n_states do
31 | for a = 1, self.n_actions do
32 | tensor[s][a] = q:get_value(s, a)
33 | end
34 | end
35 | return tensor
36 | end
37 |
38 | function TestMdpQVAnalyzer:plot_best_action(q)
39 | local best_action_at_state = torch.Tensor(self.n_states)
40 | for s = 1, self.n_states do
41 | best_action_at_state[s] = q:get_best_action(s)
42 | end
43 | local x = torch.Tensor(self.n_actions)
44 | x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0)
45 | gnuplot.plot(x, best_action_at_state)
46 | gnuplot.xlabel('State')
47 | gnuplot.zlabel('Best Action')
48 | gnuplot.title('Learned Best Action Based on q')
49 | end
50 |
51 | function TestMdpQVAnalyzer:v_from_q(q)
52 | local v = rl.VHash(self.mdp)
53 | for s = 1, self.n_states do
54 | local a = q:get_best_action(s)
55 | v:add(s, q:get_value(s, a))
56 | end
57 | return v
58 | end
59 |
--------------------------------------------------------------------------------
/doc/mdp.md:
--------------------------------------------------------------------------------
1 | ## Markov Decision Proccess (MDP)
2 | Markov Decision Proccesses (MDPs) are at the heard of the RL algorithms
3 | implemented. Here, they are represented as a class. The definition of the MDP
4 | class will depend on the particular problem.
5 |
6 | The biggest idea of MDPs is that they are memoryless. The state of an MDP should
7 | be enough to determine what happens next.
8 |
9 | ### MDP Config
10 | Most other classes will require a `MdpConfig` instance instead of a `Mdp`
11 | instance. `MdpConfig` is a wrapper data structure that contains an MDP and
12 | configuration, such as the discount factor. A common pattern is the following:
13 |
14 | ```lua
15 | local mdp = TestMdp()
16 | local discount_factor = 0.9
17 |
18 | local mdp_config = MdpConfig(mdp, discount_factor)
19 | --- use mdp_config for future calls
20 | ```
21 |
22 | ### MdpSampler
23 | The MdpSampler is a wrapper around an Mdp that out provides some convenience
24 | methods for sampling the MDP, namely:
25 |
26 | * `[number] sample_reward(policy)`
27 | * `[episode] get_episode(policy)`
28 |
29 | An episode is a table of {state, action, discounted return, reward}, indexed by
30 | time. Time starts at 1 (going along with Lua conventions).
31 |
32 |
33 | ### Creating Your Own MDP
34 | To create a MDP, extend the base MDP class using torch:
35 |
36 | ```lua
37 | require 'Mdp'
38 | local MyMdp, parent = torch.class('MyMdp', 'Mdp')
39 |
40 | function MyMdp:__init(arg1)
41 | parent.__init(self)
42 | end
43 | ```
44 |
45 | The main functions that an MDP needs to be implemented are
46 |
47 | * `[next_state, reward] step(state, action)` Note that state should capture
48 | everything needed to compute the next state and reward, given an action.
49 | * `[state] get_start_state()`
50 | * `[boolean] is_terminal(state)`
51 |
52 | Check out [Mdp.lua](../Mdp.lua) for detail on other the functions that you may
53 | want to implement. See [Blackjack.lua](../BlackJack.lua) for an example.
54 |
55 |
--------------------------------------------------------------------------------
/test/unittest_LinSarsa.lua:
--------------------------------------------------------------------------------
1 | local tester = torch.Tester()
2 |
3 | local discount_factor = 0.95
4 | local mdp = rl.TestMdp()
5 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
6 | local fe = rl.TestSAFE()
7 |
8 | local TestLinSarsa = {}
9 |
10 | function TestLinSarsa.test_update_eligibility_one_step()
11 | local lambda = 1
12 | local eps = 0.032
13 | local explorer = rl.ConstExplorer(eps)
14 | local step_size = 0.05
15 | local sarsa = rl.LinSarsa(mdp_config, lambda, explorer, fe, step_size)
16 |
17 | local s = 2
18 | local a = 1
19 | sarsa:update_eligibility(s, a)
20 |
21 | local eligibility_expected = rl.QLin(mdp, fe)
22 | eligibility_expected.weights[1] = s + a
23 | eligibility_expected.weights[2] = s - a
24 |
25 | tester:assert(sarsa.eligibility == eligibility_expected)
26 | end
27 |
28 | function TestLinSarsa.test_update_eligibility_many_steps()
29 | local lambda = 0.5
30 | local eps = 0.032
31 | local explorer = rl.ConstExplorer(eps)
32 | local step_size = 0.05
33 | local sarsa = rl.LinSarsa(mdp_config, lambda, explorer, fe, step_size)
34 |
35 | local s = 2
36 | local a = 1
37 | sarsa:update_eligibility(s, a)
38 |
39 | local eligibility_expected = rl.QLin(mdp, fe)
40 | eligibility_expected.weights = eligibility_expected.weights
41 | + torch.Tensor{s+a, s-a}
42 |
43 | local decay_factor = discount_factor * lambda
44 | eligibility_expected.weights = eligibility_expected.weights * decay_factor
45 |
46 | s = 2
47 | a = 2
48 | sarsa:update_eligibility(s, a)
49 | eligibility_expected.weights =
50 | eligibility_expected.weights + torch.Tensor{s+a, s-a}
51 |
52 | eligibility_expected.weights = eligibility_expected.weights * decay_factor
53 |
54 | s = 2
55 | a = 1
56 | sarsa:update_eligibility(s, a)
57 | eligibility_expected.weights =
58 | eligibility_expected.weights + torch.Tensor{s+a, s-a}
59 |
60 | tester:assert(sarsa.eligibility == eligibility_expected)
61 | end
62 |
63 | tester:add(TestLinSarsa)
64 |
65 | tester:run()
66 |
67 |
--------------------------------------------------------------------------------
/Sarsa.lua:
--------------------------------------------------------------------------------
1 | -- Abstract class for implementing SARSA.
2 | -- See end of file for functions that must be implemented.
3 | local Sarsa, parent = torch.class('rl.Sarsa', 'rl.Control')
4 |
5 | function Sarsa:__init(mdp_config, lambda)
6 | parent.__init(self, mdp_config)
7 | self.lambda = lambda
8 | self.q = nil
9 | self.actions = self.mdp:get_all_actions()
10 | self.discount_factor = mdp_config:get_discount_factor()
11 | end
12 |
13 | function Sarsa:run_episode(s, a)
14 | self:reset_eligibility()
15 | -- Can't use sampler because we're updating policy at each step
16 | while not self.mdp:is_terminal(s) do
17 | local s_new, r = self.mdp:step(s, a)
18 | local td_error, a_new = nil, nil
19 | if s_new == nil then
20 | td_error = r - self.q:get_value(s, a)
21 | else
22 | a_new = self.policy:get_action(s_new)
23 | td_error = r + self.discount_factor*self.q:get_value(s_new, a_new)
24 | - self.q:get_value(s, a)
25 | end
26 | self:update_eligibility(s, a)
27 | self:td_update(td_error)
28 | self:update_policy()
29 |
30 | s = s_new
31 | a = a_new
32 | end
33 | end
34 |
35 | function Sarsa:improve_policy()
36 | local s = self.mdp:get_start_state()
37 | local a = self.policy:get_action(s)
38 | self:run_episode(s, a)
39 | end
40 |
41 | function Sarsa:get_q()
42 | return self.q
43 | end
44 |
45 | -- Return an instance of a Q class.
46 | function Sarsa:get_new_q()
47 | error('Must implement get_new_q')
48 | end
49 |
50 | -- Clear self.eligibility for a new episode
51 | function Sarsa:reset_eligibility()
52 | error('Must implement reset_eligibility')
53 | end
54 |
55 | -- Update self.eligibility after visiting state s and taking action a
56 | function Sarsa:update_eligibility(s, a)
57 | error('Must implement update_eligibility')
58 | end
59 |
60 | -- Implement the TD update rule, given a TD error.
61 | function Sarsa:td_update(td_error)
62 | error('Must implement td_error')
63 | end
64 |
65 | -- Update self.policy.
66 | function Sarsa:update_policy()
67 | error('Must implement update_policy')
68 | end
69 |
--------------------------------------------------------------------------------
/test/unittest_EpisodeBuilder.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestEpisodeBuilder = {}
5 | function are_discounted_return_good(
6 | discount_factor,
7 | rewards,
8 | expected_discounted_returns)
9 | local builder = rl.EpisodeBuilder(discount_factor)
10 | local state = 5
11 | local action = 9
12 |
13 | for _, r in pairs(rewards) do
14 | builder:add_state_action_reward_step(state, action, r)
15 | end
16 |
17 | local expected = {}
18 | for i, expected_discounted_return in pairs(expected_discounted_returns) do
19 | expected[i] = {
20 | state = state,
21 | action = action,
22 | discounted_return = expected_discounted_return,
23 | reward = rewards[i]
24 | }
25 | end
26 |
27 | return rl.util.deepcompare(expected, builder:get_episode())
28 | end
29 | function TestEpisodeBuilder.test_gamma_one()
30 | local discount_factor = 1
31 | local rewards = {1, 1, 1, 1}
32 | local expected_discounted_returns = {4, 3, 2, 1}
33 | tester:assert(are_discounted_return_good(discount_factor,
34 | rewards,
35 | expected_discounted_returns))
36 | end
37 |
38 | function TestEpisodeBuilder.test_gamma_zero()
39 | local discount_factor = 0
40 | local rewards = {1, 1, 1, 1}
41 | local expected_discounted_returns = {1, 1, 1, 1}
42 | tester:assert(are_discounted_return_good(discount_factor,
43 | rewards,
44 | expected_discounted_returns))
45 | end
46 |
47 | function TestEpisodeBuilder.test_gamma_fraction()
48 | local discount_factor = 0.5
49 | local rewards = {1, 1, 1, 1}
50 | local expected_discounted_returns = {
51 | 1 + 0.5 * (1 + 0.5 * (1+0.5)),
52 | 1 + 0.5 * (1 + 0.5),
53 | 1 + 0.5,
54 | 1}
55 | tester:assert(are_discounted_return_good(discount_factor,
56 | rewards,
57 | expected_discounted_returns))
58 | end
59 |
60 | tester:add(TestEpisodeBuilder)
61 |
62 | tester:run()
63 |
--------------------------------------------------------------------------------
/doc/sarsa.md:
--------------------------------------------------------------------------------
1 | ## Sarsa-lambda
2 | See [Sutton and
3 | Barto](https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node77.html) for
4 | explanation of algorithm. One major difference is that we used discounted
5 | reward.
6 |
7 | The main step that is abstracted away is how Q(s, a) is updated given the TD
8 | error. The underlying structure of Q (e.g. hash table vs. function approximator)
9 | will determine how it is handled. The eligibility update will also change base
10 | on the structure.
11 |
12 | Three versions are implemented:
13 | * TableSarsa - Use lookup tables to store Q
14 | * LinSarsa - Use a linear dot product of features to approximate Q
15 | * NnSarsa - Use a neural network to approximate Q
16 |
17 | ### Linear Approximation
18 | Documentation is a TODO - see LinSarsa.lua for now.
19 |
20 | ### Neural Network Approximation
21 | Documentation is a TODO - see NNSarsa.lua for now.
22 |
23 | ### Sarsa-lambda Analysis Scripts
24 | Below are more intesting scripts that compare Sarsa-lambda algorithms perform
25 | relative to Monte Carlo (MC) Control.
26 | * `analyze_table_sarsa.lua`
27 | * `analyze_lin_sarsa.lua`
28 | * `analyze_nn_sarsa.lua`
29 |
30 | MC Control is used as a baseline because it gives an unbiased estimate of the
31 | true Q (state-action value) function. In each of the scripts, two plots get
32 | generated: (1) root mean square (RMS) error of the estimated Q function vs
33 | lambda. (2) RMS error of the estimated Q function vs # iterations for lambda = 0
34 | and lambda = 1.
35 |
36 | To save time, you can generated the Q function from MC Control, save
37 | it, and then load it back up in the above scripts. Generate a MC Q file with
38 |
39 | `$ th generate_q_mc.lua -saveqto .dat`
40 |
41 | and use this file when running the above scripts with the following.
42 |
43 | `$ th analyze_table_sarsa.lua -loadqfrom .dat`
44 |
45 | Run these scripts with the -h option for more help.
46 |
47 | ### Example Plots
48 |
49 | For `analyze_table_sarsa` with the number of iterations set to 10^5, you get the
50 | following plots
51 |
52 | 
54 |
55 | 
58 |
--------------------------------------------------------------------------------
/TableSarsa.lua:
--------------------------------------------------------------------------------
1 | -- Implement SARSA algorithm using a linear function approximator for on-line
2 | -- policy control
3 | local TableSarsa, parent = torch.class('rl.TableSarsa', 'rl.Sarsa')
4 | function TableSarsa:__init(mdp_config, lambda)
5 | parent.__init(self, mdp_config, lambda)
6 | self.Ns = rl.VHash(self.mdp)
7 | self.Nsa = rl.QHash(self.mdp)
8 | self.q = rl.QHash(self.mdp)
9 | self.eligibility = rl.QHash(self.mdp)
10 | end
11 |
12 | function TableSarsa:get_new_q()
13 | return rl.QHash(self.mdp)
14 | end
15 |
16 | function TableSarsa:reset_eligibility()
17 | self.eligibility = rl.QHash(self.mdp)
18 | end
19 |
20 | function TableSarsa:update_eligibility(s, a)
21 | for _, state in pairs(self.mdp:get_all_states()) do
22 | for _, action in pairs(self.mdp:get_all_actions()) do
23 | self.eligibility:mult(
24 | state,
25 | action,
26 | self.discount_factor*self.lambda)
27 | end
28 | end
29 | self.eligibility:add(s, a, 1)
30 | self.Ns:add(s, 1)
31 | self.Nsa:add(s, a, 1)
32 | end
33 |
34 | local function get_step_size(self, state, action)
35 | local value = self.Nsa:get_value(state, action)
36 | if value == 0 then
37 | return value
38 | end
39 | return 1. / value
40 | end
41 |
42 | function TableSarsa:td_update(td_error)
43 | for _, state in pairs(self.mdp:get_all_states()) do
44 | for _, action in pairs(self.mdp:get_all_actions()) do
45 | local step_size = get_step_size(self, state, action)
46 | local eligibility = self.eligibility:get_value(state, action)
47 | self.q:add(
48 | state,
49 | action,
50 | step_size * td_error * eligibility)
51 | end
52 | end
53 | end
54 |
55 | function TableSarsa:update_policy()
56 | self.explorer = rl.DecayTableExplorer(
57 | rl.MONTECARLOCONTROL_DEFAULT_N0,
58 | self.Ns)
59 | self.policy = rl.GreedyPolicy(
60 | self.q,
61 | self.explorer,
62 | self.actions
63 | )
64 | end
65 |
66 | function TableSarsa:__eq(other)
67 | return torch.typename(self) == torch.typename(other)
68 | and self.Ns == other.Ns
69 | and self.Nsa == other.Nsa
70 | and self.q == other.q
71 | and self.eligibility == other.eligibility
72 | end
73 |
--------------------------------------------------------------------------------
/test/unittest_MonteCarloControl.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | math.randomseed(os.time())
4 | local tester = torch.Tester()
5 |
6 | local TestMonteCarloControl = {}
7 | function TestMonteCarloControl.test_evalute_policy()
8 | local mdp = rl.TestMdp()
9 | local discount_factor = 1
10 | -- With this policy, the episode will be:
11 | -- Step 1
12 | -- state: 1
13 | -- action: 1
14 | -- reward: -1
15 | -- Gt: -2
16 | --
17 | -- Step 2
18 | -- state 2: 2
19 | -- action 1: 1
20 | -- reward 1: -1
21 | -- Gt 1: -1
22 | local policy = rl.TestPolicy(1)
23 | local config = rl.MdpConfig(mdp, discount_factor)
24 | local mcc = rl.MonteCarloControl(config)
25 | mcc:set_policy(policy)
26 | mcc:evaluate_policy()
27 |
28 | local expected = rl.MonteCarloControl(config)
29 | local q = rl.QHash(mdp)
30 | q:add(1, 1, -2)
31 | q:add(2, 1, -1)
32 | local Ns = rl.VHash(mdp)
33 | Ns:add(1, 1)
34 | Ns:add(2, 1)
35 | local Nsa = rl.QHash(mdp)
36 | Nsa:add(1, 1, 1)
37 | Nsa:add(2, 1, 1)
38 | local N0 = rl.MONTECARLOCONTROL_DEFAULT_N0
39 | expected.q = q
40 | expected.Ns = Ns
41 | expected.Nsa = Nsa
42 | expected.N0 = N0
43 |
44 | tester:assert(mcc == expected)
45 | end
46 |
47 | function TestMonteCarloControl.test_optimize_policy()
48 | local mdp = rl.TestMdp()
49 | local discount_factor = 1
50 | local config = rl.MdpConfig(mdp, discount_factor)
51 | local mcc = rl.MonteCarloControl(config)
52 | local q = rl.QHash(mdp)
53 | q:add(1, 1, 100) -- make action "1" be the best action
54 | q:add(2, 1, 100)
55 | local Ns = rl.VHash(mdp)
56 | local n_times_states_visited = 10 -- make other actions seem really explored
57 | Ns:add(1, n_times_states_visited)
58 | Ns:add(2, n_times_states_visited)
59 | local N0 = 1
60 |
61 | mcc.q = q
62 | mcc.Ns = Ns
63 | mcc.N0 = N0
64 | mcc:optimize_policy()
65 |
66 | local policy = mcc:get_policy()
67 |
68 | local expected_eps = N0 / (N0 + n_times_states_visited)
69 | local expected_probabilities = {
70 | 1 - 2* expected_eps / 3,
71 | expected_eps / 3,
72 | expected_eps / 3
73 | }
74 |
75 | tester:assert(rl.util.are_testmdp_policy_probabilities_good(
76 | policy,
77 | expected_probabilities))
78 | end
79 |
80 | tester:add(TestMonteCarloControl)
81 |
82 | tester:run()
83 |
--------------------------------------------------------------------------------
/rocks/rl-0.1-1.rockspec:
--------------------------------------------------------------------------------
1 | package = "rl"
2 | version = "0.1-1"
3 |
4 | source = {
5 | url = "git://github.com/vpong/torch-rl.git",
6 | tag = "v0.1"
7 | }
8 |
9 | description = {
10 | summary = "A package for basic reinforcement learning algorithms.",
11 | detailed = [[
12 | A package for basic reinforcement learning algorithms
13 | ]],
14 | homepage = "https://github.com/vpong/torch-rl"
15 | }
16 |
17 | dependencies = {
18 | "lua ~> 5.1",
19 | "torch >= 7.0"
20 | }
21 |
22 | build = {
23 | type = "builtin",
24 | modules = {
25 | AllActionsEqualPolicy = "AllActionsEqualPolicy.lua",
26 | constants = "constants.lua",
27 | ConstExplorer = "ConstExplorer.lua",
28 | ControlFactory = "ControlFactory.lua",
29 | Control = "Control.lua",
30 | DecayTableExplorer = "DecayTableExplorer.lua",
31 | EpisodeBuilder = "EpisodeBuilder.lua",
32 | Evaluator = "Evaluator.lua",
33 | Explorer = "Explorer.lua",
34 | GreedyPolicy = "GreedyPolicy.lua",
35 | LinSarsaFactory = "LinSarsaFactory.lua",
36 | LinSarsa = "LinSarsa.lua",
37 | MdpConfig = "MdpConfig.lua",
38 | Mdp = "Mdp.lua",
39 | MdpSampler = "MdpSampler.lua",
40 | MonteCarloControl = "MonteCarloControl.lua",
41 | NNSarsaFactory = "NNSarsaFactory.lua",
42 | NNSarsa = "NNSarsa.lua",
43 | Policy = "Policy.lua",
44 | QApprox = "QApprox.lua",
45 | QFunc = "QFunc.lua",
46 | QHash = "QHash.lua",
47 | QLin = "QLin.lua",
48 | QNN = "QNN.lua",
49 | QVAnalyzer = "QVAnalyzer.lua",
50 | rl = "rl.lua",
51 | SAFeatureExtractor = "SAFeatureExtractor.lua",
52 | SarsaAnalyzer = "SarsaAnalyzer.lua",
53 | SarsaFactory = "SarsaFactory.lua",
54 | Sarsa = "Sarsa.lua",
55 | TableSarsaFactory = "TableSarsaFactory.lua",
56 | TableSarsa = "TableSarsa.lua",
57 | TestMdp = "TestMdp.lua",
58 | TestMdpQVAnalyzer = "TestMdpQVAnalyzer.lua",
59 | TestPolicy = "TestPolicy.lua",
60 | TestSAFE = "TestSAFE.lua",
61 | ValueIteration = "ValueIteration.lua",
62 | VFunc = "VFunc.lua",
63 | VHash = "VHash.lua",
64 | ["util.io_util"] = "util/io_util.lua",
65 | ["util.mdputil"] = "util/mdputil.lua",
66 | ["util.tensorutil"] = "util/tensorutil.lua",
67 | ["util.util_for_unittests"] = "util/util_for_unittests.lua",
68 | ["util.util"] = "util/util.lua"
69 | },
70 | copy_directories = { "doc" , "test"}
71 | }
72 |
--------------------------------------------------------------------------------
/rl-0.2-5.rockspec:
--------------------------------------------------------------------------------
1 | package = "rl"
2 | version = "0.2-5"
3 |
4 | source = {
5 | url = "git://github.com/vpong/torch-rl.git",
6 | tag = "v0.2.3"
7 | }
8 |
9 | description = {
10 | summary = "A package for basic reinforcement learning algorithms.",
11 | detailed = [[
12 | A package for basic reinforcement learning algorithms
13 | ]],
14 | homepage = "https://github.com/vpong/torch-rl"
15 | }
16 |
17 | dependencies = {
18 | "lua ~> 5.1",
19 | "torch >= 7.0"
20 | }
21 |
22 | build = {
23 | type = "builtin",
24 | modules = {
25 | ["rl.AllActionsEqualPolicy"] = "AllActionsEqualPolicy.lua",
26 | ["rl.ConstExplorer"] = "ConstExplorer.lua",
27 | ["rl.ControlFactory"] = "ControlFactory.lua",
28 | ["rl.Control"] = "Control.lua",
29 | ["rl.DecayTableExplorer"] = "DecayTableExplorer.lua",
30 | ["rl.EpisodeBuilder"] = "EpisodeBuilder.lua",
31 | ["rl.Evaluator"] = "Evaluator.lua",
32 | ["rl.Explorer"] = "Explorer.lua",
33 | ["rl.GreedyPolicy"] = "GreedyPolicy.lua",
34 | ["rl.LinSarsaFactory"] = "LinSarsaFactory.lua",
35 | ["rl.LinSarsa"] = "LinSarsa.lua",
36 | ["rl.MdpConfig"] = "MdpConfig.lua",
37 | ["rl.Mdp"] = "Mdp.lua",
38 | ["rl.MdpSampler"] = "MdpSampler.lua",
39 | ["rl.MonteCarloControl"] = "MonteCarloControl.lua",
40 | ["rl.NNSarsaFactory"] = "NNSarsaFactory.lua",
41 | ["rl.NNSarsa"] = "NNSarsa.lua",
42 | ["rl.Policy"] = "Policy.lua",
43 | ["rl.QApprox"] = "QApprox.lua",
44 | ["rl.QFunc"] = "QFunc.lua",
45 | ["rl.QHash"] = "QHash.lua",
46 | ["rl.QLin"] = "QLin.lua",
47 | ["rl.QNN"] = "QNN.lua",
48 | ["rl.QVAnalyzer"] = "QVAnalyzer.lua",
49 | ["rl"] = "rl.lua",
50 | ["rl.rl_constants"] = "rl_constants.lua",
51 | ["rl.SAFeatureExtractor"] = "SAFeatureExtractor.lua",
52 | ["rl.SarsaAnalyzer"] = "SarsaAnalyzer.lua",
53 | ["rl.SarsaFactory"] = "SarsaFactory.lua",
54 | ["rl.Sarsa"] = "Sarsa.lua",
55 | ["rl.TableSarsaFactory"] = "TableSarsaFactory.lua",
56 | ["rl.TableSarsa"] = "TableSarsa.lua",
57 | ["rl.TestMdp"] = "TestMdp.lua",
58 | ["rl.TestMdpQVAnalyzer"] = "TestMdpQVAnalyzer.lua",
59 | ["rl.TestPolicy"] = "TestPolicy.lua",
60 | ["rl.TestSAFE"] = "TestSAFE.lua",
61 | ["rl.ValueIteration"] = "ValueIteration.lua",
62 | ["rl.VFunc"] = "VFunc.lua",
63 | ["rl.VHash"] = "VHash.lua",
64 | ["rl.util.io_util"] = "util/io_util.lua",
65 | ["rl.util.mdputil"] = "util/mdputil.lua",
66 | ["rl.util.tensorutil"] = "util/tensorutil.lua",
67 | ["rl.util.util_for_unittests"] = "util/util_for_unittests.lua",
68 | ["rl.util.util"] = "util/util.lua"
69 | },
70 | copy_directories = { "doc" , "test"}
71 | }
72 |
--------------------------------------------------------------------------------
/QNN.lua:
--------------------------------------------------------------------------------
1 | local nn = require 'nn'
2 | local nngraph = require 'nngraph'
3 | local dpnn = require 'dpnn'
4 |
5 | -- Implementation of a state-action value function approx using a neural network
6 | local QNN, parent = torch.class('rl.QNN', 'rl.QApprox')
7 |
8 | local function get_module(self)
9 | local x = nn.Identity()()
10 | local l1 = nn.Linear(self.n_features, 1)(x)
11 | return nn.gModule({x}, {l1})
12 | end
13 |
14 | function QNN:is_linear()
15 | return true
16 | end
17 |
18 | function QNN:__init(mdp, feature_extractor)
19 | parent.__init(self, mdp, feature_extractor)
20 | self.n_features = feature_extractor:get_sa_num_features()
21 | self.module = get_module(self)
22 | self.is_first_update = true
23 | end
24 |
25 | -- This took forever to figure out. See
26 | -- https://github.com/Element-Research/dpnn/blob/165ce5ff37d0bb77c207e82f5423ade08593d020/Module.lua#L488
27 | -- for detail.
28 | local function reset_momentum(net)
29 | net.momGradParams = nil
30 | if net.modules then
31 | for _, child in pairs(net.modules) do
32 | reset_momentum(child)
33 | end
34 | end
35 | end
36 |
37 | function QNN:reset_momentum()
38 | self.is_first_update = true
39 | self.module:zeroGradParameters()
40 | reset_momentum(self.module)
41 | end
42 |
43 | function QNN:get_value(s, a)
44 | local input = self.feature_extractor:get_sa_features(s, a)
45 | return self.module:forward(input)[1]
46 | end
47 |
48 | -- For now, we're ignoring eligibility. This means that the update rule to the
49 | -- parameters W of the network is:
50 | --
51 | -- W <- W + step_size * td_error * dQ(s,a)/dW
52 | --
53 | -- We can force the network to update this way by recognizing that the "loss"
54 | -- is literally just the output of the network. This makes it so that
55 | --
56 | -- dLoss/dW = dQ(s, a)/dW
57 | --
58 | -- So, the network will update correctly if we just tell it that the output is
59 | -- the loss, i.e. set grad_out = 1.
60 | --
61 | -- For more detail where the update rule, see
62 | -- https://webdocs.cs.ualberta.ca/~sutton/book/ebook/node89.html
63 | --
64 | -- For more detail on how nn works, see
65 | -- https://github.com/torch/nn/blob/master/doc/module.md
66 | function QNN:backward(s, a, learning_rate, momentum)
67 | -- forward to make sure input is set correctly
68 | local input = self.feature_extractor:get_sa_features(s, a)
69 | local output = self.module:forward(input)
70 | -- backward
71 | local grad_out = torch.ones(#output)
72 | self.module:zeroGradParameters()
73 | self.module:backward(input, grad_out)
74 | -- update
75 | -- This check is necessary because of the way updateGradParameters is
76 | -- implemented this. This makes sure that the first update doesn't give
77 | -- momentum itself. However, future updates should rely on the momentum of
78 | -- previous calls.
79 | --
80 | -- Also, we can't just put the call to updateGradParameters() before the
81 | -- call to backward() because the zeroGradParameters() call messes it up.
82 | if self.is_first_update then
83 | self.is_first_update = false
84 | else
85 | self.module:updateGradParameters(momentum, 0, false) -- momentum (dpnn)
86 | end
87 | self.module:updateParameters(-learning_rate) -- W = W - rate * dL/dW
88 | end
89 |
90 | QNN.__eq = parent.__eq -- force inheritance of this
91 |
--------------------------------------------------------------------------------
/doc/index.md:
--------------------------------------------------------------------------------
1 | # torch-rl
2 | This is a Torch 7 package that implements a few reinforcement learning
3 | algorithms. So far, we've only implemented Q-learning.
4 |
5 | This documentation is intended to mostly to give a high-level idea of what each
6 | abstract class does. We start with a summary of the important files, and then
7 | give a slightly more detailed description afterwards. For more detail, see the
8 | source code. For examples on how to use the functions, see the unit tests.
9 |
10 | ## Summary of files
11 | Files that start with upper cases are classes. Every other file is a script,
12 | except for the constants and util file. This gives a summary of the most
13 | important files.
14 |
15 | ### Interfaces/abstract classes
16 | * `Control.lua` - Represents an algorithm that improves a policy.
17 | * `Mdp.lua` - A Markov Decision Proccess that represents an environemtn.
18 | * `Policy.lua` - A way of deciding what action to do given a state.
19 | * `Sarsa.lua` - A specific Control algorithm. Technically, it's Sarsa-lambda
20 | * `SAFeatureExtractor.lua` - Represents a way of extracting features from a given
21 | [S]tate-[A]ction pair.
22 | * `ControlFactory.lua` - Used to create new Control instances
23 | * `Explorer.lua` - Used to get the epsilon value for EpsilonGreedy
24 |
25 | ### Concrete classes
26 | * `Evaluator.lua` - Used to measure the performance of a policy
27 | * `MdpConfig.lua` - A way of configuring an Mdp.
28 | * `MdpSampler.lua` - A useful wrapper around an Mdp.
29 | * `QVAnalyzer.lua` - Used to get measurements out of Control algorithms
30 |
31 | ### Specific implementations
32 | * `EpsilonGreedyPolicy.lua` - Implements epsilon greedy policy.
33 | * `DecayTableExplorer.lua` - A way of decaying epsilon for epsilon greedy policy
34 | to ensure convergence.
35 | * `NNSarsa.lua` - Implements Sarsa-lambda using neural networks as a function
36 | approximator
37 | * `LinSarsa.lua` - Implements Sarsa-lambda using linear weighting as a function
38 | approximator
39 | * `TableSarsa.lua` - Implements Sarsa-lambda using a lookup table.
40 | * `MonteCarloControl.lua` - Implements Monte Carlo control.
41 |
42 | ### Test Files
43 | * `unittest_*.lua` - Unit tests. Can be run directly with `th unittest_*.lua`.
44 | * `run_rl_unittests.lua` - Run all unit tests related to this package.
45 | * `run_BlackJack_unittests.lua` - Run all unit tests related to Black Jack.
46 | * `run_all_unittests.lua` - Run all unit tests in this package.
47 | * `TestMdp.lua` - An MDP used for testing.
48 | * `TestPolicy.lua` - A policy for TestMdp used for testing.
49 | * `TestSAFE.lua` - A feature extractor used for testing.
50 |
51 | ## Read More
52 | * [MDP](doc/mdp.md) - Markov Decision Processes are the foundation of how
53 | reinforcement learning models the world.
54 | * [Policy](doc/policy.md) - Policies are mappings from state to action.
55 | * [Sarsa](doc/sarsa.md) - Read about the Sarsa-lambda algorithm and scripts that
56 | test them.
57 | * [Monte Carlo Control](doc/montecarlo.md) - Read about Monte Carlo Control and
58 | how to use it.
59 | * [Value Functions](doc/valuefunctions.md) - Value functions represent how
60 | valuable certains states and/or actions are.
61 | * [Black Jack](doc/blackjack.md) - An example MDP that is a simplified version
62 | of black jack.
63 |
64 | ## A note on Abstract Classes and private methods
65 | Torch doesn't implement interfaces nor abstract classes natively, but this
66 | packages tries to implement them by defining functions and raising an error if
67 | you try to implement it. (We'll call everything an abstract class
68 | just for simplicity.)
69 |
70 | Also, Torch classes don't provide a way of making private methods, so this is
71 | faked with the following:
72 |
73 | ```lua
74 |
75 | local function private_method(self, arg1, ...)
76 |
77 | end
78 |
79 | -- ...
80 |
81 | function Foo:bar()
82 | self:public_method(arg1, ...) -- Call public methods normally
83 | private_method(self, arg1, ...) -- Use this to call private methods
84 | end
85 | ```
86 |
--------------------------------------------------------------------------------
/test/unittest_NNSarsa.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 |
5 | local discount_factor = 0.95
6 | local mdp = rl.TestMdp()
7 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
8 | local fe = rl.TestSAFE()
9 | local lambda = 1
10 | local eps = 0.032
11 | local explorer = rl.ConstExplorer(eps)
12 | local step_size = 0.05
13 |
14 | local function get_sarsa()
15 | return rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size)
16 | end
17 |
18 | local TestNNSarsa = {}
19 | function TestNNSarsa.test_update_eligibility_one_step()
20 | local sarsa = get_sarsa()
21 |
22 | local s = 2
23 | local a = 1
24 | sarsa:update_eligibility(s, a)
25 |
26 | tester:assert(sarsa.last_state == s and sarsa.last_action == a)
27 | end
28 |
29 | function TestNNSarsa.test_td_update_once()
30 | local sarsa = get_sarsa()
31 | local expected_module = sarsa.q.module:clone()
32 |
33 | local s = 2
34 | local a = 1
35 | local td_error = -0.4
36 | sarsa:update_eligibility(s, a)
37 | sarsa:td_update(td_error)
38 |
39 | local input = fe:get_sa_features(s, a)
40 | expected_module:forward(input)
41 | local grad_out = torch.Tensor{1}
42 | expected_module:backward(input, grad_out)
43 | local momentum = lambda * discount_factor
44 | local learning_rate = step_size * td_error
45 | expected_module:updateParameters(-learning_rate)
46 |
47 | local expected_sarsa = rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size)
48 | expected_sarsa.q.module = expected_module
49 | expected_sarsa.last_state = s
50 | expected_sarsa.last_action = a
51 |
52 | tester:assert(sarsa == expected_sarsa)
53 | end
54 |
55 | function TestNNSarsa.test_td_update_many_times()
56 | local sarsa = get_sarsa()
57 | local expected_module = sarsa.q.module:clone()
58 |
59 | local s = 2
60 | local a = 1
61 | local td_error = -0.4
62 | sarsa:update_eligibility(s, a)
63 | sarsa:td_update(td_error)
64 |
65 | local input = fe:get_sa_features(s, a)
66 | expected_module:forward(input)
67 | local grad_out = torch.Tensor{1}
68 | expected_module:backward(input, grad_out)
69 | local momentum = lambda * discount_factor
70 | local learning_rate = step_size * td_error
71 | expected_module:updateParameters(-learning_rate)
72 |
73 | s = 2
74 | a = 2
75 | td_error = -0.6
76 | sarsa:update_eligibility(s, a)
77 | sarsa:td_update(td_error)
78 |
79 | input = fe:get_sa_features(s, a)
80 | expected_module:forward(input)
81 | grad_out = torch.Tensor{1}
82 | expected_module:zeroGradParameters()
83 | expected_module:backward(input, grad_out)
84 | momentum = lambda * discount_factor
85 | expected_module:updateGradParameters(momentum, 0, false)
86 | local learning_rate = step_size * td_error
87 | expected_module:updateParameters(-learning_rate)
88 |
89 | local expected_sarsa = rl.NNSarsa(mdp_config, lambda, explorer, fe, step_size)
90 | expected_sarsa.q.module = expected_module
91 | expected_sarsa.last_state = s
92 | expected_sarsa.last_action = a
93 |
94 | tester:assert(sarsa == expected_sarsa)
95 | end
96 |
97 | function TestNNSarsa.test_reset_eligibility()
98 | local sarsa = get_sarsa()
99 | local expected_module = sarsa.q.module:clone()
100 |
101 | if not sarsa.q:is_linear() then
102 | return
103 | end
104 |
105 | local s = 2
106 | local a = 1
107 | local td_error = -0.4
108 | local old_value = sarsa.q:get_value(s, a)
109 | sarsa:update_eligibility(s, a)
110 | sarsa:td_update(td_error)
111 |
112 | local new_value1 = sarsa.q:get_value(s, a)
113 | local d_value_1 = new_value1 - old_value
114 |
115 | sarsa:reset_eligibility(s, a)
116 | sarsa:update_eligibility(s, a)
117 | sarsa:td_update(td_error)
118 | local new_value2 = sarsa.q:get_value(s, a)
119 | local d_value_2 = new_value2 - new_value1
120 |
121 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS)
122 | end
123 |
124 | tester:add(TestNNSarsa)
125 |
126 | tester:run()
127 |
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torch-rl
2 | This is a Torch 7 package that implements a few reinforcement learning (RL)
3 | algorithms. So far, only Q-learning is implemented.
4 |
5 | ## Installation
6 | #### Dependencies
7 |
8 | 0. (Optional)
9 | [LuaRocks](https://github.com/keplerproject/luarocks/wiki/Download)
10 | - Highly recommended for anyone who wants to use Lua. Also, install this before
11 | install Torch, as Torch has weird configurations for LuaRocks.
12 | 1. [Torch](http://torch.ch/docs/getting-started.html)
13 | 2. [Lua 5.1](http://www.lua.org/download.html) - Installing torch automatically
14 | installs Lua
15 |
16 | #### LuaRocks - Automatic Installation (Recommended)
17 |
18 | 1. Install [LuaRocks](https://github.com/keplerproject/luarocks/wiki/Download).
19 | 2. From terminal, run
20 | ```
21 | $ luarocks install rl
22 | ```
23 |
24 | Error finding the module? Try
25 | ```
26 | $ luarocks install --server=https://luarocks.org/ rl
27 | ```
28 |
29 | #### LuaRocks - Manual Installation
30 | 1. `$ git clone git@github.com:vpong/torch-rl.git`
31 | 2. `$ luarocks make`
32 |
33 | #### Totally Manually
34 | ```
35 | $ git clone git@github.com:vpong/torch-rl.git
36 | ```
37 | Note that you'll basically have to the files around to any project that you want
38 | to use.
39 |
40 | ## Reinforcement Learning Topics
41 | * [MDP](doc/mdp.md) - A Markov Decision Process (MDP) models the world. Read
42 | about useful MDP functions and [how to create your own
43 | MDP](doc/mdp.md#create_mdp).
44 | * [Policy](doc/policy.md) - Policies are mappings from state to action.
45 | * [Sarsa](doc/sarsa.md) - Read about the Sarsa-lambda algorithm and scripts that
46 | test them.
47 | * [Monte Carlo Control](doc/montecarlo.md) - Read about Monte Carlo Control and
48 | how to use it.
49 | * [Value Functions](doc/valuefunctions.md) - Value functions represent how
50 | valuable certains states and/or actions are.
51 | * [Black Jack](https://github.com/vpong/rl-example) - An example repository that
52 | shows how to use RL algorithms to learn to play (a simplified version of)
53 | black jack.
54 |
55 | ## Summary of files
56 | This gives a summary of the most important files. Files that start with upper
57 | cases are classes. For more detail, see the source code. For examples on how to
58 | use the functions, see the unit tests.
59 |
60 | #### Interfaces/abstract classes
61 | * `Control.lua` - Represents an algorithm that improves a policy.
62 | * `Mdp.lua` - A Markov Decision Proccess that represents an environemtn.
63 | * `Policy.lua` - A way of deciding what action to do given a state.
64 | * `Sarsa.lua` - A specific Control algorithm. Technically, it's Sarsa-lambda
65 | * `SAFeatureExtractor.lua` - Represents a way of extracting features from a given
66 | [S]tate-[A]ction pair.
67 | * `ControlFactory.lua` - Used to create new Control instances
68 | * `Explorer.lua` - Used to get the epsilon value for EpsilonGreedy
69 |
70 | #### Concrete classes
71 | * `Evaluator.lua` - Used to measure the performance of a policy
72 | * `MdpConfig.lua` - A way of configuring an Mdp.
73 | * `MdpSampler.lua` - A useful wrapper around an Mdp.
74 | * `QVAnalyzer.lua` - Used to get measurements out of Control algorithms
75 |
76 | #### Specific implementations
77 | * `EpsilonGreedyPolicy.lua` - Implements epsilon greedy policy.
78 | * `DecayTableExplorer.lua` - A way of decaying epsilon for epsilon greedy policy
79 | to ensure convergence.
80 | * `NNSarsa.lua` - Implements Sarsa-lambda using neural networks as a function
81 | approximator
82 | * `LinSarsa.lua` - Implements Sarsa-lambda using linear weighting as a function
83 | approximator
84 | * `TableSarsa.lua` - Implements Sarsa-lambda using a lookup table.
85 | * `MonteCarloControl.lua` - Implements Monte Carlo control.
86 |
87 | #### Other Files
88 | * `util/*.lua` - Utility functions used throughout the project.
89 | * `test/unittest_*.lua` - Unit tests. Can be run individual tests with `th
90 | unittest_*.lua`.
91 | * `run_tests.lua` - Run all unit tests in `test/`directory.
92 | * `TestMdp.lua` - An MDP used for testing.
93 | * `TestPolicy.lua` - A policy for TestMdp used for testing.
94 | * `TestSAFE.lua` - A feature extractor used for testing.
95 |
96 | ## A note on Abstract Classes and private methods
97 | Torch doesn't implement interfaces nor abstract classes natively, but this
98 | packages tries to implement them by defining functions and raising an error if
99 | you try to implement it. (We'll call everything an abstract class
100 | just for simplicity.)
101 |
102 | Also, Torch classes don't provide a way of making private methods, so this is
103 | faked with the following:
104 |
105 | ```lua
106 |
107 | local function private_method(self, arg1, ...)
108 | ...
109 | end
110 |
111 | function Foo:bar()
112 | self:public_method(arg1, ...) -- Call public methods normally
113 | private_method(self, arg1, ...) -- Use this to call private methods
114 | end
115 | ```
116 |
--------------------------------------------------------------------------------
/test/unittest_QLin.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestQLin = {}
5 |
6 | local mdp = rl.TestMdp()
7 | function TestQLin.test_add_once()
8 | local fe = rl.TestSAFE()
9 |
10 | local q = rl.QLin(mdp, fe)
11 | local feature_dim = fe:get_sa_features_dim()
12 | q.weights = torch.zeros(feature_dim)
13 | local d_weights = torch.zeros(feature_dim)
14 | d_weights[1] = 1
15 | q:add(d_weights)
16 |
17 | local s = 1
18 | local a = 1
19 | tester:asserteq(q:get_value(s, a), 2, "Wrong state-action value.")
20 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.")
21 |
22 | local expected_q_table = { -- row = state, colm = action
23 | [1] = {1+1, 1+2, 1+3},
24 | [2] = {2+1, 2+2, 2+3},
25 | [3] = {3+1, 3+2, 3+3}
26 | }
27 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q))
28 | end
29 |
30 | function TestQLin.test_add_complex()
31 | local fe = rl.TestSAFE()
32 |
33 | local q = rl.QLin(mdp, fe)
34 | local feature_dim = fe:get_sa_features_dim()
35 | q.weights = torch.zeros(feature_dim)
36 | local d_weights = torch.zeros(feature_dim)
37 | local weight_1 = 3
38 | local weight_2 = 0.5
39 | d_weights[1] = weight_1
40 | d_weights[2] = weight_2
41 | q:add(d_weights)
42 |
43 | local s = 3
44 | local a = 1
45 | tester:asserteq(q:get_value(s, a), 12+1, "Wrong state-action value.")
46 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.")
47 |
48 | local expected_q_table = { -- row = state, colm = action
49 | [1] = {
50 | weight_1*(1+1) + weight_2*(1-1),
51 | weight_1*(1+2) + weight_2*(1-2),
52 | weight_1*(1+3) + weight_2*(1-3)
53 | },
54 | [2] = {
55 | weight_1*(2+1) + weight_2*(2-1),
56 | weight_1*(2+2) + weight_2*(2-2),
57 | weight_1*(2+3) + weight_2*(2-3)
58 | },
59 | [3] = {
60 | weight_1*(3+1) + weight_2*(3-1),
61 | weight_1*(3+2) + weight_2*(3-2),
62 | weight_1*(3+3) + weight_2*(3-3)
63 | }
64 | }
65 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q))
66 | end
67 |
68 | function TestQLin.test_add_and_multiply()
69 | local fe = rl.TestSAFE()
70 |
71 | local q = rl.QLin(mdp, fe)
72 | local feature_dim = fe:get_sa_features_dim()
73 | q.weights = torch.zeros(feature_dim)
74 | local d_weights = torch.zeros(feature_dim)
75 | local weight_1 = 3
76 | local weight_2 = 0.5
77 | d_weights[1] = weight_1
78 | d_weights[2] = weight_2
79 | q:add(d_weights)
80 |
81 | local factor = 0.9
82 | q:mult(factor)
83 |
84 | local s = 3
85 | local a = 1
86 | tester:asserteq(
87 | q:get_value(s, a),
88 | factor*(12+1),
89 | "Wrong state-action value.")
90 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.")
91 |
92 | local expected_q_table = { -- row = state, colm = action
93 | [1] = {
94 | factor * (weight_1*(1+1) + weight_2*(1-1)),
95 | factor * (weight_1*(1+2) + weight_2*(1-2)),
96 | factor * (weight_1*(1+3) + weight_2*(1-3))
97 | },
98 | [2] = {
99 | factor * (weight_1*(2+1) + weight_2*(2-1)),
100 | factor * (weight_1*(2+2) + weight_2*(2-2)),
101 | factor * (weight_1*(2+3) + weight_2*(2-3))
102 | },
103 | [3] = {
104 | factor * (weight_1*(3+1) + weight_2*(3-1)),
105 | factor * (weight_1*(3+2) + weight_2*(3-2)),
106 | factor * (weight_1*(3+3) + weight_2*(3-3))
107 | }
108 | }
109 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q))
110 | end
111 |
112 | function TestQLin.test_clear()
113 | local fe = rl.TestSAFE()
114 |
115 | local q = rl.QLin(mdp, fe)
116 | local feature_dim = fe:get_sa_features_dim()
117 | q.weights = torch.zeros(feature_dim)
118 | local d_weights = torch.zeros(feature_dim)
119 | local weight_1 = 3
120 | local weight_2 = 0.5
121 | d_weights[1] = weight_1
122 | d_weights[2] = weight_2
123 | q:add(d_weights)
124 |
125 | local factor = 0.9
126 | q:mult(factor)
127 |
128 | -- Now it should be the same as test_add_once
129 | q:clear()
130 |
131 | d_weights[1] = 1
132 | d_weights[2] = 0
133 | q:add(d_weights)
134 |
135 | local s = 1
136 | local a = 1
137 | tester:asserteq(q:get_value(s, a), 2, "Wrong state-action value.")
138 | tester:asserteq(q:get_best_action(s), 3, "Wrong best action.")
139 |
140 | local expected_q_table = { -- row = state, colm = action
141 | [1] = {1+1, 1+2, 1+3},
142 | [2] = {2+1, 2+2, 2+3},
143 | [3] = {3+1, 3+2, 3+3}
144 | }
145 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, expected_q_table, q))
146 | end
147 | tester:add(TestQLin)
148 |
149 | tester:run()
150 |
--------------------------------------------------------------------------------
/test/unittest_QNN.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 | local tester = torch.Tester()
3 |
4 | local TestQNN = {}
5 |
6 | local mdp = rl.TestMdp()
7 | local fe = rl.TestSAFE()
8 |
9 | function TestQNN.test_backward()
10 | local q = rl.QNN(mdp, fe)
11 | local module = q.module:clone()
12 |
13 | local s = 2
14 | local a = 4
15 |
16 | local step_size = 0.5
17 | local lambda = 0
18 | local discount_factor = 0.5
19 | local td_error = 1
20 | local learning_rate = step_size * td_error
21 | local momentum = lambda * discount_factor
22 | q:backward(s, a, learning_rate, momentum)
23 | local new_params = q.module:parameters()
24 |
25 | -- This is kinda cheating since this is basically the same code as the
26 | -- function, but I also don't see a better way to do this.
27 | local input = fe:get_sa_features(s, a)
28 | local grad_out = torch.Tensor{1}
29 | local _ = module:forward(input)
30 | module:backward(input, grad_out)
31 | module:updateGradParameters(momentum, 0, false)
32 | module:updateParameters(-step_size*td_error)
33 | local expected_params = module:parameters()
34 |
35 | tester:assertTensorEq(expected_params[1], new_params[1], 0)
36 | tester:assertTensorEq(expected_params[2], new_params[2], 0)
37 | end
38 |
39 | function TestQNN.test_backward_no_momentum()
40 | local q = rl.QNN(mdp, fe)
41 | if not q:is_linear() then
42 | return
43 | end
44 |
45 | local module = q.module:clone()
46 |
47 | local s = 2
48 | local a = 4
49 | local old_value = q:get_value(s, a)
50 |
51 | local step_size = 0.5
52 | local td_error = 1
53 | local learning_rate = step_size * td_error
54 | local momentum = 0
55 |
56 | q:backward(s, a, learning_rate, momentum)
57 | local new_value1 = q:get_value(s, a)
58 | local d_value_1 = new_value1 - old_value
59 |
60 | q:backward(s, a, learning_rate, momentum)
61 | local new_value2 = q:get_value(s, a)
62 | local d_value_2 = new_value2 - new_value1
63 |
64 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS)
65 | end
66 |
67 | function TestQNN.test_backward_with_momentum()
68 | local q = rl.QNN(mdp, fe)
69 | if not q:is_linear() then
70 | return
71 | end
72 | local module = q.module:clone()
73 |
74 | local s = 2
75 | local a = 4
76 | local old_value = q:get_value(s, a)
77 |
78 | local step_size = 0.5
79 | local lambda = 1
80 | local discount_factor = 1
81 | local td_error = 1
82 | local learning_rate = step_size * td_error
83 | local momentum = lambda * discount_factor
84 |
85 | q:backward(s, a, learning_rate, momentum)
86 | local new_value1 = q:get_value(s, a)
87 | local d_value_1 = new_value1 - old_value
88 |
89 | q:backward(s, a, learning_rate, momentum)
90 | local new_value2 = q:get_value(s, a)
91 | local d_value_2 = new_value2 - new_value1
92 |
93 | tester:assert(math.abs((1+momentum)*d_value_1 - d_value_2) < rl.FLOAT_EPS)
94 | end
95 |
96 | function TestQNN.test_momentum_exists()
97 | local q = rl.QNN(mdp, fe)
98 | -- If q is non-linear, then all bets are off on whether or not the momentum
99 | -- will change things.
100 | if not q:is_linear() then
101 | return
102 | end
103 | local module = q.module:clone()
104 |
105 | local s = 2
106 | local a = 4
107 | local old_value = q:get_value(s, a)
108 |
109 | local step_size = 0.5
110 | local lambda = 1
111 | local discount_factor = 1
112 | local td_error = 1
113 | local learning_rate = step_size * td_error
114 | local momentum = lambda * discount_factor
115 |
116 | q:backward(s, a, learning_rate, momentum)
117 | local new_value1 = q:get_value(s, a)
118 | local d_value_1 = new_value1 - old_value
119 |
120 | q:backward(s, a, learning_rate, momentum)
121 | local new_value2 = q:get_value(s, a)
122 | local d_value_2 = new_value2 - new_value1
123 |
124 | tester:assert(math.abs(d_value_1 - d_value_2) > rl.FLOAT_EPS)
125 | end
126 |
127 | function TestQNN.test_backward_reset_momentum()
128 | local q = rl.QNN(mdp, fe)
129 | if not q:is_linear() then
130 | return
131 | end
132 | local module = q.module:clone()
133 |
134 | local s = 2
135 | local a = 4
136 | local old_value = q:get_value(s, a)
137 |
138 | local step_size = 0.5
139 | local td_error = 1
140 | local learning_rate = step_size * td_error
141 | local momentum = 1
142 |
143 | q:backward(s, a, learning_rate, momentum)
144 | local new_value1 = q:get_value(s, a)
145 | local d_value_1 = new_value1 - old_value
146 |
147 | q:reset_momentum()
148 |
149 | q:backward(s, a, learning_rate, momentum)
150 | local new_value2 = q:get_value(s, a)
151 | local d_value_2 = new_value2 - new_value1
152 |
153 | tester:assert(math.abs(d_value_1 - d_value_2) < rl.FLOAT_EPS)
154 | end
155 |
156 | tester:add(TestQNN)
157 |
158 | tester:run()
159 |
--------------------------------------------------------------------------------
/util/util.lua:
--------------------------------------------------------------------------------
1 | -- Thanks to
2 | -- https://scriptinghelpers.org/questions/11242/how-to-make-a-weighted-selection
3 | function rl.util.weighted_random_choice(items)
4 | -- Sum all weights
5 | local total_weight = 0
6 | for item, weight in pairs(items) do
7 | total_weight = total_weight + weight
8 | end
9 |
10 | -- Pick random value
11 | rand = math.random() * total_weight
12 | choice = nil
13 |
14 | -- Search for the interval [0, w1] [w1, w1 + w2] [w1 + w2, w1 + w2 + w3] ...
15 | -- that `rand` belongs to
16 | -- and select the corresponding choice
17 | for item, weight in pairs(items) do
18 | if rand < weight then
19 | choice = item
20 | break
21 | else
22 | rand = rand - weight
23 | end
24 | end
25 |
26 | return choice
27 | end
28 |
29 | -- Not called fold left/right because order of pairs is not guaranteed.
30 | function rl.util.fold(fn)
31 | return function (acc)
32 | return function (list)
33 | for k, v in pairs(list) do
34 | acc = fn(acc, v)
35 | end
36 | return acc
37 | end
38 | end
39 | end
40 |
41 | -- Do this to avoid accumulator being reused
42 | local sum_from = rl.util.fold(function (a, b) return a + b end)
43 | function rl.util.sum(lst)
44 | return sum_from(0)(lst)
45 | end
46 |
47 | function rl.util.fold_with_key(fn)
48 | return function (acc)
49 | return function (list)
50 | for k, v in pairs(list) do
51 | acc = fn(acc, k, v)
52 | end
53 | return acc
54 | end
55 | end
56 | end
57 |
58 | -- Return the (element with max value, key of that element) of a table
59 | -- value_of is a function that given an element in the list, returns its value.
60 | function rl.util.max(tab, value_of)
61 | max_elem = nil
62 | maxK = nil
63 | maxV = 0
64 | for k, elem in pairs(tab) do
65 | curr_v = value_of(elem)
66 | if max_elem == nil or curr_v > maxV then
67 | max_elem = elem
68 | maxK = k
69 | maxV = curr_v
70 | end
71 | end
72 | return max_elem, maxK
73 | end
74 |
75 | -- Thanks to https://gist.github.com/tylerneylon/81333721109155b2d244
76 | function rl.util.copy_simply(obj)
77 | if type(obj) ~= 'table' then return obj end
78 | local res = {}
79 | for k, v in pairs(obj) do res[rl.util.copy_simply(k)] = rl.util.copy_simply(v) end
80 | return res
81 | end
82 |
83 | -- cache the results of a function call
84 | function rl.util.memoize(f)
85 | local cache = nil
86 | return (
87 | function ()
88 | if cache == nil then
89 | cache = f()
90 | end
91 | return cache
92 | end
93 | )
94 | end
95 |
96 | -- Compare two tables, optionally ignore meta table (default FALSE)
97 | -- source: https://web.archive.org/web/20131225070434/http://snippets.luacode.org/snippets/Deep_Comparison_of_Two_Values_3
98 | local function deepcompare(t1, t2, ignore_mt)
99 | local ty1 = type(t1)
100 | local ty2 = type(t2)
101 | if ty1 ~= ty2 then
102 | return false
103 | end
104 |
105 | -- non-table types can be directly compared
106 | if ty1 ~= 'table' and ty2 ~= 'table' then
107 | return t1 == t2
108 | end
109 |
110 | -- as well as tables which have the metamethod __eq
111 | local mt = getmetatable(t1)
112 | if not ignore_mt and mt and mt.__eq then
113 | return t1 == t2
114 | end
115 |
116 | for k1,v1 in pairs(t1) do
117 | local v2 = t2[k1]
118 | if v2 == nil or not rl.util.deepcompare(v1,v2) then
119 | return false
120 | end
121 | end
122 | for k2,v2 in pairs(t2) do
123 | local v1 = t1[k2]
124 | if v1 == nil or not rl.util.deepcompare(v1,v2) then
125 | return false
126 | end
127 | end
128 | return true
129 | end
130 |
131 | function rl.util.deepcompare(t1, t2)
132 | return deepcompare(t1, t2, true)
133 | end
134 |
135 | function rl.util.deepcompare_with_meta(t1, t2)
136 | return deepcompare(t1, t2, false)
137 | end
138 |
139 | -- Get # of times elem is in list
140 | function rl.util.get_count(elem, list)
141 | local count = 0
142 | for _, e in pairs(list) do
143 | if e == elem then
144 | count = count + 1
145 | end
146 | end
147 | return count
148 | end
149 |
150 | -- check if Bernounilli trial results is reasonable
151 | function rl.util.is_prob_good(n, p, N)
152 | if p == 0 then
153 | return n == 0
154 | end
155 | if p < 0 or p > 1 then
156 | error('Invalid probability: ' .. p)
157 | end
158 | local std = math.sqrt(N * p * (1-p))
159 | local mean = N * p
160 | return (mean - 3*std < n and n < mean + 3*std)
161 | end
162 |
163 | -- Check if # times elem is in list is reasonable, assuming it had a fixed
164 | -- probability of being in that list
165 | function rl.util.elem_has_good_freq(elem, list, expected_p)
166 | local n = rl.util.get_count(elem, list)
167 | return rl.util.is_prob_good(n, expected_p, #list)
168 | end
169 |
--------------------------------------------------------------------------------
/test/unittest_MdpSampler.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 |
5 | local TestMdpSampler = {}
6 |
7 | local function get_sampler(discount_factor)
8 | local config = rl.MdpConfig(rl.TestMdp(), discount_factor)
9 | return rl.MdpSampler(config)
10 | end
11 |
12 | local function get_policy_episode(policy, discount_factor)
13 | local sampler = get_sampler(discount_factor)
14 | return sampler:get_episode(policy)
15 | end
16 |
17 | local function is_discount_good(episode, discount_factor)
18 | local data = episode[1]
19 | local last_Gt = data.discounted_return
20 | local last_r = data.reward
21 | for t = 2, #episode do
22 | data = episode[t]
23 |
24 | local s = data.state
25 | local a = data.action
26 | local Gt = data.discounted_return
27 | local r = data.reward
28 | if last_Gt ~= last_r + discount_factor * Gt then
29 | return false
30 | end
31 |
32 | last_Gt = Gt
33 | last_r = r
34 | end
35 | return true
36 | end
37 |
38 | function TestMdpSampler.test_get_episode_discounted_reward()
39 | local discount_factor = 1
40 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor)
41 | tester:assert(is_discount_good(episode, discount_factor))
42 |
43 | local discount_factor = 0.5
44 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor)
45 | tester:assert(is_discount_good(episode, discount_factor))
46 |
47 | local discount_factor = 0
48 | local episode = get_policy_episode(rl.TestPolicy(1), discount_factor)
49 | tester:assert(is_discount_good(episode, discount_factor))
50 | end
51 |
52 | function TestMdpSampler.test_discounted_reward_error()
53 | local discount_factor = 2
54 | local get_config = function ()
55 | return rl.MdpConfig(TestMdp, discount_factor)
56 | end
57 | tester:assertError(get_config)
58 |
59 | discount_factor = -1
60 | get_config = function ()
61 | return rl.MdpConfig(TestMdp, discount_factor)
62 | end
63 | tester:assertError(get_config)
64 | end
65 |
66 | function TestMdpSampler.test_sample_return_always_one()
67 | local policy = rl.TestPolicy(1)
68 |
69 | local discount_factor = 1
70 | local sampler = get_sampler(discount_factor)
71 | tester:asserteq(sampler:sample_total_reward(policy), -2)
72 |
73 | local discount_factor = 0
74 | local sampler = get_sampler(discount_factor)
75 | tester:asserteq(sampler:sample_total_reward(policy), -2)
76 |
77 | local discount_factor = 0.5
78 | local sampler = get_sampler(discount_factor)
79 | tester:asserteq(sampler:sample_total_reward(policy), -2)
80 | end
81 |
82 | function TestMdpSampler.test_sample_return_always_two()
83 | local policy = rl.TestPolicy(2)
84 |
85 | local discount_factor = 1
86 | local sampler = get_sampler(discount_factor)
87 | tester:asserteq(sampler:sample_total_reward(policy), 0)
88 |
89 | local discount_factor = 0
90 | local sampler = get_sampler(discount_factor)
91 | tester:asserteq(sampler:sample_total_reward(policy), 0)
92 |
93 | local discount_factor = 0.5
94 | local sampler = get_sampler(discount_factor)
95 | tester:asserteq(sampler:sample_total_reward(policy), 0)
96 | end
97 |
98 | function TestMdpSampler.test_sample_return_always_three()
99 | local policy = rl.TestPolicy(3)
100 |
101 | local discount_factor = 1
102 | local sampler = get_sampler(discount_factor)
103 | tester:asserteq(sampler:sample_total_reward(policy), 2)
104 |
105 | local discount_factor = 0
106 | local sampler = get_sampler(discount_factor)
107 | tester:asserteq(sampler:sample_total_reward(policy), 2)
108 |
109 | local discount_factor = 0.5
110 | local sampler = get_sampler(discount_factor)
111 | tester:asserteq(sampler:sample_total_reward(policy), 2)
112 | end
113 |
114 | local function is_action_good(episode, expected)
115 | local policy = rl.TestPolicy(expected)
116 | local discount_factor = 1
117 |
118 | local episode = get_policy_episode(policy, discount_factor)
119 | for t, data in pairs(episode) do
120 | if data.action ~= expected then
121 | return false
122 | end
123 | end
124 | return true
125 | end
126 |
127 | function TestMdpSampler.test_action_is_good()
128 | tester:assert(is_action_good(episode, 1))
129 | end
130 |
131 | function TestMdpSampler.test_action_is_good2()
132 | tester:assert(is_action_good(episode, 2))
133 | end
134 |
135 | function TestMdpSampler.test_action_is_good3()
136 | tester:assert(is_action_good(episode, 3))
137 | end
138 |
139 |
140 | function TestMdpSampler.test_episode()
141 | local policy = rl.TestPolicy(1)
142 | local discount_factor = 1
143 | local episode = get_policy_episode(policy, discount_factor)
144 |
145 | tester:assert(#episode == 2)
146 | tester:assert(episode[1].state == 1)
147 | tester:assert(episode[1].action == 1)
148 | tester:assert(episode[1].discounted_return == -2)
149 | tester:assert(episode[1].reward == -1)
150 | tester:assert(episode[2].state == 2)
151 | tester:assert(episode[2].action == 1)
152 | tester:assert(episode[2].discounted_return == -1)
153 | tester:assert(episode[2].reward == -1)
154 | end
155 |
156 | tester:add(TestMdpSampler)
157 |
158 | tester:run()
159 |
--------------------------------------------------------------------------------
/test/unittest_util.lua:
--------------------------------------------------------------------------------
1 | require 'rl'
2 |
3 | local tester = torch.Tester()
4 |
5 | local TestUtil = {}
6 | function TestUtil.test_get_count()
7 | local list = {1, 1, 1, 2}
8 | tester:asserteq(rl.util.get_count(1, list), 3)
9 | end
10 |
11 | function TestUtil.test_is_prob_good_good()
12 | local n = 5
13 | local p = 0.5
14 | local N = 10
15 | tester:assert(rl.util.is_prob_good(n, p, N))
16 | tester:assert(rl.util.is_prob_good(0, 0, N))
17 | end
18 |
19 | function TestUtil.test_is_prob_good_bad()
20 | local n = 1
21 | local p = 0.5
22 | local N = 100
23 | tester:assert(not rl.util.is_prob_good(n, p, N))
24 | tester:assert(not rl.util.is_prob_good(1, 0, N))
25 | end
26 |
27 | function TestUtil.test_elem_has_good_freq()
28 | local list = {1, 1, 1, 2}
29 | tester:assert(rl.util.elem_has_good_freq(1, list, 0.75))
30 | tester:assert(rl.util.elem_has_good_freq(2, list, 0.25))
31 | tester:assert(rl.util.elem_has_good_freq(3, list, 0))
32 | tester:assert(not rl.util.elem_has_good_freq(2, list, 0))
33 | end
34 |
35 | function TestUtil.test_fold()
36 | local t = {
37 | 1,
38 | 1,
39 | 1,
40 | 3,
41 | 1,
42 | 1,
43 | 1,
44 | 3
45 | }
46 |
47 | tester:asserteq(rl.util.sum(t), 12)
48 | tester:asserteq(rl.util.fold(function(a, b) return a - b end)(0)(t), -12)
49 | tester:asserteq(rl.util.fold(function(a, b) return a - b end)(-8)(t), -20)
50 | end
51 |
52 | function TestUtil.test_fold_with_key()
53 | local t = {
54 | 1,
55 | 1,
56 | 1,
57 | 3,
58 | }
59 |
60 | tester:asserteq(
61 | rl.util.fold_with_key(function(a, k, b) return a + k + b end)(0)(t),
62 | 16)
63 | tester:asserteq(
64 | rl.util.fold_with_key(function(a, k, b) return a + k + b end)(-6)(t),
65 | 10)
66 | end
67 |
68 | function TestUtil.test_weighted_random_choice_only_one()
69 | local t = {
70 | a = 0,
71 | b = 0,
72 | c = 0,
73 | d = 1
74 | }
75 | local function choice_is_always_good()
76 | for i = 1, 100 do
77 | if rl.util.weighted_random_choice(t) ~= 'd' then
78 | return false
79 | end
80 | end
81 | return true
82 | end
83 | tester:assert(choice_is_always_good())
84 | end
85 |
86 | function TestUtil.test_weighted_random_choice()
87 | t = {
88 | 1,
89 | 1,
90 | 1,
91 | 3,
92 | 1,
93 | 1,
94 | 1,
95 | 3
96 | }
97 | local N = 100000
98 | local denom = rl.util.sum(t)
99 | local nums = torch.zeros(8)
100 | for i = 1, N do
101 | result = rl.util.weighted_random_choice(t)
102 | nums[result] = nums[result] + 1
103 | end
104 | local function prob_is_good()
105 | for i = 1, nums:numel() do
106 | if not rl.util.is_prob_good(nums[i], t[i] / denom, N) then
107 | return false
108 | end
109 | end
110 | return true
111 | end
112 | tester:assert(prob_is_good())
113 | end
114 |
115 | function TestUtil.test_max()
116 | local t = {
117 | a = 1,
118 | b = 2,
119 | c = 3,
120 | d = -6
121 | }
122 | local maxElem, maxK = rl.util.max(t, function (v) return v end)
123 | tester:asserteq(maxElem , 3)
124 | tester:asserteq(maxK , 'c')
125 |
126 | maxElem, maxK = rl.util.max(t, function (v) return -v end)
127 | tester:asserteq(maxElem , -6)
128 | tester:asserteq(maxK , 'd')
129 | end
130 |
131 | function TestUtil.test_deepcompare()
132 | local t1 = {
133 | a = 1,
134 | b = {x=4, y=3, z=1},
135 | c = {1, 2, 3},
136 | d = -6
137 | }
138 | local t2 = {
139 | b = {x=4, y=3, z=1},
140 | a = 1,
141 | d = -6,
142 | c = {1, 2, 3}
143 | }
144 | tester:assert(rl.util.deepcompare(t1, t2))
145 | end
146 |
147 | function TestUtil.test_deepcompare_int_keys()
148 | local t1 = {
149 | [4] = 1,
150 | [-3] = 2,
151 | [0] = {1, 2, 3},
152 | [1] = -6
153 | }
154 | local t2 = {
155 | [-3] = 2,
156 | [4] = 1,
157 | [1] = -6,
158 | [0] = {1, 2, 3}
159 | }
160 | tester:assert(rl.util.deepcompare(t1, t2))
161 | end
162 |
163 | function TestUtil.test_deepcompare_with_meta()
164 | local t1 = {
165 | a = 1,
166 | b = {x=4, y=3, z=1},
167 | c = {1, 2, 3},
168 | d = -6
169 | }
170 | local t2 = {
171 | b = {x=4, y=3, z=1},
172 | a = 1,
173 | d = -6,
174 | c = {1, 2, 3}
175 | }
176 | local mt = {
177 | __eq = function (lhs, rhs)
178 | return true
179 | end
180 | }
181 | setmetatable(t1, mt)
182 | setmetatable(t2, mt)
183 | tester:assert(rl.util.deepcompare_with_meta(t1, t2))
184 | end
185 |
186 | function TestUtil.test_deepcompare_with_meta_false()
187 | local t1 = {
188 | a = 1,
189 | b = {x=4, y=3, z=1},
190 | c = {1, 2, 3},
191 | d = -6
192 | }
193 | local t2 = {
194 | b = {x=4, y=3, z=1},
195 | a = 1,
196 | d = -6,
197 | c = {1, 2, 3}
198 | }
199 | local mt = {
200 | __eq = function (lhs, rhs)
201 | return false
202 | end
203 | }
204 | setmetatable(t1, mt)
205 | setmetatable(t2, mt)
206 |
207 | tester:assert(not rl.util.deepcompare_with_meta(t1, t2))
208 | end
209 |
210 | tester:add(TestUtil)
211 |
212 | tester:run()
213 |
--------------------------------------------------------------------------------
/SarsaAnalyzer.lua:
--------------------------------------------------------------------------------
1 | -- Analyze different control algorithms.
2 | local gnuplot = require 'gnuplot'
3 |
4 | local SarsaAnalyzer = torch.class('rl.SarsaAnalyzer')
5 |
6 | function SarsaAnalyzer:__init(opt, mdp_config, qvanalyzer, sarsa_factory)
7 | self.loadqfrom = opt.loadqfrom
8 | self.save = opt.save
9 | self.show = opt.show
10 | self.rms_num_points = opt.rms_num_points
11 | self.n_iters = opt.n_iters or N_ITERS
12 |
13 | self.mdp_config = mdp_config
14 | self.qvanalyzer = qvanalyzer
15 | self.sarsa_factory = sarsa_factory
16 |
17 | self.q_mc = nil
18 | end
19 |
20 | function SarsaAnalyzer:get_true_q(n_iters)
21 | if self.loadqfrom ~= nil and self.loadqfrom ~= '' then
22 | print('Loading q_mc from ' .. self.loadqfrom)
23 | return rl.util.load_q(self.loadqfrom)
24 | end
25 |
26 | self.n_iters = n_iters or self.n_iters
27 | local mc = rl.MonteCarloControl(self.mdp_config)
28 | print('Computing Q from Monte Carlo. # iters = ' .. self.n_iters)
29 | mc:improve_policy_for_n_iters(self.n_iters)
30 |
31 | return mc:get_q()
32 | end
33 |
34 | local function plot_rms_lambda_data(self, data)
35 | gnuplot.plot(data)
36 | gnuplot.grid(true)
37 | gnuplot.xlabel('lambda')
38 | gnuplot.ylabel('RMS between Q-MC and Q-SARSA')
39 | gnuplot.title('Q RMS episodes vs lambda, after '
40 | .. self.n_iters .. ' iterations')
41 | end
42 |
43 | local function get_sarsa(self, lambda)
44 | self.sarsa_factory:set_lambda(lambda)
45 | return self.sarsa_factory:get_control()
46 | end
47 |
48 | local function plot_results(self, plot_function, image_fname)
49 | if self.show then
50 | gnuplot.figure()
51 | plot_function()
52 | end
53 | if self.save then
54 | gnuplot.epsfigure(image_fname)
55 | print('Saving plot to: ' .. image_fname)
56 | plot_function()
57 | gnuplot.plotflush()
58 | end
59 | end
60 |
61 | local function get_lambda_data(self)
62 | local rms_lambda_data = torch.Tensor(11, 2)
63 | local i = 1
64 | print('Generating data/plot for varying lambdas.')
65 | for lambda = 0, 1, 0.1 do
66 | print('Processing SARSA for lambda = ' .. lambda)
67 | local sarsa = get_sarsa(self, lambda)
68 | sarsa:improve_policy(self.n_iters)
69 | local q = sarsa:get_q()
70 | rms_lambda_data[i][1] = lambda
71 | rms_lambda_data[i][2] = self.qvanalyzer:q_rms(q, self.q_mc)
72 | i = i + 1
73 | end
74 | return rms_lambda_data
75 | end
76 |
77 | -- For a given control algorithm, see how the RMS changes with lambda.
78 | -- Sweeps and plots the performance for lambda = 0, 0.1, 0.2, ..., 1.0
79 | function SarsaAnalyzer:eval_lambdas(
80 | image_fname,
81 | n_iters)
82 | self.q_mc = self.q_mc or self:get_true_q()
83 | self.n_iters = n_iters or self.n_iters
84 | local rms_lambda_data = torch.Tensor(11, 2)
85 | local i = 1
86 | print('Generating data/plot for varying lambdas.')
87 | for lambda = 0, 1, 0.1 do
88 | print('Processing SARSA for lambda = ' .. lambda)
89 | local sarsa = get_sarsa(self, lambda)
90 | sarsa:improve_policy_for_n_iters(self.n_iters)
91 | local q = sarsa:get_q()
92 | rms_lambda_data[i][1] = lambda
93 | rms_lambda_data[i][2] = self.qvanalyzer:q_rms(q, self.q_mc)
94 | i = i + 1
95 | end
96 |
97 | plot_results(self,
98 | function ()
99 | plot_rms_lambda_data(self, rms_lambda_data)
100 | end,
101 | image_fname)
102 | end
103 |
104 | local function plot_rms_episode_data(self, data_table)
105 | for lambda, data in pairs(data_table) do
106 | gnuplot.plot({tostring(lambda), data})
107 | end
108 |
109 | gnuplot.plot({'0', data[0]},
110 | {'1', data[1]})
111 | gnuplot.grid(true)
112 | gnuplot.xlabel('Episode')
113 | gnuplot.ylabel('RMS between Q-MC and Q-SARSA')
114 | gnuplot.title('Q RMS vs Episode, lambda = 0 and 1, after '
115 | .. self.n_iters .. ' iterations')
116 | end
117 |
118 | -- hack to get around that torch doesn't seem to allow private class methods
119 | local function get_rms_episode_data(self, lambda)
120 | local rms_episode_data = torch.Tensor(self.rms_num_points, 2)
121 | local sarsa = get_sarsa(self, lambda)
122 | sarsa:improve_policy()
123 | local q = sarsa:get_q()
124 | rms_episode_data[1][1] = 1
125 | rms_episode_data[1][2] = self.qvanalyzer:q_rms(q, self.q_mc)
126 | local n_iters_per_data_point = self.n_iters / self.rms_num_points
127 | local i = n_iters_per_data_point
128 | for j = 2, self.rms_num_points do
129 | sarsa:improve_policy_for_n_iters(n_iters_per_data_point)
130 | q = sarsa:get_q()
131 | rms_episode_data[j][1] = i
132 | rms_episode_data[j][2] = self.qvanalyzer:q_rms(q, self.q_mc)
133 | i = i + n_iters_per_data_point
134 | end
135 | return rms_episode_data
136 | end
137 |
138 | -- For a given control algorithm, see how the RMS changes with # of episodes for
139 | -- lambda = 0 and lambda = 1.
140 | function SarsaAnalyzer:eval_l0_l1_rms(
141 | image_fname,
142 | n_iters)
143 | self.q_mc = self.q_mc or self:get_true_q()
144 | n_iters = n_iters or self.n_iters
145 |
146 | print('Generating data for RMS vs episode')
147 | local l0_data = get_rms_episode_data(self, 0)
148 | local l1_data = get_rms_episode_data(self, 1)
149 | data = {}
150 | data[0] = l0_data
151 | data[1] = l1_data
152 |
153 | print('Generating plots for RMS vs episode')
154 | plot_results(self,
155 | function ()
156 | plot_rms_episode_data(self, data)
157 | end,
158 | image_fname)
159 | end
160 |
--------------------------------------------------------------------------------
/test/unittest_TableSarsa.lua:
--------------------------------------------------------------------------------
1 |
2 | local tester = torch.Tester()
3 |
4 | local TestTableSarsa = {}
5 | local discount_factor = 0.9
6 | local mdp = rl.TestMdp()
7 | local mdp_config = rl.MdpConfig(mdp, discount_factor)
8 |
9 | local function non_q_params_match(
10 | sarsa,
11 | Ns_expected,
12 | Nsa_expected,
13 | eligibility_expected)
14 | local Ns = sarsa.Ns
15 | local Nsa = sarsa.Nsa
16 | local elig = sarsa.eligibility
17 |
18 | return rl.util.do_vtable_vfunc_match(mdp, Ns_expected, Ns)
19 | and rl.util.do_qtable_qfunc_match(mdp, Nsa_expected, Nsa)
20 | and rl.util.do_qtable_qfunc_match(mdp, eligibility_expected, elig)
21 | end
22 |
23 | function TestTableSarsa.test_update_eligibility_one_step()
24 | local lambda = 1
25 | local sarsa = rl.TableSarsa(mdp_config, lambda)
26 |
27 | local s = 2
28 | local a = 1
29 | sarsa:update_eligibility(s, a)
30 |
31 | local Ns_expected = {0, 1, 0}
32 | local Nsa_expected = { -- row = state, colm = action
33 | [1] = {0, 0, 0},
34 | [2] = {1, 0, 0},
35 | [3] = {0, 0, 0}
36 | }
37 | local eligibility_expected = Nsa_expected
38 | local correct = non_q_params_match(
39 | sarsa,
40 | Ns_expected,
41 | Nsa_expected,
42 | eligibility_expected)
43 | tester:assert(correct)
44 | end
45 |
46 | function TestTableSarsa.test_update_eligibility_lambda1()
47 | local lambda = 1
48 | local sarsa = rl.TableSarsa(mdp_config, lambda)
49 |
50 | local s = 2
51 | local a = 1
52 | sarsa:update_eligibility(s, a)
53 |
54 | local Ns_expected = {
55 | [1] = 0,
56 | [2] = 1,
57 | [3] = 0,
58 | }
59 | local Nsa_expected = { -- row = state, colm = action
60 | [1] = {0, 0, 0},
61 | [2] = {1, 0, 0},
62 | [3] = {0, 0, 0}
63 | }
64 | local eligibility_expected = Nsa_expected
65 | local correct = non_q_params_match(
66 | sarsa,
67 | Ns_expected,
68 | Nsa_expected,
69 | eligibility_expected)
70 | tester:assert(correct)
71 | end
72 |
73 | function TestTableSarsa:test_update_eligibility_lambda_frac()
74 | local lambda = 0.5
75 | local sarsa = rl.TableSarsa(mdp_config, lambda)
76 |
77 | local s = 2
78 | local a = 1
79 | sarsa:update_eligibility(s, a)
80 | sarsa:update_eligibility(s, a)
81 | sarsa:update_eligibility(s, a)
82 |
83 | local decay_factor = lambda * discount_factor
84 |
85 | local Ns_expected = {0, 3, 0}
86 | local Nsa_expected = { -- row = state, colm = action
87 | [1] = {0, 0, 0},
88 | [2] = {3, 0, 0},
89 | [3] = {0, 0, 0}
90 | }
91 | local eligibility_expected = {
92 | [1] = {0, 0, 0},
93 | [2] = {1 + decay_factor * (1 + decay_factor*1), 0, 0},
94 | [3] = {0, 0, 0}
95 | }
96 | local correct = non_q_params_match(
97 | sarsa,
98 | Ns_expected,
99 | Nsa_expected,
100 | eligibility_expected)
101 | tester:assert(correct)
102 | end
103 |
104 | function TestTableSarsa:test_update_eligibility_lambda0()
105 | local lambda = 0
106 | local sarsa = rl.TableSarsa(mdp_config, lambda)
107 |
108 | local s = 2
109 | local a = 1
110 | sarsa:update_eligibility(s, a)
111 | sarsa:update_eligibility(s, a)
112 | sarsa:update_eligibility(s, a)
113 |
114 | local decay_factor = lambda * discount_factor
115 |
116 | local Ns_expected = {0, 3, 0}
117 | local Nsa_expected = { -- row = state, colm = action
118 | [1] = {0, 0, 0},
119 | [2] = {3, 0, 0},
120 | [3] = {0, 0, 0}
121 | }
122 | local eligibility_expected = {
123 | [1] = {0, 0, 0},
124 | [2] = {1 + decay_factor * (1 + decay_factor*1), 0, 0},
125 | [3] = {0, 0, 0}
126 | }
127 | local correct = non_q_params_match(
128 | sarsa,
129 | Ns_expected,
130 | Nsa_expected,
131 | eligibility_expected)
132 | tester:assert(correct)
133 | end
134 |
135 | function TestTableSarsa:test_update_eligibility_mixed_updates()
136 | local lambda = 0.4
137 | local sarsa = rl.TableSarsa(mdp_config, lambda)
138 |
139 | local s = 2
140 | local a = 1
141 | sarsa:update_eligibility(s, a)
142 | s = 1
143 | a = 2
144 | sarsa:update_eligibility(s, a)
145 | s = 2
146 | a = 3
147 | sarsa:update_eligibility(s, a)
148 |
149 | local decay_factor = lambda * discount_factor
150 |
151 | local Ns_expected = {1, 2, 0}
152 | local Nsa_expected = { -- row = state, colm = action
153 | [1] = {0, 1, 0},
154 | [2] = {1, 0, 1},
155 | [3] = {0, 0, 0}
156 | }
157 | local eligibility_expected = {
158 | [1] = {0, decay_factor, 0},
159 | [2] = {decay_factor^2, 0, 1},
160 | [3] = {0, 0, 0}
161 | }
162 | local correct = non_q_params_match(
163 | sarsa,
164 | Ns_expected,
165 | Nsa_expected,
166 | eligibility_expected)
167 | tester:assert(correct)
168 | end
169 |
170 | function TestTableSarsa:test_td_update_one_update()
171 | local lambda = 0.4
172 | local sarsa = rl.TableSarsa(mdp_config, lambda)
173 |
174 | local s = 2
175 | local a = 1
176 | sarsa:update_eligibility(s, a)
177 | local td_error = 5
178 | sarsa:td_update(td_error)
179 | local q_expected = { -- row = state, colm = action
180 | [1] = {0, 0, 0},
181 | [2] = {5, 0, 0},
182 | [3] = {0, 0, 0}
183 | }
184 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, q_expected, sarsa.q))
185 | end
186 |
187 | function TestTableSarsa:test_td_update_many_updates()
188 | local lambda = 0.4
189 | local sarsa = rl.TableSarsa(mdp_config, lambda)
190 |
191 | local s = 2
192 | local a = 1
193 | sarsa:update_eligibility(s, a)
194 | local td_error = 5
195 | sarsa:td_update(td_error)
196 | sarsa:update_eligibility(s, a)
197 | sarsa:td_update(td_error)
198 |
199 | s = 3
200 | a = 3
201 | sarsa:update_eligibility(s, a)
202 | td_error = -10
203 | sarsa:td_update(td_error)
204 |
205 | local decay_factor = lambda * discount_factor
206 | local q_expected = { -- row = state, colm = action
207 | [1] = {0, 0, 0},
208 | [2] = {5+5*(1+decay_factor)/2-10*decay_factor*(1+decay_factor)/2, 0, 0},
209 | [3] = {0, 0, -10}
210 | }
211 | tester:assert(rl.util.do_qtable_qfunc_match(mdp, q_expected, sarsa.q))
212 | end
213 |
214 | tester:add(TestTableSarsa)
215 |
216 | tester:run()
217 |
218 |
--------------------------------------------------------------------------------
/images/q3a.eps:
--------------------------------------------------------------------------------
1 | %!PS-Adobe-2.0 EPSF-2.0
2 | %%Title: q3a.eps
3 | %%Creator: gnuplot 4.6 patchlevel 4
4 | %%CreationDate: Wed Jan 20 18:03:19 2016
5 | %%DocumentFonts: (atend)
6 | %%BoundingBox: 50 50 410 302
7 | %%EndComments
8 | %%BeginProlog
9 | /gnudict 256 dict def
10 | gnudict begin
11 | %
12 | % The following true/false flags may be edited by hand if desired.
13 | % The unit line width and grayscale image gamma correction may also be changed.
14 | %
15 | /Color true def
16 | /Blacktext false def
17 | /Solid false def
18 | /Dashlength 1 def
19 | /Landscape false def
20 | /Level1 false def
21 | /Rounded false def
22 | /ClipToBoundingBox false def
23 | /SuppressPDFMark false def
24 | /TransparentPatterns false def
25 | /gnulinewidth 5.000 def
26 | /userlinewidth gnulinewidth def
27 | /Gamma 1.0 def
28 | /BackgroundColor {-1.000 -1.000 -1.000} def
29 | %
30 | /vshift -46 def
31 | /dl1 {
32 | 10.0 Dashlength mul mul
33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if
34 | } def
35 | /dl2 {
36 | 10.0 Dashlength mul mul
37 | Rounded { currentlinewidth 0.75 mul add } if
38 | } def
39 | /hpt_ 31.5 def
40 | /vpt_ 31.5 def
41 | /hpt hpt_ def
42 | /vpt vpt_ def
43 | /doclip {
44 | ClipToBoundingBox {
45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath
46 | clip
47 | } if
48 | } def
49 | %
50 | % Gnuplot Prolog Version 4.6 (September 2012)
51 | %
52 | %/SuppressPDFMark true def
53 | %
54 | /M {moveto} bind def
55 | /L {lineto} bind def
56 | /R {rmoveto} bind def
57 | /V {rlineto} bind def
58 | /N {newpath moveto} bind def
59 | /Z {closepath} bind def
60 | /C {setrgbcolor} bind def
61 | /f {rlineto fill} bind def
62 | /g {setgray} bind def
63 | /Gshow {show} def % May be redefined later in the file to support UTF-8
64 | /vpt2 vpt 2 mul def
65 | /hpt2 hpt 2 mul def
66 | /Lshow {currentpoint stroke M 0 vshift R
67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R
69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R
71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def
73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def
74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash}
75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def
76 | /BL {stroke userlinewidth 2 mul setlinewidth
77 | Rounded {1 setlinejoin 1 setlinecap} if} def
78 | /AL {stroke userlinewidth 2 div setlinewidth
79 | Rounded {1 setlinejoin 1 setlinecap} if} def
80 | /UL {dup gnulinewidth mul /userlinewidth exch def
81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def
82 | /PL {stroke userlinewidth setlinewidth
83 | Rounded {1 setlinejoin 1 setlinecap} if} def
84 | 3.8 setmiterlimit
85 | % Default Line colors
86 | /LCw {1 1 1} def
87 | /LCb {0 0 0} def
88 | /LCa {0 0 0} def
89 | /LC0 {1 0 0} def
90 | /LC1 {0 1 0} def
91 | /LC2 {0 0 1} def
92 | /LC3 {1 0 1} def
93 | /LC4 {0 1 1} def
94 | /LC5 {1 1 0} def
95 | /LC6 {0 0 0} def
96 | /LC7 {1 0.3 0} def
97 | /LC8 {0.5 0.5 0.5} def
98 | % Default Line Types
99 | /LTw {PL [] 1 setgray} def
100 | /LTb {BL [] LCb DL} def
101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def
102 | /LT0 {PL [] LC0 DL} def
103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def
104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def
105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def
106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def
107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def
108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def
109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def
110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def
111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def
112 | /Dia {stroke [] 0 setdash 2 copy vpt add M
113 | hpt neg vpt neg V hpt vpt neg V
114 | hpt vpt V hpt neg vpt V closepath stroke
115 | Pnt} def
116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V
117 | currentpoint stroke M
118 | hpt neg vpt neg R hpt2 0 V stroke
119 | } def
120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M
121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
122 | hpt2 neg 0 V closepath stroke
123 | Pnt} def
124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M
125 | hpt2 vpt2 neg V currentpoint stroke M
126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def
127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M
128 | hpt neg vpt -1.62 mul V
129 | hpt 2 mul 0 V
130 | hpt neg vpt 1.62 mul V closepath stroke
131 | Pnt} def
132 | /Star {2 copy Pls Crs} def
133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M
134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
135 | hpt2 neg 0 V closepath fill} def
136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M
137 | hpt neg vpt -1.62 mul V
138 | hpt 2 mul 0 V
139 | hpt neg vpt 1.62 mul V closepath fill} def
140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M
141 | hpt neg vpt 1.62 mul V
142 | hpt 2 mul 0 V
143 | hpt neg vpt -1.62 mul V closepath stroke
144 | Pnt} def
145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M
146 | hpt neg vpt 1.62 mul V
147 | hpt 2 mul 0 V
148 | hpt neg vpt -1.62 mul V closepath fill} def
149 | /DiaF {stroke [] 0 setdash vpt add M
150 | hpt neg vpt neg V hpt vpt neg V
151 | hpt vpt V hpt neg vpt V closepath fill} def
152 | /Pent {stroke [] 0 setdash 2 copy gsave
153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
154 | closepath stroke grestore Pnt} def
155 | /PentF {stroke [] 0 setdash gsave
156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
157 | closepath fill grestore} def
158 | /Circle {stroke [] 0 setdash 2 copy
159 | hpt 0 360 arc stroke Pnt} def
160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def
161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def
162 | /C1 {BL [] 0 setdash 2 copy moveto
163 | 2 copy vpt 0 90 arc closepath fill
164 | vpt 0 360 arc closepath} bind def
165 | /C2 {BL [] 0 setdash 2 copy moveto
166 | 2 copy vpt 90 180 arc closepath fill
167 | vpt 0 360 arc closepath} bind def
168 | /C3 {BL [] 0 setdash 2 copy moveto
169 | 2 copy vpt 0 180 arc closepath fill
170 | vpt 0 360 arc closepath} bind def
171 | /C4 {BL [] 0 setdash 2 copy moveto
172 | 2 copy vpt 180 270 arc closepath fill
173 | vpt 0 360 arc closepath} bind def
174 | /C5 {BL [] 0 setdash 2 copy moveto
175 | 2 copy vpt 0 90 arc
176 | 2 copy moveto
177 | 2 copy vpt 180 270 arc closepath fill
178 | vpt 0 360 arc} bind def
179 | /C6 {BL [] 0 setdash 2 copy moveto
180 | 2 copy vpt 90 270 arc closepath fill
181 | vpt 0 360 arc closepath} bind def
182 | /C7 {BL [] 0 setdash 2 copy moveto
183 | 2 copy vpt 0 270 arc closepath fill
184 | vpt 0 360 arc closepath} bind def
185 | /C8 {BL [] 0 setdash 2 copy moveto
186 | 2 copy vpt 270 360 arc closepath fill
187 | vpt 0 360 arc closepath} bind def
188 | /C9 {BL [] 0 setdash 2 copy moveto
189 | 2 copy vpt 270 450 arc closepath fill
190 | vpt 0 360 arc closepath} bind def
191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill
192 | 2 copy moveto
193 | 2 copy vpt 90 180 arc closepath fill
194 | vpt 0 360 arc closepath} bind def
195 | /C11 {BL [] 0 setdash 2 copy moveto
196 | 2 copy vpt 0 180 arc closepath fill
197 | 2 copy moveto
198 | 2 copy vpt 270 360 arc closepath fill
199 | vpt 0 360 arc closepath} bind def
200 | /C12 {BL [] 0 setdash 2 copy moveto
201 | 2 copy vpt 180 360 arc closepath fill
202 | vpt 0 360 arc closepath} bind def
203 | /C13 {BL [] 0 setdash 2 copy moveto
204 | 2 copy vpt 0 90 arc closepath fill
205 | 2 copy moveto
206 | 2 copy vpt 180 360 arc closepath fill
207 | vpt 0 360 arc closepath} bind def
208 | /C14 {BL [] 0 setdash 2 copy moveto
209 | 2 copy vpt 90 360 arc closepath fill
210 | vpt 0 360 arc} bind def
211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill
212 | vpt 0 360 arc closepath} bind def
213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto
214 | neg 0 rlineto closepath} bind def
215 | /Square {dup Rec} bind def
216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def
217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def
218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def
219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def
221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill
223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def
225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill
226 | 2 copy vpt Square fill Bsquare} bind def
227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def
228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def
229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill
230 | Bsquare} bind def
231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill
232 | Bsquare} bind def
233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def
234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
235 | 2 copy vpt Square fill Bsquare} bind def
236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def
239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def
240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def
241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def
242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def
243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def
244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def
245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def
246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def
247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def
248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def
249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def
250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def
251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def
252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def
253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def
254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def
255 | /DiaE {stroke [] 0 setdash vpt add M
256 | hpt neg vpt neg V hpt vpt neg V
257 | hpt vpt V hpt neg vpt V closepath stroke} def
258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M
259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
260 | hpt2 neg 0 V closepath stroke} def
261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M
262 | hpt neg vpt -1.62 mul V
263 | hpt 2 mul 0 V
264 | hpt neg vpt 1.62 mul V closepath stroke} def
265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M
266 | hpt neg vpt 1.62 mul V
267 | hpt 2 mul 0 V
268 | hpt neg vpt -1.62 mul V closepath stroke} def
269 | /PentE {stroke [] 0 setdash gsave
270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
271 | closepath stroke grestore} def
272 | /CircE {stroke [] 0 setdash
273 | hpt 0 360 arc stroke} def
274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def
275 | /DiaW {stroke [] 0 setdash vpt add M
276 | hpt neg vpt neg V hpt vpt neg V
277 | hpt vpt V hpt neg vpt V Opaque stroke} def
278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M
279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
280 | hpt2 neg 0 V Opaque stroke} def
281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M
282 | hpt neg vpt -1.62 mul V
283 | hpt 2 mul 0 V
284 | hpt neg vpt 1.62 mul V Opaque stroke} def
285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M
286 | hpt neg vpt 1.62 mul V
287 | hpt 2 mul 0 V
288 | hpt neg vpt -1.62 mul V Opaque stroke} def
289 | /PentW {stroke [] 0 setdash gsave
290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
291 | Opaque stroke grestore} def
292 | /CircW {stroke [] 0 setdash
293 | hpt 0 360 arc Opaque stroke} def
294 | /BoxFill {gsave Rec 1 setgray fill grestore} def
295 | /Density {
296 | /Fillden exch def
297 | currentrgbcolor
298 | /ColB exch def /ColG exch def /ColR exch def
299 | /ColR ColR Fillden mul Fillden sub 1 add def
300 | /ColG ColG Fillden mul Fillden sub 1 add def
301 | /ColB ColB Fillden mul Fillden sub 1 add def
302 | ColR ColG ColB setrgbcolor} def
303 | /BoxColFill {gsave Rec PolyFill} def
304 | /PolyFill {gsave Density fill grestore grestore} def
305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def
306 | %
307 | % PostScript Level 1 Pattern Fill routine for rectangles
308 | % Usage: x y w h s a XX PatternFill
309 | % x,y = lower left corner of box to be filled
310 | % w,h = width and height of box
311 | % a = angle in degrees between lines and x-axis
312 | % XX = 0/1 for no/yes cross-hatch
313 | %
314 | /PatternFill {gsave /PFa [ 9 2 roll ] def
315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate
316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec
317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
318 | clip
319 | currentlinewidth 0.5 mul setlinewidth
320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def
321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate
322 | 0 1 PFs PFa 4 get div 1 add floor cvi
323 | {PFa 4 get mul 0 M 0 PFs V} for
324 | 0 PFa 6 get ne {
325 | 0 1 PFs PFa 4 get div 1 add floor cvi
326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for
327 | } if
328 | stroke grestore} def
329 | %
330 | /languagelevel where
331 | {pop languagelevel} {1} ifelse
332 | 2 lt
333 | {/InterpretLevel1 true def}
334 | {/InterpretLevel1 Level1 def}
335 | ifelse
336 | %
337 | % PostScript level 2 pattern fill definitions
338 | %
339 | /Level2PatternFill {
340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8}
341 | bind def
342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def
343 | << Tile8x8
344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke}
345 | >> matrix makepattern
346 | /Pat1 exch def
347 | << Tile8x8
348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke
349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke}
350 | >> matrix makepattern
351 | /Pat2 exch def
352 | << Tile8x8
353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L
354 | 8 8 L 8 0 L 0 0 L fill}
355 | >> matrix makepattern
356 | /Pat3 exch def
357 | << Tile8x8
358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L
359 | 0 12 M 12 0 L stroke}
360 | >> matrix makepattern
361 | /Pat4 exch def
362 | << Tile8x8
363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L
364 | 0 -4 M 12 8 L stroke}
365 | >> matrix makepattern
366 | /Pat5 exch def
367 | << Tile8x8
368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L
369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke}
370 | >> matrix makepattern
371 | /Pat6 exch def
372 | << Tile8x8
373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L
374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke}
375 | >> matrix makepattern
376 | /Pat7 exch def
377 | << Tile8x8
378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L
379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke}
380 | >> matrix makepattern
381 | /Pat8 exch def
382 | << Tile8x8
383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L
384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke}
385 | >> matrix makepattern
386 | /Pat9 exch def
387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def
388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def
389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def
390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def
391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def
392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def
393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def
394 | } def
395 | %
396 | %
397 | %End of PostScript Level 2 code
398 | %
399 | /PatternBgnd {
400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
401 | } def
402 | %
403 | % Substitute for Level 2 pattern fill codes with
404 | % grayscale if Level 2 support is not selected.
405 | %
406 | /Level1PatternFill {
407 | /Pattern1 {0.250 Density} bind def
408 | /Pattern2 {0.500 Density} bind def
409 | /Pattern3 {0.750 Density} bind def
410 | /Pattern4 {0.125 Density} bind def
411 | /Pattern5 {0.375 Density} bind def
412 | /Pattern6 {0.625 Density} bind def
413 | /Pattern7 {0.875 Density} bind def
414 | } def
415 | %
416 | % Now test for support of Level 2 code
417 | %
418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse
419 | %
420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont
421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall
422 | currentdict end definefont pop
423 | /MFshow {
424 | { dup 5 get 3 ge
425 | { 5 get 3 eq {gsave} {grestore} ifelse }
426 | {dup dup 0 get findfont exch 1 get scalefont setfont
427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6
428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq
429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5
430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div
431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get
432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop
433 | pop aload pop M} ifelse }ifelse }ifelse }
434 | ifelse }
435 | forall} def
436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def
437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse }
438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont
439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def
440 | /MLshow { currentpoint stroke M
441 | 0 exch R
442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
443 | /MRshow { currentpoint stroke M
444 | exch dup MFwidth neg 3 -1 roll R
445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
446 | /MCshow { currentpoint stroke M
447 | exch dup MFwidth -2 div 3 -1 roll R
448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def
450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def
451 | Level1 SuppressPDFMark or
452 | {} {
453 | /SDict 10 dict def
454 | systemdict /pdfmark known not {
455 | userdict /pdfmark systemdict /cleartomark get put
456 | } if
457 | SDict begin [
458 | /Title (q3a.eps)
459 | /Subject (gnuplot plot)
460 | /Creator (gnuplot 4.6 patchlevel 4)
461 | /Author (vitchyr)
462 | % /Producer (gnuplot)
463 | % /Keywords ()
464 | /CreationDate (Wed Jan 20 18:03:19 2016)
465 | /DOCINFO pdfmark
466 | end
467 | } ifelse
468 | end
469 | %%EndProlog
470 | %%Page: 1 1
471 | gnudict begin
472 | gsave
473 | doclip
474 | 50 50 translate
475 | 0.050 0.050 scale
476 | 0 setgray
477 | newpath
478 | (Helvetica) findfont 140 scalefont setfont
479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if
480 | 1.000 UL
481 | LTb
482 | 0.13 0.13 0.13 C 1.000 UL
483 | LTa
484 | LCa setrgbcolor
485 | 602 448 M
486 | 6345 0 V
487 | stroke
488 | LTb
489 | 0.13 0.13 0.13 C 602 448 M
490 | 63 0 V
491 | 6282 0 R
492 | -63 0 V
493 | stroke
494 | 518 448 M
495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 30)]
496 | ] -46.7 MRshow
497 | 1.000 UL
498 | LTb
499 | 0.13 0.13 0.13 C 1.000 UL
500 | LTa
501 | LCa setrgbcolor
502 | 602 1143 M
503 | 6345 0 V
504 | stroke
505 | LTb
506 | 0.13 0.13 0.13 C 602 1143 M
507 | 63 0 V
508 | 6282 0 R
509 | -63 0 V
510 | stroke
511 | 518 1143 M
512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 35)]
513 | ] -46.7 MRshow
514 | 1.000 UL
515 | LTb
516 | 0.13 0.13 0.13 C 1.000 UL
517 | LTa
518 | LCa setrgbcolor
519 | 602 1838 M
520 | 6345 0 V
521 | stroke
522 | LTb
523 | 0.13 0.13 0.13 C 602 1838 M
524 | 63 0 V
525 | 6282 0 R
526 | -63 0 V
527 | stroke
528 | 518 1838 M
529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 40)]
530 | ] -46.7 MRshow
531 | 1.000 UL
532 | LTb
533 | 0.13 0.13 0.13 C 1.000 UL
534 | LTa
535 | LCa setrgbcolor
536 | 602 2534 M
537 | 6345 0 V
538 | stroke
539 | LTb
540 | 0.13 0.13 0.13 C 602 2534 M
541 | 63 0 V
542 | 6282 0 R
543 | -63 0 V
544 | stroke
545 | 518 2534 M
546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 45)]
547 | ] -46.7 MRshow
548 | 1.000 UL
549 | LTb
550 | 0.13 0.13 0.13 C 1.000 UL
551 | LTa
552 | LCa setrgbcolor
553 | 602 3229 M
554 | 6345 0 V
555 | stroke
556 | LTb
557 | 0.13 0.13 0.13 C 602 3229 M
558 | 63 0 V
559 | 6282 0 R
560 | -63 0 V
561 | stroke
562 | 518 3229 M
563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 50)]
564 | ] -46.7 MRshow
565 | 1.000 UL
566 | LTb
567 | 0.13 0.13 0.13 C 1.000 UL
568 | LTa
569 | LCa setrgbcolor
570 | 602 3924 M
571 | 6345 0 V
572 | stroke
573 | LTb
574 | 0.13 0.13 0.13 C 602 3924 M
575 | 63 0 V
576 | 6282 0 R
577 | -63 0 V
578 | stroke
579 | 518 3924 M
580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 55)]
581 | ] -46.7 MRshow
582 | 1.000 UL
583 | LTb
584 | 0.13 0.13 0.13 C 1.000 UL
585 | LTa
586 | LCa setrgbcolor
587 | 602 4619 M
588 | 6345 0 V
589 | stroke
590 | LTb
591 | 0.13 0.13 0.13 C 602 4619 M
592 | 63 0 V
593 | 6282 0 R
594 | -63 0 V
595 | stroke
596 | 518 4619 M
597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 60)]
598 | ] -46.7 MRshow
599 | 1.000 UL
600 | LTb
601 | 0.13 0.13 0.13 C 1.000 UL
602 | LTa
603 | LCa setrgbcolor
604 | 602 448 M
605 | 0 4171 V
606 | stroke
607 | LTb
608 | 0.13 0.13 0.13 C 602 448 M
609 | 0 63 V
610 | 0 4108 R
611 | 0 -63 V
612 | stroke
613 | 602 308 M
614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)]
615 | ] -46.7 MCshow
616 | 1.000 UL
617 | LTb
618 | 0.13 0.13 0.13 C 1.000 UL
619 | LTa
620 | LCa setrgbcolor
621 | 1871 448 M
622 | 0 4171 V
623 | stroke
624 | LTb
625 | 0.13 0.13 0.13 C 1871 448 M
626 | 0 63 V
627 | 0 4108 R
628 | 0 -63 V
629 | stroke
630 | 1871 308 M
631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)]
632 | ] -46.7 MCshow
633 | 1.000 UL
634 | LTb
635 | 0.13 0.13 0.13 C 1.000 UL
636 | LTa
637 | LCa setrgbcolor
638 | 3140 448 M
639 | 0 4171 V
640 | stroke
641 | LTb
642 | 0.13 0.13 0.13 C 3140 448 M
643 | 0 63 V
644 | 0 4108 R
645 | 0 -63 V
646 | stroke
647 | 3140 308 M
648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)]
649 | ] -46.7 MCshow
650 | 1.000 UL
651 | LTb
652 | 0.13 0.13 0.13 C 1.000 UL
653 | LTa
654 | LCa setrgbcolor
655 | 4409 448 M
656 | 0 4171 V
657 | stroke
658 | LTb
659 | 0.13 0.13 0.13 C 4409 448 M
660 | 0 63 V
661 | 0 4108 R
662 | 0 -63 V
663 | stroke
664 | 4409 308 M
665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)]
666 | ] -46.7 MCshow
667 | 1.000 UL
668 | LTb
669 | 0.13 0.13 0.13 C 1.000 UL
670 | LTa
671 | LCa setrgbcolor
672 | 5678 448 M
673 | 0 4171 V
674 | stroke
675 | LTb
676 | 0.13 0.13 0.13 C 5678 448 M
677 | 0 63 V
678 | 0 4108 R
679 | 0 -63 V
680 | stroke
681 | 5678 308 M
682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)]
683 | ] -46.7 MCshow
684 | 1.000 UL
685 | LTb
686 | 0.13 0.13 0.13 C 1.000 UL
687 | LTa
688 | LCa setrgbcolor
689 | 6947 448 M
690 | 0 4171 V
691 | stroke
692 | LTb
693 | 0.13 0.13 0.13 C 6947 448 M
694 | 0 63 V
695 | 0 4108 R
696 | 0 -63 V
697 | stroke
698 | 6947 308 M
699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)]
700 | ] -46.7 MCshow
701 | 1.000 UL
702 | LTb
703 | 0.13 0.13 0.13 C 1.000 UL
704 | LTb
705 | 0.13 0.13 0.13 C 602 4619 N
706 | 602 448 L
707 | 6345 0 V
708 | 0 4171 V
709 | -6345 0 V
710 | Z stroke
711 | LCb setrgbcolor
712 | 112 2533 M
713 | currentpoint gsave translate -270 rotate 0 0 moveto
714 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)]
715 | ] -46.7 MCshow
716 | grestore
717 | LTb
718 | LCb setrgbcolor
719 | 3774 98 M
720 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)]
721 | ] -46.7 MCshow
722 | LTb
723 | 3774 4829 M
724 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)]
725 | ] -46.7 MCshow
726 | 1.000 UP
727 | 1.000 UL
728 | LTb
729 | 0.13 0.13 0.13 C % Begin plot #1
730 | 1.000 UP
731 | 2.000 UL
732 | LT0
733 | 0.11 0.27 0.60 C 602 1604 M
734 | 635 -594 V
735 | 634 116 V
736 | 635 722 V
737 | 634 -474 V
738 | 635 -238 V
739 | 634 1132 V
740 | 635 -215 V
741 | 634 952 V
742 | 635 -471 V
743 | 634 1835 V
744 | 602 1604 CircleF
745 | 1237 1010 CircleF
746 | 1871 1126 CircleF
747 | 2506 1848 CircleF
748 | 3140 1374 CircleF
749 | 3775 1136 CircleF
750 | 4409 2268 CircleF
751 | 5044 2053 CircleF
752 | 5678 3005 CircleF
753 | 6313 2534 CircleF
754 | 6947 4369 CircleF
755 | % End plot #1
756 | 1.000 UL
757 | LTb
758 | 0.13 0.13 0.13 C 602 4619 N
759 | 602 448 L
760 | 6345 0 V
761 | 0 4171 V
762 | -6345 0 V
763 | Z stroke
764 | 1.000 UP
765 | 1.000 UL
766 | LTb
767 | 0.13 0.13 0.13 C stroke
768 | grestore
769 | end
770 | showpage
771 | %%Trailer
772 | %%DocumentFonts: Helvetica
773 |
--------------------------------------------------------------------------------
/images/q4a.eps:
--------------------------------------------------------------------------------
1 | %!PS-Adobe-2.0 EPSF-2.0
2 | %%Title: q4a.eps
3 | %%Creator: gnuplot 4.6 patchlevel 4
4 | %%CreationDate: Wed Jan 20 18:25:52 2016
5 | %%DocumentFonts: (atend)
6 | %%BoundingBox: 50 50 410 302
7 | %%EndComments
8 | %%BeginProlog
9 | /gnudict 256 dict def
10 | gnudict begin
11 | %
12 | % The following true/false flags may be edited by hand if desired.
13 | % The unit line width and grayscale image gamma correction may also be changed.
14 | %
15 | /Color true def
16 | /Blacktext false def
17 | /Solid false def
18 | /Dashlength 1 def
19 | /Landscape false def
20 | /Level1 false def
21 | /Rounded false def
22 | /ClipToBoundingBox false def
23 | /SuppressPDFMark false def
24 | /TransparentPatterns false def
25 | /gnulinewidth 5.000 def
26 | /userlinewidth gnulinewidth def
27 | /Gamma 1.0 def
28 | /BackgroundColor {-1.000 -1.000 -1.000} def
29 | %
30 | /vshift -46 def
31 | /dl1 {
32 | 10.0 Dashlength mul mul
33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if
34 | } def
35 | /dl2 {
36 | 10.0 Dashlength mul mul
37 | Rounded { currentlinewidth 0.75 mul add } if
38 | } def
39 | /hpt_ 31.5 def
40 | /vpt_ 31.5 def
41 | /hpt hpt_ def
42 | /vpt vpt_ def
43 | /doclip {
44 | ClipToBoundingBox {
45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath
46 | clip
47 | } if
48 | } def
49 | %
50 | % Gnuplot Prolog Version 4.6 (September 2012)
51 | %
52 | %/SuppressPDFMark true def
53 | %
54 | /M {moveto} bind def
55 | /L {lineto} bind def
56 | /R {rmoveto} bind def
57 | /V {rlineto} bind def
58 | /N {newpath moveto} bind def
59 | /Z {closepath} bind def
60 | /C {setrgbcolor} bind def
61 | /f {rlineto fill} bind def
62 | /g {setgray} bind def
63 | /Gshow {show} def % May be redefined later in the file to support UTF-8
64 | /vpt2 vpt 2 mul def
65 | /hpt2 hpt 2 mul def
66 | /Lshow {currentpoint stroke M 0 vshift R
67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R
69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R
71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def
73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def
74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash}
75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def
76 | /BL {stroke userlinewidth 2 mul setlinewidth
77 | Rounded {1 setlinejoin 1 setlinecap} if} def
78 | /AL {stroke userlinewidth 2 div setlinewidth
79 | Rounded {1 setlinejoin 1 setlinecap} if} def
80 | /UL {dup gnulinewidth mul /userlinewidth exch def
81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def
82 | /PL {stroke userlinewidth setlinewidth
83 | Rounded {1 setlinejoin 1 setlinecap} if} def
84 | 3.8 setmiterlimit
85 | % Default Line colors
86 | /LCw {1 1 1} def
87 | /LCb {0 0 0} def
88 | /LCa {0 0 0} def
89 | /LC0 {1 0 0} def
90 | /LC1 {0 1 0} def
91 | /LC2 {0 0 1} def
92 | /LC3 {1 0 1} def
93 | /LC4 {0 1 1} def
94 | /LC5 {1 1 0} def
95 | /LC6 {0 0 0} def
96 | /LC7 {1 0.3 0} def
97 | /LC8 {0.5 0.5 0.5} def
98 | % Default Line Types
99 | /LTw {PL [] 1 setgray} def
100 | /LTb {BL [] LCb DL} def
101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def
102 | /LT0 {PL [] LC0 DL} def
103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def
104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def
105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def
106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def
107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def
108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def
109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def
110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def
111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def
112 | /Dia {stroke [] 0 setdash 2 copy vpt add M
113 | hpt neg vpt neg V hpt vpt neg V
114 | hpt vpt V hpt neg vpt V closepath stroke
115 | Pnt} def
116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V
117 | currentpoint stroke M
118 | hpt neg vpt neg R hpt2 0 V stroke
119 | } def
120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M
121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
122 | hpt2 neg 0 V closepath stroke
123 | Pnt} def
124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M
125 | hpt2 vpt2 neg V currentpoint stroke M
126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def
127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M
128 | hpt neg vpt -1.62 mul V
129 | hpt 2 mul 0 V
130 | hpt neg vpt 1.62 mul V closepath stroke
131 | Pnt} def
132 | /Star {2 copy Pls Crs} def
133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M
134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
135 | hpt2 neg 0 V closepath fill} def
136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M
137 | hpt neg vpt -1.62 mul V
138 | hpt 2 mul 0 V
139 | hpt neg vpt 1.62 mul V closepath fill} def
140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M
141 | hpt neg vpt 1.62 mul V
142 | hpt 2 mul 0 V
143 | hpt neg vpt -1.62 mul V closepath stroke
144 | Pnt} def
145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M
146 | hpt neg vpt 1.62 mul V
147 | hpt 2 mul 0 V
148 | hpt neg vpt -1.62 mul V closepath fill} def
149 | /DiaF {stroke [] 0 setdash vpt add M
150 | hpt neg vpt neg V hpt vpt neg V
151 | hpt vpt V hpt neg vpt V closepath fill} def
152 | /Pent {stroke [] 0 setdash 2 copy gsave
153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
154 | closepath stroke grestore Pnt} def
155 | /PentF {stroke [] 0 setdash gsave
156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
157 | closepath fill grestore} def
158 | /Circle {stroke [] 0 setdash 2 copy
159 | hpt 0 360 arc stroke Pnt} def
160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def
161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def
162 | /C1 {BL [] 0 setdash 2 copy moveto
163 | 2 copy vpt 0 90 arc closepath fill
164 | vpt 0 360 arc closepath} bind def
165 | /C2 {BL [] 0 setdash 2 copy moveto
166 | 2 copy vpt 90 180 arc closepath fill
167 | vpt 0 360 arc closepath} bind def
168 | /C3 {BL [] 0 setdash 2 copy moveto
169 | 2 copy vpt 0 180 arc closepath fill
170 | vpt 0 360 arc closepath} bind def
171 | /C4 {BL [] 0 setdash 2 copy moveto
172 | 2 copy vpt 180 270 arc closepath fill
173 | vpt 0 360 arc closepath} bind def
174 | /C5 {BL [] 0 setdash 2 copy moveto
175 | 2 copy vpt 0 90 arc
176 | 2 copy moveto
177 | 2 copy vpt 180 270 arc closepath fill
178 | vpt 0 360 arc} bind def
179 | /C6 {BL [] 0 setdash 2 copy moveto
180 | 2 copy vpt 90 270 arc closepath fill
181 | vpt 0 360 arc closepath} bind def
182 | /C7 {BL [] 0 setdash 2 copy moveto
183 | 2 copy vpt 0 270 arc closepath fill
184 | vpt 0 360 arc closepath} bind def
185 | /C8 {BL [] 0 setdash 2 copy moveto
186 | 2 copy vpt 270 360 arc closepath fill
187 | vpt 0 360 arc closepath} bind def
188 | /C9 {BL [] 0 setdash 2 copy moveto
189 | 2 copy vpt 270 450 arc closepath fill
190 | vpt 0 360 arc closepath} bind def
191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill
192 | 2 copy moveto
193 | 2 copy vpt 90 180 arc closepath fill
194 | vpt 0 360 arc closepath} bind def
195 | /C11 {BL [] 0 setdash 2 copy moveto
196 | 2 copy vpt 0 180 arc closepath fill
197 | 2 copy moveto
198 | 2 copy vpt 270 360 arc closepath fill
199 | vpt 0 360 arc closepath} bind def
200 | /C12 {BL [] 0 setdash 2 copy moveto
201 | 2 copy vpt 180 360 arc closepath fill
202 | vpt 0 360 arc closepath} bind def
203 | /C13 {BL [] 0 setdash 2 copy moveto
204 | 2 copy vpt 0 90 arc closepath fill
205 | 2 copy moveto
206 | 2 copy vpt 180 360 arc closepath fill
207 | vpt 0 360 arc closepath} bind def
208 | /C14 {BL [] 0 setdash 2 copy moveto
209 | 2 copy vpt 90 360 arc closepath fill
210 | vpt 0 360 arc} bind def
211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill
212 | vpt 0 360 arc closepath} bind def
213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto
214 | neg 0 rlineto closepath} bind def
215 | /Square {dup Rec} bind def
216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def
217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def
218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def
219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def
221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill
223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def
225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill
226 | 2 copy vpt Square fill Bsquare} bind def
227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def
228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def
229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill
230 | Bsquare} bind def
231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill
232 | Bsquare} bind def
233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def
234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
235 | 2 copy vpt Square fill Bsquare} bind def
236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def
239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def
240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def
241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def
242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def
243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def
244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def
245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def
246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def
247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def
248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def
249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def
250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def
251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def
252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def
253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def
254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def
255 | /DiaE {stroke [] 0 setdash vpt add M
256 | hpt neg vpt neg V hpt vpt neg V
257 | hpt vpt V hpt neg vpt V closepath stroke} def
258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M
259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
260 | hpt2 neg 0 V closepath stroke} def
261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M
262 | hpt neg vpt -1.62 mul V
263 | hpt 2 mul 0 V
264 | hpt neg vpt 1.62 mul V closepath stroke} def
265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M
266 | hpt neg vpt 1.62 mul V
267 | hpt 2 mul 0 V
268 | hpt neg vpt -1.62 mul V closepath stroke} def
269 | /PentE {stroke [] 0 setdash gsave
270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
271 | closepath stroke grestore} def
272 | /CircE {stroke [] 0 setdash
273 | hpt 0 360 arc stroke} def
274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def
275 | /DiaW {stroke [] 0 setdash vpt add M
276 | hpt neg vpt neg V hpt vpt neg V
277 | hpt vpt V hpt neg vpt V Opaque stroke} def
278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M
279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
280 | hpt2 neg 0 V Opaque stroke} def
281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M
282 | hpt neg vpt -1.62 mul V
283 | hpt 2 mul 0 V
284 | hpt neg vpt 1.62 mul V Opaque stroke} def
285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M
286 | hpt neg vpt 1.62 mul V
287 | hpt 2 mul 0 V
288 | hpt neg vpt -1.62 mul V Opaque stroke} def
289 | /PentW {stroke [] 0 setdash gsave
290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
291 | Opaque stroke grestore} def
292 | /CircW {stroke [] 0 setdash
293 | hpt 0 360 arc Opaque stroke} def
294 | /BoxFill {gsave Rec 1 setgray fill grestore} def
295 | /Density {
296 | /Fillden exch def
297 | currentrgbcolor
298 | /ColB exch def /ColG exch def /ColR exch def
299 | /ColR ColR Fillden mul Fillden sub 1 add def
300 | /ColG ColG Fillden mul Fillden sub 1 add def
301 | /ColB ColB Fillden mul Fillden sub 1 add def
302 | ColR ColG ColB setrgbcolor} def
303 | /BoxColFill {gsave Rec PolyFill} def
304 | /PolyFill {gsave Density fill grestore grestore} def
305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def
306 | %
307 | % PostScript Level 1 Pattern Fill routine for rectangles
308 | % Usage: x y w h s a XX PatternFill
309 | % x,y = lower left corner of box to be filled
310 | % w,h = width and height of box
311 | % a = angle in degrees between lines and x-axis
312 | % XX = 0/1 for no/yes cross-hatch
313 | %
314 | /PatternFill {gsave /PFa [ 9 2 roll ] def
315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate
316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec
317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
318 | clip
319 | currentlinewidth 0.5 mul setlinewidth
320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def
321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate
322 | 0 1 PFs PFa 4 get div 1 add floor cvi
323 | {PFa 4 get mul 0 M 0 PFs V} for
324 | 0 PFa 6 get ne {
325 | 0 1 PFs PFa 4 get div 1 add floor cvi
326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for
327 | } if
328 | stroke grestore} def
329 | %
330 | /languagelevel where
331 | {pop languagelevel} {1} ifelse
332 | 2 lt
333 | {/InterpretLevel1 true def}
334 | {/InterpretLevel1 Level1 def}
335 | ifelse
336 | %
337 | % PostScript level 2 pattern fill definitions
338 | %
339 | /Level2PatternFill {
340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8}
341 | bind def
342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def
343 | << Tile8x8
344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke}
345 | >> matrix makepattern
346 | /Pat1 exch def
347 | << Tile8x8
348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke
349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke}
350 | >> matrix makepattern
351 | /Pat2 exch def
352 | << Tile8x8
353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L
354 | 8 8 L 8 0 L 0 0 L fill}
355 | >> matrix makepattern
356 | /Pat3 exch def
357 | << Tile8x8
358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L
359 | 0 12 M 12 0 L stroke}
360 | >> matrix makepattern
361 | /Pat4 exch def
362 | << Tile8x8
363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L
364 | 0 -4 M 12 8 L stroke}
365 | >> matrix makepattern
366 | /Pat5 exch def
367 | << Tile8x8
368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L
369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke}
370 | >> matrix makepattern
371 | /Pat6 exch def
372 | << Tile8x8
373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L
374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke}
375 | >> matrix makepattern
376 | /Pat7 exch def
377 | << Tile8x8
378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L
379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke}
380 | >> matrix makepattern
381 | /Pat8 exch def
382 | << Tile8x8
383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L
384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke}
385 | >> matrix makepattern
386 | /Pat9 exch def
387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def
388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def
389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def
390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def
391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def
392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def
393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def
394 | } def
395 | %
396 | %
397 | %End of PostScript Level 2 code
398 | %
399 | /PatternBgnd {
400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
401 | } def
402 | %
403 | % Substitute for Level 2 pattern fill codes with
404 | % grayscale if Level 2 support is not selected.
405 | %
406 | /Level1PatternFill {
407 | /Pattern1 {0.250 Density} bind def
408 | /Pattern2 {0.500 Density} bind def
409 | /Pattern3 {0.750 Density} bind def
410 | /Pattern4 {0.125 Density} bind def
411 | /Pattern5 {0.375 Density} bind def
412 | /Pattern6 {0.625 Density} bind def
413 | /Pattern7 {0.875 Density} bind def
414 | } def
415 | %
416 | % Now test for support of Level 2 code
417 | %
418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse
419 | %
420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont
421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall
422 | currentdict end definefont pop
423 | /MFshow {
424 | { dup 5 get 3 ge
425 | { 5 get 3 eq {gsave} {grestore} ifelse }
426 | {dup dup 0 get findfont exch 1 get scalefont setfont
427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6
428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq
429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5
430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div
431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get
432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop
433 | pop aload pop M} ifelse }ifelse }ifelse }
434 | ifelse }
435 | forall} def
436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def
437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse }
438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont
439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def
440 | /MLshow { currentpoint stroke M
441 | 0 exch R
442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
443 | /MRshow { currentpoint stroke M
444 | exch dup MFwidth neg 3 -1 roll R
445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
446 | /MCshow { currentpoint stroke M
447 | exch dup MFwidth -2 div 3 -1 roll R
448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def
450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def
451 | Level1 SuppressPDFMark or
452 | {} {
453 | /SDict 10 dict def
454 | systemdict /pdfmark known not {
455 | userdict /pdfmark systemdict /cleartomark get put
456 | } if
457 | SDict begin [
458 | /Title (q4a.eps)
459 | /Subject (gnuplot plot)
460 | /Creator (gnuplot 4.6 patchlevel 4)
461 | /Author (vitchyr)
462 | % /Producer (gnuplot)
463 | % /Keywords ()
464 | /CreationDate (Wed Jan 20 18:25:52 2016)
465 | /DOCINFO pdfmark
466 | end
467 | } ifelse
468 | end
469 | %%EndProlog
470 | %%Page: 1 1
471 | gnudict begin
472 | gsave
473 | doclip
474 | 50 50 translate
475 | 0.050 0.050 scale
476 | 0 setgray
477 | newpath
478 | (Helvetica) findfont 140 scalefont setfont
479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if
480 | 1.000 UL
481 | LTb
482 | 0.13 0.13 0.13 C 1.000 UL
483 | LTa
484 | LCa setrgbcolor
485 | 602 448 M
486 | 6345 0 V
487 | stroke
488 | LTb
489 | 0.13 0.13 0.13 C 602 448 M
490 | 63 0 V
491 | 6282 0 R
492 | -63 0 V
493 | stroke
494 | 518 448 M
495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 44)]
496 | ] -46.7 MRshow
497 | 1.000 UL
498 | LTb
499 | 0.13 0.13 0.13 C 1.000 UL
500 | LTa
501 | LCa setrgbcolor
502 | 602 1044 M
503 | 6345 0 V
504 | stroke
505 | LTb
506 | 0.13 0.13 0.13 C 602 1044 M
507 | 63 0 V
508 | 6282 0 R
509 | -63 0 V
510 | stroke
511 | 518 1044 M
512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 46)]
513 | ] -46.7 MRshow
514 | 1.000 UL
515 | LTb
516 | 0.13 0.13 0.13 C 1.000 UL
517 | LTa
518 | LCa setrgbcolor
519 | 602 1640 M
520 | 6345 0 V
521 | stroke
522 | LTb
523 | 0.13 0.13 0.13 C 602 1640 M
524 | 63 0 V
525 | 6282 0 R
526 | -63 0 V
527 | stroke
528 | 518 1640 M
529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 48)]
530 | ] -46.7 MRshow
531 | 1.000 UL
532 | LTb
533 | 0.13 0.13 0.13 C 1.000 UL
534 | LTa
535 | LCa setrgbcolor
536 | 602 2236 M
537 | 6345 0 V
538 | stroke
539 | LTb
540 | 0.13 0.13 0.13 C 602 2236 M
541 | 63 0 V
542 | 6282 0 R
543 | -63 0 V
544 | stroke
545 | 518 2236 M
546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 50)]
547 | ] -46.7 MRshow
548 | 1.000 UL
549 | LTb
550 | 0.13 0.13 0.13 C 1.000 UL
551 | LTa
552 | LCa setrgbcolor
553 | 602 2831 M
554 | 6345 0 V
555 | stroke
556 | LTb
557 | 0.13 0.13 0.13 C 602 2831 M
558 | 63 0 V
559 | 6282 0 R
560 | -63 0 V
561 | stroke
562 | 518 2831 M
563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 52)]
564 | ] -46.7 MRshow
565 | 1.000 UL
566 | LTb
567 | 0.13 0.13 0.13 C 1.000 UL
568 | LTa
569 | LCa setrgbcolor
570 | 602 3427 M
571 | 6345 0 V
572 | stroke
573 | LTb
574 | 0.13 0.13 0.13 C 602 3427 M
575 | 63 0 V
576 | 6282 0 R
577 | -63 0 V
578 | stroke
579 | 518 3427 M
580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 54)]
581 | ] -46.7 MRshow
582 | 1.000 UL
583 | LTb
584 | 0.13 0.13 0.13 C 1.000 UL
585 | LTa
586 | LCa setrgbcolor
587 | 602 4023 M
588 | 6345 0 V
589 | stroke
590 | LTb
591 | 0.13 0.13 0.13 C 602 4023 M
592 | 63 0 V
593 | 6282 0 R
594 | -63 0 V
595 | stroke
596 | 518 4023 M
597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 56)]
598 | ] -46.7 MRshow
599 | 1.000 UL
600 | LTb
601 | 0.13 0.13 0.13 C 1.000 UL
602 | LTa
603 | LCa setrgbcolor
604 | 602 4619 M
605 | 6345 0 V
606 | stroke
607 | LTb
608 | 0.13 0.13 0.13 C 602 4619 M
609 | 63 0 V
610 | 6282 0 R
611 | -63 0 V
612 | stroke
613 | 518 4619 M
614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 58)]
615 | ] -46.7 MRshow
616 | 1.000 UL
617 | LTb
618 | 0.13 0.13 0.13 C 1.000 UL
619 | LTa
620 | LCa setrgbcolor
621 | 602 448 M
622 | 0 4171 V
623 | stroke
624 | LTb
625 | 0.13 0.13 0.13 C 602 448 M
626 | 0 63 V
627 | 0 4108 R
628 | 0 -63 V
629 | stroke
630 | 602 308 M
631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)]
632 | ] -46.7 MCshow
633 | 1.000 UL
634 | LTb
635 | 0.13 0.13 0.13 C 1.000 UL
636 | LTa
637 | LCa setrgbcolor
638 | 1871 448 M
639 | 0 4171 V
640 | stroke
641 | LTb
642 | 0.13 0.13 0.13 C 1871 448 M
643 | 0 63 V
644 | 0 4108 R
645 | 0 -63 V
646 | stroke
647 | 1871 308 M
648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)]
649 | ] -46.7 MCshow
650 | 1.000 UL
651 | LTb
652 | 0.13 0.13 0.13 C 1.000 UL
653 | LTa
654 | LCa setrgbcolor
655 | 3140 448 M
656 | 0 4171 V
657 | stroke
658 | LTb
659 | 0.13 0.13 0.13 C 3140 448 M
660 | 0 63 V
661 | 0 4108 R
662 | 0 -63 V
663 | stroke
664 | 3140 308 M
665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)]
666 | ] -46.7 MCshow
667 | 1.000 UL
668 | LTb
669 | 0.13 0.13 0.13 C 1.000 UL
670 | LTa
671 | LCa setrgbcolor
672 | 4409 448 M
673 | 0 4171 V
674 | stroke
675 | LTb
676 | 0.13 0.13 0.13 C 4409 448 M
677 | 0 63 V
678 | 0 4108 R
679 | 0 -63 V
680 | stroke
681 | 4409 308 M
682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)]
683 | ] -46.7 MCshow
684 | 1.000 UL
685 | LTb
686 | 0.13 0.13 0.13 C 1.000 UL
687 | LTa
688 | LCa setrgbcolor
689 | 5678 448 M
690 | 0 4171 V
691 | stroke
692 | LTb
693 | 0.13 0.13 0.13 C 5678 448 M
694 | 0 63 V
695 | 0 4108 R
696 | 0 -63 V
697 | stroke
698 | 5678 308 M
699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)]
700 | ] -46.7 MCshow
701 | 1.000 UL
702 | LTb
703 | 0.13 0.13 0.13 C 1.000 UL
704 | LTa
705 | LCa setrgbcolor
706 | 6947 448 M
707 | 0 4171 V
708 | stroke
709 | LTb
710 | 0.13 0.13 0.13 C 6947 448 M
711 | 0 63 V
712 | 0 4108 R
713 | 0 -63 V
714 | stroke
715 | 6947 308 M
716 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)]
717 | ] -46.7 MCshow
718 | 1.000 UL
719 | LTb
720 | 0.13 0.13 0.13 C 1.000 UL
721 | LTb
722 | 0.13 0.13 0.13 C 602 4619 N
723 | 602 448 L
724 | 6345 0 V
725 | 0 4171 V
726 | -6345 0 V
727 | Z stroke
728 | LCb setrgbcolor
729 | 112 2533 M
730 | currentpoint gsave translate -270 rotate 0 0 moveto
731 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)]
732 | ] -46.7 MCshow
733 | grestore
734 | LTb
735 | LCb setrgbcolor
736 | 3774 98 M
737 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)]
738 | ] -46.7 MCshow
739 | LTb
740 | 3774 4829 M
741 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)]
742 | ] -46.7 MCshow
743 | 1.000 UP
744 | 1.000 UL
745 | LTb
746 | 0.13 0.13 0.13 C % Begin plot #1
747 | 1.000 UP
748 | 2.000 UL
749 | LT0
750 | 0.11 0.27 0.60 C 602 1646 M
751 | 635 1139 V
752 | 634 -696 V
753 | 2506 1038 L
754 | 634 1373 V
755 | 635 -82 V
756 | 634 1844 V
757 | 635 -626 V
758 | 634 95 V
759 | 635 -518 V
760 | 634 1402 V
761 | 602 1646 CircleF
762 | 1237 2785 CircleF
763 | 1871 2089 CircleF
764 | 2506 1038 CircleF
765 | 3140 2411 CircleF
766 | 3775 2329 CircleF
767 | 4409 4173 CircleF
768 | 5044 3547 CircleF
769 | 5678 3642 CircleF
770 | 6313 3124 CircleF
771 | 6947 4526 CircleF
772 | % End plot #1
773 | 1.000 UL
774 | LTb
775 | 0.13 0.13 0.13 C 602 4619 N
776 | 602 448 L
777 | 6345 0 V
778 | 0 4171 V
779 | -6345 0 V
780 | Z stroke
781 | 1.000 UP
782 | 1.000 UL
783 | LTb
784 | 0.13 0.13 0.13 C stroke
785 | grestore
786 | end
787 | showpage
788 | %%Trailer
789 | %%DocumentFonts: Helvetica
790 |
--------------------------------------------------------------------------------
/images/q5a.eps:
--------------------------------------------------------------------------------
1 | %!PS-Adobe-2.0 EPSF-2.0
2 | %%Title: q5a.eps
3 | %%Creator: gnuplot 4.6 patchlevel 4
4 | %%CreationDate: Thu Jan 21 00:24:42 2016
5 | %%DocumentFonts: (atend)
6 | %%BoundingBox: 50 50 410 302
7 | %%EndComments
8 | %%BeginProlog
9 | /gnudict 256 dict def
10 | gnudict begin
11 | %
12 | % The following true/false flags may be edited by hand if desired.
13 | % The unit line width and grayscale image gamma correction may also be changed.
14 | %
15 | /Color true def
16 | /Blacktext false def
17 | /Solid false def
18 | /Dashlength 1 def
19 | /Landscape false def
20 | /Level1 false def
21 | /Rounded false def
22 | /ClipToBoundingBox false def
23 | /SuppressPDFMark false def
24 | /TransparentPatterns false def
25 | /gnulinewidth 5.000 def
26 | /userlinewidth gnulinewidth def
27 | /Gamma 1.0 def
28 | /BackgroundColor {-1.000 -1.000 -1.000} def
29 | %
30 | /vshift -46 def
31 | /dl1 {
32 | 10.0 Dashlength mul mul
33 | Rounded { currentlinewidth 0.75 mul sub dup 0 le { pop 0.01 } if } if
34 | } def
35 | /dl2 {
36 | 10.0 Dashlength mul mul
37 | Rounded { currentlinewidth 0.75 mul add } if
38 | } def
39 | /hpt_ 31.5 def
40 | /vpt_ 31.5 def
41 | /hpt hpt_ def
42 | /vpt vpt_ def
43 | /doclip {
44 | ClipToBoundingBox {
45 | newpath 50 50 moveto 410 50 lineto 410 302 lineto 50 302 lineto closepath
46 | clip
47 | } if
48 | } def
49 | %
50 | % Gnuplot Prolog Version 4.6 (September 2012)
51 | %
52 | %/SuppressPDFMark true def
53 | %
54 | /M {moveto} bind def
55 | /L {lineto} bind def
56 | /R {rmoveto} bind def
57 | /V {rlineto} bind def
58 | /N {newpath moveto} bind def
59 | /Z {closepath} bind def
60 | /C {setrgbcolor} bind def
61 | /f {rlineto fill} bind def
62 | /g {setgray} bind def
63 | /Gshow {show} def % May be redefined later in the file to support UTF-8
64 | /vpt2 vpt 2 mul def
65 | /hpt2 hpt 2 mul def
66 | /Lshow {currentpoint stroke M 0 vshift R
67 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
68 | /Rshow {currentpoint stroke M dup stringwidth pop neg vshift R
69 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
70 | /Cshow {currentpoint stroke M dup stringwidth pop -2 div vshift R
71 | Blacktext {gsave 0 setgray show grestore} {show} ifelse} def
72 | /UP {dup vpt_ mul /vpt exch def hpt_ mul /hpt exch def
73 | /hpt2 hpt 2 mul def /vpt2 vpt 2 mul def} def
74 | /DL {Color {setrgbcolor Solid {pop []} if 0 setdash}
75 | {pop pop pop 0 setgray Solid {pop []} if 0 setdash} ifelse} def
76 | /BL {stroke userlinewidth 2 mul setlinewidth
77 | Rounded {1 setlinejoin 1 setlinecap} if} def
78 | /AL {stroke userlinewidth 2 div setlinewidth
79 | Rounded {1 setlinejoin 1 setlinecap} if} def
80 | /UL {dup gnulinewidth mul /userlinewidth exch def
81 | dup 1 lt {pop 1} if 10 mul /udl exch def} def
82 | /PL {stroke userlinewidth setlinewidth
83 | Rounded {1 setlinejoin 1 setlinecap} if} def
84 | 3.8 setmiterlimit
85 | % Default Line colors
86 | /LCw {1 1 1} def
87 | /LCb {0 0 0} def
88 | /LCa {0 0 0} def
89 | /LC0 {1 0 0} def
90 | /LC1 {0 1 0} def
91 | /LC2 {0 0 1} def
92 | /LC3 {1 0 1} def
93 | /LC4 {0 1 1} def
94 | /LC5 {1 1 0} def
95 | /LC6 {0 0 0} def
96 | /LC7 {1 0.3 0} def
97 | /LC8 {0.5 0.5 0.5} def
98 | % Default Line Types
99 | /LTw {PL [] 1 setgray} def
100 | /LTb {BL [] LCb DL} def
101 | /LTa {AL [1 udl mul 2 udl mul] 0 setdash LCa setrgbcolor} def
102 | /LT0 {PL [] LC0 DL} def
103 | /LT1 {PL [4 dl1 2 dl2] LC1 DL} def
104 | /LT2 {PL [2 dl1 3 dl2] LC2 DL} def
105 | /LT3 {PL [1 dl1 1.5 dl2] LC3 DL} def
106 | /LT4 {PL [6 dl1 2 dl2 1 dl1 2 dl2] LC4 DL} def
107 | /LT5 {PL [3 dl1 3 dl2 1 dl1 3 dl2] LC5 DL} def
108 | /LT6 {PL [2 dl1 2 dl2 2 dl1 6 dl2] LC6 DL} def
109 | /LT7 {PL [1 dl1 2 dl2 6 dl1 2 dl2 1 dl1 2 dl2] LC7 DL} def
110 | /LT8 {PL [2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 2 dl2 2 dl1 4 dl2] LC8 DL} def
111 | /Pnt {stroke [] 0 setdash gsave 1 setlinecap M 0 0 V stroke grestore} def
112 | /Dia {stroke [] 0 setdash 2 copy vpt add M
113 | hpt neg vpt neg V hpt vpt neg V
114 | hpt vpt V hpt neg vpt V closepath stroke
115 | Pnt} def
116 | /Pls {stroke [] 0 setdash vpt sub M 0 vpt2 V
117 | currentpoint stroke M
118 | hpt neg vpt neg R hpt2 0 V stroke
119 | } def
120 | /Box {stroke [] 0 setdash 2 copy exch hpt sub exch vpt add M
121 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
122 | hpt2 neg 0 V closepath stroke
123 | Pnt} def
124 | /Crs {stroke [] 0 setdash exch hpt sub exch vpt add M
125 | hpt2 vpt2 neg V currentpoint stroke M
126 | hpt2 neg 0 R hpt2 vpt2 V stroke} def
127 | /TriU {stroke [] 0 setdash 2 copy vpt 1.12 mul add M
128 | hpt neg vpt -1.62 mul V
129 | hpt 2 mul 0 V
130 | hpt neg vpt 1.62 mul V closepath stroke
131 | Pnt} def
132 | /Star {2 copy Pls Crs} def
133 | /BoxF {stroke [] 0 setdash exch hpt sub exch vpt add M
134 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
135 | hpt2 neg 0 V closepath fill} def
136 | /TriUF {stroke [] 0 setdash vpt 1.12 mul add M
137 | hpt neg vpt -1.62 mul V
138 | hpt 2 mul 0 V
139 | hpt neg vpt 1.62 mul V closepath fill} def
140 | /TriD {stroke [] 0 setdash 2 copy vpt 1.12 mul sub M
141 | hpt neg vpt 1.62 mul V
142 | hpt 2 mul 0 V
143 | hpt neg vpt -1.62 mul V closepath stroke
144 | Pnt} def
145 | /TriDF {stroke [] 0 setdash vpt 1.12 mul sub M
146 | hpt neg vpt 1.62 mul V
147 | hpt 2 mul 0 V
148 | hpt neg vpt -1.62 mul V closepath fill} def
149 | /DiaF {stroke [] 0 setdash vpt add M
150 | hpt neg vpt neg V hpt vpt neg V
151 | hpt vpt V hpt neg vpt V closepath fill} def
152 | /Pent {stroke [] 0 setdash 2 copy gsave
153 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
154 | closepath stroke grestore Pnt} def
155 | /PentF {stroke [] 0 setdash gsave
156 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
157 | closepath fill grestore} def
158 | /Circle {stroke [] 0 setdash 2 copy
159 | hpt 0 360 arc stroke Pnt} def
160 | /CircleF {stroke [] 0 setdash hpt 0 360 arc fill} def
161 | /C0 {BL [] 0 setdash 2 copy moveto vpt 90 450 arc} bind def
162 | /C1 {BL [] 0 setdash 2 copy moveto
163 | 2 copy vpt 0 90 arc closepath fill
164 | vpt 0 360 arc closepath} bind def
165 | /C2 {BL [] 0 setdash 2 copy moveto
166 | 2 copy vpt 90 180 arc closepath fill
167 | vpt 0 360 arc closepath} bind def
168 | /C3 {BL [] 0 setdash 2 copy moveto
169 | 2 copy vpt 0 180 arc closepath fill
170 | vpt 0 360 arc closepath} bind def
171 | /C4 {BL [] 0 setdash 2 copy moveto
172 | 2 copy vpt 180 270 arc closepath fill
173 | vpt 0 360 arc closepath} bind def
174 | /C5 {BL [] 0 setdash 2 copy moveto
175 | 2 copy vpt 0 90 arc
176 | 2 copy moveto
177 | 2 copy vpt 180 270 arc closepath fill
178 | vpt 0 360 arc} bind def
179 | /C6 {BL [] 0 setdash 2 copy moveto
180 | 2 copy vpt 90 270 arc closepath fill
181 | vpt 0 360 arc closepath} bind def
182 | /C7 {BL [] 0 setdash 2 copy moveto
183 | 2 copy vpt 0 270 arc closepath fill
184 | vpt 0 360 arc closepath} bind def
185 | /C8 {BL [] 0 setdash 2 copy moveto
186 | 2 copy vpt 270 360 arc closepath fill
187 | vpt 0 360 arc closepath} bind def
188 | /C9 {BL [] 0 setdash 2 copy moveto
189 | 2 copy vpt 270 450 arc closepath fill
190 | vpt 0 360 arc closepath} bind def
191 | /C10 {BL [] 0 setdash 2 copy 2 copy moveto vpt 270 360 arc closepath fill
192 | 2 copy moveto
193 | 2 copy vpt 90 180 arc closepath fill
194 | vpt 0 360 arc closepath} bind def
195 | /C11 {BL [] 0 setdash 2 copy moveto
196 | 2 copy vpt 0 180 arc closepath fill
197 | 2 copy moveto
198 | 2 copy vpt 270 360 arc closepath fill
199 | vpt 0 360 arc closepath} bind def
200 | /C12 {BL [] 0 setdash 2 copy moveto
201 | 2 copy vpt 180 360 arc closepath fill
202 | vpt 0 360 arc closepath} bind def
203 | /C13 {BL [] 0 setdash 2 copy moveto
204 | 2 copy vpt 0 90 arc closepath fill
205 | 2 copy moveto
206 | 2 copy vpt 180 360 arc closepath fill
207 | vpt 0 360 arc closepath} bind def
208 | /C14 {BL [] 0 setdash 2 copy moveto
209 | 2 copy vpt 90 360 arc closepath fill
210 | vpt 0 360 arc} bind def
211 | /C15 {BL [] 0 setdash 2 copy vpt 0 360 arc closepath fill
212 | vpt 0 360 arc closepath} bind def
213 | /Rec {newpath 4 2 roll moveto 1 index 0 rlineto 0 exch rlineto
214 | neg 0 rlineto closepath} bind def
215 | /Square {dup Rec} bind def
216 | /Bsquare {vpt sub exch vpt sub exch vpt2 Square} bind def
217 | /S0 {BL [] 0 setdash 2 copy moveto 0 vpt rlineto BL Bsquare} bind def
218 | /S1 {BL [] 0 setdash 2 copy vpt Square fill Bsquare} bind def
219 | /S2 {BL [] 0 setdash 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
220 | /S3 {BL [] 0 setdash 2 copy exch vpt sub exch vpt2 vpt Rec fill Bsquare} bind def
221 | /S4 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
222 | /S5 {BL [] 0 setdash 2 copy 2 copy vpt Square fill
223 | exch vpt sub exch vpt sub vpt Square fill Bsquare} bind def
224 | /S6 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill Bsquare} bind def
225 | /S7 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt vpt2 Rec fill
226 | 2 copy vpt Square fill Bsquare} bind def
227 | /S8 {BL [] 0 setdash 2 copy vpt sub vpt Square fill Bsquare} bind def
228 | /S9 {BL [] 0 setdash 2 copy vpt sub vpt vpt2 Rec fill Bsquare} bind def
229 | /S10 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt Square fill
230 | Bsquare} bind def
231 | /S11 {BL [] 0 setdash 2 copy vpt sub vpt Square fill 2 copy exch vpt sub exch vpt2 vpt Rec fill
232 | Bsquare} bind def
233 | /S12 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill Bsquare} bind def
234 | /S13 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
235 | 2 copy vpt Square fill Bsquare} bind def
236 | /S14 {BL [] 0 setdash 2 copy exch vpt sub exch vpt sub vpt2 vpt Rec fill
237 | 2 copy exch vpt sub exch vpt Square fill Bsquare} bind def
238 | /S15 {BL [] 0 setdash 2 copy Bsquare fill Bsquare} bind def
239 | /D0 {gsave translate 45 rotate 0 0 S0 stroke grestore} bind def
240 | /D1 {gsave translate 45 rotate 0 0 S1 stroke grestore} bind def
241 | /D2 {gsave translate 45 rotate 0 0 S2 stroke grestore} bind def
242 | /D3 {gsave translate 45 rotate 0 0 S3 stroke grestore} bind def
243 | /D4 {gsave translate 45 rotate 0 0 S4 stroke grestore} bind def
244 | /D5 {gsave translate 45 rotate 0 0 S5 stroke grestore} bind def
245 | /D6 {gsave translate 45 rotate 0 0 S6 stroke grestore} bind def
246 | /D7 {gsave translate 45 rotate 0 0 S7 stroke grestore} bind def
247 | /D8 {gsave translate 45 rotate 0 0 S8 stroke grestore} bind def
248 | /D9 {gsave translate 45 rotate 0 0 S9 stroke grestore} bind def
249 | /D10 {gsave translate 45 rotate 0 0 S10 stroke grestore} bind def
250 | /D11 {gsave translate 45 rotate 0 0 S11 stroke grestore} bind def
251 | /D12 {gsave translate 45 rotate 0 0 S12 stroke grestore} bind def
252 | /D13 {gsave translate 45 rotate 0 0 S13 stroke grestore} bind def
253 | /D14 {gsave translate 45 rotate 0 0 S14 stroke grestore} bind def
254 | /D15 {gsave translate 45 rotate 0 0 S15 stroke grestore} bind def
255 | /DiaE {stroke [] 0 setdash vpt add M
256 | hpt neg vpt neg V hpt vpt neg V
257 | hpt vpt V hpt neg vpt V closepath stroke} def
258 | /BoxE {stroke [] 0 setdash exch hpt sub exch vpt add M
259 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
260 | hpt2 neg 0 V closepath stroke} def
261 | /TriUE {stroke [] 0 setdash vpt 1.12 mul add M
262 | hpt neg vpt -1.62 mul V
263 | hpt 2 mul 0 V
264 | hpt neg vpt 1.62 mul V closepath stroke} def
265 | /TriDE {stroke [] 0 setdash vpt 1.12 mul sub M
266 | hpt neg vpt 1.62 mul V
267 | hpt 2 mul 0 V
268 | hpt neg vpt -1.62 mul V closepath stroke} def
269 | /PentE {stroke [] 0 setdash gsave
270 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
271 | closepath stroke grestore} def
272 | /CircE {stroke [] 0 setdash
273 | hpt 0 360 arc stroke} def
274 | /Opaque {gsave closepath 1 setgray fill grestore 0 setgray closepath} def
275 | /DiaW {stroke [] 0 setdash vpt add M
276 | hpt neg vpt neg V hpt vpt neg V
277 | hpt vpt V hpt neg vpt V Opaque stroke} def
278 | /BoxW {stroke [] 0 setdash exch hpt sub exch vpt add M
279 | 0 vpt2 neg V hpt2 0 V 0 vpt2 V
280 | hpt2 neg 0 V Opaque stroke} def
281 | /TriUW {stroke [] 0 setdash vpt 1.12 mul add M
282 | hpt neg vpt -1.62 mul V
283 | hpt 2 mul 0 V
284 | hpt neg vpt 1.62 mul V Opaque stroke} def
285 | /TriDW {stroke [] 0 setdash vpt 1.12 mul sub M
286 | hpt neg vpt 1.62 mul V
287 | hpt 2 mul 0 V
288 | hpt neg vpt -1.62 mul V Opaque stroke} def
289 | /PentW {stroke [] 0 setdash gsave
290 | translate 0 hpt M 4 {72 rotate 0 hpt L} repeat
291 | Opaque stroke grestore} def
292 | /CircW {stroke [] 0 setdash
293 | hpt 0 360 arc Opaque stroke} def
294 | /BoxFill {gsave Rec 1 setgray fill grestore} def
295 | /Density {
296 | /Fillden exch def
297 | currentrgbcolor
298 | /ColB exch def /ColG exch def /ColR exch def
299 | /ColR ColR Fillden mul Fillden sub 1 add def
300 | /ColG ColG Fillden mul Fillden sub 1 add def
301 | /ColB ColB Fillden mul Fillden sub 1 add def
302 | ColR ColG ColB setrgbcolor} def
303 | /BoxColFill {gsave Rec PolyFill} def
304 | /PolyFill {gsave Density fill grestore grestore} def
305 | /h {rlineto rlineto rlineto gsave closepath fill grestore} bind def
306 | %
307 | % PostScript Level 1 Pattern Fill routine for rectangles
308 | % Usage: x y w h s a XX PatternFill
309 | % x,y = lower left corner of box to be filled
310 | % w,h = width and height of box
311 | % a = angle in degrees between lines and x-axis
312 | % XX = 0/1 for no/yes cross-hatch
313 | %
314 | /PatternFill {gsave /PFa [ 9 2 roll ] def
315 | PFa 0 get PFa 2 get 2 div add PFa 1 get PFa 3 get 2 div add translate
316 | PFa 2 get -2 div PFa 3 get -2 div PFa 2 get PFa 3 get Rec
317 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
318 | clip
319 | currentlinewidth 0.5 mul setlinewidth
320 | /PFs PFa 2 get dup mul PFa 3 get dup mul add sqrt def
321 | 0 0 M PFa 5 get rotate PFs -2 div dup translate
322 | 0 1 PFs PFa 4 get div 1 add floor cvi
323 | {PFa 4 get mul 0 M 0 PFs V} for
324 | 0 PFa 6 get ne {
325 | 0 1 PFs PFa 4 get div 1 add floor cvi
326 | {PFa 4 get mul 0 2 1 roll M PFs 0 V} for
327 | } if
328 | stroke grestore} def
329 | %
330 | /languagelevel where
331 | {pop languagelevel} {1} ifelse
332 | 2 lt
333 | {/InterpretLevel1 true def}
334 | {/InterpretLevel1 Level1 def}
335 | ifelse
336 | %
337 | % PostScript level 2 pattern fill definitions
338 | %
339 | /Level2PatternFill {
340 | /Tile8x8 {/PaintType 2 /PatternType 1 /TilingType 1 /BBox [0 0 8 8] /XStep 8 /YStep 8}
341 | bind def
342 | /KeepColor {currentrgbcolor [/Pattern /DeviceRGB] setcolorspace} bind def
343 | << Tile8x8
344 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke}
345 | >> matrix makepattern
346 | /Pat1 exch def
347 | << Tile8x8
348 | /PaintProc {0.5 setlinewidth pop 0 0 M 8 8 L 0 8 M 8 0 L stroke
349 | 0 4 M 4 8 L 8 4 L 4 0 L 0 4 L stroke}
350 | >> matrix makepattern
351 | /Pat2 exch def
352 | << Tile8x8
353 | /PaintProc {0.5 setlinewidth pop 0 0 M 0 8 L
354 | 8 8 L 8 0 L 0 0 L fill}
355 | >> matrix makepattern
356 | /Pat3 exch def
357 | << Tile8x8
358 | /PaintProc {0.5 setlinewidth pop -4 8 M 8 -4 L
359 | 0 12 M 12 0 L stroke}
360 | >> matrix makepattern
361 | /Pat4 exch def
362 | << Tile8x8
363 | /PaintProc {0.5 setlinewidth pop -4 0 M 8 12 L
364 | 0 -4 M 12 8 L stroke}
365 | >> matrix makepattern
366 | /Pat5 exch def
367 | << Tile8x8
368 | /PaintProc {0.5 setlinewidth pop -2 8 M 4 -4 L
369 | 0 12 M 8 -4 L 4 12 M 10 0 L stroke}
370 | >> matrix makepattern
371 | /Pat6 exch def
372 | << Tile8x8
373 | /PaintProc {0.5 setlinewidth pop -2 0 M 4 12 L
374 | 0 -4 M 8 12 L 4 -4 M 10 8 L stroke}
375 | >> matrix makepattern
376 | /Pat7 exch def
377 | << Tile8x8
378 | /PaintProc {0.5 setlinewidth pop 8 -2 M -4 4 L
379 | 12 0 M -4 8 L 12 4 M 0 10 L stroke}
380 | >> matrix makepattern
381 | /Pat8 exch def
382 | << Tile8x8
383 | /PaintProc {0.5 setlinewidth pop 0 -2 M 12 4 L
384 | -4 0 M 12 8 L -4 4 M 8 10 L stroke}
385 | >> matrix makepattern
386 | /Pat9 exch def
387 | /Pattern1 {PatternBgnd KeepColor Pat1 setpattern} bind def
388 | /Pattern2 {PatternBgnd KeepColor Pat2 setpattern} bind def
389 | /Pattern3 {PatternBgnd KeepColor Pat3 setpattern} bind def
390 | /Pattern4 {PatternBgnd KeepColor Landscape {Pat5} {Pat4} ifelse setpattern} bind def
391 | /Pattern5 {PatternBgnd KeepColor Landscape {Pat4} {Pat5} ifelse setpattern} bind def
392 | /Pattern6 {PatternBgnd KeepColor Landscape {Pat9} {Pat6} ifelse setpattern} bind def
393 | /Pattern7 {PatternBgnd KeepColor Landscape {Pat8} {Pat7} ifelse setpattern} bind def
394 | } def
395 | %
396 | %
397 | %End of PostScript Level 2 code
398 | %
399 | /PatternBgnd {
400 | TransparentPatterns {} {gsave 1 setgray fill grestore} ifelse
401 | } def
402 | %
403 | % Substitute for Level 2 pattern fill codes with
404 | % grayscale if Level 2 support is not selected.
405 | %
406 | /Level1PatternFill {
407 | /Pattern1 {0.250 Density} bind def
408 | /Pattern2 {0.500 Density} bind def
409 | /Pattern3 {0.750 Density} bind def
410 | /Pattern4 {0.125 Density} bind def
411 | /Pattern5 {0.375 Density} bind def
412 | /Pattern6 {0.625 Density} bind def
413 | /Pattern7 {0.875 Density} bind def
414 | } def
415 | %
416 | % Now test for support of Level 2 code
417 | %
418 | Level1 {Level1PatternFill} {Level2PatternFill} ifelse
419 | %
420 | /Symbol-Oblique /Symbol findfont [1 0 .167 1 0 0] makefont
421 | dup length dict begin {1 index /FID eq {pop pop} {def} ifelse} forall
422 | currentdict end definefont pop
423 | /MFshow {
424 | { dup 5 get 3 ge
425 | { 5 get 3 eq {gsave} {grestore} ifelse }
426 | {dup dup 0 get findfont exch 1 get scalefont setfont
427 | [ currentpoint ] exch dup 2 get 0 exch R dup 5 get 2 ne {dup dup 6
428 | get exch 4 get {Gshow} {stringwidth pop 0 R} ifelse }if dup 5 get 0 eq
429 | {dup 3 get {2 get neg 0 exch R pop} {pop aload pop M} ifelse} {dup 5
430 | get 1 eq {dup 2 get exch dup 3 get exch 6 get stringwidth pop -2 div
431 | dup 0 R} {dup 6 get stringwidth pop -2 div 0 R 6 get
432 | show 2 index {aload pop M neg 3 -1 roll neg R pop pop} {pop pop pop
433 | pop aload pop M} ifelse }ifelse }ifelse }
434 | ifelse }
435 | forall} def
436 | /Gswidth {dup type /stringtype eq {stringwidth} {pop (n) stringwidth} ifelse} def
437 | /MFwidth {0 exch { dup 5 get 3 ge { 5 get 3 eq { 0 } { pop } ifelse }
438 | {dup 3 get{dup dup 0 get findfont exch 1 get scalefont setfont
439 | 6 get Gswidth pop add} {pop} ifelse} ifelse} forall} def
440 | /MLshow { currentpoint stroke M
441 | 0 exch R
442 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
443 | /MRshow { currentpoint stroke M
444 | exch dup MFwidth neg 3 -1 roll R
445 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
446 | /MCshow { currentpoint stroke M
447 | exch dup MFwidth -2 div 3 -1 roll R
448 | Blacktext {gsave 0 setgray MFshow grestore} {MFshow} ifelse } bind def
449 | /XYsave { [( ) 1 2 true false 3 ()] } bind def
450 | /XYrestore { [( ) 1 2 true false 4 ()] } bind def
451 | Level1 SuppressPDFMark or
452 | {} {
453 | /SDict 10 dict def
454 | systemdict /pdfmark known not {
455 | userdict /pdfmark systemdict /cleartomark get put
456 | } if
457 | SDict begin [
458 | /Title (q5a.eps)
459 | /Subject (gnuplot plot)
460 | /Creator (gnuplot 4.6 patchlevel 4)
461 | /Author (vitchyr)
462 | % /Producer (gnuplot)
463 | % /Keywords ()
464 | /CreationDate (Thu Jan 21 00:24:42 2016)
465 | /DOCINFO pdfmark
466 | end
467 | } ifelse
468 | end
469 | %%EndProlog
470 | %%Page: 1 1
471 | gnudict begin
472 | gsave
473 | doclip
474 | 50 50 translate
475 | 0.050 0.050 scale
476 | 0 setgray
477 | newpath
478 | (Helvetica) findfont 140 scalefont setfont
479 | BackgroundColor 0 lt 3 1 roll 0 lt exch 0 lt or or not {BackgroundColor C 1.000 0 0 7200.00 5040.00 BoxColFill} if
480 | 1.000 UL
481 | LTb
482 | 0.13 0.13 0.13 C 1.000 UL
483 | LTa
484 | LCa setrgbcolor
485 | 602 448 M
486 | 6345 0 V
487 | stroke
488 | LTb
489 | 0.13 0.13 0.13 C 602 448 M
490 | 63 0 V
491 | 6282 0 R
492 | -63 0 V
493 | stroke
494 | 518 448 M
495 | [ [(Helvetica) 140.0 0.0 true true 0 ( 28)]
496 | ] -46.7 MRshow
497 | 1.000 UL
498 | LTb
499 | 0.13 0.13 0.13 C 1.000 UL
500 | LTa
501 | LCa setrgbcolor
502 | 602 969 M
503 | 6345 0 V
504 | stroke
505 | LTb
506 | 0.13 0.13 0.13 C 602 969 M
507 | 63 0 V
508 | 6282 0 R
509 | -63 0 V
510 | stroke
511 | 518 969 M
512 | [ [(Helvetica) 140.0 0.0 true true 0 ( 30)]
513 | ] -46.7 MRshow
514 | 1.000 UL
515 | LTb
516 | 0.13 0.13 0.13 C 1.000 UL
517 | LTa
518 | LCa setrgbcolor
519 | 602 1491 M
520 | 6345 0 V
521 | stroke
522 | LTb
523 | 0.13 0.13 0.13 C 602 1491 M
524 | 63 0 V
525 | 6282 0 R
526 | -63 0 V
527 | stroke
528 | 518 1491 M
529 | [ [(Helvetica) 140.0 0.0 true true 0 ( 32)]
530 | ] -46.7 MRshow
531 | 1.000 UL
532 | LTb
533 | 0.13 0.13 0.13 C 1.000 UL
534 | LTa
535 | LCa setrgbcolor
536 | 602 2012 M
537 | 6345 0 V
538 | stroke
539 | LTb
540 | 0.13 0.13 0.13 C 602 2012 M
541 | 63 0 V
542 | 6282 0 R
543 | -63 0 V
544 | stroke
545 | 518 2012 M
546 | [ [(Helvetica) 140.0 0.0 true true 0 ( 34)]
547 | ] -46.7 MRshow
548 | 1.000 UL
549 | LTb
550 | 0.13 0.13 0.13 C 1.000 UL
551 | LTa
552 | LCa setrgbcolor
553 | 602 2534 M
554 | 6345 0 V
555 | stroke
556 | LTb
557 | 0.13 0.13 0.13 C 602 2534 M
558 | 63 0 V
559 | 6282 0 R
560 | -63 0 V
561 | stroke
562 | 518 2534 M
563 | [ [(Helvetica) 140.0 0.0 true true 0 ( 36)]
564 | ] -46.7 MRshow
565 | 1.000 UL
566 | LTb
567 | 0.13 0.13 0.13 C 1.000 UL
568 | LTa
569 | LCa setrgbcolor
570 | 602 3055 M
571 | 6345 0 V
572 | stroke
573 | LTb
574 | 0.13 0.13 0.13 C 602 3055 M
575 | 63 0 V
576 | 6282 0 R
577 | -63 0 V
578 | stroke
579 | 518 3055 M
580 | [ [(Helvetica) 140.0 0.0 true true 0 ( 38)]
581 | ] -46.7 MRshow
582 | 1.000 UL
583 | LTb
584 | 0.13 0.13 0.13 C 1.000 UL
585 | LTa
586 | LCa setrgbcolor
587 | 602 3576 M
588 | 6345 0 V
589 | stroke
590 | LTb
591 | 0.13 0.13 0.13 C 602 3576 M
592 | 63 0 V
593 | 6282 0 R
594 | -63 0 V
595 | stroke
596 | 518 3576 M
597 | [ [(Helvetica) 140.0 0.0 true true 0 ( 40)]
598 | ] -46.7 MRshow
599 | 1.000 UL
600 | LTb
601 | 0.13 0.13 0.13 C 1.000 UL
602 | LTa
603 | LCa setrgbcolor
604 | 602 4098 M
605 | 6345 0 V
606 | stroke
607 | LTb
608 | 0.13 0.13 0.13 C 602 4098 M
609 | 63 0 V
610 | 6282 0 R
611 | -63 0 V
612 | stroke
613 | 518 4098 M
614 | [ [(Helvetica) 140.0 0.0 true true 0 ( 42)]
615 | ] -46.7 MRshow
616 | 1.000 UL
617 | LTb
618 | 0.13 0.13 0.13 C 1.000 UL
619 | LTa
620 | LCa setrgbcolor
621 | 602 4619 M
622 | 6345 0 V
623 | stroke
624 | LTb
625 | 0.13 0.13 0.13 C 602 4619 M
626 | 63 0 V
627 | 6282 0 R
628 | -63 0 V
629 | stroke
630 | 518 4619 M
631 | [ [(Helvetica) 140.0 0.0 true true 0 ( 44)]
632 | ] -46.7 MRshow
633 | 1.000 UL
634 | LTb
635 | 0.13 0.13 0.13 C 1.000 UL
636 | LTa
637 | LCa setrgbcolor
638 | 602 448 M
639 | 0 4171 V
640 | stroke
641 | LTb
642 | 0.13 0.13 0.13 C 602 448 M
643 | 0 63 V
644 | 0 4108 R
645 | 0 -63 V
646 | stroke
647 | 602 308 M
648 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0)]
649 | ] -46.7 MCshow
650 | 1.000 UL
651 | LTb
652 | 0.13 0.13 0.13 C 1.000 UL
653 | LTa
654 | LCa setrgbcolor
655 | 1871 448 M
656 | 0 4171 V
657 | stroke
658 | LTb
659 | 0.13 0.13 0.13 C 1871 448 M
660 | 0 63 V
661 | 0 4108 R
662 | 0 -63 V
663 | stroke
664 | 1871 308 M
665 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.2)]
666 | ] -46.7 MCshow
667 | 1.000 UL
668 | LTb
669 | 0.13 0.13 0.13 C 1.000 UL
670 | LTa
671 | LCa setrgbcolor
672 | 3140 448 M
673 | 0 4171 V
674 | stroke
675 | LTb
676 | 0.13 0.13 0.13 C 3140 448 M
677 | 0 63 V
678 | 0 4108 R
679 | 0 -63 V
680 | stroke
681 | 3140 308 M
682 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.4)]
683 | ] -46.7 MCshow
684 | 1.000 UL
685 | LTb
686 | 0.13 0.13 0.13 C 1.000 UL
687 | LTa
688 | LCa setrgbcolor
689 | 4409 448 M
690 | 0 4171 V
691 | stroke
692 | LTb
693 | 0.13 0.13 0.13 C 4409 448 M
694 | 0 63 V
695 | 0 4108 R
696 | 0 -63 V
697 | stroke
698 | 4409 308 M
699 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.6)]
700 | ] -46.7 MCshow
701 | 1.000 UL
702 | LTb
703 | 0.13 0.13 0.13 C 1.000 UL
704 | LTa
705 | LCa setrgbcolor
706 | 5678 448 M
707 | 0 4171 V
708 | stroke
709 | LTb
710 | 0.13 0.13 0.13 C 5678 448 M
711 | 0 63 V
712 | 0 4108 R
713 | 0 -63 V
714 | stroke
715 | 5678 308 M
716 | [ [(Helvetica) 140.0 0.0 true true 0 ( 0.8)]
717 | ] -46.7 MCshow
718 | 1.000 UL
719 | LTb
720 | 0.13 0.13 0.13 C 1.000 UL
721 | LTa
722 | LCa setrgbcolor
723 | 6947 448 M
724 | 0 4171 V
725 | stroke
726 | LTb
727 | 0.13 0.13 0.13 C 6947 448 M
728 | 0 63 V
729 | 0 4108 R
730 | 0 -63 V
731 | stroke
732 | 6947 308 M
733 | [ [(Helvetica) 140.0 0.0 true true 0 ( 1)]
734 | ] -46.7 MCshow
735 | 1.000 UL
736 | LTb
737 | 0.13 0.13 0.13 C 1.000 UL
738 | LTb
739 | 0.13 0.13 0.13 C 602 4619 N
740 | 602 448 L
741 | 6345 0 V
742 | 0 4171 V
743 | -6345 0 V
744 | Z stroke
745 | LCb setrgbcolor
746 | 112 2533 M
747 | currentpoint gsave translate -270 rotate 0 0 moveto
748 | [ [(Helvetica) 140.0 0.0 true true 0 (RMS between Q-MC and Q-SARSA)]
749 | ] -46.7 MCshow
750 | grestore
751 | LTb
752 | LCb setrgbcolor
753 | 3774 98 M
754 | [ [(Helvetica) 140.0 0.0 true true 0 (lambda)]
755 | ] -46.7 MCshow
756 | LTb
757 | 3774 4829 M
758 | [ [(Helvetica) 140.0 0.0 true true 0 (Q RMS after 1000 episodes vs lambda)]
759 | ] -46.7 MCshow
760 | 1.000 UP
761 | 1.000 UL
762 | LTb
763 | 0.13 0.13 0.13 C % Begin plot #1
764 | 1.000 UP
765 | 2.000 UL
766 | LT0
767 | 0.11 0.27 0.60 C 602 910 M
768 | 635 2723 V
769 | 1871 1339 L
770 | 2506 517 L
771 | 634 3912 V
772 | 3775 1021 L
773 | 4409 813 L
774 | 635 526 V
775 | 634 780 V
776 | 6313 934 L
777 | 634 35 V
778 | 602 910 CircleF
779 | 1237 3633 CircleF
780 | 1871 1339 CircleF
781 | 2506 517 CircleF
782 | 3140 4429 CircleF
783 | 3775 1021 CircleF
784 | 4409 813 CircleF
785 | 5044 1339 CircleF
786 | 5678 2119 CircleF
787 | 6313 934 CircleF
788 | 6947 969 CircleF
789 | % End plot #1
790 | 1.000 UL
791 | LTb
792 | 0.13 0.13 0.13 C 602 4619 N
793 | 602 448 L
794 | 6345 0 V
795 | 0 4171 V
796 | -6345 0 V
797 | Z stroke
798 | 1.000 UP
799 | 1.000 UL
800 | LTb
801 | 0.13 0.13 0.13 C stroke
802 | grestore
803 | end
804 | showpage
805 | %%Trailer
806 | %%DocumentFonts: Helvetica
807 |
--------------------------------------------------------------------------------