├── .gitignore ├── images ├── restaurant.png ├── bayesnetwork.png └── sprinklernet.jpg ├── .gitmodules ├── .travis.yml ├── aimajulia.jl ├── LICENSE ├── tests ├── run_travis_tests.sh ├── run_game_tests.jl ├── run_mdp_tests.jl ├── run_util_tests.jl ├── run_agent_tests.jl ├── non_deterministic_astar.jl ├── run_rl_tests.jl ├── run_search_tests.jl ├── run_csp_tests.jl ├── run_planning_tests.jl ├── run_nlp_tests.jl ├── run_learning_tests.jl ├── run_probability_tests.jl ├── run_text_tests.jl ├── run_logic_tests.jl └── run_kl_tests.jl ├── nlp_apps.ipynb ├── CONTRIBUTING.md ├── README.md ├── planning.ipynb ├── mdp.jl ├── rl.jl ├── games.jl └── kl.jl /.gitignore: -------------------------------------------------------------------------------- 1 | # IJulia Notebook 2 | .ipynb_checkpoints -------------------------------------------------------------------------------- /images/restaurant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/aima-julia/master/images/restaurant.png -------------------------------------------------------------------------------- /images/bayesnetwork.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/aima-julia/master/images/bayesnetwork.png -------------------------------------------------------------------------------- /images/sprinklernet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/findmyway/aima-julia/master/images/sprinklernet.jpg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "aima-data"] 2 | path = aima-data 3 | url = https://github.com/aimacode/aima-data 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: julia 2 | sudo: false 3 | 4 | before_install: 5 | - git submodule update --init 6 | 7 | os: 8 | - linux 9 | - osx 10 | 11 | julia: 12 | - 0.6 13 | 14 | script: 15 | - "travis_wait 20 sleep 1200 &" 16 | - sh tests/run_travis_tests.sh 17 | -------------------------------------------------------------------------------- /aimajulia.jl: -------------------------------------------------------------------------------- 1 | module aimajulia; 2 | 3 | include("utils.jl"); 4 | 5 | using aimajulia.utils; 6 | 7 | AIMAJULIA_DIRECTORY = Base.source_dir(); 8 | 9 | include("logic.jl"); 10 | 11 | include("agents.jl"); 12 | 13 | include("search.jl"); 14 | 15 | include("games.jl"); 16 | 17 | include("csp.jl"); 18 | 19 | include("planning.jl"); 20 | 21 | include("probability.jl"); 22 | 23 | include("mdp.jl"); 24 | 25 | include("learning.jl"); 26 | 27 | include("kl.jl"); 28 | 29 | include("rl.jl"); 30 | 31 | include("nlp.jl"); 32 | 33 | include("text.jl"); 34 | 35 | end; -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 aima-julia contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/run_travis_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | echo "Changing directory to: $(dirname $0)" 6 | 7 | cd $(dirname $0) 8 | 9 | echo "TRAVIS_PULL_REQUEST: $TRAVIS_PULL_REQUEST" 10 | echo "TRAVIS_PULL_REQUEST_SHA: $TRAVIS_PULL_REQUEST_SHA" 11 | 12 | echo "$" "ulimit -a" 13 | 14 | ulimit -a 15 | 16 | echo 17 | 18 | git clone https://github.com/aimacode/aima-data 19 | 20 | echo 21 | 22 | julia -e "versioninfo();" 23 | 24 | echo 25 | 26 | #Some of the testv() doctests in agents.py can sometimes fail when the 27 | #scores are out of expected bounds. 28 | 29 | julia --color=yes run_agent_tests.jl 30 | 31 | julia --color=yes run_search_tests.jl 32 | 33 | julia --color=yes run_util_tests.jl 34 | 35 | julia --color=yes run_game_tests.jl 36 | 37 | julia --color=yes run_csp_tests.jl 38 | 39 | julia --color=yes run_logic_tests.jl 40 | 41 | julia --color=yes run_planning_tests.jl 42 | 43 | julia --color=yes run_probability_tests.jl 44 | 45 | julia --color=yes run_mdp_tests.jl 46 | 47 | julia --color=yes run_learning_tests.jl 48 | 49 | julia --color=yes run_kl_tests.jl 50 | 51 | julia --color=yes run_rl_tests.jl 52 | 53 | julia --color=yes run_nlp_tests.jl 54 | 55 | julia --color=yes run_text_tests.jl 56 | 57 | -------------------------------------------------------------------------------- /tests/run_game_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following game tests are from the aima-python doctest 8 | 9 | @test minimax_decision("A", Figure52Game()) == "A1"; 10 | 11 | @test alphabeta_full_search("A", Figure52Game()) == "A1"; 12 | 13 | @test alphabeta_search("A", Figure52Game()) == "A1"; 14 | 15 | @test play_game(Figure52Game(), alphabeta_player, alphabeta_player) == 3; 16 | 17 | #= 18 | 19 | The following tests may fail sometimes because the tests run on random behavior. 20 | 21 | However, the results of tests that fail does not imply something is wrong. 22 | 23 | =# 24 | 25 | function colorize_testv_doctest_results(result::Bool) 26 | if (result) 27 | print_with_color(:green, "Test Passed\n"); 28 | else 29 | print_with_color(:red, "Test Failed\n"); 30 | end 31 | end 32 | 33 | randf52_result = play_game(Figure52Game(), random_player, random_player); 34 | colorize_testv_doctest_results(randf52_result == 6); 35 | println("Expression: play_game(Figure52Game(), random_player, random_player) == 6"); 36 | println("Evaluated: ", randf52_result, " == 6"); 37 | 38 | randttt_result = play_game(TicTacToeGame(), random_player, random_player); 39 | colorize_testv_doctest_results(randttt_result == 0); 40 | println("Expression: play_game(TicTacToeGame(), random_player, random_player) == 0"); 41 | println("Evaluated: ", randttt_result, " == 0"); 42 | -------------------------------------------------------------------------------- /tests/run_mdp_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following mdp tests are from the aima-python doctests 8 | 9 | tm = Dict([Pair("A", Dict([Pair("a1", (0.3, "B")), Pair("a2", (0.7, "C"))])), 10 | Pair("B", Dict([Pair("a1", (0.5, "B")), Pair("a2", (0.5, "A"))])), 11 | Pair("C", Dict([Pair("a1", (0.9, "A")), Pair("a2", (0.1, "B"))]))]); 12 | 13 | mdp = MarkovDecisionProcess("A", Set(["a1", "a2"]), Set(["C"]), tm, states=Set(["A","B","C"])); 14 | 15 | @test (transition_model(mdp, "A", "a1") == (0.3, "B")); 16 | 17 | @test (transition_model(mdp, "B", "a2") == (0.5, "A")); 18 | 19 | @test (transition_model(mdp, "C", "a1") == (0.9, "A")); 20 | 21 | @test (repr(value_iteration(aimajulia.sequential_decision_environment, epsilon=0.01)) == 22 | "Dict((2, 3)=>0.486437,(2, 1)=>0.398102,(3, 1)=>0.509285,(1, 4)=>0.129589,(3, 3)=>0.795361,(1, 3)=>0.344613,(3, 2)=>0.649581,(2, 4)=>-1.0,(1, 1)=>0.295435,(1, 2)=>0.253487,(3, 4)=>1.0)"); 23 | 24 | pi = optimal_policy(aimajulia.sequential_decision_environment, value_iteration(aimajulia.sequential_decision_environment, epsilon=0.01)); 25 | 26 | @test (repr(pi) == "Dict{Any,Any}(Pair{Any,Any}((2, 3), (1, 0)),Pair{Any,Any}((2, 1), (1, 0)),Pair{Any,Any}((3, 1), (0, 1)),Pair{Any,Any}((1, 4), (0, -1)),Pair{Any,Any}((3, 3), (0, 1)),Pair{Any,Any}((1, 3), (1, 0)),Pair{Any,Any}((3, 2), (0, 1)),Pair{Any,Any}((2, 4), nothing),Pair{Any,Any}((1, 1), (1, 0)),Pair{Any,Any}((1, 2), (0, 1)),Pair{Any,Any}((3, 4), nothing))"); 27 | 28 | @test (repr(to_arrows(aimajulia.sequential_decision_environment, pi)) == "Nullable{String}[\"v\" \">\" \"v\" \"<\"; \"v\" #NULL \"v\" \".\"; \">\" \">\" \">\" \".\"]"); 29 | 30 | @test (policy_iteration(aimajulia.sequential_decision_environment) == 31 | Dict([Pair((2,3),(1,0)), 32 | Pair((2,1),(1,0)), 33 | Pair((3,1),(0,1)), 34 | Pair((1,4),(0,-1)), 35 | Pair((3,3),(0,1)), 36 | Pair((1,3),(1,0)), 37 | Pair((3,2),(0,1)), 38 | Pair((2,4),nothing), 39 | Pair((1,1),(1,0)), 40 | Pair((1,2),(0,1)), 41 | Pair((3,4),nothing)])); 42 | 43 | -------------------------------------------------------------------------------- /tests/run_util_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | # The following util functions' tests are from the aima-python utils.py doctest 10 | 11 | na = [1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0]; 12 | 13 | function qtest(qf::DataType; order::Union{Bool, Base.Order.Ordering}=false, f::Union{Void, Function, MemoizedFunction}=nothing) 14 | if (!(qf <: PQueue)) 15 | q = qf(); 16 | extend!(q, na); 17 | for num in na 18 | @test num in q; 19 | end 20 | @test !(42 in q); 21 | return [pop!(q) for i in range(0, length(q))]; 22 | else 23 | if (order == false) 24 | q = qf(); 25 | else 26 | q = qf(order=order); 27 | end 28 | if (!(typeof(f) <: Void)) 29 | extend!(q, na, f); 30 | else 31 | extend!(q, na, (function(item) return item; end)); 32 | end 33 | for num in na 34 | @test num in [getindex(x, 2) for x in collect(q)]; 35 | end 36 | @test !(42 in [getindex(x, 2) for x in collect(q)]); 37 | return [pop!(q) for i in range(0, length(q))]; 38 | end 39 | end 40 | 41 | @test qtest(Stack) == [0, 3, 4, 99, -99, 6, 5, 7, 2, 8, 1]; 42 | 43 | @test qtest(FIFOQueue) == [1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0]; 44 | 45 | @test qtest(PQueue) == [-99, 0, 1, 2, 3, 4, 5, 6, 7, 8, 99]; 46 | 47 | @test qtest(PQueue, order=Base.Order.Reverse) == [99, 8, 7, 6, 5, 4, 3, 2, 1, 0, -99]; 48 | 49 | @test qtest(PQueue, f=abs) == [0, 1, 2, 3, 4, 5, 6, 7, 8, -99, 99]; 50 | 51 | @test qtest(PQueue, order=Base.Order.Reverse, f=abs) == [99, -99, 8, 7, 6, 5, 4, 3, 2, 1, 0]; 52 | 53 | mabs = MemoizedFunction(abs); #memoize abs() 54 | 55 | @test qtest(PQueue, f=mabs) == [0, 1, 2, 3, 4, 5, 6, 7, 8, -99, 99]; 56 | 57 | @test qtest(PQueue, order=Base.Order.Reverse, f=mabs) == [99, -99, 8, 7, 6, 5, 4, 3, 2, 1, 0]; 58 | 59 | @test weighted_sample_with_replacement([], [], 0) == []; 60 | 61 | @test weighted_sample_with_replacement("a", [3], 2) == ['a', 'a']; 62 | 63 | @test weighted_sample_with_replacement("ab", [0, 3], 3) == ['b', 'b', 'b']; 64 | 65 | @test count(isfunction, [42, nothing, max, min]) == 2; 66 | 67 | @test findfirst(isfunction, [3, min, max]) == 2; 68 | 69 | @test findfirst(isfunction, [1, 2, 3]) == 0; 70 | 71 | @test normalize_probability_distribution([1, 2, 1]) == [0.25, 0.5, 0.25]; 72 | 73 | -------------------------------------------------------------------------------- /tests/run_agent_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following Agent tests are from the aima-python doctest 8 | 9 | RVA = ReflexVacuumAgent(); 10 | 11 | @test execute(RVA.program, (aimajulia.loc_A, "Clean")) == "Right"; 12 | 13 | @test execute(RVA.program, (aimajulia.loc_B, "Clean")) == "Left"; 14 | 15 | @test execute(RVA.program, (aimajulia.loc_A, "Dirty")) == "Suck"; 16 | 17 | @test execute(RVA.program, (aimajulia.loc_B, "Dirty")) == "Suck"; 18 | 19 | TVE = TrivialVacuumEnvironment(); 20 | 21 | @test add_object(TVE, ModelBasedVacuumAgent()) == nothing; 22 | 23 | @test run(TVE, steps=5) == nothing; 24 | 25 | #= 26 | 27 | The following tests may fail sometimes because the tests check for the expected bounds. 28 | 29 | However, the results of tests that lie outside of expected bounds does not imply something is wrong. 30 | 31 | =# 32 | 33 | function colorize_testv_doctest_results(result::Bool) 34 | if (result) 35 | print_with_color(:green, "Test Passed\n"); 36 | else 37 | print_with_color(:red, "Test Failed\n"); 38 | end 39 | end 40 | 41 | envs = [TrivialVacuumEnvironment() for i in range(0, 100)]; 42 | 43 | mbva_result = test_agent(ModelBasedVacuumAgent, 4, deepcopy(envs)); 44 | colorize_testv_doctest_results(7 < mbva_result < 11); 45 | println("Expression: 7 < test_agent(ModelBasedVacuumAgent, 4, deepcopy(envs)) < 11"); 46 | println("Evaluated: 7 < ", mbva_result, " < 11"); 47 | 48 | refva_result = test_agent(ReflexVacuumAgent, 4, deepcopy(envs)); 49 | colorize_testv_doctest_results(5 < refva_result < 9); 50 | println("Expression: 5 < test_agent(ReflexVacuumAgent, 4, deepcopy(envs)) < 9"); 51 | println("Evaluated: 5 < ", refva_result, " < 9"); 52 | 53 | tdva_result = test_agent(TableDrivenVacuumAgent, 4, deepcopy(envs)); 54 | colorize_testv_doctest_results(2 < tdva_result < 6); 55 | println("Expression: 2 < test_agent(TableDrivenVacuumAgent, 4, deepcopy(envs)) < 6"); 56 | println("Evaluated: 2 < ", tdva_result, " < 6"); 57 | 58 | randva_result = test_agent(RandomVacuumAgent, 4, deepcopy(envs)); 59 | colorize_testv_doctest_results(0.5 < randva_result < 3); 60 | println("Expression: 0.5 < test_agent(RandomVacuumAgent, 4, deepcopy(envs)) < 3"); 61 | println("Evaluated: 0.5 < ", randva_result, " < 3"); 62 | -------------------------------------------------------------------------------- /tests/non_deterministic_astar.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | using aimajulia; 3 | using aimajulia.utils; 4 | 5 | function beautify_node_args(n) 6 | if (typeof(n) <: Node) 7 | return beautify_node(n); 8 | elseif (typeof(n) <: Tuple) 9 | return map(beautify_node_args, n); 10 | else 11 | return n; 12 | end 13 | end 14 | 15 | function bfgs{T <: AbstractProblem}(problem::T, f::Function) 16 | mf = MemoizedFunction(f); 17 | local node = Node{typeof(problem.initial)}(problem.initial); 18 | if (goal_test(problem, node.state)) 19 | return node; 20 | end 21 | local frontier = PQueue(); 22 | push!(frontier, node, mf); 23 | local explored = Set{String}(); 24 | while (length(frontier) != 0) 25 | node = pop!(frontier); 26 | if (goal_test(problem, node.state)) 27 | return node; 28 | end 29 | push!(explored, node.state); 30 | for child_node in expand(node, problem) 31 | if (!(child_node.state in explored) && 32 | !(child_node in collect(getindex(x, 2) for x in frontier.array))) 33 | push!(frontier, child_node, mf); 34 | elseif (child_node in [getindex(x, 2) for x in frontier.array]) 35 | #Recall that Nodes can share the same state and different values for other fields. 36 | local existing_node = pop!(collect(getindex(x, 2) 37 | for x in frontier.array 38 | if (getindex(x, 2) == child_node))); 39 | 40 | eval_memoized_function(mf, child_node); 41 | eval_memoized_function(mf, existing_node); 42 | 43 | if (eval_memoized_function(mf, child_node) < eval_memoized_function(mf, existing_node)) 44 | delete!(frontier, existing_node); 45 | push!(frontier, child_node, mf); 46 | end 47 | end 48 | end 49 | print("length of memoization dictionary: ", length(mf.values), " "); 50 | println(map(beautify_node_args, collect(keys(mf.values)))...); 51 | end 52 | return nothing; 53 | end 54 | 55 | function astar_memoized_search(problem::GraphProblem; h::Union{Void, Function}=nothing) 56 | local mh::MemoizedFunction; #memoized h(n) function 57 | if (!(typeof(h) <: Void)) 58 | mh = h; 59 | else 60 | mh = problem.h; 61 | end 62 | return bfgs(problem, 63 | (function(node::Node; h::MemoizedFunction=mh, prob::GraphProblem=problem) 64 | return node.path_cost + eval_memoized_function(h, prob, node);end)); 65 | end 66 | 67 | astar_str="";for i in 1:5 68 | if (Node{String}("P")==get(astar_memoized_search(GraphProblem("A", "B", aimajulia.romania)).parent)) 69 | astar_str = astar_str * "1"; 70 | println("Test: Passed!"); 71 | else 72 | astar_str = astar_str * "0"; 73 | println("Test: Failed!"); 74 | end 75 | end;println(count(i->(i=='1'), astar_str), " of 5 tries passed!"); 76 | -------------------------------------------------------------------------------- /tests/run_rl_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | #The following reinforcement learning tests are from the aima-python doctests 10 | 11 | north, south, west, east = (1, 0), (-1, 0), (0, -1), (0, 1); 12 | 13 | policy = Dict([Pair((1, 1), north), 14 | Pair((1, 2), west), 15 | Pair((1, 3), west), 16 | Pair((1, 4), west), 17 | Pair((2, 1), north), 18 | Pair((2, 3), north), 19 | Pair((2, 4), nothing), 20 | Pair((3, 1), east), 21 | Pair((3, 2), east), 22 | Pair((3, 3), east), 23 | Pair((3, 4), nothing)]) 24 | 25 | passive_adp_agent = PassiveADPAgentProgram(policy, aimajulia.sequential_decision_environment); 26 | for i in 1:75 27 | aimajulia.run_single_trial(passive_adp_agent, aimajulia.sequential_decision_environment); 28 | end 29 | 30 | @test (passive_adp_agent.U[(1, 1)] > 0.15); 31 | println("passive_adp_agent.U[(1, 1)] (expected ~0.3): ", passive_adp_agent.U[(1, 1)]); 32 | 33 | @test (passive_adp_agent.U[(2, 1)] > 0.15); 34 | println("passive_adp_agent.U[(2, 1)] (expected ~0.4): ", passive_adp_agent.U[(2, 1)]); 35 | 36 | @test (passive_adp_agent.U[(1, 2)] > 0); 37 | println("passive_adp_agent.U[(1, 2)] (expected ~0.2): ", passive_adp_agent.U[(1, 2)]); 38 | 39 | passive_td_agent = PassiveTDAgentProgram(policy, 40 | aimajulia.sequential_decision_environment, 41 | alpha=(function(n::Number) 42 | return (60/(59+n)); 43 | end)); 44 | 45 | for i in 1:200 46 | aimajulia.run_single_trial(passive_td_agent, aimajulia.sequential_decision_environment); 47 | end 48 | 49 | @test (passive_td_agent.U[(1, 1)] > 0.15); 50 | println("passive_td_agent.U[(1, 1)] (expected ~0.3): ", passive_td_agent.U[(1, 1)]); 51 | 52 | @test (passive_td_agent.U[(2, 1)] > 0.15); 53 | println("passive_td_agent.U[(2, 1)] (expected ~0.35): ", passive_td_agent.U[(2, 1)]); 54 | 55 | @test (passive_td_agent.U[(1, 2)] > 0.13); 56 | println("passive_td_agent.U[(1, 2)] (expected ~0.25): ", passive_td_agent.U[(1, 2)]); 57 | 58 | qlearning_agent = QLearningAgentProgram(aimajulia.sequential_decision_environment, 59 | 5, 60 | 2, 61 | alpha=(function(n::Number) 62 | return (60/(59 + n)); 63 | end)); 64 | 65 | for i in 1:200 66 | aimajulia.run_single_trial(qlearning_agent, aimajulia.sequential_decision_environment); 67 | end 68 | 69 | @test (qlearning_agent.Q[((2, 1), (1, 0))]>= -0.5); 70 | println("qlearning_agent.Q[((2, 1), (1, 0))] expected (0.1): ", qlearning_agent.Q[((2, 1), (1, 0))]); 71 | 72 | @test (qlearning_agent.Q[((1, 2), (-1, 0))] <= 0.5); 73 | println("qlearning_agent.Q[((1, 2), (-1, 0))] expected (-0.1): ", qlearning_agent.Q[((1, 2), (-1, 0))]); 74 | 75 | -------------------------------------------------------------------------------- /tests/run_search_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following search tests are from the aima-python doctest 8 | 9 | @test depth_first_tree_search(NQueensProblem(8)) == Node{Array{Int64, 1}}([8, 4, 1, 3, 6, 2, 7, 5]); 10 | 11 | #Specify wordlist path for travis-ci testing. 12 | filename = "./aima-data/EN-text/wordlist.txt"; 13 | 14 | @test length(BoggleFinder(board=collect("SARTELNID"), fn=filename)) == 206; 15 | 16 | ab = GraphProblem("A", "B", aimajulia.romania); 17 | 18 | @test solution(breadth_first_tree_search(ab)) == ["S", "F", "B"]; 19 | 20 | @test solution(breadth_first_search(ab)) == ["S", "F", "B"]; 21 | 22 | @test solution(uniform_cost_search(ab)) == ["S", "R", "P", "B"]; 23 | 24 | @test solution(depth_first_graph_search(ab)) == ["T", "L", "M", "D", "C", "P", "B"]; 25 | 26 | @test solution(iterative_deepening_search(ab)) == ["S", "F", "B"]; 27 | 28 | @test length(solution(depth_limited_search(ab))) == 50; 29 | 30 | @test solution(astar_search(ab)) == ["S", "R", "P", "B"]; 31 | 32 | @test solution(recursive_best_first_search(ab)) == ["S", "R", "P", "B"]; 33 | 34 | @test compare_searchers([GraphProblem("A", "B", aimajulia.romania), 35 | GraphProblem("O", "N", aimajulia.romania), 36 | GraphProblem("Q", "WA", aimajulia.australia)], 37 | ["Searcher", "Romania(A, B)", "Romania(O, N)", "Australia"]) == 38 | ["Searcher" "Romania(A, B)" "Romania(O, N)" "Australia"; 39 | "aimajulia.breadth_first_tree_search" "< 23/ 24/ 63/B>" "<1191/1192/3378/N>" "< 9/ 10/ 32/WA>"; 40 | "aimajulia.breadth_first_search" "< 7/ 11/ 18/B>" "< 18/ 20/ 44/N>" "< 3/ 6/ 9/WA>"; 41 | "aimajulia.depth_first_graph_search" "< 8/ 9/ 20/B>" "< 16/ 17/ 37/N>" "< 2/ 3/ 8/WA>"; 42 | "aimajulia.iterative_deepening_search" "< 13/ 36/ 36/B>" "< 683/1874/1875/N>" "< 4/ 13/ 12/WA>"; 43 | "aimajulia.depth_limited_search" "< 64/ 94/ 167/B>" "< 948/2629/2701/N>" "< 51/ 57/ 153/WA>"; 44 | "aimajulia.recursive_best_first_search" "< 11/ 12/ 35/B>" "<8481/8482/23788/N>" "< 10/ 11/ 38/WA>"]; 45 | 46 | # Initialize LRTAStarAgentProgram with an OnlineSearchProblem. 47 | 48 | lrtastar_program = OnlineSearchProblem("State_3", "State_5", aimajulia.one_dim_state_space, aimajulia.one_dim_state_space_least_costs); 49 | lrtastar_agentprogram = LRTAStarAgentProgram(lrtastar_program); 50 | 51 | @test execute(lrtastar_agentprogram, "State_3") == "Right"; 52 | 53 | @test execute(lrtastar_agentprogram, "State_4") == "Left"; 54 | 55 | @test execute(lrtastar_agentprogram, "State_3") == "Right"; 56 | 57 | @test execute(lrtastar_agentprogram, "State_4") == "Right"; 58 | 59 | @test execute(lrtastar_agentprogram, "State_5") == nothing; 60 | 61 | lrtastar_agentprogram = LRTAStarAgentProgram(lrtastar_program); 62 | 63 | @test execute(lrtastar_agentprogram, "State_4") == "Left"; 64 | 65 | lrtastar_agentprogram = LRTAStarAgentProgram(lrtastar_program); 66 | 67 | @test execute(lrtastar_agentprogram, "State_5") == nothing; 68 | 69 | -------------------------------------------------------------------------------- /tests/run_csp_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following CSP tests are from the aima-python doctest 8 | 9 | @test (solution(depth_first_graph_search(aimajulia.australia_csp))...) == (("NSW","B"),("Q","G"),("NT","B"),("T","B"),("V","G"),("SA","R"),("WA","G")); 10 | 11 | d = CSPDict(ConstantFunctionDict(42)); 12 | 13 | @test d["life"] == 42; 14 | 15 | @test (!(typeof(backtracking_search(aimajulia.australia_csp)) <: Void) == true); 16 | 17 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, 18 | select_unassigned_variable=minimum_remaining_values)) <: Void) == true); 19 | 20 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, 21 | order_domain_values=least_constraining_values)) <: Void) == true); 22 | 23 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, 24 | select_unassigned_variable=minimum_remaining_values, 25 | order_domain_values=least_constraining_values)) <: Void) == true); 26 | 27 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, inference=forward_checking)) <: Void) == true); 28 | 29 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, 30 | inference=maintain_arc_consistency)) <: Void) == true); 31 | 32 | @test (!(typeof(backtracking_search(aimajulia.australia_csp, 33 | select_unassigned_variable=minimum_remaining_values, 34 | order_domain_values=least_constraining_values, 35 | inference=maintain_arc_consistency)) <: Void) == true); 36 | 37 | topological_sorted_nodes, parent_dict = topological_sort(aimajulia.australia_csp, "NT"); 38 | 39 | @test topological_sorted_nodes == Any["NT","SA","Q","NSW","V","WA"]; 40 | 41 | @test haskey(parent_dict, "NT") == false; 42 | 43 | @test parent_dict == Dict{Any,Any}(Pair("NSW","Q"), Pair("Q","SA"), Pair("V","NSW"), Pair("SA","NT"), Pair("WA","SA")); 44 | 45 | @test length(backtracking_search(NQueensCSP(8))) == 8; 46 | 47 | e = SudokuCSP(aimajulia.easy_sudoku_grid); 48 | 49 | @test display(e, infer_assignment(e)) == ". . 3 | . 2 . | 6 . .\n9 . . | 3 . 5 | . . 1\n. . 1 | 8 . 6 | 4 . .\n------+-------+------\n. . 8 | 1 . 2 | 9 . .\n7 . . | . . . | . . 8\n. . 6 | 7 . 8 | 2 . .\n------+-------+------\n. . 2 | 6 . 9 | 5 . .\n8 . . | 2 . 3 | . . 9\n. . 5 | . 1 . | 3 . ."; 50 | 51 | AC3(e); 52 | 53 | @test display(e, infer_assignment(e)) == "4 8 3 | 9 2 1 | 6 5 7\n9 6 7 | 3 4 5 | 8 2 1\n2 5 1 | 8 7 6 | 4 9 3\n------+-------+------\n5 4 8 | 1 3 2 | 9 7 6\n7 2 9 | 5 6 4 | 1 3 8\n1 3 6 | 7 9 8 | 2 4 5\n------+-------+------\n3 7 2 | 6 8 9 | 5 1 4\n8 1 4 | 2 5 3 | 7 6 9\n6 9 5 | 4 1 7 | 3 8 2"; 54 | 55 | @test !(typeof(backtracking_search(SudokuCSP(aimajulia.harder_sudoku_grid), select_unassigned_variable=minimum_remaining_values, inference=forward_checking)) <: Void) 56 | 57 | @test solve_zebra(ZebraCSP(), backtracking_search) == (5, 58 | 1, 59 | 75472, 60 | Dict{Any,Any}(Pair{Any,Any}("Tea",2), 61 | Pair{Any,Any}("Red",3), 62 | Pair{Any,Any}("Kools",1), 63 | Pair{Any,Any}("Green",5), 64 | Pair{Any,Any}("Horse",2), 65 | Pair{Any,Any}("Zebra",5), 66 | Pair{Any,Any}("OJ",4), 67 | Pair{Any,Any}("Milk",3), 68 | Pair{Any,Any}("Coffee",5), 69 | Pair{Any,Any}("Ukranian",2), 70 | Pair{Any,Any}("Japanese",5), 71 | Pair{Any,Any}("Snails",3), 72 | Pair{Any,Any}("Spaniard",4), 73 | Pair{Any,Any}("Water",1), 74 | Pair{Any,Any}("Winston",3), 75 | Pair{Any,Any}("Norwegian",1), 76 | Pair{Any,Any}("Fox",1), 77 | Pair{Any,Any}("Dog",4), 78 | Pair{Any,Any}("Ivory",4), 79 | Pair{Any,Any}("Englishman",3), 80 | Pair{Any,Any}("Yellow",1), 81 | Pair{Any,Any}("LuckyStrike",4), 82 | Pair{Any,Any}("Parliaments",5), 83 | Pair{Any,Any}("Blue",2), 84 | Pair{Any,Any}("Chesterfields",2))); 85 | -------------------------------------------------------------------------------- /tests/run_planning_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following planning tests are from the aima-python doctests 8 | 9 | precondition = (map(expr, ["P(x)", "Q(y, z)"]), [expr("Q(x)")]); 10 | 11 | effect = ([expr("Q(x)")], [expr("P(x)")]); 12 | 13 | plan_action = PlanningAction(expr("A(x, y, z)"), precondition, effect); 14 | 15 | arguments = map(expr, ["A", "B", "C"]); 16 | 17 | @test (substitute(plan_action, expr("P(x, z, y)"), (arguments...)) == expr("P(A, C, B)")); 18 | 19 | test_planning_kb = FirstOrderLogicKnowledgeBase(map(expr, ["P(A)", "Q(B, C)", "R(D)"])); 20 | 21 | @test check_precondition(plan_action, test_planning_kb, (arguments...)); 22 | 23 | execute_action(plan_action, test_planning_kb, (arguments...)); 24 | 25 | # Found no valid substitutions! 26 | @test (length(ask(test_planning_kb, expr("P(A)"))) == 0); 27 | 28 | # Found valid substitutions! 29 | @test (length(ask(test_planning_kb, expr("Q(A)"))) != 0); 30 | 31 | # Found valid substitutions! 32 | @test (length(ask(test_planning_kb, expr("Q(A)"))) != 0); 33 | 34 | @test (!check_precondition(plan_action, test_planning_kb, (arguments...))); 35 | 36 | air_cargo = air_cargo_pddl(); 37 | 38 | @test (goal_test(air_cargo) == false); 39 | 40 | for action in map(expr, ("Load(C1, P1, SFO)", "Fly(P1, SFO, JFK)", "Unload(C1, P1, JFK)", 41 | "Load(C2, P2, JFK)", "Fly(P2, JFK, SFO)", "Unload(C2, P2, SFO)")) 42 | execute_action(air_cargo, action); 43 | end 44 | 45 | @test goal_test(air_cargo); 46 | 47 | air_cargo = air_cargo_pddl(); 48 | 49 | @test (goal_test(air_cargo) == false); 50 | 51 | for action in map(expr, ("Load(C2, P2, JFK)", "Fly(P2, JFK, SFO)", "Unload(C2, P2, SFO)", 52 | "Load(C1, P1, SFO)", "Fly(P1, SFO, JFK)", "Unload(C1, P1, JFK)")) 53 | execute_action(air_cargo, action); 54 | end 55 | 56 | @test goal_test(air_cargo); 57 | 58 | spare_tire = spare_tire_pddl(); 59 | 60 | @test (goal_test(spare_tire) == false); 61 | 62 | for action in map(expr, ("Remove(Flat, Axle)", "Remove(Spare, Trunk)", "PutOn(Spare, Axle)")) 63 | execute_action(spare_tire, action); 64 | end 65 | 66 | @test goal_test(spare_tire); 67 | 68 | three_block_tower = three_block_tower_pddl(); 69 | 70 | @test (goal_test(three_block_tower) == false); 71 | 72 | for action in map(expr, ("MoveToTable(C, A)", "Move(B, Table, C)", "Move(A, Table, B)")) 73 | execute_action(three_block_tower, action); 74 | end 75 | 76 | @test goal_test(three_block_tower); 77 | 78 | have_cake_and_eat_cake_too = have_cake_and_eat_cake_too_pddl(); 79 | 80 | @test (goal_test(have_cake_and_eat_cake_too) == false); 81 | 82 | for action in map(expr, ("Eat(Cake)", "Bake(Cake)")) 83 | execute_action(have_cake_and_eat_cake_too, action); 84 | end 85 | 86 | @test goal_test(have_cake_and_eat_cake_too); 87 | 88 | spare_tire = spare_tire_pddl(); 89 | 90 | negated_kb = FirstOrderLogicKnowledgeBase([expr("At(Flat, Trunk)")]); 91 | 92 | spare_tire_graph = PlanningGraph(spare_tire, negated_kb); 93 | 94 | untouched_graph_levels_count = length(spare_tire_graph.levels); 95 | 96 | expand_graph(spare_tire_graph); 97 | 98 | @test (untouched_graph_levels_count == (length(spare_tire_graph.levels) - 1)); 99 | 100 | # Apply graphplan() to spare tire planning problem. 101 | 102 | spare_tire = spare_tire_pddl(); 103 | 104 | negated_kb = FirstOrderLogicKnowledgeBase([expr("At(Flat, Trunk)")]); 105 | 106 | spare_tire_gp = GraphPlanProblem(spare_tire, negated_kb); 107 | 108 | @test (!(typeof(graphplan(spare_tire_gp, ([expr("At(Spare, Axle)"), expr("At(Flat, Ground)")], []))) <: Void)); 109 | 110 | doubles_tennis = doubles_tennis_pddl(); 111 | 112 | @test (goal_test(doubles_tennis) == false); 113 | 114 | for action in map(expr, ["Go(A, LeftBaseLine, RightBaseLine)", "Hit(A, RightBaseLine, Ball)", "Go(A, RightBaseLine, LeftNet)"]) 115 | execute_action(doubles_tennis, action); 116 | end 117 | 118 | @test goal_test(doubles_tennis); 119 | 120 | # Create dictionary representation of possible refinements for "going to San Francisco airport HLA" (Fig. 11.4). 121 | go_to_sfo_refinements_dict = Dict([Pair("HLA", ["Go(Home,SFO)", "Go(Home,SFO)", "Drive(Home, SFOLongTermParking)", "Shuttle(SFOLongTermParking, SFO)", "Taxi(Home, SFO)"]), 122 | Pair("steps", [["Drive(Home, SFOLongTermParking)", "Shuttle(SFOLongTermParking, SFO)"], ["Taxi(Home, SFO)"], [], [], []]), 123 | Pair("precondition_positive", [["At(Home), Have(Car)"], ["At(Home)"], ["At(Home)", "Have(Car)"], ["At(SFOLongTermParking)"], ["At(Home)"]]), 124 | Pair("precondition_negated", [[], [], [], [], []]), 125 | Pair("effect_add_list", [["At(SFO)"], ["At(SFO)"], ["At(SFOLongTermParking)"], ["At(SFO)"], ["At(SFO)"]]), 126 | Pair("effect_delete_list", [["At(Home)"], ["At(Home)"], ["At(Home)"], ["At(SFOLongTermParking)"], ["At(Home)"]]) 127 | ]); 128 | 129 | # Base.Test tests for refinements(). 130 | function test_refinement_goal_test(kb::FirstOrderLogicKnowledgeBase) 131 | return ask(kb, expr("At(SFO)")); 132 | end 133 | 134 | refinement_lib = Dict([Pair("HLA", ["Go(Home, SFO)", "Taxi(Home, SFO)"]), 135 | Pair("steps", [["Taxi(Home, SFO)"], []]), 136 | Pair("precondition_positive", [["At(Home)"], ["At(Home)"]]), 137 | Pair("precondition_negated", [[],[]]), 138 | Pair("effect_add_list", [["At(SFO)"],["At(SFO)"]]), 139 | Pair("effect_delete_list", [["At(Home)"], ["At(Home)"]])]); 140 | # Go to San Francisco airport high-level action schema 141 | precondition_positive = Array{Expression, 1}([expr("At(Home)")]); 142 | precondition_negated = Array{Expression, 1}([]); 143 | effect_add_list = Array{Expression, 1}([expr("At(SFO)")]); 144 | effect_delete_list = Array{Expression, 1}([expr("At(Home)")]); 145 | go_sfo = PlanningHighLevelAction(expr("Go(Home, SFO)"), 146 | (precondition_positive, precondition_negated), 147 | (effect_add_list, effect_delete_list)); 148 | # Take Taxi to San Francisco airport high-level action schema 149 | precondition_positive = Array{Expression, 1}([expr("At(Home)")]); 150 | precondition_negated = Array{Expression, 1}([]); 151 | effect_add_list = Array{Expression, 1}([expr("At(SFO)")]); 152 | effect_delete_list = Array{Expression, 1}([expr("At(Home)")]); 153 | taxi_sfo = PlanningHighLevelAction(expr("Go(Home, SFO)"), 154 | (precondition_positive, precondition_negated), 155 | (effect_add_list, effect_delete_list)); 156 | go_sfo_pddl = HighLevelPDDL(Array{Expression, 1}([expr("At(Home)")]), [go_sfo, taxi_sfo], test_refinement_goal_test); 157 | result = refinements(go_sfo, go_sfo_pddl, refinement_lib); 158 | @test (length(result) == 1); 159 | @test (result[1].name == "Taxi"); 160 | @test (result[1].arguments == (expr("Home"), expr("SFO"))); 161 | 162 | job_shop_scheduling = job_shop_scheduling_pddl(); 163 | 164 | @test (goal_test(job_shop_scheduling) == false); 165 | 166 | for i in reverse(1:2) 167 | for j in 1:3 168 | execute_action(job_shop_scheduling, job_shop_scheduling.jobs[i][j]); 169 | end 170 | end 171 | 172 | @test goal_test(job_shop_scheduling); 173 | 174 | -------------------------------------------------------------------------------- /nlp_apps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# NATURAL LANGUAGE PROCESSING APPLICATIONS" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook we will take a look at some indicative applications of natural language processing. We will cover content from [`nlp.jl`](https://github.com/aimacode/aima-julia/blob/master/nlp.jl) and [`text.jl`](https://github.com/aimacode/aima-julia/blob/master/text.jl), for chapters 22 and 23 of Stuart Russel's and Peter Norvig's book [*Artificial Intelligence: A Modern Approach*](http://aima.cs.berkeley.edu/)." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## CONTENTS\n", 22 | "* Language Recognition" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## LANGUAGE RECOGNITION\n", 30 | "A very useful application of text model is categorizing text into a language. In fact, with enough data we can categorize correctly mostly any text. That is because different languages have certain characteristics that set them apart. For example, in German it is very usual for 'c' to be followed by 'h' while in English we see 't' followed by 'h' a lot.\n", 31 | "\n", 32 | "Here we will build an application to categorize sentences in either English or German.\n", 33 | "\n", 34 | "First we need to build our dataset. We will take as input text in English and in German and we will extract n-gram character models (in this case, *bigrams* for n=2). For English, we will use *Flatland* by Edwin Abbott and for German *Faust* by Goethe.\n", 35 | "\n", 36 | "Let's build our text models for each language, which will hold the probability of each bigram occuring in the text." 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "include(\"aimajulia.jl\");\n", 46 | "\n", 47 | "using aimajulia;\n", 48 | "\n", 49 | "flatland = readstring(open(\"./aima-data/EN-text/flatland.txt\"));\n", 50 | "wordseq = extract_words(flatland);\n", 51 | "\n", 52 | "P_flatland = NgramCharModel(2, wordseq);\n", 53 | "\n", 54 | "faust = readstring(open(\"./aima-data/faust.txt\"));\n", 55 | "wordseq = extract_words(faust);\n", 56 | "\n", 57 | "P_faust = NgramCharModel(2, wordseq);" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "We can use this information to build a *Naive Bayes Classifier* that will be used to categorize sentences. The classifier will take as input the probability distribution of bigrams and given a list of bigrams (extracted from the sentence to be classified), it will calculate the probability of the example/sentence coming from each language and pick the maximum.\n", 65 | "\n", 66 | "Let's build our classifier, with the assumption that English is as probable as German (the input is a dictionary with values the text models and keys the tuple `language, probability`):" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "dist = Dict([((\"English\", 1), P_flatland), ((\"German\", 1), P_faust)]);\n", 76 | "\n", 77 | "nBS = NaiveBayesLearner(dist, simple=true);" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Now we need to write a function that takes as input a sentence, breaks it into a list of bigrams and classifies it with the naive bayes classifier from above.\n", 85 | "\n", 86 | "Once we get the text model for the sentence, we need to unravel it. The text models show the probability of each bigram, but the classifier can't handle that extra data. It requires a simple *list* of bigrams. So, if the text model shows that a bigram appears three times, we need to add it three times in the list. Since the text model stores the n-gram information in a dictionary (with the key being the n-gram and the value the number of times the n-gram appears) we need to iterate through the items of the dictionary and manually add them to the list of n-grams." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 3, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "function recognize(sentence::String, nBS::aimajulia.NaiveBayesLearner, n::Int64)\n", 96 | " sentence = lowercase(sentence);\n", 97 | " wordseq = extract_words(sentence);\n", 98 | " P_sentence = NgramCharModel(n, wordseq);\n", 99 | " \n", 100 | " ngrams = [];\n", 101 | " for (b, p) in P_sentence.dict\n", 102 | " ngrams = vcat(ngrams,fill(b, p));\n", 103 | " end\n", 104 | " \n", 105 | " return predict(nBS, ngrams);\n", 106 | "end;" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "Now we can start categorizing sentences." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "\"German\"" 125 | ] 126 | }, 127 | "execution_count": 4, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "recognize(\"Ich bin ein platz\", nBS, 2)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "\"English\"" 145 | ] 146 | }, 147 | "execution_count": 5, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "recognize(\"Turtles fly high\", nBS, 2)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 6, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "\"German\"" 165 | ] 166 | }, 167 | "execution_count": 6, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "recognize(\"Der pelikan ist hier\", nBS, 2)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 7, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "\"English\"" 185 | ] 186 | }, 187 | "execution_count": 7, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "recognize(\"And thus the wizard spoke\", nBS, 2)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "You can add more languages if you want, the algorithm works for as many as you like! Also, you can play around with *n*. Here we used 2, but other numbers work too (even though 2 suffices). The algorithm is not perfect, but it has high accuracy even for small samples like the ones we used. That is because English and German are very different languages. The closer together languages are (for example, Norwegian and Swedish share a lot of common ground) the lower the accuracy of the classifier." 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Julia 0.6.0", 207 | "language": "julia", 208 | "name": "julia-0.6" 209 | }, 210 | "language_info": { 211 | "file_extension": ".jl", 212 | "mimetype": "application/julia", 213 | "name": "julia", 214 | "version": "0.6.0" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | How to Contribute to aima-julia 2 | =============================== 3 | 4 | Thanks for considering contributing to aima-julia! 5 | 6 | Here is a guide (similar to the [aima-python contributing guide](https://github.com/aimacode/aima-python/blob/master/CONTRIBUTING.md)) on how you can help. 7 | 8 | In general, the main ways you can contribute to the repository are the following: 9 | 10 | 1. Implement algorithms from the [list of algorithms](https://github.com/aimacode/aima-julia/blob/master/README.md). 11 | 2. Add tests for algorithms that are missing them (you can also add more tests to algorithms that already have some). 12 | 3. Take care of [issues](https://github.com/aimacode/aima-julia/issues). 13 | 4. Write on the notebooks (`.ipynb` files). 14 | 5. Add and edit documentation (the docstrings in `.jl` files). 15 | 16 | In more detail: 17 | 18 | ## Read the Code and Start on an Issue 19 | 20 | - First, read and understand the code to get a feel for the extent and the style (see Style Guide below). 21 | - Look at the issues and pick one to work on. 22 | - One of the issues is that some algorithms are missing from the list of algorithms and that some don't have tests. 23 | 24 | ## RandomDevice() 25 | 26 | Avoid using `RandomDevice()`. Try using `aimajulia.RandomDeviceInstance` (can be referenced as `RandomDeviceInstance` within the `aimajulia` module) instead. Multiple `RandomDevice()` calls cause errors on some operating systems when opening multiple concurrent file descriptors to `/dev/urandom` or `/dev/random`. 27 | 28 | ## Haskell-like type assertion, Haskell/Lisp Functional Programming 29 | 30 | - When writing functions, the arguments should be type asserted (with exception to functions like getindex() for the Dict DataType). 31 | - The use of `collect()`, `map()`, `reduce()`, `mapreduce()`, and anonymous functions are recommended. 32 | - When declaring type definitions, try to assert the type of the fields. 33 | 34 | ## Porting to Julia from Python 35 | 36 | - Use comprehensions when possible. In addition, use `Iterators` when dealing with Python generator expressions (collecting the items if required). 37 | - String formatting can be accomplished with `sprintf()` and string concatenation (using the `*` operator or passing the `*` operator to `reduce()`). 38 | - Division between 2 real numbers results in a float. 39 | - Julia has native matrices, avoid using arrays of arrays unless required. 40 | - Add more tests in `test_*.jl` files. Strive for terseness; it is ok to group multiple asserts into one function. Move most tests to `test_*.jl`, but it is fine to have a single doctest example in the docstring of a function in the `.jl` file, if the purpose of the doctest is to explain how to use the function, rather than test the implementation. 41 | 42 | ## New and Improved Algorithms 43 | 44 | - Implement functions that were in the third edition of the book but were not yet implemented in the code. Check the [list of pseudocode algorithms (pdf)](https://github.com/aimacode/pseudocode/blob/master/aima3e-algorithms.pdf) to see what's missing. 45 | - As we finish chapters for the new fourth edition, we will share the new pseudocode in the [`aima-pseudocode`](https://github.com/aimacode/aima-pseudocode) repository, and describe what changes are necessary. We hope to have an `algorithm-name.md` file for each algorithm, eventually; it would be great if contributors could add some for the existing algorithms. 46 | - Give examples of how to use the code in the `.ipynb` files. 47 | 48 | ## Jupyter Notebooks 49 | 50 | In this project we use Jupyter/IJulia Notebooks to showcase the algorithms in the book. They serve as short tutorials on what the algorithms do, how they are implemented and how one can use them. To install Jupyter, you can follow the instructions here. These are some ways you can contribute to the notebooks: 51 | 52 | - Proofread the notebooks for grammar mistakes, typos, or general errors. 53 | - Move visualization and unrelated to the algorithm code from notebooks to `notebook.jl` (a file used to store code for the notebooks, like visualization and other miscellaneous stuff). Make sure the notebooks still work and have their outputs showing! 54 | - Replace the `%psource` magic notebook command with the function `psource` from `notebook.jl` where needed. Examples where this is useful are a) when we want to show code for algorithm implementation and b) when we have consecutive cells with the magic keyword (in this case, if the code is large, it's best to leave the output hidden). 55 | - Add the function pseudocode(algorithm_name) in algorithm sections. The function prints the pseudocode of the algorithm. You can see some example usage in `knowledge.ipynb`. 56 | - Edit existing sections for algorithms to add more information and/or examples. 57 | - Add visualizations for algorithms. The visualization code should go in notebook.jl to keep things clean. 58 | - Add new sections for algorithms not yet covered. The general format we use in the notebooks is the following: First start with an overview of the algorithm, printing the pseudocode and explaining how it works. Then, add some implementation details, including showing the code (using psource). Finally, add examples for the implementations, showing how the algorithms work. Don't fret with adding complex, real-world examples; the project is meant for educational purposes. You can of course choose another format if something better suits an algorithm. 59 | 60 | Apart from the notebooks explaining how the algorithms work, we also have notebooks showcasing some indicative applications of the algorithms. These notebooks are in the `*_apps.ipynb` format. We aim to have an apps notebook for each module, so if you don't see one for the module you would like to contribute to, feel free to create it from scratch! In these notebooks we are looking for applications showing what the algorithms can do. The general format of these sections is this: Add a description of the problem you are trying to solve, then explain how you are going to solve it and finally provide your solution with examples. Note that any code you write should not require any external libraries apart from the ones already provided (like matplotlib). 61 | 62 | # Style Guide 63 | 64 | There are a few style rules that are unique to this project: 65 | 66 | - The first rule is that the code should correspond directly to the pseudocode in the book. When possible this will be almost one-to-one, just allowing for the syntactic differences between Julia and pseudocode, and for different library functions. 67 | - Don't make a function more complicated than the pseudocode in the book, even if the complication would add a nice feature, or give an efficiency gain. Instead, remain faithful to the pseudocode, and if you must, add a new function (not in the book) with the added feature. 68 | - I use functional programming (functions with no side effects) in many cases, but not exclusively (sometimes type declarations and/or functions with side effects are used). Let the book's pseudocode be the guide. 69 | 70 | Beyond the above rules, we use the official Julia Style Guide ([0.5](https://docs.julialang.org/en/release-0.5/manual/style-guide/)/[0.6](https://docs.julialang.org/en/release-0.5/manual/style-guide/)), with a few minor exceptions: 71 | 72 | - One line comments start with a space after the # sign. 73 | - Use 4 spaces instead of tabs 74 | - Strunk and White is [not a good guide for English](http://chronicle.com/article/50-Years-of-Stupid-Grammar/25497). 75 | - I prefer more concise docstrings. In most cases, a one-line docstring does not suffice. It is necessary to list what each argument does; the name of the argument usually is enough. 76 | - Not all constants have to be uppercase. 77 | - Parenthesize expressions consisting of multiple subexpressions to avoid confusion. 78 | 79 | Updating existing code to newer Julia versions 80 | ============================================== 81 | 82 | The Julia language frequently changes their latest stable version (pre v1.0). For example, Julia 0.5 was announced October 11, 2016 and Julia 0.6 was announced June 27, 2017. As a result, we should have a separate branch for each supported Julia version. Pull requests should be made specifically to those branches (make an issue if the branch does not exist). 83 | 84 | Contributing a Patch 85 | ==================== 86 | 87 | - Submit an issue describing your proposed change to the repo in question (or work on an existing issue). 88 | 89 | - The repo owner will respond to your issue promptly. 90 | 91 | - Fork the desired repo, develop and test your code changes. 92 | 93 | - Submit a pull request. 94 | 95 | Reporting Issues 96 | ================ 97 | 98 | - Under which versions of Julia does this happen? 99 | 100 | - Provide an example of the issue occurring. 101 | 102 | - Is anybody working on this? 103 | 104 | Patch Rules 105 | =========== 106 | 107 | - Ensure that the patch is Julia 0.6 compliant. 108 | 109 | - Include tests if your patch is supposed to solve a bug, and explain clearly under which circumstances the bug happens. Make sure the test fails without your patch. 110 | 111 | - Follow the style guidelines described above. 112 | 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aima-julia 2 | 3 | [](https://travis-ci.org/aimacode/aima-julia) 4 | 5 | Julia (v0.6) implementation of the algorithms found in "Artificial Intelligence: A Modern Approach". 6 | 7 | This project is not intended to be a standard Julia package (i.e. the Julia package manager). 8 | 9 | We're looking for [solid contributors](CONTRIBUTING.md) to help. 10 | 11 | Structure of the Project 12 | ------------------------ 13 | 14 | When complete, this project will have Julia implementations for all the pseudocode algorithms in the book, as well as tests and examples of use. For each major topic, such as `nlp` (natural language processing), we provide the following files: 15 | 16 | - `nlp.jl`: Implementations of all the pseudocode algorithms, and necessary support functions/datatypes/data. 17 | - `tests/test_nlp.jl`: A lightweight test suite using Base.Test macros. 18 | - `nlp.ipynb`: A Jupyter (IJulia) notebook that explains and gives examples of how to use the code. 19 | - `nlp_apps.ipynb`: A Jupyter notebook that gives example applications of the code. 20 | 21 | Using aima-julia for portable purposes 22 | -------------------------------------- 23 | 24 | Include the following lines in all files within the same directory. 25 | 26 | ~~~ 27 | include("aimajulia.jl"); 28 | using aimajulia; 29 | ~~~ 30 | 31 | Index of Algorithms 32 | ------------------- 33 | 34 | | **Figure** | **Name (in 3rd edition)** | **Name (in repository)** | **File** 35 | |:--------|:-------------------|:---------|:-----------| 36 | | 2.1 | Environment | `Environment` | [`agents.jl`](agents.jl) | 37 | | 2.1 | Agent | `Agent` | [`agents.jl`](agents.jl) | 38 | | 2.3 | Table-Driven-Vacuum-Agent | `TableDrivenVacuumAgent` | [`agents.jl`](agents.jl) | 39 | | 2.7 | Table-Driven-Agent | `TableDrivenAgentProgram` | [`agents.jl`](agents.jl) | 40 | | 2.8 | Reflex-Vacuum-Agent | `ReflexVacuumAgent` | [`agents.jl`](agents.jl) | 41 | | 2.10 | Simple-Reflex-Agent | `SimpleReflexAgent` | [`agents.jl`](agents.jl) | 42 | | 2.12 | Model-Based-Reflex-Agent | `ModelBasedReflexAgentProgram` | [`agents.jl`](agents.jl) | 43 | | 3 | Problem | `Problem` | [`search.jl`](search.jl) | 44 | | 3 | Node | `Node` | [`search.jl`](search.jl) | 45 | | 3 | Queue | `Queue` | [`utils.jl`](utils.jl) | 46 | | 3.1 | Simple-Problem-Solving-Agent | `SimpleProblemSolvingAgent` | [`search.jl`](search.jl) | 47 | | 3.2 | Romania | `romania` | [`search.jl`](search.jl) | 48 | | 3.7 | Tree-Search | `tree_search` | [`search.jl`](search.jl) | 49 | | 3.7 | Graph-Search | `graph_search` | [`search.jl`](search.jl) | 50 | | 3.11 | Breadth-First-Search | `breadth_first_search` | [`search.jl`](search.jl) | 51 | | 3.14 | Uniform-Cost-Search | `uniform_cost_search` | [`search.jl`](search.jl) | 52 | | 3.17 | Depth-Limited-Search | `depth_limited_search` | [`search.jl`](search.jl) | 53 | | 3.18 | Iterative-Deepening-Search | `iterative_deepening_search` | [`search.jl`](search.jl) | 54 | | 3.22 | Best-First-Search | `best_first_graph_search` | [`search.jl`](search.jl) | 55 | | 3.24 | A\*-Search | `astar_search` | [`search.jl`](search.jl) | 56 | | 3.26 | Recursive-Best-First-Search | `recursive_best_first_search` | [`search.jl`](search.jl) | 57 | | 4.2 | Hill-Climbing | `hill_climbing` | [`search.jl`](search.jl) | 58 | | 4.5 | Simulated-Annealing | `simulated_annealing` | [`search.jl`](search.jl) | 59 | | 4.8 | Genetic-Algorithm | `genetic_algorithm` | [`search.jl`](search.jl) | 60 | | 4.11 | And-Or-Graph-Search | `and_or_graph_search` | [`search.jl`](search.jl) | 61 | | 4.21 | Online-DFS-Agent | `OnlineDFSAgentProgram` | [`search.jl`](search.jl) | 62 | | 4.24 | LRTA\*-Agent | `LRTAStarAgentProgram` | [`search.jl`](search.jl) | 63 | | 5.3 | Minimax-Decision | `minimax_decision` | [`games.jl`](games.jl) | 64 | | 5.7 | Alpha-Beta-Search | `alphabeta_search` | [`games.jl`](games.jl) | 65 | | 6 | CSP | `CSP` | [`csp.jl`](csp.jl) | 66 | | 6.3 | AC-3 | `AC3` | [`csp.jl`](csp.jl) | 67 | | 6.5 | Backtracking-Search | `backtracking_search` | [`csp.jl`](csp.jl) | 68 | | 6.8 | Min-Conflicts | `min_conflicts` | [`csp.jl`](csp.jl) | 69 | | 6.11 | Tree-CSP-Solver | `tree_csp_solver` | [`csp.jl`](csp.jl) | 70 | | 7 | KB | `KnowledgeBase` | [`logic.jl`](logic.jl) | 71 | | 7.1 | KB-Agent | `KnowledgeBaseAgentProgram` | [`logic.jl`](logic.jl) | 72 | | 7.7 | Propositional Logic Sentence | `Expression` | [`logic.jl`](logic.jl) | 73 | | 7.10 | TT-Entails | `tt_entails` | [`logic.jl`](logic.jl) | 74 | | 7.12 | PL-Resolution | `pl_resolution` | [`logic.jl`](logic.jl) | 75 | | 7.14 | Convert to CNF | `to_conjunctive_normal_form` | [`logic.jl`](logic.jl) | 76 | | 7.15 | PL-FC-Entails? | `pl_fc_resolution` | [`logic.jl`](logic.jl) | 77 | | 7.17 | DPLL-Satisfiable? | `dpll_satisfiable` | [`logic.jl`](logic.jl) | 78 | | 7.18 | WalkSAT | `walksat` | [`logic.jl`](logic.jl) | 79 | | 7.20 | Hybrid-Wumpus-Agent | | | 80 | | 7.22 | SATPlan | `sat_plan` | [`logic.jl`](logic.jl) | 81 | | 9 | Subst | `substitute` | [`logic.jl`](logic.jl) | 82 | | 9.1 | Unify | `unify` | [`logic.jl`](logic.jl) | 83 | | 9.3 | FOL-FC-Ask | `fol_fc_ask` | [`logic.jl`](logic.jl) | 84 | | 9.6 | FOL-BC-Ask | `fol_bc_ask` | [`logic.jl`](logic.jl) | 85 | | 9.8 | Append | | | 86 | | 10.1 | Air-Cargo-problem | `air_cargo_pddl` |[`planning.jl`](planning.jl) | 87 | | 10.2 | Spare-Tire-problem | `spare_tire_pddl` |[`planning.jl`](planning.jl) | 88 | | 10.3 | Three-Block-Tower | `three_block_tower_pddl` |[`planning.jl`](planning.jl) | 89 | | 10.7 | Cake-problem | `have_cake_and_eat_cake_too_pddl` |[`planning.jl`](planning.jl) | 90 | | 10.9 | Graphplan | `graphplan` | [`planning.jl`](planning.jl) | 91 | | 10.13 | Partial-Order-Planner | | | 92 | | 11.1 | Job-Shop-Problem-With-Resources | `job_shop_scheduling_pddl` |[`planning.jl`](planning.jl) | 93 | | 11.5 | Hierarchical-Search | `hierarchical_search` | [`planning.jl`](planning.jl) | 94 | | 11.8 | Angelic-Search | | | 95 | | 11.10 | Doubles-Tennis-problem | `doubles_tennis_pddl` | [`planning.jl`](planning.jl) | 96 | | 13 | Discrete Probability Distribution | `ProbabilityDistribution` | [`probability.jl`](probability.jl) | 97 | | 13.1 | DT-Agent | `DecisionTheoreticAgentProgram` | [`probability.jl`](probability.jl) | 98 | | 14.9 | Enumeration-Ask | `enumeration_ask` | [`probability.jl`](probability.jl) | 99 | | 14.11 | Elimination-Ask | `elimination_ask` | [`probability.jl`](probability.jl) | 100 | | 14.13 | Prior-Sample | `prior_sample` | [`probability.jl`](probability.jl) | 101 | | 14.14 | Rejection-Sampling | `rejection_sample` | [`probability.jl`](probability.jl) | 102 | | 14.15 | Likelihood-Weighting | `likelihood_weighting` | [`probability.jl`](probability.jl) | 103 | | 14.16 | Gibbs-Ask | `gibbs_ask` | [`probability.jl`](probability.jl) | 104 | | 15.4 | Forward-Backward | `forward_backward` | [`probability.jl`](probability.jl) | 105 | | 15.6 | Fixed-Lag-Smoothing | `fixed_lag_smoothing` | [`probability.jl`](probability.jl) | 106 | | 15.17 | Particle-Filtering | `particle_filtering` | [`probability.jl`](probability.jl) | 107 | | 16.9 | Information-Gathering-Agent | | | 108 | | 17.4 | Value-Iteration | `value_iteration` | [`mdp.jl`](mdp.jl) | 109 | | 17.7 | Policy-Iteration | `policy_iteration` | [`mdp.jl`](mdp.jl) | 110 | | 17.7 | POMDP-Value-Iteration | | | 111 | | 18.5 | Decision-Tree-Learning | `decision_tree_learning` | [`learning.jl`](learning.jl) | 112 | | 18.8 | Cross-Validation | `cross_validation` | [`learning.jl`](learning.jl) | 113 | | 18.11 | Decision-List-Learning | `decision_list_learning` | [`learning.jl`](learning.jl) | 114 | | 18.24 | Back-Prop-Learning | `back_propagation_learning!` | [`learning.jl`](learning.jl) | 115 | | 18.34 | AdaBoost | `adaboost!` | [`learning.jl`](learning.jl) | 116 | | 19.2 | Current-Best-Learning | `current_best_learning` | [`kl.jl`](kl.jl) | 117 | | 19.3 | Version-Space-Learning | `version_space_learning` | [`kl.jl`](kl.jl) | 118 | | 19.8 | Minimal-Consistent-Det | `minimal_consistent_determination` | [`kl.jl`](kl.jl) | 119 | | 19.12 | FOIL | `foil` | [`kl.jl`](kl.jl) | 120 | | 21.2 | Passive-ADP-Agent | `PassiveADPAgentProgram` | [`rl.jl`](rl.jl) | 121 | | 21.4 | Passive-TD-Agent | `PassiveTDAgentProgram` | [`rl.jl`](rl.jl) | 122 | | 21.8 | Q-Learning-Agent | `QLearningAgentProgram` | [`rl.jl`](rl.jl) | 123 | | 22.1 | HITS | `HITS` | [`nlp.jl`](nlp.jl) | 124 | | 23 | Chart-Parse | `Chart` | [`nlp.jl`](nlp.jl) | 125 | | 23.5 | CYK-Parse | `cyk_parse` | [`nlp.jl`](nlp.jl) | 126 | | 25.9 | Monte-Carlo-Localization| `monte_carlo_localization` | [`probability.jl`](probability.jl) | 127 | 128 | Running tests 129 | ------------- 130 | 131 | All `Base.Test` tests for the aima-julia project can be found in the [tests](https://github.com/aimacode/aima-julia/tree/master/tests) directory. 132 | 133 | ## Acknowledgements 134 | 135 | The algorithms implemented in this project are found from both Russell And Norvig's "Artificial Intelligence - A Modern Approach" and [aima-pseudocode](https://github.com/aimacode/aima-pseudocode). 136 | -------------------------------------------------------------------------------- /tests/run_nlp_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | #The following nlp tests are from the aima-python doctests 10 | 11 | # Test the Grammar DataType methods 12 | check = Dict([Pair("A", [["B", "C"], ["D", "E"]]), 13 | Pair("B", [["E"], ["a"], ["b", "c"]])]); 14 | 15 | @test (Rules(["A"=>"B C | D E", "B"=>"E | a | b c"]) == check); 16 | 17 | check = Dict([Pair("Article", ["the", "a", "an"]), 18 | Pair("Pronoun", ["i", "you", "he"])]); 19 | 20 | @test (Lexicon(["Article"=>"the | a | an", "Pronoun"=>"i | you | he"]) == check); 21 | 22 | rules = Rules(["A"=>"B C | D E", "B"=>"E | a | b c"]); 23 | lexicon = Lexicon(["Article"=>"the | a | an", "Pronoun"=>"i | you | he"]); 24 | grammar = Grammar("Simplegram", rules, lexicon); 25 | 26 | @test (rewrites_for(grammar, "A") == [["B", "C"], ["D", "E"]]); 27 | 28 | @test (is_category(grammar, "the", "Article") == true); 29 | 30 | grammar = aimajulia.nlp.epsilon_chomsky; 31 | 32 | @test (all((function(rule::Tuple) 33 | return (length(rule) == 3); 34 | end), cnf_rules(grammar))); 35 | 36 | lexicon = Lexicon(["Article"=>"the | a | an", "Pronoun"=>"i | you | he"]); 37 | rules = Rules(["S"=>"Article | More | Pronoun", "More"=>"Article Pronoun | Pronoun Pronoun"]); 38 | grammar = Grammar("Simplegram", rules, lexicon); 39 | sentence = generate_random_sentence(grammar, "S"); 40 | 41 | @test (all((function(token::String) 42 | return any((function(terminals::AbstractVector) 43 | return (token in terminals); 44 | end), values(grammar.lexicon)); 45 | end), map(String, split(sentence)))); 46 | 47 | # Test the ProbabilityGrammar DataType methods 48 | check = Dict([Pair("A", [(["B", "C"], 0.3), (["D", "E"], 0.7)]), 49 | Pair("B", [(["E"], 0.1), (["a"], 0.2), (["b", "c"], 0.7)])]); 50 | 51 | @test (ProbabilityRules(["A"=>"B C [0.3] | D E [0.7]", "B"=>"E [0.1] | a [0.2] | b c [0.7]"]) == check); 52 | 53 | check = Dict([Pair("Article", [("the", 0.5), ("a", 0.25), ("an", 0.25)]), 54 | Pair("Pronoun", [("i", 0.4), ("you", 0.3), ("he", 0.3)])]); 55 | 56 | @test (ProbabilityLexicon(["Article"=>"the [0.5] | a [0.25] | an [0.25]", "Pronoun"=>"i [0.4] | you [0.3] | he [0.3]"]) == check); 57 | 58 | rules = ProbabilityRules(["A"=>"B C [0.3] | D E [0.7]", "B"=>"E [0.1] | a [0.2] | b c [0.7]"]); 59 | lexicon = ProbabilityLexicon(["Article"=>"the [0.5] | a [0.25] | an [0.25]", "Pronoun"=>"i [0.4] | you [0.3] | he [0.3]"]); 60 | grammar = ProbabilityGrammar("Simplegram", rules, lexicon); 61 | 62 | @test (rewrites_for(grammar, "A") == [(["B", "C"], 0.3), (["D", "E"], 0.7)]); 63 | 64 | @test (is_category(grammar, "the", "Article") == true); 65 | 66 | grammar = aimajulia.nlp.epsilon_probability_chomsky; 67 | 68 | @test (all((function(rule::Tuple) 69 | return (length(rule) == 4); 70 | end), cnf_rules(grammar))); 71 | 72 | lexicon = ProbabilityLexicon(["Verb"=>"am [0.5] | are [0.25] | is [0.25]", 73 | "Pronoun"=>"i [0.4] | you [0.3] | he [0.3]"]); 74 | rules = ProbabilityRules(["S"=>"Verb [0.5] | More [0.3] | Pronoun [0.1] | nobody is here [0.1]", 75 | "More"=>"Pronoun Verb [0.7] | Pronoun Pronoun [0.3]"]); 76 | grammar = ProbabilityGrammar("Simplegram", rules, lexicon); 77 | sentence = generate_random_sentence(grammar, "S"); 78 | 79 | @test (length(sentence) == 2); 80 | 81 | # Test the Chart DataType 82 | chart = Chart(aimajulia.nlp.epsilon_0); 83 | 84 | @test (length(parse_sentence(chart, "the stench is in 2 2")) == 1); 85 | 86 | # Test the CYK parsing 87 | grammar = aimajulia.nlp.epsilon_probability_chomsky; 88 | words = ["the", "robot", "is", "good"]; 89 | 90 | @test (length(cyk_parse(words, grammar)) == 52); 91 | 92 | # Test the HTML parsing functions 93 | address = "https://en.wikipedia.org/wiki/Ethics"; 94 | 95 | page = load_page_html([address]); 96 | 97 | page_html = page[address]; 98 | 99 | @test ((!contains(page_html, "
")) && (!contains(page_html, ""))); 100 | 101 | test_html_1 = ("Keyword String 1: A man is a male human." 102 | *"Keyword String 2: Like most other male mammals, a man inherits an" 103 | *"X from his mom and a Y from his dad." 104 | *"Links:" 105 | *"href=\"https://google.com.au\"" 106 | *"AIMA book
"; 110 | no_head_test_html_3 = replace(test_html_3, @r_str("(.*)", "s"), ""); 111 | 112 | @test ((!contains(no_head_test_html_3, "")) && (!contains(no_head_test_html_3, ""))); 113 | 114 | @test ((contains(test_html_3, "AIMA book")) && (contains(no_head_test_html_3, "AIMA book"))); 115 | 116 | page_A = Page("A", inlinks=["B", "C", "E"], outlinks=["D"], hub=1, authority=6); 117 | page_B = Page("B", inlinks=["E"], outlinks=["A", "C", "D"], hub=2, authority=5); 118 | page_C = Page("C", inlinks=["B", "E"], outlinks=["A", "D"], hub=3, authority=4); 119 | page_D = Page("D", inlinks=["A", "B", "C", "E"], outlinks=[], hub=4, authority=3); 120 | page_E = Page("E", inlinks=[], outlinks=["A", "B", "C", "D", "F"], hub=5, authority=2); 121 | page_F = Page("F", inlinks=["E"], outlinks=[], hub=6, authority=1); 122 | 123 | page_dict = Dict([Pair(page_A.address, page_A), 124 | Pair(page_B.address, page_B), 125 | Pair(page_C.address, page_C), 126 | Pair(page_D.address, page_D), 127 | Pair(page_E.address, page_E), 128 | Pair(page_F.address, page_F)]); 129 | 130 | pages_index = page_dict; 131 | 132 | pages_content = Dict([Pair(page_A.address, test_html_1), 133 | Pair(page_B.address, test_html_2), 134 | Pair(page_C.address, test_html_1), 135 | Pair(page_D.address, test_html_2), 136 | Pair(page_E.address, test_html_1), 137 | Pair(page_F.address, test_html_2)]); 138 | 139 | @test (Set(determine_inlinks(page_A, pages_index)) == Set(["B", "C", "E"])); 140 | 141 | @test (Set(determine_inlinks(page_E, pages_index)) == Set([])); 142 | 143 | @test (Set(determine_inlinks(page_F, pages_index)) == Set(["E"])); 144 | 145 | test_page_A = page_dict[page_A.address]; 146 | test_outlinks = find_outlinks(test_page_A, pages_content, only_wikipedia_urls); 147 | 148 | @test ("https://en.wikipedia.org/wiki/TestThing" in test_outlinks); 149 | 150 | @test (!("https://google.com.au" in test_outlinks)); 151 | 152 | pages = Dict(collect((k, page_dict[k]) for k in ("F",))); 153 | pages_two = Dict(collect((k, page_dict[k]) for k in ("A", "E"))); 154 | expanded_pages = expand_pages(pages, pages_index); 155 | 156 | @test (all(x in keys(expanded_pages) for x in ("F", "E"))); 157 | 158 | @test (all(!(x in keys(expanded_pages)) for x in ("A", "B", "C", "D"))); 159 | 160 | expanded_pages = expand_pages(pages_two, pages_index); 161 | 162 | @test (all(x in keys(expanded_pages) for x in ("A", "B", "C", "D", "E", "F"))); 163 | 164 | pages = relevant_pages("his dad", pages_index, pages_content); 165 | 166 | @test (all((x in keys(pages)) for x in ("A", "C", "E"))); 167 | 168 | @test (all((!(x in keys(pages))) for x in ("B", "D", "F"))); 169 | 170 | pages = relevant_pages("mom and dad", pages_index, pages_content); 171 | 172 | @test (all((x in keys(pages)) for x in ("A", "B", "C", "D", "E", "F"))); 173 | 174 | pages = relevant_pages("philosophy", pages_index, pages_content); 175 | 176 | @test (all((!(x in keys(pages))) for x in ("A", "B", "C", "D", "E", "F"))); 177 | 178 | normalize_pages(page_dict); 179 | println("pages_index hubs: ", collect(page.hub for page in values(pages_index))); 180 | 181 | expected_hubs = [(1 / sqrt(91)), (2 / sqrt(91)), (3 / sqrt(91)), (4 / sqrt(91)), (5 / sqrt(91)), (6 / sqrt(91))]; 182 | expected_authorities = collect(reverse(expected_hubs)); 183 | 184 | @test (length(expected_hubs) == length(expected_authorities) == length(pages_index)); 185 | 186 | sorted_pages = sort(collect(pages_index), 187 | lt=(function(p1::Pair, p2::Pair) 188 | if (p1.first < p2.first) 189 | return true; 190 | else 191 | return false; 192 | end 193 | end)); 194 | 195 | @test (expected_hubs == collect(page.hub for (address, page) in sorted_pages)); 196 | 197 | @test (expected_authorities == collect(page.authority for (address, page) in sorted_pages)); 198 | 199 | convergence = ConvergenceDetector(); 200 | detect_convergence(convergence, pages_index); 201 | 202 | @test (detect_convergence(convergence, pages_index)); 203 | 204 | new_pages_index = deepcopy(pages_index); 205 | 206 | for (address, page) in new_pages_index 207 | page.hub = page.hub + 0.0003; 208 | page.authority = page.authority + 0.0004; 209 | end 210 | 211 | @test (detect_convergence(convergence, new_pages_index)); 212 | 213 | for (address, page) in new_pages_index 214 | page.hub = page.hub + 3000000; 215 | page.authority = page.authority + 3000000; 216 | end 217 | 218 | @test (!detect_convergence(convergence, new_pages_index)); 219 | 220 | @test (sort(get_inlinks(page_dict["A"], pages_index)) == page_dict["A"].inlinks); 221 | 222 | @test (sort(get_outlinks(page_dict["A"], pages_index, pages_content)) == page_dict["A"].outlinks); 223 | 224 | HITS("inherit", pages_index, pages_content); 225 | 226 | authorities = [page_A.authority, page_B.authority, page_C.authority, page_D.authority, page_E.authority, page_F.authority]; 227 | hubs = [page_A.hub, page_B.hub, page_C.hub, page_D.hub, page_E.hub, page_F.hub]; 228 | 229 | @test (reduce(max, authorities) == page_D.authority); 230 | 231 | @test (reduce(max, hubs) == page_E.hub); 232 | 233 | -------------------------------------------------------------------------------- /tests/run_learning_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | #The following learning tests are from the aima-python doctests 10 | 11 | @test (repr(euclidean_distance([1, 2], [3, 4])) == "2.8284271247461903"); 12 | 13 | @test (repr(euclidean_distance([1, 2, 3], [4, 5, 6])) == "5.196152422706632"); 14 | 15 | @test (repr(euclidean_distance([0, 0, 0], [0, 0, 0])) == "0.0"); 16 | 17 | @test (root_mean_square_error([2, 2], [2, 2]) == 0); 18 | 19 | @test (root_mean_square_error([0, 0], [0, 1]) == sqrt(0.5)); 20 | 21 | @test (root_mean_square_error([1, 0], [0, 1]) == 1); 22 | 23 | @test (root_mean_square_error([0, 0], [0, -1]) == sqrt(0.5)); 24 | 25 | @test (root_mean_square_error([0, 0.5], [0, -0.5]) == sqrt(0.5)); 26 | 27 | @test (manhattan_distance([2, 2], [2, 2]) == 0); 28 | 29 | @test (manhattan_distance([0, 0], [0, 1]) == 1); 30 | 31 | @test (manhattan_distance([1, 0], [0, 1]) == 2); 32 | 33 | @test (manhattan_distance([0, 0], [0, -1]) == 1); 34 | 35 | @test (manhattan_distance([0, 0.5], [0, -0.5]) == 1); 36 | 37 | @test (mean_boolean_error([1, 1], [0, 0]) == 1) 38 | 39 | @test (mean_boolean_error([0, 1], [1, 0]) == 1) 40 | 41 | @test (mean_boolean_error([1, 1], [0, 1]) == 0.5) 42 | 43 | @test (mean_boolean_error([0, 0], [0, 0]) == 0) 44 | 45 | @test (mean_boolean_error([1, 1], [1, 1]) == 0) 46 | 47 | @test (mean_error([2, 2], [2, 2]) == 0); 48 | 49 | @test (mean_error([0, 0], [0, 1]) == 0.5); 50 | 51 | @test (mean_error([1, 0], [0, 1]) == 1); 52 | 53 | @test (mean_error([0, 0], [0, -1]) == 0.5); 54 | 55 | @test (mean_error([0, 0.5], [0, -0.5]) == 0.5); 56 | 57 | @test (gaussian(1,0.5,0.7) == 0.6664492057835993); 58 | 59 | @test (gaussian(5,2,4.5) == 0.19333405840142462); 60 | 61 | @test (gaussian(3,1,3) == 0.3989422804014327); 62 | 63 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv", exclude=[4]); 64 | 65 | @test (iris_dataset.inputs == [1, 2, 3]); 66 | 67 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 68 | means_dict, deviations_dict = find_means_and_deviations(iris_dataset); 69 | 70 | @test (means_dict["setosa"][1] == 5.006); 71 | 72 | @test (means_dict["versicolor"][1] == 5.936); 73 | 74 | @test (means_dict["virginica"][1] == 6.587999999999999); 75 | 76 | @test (deviations_dict["setosa"][1] == 0.3524896872134513); 77 | 78 | @test (deviations_dict["versicolor"][1] == 0.5161711470638634); 79 | 80 | @test (deviations_dict["virginica"][1] == 0.6358795932744321); 81 | 82 | cpd = CountingProbabilityDistribution(); 83 | 84 | for i in 1:10000 85 | add(cpd, rand(RandomDeviceInstance, ["1", "2", "3", "4", "5", "6"])); 86 | end 87 | 88 | probabilities = collect(cpd[n] for n in ("1", "2", "3", "4", "5", "6")); 89 | 90 | @test ((1.0/7.0) <= reduce(min, probabilities) <= reduce(max, probabilities) <= (1.0/5.0)); 91 | 92 | zoo_dataset = DataSet(name="zoo", examples="./aima-data/zoo.csv"); 93 | 94 | pl = PluralityLearner(zoo_dataset); 95 | 96 | @test (predict(pl, [1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 4, 1, 0, 1]) == "mammal"); 97 | 98 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 99 | 100 | # Naive Discrete Model 101 | nbdm = NaiveBayesLearner(iris_dataset, continuous=false); 102 | 103 | @test (predict(nbdm, [5, 3, 1, 0.1]) == "setosa"); 104 | 105 | @test (predict(nbdm, [6, 3, 4, 1.1]) == "versicolor"); 106 | 107 | @test (predict(nbdm, [7.7, 3, 6, 2]) == "virginica"); 108 | 109 | # Naive Continuous Model 110 | nbcm = NaiveBayesLearner(iris_dataset, continuous=true); 111 | 112 | @test (predict(nbcm, [5, 3, 1, 0.1]) == "setosa"); 113 | 114 | @test (predict(nbcm, [6, 5, 3, 1.5]) == "versicolor"); 115 | 116 | @test (predict(nbcm, [7, 3, 6.5, 2]) == "virginica"); 117 | 118 | # Naive Conditional Probability Model 119 | d1 = CountingProbabilityDistribution(vcat(fill('a', 50), fill('b', 30), fill('c', 15))); 120 | d2 = CountingProbabilityDistribution(vcat(fill('a', 30), fill('b', 45), fill('c', 20))); 121 | d3 = CountingProbabilityDistribution(vcat(fill('a', 20), fill('b', 20), fill('c', 35))); 122 | nbsm = NaiveBayesLearner(Dict([(("First", 0.5), d1), (("Second", 0.3), d2), (("Third", 0.2), d3)]), simple=true); 123 | 124 | @test (predict(nbsm, "aab") == "First"); 125 | 126 | @test (predict(nbsm, ['b', 'b']) == "Second"); 127 | 128 | @test (predict(nbsm, "ccbcc") == "Third"); 129 | 130 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 131 | 132 | k_nearest_neighbors = NearestNeighborLearner(iris_dataset, 3); 133 | 134 | @test (predict(k_nearest_neighbors, [5, 3, 1, 0.1]) == "setosa"); 135 | 136 | @test (predict(k_nearest_neighbors, [6, 5, 3, 1.5]) == "versicolor"); 137 | 138 | @test (predict(k_nearest_neighbors, [7.5, 4, 6, 2]) == "virginica"); 139 | 140 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 141 | 142 | dtl = DecisionTreeLearner(iris_dataset); 143 | 144 | @test (predict(dtl, [5, 3, 1, 0.1]) == "setosa"); 145 | 146 | @test (predict(dtl, [6, 5, 3, 1.5]) == "versicolor"); 147 | 148 | @test (predict(dtl, [7.5, 4, 6, 2]) == "virginica"); 149 | 150 | function test_rf_predictions(ex1_results::AbstractVector, ex2_results::AbstractVector, ex3_results::AbstractVector) 151 | local rf::RandomForest = RandomForest(iris_dataset); 152 | push!(ex1_results, (predict(rf, [5, 3, 1, 0.1]) == "setosa")); 153 | push!(ex2_results, (predict(rf, [6, 5, 3, 1]) == "versicolor")); 154 | push!(ex3_results, (predict(rf, [7.5, 4, 6, 2]) == "virginica")); 155 | nothing; 156 | end 157 | 158 | setosa_results = Array{Bool, 1}(); 159 | versicolor_results = Array{Bool, 1}(); 160 | virginica_results = Array{Bool, 1}(); 161 | # Run test_rf_predictions() 1000 times. 162 | println("@time for i in 1:1000\n\ttest_rf_predictions(setosa_results, versicolor_results, virginica_results);\nend"); 163 | @time for i in 1:1000 164 | test_rf_predictions(setosa_results, versicolor_results, virginica_results); 165 | end 166 | 167 | setosa_results_count = count((function(b::Bool) 168 | return b; 169 | end), setosa_results); 170 | versicolor_results_count = count((function(b::Bool) 171 | return b; 172 | end), versicolor_results); 173 | virginica_results_count = count((function(b::Bool) 174 | return b; 175 | end), virginica_results); 176 | 177 | # lowest setosa_results_count result previously obtained was 970 178 | @test (setosa_results_count >= 960); 179 | 180 | # lowest versicolor_results_count result previously obtained was 959 181 | @test (versicolor_results_count >= 950); 182 | 183 | # lowest virginica_results_count result previously obtained was 996 184 | @test (virginica_results_count >= 990); 185 | 186 | println(); 187 | println("setosa assert count (out of 1000): ", setosa_results_count); 188 | println("setosa assertion failure rate: approximately ", Float64(1000 - setosa_results_count)/10.0, "%"); 189 | println("versicolor assert count (out of 1000): ", versicolor_results_count); 190 | println("versicolor assertion failure rate: approximately ", Float64(1000 - versicolor_results_count)/10.0, "%"); 191 | println("virginica assert count (out of 1000): ", virginica_results_count); 192 | println("virginica assertion failure rate: approximately ", Float64(1000 - virginica_results_count)/10.0, "%"); 193 | println(); 194 | 195 | weights = random_weights(-0.5, 0.5, 10); 196 | 197 | @test (length(weights) == 10); 198 | 199 | @test (all(((weight >= -0.5) && (weight <= 0.5)) for weight in weights)); 200 | 201 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 202 | 203 | # The DataType of the example classification must match the eltype of the classes array. 204 | 205 | classes = map(SubString{String}, ["setosa", "versicolor", "virginica"]); 206 | 207 | classes_to_numbers(iris_dataset, classes); 208 | 209 | nnl = NeuralNetworkLearner(iris_dataset, hidden_layers_sizes=[5], learning_rate=0.15, epochs=75); 210 | 211 | neural_network_learner_score = aimajulia.grade_learner(nnl, 212 | [([5, 3, 1, 0.1], 1), 213 | ([5, 3.5, 1, 0], 1), 214 | ([6, 3, 4, 1.1], 2), 215 | ([6, 2, 3.5, 1], 2), 216 | ([7.5, 4, 6, 2], 3), 217 | ([7, 3, 6, 2.5], 3)]); 218 | 219 | println("neural network learner score (out of 1.0): ", neural_network_learner_score); 220 | 221 | # Allow up to 2 failed tests. 222 | @test (neural_network_learner_score >= (2/3)); 223 | 224 | neural_network_learner_error_ratio = aimajulia.error_ratio(nnl, iris_dataset); 225 | 226 | println("neural network learner error ratio: ", (neural_network_learner_error_ratio * 100), "%"); 227 | println(); 228 | 229 | # NeuralNetworkLearner previously had an error ratio of 0.33333333333333337. 230 | @test (neural_network_learner_error_ratio < 0.40); 231 | 232 | iris_dataset = DataSet(name="iris", examples="./aima-data/iris.csv"); 233 | 234 | classes_to_numbers(iris_dataset, nothing); 235 | 236 | pl = PerceptronLearner(iris_dataset); 237 | 238 | perceptron_learner_score = aimajulia.grade_learner(pl, 239 | [([5, 3, 1, 0.1], 1), 240 | ([5, 3.5, 1, 0], 1), 241 | ([6, 3, 4, 1.1], 2), 242 | ([6, 2, 3.5, 1], 2), 243 | ([7.5, 4, 6, 2], 3), 244 | ([7, 3, 6, 2.5], 3)]); 245 | 246 | println("perceptron learner score (out of 1.0): ", perceptron_learner_score); 247 | 248 | # Allow up to 3 failed tests. 249 | @test (perceptron_learner_score >= (1/2)); 250 | 251 | perceptron_learner_error_ratio = aimajulia.error_ratio(pl, iris_dataset); 252 | 253 | println("perceptron learner error ratio: ", (perceptron_learner_error_ratio * 100), "%"); 254 | println(); 255 | 256 | @test (perceptron_learner_error_ratio < 0.40); 257 | 258 | @test (weighted_mode("abbaa", [1, 2, 3, 1, 2]) == "b"); 259 | 260 | @test (weighted_mode(["a", "b", "b", "a", "a"], [1, 2, 3, 1, 2]) == "b"); 261 | 262 | @test (weighted_replicate(["A", "B", "C"], [1, 2, 1], 4) == ["A", "B", "B", "C"]); 263 | 264 | -------------------------------------------------------------------------------- /tests/run_probability_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following probability tests are from the aima-python doctests 8 | 9 | cpt = variable_node(aimajulia.burglary_network, "Alarm"); 10 | event = Dict([Pair("Burglary", true), Pair("Earthquake", true)]); 11 | 12 | @test probability(cpt, true, event) == 0.95; 13 | 14 | event["Burglary"] = false; 15 | 16 | @test probability(cpt, false, event) == 0.71; 17 | 18 | s = Dict([Pair("A", true), 19 | Pair("B", false), 20 | Pair("C", true), 21 | Pair("D", false)]); 22 | 23 | @test consistent_with(s, Dict()); 24 | 25 | @test consistent_with(s, s); 26 | 27 | @test !consistent_with(s, Dict(Pair("A", false))); 28 | 29 | @test !consistent_with(s, Dict(Pair("D", true))); 30 | 31 | p = ProbabilityDistribution(variable_name="Flip"); 32 | p["H"], p["T"] = 0.25, 0.75; 33 | 34 | @test p["H"] == 0.25; 35 | 36 | p = ProbabilityDistribution(variable_name="X", frequencies=Dict([Pair("lo", 125), 37 | Pair("med", 375), 38 | Pair("hi", 500)])); 39 | 40 | @test ((p["lo"], p["med"], p["hi"]) == (0.125, 0.375, 0.5)); 41 | 42 | p = JointProbabilityDistribution(["X", "Y"]); 43 | p[(1,1)] = 0.25; 44 | 45 | @test p[(1,1)] == 0.25; 46 | 47 | p[Dict([Pair("X", 0), Pair("Y", 1)])] = 0.5; 48 | 49 | @test p[Dict([Pair("X", 0), Pair("Y", 1)])] == 0.5; 50 | 51 | @test (event_values(Dict([Pair("A", 10), Pair("B", 9), Pair("C", 8)]), ["C", "A"]) == (8, 10)); 52 | 53 | @test (event_values((1, 2), ["C", "A"]) == (1, 2)); 54 | 55 | p = JointProbabilityDistribution(["X", "Y"]); 56 | p[(0, 0)], p[(0, 1)], p[(1, 1)], p[(2, 1)] = 0.25, 0.5, 0.125, 0.125; 57 | 58 | @test enumerate_joint(["Y"], Dict([Pair("X", 0)]), p) == 0.75; 59 | 60 | @test enumerate_joint(["X"], Dict([Pair("Y", 2)]), p) == 0; 61 | 62 | @test enumerate_joint(["X"], Dict([Pair("Y", 1)]), p) == 0.75; 63 | 64 | @test show_approximation(enumerate_joint_ask("X", Dict([Pair("Y", 1)]), p)) == "0: 0.6667, 1: 0.1667, 2: 0.1667"; 65 | 66 | bn = BayesianNetworkNode("X", "Burglary", Dict([Pair(true, 0.2), Pair(false, 0.625)])); 67 | 68 | @test probability(bn, false, Dict([Pair("Burglary", false), Pair("Earthquake", true)])) == 0.375; 69 | 70 | @test probability(BayesianNetworkNode("W", "", 0.75), false, Dict([Pair("Random", true)])) == 0.25; 71 | 72 | X = BayesianNetworkNode("X", "Burglary", Dict([Pair(true, 0.2), Pair(false, 0.625)])); 73 | 74 | @test (sample(X, Dict([Pair("Burglary", false), Pair("Earthquake", true)])) in (true, false)); 75 | 76 | Z = BayesianNetworkNode("Z", "P Q", Dict([Pair((true, true), 0.2), 77 | Pair((true, false), 0.3), 78 | Pair((false, true), 0.5), 79 | Pair((false, false), 0.7)])); 80 | 81 | @test (sample(Z, Dict([Pair("P", true), Pair("Q", false)])) in (true, false)); 82 | 83 | @test show_approximation(enumeration_ask("Burglary", 84 | Dict([Pair("JohnCalls", true), Pair("MaryCalls", true)]), 85 | aimajulia.burglary_network)) == "false: 0.7158, true: 0.2842"; 86 | 87 | @test show_approximation(elimination_ask("Burglary", 88 | Dict([Pair("JohnCalls", true), Pair("MaryCalls", true)]), 89 | aimajulia.burglary_network)) == "false: 0.7158, true: 0.2842"; 90 | 91 | # RandomDevice() does not allow seeding. 92 | 93 | mt_rng = MersenneTwister(21); 94 | 95 | p = rejection_sampling("Earthquake", Dict(), aimajulia.burglary_network, 1000, mt_rng); 96 | 97 | @test ((p[true], p[false]) == (0.002, 0.998)); 98 | 99 | mt_rng = srand(mt_rng, 71); 100 | 101 | p = likelihood_weighting("Earthquake", Dict(), aimajulia.burglary_network, 1000, mt_rng); 102 | 103 | @test ((p[true], p[false]) == (0.0, 1.0)); 104 | 105 | mt_rng = srand(mt_rng, 1017); 106 | 107 | @test (show_approximation(likelihood_weighting("Burglary", 108 | Dict([Pair("JohnCalls", true), Pair("MaryCalls", true)]), 109 | aimajulia.burglary_network, 110 | 10000, 111 | mt_rng)) == "false: 0.718, true: 0.282"); 112 | 113 | umbrella_prior = [0.5, 0.5]; 114 | umbrella_transition = [[0.7, 0.3], [0.3, 0.7]]; 115 | umbrella_sensor = [[0.9, 0.2], [0.1, 0.8]]; 116 | umbrella_hmm = HiddenMarkovModel(umbrella_transition, umbrella_sensor); 117 | umbrella_evidence = [true, true, false, true, true]; # Umbrella observation sequence (Fig. 15.5b) 118 | 119 | @test (repr(forward_backward(umbrella_hmm, umbrella_evidence, umbrella_prior)) == 120 | "Array{Float64,1}[[0.646936, 0.353064], [0.867339, 0.132661], [0.820419, 0.179581], [0.307484, 0.692516], [0.820419, 0.179581], [0.867339, 0.132661]]"); 121 | 122 | umbrella_evidence = [true, false, true, false, true]; 123 | 124 | @test (repr(forward_backward(umbrella_hmm, umbrella_evidence, umbrella_prior)) == 125 | "Array{Float64,1}[[0.587074, 0.412926], [0.717684, 0.282316], [0.2324, 0.7676], [0.607195, 0.392805], [0.2324, 0.7676], [0.717684, 0.282316]]"); 126 | 127 | umbrella_prior = [0.5, 0.5]; 128 | umbrella_transition = [[0.7, 0.3], [0.3, 0.7]]; 129 | umbrella_sensor = [[0.9, 0.2], [0.1, 0.8]]; 130 | umbrella_hmm = HiddenMarkovModel(umbrella_transition, umbrella_sensor); 131 | umbrella_evidence = [true, false, true, false, true]; 132 | e_t = false; 133 | t = 4; 134 | d = 2; 135 | 136 | @test (repr(fixed_lag_smoothing(e_t, umbrella_hmm, d, umbrella_evidence; t=t)) == "[0.111111, 0.888889]"); 137 | 138 | d = 5; 139 | 140 | @test (fixed_lag_smoothing(e_t, umbrella_hmm, d, umbrella_evidence; t=t) == nothing); 141 | 142 | umbrella_evidence = [true, true, false, true, true]; 143 | e_t = true; 144 | d = 1; 145 | 146 | @test (repr(fixed_lag_smoothing(e_t, umbrella_hmm, d, umbrella_evidence; t=t)) == "[0.993865, 0.00613497]"); 147 | 148 | N = 10; 149 | umbrella_evidence = true; 150 | umbrella_transition = [[0.7, 0.3], [0.3, 0.7]]; 151 | umbrella_sensor = [[0.9, 0.2], [0.1, 0.8]]; 152 | umbrella_hmm = HiddenMarkovModel(umbrella_transition, umbrella_sensor); 153 | s = particle_filtering(umbrella_evidence, N, umbrella_hmm); 154 | 155 | @test length(s) == N; 156 | 157 | @test all(state in ("A", "B") for state in s); 158 | 159 | # Probability Distribution Example (p.493) 160 | p = ProbabilityDistribution(variable_name="Weather"); 161 | p["sunny"] = 0.6; 162 | p["rain"] = 0.1; 163 | p["cloudy"] = 0.29; 164 | p["snow"] = 0.01; 165 | 166 | @test p["rain"] == 0.1; 167 | 168 | # Joint Probability Distribution Example (Fig. 13.3) 169 | p = JointProbabilityDistribution(["Toothache", "Cavity", "Catch"]); 170 | p[(true, true, true)] = 0.108; 171 | p[(true, true, false)] = 0.012; 172 | p[(false, true, true)] = 0.072; 173 | p[(false, true, false)] = 0.008; 174 | p[(true, false, true)] = 0.016; 175 | p[(true, false, false)] = 0.064; 176 | p[(false, false, true)] = 0.144; 177 | p[(false, false, false)] = 0.576; 178 | 179 | @test p[(true, true, true)] == 0.108; 180 | 181 | # P(Cavity | Toothache) example from page 500 182 | probability_cavity = enumerate_joint_ask("Cavity", Dict([Pair("Toothache", true)]), p); 183 | 184 | @test show_approximation(probability_cavity) == "false: 0.4, true: 0.6"; 185 | 186 | @test (0.6 - 0.001 < probability_cavity[true] < 0.6 + 0.001); 187 | 188 | @test (0.4 - 0.001 < probability_cavity[false] < 0.4 + 0.001); 189 | 190 | # Seed the RNG for Monte Carlo localization Base.Tests 191 | mt_rng = MersenneTwister(sum(Vector{UInt8}("aima-julia"))); 192 | 193 | m = MonteCarloLocalizationMap([0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 0; 194 | 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0; 195 | 1 1 1 1 1 1 1 1 0 0 1 1 1 0 1 1 0; 196 | 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0; 197 | 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 0; 198 | 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0; 199 | 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0; 200 | 0 0 1 1 1 1 1 0 0 0 1 1 1 0 1 1 0; 201 | 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 1 0; 202 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0; 203 | 0 0 1 1 1 1 1 0 0 0 1 1 1 0 0 1 0], 204 | rng=mt_rng); 205 | 206 | """ 207 | P_motion_sample(kinematic_state::Tuple, v::Tuple, w::Int64) 208 | 209 | Return a sample from the possible kinematic states (using a single element 210 | probability distribution). 211 | """ 212 | function P_motion_sample(kinematic_state::Tuple, v::Tuple, w::Int64) 213 | local position::Tuple = kinematic_state[1:2]; 214 | local orientation::Int64 = kinematic_state[3]; 215 | 216 | # Rotate the robot. 217 | orientation = (orientation + w) % 4; 218 | 219 | for i in 1:orientation 220 | v = (v[2], -v[1]); 221 | end 222 | 223 | position = (position[1] + v[1], position[2] + v[2]); 224 | 225 | return (position..., orientation); 226 | end 227 | 228 | """ 229 | P_sensor(x::Int64, y::Int64) 230 | 231 | Return the conditional probability for the range sensor noise reading. 232 | """ 233 | function P_sensor(x::Int64, y::Int64) 234 | if (x == y) 235 | return 0.8; 236 | elseif (abs(x - y) <= 2) 237 | return 0.05; 238 | else 239 | return 0.0; 240 | end 241 | end 242 | 243 | a = Dict([Pair("v", (0, 0)), Pair("w", 0.0)]); 244 | z = (2, 4, 1, 6); 245 | S = monte_carlo_localization(a, z, 1000, P_motion_sample, P_sensor, m); 246 | grid_1 = fill(0, 11, 17); 247 | 248 | for (x, y, v) in S 249 | if ((0 < x <= 11) && (0 < y <= 17)) 250 | grid_1[x, y] = grid_1[x, y] + 1; 251 | end 252 | end 253 | 254 | println("GRID 1:"); 255 | for x in 1:size(grid_1)[1] 256 | for y in 1:size(grid_1)[2] 257 | print(grid_1[x, y], " "); 258 | end 259 | println(); 260 | end 261 | println(); 262 | 263 | a = Dict([Pair("v", (0, 1)), Pair("w", 0.0)]); 264 | z = (2, 3, 5, 7); 265 | S = monte_carlo_localization(a, z, 1000, P_motion_sample, P_sensor, m, S); 266 | grid_2 = fill(0, 11, 17); 267 | 268 | for (x, y, v) in S 269 | if ((0 < x <= 11) && (0 < y <= 17)) 270 | grid_2[x, y] = grid_2[x, y] + 1; 271 | end 272 | end 273 | 274 | println("GRID 2:"); 275 | for x in 1:size(grid_2)[1] 276 | for y in 1:size(grid_2)[2] 277 | print(grid_2[x, y], " "); 278 | end 279 | println(); 280 | end 281 | println(); 282 | 283 | @test (grid_2[7, 8] > 700); 284 | 285 | -------------------------------------------------------------------------------- /tests/run_text_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | #The following text tests are from the aima-python doctests 10 | 11 | # Base.Test tests for unigram and n-gram of words 12 | flatland = String(read("./aima-data/EN-text/flatland.txt")); 13 | word_sequence = extract_words(flatland); 14 | P1 = UnigramWordModel(word_sequence); 15 | P2 = NgramWordModel(2, word_sequence); 16 | P3 = NgramWordModel(3, word_sequence); 17 | 18 | @test (top(P1, 5) == [(2081,"the"),(1479,"of"),(1021,"and"),(1008,"to"),(850,"a")]); 19 | 20 | @test (top(P2, 5) == [(368,("of","the")),(152,("to","the")),(152,("in","the")),(86,("of","a")),(80,("it","is"))]); 21 | 22 | @test (top(P3, 5) == [(30,("a","straight","line")), 23 | (19,("of","three","dimensions")), 24 | (16,("the","sense","of")), 25 | (13,("by","the","sense")), 26 | (13,("as","well","as"))]); 27 | 28 | @test (abs(P1["the"] - 0.0611) <= (0.001 * max(abs(P1["the"]), abs(0.0611)))); 29 | 30 | @test (abs(P2[("of", "the")] - 0.0108) <= (0.01 * max(abs(P2[("of", "the")]), abs(0.0108)))); 31 | 32 | @test (abs(P3[("so", "as", "to")] - 0.000323) <= (0.001 * max(abs(P3[("so", "as", "to")]), abs(0.000323)))); 33 | 34 | @test (!haskey(P2.conditional_probabilities, ("went",))); 35 | 36 | @test (P3.conditional_probabilities[("in", "order")].dict == Dict([Pair("to", 6)])); 37 | 38 | test_string = "unigram"; 39 | word_sequence = extract_words(test_string); 40 | P1 = UnigramWordModel(word_sequence); 41 | 42 | @test (P1.dict == Dict([Pair("unigram", 1)])); 43 | 44 | test_string = "bigram text"; 45 | word_sequence = extract_words(test_string); 46 | P2 = NgramWordModel(2, word_sequence); 47 | 48 | @test (P2.dict == Dict([Pair(("bigram", "text"), 1)])); 49 | 50 | test_string = "test trigram text here"; 51 | word_sequence = extract_words(test_string); 52 | P3 = NgramWordModel(3, word_sequence); 53 | 54 | @test (haskey(P3.dict, ("test", "trigram", "text"))); 55 | 56 | @test (haskey(P3.dict, ("trigram", "text", "here"))); 57 | 58 | # Base.Test tests for canonicalizing text 59 | @test (extract_words("``EGAD'' Edgar cried.") == ["egad", "edgar", "cried"]); 60 | 61 | @test (canonicalize_text("``EGAD'' Edgar cried.") == "egad edgar cried"); 62 | 63 | # Base.Test tests for samples() methods 64 | story = String(read("./aima-data/EN-text/flatland.txt")); 65 | story = story*String(read("./aima-data/gutenberg.txt")); 66 | word_sequence = extract_words(story); 67 | P1 = UnigramWordModel(word_sequence); 68 | P2 = NgramWordModel(2, word_sequence); 69 | P3 = NgramWordModel(3, word_sequence); 70 | 71 | @test (length(split(samples(P1, 10))) == 10); 72 | 73 | @test (length(split(samples(P2, 10))) == 10); 74 | 75 | @test (length(split(samples(P3, 10))) == 10); 76 | 77 | # Base.Test tests for unigram and n-gram of characters/letters 78 | test_string = "test unigram"; 79 | word_sequence = extract_words(test_string); 80 | P1 = UnigramCharModel(word_sequence); 81 | expected_unigrams = Dict([Pair('n',1), 82 | Pair('g', 1), 83 | Pair('t', 2), 84 | Pair('a', 1), 85 | Pair('u', 1), 86 | Pair('i', 1), 87 | Pair('m', 1), 88 | Pair('e', 1), 89 | Pair('s', 1), 90 | Pair('r', 1)]); 91 | 92 | @test (length(P1.dict) == length(expected_unigrams)); 93 | 94 | @test (all(haskey(P1.dict, character) for character in setdiff(Set(test_string), Set(" ")))); 95 | 96 | test_string = "alpha beta"; 97 | word_sequence = extract_words(test_string); 98 | P1 = NgramCharModel(1, word_sequence); 99 | 100 | @test (length(P1.dict) == length(Set(collect(test_string)))); 101 | 102 | @test (all((function(c::Char) 103 | return haskey(P1.dict, (c,)); 104 | end), 105 | Set(collect(test_string)))); 106 | 107 | test_string = "bigram"; 108 | word_sequence = extract_words(test_string); 109 | P2 = NgramCharModel(2, word_sequence); 110 | expected_bigrams = Dict([Pair((' ', 'b'), 1), 111 | Pair(('b', 'i'), 1), 112 | Pair(('i', 'g'), 1), 113 | Pair(('g', 'r'), 1), 114 | Pair(('r', 'a'), 1), 115 | Pair(('a', 'm'), 1)]); 116 | 117 | @test (length(P2.dict) == length(expected_bigrams)); 118 | 119 | @test (all(haskey(P2.dict, key) for key in keys(expected_bigrams))); 120 | 121 | @test (all((P2.dict[key] == expected_bigrams[key]) for key in keys(expected_bigrams))); 122 | 123 | test_string = "trigram"; 124 | word_sequence = extract_words(test_string); 125 | P3 = NgramCharModel(3, word_sequence); 126 | expected_trigrams = Dict([Pair((' ', 't', 'r'), 1), 127 | Pair(('t', 'r', 'i'), 1), 128 | Pair(('r', 'i', 'g'), 1), 129 | Pair(('i', 'g', 'r'), 1), 130 | Pair(('g', 'r', 'a'), 1), 131 | Pair(('r', 'a', 'm'), 1)]); 132 | 133 | @test (length(P3.dict) == length(expected_trigrams)); 134 | 135 | @test (all(haskey(P3.dict, key) for key in keys(expected_trigrams))); 136 | 137 | @test (all((P3.dict[key] == expected_trigrams[key]) for key in keys(expected_trigrams))); 138 | 139 | test_string = "trigram trigram trigram"; 140 | word_sequence = extract_words(test_string); 141 | P3 = NgramCharModel(3, word_sequence); 142 | expected_trigrams = Dict([Pair((' ', 't', 'r'), 3), 143 | Pair(('t', 'r', 'i'), 3), 144 | Pair(('r', 'i', 'g'), 3), 145 | Pair(('i', 'g', 'r'), 3), 146 | Pair(('g', 'r', 'a'), 3), 147 | Pair(('r', 'a', 'm'), 3)]); 148 | 149 | @test (length(P3.dict) == length(expected_trigrams)); 150 | 151 | @test (all(haskey(P3.dict, key) for key in keys(expected_trigrams))); 152 | 153 | @test (all((P3.dict[key] == expected_trigrams[key]) for key in keys(expected_trigrams))); 154 | 155 | # Base.Test tests for encoding 156 | @test (shift_encode("This is a secret message.", 17) == "Kyzj zj r jvtivk dvjjrxv."); 157 | 158 | @test (rot13("Hello, world!") == "Uryyb, jbeyq!"); 159 | 160 | @test (reduce(*, map((function(c::Char) 161 | if (c == ' ') 162 | return String(['s', ' ', c]); 163 | else 164 | return String([c]); 165 | end 166 | end), 167 | collect("orange apple lemon "))) == "oranges apples lemons "); 168 | 169 | # Base.Test tests for decoding 170 | flatland = readstring("./aima-data/EN-text/flatland.txt"); 171 | ring = ShiftCipherDecoder(flatland); 172 | 173 | @test (decode_text(ring, "Kyzj zj r jvtivk dvjjrxv.") == "This is a secret message."); 174 | 175 | @test (decode_text(ring, rot13("Hello, world!")) == "Hello, world!"); 176 | 177 | gutenberg = readstring("./aima-data/gutenberg.txt"); 178 | pd = PermutationCipherDecoder(canonicalize_text(gutenberg)); 179 | 180 | @test (decode_text(pd, "aba") in ("ece", "ete", "tat", "tit", "txt")); 181 | 182 | pd = PermutationCipherDecoder(canonicalize_text(flatland)); 183 | 184 | @test (decode_text(pd, "aba") in ("ded", "did", "ece", "ele", "eme", "ere", "eve", "eye", "iti", "mom", "ses", "tat", "tit")); 185 | 186 | # Base.Test tests for generating arrays of bigrams 187 | @test (bigrams("this") == [('t', 'h'), ('h', 'i'), ('i', 's')]); 188 | 189 | @test (bigrams(["this", "is", "a", "test"]) == [("this", "is"), ("is", "a"), ("a", "test")]); 190 | 191 | flatland = String(read("./aima-data/EN-text/flatland.txt")); 192 | word_sequence = extract_words(flatland); 193 | P = UnigramWordModel(word_sequence); 194 | segmented_text, p = viterbi_text_segmentation("itiseasytoreadwordswithoutspaces", P); 195 | 196 | @test (segmented_text == ["it", "is", "easy", "to", "read", "words", "without", "spaces"]); 197 | 198 | # Base.Test tests for IR systems 199 | uc = UnixConsultant(); 200 | 201 | function check_query{T <: AbstractInformationRetrievalSystem}(irs::T, results::AbstractVector, expected::AbstractVector) 202 | @test (length(results) == length(expected)); 203 | for (i, (score, id)) in enumerate(results) 204 | expected_score, expected_url = expected[i]; 205 | @test (@sprintf("%.4f", score) == @sprintf("%.4f", expected_score)); 206 | @test (basename(irs.documents[id].url) == basename(expected_url)); 207 | end 208 | nothing; 209 | end 210 | 211 | check_query(uc, 212 | execute_query(uc, "how do I remove a file"), 213 | [(0.7683, "aima-data/MAN/rm.txt"), 214 | (0.6783, "aima-data/MAN/tar.txt"), 215 | (0.6779, "aima-data/MAN/cp.txt"), 216 | (0.6658, "aima-data/MAN/zip.txt"), 217 | (0.6458, "aima-data/MAN/gzip.txt"), 218 | (0.6374, "aima-data/MAN/pine.txt"), 219 | (0.6295, "aima-data/MAN/shred.txt"), 220 | (0.5746, "aima-data/MAN/pico.txt"), 221 | (0.4338, "aima-data/MAN/login.txt"), 222 | (0.4193, "aima-data/MAN/ln.txt")]); 223 | 224 | check_query(uc, 225 | execute_query(uc, "how do I delete a file"), 226 | [(0.7547, "aima-data/MAN/diff.txt"), 227 | (0.6912, "aima-data/MAN/pine.txt"), 228 | (0.6356, "aima-data/MAN/tar.txt"), 229 | (0.6063, "aima-data/MAN/zip.txt"), 230 | (0.5746, "aima-data/MAN/pico.txt"), 231 | (0.5128, "aima-data/MAN/shred.txt"), 232 | (0.2672, "aima-data/MAN/tr.txt")]); 233 | 234 | check_query(uc, 235 | execute_query(uc, "email"), 236 | [(0.1839, "aima-data/MAN/pine.txt"), 237 | (0.1201, "aima-data/MAN/info.txt"), 238 | (0.0989, "aima-data/MAN/pico.txt"), 239 | (0.0873, "aima-data/MAN/grep.txt"), 240 | (0.0807, "aima-data/MAN/zip.txt")]); 241 | 242 | check_query(uc, 243 | execute_query(uc, "word count for files"), 244 | [(1.2815, "aima-data/MAN/grep.txt"), 245 | (0.9420, "aima-data/MAN/find.txt"), 246 | (0.8171, "aima-data/MAN/du.txt"), 247 | (0.5545, "aima-data/MAN/ps.txt"), 248 | (0.5342, "aima-data/MAN/more.txt"), 249 | (0.4200, "aima-data/MAN/dd.txt"), 250 | (0.1285, "aima-data/MAN/who.txt")]); 251 | 252 | if (!is_windows()) # Windows 7/8 does not install a date executable by default 253 | check_query(uc, execute_query(uc, "learn: date"), []); 254 | end 255 | 256 | check_query(uc, 257 | execute_query(uc, "2003"), 258 | [(0.1458, "aima-data/MAN/pine.txt"), 259 | (0.1162, "aima-data/MAN/jar.txt")]); 260 | 261 | -------------------------------------------------------------------------------- /planning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Planning: planning.jl; chapters 10-11\n", 8 | "This notebook describes the [planning.jl](https://github.com/aimacode/aima-julia/blob/master/planning.jl) module, which covers Chapters 10 (Classical Planning) and 11 (Planning and Acting in the Real World) of *[Artificial Intelligence: A Modern Approach](http://aima.cs.berkeley.edu)*.\n", 9 | "\n", 10 | "We'll start by looking at `PDDL` and `Action` data types for defining problems and actions. Then, we will see how to use them by trying to plan a trip from *Sibiu* to *Bucharest* across the familiar map of Romania.\n", 11 | "\n", 12 | "The first step is to load the code:" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "include(\"aimajulia.jl\");\n", 22 | "\n", 23 | "using aimajulia;" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "To be able to model a planning problem properly, it is essential to be able to represent an Action. Each action we model requires at least three things:\n", 31 | "* preconditions that the action must meet\n", 32 | "* the effects of executing the action\n", 33 | "* some expression that represents the action\n", 34 | "\n", 35 | "Planning actions have been modelled using `Action`. It is interesting to see the way preconditions and effects are represented here. Instead of just being a list of expressions each, they consist of two arrays - `precond_pos` and `precond_neg`. This is to work around the fact that PDDL doesn't allow for negations. Thus, for each precondition, we maintain a seperate list of those preconditions that must hold true, and those whose negations must hold true. Similarly, instead of having a single array of expressions that are the result of executing an action, we have two. The first (`effect_add`) contains all the expressions that will evaluate to true if the action is executed, and the the second (`effect_neg`) contains all those expressions that would be false if the action is executed (ie. their negations would be true).\n", 36 | "\n", 37 | "The constructor parameters, however combine the two precondition arrays into a single `precond` parameter, and the effect arrays into a single `effect` parameter.\n", 38 | "\n", 39 | "`PDDL` is used to represent planning problems in this module. The following attributes are essential to be able to define a problem:\n", 40 | "* a goal test\n", 41 | "* an initial state\n", 42 | "* a set of viable actions that can be executed in the search space of the problem\n", 43 | "\n", 44 | "Now lets try to define a planning problem. Since we already know about the map of Romania, lets see if we can plan a trip across a simplified map of Romania.\n", 45 | "\n", 46 | "Here is our simplified map definition:" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "knowledge_base = [\n", 56 | " expr(\"Connected(Bucharest,Pitesti)\"),\n", 57 | " expr(\"Connected(Pitesti,Rimnicu)\"),\n", 58 | " expr(\"Connected(Rimnicu,Sibiu)\"),\n", 59 | " expr(\"Connected(Sibiu,Fagaras)\"),\n", 60 | " expr(\"Connected(Fagaras,Bucharest)\"),\n", 61 | " expr(\"Connected(Pitesti,Craiova)\"),\n", 62 | " expr(\"Connected(Craiova,Rimnicu)\"),\n", 63 | "];" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Let us add some logic propositions to complete our knowledge about travelling around the map. These are the typical symmetry and transitivity properties of connections on a map. We can now be sure that our `knowledge_base` understands what it truly means for two locations to be connected in the sense usually meant by humans when we use the term.\n", 71 | "\n", 72 | "Let's also add our starting location - *Sibiu* to the map." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "10-element Array{aimajulia.Expression,1}:\n", 84 | " Connected(Bucharest, Pitesti) \n", 85 | " Connected(Pitesti, Rimnicu) \n", 86 | " Connected(Rimnicu, Sibiu) \n", 87 | " Connected(Sibiu, Fagaras) \n", 88 | " Connected(Fagaras, Bucharest) \n", 89 | " Connected(Pitesti, Craiova) \n", 90 | " Connected(Craiova, Rimnicu) \n", 91 | " (Connected(x, y) ==> Connected(y, x)) \n", 92 | " ((Connected(x, y) & Connected(y, z)) ==> Connected(x, z))\n", 93 | " At(Sibiu) " 94 | ] 95 | }, 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "for element in [\n", 103 | " expr(\"Connected(x,y) ==> Connected(y,x)\"),\n", 104 | " expr(\"Connected(x,y) & Connected(y,z) ==> Connected(x,z)\"),\n", 105 | " expr(\"At(Sibiu)\")\n", 106 | " ]\n", 107 | " push!(knowledge_base, element);\n", 108 | "end\n", 109 | "knowledge_base" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "We now define possible actions to our problem. We know that we can drive between any connected places. But, as is evident from [this](https://en.wikipedia.org/wiki/List_of_airports_in_Romania) list of Romanian airports, we can also fly directly between Sibiu, Bucharest, and Craiova.\n", 117 | "\n", 118 | "We can define these flight actions like this:" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# Sibiu to Bucharest\n", 128 | "precond_pos = [expr(\"ft(Sibiu)\")];\n", 129 | "precond_neg = [];\n", 130 | "effect_add = [expr(\"At(Bucharest)\")];\n", 131 | "effect_rem = [expr(\"At(Sibiu)\")];\n", 132 | "fly_s_b = PlanningAction(expr(\"Fly(Sibiu, Bucharest)\"), (precond_pos, precond_neg), (effect_add, effect_rem));\n", 133 | "\n", 134 | "# Bucharest to Sibiu\n", 135 | "precond_pos = [expr(\"At(Bucharest)\")];\n", 136 | "precond_neg = [];\n", 137 | "effect_add = [expr(\"At(Sibiu)\")];\n", 138 | "effect_rem = [expr(\"At(Bucharest)\")];\n", 139 | "fly_b_s = PlanningAction(expr(\"Fly(Bucharest, Sibiu)\"), (precond_pos, precond_neg), (effect_add, effect_rem));\n", 140 | "\n", 141 | "# Sibiu to Craiova\n", 142 | "precond_pos = [expr(\"At(Sibiu)\")];\n", 143 | "precond_neg = [];\n", 144 | "effect_add = [expr(\"At(Craiova)\")];\n", 145 | "effect_rem = [expr(\"At(Sibiu)\")];\n", 146 | "fly_s_c = PlanningAction(expr(\"Fly(Sibiu, Craiova)\"), (precond_pos, precond_neg), (effect_add, effect_rem));\n", 147 | "\n", 148 | "# Craiova to Sibiu\n", 149 | "precond_pos = [expr(\"At(Craiova)\")];\n", 150 | "precond_neg = [];\n", 151 | "effect_add = [expr(\"At(Sibiu)\")];\n", 152 | "effect_rem = [expr(\"At(Craiova)\")];\n", 153 | "fly_c_s = PlanningAction(expr(\"Fly(Craiova, Sibiu)\"), (precond_pos, precond_neg), (effect_add, effect_rem));\n", 154 | "\n", 155 | "# Bucharest to Craiova\n", 156 | "precond_pos = [expr(\"At(Bucharest)\")];\n", 157 | "precond_neg = [];\n", 158 | "effect_add = [expr(\"At(Craiova)\")];\n", 159 | "effect_rem = [expr(\"At(Bucharest)\")];\n", 160 | "fly_b_c = PlanningAction(expr(\"Fly(Bucharest, Craiova)\"), (precond_pos, precond_neg), (effect_add, effect_rem));\n", 161 | "\n", 162 | "# Craiova to Bucharest\n", 163 | "precond_pos = [expr(\"At(Craiova)\")];\n", 164 | "precond_neg = [];\n", 165 | "effect_add = [expr(\"At(Bucharest)\")];\n", 166 | "effect_rem = [expr(\"At(Craiova)\")];\n", 167 | "fly_c_b = PlanningAction(expr(\"Fly(Craiova, Bucharest)\"), (precond_pos, precond_neg), (effect_add, effect_rem));" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "And the drive actions like this." 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 5, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "# Drive\n", 184 | "precond_pos = [expr(\"At(x)\")];\n", 185 | "precond_neg = [];\n", 186 | "effect_add = [expr(\"At(y)\")];\n", 187 | "effect_rem = [expr(\"At(x)\")];\n", 188 | "drive = PlanningAction(expr(\"Drive(x, y)\"), (precond_pos, precond_neg), (effect_add, effect_rem));" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "Finally, we can define a a function that will tell us when we have reached our destination, Bucharest." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 6, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/plain": [ 206 | "goal_text (generic function with 1 method)" 207 | ] 208 | }, 209 | "execution_count": 6, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "function goal_text(kb::PDDL)\n", 216 | " return ask(kb, expr(\"At(Bucharest)\"));\n", 217 | "end" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "Thus, with all the components in place, we can define the planning problem." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 7, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "text/plain": [ 235 | "aimajulia.PDDL(aimajulia.FirstOrderLogicKnowledgeBase(aimajulia.Expression[Connected(Bucharest, Pitesti), Connected(Pitesti, Rimnicu), Connected(Rimnicu, Sibiu), Connected(Sibiu, Fagaras), Connected(Fagaras, Bucharest), Connected(Pitesti, Craiova), Connected(Craiova, Rimnicu), (Connected(x, y) ==> Connected(y, x)), ((Connected(x, y) & Connected(y, z)) ==> Connected(x, z)), At(Sibiu)]), aimajulia.PlanningAction[aimajulia.PlanningAction(\"Fly\", (Sibiu, Bucharest), aimajulia.Expression[ft(Sibiu)], aimajulia.Expression[], aimajulia.Expression[At(Bucharest)], aimajulia.Expression[At(Sibiu)]), aimajulia.PlanningAction(\"Fly\", (Bucharest, Sibiu), aimajulia.Expression[At(Bucharest)], aimajulia.Expression[], aimajulia.Expression[At(Sibiu)], aimajulia.Expression[At(Bucharest)]), aimajulia.PlanningAction(\"Fly\", (Sibiu, Craiova), aimajulia.Expression[At(Sibiu)], aimajulia.Expression[], aimajulia.Expression[At(Craiova)], aimajulia.Expression[At(Sibiu)]), aimajulia.PlanningAction(\"Fly\", (Craiova, Sibiu), aimajulia.Expression[At(Craiova)], aimajulia.Expression[], aimajulia.Expression[At(Sibiu)], aimajulia.Expression[At(Craiova)]), aimajulia.PlanningAction(\"Fly\", (Bucharest, Craiova), aimajulia.Expression[At(Bucharest)], aimajulia.Expression[], aimajulia.Expression[At(Craiova)], aimajulia.Expression[At(Bucharest)]), aimajulia.PlanningAction(\"Fly\", (Craiova, Bucharest), aimajulia.Expression[At(Craiova)], aimajulia.Expression[], aimajulia.Expression[At(Bucharest)], aimajulia.Expression[At(Craiova)]), aimajulia.PlanningAction(\"Drive\", (x, y), aimajulia.Expression[At(x)], aimajulia.Expression[], aimajulia.Expression[At(y)], aimajulia.Expression[At(x)])], aimajulia.goal_test)" 236 | ] 237 | }, 238 | "execution_count": 7, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "prob = PDDL(knowledge_base, [fly_s_b, fly_b_s, fly_s_c, fly_c_s, fly_b_c, fly_c_b, drive], goal_test)" 245 | ] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Julia 0.6.0", 251 | "language": "julia", 252 | "name": "julia-0.6" 253 | }, 254 | "language_info": { 255 | "file_extension": ".jl", 256 | "mimetype": "application/julia", 257 | "name": "julia", 258 | "version": "0.6.0" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 2 263 | } 264 | -------------------------------------------------------------------------------- /mdp.jl: -------------------------------------------------------------------------------- 1 | 2 | export AbstractMarkovDecisionProcess, MarkovDecisionProcess, 3 | reward, transition_model, actions, 4 | GridMarkovDecisionProcess, go_to, show_grid, to_arrows, 5 | value_iteration, expected_utility, optimal_policy, 6 | policy_evaluation, policy_iteration; 7 | 8 | abstract type AbstractMarkovDecisionProcess end; 9 | 10 | #= 11 | 12 | MarkovDecisionProcess is a MDP implementation of AbstractMarkovDecisionProcess. 13 | 14 | A Markov decision process is a sequential decision problem with fully observable 15 | 16 | and stochastic environment with a transition model and rewards function. 17 | 18 | The discount factor (gamma variable) describes the preference for current rewards 19 | 20 | over future rewards. 21 | 22 | =# 23 | struct MarkovDecisionProcess{T} <: AbstractMarkovDecisionProcess 24 | initial::T 25 | states::Set{T} 26 | actions::Set{T} 27 | terminal_states::Set{T} 28 | transitions::Dict 29 | gamma::Float64 30 | reward::Dict 31 | 32 | function MarkovDecisionProcess{T}(initial::T, actions_list::Set{T}, terminal_states::Set{T}, transitions::Dict, states::Union{Void, Set{T}}, gamma::Float64) where T 33 | if (!(0 < gamma <= 1)) 34 | error("MarkovDecisionProcess(): The gamma variable of an MDP must be between 0 and 1, the constructor was given ", gamma, "!"); 35 | end 36 | local new_states::Set{typeof(initial)}; 37 | if (typeof(states) <: Set) 38 | new_states = states; 39 | else 40 | new_states = Set{typeof(initial)}(); 41 | end 42 | return new(initial, new_states, actions_list, terminal_states, transitions, gamma, Dict()); 43 | end 44 | end 45 | 46 | MarkovDecisionProcess(initial, actions_list::Set, terminal_states::Set, transitions::Dict; states::Union{Void, Set}=nothing, gamma::Float64=0.9) = MarkovDecisionProcess{typeof(initial)}(initial, actions_list, terminal_states, transitions, states, gamma); 47 | 48 | """ 49 | reward{T <: AbstractMarkovDecisionProcess}(mdp::T, state) 50 | 51 | Return a reward based on the given 'state'. 52 | """ 53 | function reward{T <: AbstractMarkovDecisionProcess}(mdp::T, state) 54 | return mdp.reward[state]; 55 | end 56 | 57 | """ 58 | transition_model{T <: AbstractMarkovDecisionProcess}(mdp::T, state, action) 59 | 60 | Return a list of (P(s'|s, a), s') pairs given the state 's' and action 'a'. 61 | """ 62 | function transition_model{T <: AbstractMarkovDecisionProcess}(mdp::T, state, action) 63 | if (length(mdp.transitions) == 0) 64 | error("transition_model(): The transition model for the given 'mdp' could not be found!"); 65 | else 66 | return mdp.transitions[state][action]; 67 | end 68 | end 69 | 70 | """ 71 | actions{T <: AbstractMarkovDecisionProcess}(mdp::T, state) 72 | 73 | Return a set of actions that are possible in the given state. 74 | """ 75 | function actions{T <: AbstractMarkovDecisionProcess}(mdp::T, state) 76 | if (state in mdp.terminal_states) 77 | return Set{Void}([nothing]); 78 | else 79 | return mdp.actions; 80 | end 81 | end 82 | 83 | #= 84 | 85 | GridMarkovDecisionProcess is a two-dimensional environment MDP implementation 86 | 87 | of AbstractMarkovDecisionProcess. Obstacles in the environment are represented 88 | 89 | by a null. 90 | 91 | =# 92 | struct GridMarkovDecisionProcess <: AbstractMarkovDecisionProcess 93 | initial::Tuple{Int64, Int64} 94 | states::Set{Tuple{Int64, Int64}} 95 | actions::Set{Tuple{Int64, Int64}} 96 | terminal_states::Set{Tuple{Int64, Int64}} 97 | grid::Array{Nullable{Float64}, 2} 98 | gamma::Float64 99 | reward::Dict 100 | 101 | function GridMarkovDecisionProcess(initial::Tuple{Int64, Int64}, terminal_states::Set{Tuple{Int64, Int64}}, grid::Array{Nullable{Float64}, 2}; states::Union{Void, Set{Tuple{Int64, Int64}}}=nothing, gamma::Float64=0.9) 102 | if (!(0 < gamma <= 1)) 103 | error("GridMarkovDecisionProcess(): The gamma variable of an MDP must be between 0 and 1, the constructor was given ", gamma, "!"); 104 | end 105 | local new_states::Set{Tuple{Int64, Int64}}; 106 | if (typeof(states) <: Set) 107 | new_states = states; 108 | else 109 | new_states = Set{Tuple{Int64, Int64}}(); 110 | end 111 | local orientations::Set = Set{Tuple{Int64, Int64}}([(1, 0), (0, 1), (-1, 0), (0, -1)]); 112 | local reward::Dict = Dict(); 113 | for i in 1:getindex(size(grid), 1) 114 | for j in 1:getindex(size(grid, 2)) 115 | reward[(i, j)] = grid[i, j] 116 | if (!isnull(grid[i, j])) 117 | push!(new_states, (i, j)); 118 | end 119 | end 120 | end 121 | return new(initial, new_states, orientations, terminal_states, grid, gamma, reward); 122 | end 123 | end 124 | 125 | """ 126 | go_to(gmdp::GridMarkovDecisionProcess, state::Tuple{Int64, Int64}, direction::Tuple{Int64, Int64}) 127 | 128 | Return the next state given the current state and direction. 129 | """ 130 | function go_to(gmdp::GridMarkovDecisionProcess, state::Tuple{Int64, Int64}, direction::Tuple{Int64, Int64}) 131 | local next_state::Tuple{Int64, Int64} = map(+, state, direction); 132 | if (next_state in gmdp.states) 133 | return next_state; 134 | else 135 | return state; 136 | end 137 | end 138 | 139 | function transition_model(gmdp::GridMarkovDecisionProcess, state::Tuple{Int64, Int64}, action::Void) 140 | return [(0.0, state)]; 141 | end 142 | 143 | function transition_model(gmdp::GridMarkovDecisionProcess, state::Tuple{Int64, Int64}, action::Tuple{Int64, Int64}) 144 | return [(0.8, go_to(gmdp, state, action)), 145 | (0.1, go_to(gmdp, state, utils.turn_heading(action, -1))), 146 | (0.1, go_to(gmdp, state, utils.turn_heading(action, 1)))]; 147 | end 148 | 149 | function show_grid(gmdp::GridMarkovDecisionProcess, mapping::Dict) 150 | local grid::Array{Nullable{String}, 2}; 151 | local rows::AbstractVector = []; 152 | for i in 1:getindex(size(gmdp.grid), 1) 153 | local row::Array{Nullable{String}, 1} = Array{Nullable{String}, 1}(); 154 | for j in 1:getindex(size(gmdp.grid), 2) 155 | push!(row, Nullable{String}(get(mapping, (i, j), nothing))); 156 | end 157 | push!(rows, reshape(row, (1, length(row)))); 158 | end 159 | grid = reduce(vcat, rows); 160 | return grid; 161 | end 162 | 163 | # (0, 1) will move the agent rightward. 164 | # (-1, 0) will move the agent upward. 165 | # (0, -1) will move the agent leftward. 166 | # (1, 0) will move the agent downward. 167 | function to_arrows(gmdp::GridMarkovDecisionProcess, policy::Dict) 168 | local arrow_characters::Dict = Dict([Pair((0, 1), ">"), 169 | Pair((-1, 0), "^"), 170 | Pair((0, -1), "<"), 171 | Pair((1, 0), "v"), 172 | Pair(nothing, ".")]); 173 | return show_grid(gmdp, Dict(collect(Pair(state, arrow_characters[action]) 174 | for (state, action) in policy))); 175 | end 176 | 177 | # An example sequential decision problem (Fig. 17.1a) where an agent does not 178 | # terminate until it reaches a terminal state in the 4x3 environment (Fig. 17.1a). 179 | # 180 | # Matrices in Julia start from the upper-left corner and index (1, 1). 181 | sequential_decision_environment = GridMarkovDecisionProcess((1, 1), 182 | Set([(2, 4), (3, 4)]), 183 | map(Nullable{Float64}, [-0.04 -0.04 -0.04 -0.04; 184 | -0.04 nothing -0.04 -1; 185 | -0.04 -0.04 -0.04 +1])); 186 | 187 | """ 188 | value_iteration{T <: AbstractMarkovDecisionProcess}(mdp::T; epsilon::Float64=0.001) 189 | 190 | Return the utilities of the MDP's states as a Dict by applying the value iteration algorithm (Fig. 17.4) 191 | on the given Markov decision process 'mdp' and a arbitarily small positive number 'epsilon'. 192 | """ 193 | function value_iteration{T <: AbstractMarkovDecisionProcess}(mdp::T; epsilon::Float64=0.001) 194 | local U_prime::Dict = Dict(collect(Pair(state, 0.0) for state in mdp.states)); 195 | while (true) 196 | local U::Dict = copy(U_prime); 197 | local delta::Float64 = 0.0; 198 | for state in mdp.states 199 | U_prime[state] = (reward(mdp, state) 200 | + (mdp.gamma 201 | * max((sum(collect(p * U[state_prime] 202 | for (p, state_prime) in transition_model(mdp, state, action))) 203 | for action in actions(mdp, state))...))); 204 | delta = max(delta, abs(U_prime[state] - U[state])); 205 | end 206 | if (delta < ((epsilon * (1 - mdp.gamma))/mdp.gamma)) 207 | return U; 208 | end 209 | end 210 | end 211 | 212 | function value_iteration(gmdp::GridMarkovDecisionProcess; epsilon::Float64=0.001) 213 | local U_prime::Dict = Dict(collect(Pair(state, 0.0) for state in gmdp.states)); 214 | while (true) 215 | local U::Dict = copy(U_prime); 216 | local delta::Float64 = 0.0; 217 | for state in gmdp.states 218 | # Extract Float64 from Nullable{Float64} 219 | U_prime[state] = (get(reward(gmdp, state)) 220 | + (gmdp.gamma 221 | * max((sum(collect(p * U[state_prime] 222 | for (p, state_prime) in transition_model(gmdp, state, action))) 223 | for action in actions(gmdp, state))...))); 224 | delta = max(delta, abs(U_prime[state] - U[state])); 225 | end 226 | if (delta < ((epsilon * (1 - gmdp.gamma))/gmdp.gamma)) 227 | return U; 228 | end 229 | end 230 | end 231 | 232 | function expected_utility{T <: AbstractMarkovDecisionProcess}(mdp::T, U::Dict, state::Tuple{Int64, Int64}, action::Tuple{Int64, Int64}) 233 | return sum((p * U[state_prime] for (p, state_prime) in transition_model(mdp, state, action))); 234 | end 235 | 236 | function expected_utility{T <: AbstractMarkovDecisionProcess}(mdp::T, U::Dict, state::Tuple{Int64, Int64}, action::Void) 237 | return sum((p * U[state_prime] for (p, state_prime) in transition_model(mdp, state, action))); 238 | end 239 | 240 | """ 241 | optimal_policy{T <: AbstractMarkovDecisionProcess}(mdp::T, U::Dict) 242 | 243 | Return the optimal_policy 'π*(s)' (Equation 17.4) given the Markov decision process 'mdp' 244 | and the utility function 'U'. 245 | """ 246 | function optimal_policy{T <: AbstractMarkovDecisionProcess}(mdp::T, U::Dict) 247 | local pi::Dict = Dict(); 248 | for state in mdp.states 249 | pi[state] = argmax(collect(actions(mdp, state)), (function(action::Union{Void, Tuple{Int64, Int64}}) 250 | return expected_utility(mdp, U, state, action); 251 | end)); 252 | end 253 | return pi; 254 | end 255 | 256 | """ 257 | policy_evaluation{T <: AbstractMarkovDecisionProcess}(pi::Dict, U::Dict, mdp::T; k::Int64=20) 258 | 259 | Return the updated utilities of the MDP's states by applying the modified policy iteration 260 | algorithm on the given Markov decision process 'mdp', utility function 'U', policy 'pi', 261 | and number of Bellman updates to use 'k'. 262 | """ 263 | function policy_evaluation{T <: AbstractMarkovDecisionProcess}(pi::Dict, U::Dict, mdp::T; k::Int64=20) 264 | for i in 1:k 265 | for state in mdp.states 266 | U[state] = (reward(mdp, state) 267 | + (mdp.gamma 268 | * sum((p * U[state_prime] for (p, state_prime) in transition_model(mdp, state, pi[state]))))); 269 | end 270 | end 271 | return U; 272 | end 273 | 274 | function policy_evaluation(pi::Dict, U::Dict, mdp::GridMarkovDecisionProcess; k::Int64=20) 275 | for i in 1:k 276 | for state in mdp.states 277 | U[state] = (get(reward(mdp, state)) 278 | + (mdp.gamma 279 | * sum((p * U[state_prime] for (p, state_prime) in transition_model(mdp, state, pi[state]))))); 280 | end 281 | end 282 | return U; 283 | end 284 | 285 | """ 286 | policy_iteration{T <: AbstractMarkovDecisionProcess}(mdp::T) 287 | 288 | Return a policy using the policy iteration algorithm (Fig. 17.7) given the Markov decision process 'mdp'. 289 | """ 290 | function policy_iteration{T <: AbstractMarkovDecisionProcess}(mdp::T) 291 | local U::Dict = Dict(collect(Pair(state, 0.0) for state in mdp.states)); 292 | local pi::Dict = Dict(collect(Pair(state, rand(RandomDeviceInstance, collect(actions(mdp, state)))) 293 | for state in mdp.states)); 294 | while (true) 295 | U = policy_evaluation(pi, U, mdp); 296 | local unchanged::Bool = true; 297 | for state in mdp.states 298 | local action = argmax(collect(actions(mdp, state)), (function(action::Union{Void, Tuple{Int64, Int64}}) 299 | return expected_utility(mdp, U, state, action); 300 | end)); 301 | if (action != pi[state]) 302 | pi[state] = action; 303 | unchanged = false; 304 | end 305 | end 306 | if (unchanged) 307 | return pi; 308 | end 309 | end 310 | end 311 | 312 | -------------------------------------------------------------------------------- /tests/run_logic_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | #The following logic tests are from the aima-python doctests 8 | 9 | x = Expression("x"); 10 | 11 | y = Expression("y"); 12 | 13 | z = Expression("z"); 14 | 15 | @test variables(expr("F(x, x) & G(x, y) & H(y, z) & R(A, z, z)")) == Set([x, y, z]); 16 | 17 | @test variables(expr("F(x, A, y)")) == Set([x, y]); 18 | 19 | @test variables(expr("F(G(x), z)")) == Set([x, z]); 20 | 21 | @test show(expr("P & Q ==> Q")) == "((P & Q) ==> Q)"; 22 | 23 | @test show(expr("P ==> Q(1)")) == "(P ==> Q(1))"; 24 | 25 | @test show(expr("P & Q | ~R(x, F(x))")) == "((P & Q) | ~(R(x, F(x))))"; 26 | 27 | @test show(expr("P & Q ==> R & S")) == "(((P & Q) ==> R) & S)"; 28 | 29 | @test tt_entails(expr("P & Q"), expr("Q")) == true; 30 | 31 | @test tt_entails(expr("P | Q"), expr("Q")) == false; 32 | 33 | @test tt_entails(expr("A & (B | C) & E & F & ~(P | Q)"), expr("A & E & F & ~P & ~Q")) == true; 34 | 35 | @test proposition_symbols(expr("x & y & z | A")) == [Expression("A")]; 36 | 37 | @test proposition_symbols(expr("(x & B(z)) ==> Farmer(y) | A")) == [Expression("A"), expr("Farmer(y)"), expr("B(z)")]; 38 | 39 | @test tt_true("(P ==> Q) <=> (~P | Q)") == true; 40 | 41 | @test typeof(pl_true(Expression("P"))) <: Void; 42 | 43 | @test typeof(pl_true(expr("P | P"))) <: Void; 44 | 45 | @test pl_true(expr("P | Q"), model=Dict([Pair(Expression("P"), true)])) == true; 46 | 47 | @test pl_true(expr("(A | B) & (C | D)"), 48 | model=Dict([Pair(Expression("A"), false), 49 | Pair(Expression("B"), true), 50 | Pair(Expression("C"), true)])) == true; 51 | 52 | @test pl_true(expr("(A & B) & (C | D)"), 53 | model=Dict([Pair(Expression("A"), false), 54 | Pair(Expression("B"), true), 55 | Pair(Expression("C"), true)])) == false; 56 | 57 | @test pl_true(expr("(A & B) | (A & C)"), 58 | model=Dict([Pair(Expression("A"), false), 59 | Pair(Expression("B"), true), 60 | Pair(Expression("C"), true)])) == false; 61 | 62 | @test typeof(pl_true(expr("(A | B) & (C | D)"), 63 | model=Dict([Pair(Expression("A"), true), 64 | Pair(Expression("D"), false)]))) <: Void; 65 | 66 | @test pl_true(Expression("P"), model=Dict([Pair(Expression("P"), false)])) == false; 67 | 68 | @test typeof(pl_true(expr("P | ~P"))) <: Void; 69 | 70 | @test eliminate_implications(expr("A ==> (~B <== C)")) == expr("(~B | ~C) | ~A"); 71 | 72 | @test eliminate_implications(expr("A ^ B")) == expr("(A & ~B) | (~A & B)"); 73 | 74 | @test move_not_inwards(expr("~(A | B)")) == expr("~A & ~B"); 75 | 76 | @test move_not_inwards(expr("~(A & B)")) == expr("~A | ~B"); 77 | 78 | @test move_not_inwards(expr("~(~(A | ~B) | ~(~C))")) == expr("(A | ~B) & ~C"); 79 | 80 | @test distribute_and_over_or(expr("(A & B) | C")) == expr("(A | C) & (B | C)"); 81 | 82 | @test associate("&", (expr("A & B"), expr("B | C"), expr("B & C"))) == expr("&(A, B, (B | C), B, C)"); 83 | 84 | @test associate("|", (expr("A | (B | (C | (A & B)))"),)) == expr("|(A, B, C, (A & B))"); 85 | 86 | @test conjuncts(expr("A & B")) == [Expression("A"), Expression("B")]; 87 | 88 | @test conjuncts(expr("A | B")) == [expr("A | B")]; 89 | 90 | @test disjuncts(expr("A | B")) == [Expression("A"), Expression("B")]; 91 | 92 | @test disjuncts(expr("A & B")) == [expr("A & B")]; 93 | 94 | @test repr(to_conjunctive_normal_form(Expression("&", 95 | aimajulia.wumpus_world_inference, 96 | Expression("~", expr("~P12"))))) == 97 | "((~(P12) | B11) & (~(P21) | B11) & (P12 | P21 | ~(B11)) & ~(B11) & P12)"; 98 | 99 | @test to_conjunctive_normal_form(expr("~(B | C)")) == expr("~B & ~C"); 100 | 101 | @test repr(to_conjunctive_normal_form(expr("~(B | C)"))) == "(~(B) & ~(C))"; 102 | 103 | @test to_conjunctive_normal_form(expr("(P & Q) | (~P & ~Q)")) == expr("&((~P | P), (~Q | P), (~P | Q), (~Q | Q))"); 104 | 105 | @test repr(to_conjunctive_normal_form(expr("(P & Q) | (~P & ~Q)"))) == "((~(P) | P) & (~(Q) | P) & (~(P) | Q) & (~(Q) | Q))"; 106 | 107 | @test to_conjunctive_normal_form(expr("B <=> (P1 | P2)")) == expr("&((~P1 | B), (~P2 | B), |(P1, P2, ~B))"); 108 | 109 | @test repr(to_conjunctive_normal_form(expr("B <=> (P1 | P2)"))) == "((~(P1) | B) & (~(P2) | B) & (P1 | P2 | ~(B)))"; 110 | 111 | @test to_conjunctive_normal_form(expr("a | (b & c) | d")) == expr("|(b, a, d) & |(c, a, d)"); 112 | 113 | @test repr(to_conjunctive_normal_form(expr("a | (b & c) | d"))) == "((b | a | d) & (c | a | d))"; 114 | 115 | @test to_conjunctive_normal_form(expr("A & (B | (D & E))")) == expr("&(A, (D | B), (E | B))"); 116 | 117 | @test repr(to_conjunctive_normal_form(expr("A & (B | (D & E))"))) == "(A & (D | B) & (E | B))"; 118 | 119 | @test to_conjunctive_normal_form(expr("A | (B | (C | (D & E)))")) == expr("|(D, A, B, C) & |(E, A, B, C)"); 120 | 121 | @test repr(to_conjunctive_normal_form(expr("A | (B | (C | (D & E)))"))) == "((D | A | B | C) & (E | A | B | C))"; 122 | 123 | prop_kb = PropositionalKnowledgeBase(); 124 | 125 | @test count((function(item) 126 | if (typeof(item) <: Bool) 127 | return item; 128 | else 129 | return true; 130 | end 131 | end), collect(ask(prop_kb, e) for e in map(expr, ["A", "C", "D", "E", "Q"]))) == 0; 132 | 133 | tell(prop_kb, expr("A & E")); 134 | 135 | @test ask(prop_kb, expr("A")) == Dict([]); 136 | 137 | @test ask(prop_kb, expr("E")) == Dict([]); 138 | 139 | tell(prop_kb, expr("E ==> C")); 140 | 141 | @test ask(prop_kb, expr("C")) == Dict([]); 142 | 143 | retract(prop_kb, expr("E")); 144 | 145 | @test ask(prop_kb, expr("E")) == false; 146 | 147 | @test ask(prop_kb, expr("C")) == false; 148 | 149 | plr_results = pl_resolve(to_conjunctive_normal_form(expr("A | B | C")), 150 | to_conjunctive_normal_form(expr("~B | ~C | F"))); 151 | 152 | @test pretty_set(Set{Expression}(disjuncts(plr_results[1]))) == "Set(aimajulia.Expression[A, C, F, ~(C)])"; 153 | 154 | @test pretty_set(Set{Expression}(disjuncts(plr_results[2]))) == "Set(aimajulia.Expression[A, B, F, ~(B)])"; 155 | 156 | # Use PropositionalKnowledgeBase to represent the Wumpus World (Fig. 7.4) 157 | 158 | kb_wumpus = PropositionalKnowledgeBase(); 159 | tell(kb_wumpus, expr("~P11")); 160 | tell(kb_wumpus, expr("B11 <=> (P12 | P21)")); 161 | tell(kb_wumpus, expr("B21 <=> (P11 | P22 | P31)")); 162 | tell(kb_wumpus, expr("~B11")); 163 | tell(kb_wumpus, expr("B21")); 164 | 165 | # Can't find a pit at location (1, 1). 166 | @test ask(kb_wumpus, expr("~P11")) == Dict([]); 167 | 168 | # Can't find a pit at location (1, 2). 169 | @test ask(kb_wumpus, expr("~P12")) == Dict([]); 170 | 171 | # Found pit at location (2, 2). 172 | @test ask(kb_wumpus, expr("P22")) == false; 173 | 174 | # Found pit at location (3, 1). 175 | @test ask(kb_wumpus, expr("P31")) == false; 176 | 177 | # Locations (1, 2) and (2, 1) do not contain pits. 178 | @test ask(kb_wumpus, expr("~P12 & ~P21")) == Dict([]); 179 | 180 | # Found a pit in either (3, 1) or (2,2). 181 | @test ask(kb_wumpus, expr("P22 | P31")) == Dict([]); 182 | 183 | @test pl_fc_entails(aimajulia.horn_clauses_kb, Expression("Q")) == true; 184 | 185 | @test pl_fc_entails(aimajulia.horn_clauses_kb, Expression("SomethingSilly")) == false; 186 | 187 | @test inspect_literal(Expression("P")) == (Expression("P"), true); 188 | 189 | @test inspect_literal(Expression("~", Expression("P"))) == (Expression("P"), false); 190 | 191 | @test unit_clause_assign(expr("A | B | C"), Dict([Pair(Expression("A"), true)])) == (nothing, nothing); 192 | 193 | @test unit_clause_assign(expr("B | ~C"), Dict([Pair(Expression("A"), true)])) == (nothing, nothing); 194 | 195 | @test unit_clause_assign(expr("B | C"), Dict([Pair(Expression("A"), true)])) == (nothing, nothing); 196 | 197 | @test unit_clause_assign(expr("~A | ~B"), Dict([Pair(Expression("A"), true)])) == (Expression("B"), false); 198 | 199 | @test unit_clause_assign(expr("B | ~A"), Dict([Pair(Expression("A"), true)])) == (Expression("B"), true); 200 | 201 | @test find_unit_clause(map(expr, ["A | B | C", "B | ~C", "~A | ~B"]), Dict([Pair(Expression("A"), true)])) == (Expression("B"), false); 202 | 203 | @test find_pure_symbol(map(expr, ["A", "B", "C"]), map(expr, ["A | ~B", "~B | ~C", "C | A"])) == (Expression("A"), true); 204 | 205 | @test find_pure_symbol(map(expr, ["A", "B", "C"]), map(expr, ["~A | ~B", "~B | ~C", "C | A"])) == (Expression("B"), false); 206 | 207 | @test find_pure_symbol(map(expr, ["A", "B", "C"]), map(expr, ["~A | B", "~B | ~C", "C | A"])) == (nothing, nothing); 208 | 209 | @test dpll_satisfiable(expr("A & ~B")) == Dict([Pair(Expression("A"), true), 210 | Pair(Expression("B"), false),]); 211 | 212 | @test dpll_satisfiable(expr("P & ~P")) == false; 213 | 214 | @test (dpll_satisfiable(expr("A & ~B & C & (A | ~D) & (~E | ~D) & (C | ~D) & (~A | ~F) & (E | ~F) & (~D | ~F) & (B | ~C | D) & (A | ~E | F) & (~A | E | D)")) 215 | == Dict([Pair(Expression("A"), true), 216 | Pair(Expression("B"), false), 217 | Pair(Expression("C"), true), 218 | Pair(Expression("D"), true), 219 | Pair(Expression("E"), false), 220 | Pair(Expression("F"), false),])); 221 | 222 | function walksat_test(clauses::Array{Expression, 1}; solutions::Dict=Dict()) 223 | local sln = walksat(clauses); 224 | if (!(typeof(sln) <: Void)) #found a satisfiable solution 225 | @test all(collect(pl_true(clause, model=sln) for clause in clauses)); 226 | if (length(solutions) != 0) 227 | @test all(collect(pl_true(clause, model=solutions) for clause in clauses)); 228 | @test sln == solutions; 229 | end 230 | end 231 | nothing; 232 | end 233 | 234 | walksat_test(map(expr, ["A & B", "A & C"])); 235 | 236 | walksat_test(map(expr, ["A | B", "P & Q", "P & B"])); 237 | 238 | walksat_test(map(expr, ["A & B", "C | D", "~(D | P)"]), solutions=Dict([Pair(Expression("A"), true), 239 | Pair(Expression("B"), true), 240 | Pair(Expression("C"), true), 241 | Pair(Expression("D"), false), 242 | Pair(Expression("P"), false),])); 243 | 244 | @test (typeof(walksat(map(expr, ["A & ~A"]), p=0.5, max_flips=100)) <: Void); 245 | 246 | @test (typeof(walksat(map(expr, ["A | B", "~A", "~(B | C)", "C | D", "P | Q"]), p=0.5, max_flips=100)) <: Void); 247 | 248 | @test (typeof(walksat(map(expr, ["A | B", "B & C", "C | D", "D & A", "P", "~P"]), p=0.5, max_flips=100)) <: Void); 249 | 250 | transition = Dict([Pair("A", Dict([Pair("Left", "A"), Pair("Right", "B")])), 251 | Pair("B", Dict([Pair("Left", "A"), Pair("Right", "C")])), 252 | Pair("C", Dict([Pair("Left", "B"), Pair("Right", "C")]))]); 253 | 254 | @test (typeof(sat_plan("A", transition,"C", 2)) <: Void); 255 | 256 | @test sat_plan("A", transition, "B", 3) == ["Right"]; 257 | 258 | @test sat_plan("C", transition, "A", 3) == ["Left", "Left"]; 259 | 260 | transition = Dict([Pair((0, 0), Dict([Pair("Right", (0, 1)), Pair("Down", (1, 0))])), 261 | Pair((0, 1), Dict([Pair("Left", (1, 0)), Pair("Down", (1, 1))])), 262 | Pair((1, 0), Dict([Pair("Right", (1, 0)), Pair("Up", (1, 0)), Pair("Left", (1, 0)), Pair("Down", (1, 0))])), 263 | Pair((1, 1), Dict([Pair("Left", (1, 0)), Pair("Up", (0, 1))]))]); 264 | 265 | @test sat_plan((0, 0), transition, (1, 1), 4) == ["Right", "Down"]; 266 | 267 | @test unify(expr("x + y"), expr("y + C"), Dict([])) == Dict([Pair(Expression("x"), Expression("y")), 268 | Pair(Expression("y"), Expression("C"))]); 269 | 270 | @test unify(expr("x"), expr("3"), Dict([])) == Dict([Pair(Expression("x"), Expression("3"))]); 271 | 272 | @test unify(expr("x"), expr("x"), Dict([])) == Dict([]); 273 | 274 | @test extend(Dict([Pair(Expression("x"), 1)]), Expression("y"), 2) == Dict([Pair(Expression("x"), 1), 275 | Pair(Expression("y"), 2)]); 276 | 277 | @test repr(substitute(Dict([Pair(Expression("x"), Expression("42")), 278 | Pair(Expression("y"), Expression("0"))]), 279 | expr("F(x) + y"))) == "(F(42) + 0)"; 280 | 281 | function fol_bc_ask_query(q::Expression; kb::Union{Void, AbstractKnowledgeBase}=nothing) 282 | local answers::Tuple; 283 | if (typeof(kb) <: Void) 284 | answers = fol_bc_ask(aimajulia.test_fol_kb, q); 285 | else 286 | answers = fol_bc_ask(kb, q); 287 | end 288 | local test_vars = variables(q); 289 | return sort(collect(Dict(collect(Pair(k, v) for (k, v) in answer if (k in test_vars))) for answer in answers), 290 | lt=(function(d1::Dict, d2::Dict) 291 | return isless(repr(d1), repr(d2)); 292 | end)); 293 | end 294 | 295 | @test fol_bc_ask_query(expr("Farmer(x)")) == [Dict([Pair(Expression("x"), Expression("Mac"))])]; 296 | 297 | @test fol_bc_ask_query(expr("Human(x)")) == [Dict([Pair(Expression("x"), Expression("Mac"))]), 298 | Dict([Pair(Expression("x"), Expression("MrsMac"))])]; 299 | 300 | @test fol_bc_ask_query(expr("Rabbit(x)")) == [Dict([Pair(Expression("x"), Expression("MrsRabbit"))]), 301 | Dict([Pair(Expression("x"), Expression("Pete"))])]; 302 | 303 | @test fol_bc_ask_query(expr("Criminal(x)"), kb=aimajulia.crime_kb) == [Dict([Pair(Expression("x"), Expression("West"))])]; 304 | 305 | @test differentiate(expr("x * x"), expr("x")) == expr("(x * 1) + (x * 1)"); 306 | 307 | @test differentiate_simplify(expr("x * x"), expr("x")) == expr("2 * x"); 308 | 309 | @test (constant_symbols(expr("x & y & z | A")) == Set([expr("A")])); 310 | 311 | @test (constant_symbols(expr("(x & B(z)) & Father(John) ==> Farmer(y) | A")) == Set(map(expr, ["A", "John"]))); 312 | 313 | @test (predicate_symbols(expr("x & y & z | A")) == Set()); 314 | 315 | @test (predicate_symbols(expr("(x & B(z)) & Father(John) ==> Farmer(y) | A")) == Set([("B", 1), ("Father", 1), ("Farmer", 1)])); 316 | 317 | @test (predicate_symbols(expr("(x & B(x, y, z)) & F(G(x, y), x) ==> P(Q(R(x, y)), x, y, z)")) == Set([("G", 2), ("P", 4), ("R", 2), ("B", 3), ("Q", 1), ("F", 2)])); 318 | 319 | -------------------------------------------------------------------------------- /rl.jl: -------------------------------------------------------------------------------- 1 | 2 | export PassiveADPAgentMDP, PassiveADPAgentProgram, 3 | PassiveTDAgentProgram, QLearningAgentProgram; 4 | 5 | #= 6 | 7 | PassiveADPAgentMDP is a MDP implementation of AbstractMarkovDecisionProcess 8 | 9 | that consists of a MarkovDecisionProcess 'mdp'. 10 | 11 | =# 12 | struct PassiveADPAgentMDP{T} <: AbstractMarkovDecisionProcess 13 | mdp::MarkovDecisionProcess{T} 14 | 15 | 16 | function PassiveADPAgentMDP{T}(initial::T, actions_list::Set{T}, terminal_states::Set{T}, gamma::Float64) where T 17 | return new(MarkovDecisionProcess(initial, actions_list, terminal_states, Dict(), gamma=gamma)); 18 | end 19 | end 20 | 21 | PassiveADPAgentMDP(initial, actions_list::Set, terminal_states::Set, gamma::Float64) = PassiveADPAgentMDP{typeof(initial)}(initial, actions_list, terminal_states, gamma); 22 | 23 | """ 24 | reward(mdp::PassiveADPAgentMDP, state) 25 | 26 | Return a reward based on the given 'state'. 27 | """ 28 | function reward(mdp::PassiveADPAgentMDP, state) 29 | return mdp.mdp.reward[state]; 30 | end 31 | 32 | """ 33 | transition_model(mdp::PassiveADPAgentMDP, state, action) 34 | 35 | Return a list of (P(s'|s, a), s') pairs given the state 's' and action 'a'. 36 | """ 37 | function transition_model(mdp::PassiveADPAgentMDP, state, action) 38 | return collect((v, k) for (k, v) in get!(mdp.mdp.transitions, (state, action), Dict())); 39 | end 40 | 41 | """ 42 | actions(mdp::PassiveADPAgentMDP, state) 43 | 44 | Return a set of actions that are possible in the given state. 45 | """ 46 | function actions(mdp::PassiveADPAgentMDP, state) 47 | if (state in mdp.mdp.terminal_states) 48 | return Set{Void}([nothing]); 49 | else 50 | return mdp.mdp.actions; 51 | end 52 | end 53 | 54 | """ 55 | policy_evaluation(pi::Dict, U::Dict, mdp::PassiveADPAgentMDP; k::Int64=20) 56 | 57 | Return the updated utilities of the MDP's states by applying the modified policy iteration 58 | algorithm on the given Markov decision process 'mdp', utility function 'U', policy 'pi', 59 | and number of Bellman updates to use 'k'. 60 | """ 61 | function policy_evaluation(pi::Dict, U::Dict, mdp::PassiveADPAgentMDP; k::Int64=20) 62 | for i in 1:k 63 | for state in mdp.mdp.states 64 | if (length(transition_model(mdp, state, pi[state])) != 0) 65 | U[state] = (reward(mdp, state) 66 | + (mdp.mdp.gamma 67 | * sum((p * U[state_prime] for (p, state_prime) in transition_model(mdp, state, pi[state]))))); 68 | else 69 | U[state] = (reward(mdp, state) + (mdp.mdp.gamma * 0)); 70 | end 71 | end 72 | end 73 | return U; 74 | end 75 | 76 | #= 77 | 78 | PassiveADPAgentProgram is a passive reinforcement learning agent based on 79 | 80 | adaptive dynamic programming (Fig. 21.2). 81 | 82 | =# 83 | mutable struct PassiveADPAgentProgram <: AgentProgram 84 | state::Nullable 85 | action::Nullable 86 | U::Dict 87 | pi::Dict 88 | mdp::PassiveADPAgentMDP 89 | N_sa::Dict 90 | N_s_prime_sa::Dict 91 | 92 | function PassiveADPAgentProgram{T <: AbstractMarkovDecisionProcess}(pi::Dict, mdp::T) 93 | return new(Nullable(), 94 | Nullable(), 95 | Dict(), 96 | pi, 97 | PassiveADPAgentMDP(mdp.initial, mdp.actions, mdp.terminal_states, mdp.gamma), 98 | Dict(), 99 | Dict()); 100 | end 101 | end 102 | 103 | function execute(padpap::PassiveADPAgentProgram, percept::Tuple{Any, Any}) 104 | local r_prime::Float64; 105 | s_prime, r_prime = percept; 106 | 107 | push!(padpap.mdp.mdp.states, s_prime); 108 | if (!haskey(padpap.mdp.mdp.reward, s_prime)) 109 | padpap.U[s_prime] = r_prime; 110 | padpap.mdp.mdp.reward[s_prime] = r_prime; 111 | end 112 | if (!isnull(padpap.state)) 113 | padpap.N_sa[(get(padpap.state), get(padpap.action))] = get!(padpap.N_sa, (get(padpap.state), get(padpap.action)), 0) + 1; 114 | padpap.N_s_prime_sa[(s_prime, get(padpap.state), get(padpap.action))] = get!(padpap.N_s_prime_sa, (s_prime, get(padpap.state), get(padpap.action)), 0) + 1; 115 | for t in collect(result_state 116 | for ((result_state, state, action), occurrences) in padpap.N_s_prime_sa 117 | if (((state, action) == (get(padpap.state), get(padpap.action))) && (occurrences != 0))) 118 | get!(padpap.mdp.mdp.transitions, (get(padpap.state), get(padpap.action)), Dict())[t] = padpap.N_s_prime_sa[(t, get(padpap.state), get(padpap.action))] / padpap.N_sa[(get(padpap.state), get(padpap.action))]; 119 | end 120 | end 121 | local U::Dict = policy_evaluation(padpap.pi, padpap.U, padpap.mdp); 122 | if (s_prime in padpap.mdp.mdp.terminal_states) 123 | padpap.state = Nullable(); 124 | padpap.action = Nullable(); 125 | else 126 | padpap.state = Nullable(s_prime); 127 | padpap.action = Nullable(padpap.pi[s_prime]); 128 | end 129 | if (isnull(padpap.action)) 130 | return nothing; 131 | else 132 | return get(padpap.action); 133 | end 134 | end 135 | 136 | function update_state(padpap::PassiveADPAgentProgram, percept::Tuple{Any, Any}) 137 | return percept; 138 | end 139 | 140 | #= 141 | 142 | PassiveTDAgentProgram is a passive reinforcement learning agent that learns 143 | 144 | utility estimates by using temporal differences (Fig. 21.4). 145 | 146 | =# 147 | mutable struct PassiveTDAgentProgram <: AgentProgram 148 | state::Nullable 149 | action::Nullable 150 | reward::Nullable 151 | gamma::Float64 152 | U::Dict 153 | pi::Dict 154 | N_s::Dict 155 | terminal_states::Set 156 | alpha::Function 157 | 158 | function PassiveTDAgentProgram{T <: AbstractMarkovDecisionProcess}(pi::Dict, mdp::T; alpha::Union{Void, Function}=nothing) 159 | local gamma::Float64; 160 | local terminal_states::Set; 161 | local new_alpha::Function; 162 | if (typeof(mdp) <: PassiveADPAgentMDP) 163 | gamma = mdp.mdp.gamma; 164 | terminal_states = mdp.mdp.terminal_states; 165 | else 166 | gamma = mdp.gamma; 167 | terminal_states = mdp.terminal_states; 168 | end 169 | if (typeof(alpha) <: Void) 170 | new_alpha = (function(n::Number) 171 | return (1/(n + 1)); 172 | end); 173 | else 174 | new_alpha = alpha; 175 | end 176 | return new(Nullable(), 177 | Nullable(), 178 | Nullable(), 179 | gamma, 180 | Dict(), 181 | pi, 182 | Dict(), 183 | terminal_states, 184 | new_alpha); 185 | end 186 | end 187 | 188 | function execute(ptdap::PassiveTDAgentProgram, percept::Tuple{Any, Any}) 189 | local r_prime::Float64; 190 | s_prime, r_prime = update_state(ptdap, percept); 191 | if (!haskey(ptdap.N_s, s_prime)) 192 | ptdap.U[s_prime] = r_prime; 193 | end 194 | if (!isnull(ptdap.state)) 195 | ptdap.N_s[get(ptdap.state)] = get!(ptdap.N_s, get(ptdap.state), 0) + 1; 196 | ptdap.U[get(ptdap.state)] = (get!(ptdap.U, get(ptdap.state), 0.0) 197 | + ptdap.alpha(get!(ptdap.N_s, get(ptdap.state), 0)) 198 | * (get(ptdap.reward) 199 | + (ptdap.gamma * get!(ptdap.U, s_prime, 0.0)) 200 | - get!(ptdap.U, get(ptdap.state), 0.0))); 201 | end 202 | if (s_prime in ptdap.terminal_states) 203 | ptdap.state = Nullable(); 204 | ptdap.action = Nullable(); 205 | ptdap.reward = Nullable(); 206 | else 207 | ptdap.state = Nullable(s_prime); 208 | ptdap.action = Nullable(ptdap.pi[s_prime]); 209 | ptdap.reward = Nullable(r_prime); 210 | end 211 | if (isnull(ptdap.action)) 212 | return nothing; 213 | else 214 | return get(ptdap.action); 215 | end 216 | end 217 | 218 | function update_state(ptdap::PassiveTDAgentProgram, percept::Tuple{Any, Any}) 219 | return percept; 220 | end 221 | 222 | #= 223 | 224 | QLearningAgentProgram is an exploratory Q-learning agent that learns the value 225 | 226 | Q(state, action) for each action in each situation (Fig. 21.8). The agent uses the 227 | 228 | same exploration function as the exploratory ADP agent, but avoid learning the 229 | 230 | transition model because the Q-value of a state can be related directly to those of 231 | 232 | its neighbor. 233 | 234 | =# 235 | mutable struct QLearningAgentProgram <: AgentProgram 236 | state::Nullable 237 | action::Nullable 238 | reward::Nullable 239 | gamma::Float64 240 | Q::Dict 241 | N_sa::Dict 242 | actions::Set 243 | terminal_states::Set 244 | R_plus::Float64 # optimistic estimate of the best possible reward obtainable 245 | N_e::Int64 # try action-state pair at least N_e times 246 | f::Function 247 | alpha::Function 248 | 249 | function QLearningAgentProgram{T <: AbstractMarkovDecisionProcess}(mdp::T, N_e::Int64, R_plus::Number; alpha::Union{Void, Function}=nothing) 250 | local new_alpha::Function; 251 | local gamma::Float64; 252 | local actions::Set; 253 | local terminal_states::Set; 254 | if (typeof(mdp) <: PassiveADPAgentMDP) 255 | gamma = mdp.mdp.gamma; 256 | actions = mdp.mdp.actions; 257 | terminal_states = mdp.mdp.terminal_states; 258 | else 259 | gamma = mdp.gamma; 260 | actions = mdp.actions; 261 | terminal_states = mdp.terminal_states; 262 | end 263 | if (typeof(alpha) <: Void) 264 | new_alpha = (function(n::Number) 265 | return (1/(n + 1)); 266 | end); 267 | else 268 | new_alpha = alpha; 269 | end 270 | return new(Nullable(), 271 | Nullable(), 272 | Nullable(), 273 | gamma, 274 | Dict(), 275 | Dict(), 276 | actions, 277 | terminal_states, 278 | R_plus, 279 | N_e, 280 | exploration_function, 281 | new_alpha); 282 | end 283 | end 284 | 285 | function exploration_function(qlap::QLearningAgentProgram, u::Number, n::Number) 286 | if (n < qlap.N_e) 287 | return qlap.R_plus; 288 | else 289 | return u; 290 | end 291 | end 292 | 293 | function actions(qlap::QLearningAgentProgram, state) 294 | if (state in qlap.terminal_states) 295 | return Set([nothing]); 296 | else 297 | return qlap.actions; 298 | end 299 | end 300 | 301 | function execute(qlap::QLearningAgentProgram, percept::Tuple{Any, Any}) 302 | local r_prime::Float64; 303 | s_prime, r_prime = update_state(qlap, percept); 304 | if (!isnull(qlap.state)) 305 | if (get(qlap.state) in qlap.terminal_states) 306 | qlap.Q[(get(qlap.state), nothing)] = r_prime; 307 | end 308 | qlap.N_sa[(get(qlap.state), get(qlap.action))] = get!(qlap.N_sa, (get(qlap.state), get(qlap.action)), 0) + 1; 309 | # Default value for Q keys is 0.0. 310 | get!(qlap.Q, (get(qlap.state), get(qlap.action)), 0.0); 311 | qlap.Q[(get(qlap.state), get(qlap.action))] = (qlap.Q[(get(qlap.state), get(qlap.action))] 312 | + (qlap.alpha(qlap.N_sa[(get(qlap.state), get(qlap.action))]) * 313 | (get(qlap.reward) + 314 | (qlap.gamma * reduce(max, collect(get!(qlap.Q, (s_prime, a_prime), 0.0) 315 | for a_prime in actions(qlap, s_prime)))) 316 | - qlap.Q[(get(qlap.state), get(qlap.action))]))); 317 | end 318 | if (!isnull(qlap.state) && get(qlap.state) in qlap.terminal_states) 319 | qlap.state = Nullable(); 320 | qlap.action = Nullable(); 321 | qlap.reward = Nullable(); 322 | else 323 | qlap.state = Nullable(s_prime); 324 | qlap.action = Nullable(argmax(collect(actions(qlap, s_prime)), 325 | (function(a_prime) 326 | return qlap.f(qlap, get!(qlap.Q, (s_prime, a_prime), 0.0), get!(qlap.N_sa, (s_prime, a_prime), 0)); 327 | end))); 328 | qlap.reward = Nullable(r_prime); 329 | end 330 | if (isnull(qlap.action)) 331 | return nothing; 332 | else 333 | return get(qlap.action); 334 | end 335 | end 336 | 337 | function update_state(qlap::QLearningAgentProgram, percept::Tuple{Any, Any}) 338 | return percept; 339 | end 340 | 341 | """ 342 | take_single_action{T <: AbstractMarkovDecisionProcess}(mdp::T, state, action) 343 | 344 | Return the next state by choosing a weighted sample of the resulting states for 345 | taking the action 'action' in state 'state'. 346 | """ 347 | function take_single_action{T <: AbstractMarkovDecisionProcess}(mdp::T, state, action) 348 | local x::Float64 = rand(RandomDeviceInstance); 349 | local cumulative_probability::Float64 = 0.0; 350 | for (p, state_p) in transition_model(mdp, state, action) 351 | cumulative_probability = cumulative_probability + p; 352 | if (x < cumulative_probability) 353 | return state_p; 354 | end 355 | end 356 | error("take_single_action(): Could not find a valid resulting state for the state ", state, 357 | " and action ", action, "!"); 358 | end 359 | 360 | """ 361 | run_single_trial{T1 <: AgentProgram, T2 <: AbstractMarkovDecisionProcess}(ap::T1, mdp::T2) 362 | 363 | The agent program 'ap' executes a trial in the environment represented by the MDP 'mdp'. 364 | """ 365 | function run_single_trial{T1 <: AgentProgram, T2 <: AbstractMarkovDecisionProcess}(ap::T1, mdp::T2) 366 | current_state = mdp.initial; 367 | while (true) 368 | local current_reward::Float64; 369 | if (typeof(reward(mdp, current_state)) <: Nullable) 370 | current_reward = get(reward(mdp, current_state)); 371 | else 372 | current_reward = reward(mdp, current_state); 373 | end 374 | local percept::Tuple = (current_state, current_reward); 375 | next_action = execute(ap, percept); 376 | if (typeof(next_action) <: Void) 377 | break; 378 | end 379 | current_state = take_single_action(mdp, current_state, next_action); 380 | end 381 | return nothing; 382 | end 383 | 384 | -------------------------------------------------------------------------------- /tests/run_kl_tests.jl: -------------------------------------------------------------------------------- 1 | include("../aimajulia.jl"); 2 | 3 | using Base.Test; 4 | 5 | using aimajulia; 6 | 7 | using aimajulia.utils; 8 | 9 | #The following learning with knowledge tests are from the aima-python doctests 10 | 11 | restaurant_attribute_names = ("Alternate", "Bar", "Fri/Sat", "Hungry", "Patrons", "Price", "Rain", "Reservation", "Type", "WaitEstimate", "GOAL") 12 | 13 | restaurant = [Dict(collect(zip(restaurant_attribute_names, ("Yes", "No", "No", "Yes", "Some", "\$\$\$", "No", "Yes", "French", "0-10", true)))), 14 | Dict(collect(zip(restaurant_attribute_names, ("Yes", "No", "No", "Yes", "Full", "\$", "No", "No", "Thai", "30-60", false)))), 15 | Dict(collect(zip(restaurant_attribute_names, ("No", "Yes", "No", "No", "Some", "\$", "No", "No", "Burger", "0-10", true)))), 16 | Dict(collect(zip(restaurant_attribute_names, ("Yes", "No", "Yes", "Yes", "Full", "\$", "Yes", "No", "Thai", "10-30", true)))), 17 | Dict(collect(zip(restaurant_attribute_names, ("Yes", "No", "Yes", "No", "Full", "\$\$\$", "No", "Yes", "French", ">60", false)))), 18 | Dict(collect(zip(restaurant_attribute_names, ("No", "Yes", "No", "Yes", "Some", "\$\$", "Yes", "Yes", "Italian", "0-10", true)))), 19 | Dict(collect(zip(restaurant_attribute_names, ("No", "Yes", "No", "No", "None", "\$", "Yes", "No", "Burger", "0-10", false)))), 20 | Dict(collect(zip(restaurant_attribute_names, ("No", "No", "No", "Yes", "Some", "\$\$", "Yes", "Yes", "Thai", "0-10", true)))), 21 | Dict(collect(zip(restaurant_attribute_names, ("No", "Yes", "Yes", "No", "Full", "\$", "Yes", "No", "Burger", ">60", false)))), 22 | Dict(collect(zip(restaurant_attribute_names, ("Yes", "Yes", "Yes", "Yes", "Full", "\$\$\$", "No", "Yes", "Italian", "10-30",false)))), 23 | Dict(collect(zip(restaurant_attribute_names, ("No", "No", "No", "No", "None", "\$", "No", "No", "Thai", "0-10", false)))), 24 | Dict(collect(zip(restaurant_attribute_names, ("Yes", "Yes", "Yes", "Yes", "Full", "\$", "No", "No", "Burger", "30-60", true))))]; 25 | 26 | initial_h = [Dict([Pair("Alternate", "Yes")])]; 27 | 28 | h = current_best_learning(restaurant, initial_h); 29 | 30 | @test (map(guess_example_value, restaurant, Base.Iterators.repeated(h)) == [true, false, true, true, false, true, false, true, false, false, false, true]); 31 | 32 | animal_umbrellas = [Dict([("Species", "Cat"), ("Rain", "Yes"), ("Coat", "No"), ("GOAL", true)]), 33 | Dict([("Species", "Cat"), ("Rain", "Yes"), ("Coat", "Yes"), ("GOAL", true)]), 34 | Dict([("Species", "Dog"), ("Rain", "Yes"), ("Coat", "Yes"), ("GOAL", true)]), 35 | Dict([("Species", "Dog"), ("Rain", "Yes"), ("Coat", "No"), ("GOAL", false)]), 36 | Dict([("Species", "Dog"), ("Rain", "No"), ("Coat", "No"), ("GOAL", false)]), 37 | Dict([("Species", "Cat"), ("Rain", "No"), ("Coat", "No"), ("GOAL", false)]), 38 | Dict([("Species", "Cat"), ("Rain", "No"), ("Coat", "Yes"), ("GOAL", true)])]; 39 | 40 | initial_h = [Dict([Pair("Species", "Cat")])]; 41 | 42 | h = current_best_learning(animal_umbrellas, initial_h); 43 | 44 | @test (map(guess_example_value, animal_umbrellas, Base.Iterators.repeated(h)) == [true, true, true, false, false, false, true]); 45 | 46 | party = [Dict([("Pizza", "Yes"), ("Soda", "No"), ("GOAL", true)]), 47 | Dict([("Pizza", "Yes"), ("Soda", "Yes"), ("GOAL", true)]), 48 | Dict([("Pizza", "No"), ("Soda", "No"), ("GOAL", false)])]; 49 | 50 | initial_h = [Dict([Pair("Pizza", "Yes")])]; 51 | 52 | h = current_best_learning(party, initial_h); 53 | 54 | @test (map(guess_example_value, party, Base.Iterators.repeated(h)) == [true, true, false]); 55 | 56 | party = [Dict([("Pizza", "Yes"), ("Soda", "No"), ("GOAL", true)]), 57 | Dict([("Pizza", "Yes"), ("Soda", "Yes"), ("GOAL", true)]), 58 | Dict([("Pizza", "No"), ("Soda", "No"), ("GOAL", false)])]; 59 | 60 | version_space = version_space_learning(party); 61 | 62 | @test (map((function(e::Dict, V::AbstractVector) 63 | for h in V 64 | if (guess_example_value(e, h)) 65 | return true; 66 | end 67 | end 68 | return false; 69 | end), party, Base.Iterators.repeated(version_space)) == [true, true, false]); 70 | 71 | @test ([Dict([Pair("Pizza", "Yes")])] in version_space); 72 | 73 | party = [Dict([("Pizza", "Yes"), ("Soda", "No"), ("GOAL", true)]), 74 | Dict([("Pizza", "Yes"), ("Soda", "Yes"), ("GOAL", true)]), 75 | Dict([("Pizza", "No"), ("Soda", "No"), ("GOAL", false)])]; 76 | 77 | animal_umbrellas = [Dict([("Species", "Cat"), ("Rain", "Yes"), ("Coat", "No"), ("GOAL", true)]), 78 | Dict([("Species", "Cat"), ("Rain", "Yes"), ("Coat", "Yes"), ("GOAL", true)]), 79 | Dict([("Species", "Dog"), ("Rain", "Yes"), ("Coat", "Yes"), ("GOAL", true)]), 80 | Dict([("Species", "Dog"), ("Rain", "Yes"), ("Coat", "No"), ("GOAL", false)]), 81 | Dict([("Species", "Dog"), ("Rain", "No"), ("Coat", "No"), ("GOAL", false)]), 82 | Dict([("Species", "Cat"), ("Rain", "No"), ("Coat", "No"), ("GOAL", false)]), 83 | Dict([("Species", "Cat"), ("Rain", "No"), ("Coat", "Yes"), ("GOAL", true)])]; 84 | 85 | conductance_attribute_names = ("Sample", "Mass", "Temperature", "Material", "Size", "GOAL"); 86 | conductance = [Dict(collect(zip(conductance_attribute_names, ("S1", 12, 26, "Cu", 3, 0.59)))), 87 | Dict(collect(zip(conductance_attribute_names, ("S1", 12, 100, "Cu", 3, 0.57)))), 88 | Dict(collect(zip(conductance_attribute_names, ("S2", 24, 26, "Cu", 6, 0.59)))), 89 | Dict(collect(zip(conductance_attribute_names, ("S3", 12, 26, "Pb", 2, 0.05)))), 90 | Dict(collect(zip(conductance_attribute_names, ("S3", 12, 100, "Pb", 2, 0.04)))), 91 | Dict(collect(zip(conductance_attribute_names, ("S4", 18, 100, "Pb", 3, 0.04)))), 92 | Dict(collect(zip(conductance_attribute_names, ("S4", 18, 100, "Pb", 3, 0.04)))), 93 | Dict(collect(zip(conductance_attribute_names, ("S5", 24, 100, "Pb", 4, 0.04)))), 94 | Dict(collect(zip(conductance_attribute_names, ("S6", 36, 26, "Pb", 6, 0.05))))]; 95 | 96 | @test (minimal_consistent_determination(party, Set(["Pizza", "Soda"])) == Set(["Pizza"])); 97 | 98 | @test (minimal_consistent_determination(party[1:2], Set(["Pizza", "Soda"])) == Set()); 99 | 100 | @test (minimal_consistent_determination(animal_umbrellas, Set(["Species", "Rain", "Coat"])) == Set(["Species", "Rain", "Coat"])); 101 | 102 | @test (minimal_consistent_determination(conductance, Set(["Mass", "Temperature", "Material", "Size"])) == Set(["Temperature", "Material"])); 103 | 104 | @test (minimal_consistent_determination(conductance, Set(["Mass", "Temperature", "Size"])) == Set(["Mass", "Temperature", "Size"])); 105 | 106 | # Initialize FOIL knowledge bases for extend_example(), choose_literal(), new_clause(), 107 | # new_literals(), and foil(). 108 | 109 | test_network = FOILKnowledgeBase([expr("Conn(A, B)"), 110 | expr("Conn(A ,D)"), 111 | expr("Conn(B, C)"), 112 | expr("Conn(D, C)"), 113 | expr("Conn(D, E)"), 114 | expr("Conn(E ,F)"), 115 | expr("Conn(E, G)"), 116 | expr("Conn(G, I)"), 117 | expr("Conn(H, G)"), 118 | expr("Conn(H, I)")]); 119 | 120 | small_family = FOILKnowledgeBase([expr("Mother(Anne, Peter)"), 121 | expr("Mother(Anne, Zara)"), 122 | expr("Mother(Sarah, Beatrice)"), 123 | expr("Mother(Sarah, Eugenie)"), 124 | expr("Father(Mark, Peter)"), 125 | expr("Father(Mark, Zara)"), 126 | expr("Father(Andrew, Beatrice)"), 127 | expr("Father(Andrew, Eugenie)"), 128 | expr("Father(Philip, Anne)"), 129 | expr("Father(Philip, Andrew)"), 130 | expr("Mother(Elizabeth, Anne)"), 131 | expr("Mother(Elizabeth, Andrew)"), 132 | expr("Male(Philip)"), 133 | expr("Male(Mark)"), 134 | expr("Male(Andrew)"), 135 | expr("Male(Peter)"), 136 | expr("Female(Elizabeth)"), 137 | expr("Female(Anne)"), 138 | expr("Female(Sarah)"), 139 | expr("Female(Zara)"), 140 | expr("Female(Beatrice)"), 141 | expr("Female(Eugenie)")]); 142 | 143 | @test (extend_example(test_network, Dict([(expr("x"), expr("A")), (expr("y"), expr("B"))]), expr("Conn(x, z)")) 144 | == [Dict([(expr("x"), expr("A")), 145 | (expr("y"), expr("B")), 146 | (expr("z"), expr("B"))]), 147 | Dict([(expr("x"), expr("A")), 148 | (expr("y"), expr("B")), 149 | (expr("z"), expr("D"))])]); 150 | 151 | @test (extend_example(test_network, Dict([(expr("x"), expr("G"))]), expr("Conn(x, y)")) 152 | == [Dict([(expr("x"), expr("G")), 153 | (expr("y"), expr("I"))])]); 154 | 155 | 156 | @test (extend_example(test_network, Dict([(expr("x"), expr("C"))]), expr("Conn(x, y)")) == []); 157 | 158 | @test (length(extend_example(test_network, Dict(), expr("Conn(x, y)"))) == 10); 159 | 160 | @test (length(extend_example(small_family, Dict([(expr("x"), expr("Andrew"))]), expr("Father(x, y)"))) == 2); 161 | 162 | @test (length(extend_example(small_family, Dict([(expr("x"), expr("Andrew"))]), expr("Mother(x, y)"))) == 0); 163 | 164 | @test (length(extend_example(small_family, Dict([(expr("x"), expr("Andrew"))]), expr("Female(y)"))) == 6); 165 | 166 | # Initialize Tuple of literals and examples for choose_literal(). 167 | 168 | literals = map(expr, ("Conn(p, q)", "Conn(x, z)", "Conn(r, s)", "Conn(t, y)")); 169 | 170 | examples_positive = [Dict([map(expr, ("x", "A")), map(expr, ("y", "B"))]), 171 | Dict([map(expr, ("x", "A")), map(expr, ("y", "D"))])]; 172 | 173 | examples_negative = [Dict([map(expr, ("x", "A")), map(expr, ("y", "C"))]), 174 | Dict([map(expr, ("x", "C")), map(expr, ("y", "A"))]), 175 | Dict([map(expr, ("x", "C")), map(expr, ("y", "B"))]), 176 | Dict([map(expr, ("x", "A")), map(expr, ("y", "I"))])]; 177 | 178 | @test (choose_literal(test_network, literals, (examples_positive, examples_negative)) == expr("Conn(x, z)")); 179 | 180 | literals = map(expr, ("Conn(x, p)", "Conn(p, x)", "Conn(p, q)")); 181 | 182 | examples_positive = [Dict([map(expr, ("x", "C"))]), 183 | Dict([map(expr, ("x", "F"))]), 184 | Dict([map(expr, ("x", "I"))])]; 185 | 186 | examples_negative = [Dict([map(expr, ("x", "D"))]), 187 | Dict([map(expr, ("x", "A"))]), 188 | Dict([map(expr, ("x", "B"))]), 189 | Dict([map(expr, ("x", "G"))])]; 190 | 191 | @test (choose_literal(test_network, literals, (examples_positive, examples_negative)) == expr("Conn(p, x)")); 192 | 193 | literals = map(expr, ("Father(x, y)", "Father(y, x)", "Mother(x, y)", "Mother(x, y)")); 194 | 195 | examples_positive = [Dict([map(expr, ("x", "Philip"))]), 196 | Dict([map(expr, ("x", "Mark"))]), 197 | Dict([map(expr, ("x", "Peter"))])]; 198 | 199 | examples_negative = [Dict([map(expr, ("x", "Elizabeth"))]), 200 | Dict([map(expr, ("x", "Sarah"))])]; 201 | 202 | @test (choose_literal(small_family, literals, (examples_positive, examples_negative)) == expr("Father(x, y)")); 203 | 204 | literals = map(expr, ("Father(x, y)", "Father(y, x)", "Male(x)")); 205 | 206 | examples_positive = [Dict([map(expr, ("x", "Philip"))]), 207 | Dict([map(expr, ("x", "Mark"))]), 208 | Dict([map(expr, ("x", "Andrew"))])]; 209 | 210 | examples_negative = [Dict([map(expr, ("x", "Elizabeth"))]), 211 | Dict([map(expr, ("x", "Sarah"))])]; 212 | 213 | @test (choose_literal(small_family, literals, (examples_positive, examples_negative)) == expr("Male(x)")); 214 | 215 | # Initialize target literal and examples for new_clause(). 216 | 217 | target = expr("Open(x, y)"); 218 | 219 | examples_positive = [Dict([map(expr, ("x", "B"))]), 220 | Dict([map(expr, ("x", "A"))]), 221 | Dict([map(expr, ("x", "G"))])]; 222 | 223 | examples_negative = [Dict([map(expr, ("x", "C"))]), 224 | Dict([map(expr, ("x", "F"))]), 225 | Dict([map(expr, ("x", "I"))])]; 226 | 227 | clause = new_clause(test_network, (examples_positive, examples_negative), target)[1][2]; 228 | 229 | @test ((length(clause) == 1) 230 | && (clause[1].operator == "Conn") 231 | && (clause[1].arguments[1] == expr("x"))); 232 | 233 | target = expr("Flow(x, y)"); 234 | 235 | examples_positive = [Dict([map(expr, ("x", "B"))]), 236 | Dict([map(expr, ("x", "D"))]), 237 | Dict([map(expr, ("x", "E"))]), 238 | Dict([map(expr, ("x", "G"))])]; 239 | 240 | examples_negative = [Dict([map(expr, ("x", "A"))]), 241 | Dict([map(expr, ("x", "C"))]), 242 | Dict([map(expr, ("x", "F"))]), 243 | Dict([map(expr, ("x", "I"))]), 244 | Dict([map(expr, ("x", "H"))])]; 245 | 246 | clause = new_clause(test_network, (examples_positive, examples_negative), target)[1][2]; 247 | 248 | @test ((length(clause) == 2) && 249 | (((clause[1].arguments[1] == expr("x")) && (clause[2].arguments[2] == expr("x"))) 250 | || ((clause[1].arguments[2] == expr("x")) && (clause[2].arguments[1] == expr("x"))))); 251 | 252 | # Check length of returned Tuple for new_literals(). 253 | 254 | @test (length(new_literals(test_network, (expr("p | q"), [expr("p")]))) == 8); 255 | 256 | @test (length(new_literals(test_network, (expr("p"), [expr("q"), expr("p | r")]))) == 15); 257 | 258 | @test (length(new_literals(small_family, (expr("p"), []))) == 8); 259 | 260 | @test (length(new_literals(small_family, (expr("p & q"), []))) == 20); 261 | 262 | # Initialize examples Tuple and target literal for foil(). 263 | 264 | target = expr("Parent(x, y)"); 265 | 266 | examples_positive = [Dict([map(expr, ("x", "Elizabeth")), map(expr, ("y", "Anne"))]), 267 | Dict([map(expr, ("x", "Elizabeth")), map(expr, ("y", "Andrew"))]), 268 | Dict([map(expr, ("x", "Philip")), map(expr, ("y", "Anne"))]), 269 | Dict([map(expr, ("x", "Philip")), map(expr, ("y", "Andrew"))]), 270 | Dict([map(expr, ("x", "Anne")), map(expr, ("y", "Peter"))]), 271 | Dict([map(expr, ("x", "Anne")), map(expr, ("y", "Zara"))]), 272 | Dict([map(expr, ("x", "Mark")), map(expr, ("y", "Peter"))]), 273 | Dict([map(expr, ("x", "Mark")), map(expr, ("y", "Zara"))]), 274 | Dict([map(expr, ("x", "Andrew")), map(expr, ("y", "Beatrice"))]), 275 | Dict([map(expr, ("x", "Andrew")), map(expr, ("y", "Eugenie"))]), 276 | Dict([map(expr, ("x", "Sarah")), map(expr, ("y", "Beatrice"))]), 277 | Dict([map(expr, ("x", "Sarah")), map(expr, ("y", "Eugenie"))])]; 278 | 279 | examples_negative = [Dict([map(expr, ("x", "Anne")), map(expr, ("y", "Eugenie"))]), 280 | Dict([map(expr, ("x", "Beatrice")), map(expr, ("y", "Eugenie"))]), 281 | Dict([map(expr, ("x", "Mark")), map(expr, ("y", "Elizabeth"))]), 282 | Dict([map(expr, ("x", "Beatrice")), map(expr, ("y", "Philip"))])]; 283 | 284 | clauses = foil(small_family, (examples_positive, examples_negative), target); 285 | 286 | @test ((length(clauses) == 2) && 287 | (((clauses[1][2][1] == expr("Father(x, y)")) && (clauses[2][2][1] == expr("Mother(x, y)"))) 288 | || ((clauses[2][2][1] == expr("Father(x, y)")) && (clauses[1][2][1] == expr("Mother(x, y)"))))); 289 | 290 | -------------------------------------------------------------------------------- /games.jl: -------------------------------------------------------------------------------- 1 | 2 | import Base.display; 3 | 4 | export AbstractGame, Figure52Game, TicTacToeGame, ConnectFourGame, 5 | TicTacToeState, ConnectFourState, 6 | minimax_decision, alphabeta_full_search, alphabeta_search, 7 | display, 8 | random_player, alphabeta_player, play_game; 9 | 10 | abstract type AbstractGame end; 11 | 12 | #= 13 | 14 | Game is an abstract game that contains an initial state. 15 | 16 | Games have a corresponding utility function, terminal test, set of legal moves, and transition model. 17 | 18 | =# 19 | 20 | struct Game <: AbstractGame 21 | initial::String 22 | 23 | function Game(initial_state::String) 24 | return new(initial_state); 25 | end 26 | end 27 | 28 | function actions{T <: AbstractGame}(game::T, state::String) 29 | println("actions() is not implemented yet for ", typeof(game), "!"); 30 | nothing; 31 | end 32 | 33 | function result{T <: AbstractGame}(game::T, state::String, move::String) 34 | println("result() is not implemented yet for ", typeof(game), "!"); 35 | nothing; 36 | end 37 | 38 | function utility{T <: AbstractGame}(game::T, state::String, player::String) 39 | println("utility() is not implemented yet for ", typeof(game), "!"); 40 | nothing; 41 | end 42 | 43 | function terminal_test{T <: AbstractGame}(game::T, state::String) 44 | if (length(actions(game, state)) == 0) 45 | return true; 46 | else 47 | return false; 48 | end 49 | end 50 | 51 | function to_move{T <: AbstractGame}(game::T, state::String) 52 | println("to_move() is not implemented yet for ", typeof(game), "!"); 53 | nothing; 54 | end 55 | 56 | function display{T <: AbstractGame}(game::T, state::String) 57 | println(state); 58 | end 59 | 60 | #= 61 | 62 | Figure52Game is the game represented by the game tree in Fig. 5.2. 63 | 64 | =# 65 | struct Figure52Game <: AbstractGame 66 | initial::String 67 | nodes::Dict 68 | utilities::Dict 69 | 70 | function Figure52Game() 71 | return new("A", Dict([ 72 | Pair("A", Dict("A1"=>"B", "A2"=>"C", "A3"=>"D")), 73 | Pair("B", Dict("B1"=>"B1", "B2"=>"B2", "B3"=>"B3")), 74 | Pair("C", Dict("C1"=>"C1", "C2"=>"C2", "C3"=>"C3")), 75 | Pair("D", Dict("D1"=>"D1", "D2"=>"D2", "D3"=>"D3")), 76 | ]), 77 | Dict([ 78 | Pair("B1", 3), 79 | Pair("B2", 12), 80 | Pair("B3", 8), 81 | Pair("C1", 2), 82 | Pair("C2", 4), 83 | Pair("C3", 6), 84 | Pair("D1", 14), 85 | Pair("D2", 5), 86 | Pair("D3", 2), 87 | ])); 88 | end 89 | end 90 | 91 | function actions(game::Figure52Game, state::String) 92 | return collect(keys(get(game.nodes, state, Dict()))); 93 | end 94 | 95 | function result(game::Figure52Game, state::String, move::String) 96 | return game.nodes[state][move]; 97 | end 98 | 99 | function utility(game::Figure52Game, state::String, player::String) 100 | if (player == "MAX") 101 | return game.utilities[state]; 102 | else 103 | return -game.utilities[state]; 104 | end 105 | end 106 | 107 | function terminal_test(game::Figure52Game, state::String) 108 | return !(state in ["A", "B", "C", "D"]); 109 | end 110 | 111 | function to_move(game::Figure52Game, state::String) 112 | return if_((state in ["B", "C", "D"]), "MIN", "MAX"); 113 | end 114 | 115 | struct TicTacToeState 116 | turn::String 117 | utility::Int64 118 | board::Dict 119 | moves::AbstractVector 120 | 121 | function TicTacToeState(turn::String, utility::Int64, board::Dict, moves::AbstractVector) 122 | return new(turn, utility, board, moves); 123 | end 124 | end 125 | 126 | #= 127 | 128 | TicTacToeGame is a AbstractGame implementation of the Tic-tac-toe game. 129 | 130 | =# 131 | struct TicTacToeGame <: AbstractGame 132 | initial::TicTacToeState 133 | h::Int64 134 | v::Int64 135 | k::Int64 136 | 137 | function TicTacToeGame(initial::TicTacToeState) 138 | return new(initial, 3, 3, 3); 139 | end 140 | 141 | function TicTacToeGame() 142 | return new(TicTacToeState("X", 0, Dict(), collect((x, y) for x in 1:3 for y in 1:3)), 3, 3, 3); 143 | end 144 | end 145 | 146 | function actions(game::TicTacToeGame, state::TicTacToeState) 147 | return state.moves; 148 | end 149 | 150 | function result(game::TicTacToeGame, state::TicTacToeState, move::Tuple{Signed, Signed}) 151 | if (!(move in state.moves)) 152 | return state; 153 | end 154 | local board::Dict = copy(state.board); 155 | board[move] = state.turn; 156 | local moves::Array{eltype(state.moves), 1} = collect(state.moves); 157 | for (i, element) in enumerate(moves) 158 | if (element == move) 159 | deleteat!(moves, i); 160 | break; 161 | end 162 | end 163 | return TicTacToeState(if_((state.turn == "X"), "O", "X"), compute_utility(game, board, move, state.turn), board, moves); 164 | end 165 | 166 | function utility(game::TicTacToeGame, state::TicTacToeState, player::String) 167 | return if_((player == "X"), state.utility, -state.utility); 168 | end 169 | 170 | function terminal_test(game::TicTacToeGame, state::TicTacToeState) 171 | return ((state.utility != 0) || (length(state.moves) == 0)); 172 | end 173 | 174 | function to_move(game::TicTacToeGame, state::TicTacToeState) 175 | return state.turn; 176 | end 177 | 178 | function display(game::TicTacToeGame, state::TicTacToeState) 179 | for x in 1:game.h 180 | for y in 1:game.v 181 | print(get(state.board, (x, y), ".")); 182 | end 183 | println(); 184 | end 185 | end 186 | 187 | function compute_utility{T <: Dict}(game::TicTacToeGame, board::T, move::Tuple{Signed, Signed}, player::String) 188 | if (k_in_row(game, board, move, player, (0, 1)) || 189 | k_in_row(game, board, move, player, (1, 0)) || 190 | k_in_row(game, board, move, player, (1, -1)) || 191 | k_in_row(game, board, move, player, (1, 1))) 192 | return if_((player == "X"), 1, -1); 193 | else 194 | return 0; 195 | end 196 | end 197 | 198 | function k_in_row(game::TicTacToeGame, board::Dict, move::Tuple{Signed, Signed}, player::String, delta::Tuple{Signed, Signed}) 199 | local delta_x::Int64 = Int64(getindex(delta, 1)); 200 | local delta_y::Int64 = Int64(getindex(delta, 2)); 201 | local x::Int64 = Int64(getindex(move, 1)); 202 | local y::Int64 = Int64(getindex(move, 2)); 203 | local n::Int64 = Int64(0); 204 | while (get(board, (x,y), nothing) == player) 205 | n = n + 1; 206 | x = x + delta_x; 207 | y = y + delta_y; 208 | end 209 | x = Int64(getindex(move, 1)); 210 | y = Int64(getindex(move, 2)); 211 | while (get(board, (x,y), nothing) == player) 212 | n = n + 1; 213 | x = x - delta_x; 214 | y = y - delta_y; 215 | end 216 | n = n - 1; #remove the duplicate check on get(board, move, nothing) 217 | return n >= game.k; 218 | end 219 | 220 | const ConnectFourState = TicTacToeState; 221 | 222 | #= 223 | 224 | ConnectFourGame is a AbstractGame implementation of the Connect Four game. 225 | 226 | =# 227 | struct ConnectFourGame <: AbstractGame 228 | initial::ConnectFourState 229 | h::Int64 230 | v::Int64 231 | k::Int64 232 | 233 | function ConnectFourGame(initial::ConnectFourState) 234 | return new(initial, 3, 3, 3); 235 | end 236 | 237 | function ConnectFourGame() 238 | return new(ConnectFourState("X", 0, Dict(), collect((x, y) for x in 1:7 for y in 1:6)), 7, 6, 4); 239 | end 240 | end 241 | 242 | function actions(game::ConnectFourGame, state::ConnectFourState) 243 | return collect((x,y) for (x, y) in state.moves if ((y == 0) || ((x, y - 1) in state.board))); 244 | end 245 | 246 | function result(game::ConnectFourGame, state::ConnectFourState, move::Tuple{Signed, Signed}) 247 | if (!(move in state.moves)) 248 | return state; 249 | end 250 | local board::Dict = copy(state.board); 251 | board[move] = state.turn; 252 | local moves::Array{eltype(state.moves), 1} = collect(state.moves); 253 | for (i, element) in enumerate(moves) 254 | if (element == move) 255 | deleteat!(moves, i); 256 | break; 257 | end 258 | end 259 | return ConnectFourState(if_((state.turn == "X"), "O", "X"), compute_utility(game, board, move, state.turn), board, moves); 260 | end 261 | 262 | function utility(game::ConnectFourGame, state::ConnectFourState, player::String) 263 | return if_((player == "X"), state.utility, -state.utility); 264 | end 265 | 266 | function terminal_test(game::ConnectFourGame, state::ConnectFourState) 267 | return ((state.utility != 0) || (length(state.moves) == 0)); 268 | end 269 | 270 | function to_move(game::ConnectFourGame, state::ConnectFourState) 271 | return state.turn; 272 | end 273 | 274 | function display(game::ConnectFourGame, state::ConnectFourState) 275 | for x in 1:game.h 276 | for y in 1:game.v 277 | print(get(state.board, (x, y), ".")); 278 | end 279 | println(); 280 | end 281 | end 282 | 283 | function compute_utility{T <: Dict}(game::ConnectFourGame, board::T, move::Tuple{Signed, Signed}, player::String) 284 | if (k_in_row(game, board, move, player, (0, 1)) || 285 | k_in_row(game, board, move, player, (1, 0)) || 286 | k_in_row(game, board, move, player, (1, -1)) || 287 | k_in_row(game, board, move, player, (1, 1))) 288 | return if_((player == "X"), 1, -1); 289 | else 290 | return 0; 291 | end 292 | end 293 | 294 | function k_in_row(game::ConnectFourGame, board::Dict, move::Tuple{Signed, Signed}, player::String, delta::Tuple{Signed, Signed}) 295 | local delta_x::Int64 = Int64(getindex(delta, 1)); 296 | local delta_y::Int64 = Int64(getindex(delta, 2)); 297 | local x::Int64 = Int64(getindex(move, 1)); 298 | local y::Int64 = Int64(getindex(move, 2)); 299 | local n::Int64 = Int64(0); 300 | while (get(board, (x,y), nothing) == player) 301 | n = n + 1; 302 | x = x + delta_x; 303 | y = y + delta_y; 304 | end 305 | x = Int64(getindex(move, 1)); 306 | y = Int64(getindex(move, 2)); 307 | while (get(board, (x,y), nothing) == player) 308 | n = n + 1; 309 | x = x - delta_x; 310 | y = y - delta_y; 311 | end 312 | n = n - 1; #remove the duplicate check on get(board, move, nothing) 313 | return n >= game.k; 314 | end 315 | 316 | function minimax_max_value{T <: AbstractGame}(game::T, player::String, state::String) 317 | if (terminal_test(game, state)) 318 | return utility(game, state, player) 319 | end 320 | local v::Float64 = -Inf64; 321 | v = reduce(max, vcat(v, collect(minimax_min_value(game, player, result(game, state, action)) 322 | for action in actions(game, state)))); 323 | return v; 324 | end 325 | 326 | function minimax_min_value{T <: AbstractGame}(game::T, player::String, state::String) 327 | if (terminal_test(game, state)) 328 | return utility(game, state, player); 329 | end 330 | local v::Float64 = Inf64; 331 | v = reduce(min, vcat(v, collect(minimax_max_value(game, player, result(game, state, action)) 332 | for action in actions(game, state)))); 333 | return v; 334 | end 335 | 336 | """ 337 | minimax_decision(state, game) 338 | 339 | Calculate the best move by searching through moves, all the way to the leaves (terminal states) (Fig 5.3). 340 | """ 341 | function minimax_decision{T <: AbstractGame}(state::String, game::T) 342 | local player = to_move(game, state); 343 | return argmax(actions(game, state), 344 | (function(action::String,; relevant_game::AbstractGame=game, relevant_player::String=player, relevant_state::String=state) 345 | return minimax_min_value(relevant_game, relevant_player, result(relevant_game, relevant_state, action)); 346 | end)); 347 | end 348 | 349 | function alphabeta_full_search_max_value{T <: AbstractGame}(game::T, player::String, state::String, alpha::Number, beta::Number) 350 | if (terminal_test(game, state)) 351 | return utility(game, state, player) 352 | end 353 | local v::Float64 = -Inf64; 354 | for action in actions(game, state) 355 | v = max(v, alphabeta_full_search_min_value(game, player, result(game, state, action), alpha, beta)); 356 | if (v >= beta) 357 | return v; 358 | end 359 | alpha = max(alpha, v); 360 | end 361 | return v; 362 | end 363 | 364 | function alphabeta_full_search_min_value{T <: AbstractGame}(game::T, player::String, state::String, alpha::Number, beta::Number) 365 | if (terminal_test(game, state)) 366 | return utility(game, state, player); 367 | end 368 | local v::Float64 = Inf64; 369 | for action in actions(game, state) 370 | v = min(v, alphabeta_full_search_max_value(game, player, result(game, state, action), alpha, beta)); 371 | if (v <= alpha) 372 | return v; 373 | end 374 | beta = min(beta, v); 375 | end 376 | return v; 377 | end 378 | 379 | """ 380 | alphabeta_full_search(state, game) 381 | 382 | Search the given game to find the best action using alpha-beta pruning (Fig 5.7). 383 | """ 384 | function alphabeta_full_search{T <: AbstractGame}(state::String, game::T) 385 | local player::String = to_move(game, state); 386 | return argmax(actions(game, state), 387 | (function(action::String,; relevant_game::AbstractGame=game, relevant_state::String=state, relevant_player::String=player) 388 | return alphabeta_full_search_min_value(relevant_game, relevant_player, result(relevant_game, relevant_state, action), -Inf64, Inf64); 389 | end)); 390 | end 391 | 392 | function alphabeta_search_max_value{T <: AbstractGame}(game::T, player::String, cutoff_test_fn::Function, evaluation_fn::Function, state::String, alpha::Number, beta::Number, depth::Int64) 393 | if (cutoff_test_fn(state, depth)) 394 | return evaluation_fn(state); 395 | end 396 | local v::Float64 = -Inf64; 397 | for action in actions(game, state) 398 | v = max(v, alphabeta_search_min_value(game, player, cutoff_test_fn, evaluation_fn, result(game, state, action), alpha, beta, depth + 1)); 399 | if (v >= beta) 400 | return v; 401 | end 402 | alpha = max(alpha, v); 403 | end 404 | return v; 405 | end 406 | 407 | function alphabeta_search_max_value{T <: AbstractGame}(game::T, player::String, cutoff_test_fn::Function, evaluation_fn::Function, state::TicTacToeState, alpha::Number, beta::Number, depth::Int64) 408 | if (cutoff_test_fn(state, depth)) 409 | return evaluation_fn(state); 410 | end 411 | local v::Float64 = -Inf64; 412 | for action in actions(game, state) 413 | v = max(v, alphabeta_search_min_value(game, player, cutoff_test_fn, evaluation_fn, result(game, state, action), alpha, beta, depth + 1)); 414 | if (v >= beta) 415 | return v; 416 | end 417 | alpha = max(alpha, v); 418 | end 419 | return v; 420 | end 421 | 422 | function alphabeta_search_min_value{T <: AbstractGame}(game::T, player::String, cutoff_test_fn::Function, evaluation_fn::Function, state::String, alpha::Number, beta::Number, depth::Int64) 423 | if (cutoff_test_fn(state, depth)) 424 | return evaluation_fn(state); 425 | end 426 | local v::Float64 = Inf64; 427 | for action in actions(game, state) 428 | v = min(v, alphabeta_search_max_value(game, player, cutoff_test_fn, evaluation_fn, result(game, state, action), alpha, beta, depth + 1)); 429 | if (v >= alpha) 430 | return v; 431 | end 432 | beta = min(alpha, v); 433 | end 434 | return v; 435 | end 436 | 437 | function alphabeta_search_min_value{T <: AbstractGame}(game::T, player::String, cutoff_test_fn::Function, evaluation_fn::Function, state::TicTacToeState, alpha::Number, beta::Number, depth::Int64) 438 | if (cutoff_test_fn(state, depth)) 439 | return evaluation_fn(state); 440 | end 441 | local v::Float64 = Inf64; 442 | for action in actions(game, state) 443 | v = min(v, alphabeta_search_max_value(game, player, cutoff_test_fn, evaluation_fn, result(game, state, action), alpha, beta, depth + 1)); 444 | if (v >= alpha) 445 | return v; 446 | end 447 | beta = min(alpha, v); 448 | end 449 | return v; 450 | end 451 | 452 | """ 453 | alphabeta_search(state, game) 454 | 455 | Search the given game to find the best action using alpha-beta pruning. However, this function also uses a 456 | cutoff test to cut off the search early and apply a heuristic evaluation function to turn nonterminal 457 | states into terminal states. 458 | """ 459 | function alphabeta_search{T <: AbstractGame}(state::String, game::T; d::Int64=4, cutoff_test_fn::Union{Void, Function}=nothing, evaluation_fn::Union{Void, Function}=nothing) 460 | local player::String = to_move(game, state); 461 | if (typeof(cutoff_test_fn) <: Void) 462 | cutoff_test_fn = (function(state::String, depth::Int64; dvar::Int64=d, relevant_game::AbstractGame=game) 463 | return ((depth > dvar) || terminal_test(relevant_game, state)); 464 | end); 465 | end 466 | if (typeof(evaluation_fn) <: Void) 467 | evaluation_fn = (function(state::String, ; relevant_game::AbstractGame=game, relevant_player::String=player) 468 | return utility(relevant_game, state, relevant_player); 469 | end); 470 | end 471 | return argmax(actions(game, state), 472 | (function(action::String,; relevant_game::AbstractGame=game, relevant_state::String=state, relevant_player::String=player, cutoff_test::Function=cutoff_test_fn, eval_fn::Function=evaluation_fn) 473 | return alphabeta_search_min_value(relevant_game, relevant_player, cutoff_test, eval_fn, result(relevant_game, relevant_state, action), -Inf64, Inf64, 0); 474 | end)); 475 | end 476 | 477 | function alphabeta_search{T <: AbstractGame}(state::TicTacToeState, game::T; d::Int64=4, cutoff_test_fn::Union{Void, Function}=nothing, evaluation_fn::Union{Void, Function}=nothing) 478 | local player::String = to_move(game, state); 479 | if (typeof(cutoff_test_fn) <: Void) 480 | cutoff_test_fn = (function(state::TicTacToeState, depth::Int64; dvar::Int64=d, relevant_game::AbstractGame=game) 481 | return ((depth > dvar) || terminal_test(relevant_game, state)); 482 | end); 483 | end 484 | if (typeof(evaluation_fn) <: Void) 485 | evaluation_fn = (function(state::TicTacToeState, ; relevant_game::AbstractGame=game, relevant_player::String=player) 486 | return utility(relevant_game, state, relevant_player); 487 | end); 488 | end 489 | return argmax(actions(game, state), 490 | (function(action::Tuple{Signed, Signed},; relevant_game::AbstractGame=game, relevant_state::TicTacToeState=state, relevant_player::String=player, cutoff_test::Function=cutoff_test_fn, eval_fn::Function=evaluation_fn) 491 | return alphabeta_search_min_value(relevant_game, relevant_player, cutoff_test, eval_fn, result(relevant_game, relevant_state, action), -Inf64, Inf64, 0); 492 | end)); 493 | end 494 | 495 | function random_player{T <: AbstractGame}(game::T, state::String) 496 | return rand(RandomDeviceInstance, actions(game, state)); 497 | end 498 | 499 | function random_player{T <: AbstractGame}(game::T, state::TicTacToeState) 500 | return rand(RandomDeviceInstance, actions(game, state)); 501 | end 502 | 503 | function alphabeta_player{T <: AbstractGame}(game::T, state::String) 504 | return alphabeta_search(state, game); 505 | end 506 | 507 | function alphabeta_player{T <: AbstractGame}(game::T, state::TicTacToeState) 508 | return alphabeta_search(state, game); 509 | end 510 | 511 | function play_game{T <: AbstractGame}(game::T, players::Vararg{Function}) 512 | state = game.initial; 513 | while (true) 514 | for player in players 515 | move = player(game, state); 516 | state = result(game, state, move); 517 | if (terminal_test(game, state)) 518 | return utility(game, state, to_move(game, game.initial)); 519 | end 520 | end 521 | end 522 | end 523 | 524 | -------------------------------------------------------------------------------- /kl.jl: -------------------------------------------------------------------------------- 1 | 2 | # Learning with knowledge 3 | 4 | export guess_example_value, generate_powerset, current_best_learning, 5 | version_space_learning, 6 | is_consistent_determination, minimal_consistent_determination, 7 | FOILKnowledgeBase, extend_example, choose_literal, new_literals, 8 | new_clause, foil; 9 | 10 | function disjunction_value(e::Dict, d::Dict) 11 | for (k, v) in d 12 | if (!(typeof(v) <: AbstractString)) 13 | error("disjunction_value(): Found an unexpected type, ", typeof(v), "!"); 14 | end 15 | # Check for negation 16 | if (v[1] == '!') 17 | if (e[k] == v[2:end]) 18 | return false; 19 | end 20 | elseif (e[k] != v) 21 | return false; 22 | end 23 | end 24 | return true; 25 | end 26 | 27 | """ 28 | guess_example_value(e::Dict, h::AbstractVector) 29 | 30 | Return a guess for the logical value of the given example 'e' based on the given hypothesis 'h'. 31 | """ 32 | function guess_example_value(e::Dict, h::AbstractVector) 33 | for d in h 34 | if (disjunction_value(e, d)) 35 | return true; 36 | end 37 | end 38 | return false; 39 | end 40 | 41 | function example_is_consistent(e::Dict, h::AbstractVector) 42 | return (e["GOAL"] == guess_example_value(e, h)); 43 | end 44 | 45 | function example_is_false_positive(e::Dict, h::AbstractVector) 46 | if (e["GOAL"] == false) 47 | if (guess_example_value(e, h)) 48 | return true; 49 | end 50 | end 51 | return false; 52 | end 53 | 54 | function example_is_false_negative(e::Dict, h::AbstractVector) 55 | if (e["GOAL"] == true) 56 | if (!(guess_example_value(e, h))) 57 | return true; 58 | end 59 | end 60 | return false; 61 | end 62 | 63 | function check_all_consistency(examples::AbstractVector, h::AbstractVector) 64 | for example in examples 65 | if (!(example_is_consistent(example, h))) 66 | return false; 67 | end 68 | end 69 | return true; 70 | end 71 | 72 | function specializations(prior_examples::AbstractVector, h::AbstractVector) 73 | local hypotheses::AbstractVector = []; 74 | for (i, disjunction) in enumerate(h) 75 | for example in prior_examples 76 | for (k, v) in example 77 | if ((haskey(disjunction, k)) || k == "GOAL") 78 | continue; 79 | end 80 | 81 | local h_prime::Dict = copy(h[i]); 82 | h_prime[k] = "!" * v; 83 | local h_prime_prime::AbstractVector = copy(h); 84 | h_prime_prime[i] = h_prime; 85 | 86 | if (check_all_consistency(prior_examples, h_prime_prime)) 87 | push!(hypotheses, h_prime_prime); 88 | end 89 | end 90 | end 91 | end 92 | shuffle!(RandomDeviceInstance, hypotheses); 93 | return hypotheses; 94 | end 95 | 96 | function check_negative_consistency(examples::AbstractVector, h::Dict) 97 | for example in examples 98 | if (example["GOAL"]) 99 | continue; 100 | end 101 | if (!example_is_consistent(example, [h])) 102 | return false; 103 | end 104 | end 105 | return true; 106 | end 107 | 108 | function generate_powerset(array::AbstractVector) 109 | local result::AbstractVector = Array{Any, 1}([()]); 110 | for element in array 111 | for i in eachindex(result) 112 | push!(result, (result[i]..., element)); 113 | end 114 | end 115 | return Set{Tuple}(result); 116 | end 117 | 118 | function add_or_examples(prior_examples::AbstractVector, h::AbstractVector) 119 | local result::AbstractVector = []; 120 | local example::Dict = prior_examples[end]; 121 | local attributes::Dict = Dict((k, v) for (k, v) in example if (k != "GOAL")); 122 | local attribute_powerset = setdiff!(generate_powerset(collect(keys(attributes))), Set([()])); 123 | 124 | for subset in attribute_powerset 125 | local h_prime::Dict = Dict(); 126 | for key in subset 127 | h_prime[key] = attributes[key]; 128 | end 129 | if (check_negative_consistency(prior_examples, h_prime)) 130 | local h_prime_prime::AbstractVector = copy(h); 131 | push!(h_prime_prime, h_prime); 132 | push!(result, h_prime_prime); 133 | end 134 | end 135 | 136 | return result; 137 | end 138 | 139 | function generalizations(prior_examples::AbstractVector, h::AbstractVector) 140 | local hypotheses::AbstractVector = []; 141 | # Remove the empty set from the powerset. 142 | local disjunctions_powerset::Set = setdiff!(generate_powerset(collect(1:length(h))), Set([()])); 143 | for disjunctions in disjunctions_powerset 144 | h_prime = copy(h); 145 | deleteat!(h_prime, disjunctions); 146 | 147 | if (check_all_consistency(prior_examples, h_prime)) 148 | append!(hypotheses, h_prime); 149 | end 150 | end 151 | 152 | for (i, disjunction) in enumerate(h) 153 | local attribute_powerset::Set = setdiff!(generate_powerset(collect(keys(disjunction))), Set([()])); 154 | for attributes in attribute_powerset 155 | h_prime = copy(h[i]); 156 | 157 | if (check_all_consistency(prior_examples, [h_prime])) 158 | local h_prime_prime::AbstractVector = copy(h); 159 | h_prime_prime[i] = copy(h_prime); 160 | push!(hypotheses, h_prime_prime); 161 | end 162 | end 163 | end 164 | if ((length(hypotheses) == 0) || (hypotheses == [Dict()])) 165 | hypotheses = add_or_examples(prior_examples, h); 166 | else 167 | append!(hypotheses, add_or_examples(prior_examples, h)); 168 | end 169 | 170 | shuffle!(hypotheses); 171 | return hypotheses; 172 | end 173 | 174 | """ 175 | current_best_learning(examples::AbstractVector, h::AbstractVector, prior_examples::AbstractVector) 176 | current_best_learning(examples::AbstractVector, h::AbstractVector) 177 | 178 | Apply the current-best-hypothesis learning algorithm (Fig. 19.2) on the given examples 'examples' 179 | and hypothesis 'h' (an array of dictionaries where each Dict represents a disjunction). Return 180 | a consistent hypothesis if possible, otherwise 'nothing' on failure. 181 | """ 182 | function current_best_learning(examples::AbstractVector, h::AbstractVector, prior_examples::AbstractVector) 183 | if (length(examples) == 0) 184 | return h; 185 | end 186 | 187 | local example::Dict = examples[1]; 188 | 189 | push!(prior_examples, example); 190 | 191 | if (example_is_consistent(example, h)) 192 | return current_best_learning(examples[2:end], h, prior_examples); 193 | elseif (example_is_false_positive(example, h)) 194 | for h_prime in specializations(prior_examples, h) 195 | h_prime_prime = current_best_learning(examples[2:end], h_prime, prior_examples); 196 | if (!(typeof(h_prime_prime) <: Void)) 197 | return h_prime_prime; 198 | end 199 | end 200 | elseif (example_is_false_negative(example, h)) 201 | for h_prime in generalizations(prior_examples, h) 202 | h_prime_prime = current_best_learning(examples[2:end], h_prime, prior_examples); 203 | if (!(typeof(h_prime_prime) <: Void)) 204 | return h_prime_prime; 205 | end 206 | end 207 | end 208 | return nothing; 209 | end 210 | 211 | function current_best_learning(examples::AbstractVector, h::AbstractVector) 212 | return current_best_learning(examples, h, []); 213 | end 214 | 215 | function version_space_update(V::AbstractVector, e::Dict) 216 | return collect(h for h in V if (example_is_consistent(e, h))); 217 | end 218 | 219 | function values_table(examples::AbstractVector) 220 | local values::Dict = Dict(); 221 | for example in examples 222 | for (k, v) in example 223 | if (k == "GOAL") 224 | continue 225 | end 226 | local modifier::String = "!"; 227 | if (example["GOAL"]) 228 | modifier = ""; 229 | end 230 | 231 | local modified_value::String = modifier * v; 232 | if (!(modified_value in get!(values, k, []))) 233 | push!(get!(values, k, []), modified_value); 234 | end 235 | end 236 | end 237 | return values; 238 | end 239 | 240 | function build_attribute_combinations(subset::Tuple, values::Dict) 241 | local h::AbstractVector = []; 242 | if (length(subset) == 1) 243 | k = values[subset[1]] 244 | h = collect([Dict([Pair(subset[1], v)])] for v in values[subset[1]]); 245 | return h; 246 | end 247 | for (i, attribute) in enumerate(subset) 248 | local rest::AbstractVector = build_attribute_combinations(subset[2:end], values); 249 | for value in values[attribute] 250 | local combination::Dict = Dict([Pair(attribute, value)]); 251 | for rest_item in rest 252 | local combination_prime::Dict = copy(combination); 253 | for dictionary in rest_item 254 | merge!(combination_prime, dictionary); 255 | end 256 | push!(h, [combination_prime]); 257 | end 258 | end 259 | end 260 | return h; 261 | end 262 | 263 | function build_h_combinations(hypotheses::AbstractVector) 264 | local h::AbstractVector = []; 265 | local h_powerset::Set = setdiff!(generate_powerset(collect(1:length(hypotheses))), Set([()])); 266 | 267 | for subset in h_powerset 268 | local combination::AbstractVector = []; 269 | for index in subset 270 | append!(combination, hypotheses[index]); 271 | end 272 | push!(h, combination); 273 | end 274 | 275 | return h; 276 | end 277 | 278 | function all_hypotheses(examples::AbstractVector) 279 | local values::Dict = values_table(examples); 280 | local h_powerset::Set = setdiff!(generate_powerset(collect(keys(values))), Set([()])); 281 | local hypotheses::AbstractVector = []; 282 | for subset in h_powerset 283 | append!(hypotheses, build_attribute_combinations(subset, values)); 284 | end 285 | append!(hypotheses, build_h_combinations(hypotheses)); 286 | return hypotheses; 287 | end 288 | 289 | """ 290 | version_space_learning(examples::AbstractVector) 291 | 292 | Return a version space for the given 'examples' by using the version space learning 293 | algorithm (Fig. 19.3). 294 | """ 295 | function version_space_learning(examples::AbstractVector) 296 | local V::AbstractVector = all_hypotheses(examples); 297 | for example in examples 298 | if (length(V) != 0) 299 | V = version_space_update(V, example); 300 | end 301 | end 302 | return V; 303 | end 304 | 305 | function is_consistent_determination(A::AbstractVector, E::AbstractVector) 306 | local H::Dict = Dict(); 307 | 308 | for example in E 309 | local attribute_values::Tuple = Tuple((collect(example[attribute] for attribute in A)...)); 310 | if (haskey(H, attribute_values)) 311 | if (H[attribute_values] != example["GOAL"]) 312 | return false; 313 | end 314 | end 315 | H[attribute_values] = example["GOAL"]; 316 | end 317 | 318 | return true; 319 | end 320 | 321 | """ 322 | minimal_consistent_determination(E::AbstractVector, A::Set) 323 | 324 | Return a set of attributes by using the algorithm for finding a minimal consistent 325 | determination (Fig. 19.8). 326 | """ 327 | function minimal_consistent_determination(E::AbstractVector, A::Set) 328 | local n::Int64 = length(A); 329 | for i in 0:n 330 | for A_i in combinations(A, i); 331 | if (is_consistent_determination(A_i, E)) 332 | return Set(A_i); 333 | end 334 | end 335 | end 336 | return nothing; 337 | end 338 | 339 | #= 340 | 341 | FOILKnowledgeBase is a knowledge base that consists of first order logic definite clauses, 342 | 343 | constant symbols, and predicate symbols used by foil(). 344 | 345 | =# 346 | mutable struct FOILKnowledgeBase <: AbstractKnowledgeBase 347 | fol_kb::FirstOrderLogicKnowledgeBase 348 | constant_symbols::Set 349 | predicate_symbols::Set 350 | 351 | function FOILKnowledgeBase() 352 | return new(FirstOrderLogicKnowledgeBase(), Set(), Set()); 353 | end 354 | 355 | function FOILKnowledgeBase(initial_clauses::Array{Expression, 1}) 356 | local fkb::FOILKnowledgeBase = new(FirstOrderLogicKnowledgeBase(), Set(), Set()); 357 | for clause in initial_clauses 358 | tell(fkb, clause); 359 | end 360 | return fkb; 361 | end 362 | end 363 | 364 | function tell(fkb::FOILKnowledgeBase, e::Expression) 365 | if (!is_logic_definite_clause(e)) 366 | error("tell(): ", repr(e), " , is not a definite clause!"); 367 | end 368 | 369 | tell(fkb.fol_kb, e); 370 | fkb.constant_symbols = union(fkb.constant_symbols, constant_symbols(e)); 371 | fkb.predicate_symbols = union(fkb.predicate_symbols, predicate_symbols(e)); 372 | 373 | return nothing; 374 | end 375 | 376 | function ask(fkb::FOILKnowledgeBase, e::Expression) 377 | return fol_bc_ask(fkb.fol_kb, e); 378 | end 379 | 380 | function retract(fkb::FOILKnowledgeBase, e::Expression) 381 | retract(fkb.fol_kb, e); 382 | nothing; 383 | end 384 | 385 | """ 386 | extend_example(fkb::FOILKnowledgeBase, example::Dict, literal::Expression) 387 | 388 | Return an array of extended examples by extending the given example 'example' to satisfy 389 | the given literal 'literal'. 390 | """ 391 | function extend_example(fkb::FOILKnowledgeBase, example::Dict, literal::Expression) 392 | local solution::AbstractVector = []; 393 | local substitutions::Tuple = ask(fkb, substitute(example, literal)); 394 | for substitution in substitutions 395 | push!(solution, merge!(substitution, example)); 396 | end 397 | return solution; 398 | end 399 | 400 | """ 401 | update_positive_examples(fkb::FOILKnowledgeBase, examples_positive::AbstractVector, extended_positive_examples::AbstractVector, target::Expression) 402 | 403 | Return an array of uncovered positive examples given the positive examples 'positive_examples' and 404 | the extended positive examples 'extended_positive_examples'. 405 | """ 406 | function update_positive_examples(fkb::FOILKnowledgeBase, examples_positive::AbstractVector, extended_positive_examples::AbstractVector, target::Expression) 407 | local uncovered_positive_examples::Array{Dict, 1} = Array{Dict, 1}(); 408 | for example in examples_positive 409 | if (any((function(dict::Dict) 410 | return all((dict[x] == example[x]) for x in keys(example)); 411 | end), 412 | extended_positive_examples)) 413 | tell(fkb, substitute(example, target)); 414 | else 415 | push!(uncovered_positive_examples, example); 416 | end 417 | end 418 | return uncovered_positive_examples; 419 | end 420 | 421 | """ 422 | new_literals(fkb::FOILKnowledgeBase, clause::Tuple{Expression, AbstractVector}) 423 | 424 | Return a Tuple of literals given the known predicate symbols in the FOIL knowledge base 'fkb' 425 | and the horn clause 'clause'. 426 | 427 | Each literal in the returned literals share at least 1 variable with the given horn clause. 428 | """ 429 | function new_literals(fkb::FOILKnowledgeBase, clause::Tuple{Expression, AbstractVector}) 430 | local share_known_variables::Set = variables(clause[1]); 431 | for literal in clause[2] 432 | union!(share_known_variables, variables(literal)); 433 | end 434 | local result::Tuple = (); 435 | for (predicate, arity) in fkb.predicate_symbols 436 | local new_variables::Set = Set(collect(standardize_variables(expr("x"), standardize_variables_counter) 437 | for i in 1:(arity - 1))); 438 | for arguments in iterable_cartesian_product(fill(union(share_known_variables, new_variables), arity)) 439 | if (any((variable in share_known_variables) for variable in arguments)) 440 | result = Tuple((result..., Expression(predicate, arguments...))); 441 | end 442 | end 443 | end 444 | return result; 445 | end 446 | 447 | """ 448 | choose_literal(fkb::FOILKnowledgeBase, literals::Tuple, examples::Tuple{AbstractVector, AbstractVector}) 449 | 450 | Return the best literal from the given literals 'literals' by comparing the information gained. 451 | """ 452 | function choose_literal(fkb::FOILKnowledgeBase, literals::Tuple, examples::Tuple{AbstractVector, AbstractVector}) 453 | local information_gain::Function = (function(literal::Expression) 454 | local examples_positive::Int64 = length(examples[1]); 455 | local examples_negative::Int64 = length(examples[2]); 456 | local extended_examples::AbstractVector = collect(vcat(collect(extend_example(fkb, example, literal) 457 | for example in examples[i])...) 458 | for i in 1:2); 459 | local extended_examples_positive::Int64 = length(extended_examples[1]); 460 | local extended_examples_negative::Int64 = length(extended_examples[2]); 461 | if ((examples_positive + examples_negative == 0) || 462 | (extended_examples_positive + extended_examples_negative == 0)) 463 | return (literal, -1); 464 | end 465 | local T::Int64 = 0; 466 | for example in examples[1] 467 | if (any((function(l_prime::Dict) 468 | return all((l_prime[x] == example[x]) for x in keys(example)); 469 | end), 470 | extended_examples[1])) 471 | T = T + 1; 472 | end 473 | end 474 | return (literal, (T * log((extended_examples_positive * (examples_positive + examples_negative) + 0.0001)/((extended_examples_positive + extended_examples_negative) * examples_positive)))); 475 | end); 476 | 477 | local gains::Tuple = map(information_gain, literals); 478 | return reduce((function(t1::Tuple, t2::Tuple) 479 | if (getindex(t1, 2) < getindex(t2, 2)) 480 | return t2; 481 | else 482 | return t1; 483 | end 484 | end), gains)[1]; 485 | end 486 | 487 | """ 488 | new_clause(fkb::FOILKnowledgeBase, examples::Tuple{AbstractVector, AbstractVector}, target::Expression) 489 | 490 | Return a horn clause and the extended positive examples as Tuple. 491 | 492 | The horn clause is represented as (consequent, array of antecendents). 493 | """ 494 | function new_clause(fkb::FOILKnowledgeBase, examples::Tuple{AbstractVector, AbstractVector}, target::Expression) 495 | local clause::Tuple = (target, Array{Expression, 1}()); 496 | extended_examples = examples; 497 | while (length(extended_examples[2]) != 0) 498 | local literal::Expression = choose_literal(fkb, new_literals(fkb, clause), extended_examples); 499 | push!(clause[2], literal); 500 | extended_examples = (collect(vcat(collect(extend_example(fkb, example, literal) 501 | for example in extended_examples[i])...) 502 | for i in 1:2)...); 503 | end 504 | return (clause, extended_examples[1]); 505 | end 506 | 507 | """ 508 | foil(fkb::FOILKnowledgeBase, examples::Tuple{AbstractVector, AbstractVector}, target::Expression) 509 | 510 | Return an array of horn clauses by using the FOIL algorithm (Fig. 19.12) on the given FOIL knowledge 511 | base 'fkb', set of examples 'examples', and the target literal 'target'. 512 | """ 513 | function foil(fkb::FOILKnowledgeBase, examples::Tuple{AbstractVector, AbstractVector}, target::Expression) 514 | local clauses::AbstractVector = []; 515 | local positive_examples::AbstractVector; 516 | local negative_examples::AbstractVector; 517 | positive_examples, negative_examples = examples; 518 | 519 | while (length(positive_examples) != 0) 520 | local clause::Tuple; 521 | local positive_extended_examples::AbstractVector; 522 | clause, positive_extended_examples = new_clause(fkb, (positive_examples, negative_examples), target); 523 | # Remove postive examples covered by 'clause' from 'examples' 524 | positive_examples = update_positive_examples(fkb, positive_examples, positive_extended_examples, target); 525 | push!(clauses, clause); 526 | end 527 | return clauses; 528 | end 529 | 530 | --------------------------------------------------------------------------------