├── tests ├── __init__.py ├── .pytest_cache │ └── v │ │ └── cache │ │ └── lastfailed ├── pytest.ini ├── test_games.py ├── test_rl.py ├── test_mdp.py ├── test_learning.py ├── test_utils.py ├── test_agents.py ├── test_nlp.py ├── test_text.py ├── test_knowledge.py └── test_logic.py ├── pytest.ini ├── images ├── -4.jpg ├── 4.jpg ├── ge0.jpg ├── ge1.jpg ├── ge2.jpg ├── ge4.jpg ├── mdp.png ├── pop.jpg ├── -0.04.jpg ├── -0.4.jpg ├── maze.png ├── mdp-a.png ├── mdp-b.png ├── mdp-c.png ├── mdp-d.png ├── bayesnet.png ├── fig_5_2.png ├── grid_mdp.jpg ├── knn_plot.png ├── queen_s.png ├── aima3e_big.jpg ├── aima_logo.png ├── cake_graph.jpg ├── neural_net.png ├── parse_tree.png ├── perceptron.png ├── refinement.png ├── restaurant.png ├── wall-icon.jpg ├── dirt05-icon.jpg ├── hillclimb-tsp.png ├── random_forest.png ├── romania_map.png ├── sprinklernet.jpg ├── vacuum-icon.jpg ├── grid_mdp_agent.jpg ├── point_crossover.png ├── decisiontree_fruit.jpg ├── ensemble_learner.jpg ├── uniform_crossover.png ├── simple_reflex_agent.jpg ├── general_learning_agent.jpg ├── knowledge_foil_family.png ├── model_goal_based_agent.jpg ├── pluralityLearner_plot.png ├── model_based_reflex_agent.jpg ├── model_utility_based_agent.jpg ├── knowledge_FOIL_grandparent.png ├── simple_problem_solving_agent.jpg ├── IMAGE-CREDITS ├── makefile └── vacuum.svg ├── requirements.txt ├── .gitmodules ├── .flake8 ├── .travis.yml ├── SUBMODULE.md ├── LICENSE ├── .gitignore ├── js ├── continuousworld.js ├── canvas.js └── gridworld.js ├── index.ipynb ├── gui ├── eight_puzzle.py ├── vacuum_agent.py ├── genetic_algorithm_example.py ├── tic-tac-toe.py └── xy_vacuum_environment.py ├── ipyviews.py ├── intro.ipynb ├── CONTRIBUTING.md ├── rl.py └── knowledge.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/.pytest_cache/v/cache/lastfailed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::ResourceWarning 4 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::ResourceWarning -------------------------------------------------------------------------------- /images/-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/-4.jpg -------------------------------------------------------------------------------- /images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/4.jpg -------------------------------------------------------------------------------- /images/ge0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/ge0.jpg -------------------------------------------------------------------------------- /images/ge1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/ge1.jpg -------------------------------------------------------------------------------- /images/ge2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/ge2.jpg -------------------------------------------------------------------------------- /images/ge4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/ge4.jpg -------------------------------------------------------------------------------- /images/mdp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/mdp.png -------------------------------------------------------------------------------- /images/pop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/pop.jpg -------------------------------------------------------------------------------- /images/-0.04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/-0.04.jpg -------------------------------------------------------------------------------- /images/-0.4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/-0.4.jpg -------------------------------------------------------------------------------- /images/maze.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/maze.png -------------------------------------------------------------------------------- /images/mdp-a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/mdp-a.png -------------------------------------------------------------------------------- /images/mdp-b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/mdp-b.png -------------------------------------------------------------------------------- /images/mdp-c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/mdp-c.png -------------------------------------------------------------------------------- /images/mdp-d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/mdp-d.png -------------------------------------------------------------------------------- /images/bayesnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/bayesnet.png -------------------------------------------------------------------------------- /images/fig_5_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/fig_5_2.png -------------------------------------------------------------------------------- /images/grid_mdp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/grid_mdp.jpg -------------------------------------------------------------------------------- /images/knn_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/knn_plot.png -------------------------------------------------------------------------------- /images/queen_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/queen_s.png -------------------------------------------------------------------------------- /images/aima3e_big.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/aima3e_big.jpg -------------------------------------------------------------------------------- /images/aima_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/aima_logo.png -------------------------------------------------------------------------------- /images/cake_graph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/cake_graph.jpg -------------------------------------------------------------------------------- /images/neural_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/neural_net.png -------------------------------------------------------------------------------- /images/parse_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/parse_tree.png -------------------------------------------------------------------------------- /images/perceptron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/perceptron.png -------------------------------------------------------------------------------- /images/refinement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/refinement.png -------------------------------------------------------------------------------- /images/restaurant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/restaurant.png -------------------------------------------------------------------------------- /images/wall-icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/wall-icon.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==1.11 2 | jupyter 3 | pandas 4 | matplotlib 5 | pillow 6 | Image 7 | -------------------------------------------------------------------------------- /images/dirt05-icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/dirt05-icon.jpg -------------------------------------------------------------------------------- /images/hillclimb-tsp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/hillclimb-tsp.png -------------------------------------------------------------------------------- /images/random_forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/random_forest.png -------------------------------------------------------------------------------- /images/romania_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/romania_map.png -------------------------------------------------------------------------------- /images/sprinklernet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/sprinklernet.jpg -------------------------------------------------------------------------------- /images/vacuum-icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/vacuum-icon.jpg -------------------------------------------------------------------------------- /images/grid_mdp_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/grid_mdp_agent.jpg -------------------------------------------------------------------------------- /images/point_crossover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/point_crossover.png -------------------------------------------------------------------------------- /images/decisiontree_fruit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/decisiontree_fruit.jpg -------------------------------------------------------------------------------- /images/ensemble_learner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/ensemble_learner.jpg -------------------------------------------------------------------------------- /images/uniform_crossover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/uniform_crossover.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "aima-data"] 2 | path = aima-data 3 | url = https://github.com/aimacode/aima-data.git 4 | -------------------------------------------------------------------------------- /images/simple_reflex_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/simple_reflex_agent.jpg -------------------------------------------------------------------------------- /images/general_learning_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/general_learning_agent.jpg -------------------------------------------------------------------------------- /images/knowledge_foil_family.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/knowledge_foil_family.png -------------------------------------------------------------------------------- /images/model_goal_based_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/model_goal_based_agent.jpg -------------------------------------------------------------------------------- /images/pluralityLearner_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/pluralityLearner_plot.png -------------------------------------------------------------------------------- /images/model_based_reflex_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/model_based_reflex_agent.jpg -------------------------------------------------------------------------------- /images/model_utility_based_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/model_utility_based_agent.jpg -------------------------------------------------------------------------------- /images/knowledge_FOIL_grandparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/knowledge_FOIL_grandparent.png -------------------------------------------------------------------------------- /images/simple_problem_solving_agent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/aima-python/master/images/simple_problem_solving_agent.jpg -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore = E121,E123,E126,E221,E222,E225,E226,E242,E701,E702,E704,E731,W503,F405,F841 4 | exclude = tests 5 | -------------------------------------------------------------------------------- /images/IMAGE-CREDITS: -------------------------------------------------------------------------------- 1 | PHOTO CREDITS 2 | 3 | Image After http://www.imageafter.com/ 4 | 5 | b15woods003.jpg 6 | (Cropped to 764x764 and scaled to 50x50 to make wall-icon.jpg 7 | by Gregory Weber) 8 | 9 | Noctua Graphics, http://www.noctua-graphics.de/english/fraset_e.htm 10 | 11 | dirt05.jpg 512x512 12 | (Scaled to 50x50 to make dirt05-icon.jpg by Gregory Weber) 13 | 14 | Gregory Weber 15 | 16 | dirt.svg, dirt.png 17 | vacuum.svg, vacuum.png 18 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: 2 | - python 3 | 4 | python: 5 | - "3.4" 6 | 7 | before_install: 8 | - git submodule update --remote 9 | 10 | install: 11 | - pip install six 12 | - pip install flake8 13 | - pip install ipython 14 | - pip install matplotlib 15 | - pip install networkx 16 | - pip install ipywidgets 17 | - pip install Pillow 18 | 19 | script: 20 | - py.test 21 | - python -m doctest -v *.py 22 | 23 | after_success: 24 | - flake8 --max-line-length 100 --ignore=E121,E123,E126,E221,E222,E225,E226,E242,E701,E702,E704,E731,W503 . 25 | 26 | notifications: 27 | email: false 28 | -------------------------------------------------------------------------------- /SUBMODULE.md: -------------------------------------------------------------------------------- 1 | This is a guide on how to update the `aima-data` submodule to the latest version. This needs to be done every time something changes in the [aima-data](https://github.com/aimacode/aima-data) repository. All the below commands should be executed from the local directory of the `aima-python` repository, using `git`. 2 | 3 | ``` 4 | git submodule deinit aima-data 5 | git rm aima-data 6 | git submodule add https://github.com/aimacode/aima-data.git aima-data 7 | git commit 8 | git push origin 9 | ``` 10 | 11 | Then you need to pull request the changes (unless you are a collaborator, in which case you can commit directly to the master). 12 | -------------------------------------------------------------------------------- /images/makefile: -------------------------------------------------------------------------------- 1 | # makefile for images 2 | 3 | Sources = dirt.svg vacuum.svg 4 | 5 | Targets = $(Sources:.svg=.png) 6 | 7 | ImageScale = 50x50 8 | 9 | Temporary = tmp.jpg 10 | 11 | .PHONY: all 12 | 13 | all: $(Targets) 14 | 15 | .PHONY: clean 16 | 17 | clean: 18 | rm -f $(Targets) $(Temporary) 19 | 20 | %.png: %.svg 21 | convert -scale $(ImageScale) $< $@ 22 | 23 | %-icon.jpg: %.svg 24 | convert -scale $(ImageScale) $< $@ 25 | 26 | %-icon.jpg: %.jpg 27 | convert -scale $(ImageScale) $< $@ 28 | 29 | wall-icon.jpg: b15woods003.jpg 30 | convert -crop 764x764+0+0 $< tmp.jpg 31 | convert -resize 50x50+0+0 tmp.jpg $@ 32 | 33 | vacuum-icon.jpg: vacuum.svg 34 | convert -scale $(ImageScale) -transparent white $< $@ 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 aima-python contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # IPython Notebook 66 | .ipynb_checkpoints 67 | 68 | # pyenv 69 | .python-version 70 | 71 | # dotenv 72 | .env 73 | .idea 74 | 75 | # for macOS 76 | .DS_Store 77 | ._.DS_Store 78 | -------------------------------------------------------------------------------- /js/continuousworld.js: -------------------------------------------------------------------------------- 1 | var latest_output_area ="NONE"; // Jquery object for the DOM element of output area which was used most recently 2 | function handle_output(out, block){ 3 | var output = out.content.data["text/html"]; 4 | latest_output_area.html(output); 5 | } 6 | function polygon_complete(canvas, vertices){ 7 | latest_output_area = $(canvas).parents('.output_subarea'); 8 | var world_object_name = canvas.dataset.world_name; 9 | var command = world_object_name + ".handle_add_obstacle(" + JSON.stringify(vertices) + ")"; 10 | console.log("Executing Command: " + command); 11 | var kernel = IPython.notebook.kernel; 12 | var callbacks = { 'iopub' : {'output' : handle_output}}; 13 | kernel.execute(command,callbacks); 14 | } 15 | var canvas , ctx; 16 | function drawPolygon(array) { 17 | ctx.fillStyle = '#f00'; 18 | ctx.beginPath(); 19 | ctx.moveTo(array[0][0],array[0][1]); 20 | for(var i = 1;i1) 40 | { 41 | drawPoint(pArray[0][0],pArray[0][1]); 42 | } 43 | //check overlap 44 | if(ctx.isPointInPath(x, y) && (pArray.length>1)) { 45 | //Do something 46 | drawPolygon(pArray); 47 | polygon_complete(canvas,pArray); 48 | } 49 | else { 50 | var point = new Array(); 51 | point.push(x,y); 52 | pArray.push(point); 53 | } 54 | } 55 | function drawPoint(x, y) { 56 | ctx.beginPath(); 57 | ctx.arc(x, y, 5, 0, Math.PI*2); 58 | ctx.fillStyle = '#00f'; 59 | ctx.fill(); 60 | ctx.closePath(); 61 | } 62 | function initalizeObstacles(objects) { 63 | canvas = $('canvas.main-robo-world').get(0); 64 | ctx = canvas.getContext('2d'); 65 | $('canvas.main-robo-world').removeClass('main-robo-world'); 66 | for(var i=0;i= 0 60 | 61 | # The player 'X' (one who plays first) in TicTacToe never loses: 62 | assert ttt.play_game(alphabeta_player, random_player) >= 0 63 | -------------------------------------------------------------------------------- /tests/test_rl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rl import * 4 | from mdp import sequential_decision_environment 5 | 6 | 7 | north = (0, 1) 8 | south = (0,-1) 9 | west = (-1, 0) 10 | east = (1, 0) 11 | 12 | policy = { 13 | (0, 2): east, (1, 2): east, (2, 2): east, (3, 2): None, 14 | (0, 1): north, (2, 1): north, (3, 1): None, 15 | (0, 0): north, (1, 0): west, (2, 0): west, (3, 0): west, 16 | } 17 | 18 | def test_PassiveDUEAgent(): 19 | agent = PassiveDUEAgent(policy, sequential_decision_environment) 20 | for i in range(200): 21 | run_single_trial(agent,sequential_decision_environment) 22 | agent.estimate_U() 23 | # Agent does not always produce same results. 24 | # Check if results are good enough. 25 | #print(agent.U[(0, 0)], agent.U[(0,1)], agent.U[(1,0)]) 26 | assert agent.U[(0, 0)] > 0.15 # In reality around 0.3 27 | assert agent.U[(0, 1)] > 0.15 # In reality around 0.4 28 | assert agent.U[(1, 0)] > 0 # In reality around 0.2 29 | 30 | def test_PassiveADPAgent(): 31 | agent = PassiveADPAgent(policy, sequential_decision_environment) 32 | for i in range(100): 33 | run_single_trial(agent,sequential_decision_environment) 34 | 35 | # Agent does not always produce same results. 36 | # Check if results are good enough. 37 | #print(agent.U[(0, 0)], agent.U[(0,1)], agent.U[(1,0)]) 38 | assert agent.U[(0, 0)] > 0.15 # In reality around 0.3 39 | assert agent.U[(0, 1)] > 0.15 # In reality around 0.4 40 | assert agent.U[(1, 0)] > 0 # In reality around 0.2 41 | 42 | 43 | 44 | def test_PassiveTDAgent(): 45 | agent = PassiveTDAgent(policy, sequential_decision_environment, alpha=lambda n: 60./(59+n)) 46 | for i in range(200): 47 | run_single_trial(agent,sequential_decision_environment) 48 | 49 | # Agent does not always produce same results. 50 | # Check if results are good enough. 51 | assert agent.U[(0, 0)] > 0.15 # In reality around 0.3 52 | assert agent.U[(0, 1)] > 0.15 # In reality around 0.35 53 | assert agent.U[(1, 0)] > 0.15 # In reality around 0.25 54 | 55 | 56 | def test_QLearning(): 57 | q_agent = QLearningAgent(sequential_decision_environment, Ne=5, Rplus=2, 58 | alpha=lambda n: 60./(59+n)) 59 | 60 | for i in range(200): 61 | run_single_trial(q_agent,sequential_decision_environment) 62 | 63 | # Agent does not always produce same results. 64 | # Check if results are good enough. 65 | assert q_agent.Q[((0, 1), (0, 1))] >= -0.5 # In reality around 0.1 66 | assert q_agent.Q[((1, 0), (0, -1))] <= 0.5 # In reality around -0.1 67 | -------------------------------------------------------------------------------- /index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AIMA Python Binder Index\n", 8 | "\n", 9 | "Welcome to the AIMA Python Code Repository. You should be seeing this index notebook if you clicked on the **Launch Binder** button on the [repository](https://github.com/aimacode/aima-python). If you are viewing this notebook directly on Github we suggest that you use the **Launch Binder** button instead. Binder allows you to experiment with all the code in the browser itself without the need of installing anything on your local machine. Below is the list of notebooks that should assist you in navigating the different notebooks available. \n", 10 | "\n", 11 | "If you are completely new to AIMA Python or Jupyter Notebooks we suggest that you start with the Introduction Notebook.\n", 12 | "\n", 13 | "# List of Notebooks\n", 14 | "\n", 15 | "1. [**Introduction**](./intro.ipynb)\n", 16 | "\n", 17 | "2. [**Agents**](./agents.ipynb)\n", 18 | "\n", 19 | "3. [**Search**](./search.ipynb)\n", 20 | "\n", 21 | "4. [**Search - 4th edition**](./search-4e.ipynb)\n", 22 | "\n", 23 | "4. [**Games**](./games.ipynb)\n", 24 | "\n", 25 | "5. [**Constraint Satisfaction Problems**](./csp.ipynb)\n", 26 | "\n", 27 | "6. [**Logic**](./logic.ipynb)\n", 28 | "\n", 29 | "7. [**Planning**](./planning.ipynb)\n", 30 | "\n", 31 | "8. [**Probability**](./probability.ipynb)\n", 32 | "\n", 33 | "9. [**Markov Decision Processes**](./mdp.ipynb)\n", 34 | "\n", 35 | "10. [**Learning**](./learning.ipynb)\n", 36 | "\n", 37 | "11. [**Reinforcement Learning**](./rl.ipynb)\n", 38 | "\n", 39 | "12. [**Statistical Language Processing Tools**](./text.ipynb)\n", 40 | "\n", 41 | "13. [**Natural Language Processing**](./nlp.ipynb)\n", 42 | "\n", 43 | "Besides the notebooks it is also possible to make direct modifications to the Python/JS code. To view/modify the complete set of files [click here](.) to view the Directory structure." 44 | ] 45 | } 46 | ], 47 | "metadata": { 48 | "kernelspec": { 49 | "display_name": "Python 3", 50 | "language": "python", 51 | "name": "python3" 52 | }, 53 | "language_info": { 54 | "codemirror_mode": { 55 | "name": "ipython", 56 | "version": 3 57 | }, 58 | "file_extension": ".py", 59 | "mimetype": "text/x-python", 60 | "name": "python", 61 | "nbconvert_exporter": "python", 62 | "pygments_lexer": "ipython3", 63 | "version": "3.5.1" 64 | } 65 | }, 66 | "nbformat": 4, 67 | "nbformat_minor": 0 68 | } 69 | -------------------------------------------------------------------------------- /js/canvas.js: -------------------------------------------------------------------------------- 1 | /* 2 | JavaScript functions that are executed by running the corresponding methods of a Canvas object 3 | Donot use these functions by making a js file. Instead use the python Canvas class. 4 | See canvas.py for help on how to use the Canvas class to draw on the HTML Canvas 5 | */ 6 | 7 | 8 | //Manages the output of code executed in IPython kernel 9 | function output_callback(out, block){ 10 | console.log(out); 11 | //Handle error in python 12 | if(out.msg_type == "error"){ 13 | console.log("Error in python script!"); 14 | console.log(out.content); 15 | return ; 16 | } 17 | script = out.content.data['text/html']; 18 | script = script.substr(8, script.length - 17); 19 | eval(script) 20 | } 21 | 22 | //Handles mouse click by calling mouse_click of Canvas object with the co-ordinates as arguments 23 | function click_callback(element, event, varname){ 24 | var rect = element.getBoundingClientRect(); 25 | var x = event.clientX - rect.left; 26 | var y = event.clientY - rect.top; 27 | var kernel = IPython.notebook.kernel; 28 | var exec_str = varname + ".mouse_click(" + String(x) + ", " + String(y) + ")"; 29 | console.log(exec_str); 30 | kernel.execute(exec_str,{'iopub': {'output': output_callback}}, {silent: false}); 31 | } 32 | 33 | function rgbToHex(r,g,b){ 34 | var hexValue=(r<<16) + (g<<8) + (b<<0); 35 | var hexString=hexValue.toString(16); 36 | hexString ='#' + Array(7-hexString.length).join('0') + hexString; //Add 0 padding 37 | return hexString; 38 | } 39 | 40 | function toRad(x){ 41 | return x*Math.PI/180; 42 | } 43 | 44 | //Canvas class to store variables 45 | function Canvas(id){ 46 | this.canvas = document.getElementById(id); 47 | this.ctx = this.canvas.getContext("2d"); 48 | this.WIDTH = this.canvas.width; 49 | this.HEIGHT = this.canvas.height; 50 | this.MOUSE = {x:0,y:0}; 51 | } 52 | 53 | //Sets the fill color with which shapes are filled 54 | Canvas.prototype.fill = function(r, g, b){ 55 | this.ctx.fillStyle = rgbToHex(r,g,b); 56 | } 57 | 58 | //Set the stroke color 59 | Canvas.prototype.stroke = function(r, g, b){ 60 | this.ctx.strokeStyle = rgbToHex(r,g,b); 61 | } 62 | 63 | //Set width of the lines/strokes 64 | Canvas.prototype.strokeWidth = function(w){ 65 | this.ctx.lineWidth = w; 66 | } 67 | 68 | //Draw a rectangle with top left at (x,y) with 'w' width and 'h' height 69 | Canvas.prototype.rect = function(x, y, w, h){ 70 | this.ctx.fillRect(x,y,w,h); 71 | } 72 | 73 | //Draw a line with (x1, y1) and (x2, y2) as end points 74 | Canvas.prototype.line = function(x1, y1, x2, y2){ 75 | this.ctx.beginPath(); 76 | this.ctx.moveTo(x1, y1); 77 | this.ctx.lineTo(x2, y2); 78 | this.ctx.stroke(); 79 | } 80 | 81 | //Draw an arc with (x, y) as centre, 'r' as radius from angles start to stop 82 | Canvas.prototype.arc = function(x, y, r, start, stop){ 83 | this.ctx.beginPath(); 84 | this.ctx.arc(x, y, r, toRad(start), toRad(stop)); 85 | this.ctx.stroke(); 86 | } 87 | 88 | //Clear the HTML canvas 89 | Canvas.prototype.clear = function(){ 90 | this.ctx.clearRect(0, 0, this.WIDTH, this.HEIGHT); 91 | } 92 | 93 | //Change font, size and style 94 | Canvas.prototype.font = function(font_str){ 95 | this.ctx.font = font_str; 96 | } 97 | 98 | //Draws "filled" text on the canvas 99 | Canvas.prototype.fill_text = function(text, x, y){ 100 | this.ctx.fillText(text, x, y); 101 | } 102 | 103 | //Write text on the canvas 104 | Canvas.prototype.stroke_text = function(text, x, y){ 105 | this.ctx.strokeText(text, x, y); 106 | } 107 | 108 | 109 | //Test if the canvas functions are working 110 | Canvas.prototype.test_run = function(){ 111 | var dbg = false; 112 | if(dbg) 113 | alert("1"); 114 | this.clear(); 115 | if(dbg) 116 | alert("2"); 117 | this.fill(0, 200, 0); 118 | if(dbg) 119 | alert("3"); 120 | this.rect(this.MOUSE.x, this.MOUSE.y, 100, 200); 121 | if(dbg) 122 | alert("4"); 123 | this.stroke(0, 0, 50); 124 | if(dbg) 125 | alert("5"); 126 | this.line(0, 0, 100, 100); 127 | if(dbg) 128 | alert("6"); 129 | this.stroke(200, 200, 200); 130 | if(dbg) 131 | alert("7"); 132 | this.arc(200, 100, 50, 0, 360); 133 | if(dbg) 134 | alert("8"); 135 | } 136 | -------------------------------------------------------------------------------- /js/gridworld.js: -------------------------------------------------------------------------------- 1 | var latest_output_area ="NONE"; // Jquery object for the DOM element of output area which was used most recently 2 | 3 | function handle_output(out, block){ 4 | var output = out.content.data["text/html"]; 5 | latest_output_area.html(output); 6 | } 7 | 8 | function handle_click(canvas,coord) { 9 | console.log(canvas,coord); 10 | latest_output_area = $(canvas).parents('.output_subarea'); 11 | $(canvas).parents('.output_subarea') 12 | var world_object_name = canvas.dataset.world_name; 13 | var command = world_object_name + ".handle_click(" + JSON.stringify(coord) + ")"; 14 | console.log("Executing Command: " + command); 15 | var kernel = IPython.notebook.kernel; 16 | var callbacks = { 'iopub' : {'output' : handle_output}}; 17 | kernel.execute(command,callbacks); 18 | }; 19 | 20 | 21 | function generateGridWorld(state,size,elements) 22 | { 23 | // Declaring array to store image object 24 | var $imgArray = new Object(), hasImg=false; 25 | // Loading images LOOP 26 | $.each(elements, function(i, val) { 27 | // filtering for type img 28 | if(val["type"]=="img") { 29 | // setting image load 30 | hasImg = true; 31 | $imgArray[i] = $('').attr({height:size,width:size,src:val["source"]}).data({name:i,loaded:false}).load(function(){ 32 | // Check for all image loaded 33 | var execute=true; 34 | $(this).data("loaded",true); 35 | $.each($imgArray, function(i, val) { 36 | if(!$(this).data("loaded")) { 37 | execute=false; 38 | // exit on unloaded image 39 | return false; 40 | } 41 | }); 42 | if (execute) { 43 | // Converting loaded image to canvas covering block size. 44 | $.each($imgArray, function(i, val) { 45 | $imgArray[i] = $('').attr({width:size,height:size}).get(0); 46 | $imgArray[i].getContext('2d').drawImage(val.get(0),0,0,size,size); 47 | }); 48 | // initialize the world 49 | initializeWorld(); 50 | } 51 | }); 52 | } 53 | }); 54 | 55 | if(!hasImg) { 56 | initializeWorld(); 57 | } 58 | 59 | function initializeWorld(){ 60 | var $parentDiv = $('div.map-grid-world'); 61 | // remove object reference 62 | $('div.map-grid-world').removeClass('map-grid-world'); 63 | // get some info about the canvas 64 | var row = state.length; 65 | var column = state[0].length; 66 | var canvas = $parentDiv.find('canvas').get(0); 67 | var ctx = canvas.getContext('2d'); 68 | canvas.width = size * column; 69 | canvas.height = size * row; 70 | 71 | //Initialize previous positions 72 | for(var i=0;i=0 && gx=0 && gy 16 | 17 | 18 | 19 | 23 | ''' # noqa 24 | 25 | with open('js/continuousworld.js', 'r') as js_file: 26 | _JS_CONTINUOUS_WORLD = js_file.read() 27 | 28 | 29 | class ContinuousWorldView: 30 | """ View for continuousworld Implementation in agents.py """ 31 | 32 | def __init__(self, world, fill="#AAA"): 33 | self.time = time.time() 34 | self.world = world 35 | self.width = world.width 36 | self.height = world.height 37 | 38 | def object_name(self): 39 | globals_in_main = {x: getattr(__main__, x) for x in dir(__main__)} 40 | for x in globals_in_main: 41 | if isinstance(globals_in_main[x], type(self)): 42 | if globals_in_main[x].time == self.time: 43 | return x 44 | 45 | def handle_add_obstacle(self, vertices): 46 | """ Vertices must be a nestedtuple. This method 47 | is called from kernel.execute on completion of 48 | a polygon. """ 49 | self.world.add_obstacle(vertices) 50 | self.show() 51 | 52 | def handle_remove_obstacle(self): 53 | return NotImplementedError 54 | 55 | def get_polygon_obstacles_coordinates(self): 56 | obstacle_coordiantes = [] 57 | for thing in self.world.things: 58 | if isinstance(thing, PolygonObstacle): 59 | obstacle_coordiantes.append(thing.coordinates) 60 | return obstacle_coordiantes 61 | 62 | def show(self): 63 | clear_output() 64 | total_html = _CONTINUOUS_WORLD_HTML.format(self.width, self.height, self.object_name(), 65 | str(self.get_polygon_obstacles_coordinates()), 66 | _JS_CONTINUOUS_WORLD) 67 | display(HTML(total_html)) 68 | 69 | 70 | # ______________________________________________________________________________ 71 | # Grid environment 72 | 73 | _GRID_WORLD_HTML = ''' 74 |
75 | 76 |
77 | 78 |
79 |
80 | 84 | ''' 85 | 86 | with open('js/gridworld.js', 'r') as js_file: 87 | _JS_GRID_WORLD = js_file.read() 88 | 89 | 90 | class GridWorldView: 91 | """ View for grid world. Uses XYEnviornment in agents.py as model. 92 | world: an instance of XYEnviornment. 93 | block_size: size of individual blocks in pixes. 94 | default_fill: color of blocks. A hex value or name should be passed. 95 | """ 96 | 97 | def __init__(self, world, block_size=30, default_fill="white"): 98 | self.time = time.time() 99 | self.world = world 100 | self.labels = defaultdict(str) # locations as keys 101 | self.representation = {"default": {"type": "color", "source": default_fill}} 102 | self.block_size = block_size 103 | 104 | def object_name(self): 105 | globals_in_main = {x: getattr(__main__, x) for x in dir(__main__)} 106 | for x in globals_in_main: 107 | if isinstance(globals_in_main[x], type(self)): 108 | if globals_in_main[x].time == self.time: 109 | return x 110 | 111 | def set_label(self, coordinates, label): 112 | """ Add lables to a particular block of grid. 113 | coordinates: a tuple of (row, column). 114 | rows and columns are 0 indexed. 115 | """ 116 | self.labels[coordinates] = label 117 | 118 | def set_representation(self, thing, repr_type, source): 119 | """ Set the representation of different things in the 120 | environment. 121 | thing: a thing object. 122 | repr_type : type of representation can be either "color" or "img" 123 | source: Hex value in case of color. Image path in case of image. 124 | """ 125 | thing_class_name = thing.__class__.__name__ 126 | if repr_type not in ("img", "color"): 127 | raise ValueError('Invalid repr_type passed. Possible types are img/color') 128 | self.representation[thing_class_name] = {"type": repr_type, "source": source} 129 | 130 | def handle_click(self, coordinates): 131 | """ This method needs to be overidden. Make sure to include a 132 | self.show() call at the end. """ 133 | self.show() 134 | 135 | def map_to_render(self): 136 | default_representation = {"val": "default", "tooltip": ""} 137 | world_map = [[copy.deepcopy(default_representation) for _ in range(self.world.width)] 138 | for _ in range(self.world.height)] 139 | 140 | for thing in self.world.things: 141 | row, column = thing.location 142 | thing_class_name = thing.__class__.__name__ 143 | if thing_class_name not in self.representation: 144 | raise KeyError('Representation not found for {}'.format(thing_class_name)) 145 | world_map[row][column]["val"] = thing.__class__.__name__ 146 | 147 | for location, label in self.labels.items(): 148 | row, column = location 149 | world_map[row][column]["tooltip"] = label 150 | 151 | return json.dumps(world_map) 152 | 153 | def show(self): 154 | clear_output() 155 | total_html = _GRID_WORLD_HTML.format( 156 | self.object_name(), self.map_to_render(), 157 | self.block_size, json.dumps(self.representation), _JS_GRID_WORLD) 158 | display(HTML(total_html)) 159 | -------------------------------------------------------------------------------- /images/vacuum.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 22 | 24 | 42 | 44 | 45 | 47 | image/svg+xml 48 | 50 | 51 | 52 | 53 | 57 | 67 | 79 | 93 | 108 | 124 | 137 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /gui/vacuum_agent.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | import random 3 | import sys 4 | import os.path 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | from agents import * 7 | 8 | loc_A, loc_B = (0, 0), (1, 0) # The two locations for the Vacuum world 9 | 10 | 11 | class Gui(Environment): 12 | 13 | """This GUI environment has two locations, A and B. Each can be Dirty 14 | or Clean. The agent perceives its location and the location's 15 | status.""" 16 | 17 | def __init__(self, root, height=300, width=380): 18 | super().__init__() 19 | self.status = {loc_A: 'Clean', 20 | loc_B: 'Clean'} 21 | self.root = root 22 | self.height = height 23 | self.width = width 24 | self.canvas = None 25 | self.buttons = [] 26 | self.create_canvas() 27 | self.create_buttons() 28 | 29 | def thing_classes(self): 30 | """The list of things which can be used in the environment.""" 31 | return [Wall, Dirt, ReflexVacuumAgent, RandomVacuumAgent, 32 | TableDrivenVacuumAgent, ModelBasedVacuumAgent] 33 | 34 | def percept(self, agent): 35 | """Returns the agent's location, and the location status (Dirty/Clean).""" 36 | return (agent.location, self.status[agent.location]) 37 | 38 | def execute_action(self, agent, action): 39 | """Change the location status (Dirty/Clean); track performance. 40 | Score 10 for each dirt cleaned; -1 for each move.""" 41 | if action == 'Right': 42 | agent.location = loc_B 43 | agent.performance -= 1 44 | elif action == 'Left': 45 | agent.location = loc_A 46 | agent.performance -= 1 47 | elif action == 'Suck': 48 | if self.status[agent.location] == 'Dirty': 49 | if agent.location == loc_A: 50 | self.buttons[0].config(bg='white', activebackground='light grey') 51 | else: 52 | self.buttons[1].config(bg='white', activebackground='light grey') 53 | agent.performance += 10 54 | self.status[agent.location] = 'Clean' 55 | 56 | def default_location(self, thing): 57 | """Agents start in either location at random.""" 58 | return random.choice([loc_A, loc_B]) 59 | 60 | def create_canvas(self): 61 | """Creates Canvas element in the GUI.""" 62 | self.canvas = Canvas( 63 | self.root, 64 | width=self.width, 65 | height=self.height, 66 | background='powder blue') 67 | self.canvas.pack(side='bottom') 68 | 69 | def create_buttons(self): 70 | """Creates the buttons required in the GUI.""" 71 | button_left = Button(self.root, height=4, width=12, padx=2, pady=2, bg='white') 72 | button_left.config(command=lambda btn=button_left: self.dirt_switch(btn)) 73 | self.buttons.append(button_left) 74 | button_left_window = self.canvas.create_window(130, 200, anchor=N, window=button_left) 75 | button_right = Button(self.root, height=4, width=12, padx=2, pady=2, bg='white') 76 | button_right.config(command=lambda btn=button_right: self.dirt_switch(btn)) 77 | self.buttons.append(button_right) 78 | button_right_window = self.canvas.create_window(250, 200, anchor=N, window=button_right) 79 | 80 | def dirt_switch(self, button): 81 | """Gives user the option to put dirt in any tile.""" 82 | bg_color = button['bg'] 83 | if bg_color == 'saddle brown': 84 | button.config(bg='white', activebackground='light grey') 85 | elif bg_color == 'white': 86 | button.config(bg='saddle brown', activebackground='light goldenrod') 87 | 88 | def read_env(self): 89 | """Reads the current state of the GUI.""" 90 | for i, btn in enumerate(self.buttons): 91 | if i == 0: 92 | if btn['bg'] == 'white': 93 | self.status[loc_A] = 'Clean' 94 | else: 95 | self.status[loc_A] = 'Dirty' 96 | else: 97 | if btn['bg'] == 'white': 98 | self.status[loc_B] = 'Clean' 99 | else: 100 | self.status[loc_B] = 'Dirty' 101 | 102 | def update_env(self, agent): 103 | """Updates the GUI according to the agent's action.""" 104 | self.read_env() 105 | # print(self.status) 106 | before_step = agent.location 107 | self.step() 108 | # print(self.status) 109 | # print(agent.location) 110 | move_agent(self, agent, before_step) 111 | 112 | 113 | def create_agent(env, agent): 114 | """Creates the agent in the GUI and is kept independent of the environment.""" 115 | env.add_thing(agent) 116 | # print(agent.location) 117 | if agent.location == (0, 0): 118 | env.agent_rect = env.canvas.create_rectangle(80, 100, 175, 180, fill='lime green') 119 | env.text = env.canvas.create_text(128, 140, font="Helvetica 10 bold italic", text="Agent") 120 | else: 121 | env.agent_rect = env.canvas.create_rectangle(200, 100, 295, 180, fill='lime green') 122 | env.text = env.canvas.create_text(248, 140, font="Helvetica 10 bold italic", text="Agent") 123 | 124 | 125 | def move_agent(env, agent, before_step): 126 | """Moves the agent in the GUI when 'next' button is pressed.""" 127 | if agent.location == before_step: 128 | pass 129 | else: 130 | if agent.location == (1, 0): 131 | env.canvas.move(env.text, 120, 0) 132 | env.canvas.move(env.agent_rect, 120, 0) 133 | elif agent.location == (0, 0): 134 | env.canvas.move(env.text, -120, 0) 135 | env.canvas.move(env.agent_rect, -120, 0) 136 | 137 | 138 | # TODO: Add more agents to the environment. 139 | # TODO: Expand the environment to XYEnvironment. 140 | def main(): 141 | """The main function of the program.""" 142 | root = Tk() 143 | root.title("Vacuum Environment") 144 | root.geometry("420x380") 145 | root.resizable(0, 0) 146 | frame = Frame(root, bg='black') 147 | # reset_button = Button(frame, text='Reset', height=2, width=6, padx=2, pady=2, command=None) 148 | # reset_button.pack(side='left') 149 | next_button = Button(frame, text='Next', height=2, width=6, padx=2, pady=2) 150 | next_button.pack(side='left') 151 | frame.pack(side='bottom') 152 | env = Gui(root) 153 | agent = ReflexVacuumAgent() 154 | create_agent(env, agent) 155 | next_button.config(command=lambda: env.update_env(agent)) 156 | root.mainloop() 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# An Introduction To `aima-python` \n", 8 | " \n", 9 | "The [aima-python](https://github.com/aimacode/aima-python) repository implements, in Python code, the algorithms in the textbook *[Artificial Intelligence: A Modern Approach](http://aima.cs.berkeley.edu)*. A typical module in the repository has the code for a single chapter in the book, but some modules combine several chapters. See [the index](https://github.com/aimacode/aima-python#index-of-code) if you can't find the algorithm you want. The code in this repository attempts to mirror the pseudocode in the textbook as closely as possible and to stress readability foremost; if you are looking for high-performance code with advanced features, there are other repositories for you. For each module, there are three/four files, for example:\n", 10 | "\n", 11 | "- [**`nlp.py`**](https://github.com/aimacode/aima-python/blob/master/nlp.py): Source code with data types and algorithms for natural language processing; functions have docstrings explaining their use.\n", 12 | "- [**`nlp.ipynb`**](https://github.com/aimacode/aima-python/blob/master/nlp.ipynb): A notebook like this one; gives more detailed examples and explanations of use.\n", 13 | "- [**`nlp_apps.ipynb`**](https://github.com/aimacode/aima-python/blob/master/nlp_apps.ipynb): A Jupyter notebook that gives example applications of the code.\n", 14 | "- [**`tests/test_nlp.py`**](https://github.com/aimacode/aima-python/blob/master/tests/test_nlp.py): Test cases, used to verify the code is correct, and also useful to see examples of use.\n", 15 | "\n", 16 | "There is also an [aima-java](https://github.com/aimacode/aima-java) repository, if you prefer Java.\n", 17 | " \n", 18 | "## What version of Python?\n", 19 | " \n", 20 | "The code is tested in Python [3.4](https://www.python.org/download/releases/3.4.3/) and [3.5](https://www.python.org/downloads/release/python-351/). If you try a different version of Python 3 and find a problem, please report it as an [Issue](https://github.com/aimacode/aima-python/issues). There is an incomplete [legacy branch](https://github.com/aimacode/aima-python/tree/aima3python2) for those who must run in Python 2. \n", 21 | " \n", 22 | "We recommend the [Anaconda](https://www.continuum.io/downloads) distribution of Python 3.5. It comes with additional tools like the powerful IPython interpreter, the Jupyter Notebook and many helpful packages for scientific computing. After installing Anaconda, you will be good to go to run all the code and all the IPython notebooks. \n", 23 | "\n", 24 | "## IPython notebooks \n", 25 | " \n", 26 | "The IPython notebooks in this repository explain how to use the modules, and give examples of usage. \n", 27 | "You can use them in three ways: \n", 28 | "\n", 29 | "1. View static HTML pages. (Just browse to the [repository](https://github.com/aimacode/aima-python) and click on a `.ipynb` file link.)\n", 30 | "2. Run, modify, and re-run code, live. (Download the repository (by [zip file](https://github.com/aimacode/aima-python/archive/master.zip) or by `git` commands), start a Jupyter notebook server with the shell command \"`jupyter notebook`\" (issued from the directory where the files are), and click on the notebook you want to interact with.)\n", 31 | "3. Binder - Click on the binder badge on the [repository](https://github.com/aimacode/aima-python) main page to open the notebooks in an executable environment, online. This method does not require any extra installation. The code can be executed and modified from the browser itself. Note that this is an unstable option; there is a chance the notebooks will never load.\n", 32 | "\n", 33 | " \n", 34 | "You can [read about notebooks](https://jupyter-notebook-beginner-guide.readthedocs.org/en/latest/) and then [get started](https://nbviewer.jupyter.org/github/jupyter/notebook/blob/master/docs/source/examples/Notebook/Running%20Code.ipynb)." 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "source": [ 43 | "# Helpful Tips\n", 44 | "\n", 45 | "Most of these notebooks start by importing all the symbols in a module:" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "from logic import *" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "From there, the notebook alternates explanations with examples of use. You can run the examples as they are, and you can modify the code cells (or add new cells) and run your own examples. If you have some really good examples to add, you can make a github pull request.\n", 64 | "\n", 65 | "If you want to see the source code of a function, you can open a browser or editor and see it in another window, or from within the notebook you can use the IPython magic function `%psource` (for \"print source\") or the function `psource` from `notebook.py`. Also, if the algorithm has pseudocode, you can read it by calling the `pseudocode` function with input the name of the algorithm." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": { 72 | "collapsed": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "%psource WalkSAT" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from notebook import psource, pseudocode\n", 86 | "\n", 87 | "psource(WalkSAT)\n", 88 | "pseudocode(\"WalkSAT\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "Or see an abbreviated description of an object with a trailing question mark:" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": { 102 | "collapsed": true 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "WalkSAT?" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Authors\n", 114 | "\n", 115 | "This notebook is written by [Chirag Vertak](https://github.com/chiragvartak) and [Peter Norvig](https://github.com/norvig)." 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.5.3" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 1 140 | } 141 | -------------------------------------------------------------------------------- /gui/genetic_algorithm_example.py: -------------------------------------------------------------------------------- 1 | # author: ad71 2 | # A simple program that implements the solution to the phrase generation problem using 3 | # genetic algorithms as given in the search.ipynb notebook. 4 | # 5 | # Type on the home screen to change the target phrase 6 | # Click on the slider to change genetic algorithm parameters 7 | # Click 'GO' to run the algorithm with the specified variables 8 | # Displays best individual of the current generation 9 | # Displays a progress bar that indicates the amount of completion of the algorithm 10 | # Displays the first few individuals of the current generation 11 | 12 | import sys 13 | import time 14 | import random 15 | import os.path 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | from tkinter import * 19 | from tkinter import ttk 20 | 21 | import search 22 | from utils import argmax 23 | 24 | LARGE_FONT = ('Verdana', 12) 25 | EXTRA_LARGE_FONT = ('Consolas', 36, 'bold') 26 | 27 | canvas_width = 800 28 | canvas_height = 600 29 | 30 | black = '#000000' 31 | white = '#ffffff' 32 | p_blue = '#042533' 33 | lp_blue = '#0c394c' 34 | 35 | # genetic algorithm variables 36 | # feel free to play around with these 37 | target = 'Genetic Algorithm' # the phrase to be generated 38 | max_population = 100 # number of samples in each population 39 | mutation_rate = 0.1 # probability of mutation 40 | f_thres = len(target) # fitness threshold 41 | ngen = 1200 # max number of generations to run the genetic algorithm 42 | 43 | generation = 0 # counter to keep track of generation number 44 | 45 | u_case = [chr(x) for x in range(65, 91)] # list containing all uppercase characters 46 | l_case = [chr(x) for x in range(97, 123)] # list containing all lowercase characters 47 | punctuations1 = [chr(x) for x in range(33, 48)] # lists containing punctuation symbols 48 | punctuations2 = [chr(x) for x in range(58, 65)] 49 | punctuations3 = [chr(x) for x in range(91, 97)] 50 | numerals = [chr(x) for x in range(48, 58)] # list containing numbers 51 | 52 | # extend the gene pool with the required lists and append the space character 53 | gene_pool = [] 54 | gene_pool.extend(u_case) 55 | gene_pool.extend(l_case) 56 | gene_pool.append(' ') 57 | 58 | # callbacks to update global variables from the slider values 59 | def update_max_population(slider_value): 60 | global max_population 61 | max_population = slider_value 62 | 63 | def update_mutation_rate(slider_value): 64 | global mutation_rate 65 | mutation_rate = slider_value 66 | 67 | def update_f_thres(slider_value): 68 | global f_thres 69 | f_thres = slider_value 70 | 71 | def update_ngen(slider_value): 72 | global ngen 73 | ngen = slider_value 74 | 75 | # fitness function 76 | def fitness_fn(_list): 77 | fitness = 0 78 | # create string from list of characters 79 | phrase = ''.join(_list) 80 | # add 1 to fitness value for every matching character 81 | for i in range(len(phrase)): 82 | if target[i] == phrase[i]: 83 | fitness += 1 84 | return fitness 85 | 86 | # function to bring a new frame on top 87 | def raise_frame(frame, init=False, update_target=False, target_entry=None, f_thres_slider=None): 88 | frame.tkraise() 89 | global target 90 | if update_target and target_entry is not None: 91 | target = target_entry.get() 92 | f_thres_slider.config(to=len(target)) 93 | if init: 94 | population = search.init_population(max_population, gene_pool, len(target)) 95 | genetic_algorithm_stepwise(population) 96 | 97 | # defining root and child frames 98 | root = Tk() 99 | f1 = Frame(root) 100 | f2 = Frame(root) 101 | 102 | # pack frames on top of one another 103 | for frame in (f1, f2): 104 | frame.grid(row=0, column=0, sticky='news') 105 | 106 | # Home Screen (f1) widgets 107 | target_entry = Entry(f1, font=('Consolas 46 bold'), exportselection=0, foreground=p_blue, justify=CENTER) 108 | target_entry.insert(0, target) 109 | target_entry.pack(expand=YES, side=TOP, fill=X, padx=50) 110 | target_entry.focus_force() 111 | 112 | max_population_slider = Scale(f1, from_=3, to=1000, orient=HORIZONTAL, label='Max population', command=lambda value: update_max_population(int(value))) 113 | max_population_slider.set(max_population) 114 | max_population_slider.pack(expand=YES, side=TOP, fill=X, padx=40) 115 | 116 | mutation_rate_slider = Scale(f1, from_=0, to=1, orient=HORIZONTAL, label='Mutation rate', resolution=0.0001, command=lambda value: update_mutation_rate(float(value))) 117 | mutation_rate_slider.set(mutation_rate) 118 | mutation_rate_slider.pack(expand=YES, side=TOP, fill=X, padx=40) 119 | 120 | f_thres_slider = Scale(f1, from_=0, to=len(target), orient=HORIZONTAL, label='Fitness threshold', command=lambda value: update_f_thres(int(value))) 121 | f_thres_slider.set(f_thres) 122 | f_thres_slider.pack(expand=YES, side=TOP, fill=X, padx=40) 123 | 124 | ngen_slider = Scale(f1, from_=1, to=5000, orient=HORIZONTAL, label='Max number of generations', command=lambda value: update_ngen(int(value))) 125 | ngen_slider.set(ngen) 126 | ngen_slider.pack(expand=YES, side=TOP, fill=X, padx=40) 127 | 128 | button = ttk.Button(f1, text='RUN', command=lambda: raise_frame(f2, init=True, update_target=True, target_entry=target_entry, f_thres_slider=f_thres_slider)).pack(side=BOTTOM, pady=50) 129 | 130 | # f2 widgets 131 | canvas = Canvas(f2, width=canvas_width, height=canvas_height) 132 | canvas.pack(expand=YES, fill=BOTH, padx=20, pady=15) 133 | button = ttk.Button(f2, text='EXIT', command=lambda: raise_frame(f1)).pack(side=BOTTOM, pady=15) 134 | 135 | # function to run the genetic algorithm and update text on the canvas 136 | def genetic_algorithm_stepwise(population): 137 | root.title('Genetic Algorithm') 138 | for generation in range(ngen): 139 | # generating new population after selecting, recombining and mutating the existing population 140 | population = [search.mutate(search.recombine(*search.select(2, population, fitness_fn)), gene_pool, mutation_rate) for i in range(len(population))] 141 | # genome with the highest fitness in the current generation 142 | current_best = ''.join(argmax(population, key=fitness_fn)) 143 | # collecting first few examples from the current population 144 | members = [''.join(x) for x in population][:48] 145 | 146 | # clear the canvas 147 | canvas.delete('all') 148 | # displays current best on top of the screen 149 | canvas.create_text(canvas_width / 2, 40, fill=p_blue, font='Consolas 46 bold', text=current_best) 150 | 151 | # displaying a part of the population on the screen 152 | for i in range(len(members) // 3): 153 | canvas.create_text((canvas_width * .175), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i]) 154 | canvas.create_text((canvas_width * .500), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i + 1]) 155 | canvas.create_text((canvas_width * .825), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i + 2]) 156 | 157 | # displays current generation number 158 | canvas.create_text((canvas_width * .5), (canvas_height * 0.95), fill=p_blue, font='Consolas 18 bold', text=f'Generation {generation}') 159 | 160 | # displays blue bar that indicates current maximum fitness compared to maximum possible fitness 161 | scaling_factor = fitness_fn(current_best) / len(target) 162 | canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.9, 100, outline=p_blue) 163 | canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.1 + scaling_factor * canvas_width * 0.8, 100, fill=lp_blue) 164 | canvas.update() 165 | 166 | # checks for completion 167 | fittest_individual = search.fitness_threshold(fitness_fn, f_thres, population) 168 | if fittest_individual: 169 | break 170 | 171 | raise_frame(f1) 172 | root.mainloop() -------------------------------------------------------------------------------- /gui/tic-tac-toe.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | import sys 3 | import os.path 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | from games import minimax_decision, alphabeta_player, random_player, TicTacToe 6 | # "gen_state" can be used to generate a game state to apply the algorithm 7 | from tests.test_games import gen_state 8 | 9 | ttt = TicTacToe() 10 | root = None 11 | buttons = [] 12 | frames = [] 13 | x_pos = [] 14 | o_pos = [] 15 | count = 0 16 | sym = "" 17 | result = None 18 | choices = None 19 | 20 | 21 | def create_frames(root): 22 | """ 23 | This function creates the necessary structure of the game. 24 | """ 25 | frame1 = Frame(root) 26 | frame2 = Frame(root) 27 | frame3 = Frame(root) 28 | frame4 = Frame(root) 29 | create_buttons(frame1) 30 | create_buttons(frame2) 31 | create_buttons(frame3) 32 | buttonExit = Button( 33 | frame4, height=1, width=2, 34 | text="Exit", 35 | command=lambda: exit_game(root)) 36 | buttonExit.pack(side=LEFT) 37 | frame4.pack(side=BOTTOM) 38 | frame3.pack(side=BOTTOM) 39 | frame2.pack(side=BOTTOM) 40 | frame1.pack(side=BOTTOM) 41 | frames.append(frame1) 42 | frames.append(frame2) 43 | frames.append(frame3) 44 | for x in frames: 45 | buttons_in_frame = [] 46 | for y in x.winfo_children(): 47 | buttons_in_frame.append(y) 48 | buttons.append(buttons_in_frame) 49 | buttonReset = Button(frame4, height=1, width=2, 50 | text="Reset", command=lambda: reset_game()) 51 | buttonReset.pack(side=LEFT) 52 | 53 | 54 | def create_buttons(frame): 55 | """ 56 | This function creates the buttons to be pressed/clicked during the game. 57 | """ 58 | button0 = Button(frame, height=2, width=2, text=" ", 59 | command=lambda: on_click(button0)) 60 | button0.pack(side=LEFT) 61 | button1 = Button(frame, height=2, width=2, text=" ", 62 | command=lambda: on_click(button1)) 63 | button1.pack(side=LEFT) 64 | button2 = Button(frame, height=2, width=2, text=" ", 65 | command=lambda: on_click(button2)) 66 | button2.pack(side=LEFT) 67 | 68 | 69 | # TODO: Add a choice option for the user. 70 | def on_click(button): 71 | """ 72 | This function determines the action of any button. 73 | """ 74 | global ttt, choices, count, sym, result, x_pos, o_pos 75 | 76 | if count % 2 == 0: 77 | sym = "X" 78 | else: 79 | sym = "O" 80 | count += 1 81 | 82 | button.config( 83 | text=sym, 84 | state='disabled', 85 | disabledforeground="red") # For cross 86 | 87 | x, y = get_coordinates(button) 88 | x += 1 89 | y += 1 90 | x_pos.append((x, y)) 91 | state = gen_state(to_move='O', x_positions=x_pos, 92 | o_positions=o_pos) 93 | try: 94 | choice = choices.get() 95 | if "Random" in choice: 96 | a, b = random_player(ttt, state) 97 | elif "Pro" in choice: 98 | a, b = minimax_decision(state, ttt) 99 | else: 100 | a, b = alphabeta_player(ttt, state) 101 | except (ValueError, IndexError, TypeError) as e: 102 | disable_game() 103 | result.set("It's a draw :|") 104 | return 105 | if 1 <= a <= 3 and 1 <= b <= 3: 106 | o_pos.append((a, b)) 107 | button_to_change = get_button(a - 1, b - 1) 108 | if count % 2 == 0: # Used again, will become handy when user is given the choice of turn. 109 | sym = "X" 110 | else: 111 | sym = "O" 112 | count += 1 113 | 114 | if check_victory(button): 115 | result.set("You win :)") 116 | disable_game() 117 | else: 118 | button_to_change.config(text=sym, state='disabled', 119 | disabledforeground="black") 120 | if check_victory(button_to_change): 121 | result.set("You lose :(") 122 | disable_game() 123 | 124 | 125 | # TODO: Replace "check_victory" by "k_in_row" function. 126 | def check_victory(button): 127 | """ 128 | This function checks various winning conditions of the game. 129 | """ 130 | # check if previous move caused a win on vertical line 131 | global buttons 132 | x, y = get_coordinates(button) 133 | tt = button['text'] 134 | if buttons[0][y]['text'] == buttons[1][y]['text'] == buttons[2][y]['text'] != " ": 135 | buttons[0][y].config(text="|" + tt + "|") 136 | buttons[1][y].config(text="|" + tt + "|") 137 | buttons[2][y].config(text="|" + tt + "|") 138 | return True 139 | 140 | # check if previous move caused a win on horizontal line 141 | if buttons[x][0]['text'] == buttons[x][1]['text'] == buttons[x][2]['text'] != " ": 142 | buttons[x][0].config(text="--" + tt + "--") 143 | buttons[x][1].config(text="--" + tt + "--") 144 | buttons[x][2].config(text="--" + tt + "--") 145 | return True 146 | 147 | # check if previous move was on the main diagonal and caused a win 148 | if x == y and buttons[0][0]['text'] == buttons[1][1]['text'] == buttons[2][2]['text'] != " ": 149 | buttons[0][0].config(text="\\" + tt + "\\") 150 | buttons[1][1].config(text="\\" + tt + "\\") 151 | buttons[2][2].config(text="\\" + tt + "\\") 152 | return True 153 | 154 | # check if previous move was on the secondary diagonal and caused a win 155 | if x + y \ 156 | == 2 and buttons[0][2]['text'] == buttons[1][1]['text'] == buttons[2][0]['text'] != " ": 157 | buttons[0][2].config(text="/" + tt + "/") 158 | buttons[1][1].config(text="/" + tt + "/") 159 | buttons[2][0].config(text="/" + tt + "/") 160 | return True 161 | 162 | return False 163 | 164 | 165 | def get_coordinates(button): 166 | """ 167 | This function returns the coordinates of the button clicked. 168 | """ 169 | global buttons 170 | for x in range(len(buttons)): 171 | for y in range(len(buttons[x])): 172 | if buttons[x][y] == button: 173 | return x, y 174 | 175 | 176 | def get_button(x, y): 177 | """ 178 | This function returns the button memory location corresponding to a coordinate. 179 | """ 180 | global buttons 181 | return buttons[x][y] 182 | 183 | 184 | def reset_game(): 185 | """ 186 | This function will reset all the tiles to the initial null value. 187 | """ 188 | global x_pos, o_pos, frames, count 189 | 190 | count = 0 191 | x_pos = [] 192 | o_pos = [] 193 | result.set("Your Turn!") 194 | for x in frames: 195 | for y in x.winfo_children(): 196 | y.config(text=" ", state='normal') 197 | 198 | 199 | def disable_game(): 200 | """ 201 | This function deactivates the game after a win, loss or draw. 202 | """ 203 | global frames 204 | for x in frames: 205 | for y in x.winfo_children(): 206 | y.config(state='disabled') 207 | 208 | 209 | def exit_game(root): 210 | """ 211 | This function will exit the game by killing the root. 212 | """ 213 | root.destroy() 214 | 215 | 216 | def main(): 217 | global result, choices 218 | 219 | root = Tk() 220 | root.title("TicTacToe") 221 | root.geometry("150x200") # Improved the window geometry 222 | root.resizable(0, 0) # To remove the maximize window option 223 | result = StringVar() 224 | result.set("Your Turn!") 225 | w = Label(root, textvariable=result) 226 | w.pack(side=BOTTOM) 227 | create_frames(root) 228 | choices = StringVar(root) 229 | choices.set("Vs Pro") 230 | menu = OptionMenu(root, choices, "Vs Random", "Vs Pro", "Vs Legend") 231 | menu.pack() 232 | root.mainloop() 233 | 234 | 235 | if __name__ == "__main__": 236 | main() 237 | -------------------------------------------------------------------------------- /gui/xy_vacuum_environment.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | import random 3 | import sys 4 | import os.path 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | from agents import * 7 | 8 | 9 | class Gui(VacuumEnvironment): 10 | """This is a two-dimensional GUI environment. Each location may be 11 | dirty, clean or can have a wall. The user can change these at each step. 12 | """ 13 | xi, yi = (0, 0) 14 | perceptible_distance = 1 15 | 16 | def __init__(self, root, width=7, height=7, elements=['D', 'W']): 17 | super().__init__(width, height) 18 | self.root = root 19 | self.create_frames() 20 | self.create_buttons() 21 | self.create_walls() 22 | self.elements = elements 23 | 24 | def create_frames(self): 25 | """Adds frames to the GUI environment.""" 26 | self.frames = [] 27 | for _ in range(7): 28 | frame = Frame(self.root, bg='grey') 29 | frame.pack(side='bottom') 30 | self.frames.append(frame) 31 | 32 | def create_buttons(self): 33 | """Adds buttons to the respective frames in the GUI.""" 34 | self.buttons = [] 35 | for frame in self.frames: 36 | button_row = [] 37 | for _ in range(7): 38 | button = Button(frame, height=3, width=5, padx=2, pady=2) 39 | button.config( 40 | command=lambda btn=button: self.display_element(btn)) 41 | button.pack(side='left') 42 | button_row.append(button) 43 | self.buttons.append(button_row) 44 | 45 | def create_walls(self): 46 | """Creates the outer boundary walls which do not move.""" 47 | for row, button_row in enumerate(self.buttons): 48 | if row == 0 or row == len(self.buttons) - 1: 49 | for button in button_row: 50 | button.config(text='W', state='disabled', 51 | disabledforeground='black') 52 | else: 53 | button_row[0].config( 54 | text='W', state='disabled', disabledforeground='black') 55 | button_row[len(button_row) - 1].config(text='W', 56 | state='disabled', disabledforeground='black') 57 | # Place the agent in the centre of the grid. 58 | self.buttons[3][3].config( 59 | text='A', state='disabled', disabledforeground='black') 60 | 61 | def display_element(self, button): 62 | """Show the things on the GUI.""" 63 | txt = button['text'] 64 | if txt != 'A': 65 | if txt == 'W': 66 | button.config(text='D') 67 | elif txt == 'D': 68 | button.config(text='') 69 | elif txt == '': 70 | button.config(text='W') 71 | 72 | def execute_action(self, agent, action): 73 | """Determines the action the agent performs.""" 74 | xi, yi = ((self.xi, self.yi)) 75 | if action == 'Suck': 76 | dirt_list = self.list_things_at(agent.location, Dirt) 77 | if dirt_list != []: 78 | dirt = dirt_list[0] 79 | agent.performance += 100 80 | self.delete_thing(dirt) 81 | self.buttons[xi][yi].config(text='', state='normal') 82 | xf, yf = agent.location 83 | self.buttons[xf][yf].config( 84 | text='A', state='disabled', disabledforeground='black') 85 | 86 | else: 87 | agent.bump = False 88 | if action == 'TurnRight': 89 | agent.direction += Direction.R 90 | elif action == 'TurnLeft': 91 | agent.direction += Direction.L 92 | elif action == 'Forward': 93 | agent.bump = self.move_to(agent, agent.direction.move_forward(agent.location)) 94 | if not agent.bump: 95 | self.buttons[xi][yi].config(text='', state='normal') 96 | xf, yf = agent.location 97 | self.buttons[xf][yf].config( 98 | text='A', state='disabled', disabledforeground='black') 99 | 100 | if action != 'NoOp': 101 | agent.performance -= 1 102 | 103 | def read_env(self): 104 | """Reads the current state of the GUI environment.""" 105 | for i, btn_row in enumerate(self.buttons): 106 | for j, btn in enumerate(btn_row): 107 | if (i != 0 and i != len(self.buttons) - 1) and (j != 0 and j != len(btn_row) - 1): 108 | agt_loc = self.agents[0].location 109 | if self.some_things_at((i, j)) and (i, j) != agt_loc: 110 | for thing in self.list_things_at((i, j)): 111 | self.delete_thing(thing) 112 | if btn['text'] == self.elements[0]: 113 | self.add_thing(Dirt(), (i, j)) 114 | elif btn['text'] == self.elements[1]: 115 | self.add_thing(Wall(), (i, j)) 116 | 117 | def update_env(self): 118 | """Updates the GUI environment according to the current state.""" 119 | self.read_env() 120 | agt = self.agents[0] 121 | previous_agent_location = agt.location 122 | self.xi, self.yi = previous_agent_location 123 | self.step() 124 | xf, yf = agt.location 125 | 126 | def reset_env(self, agt): 127 | """Resets the GUI environment to the initial state.""" 128 | self.read_env() 129 | for i, btn_row in enumerate(self.buttons): 130 | for j, btn in enumerate(btn_row): 131 | if (i != 0 and i != len(self.buttons) - 1) and (j != 0 and j != len(btn_row) - 1): 132 | if self.some_things_at((i, j)): 133 | for thing in self.list_things_at((i, j)): 134 | self.delete_thing(thing) 135 | btn.config(text='', state='normal') 136 | self.add_thing(agt, location=(3, 3)) 137 | self.buttons[3][3].config( 138 | text='A', state='disabled', disabledforeground='black') 139 | 140 | 141 | def XYReflexAgentProgram(percept): 142 | """The modified SimpleReflexAgentProgram for the GUI environment.""" 143 | status, bump = percept 144 | if status == 'Dirty': 145 | return 'Suck' 146 | 147 | if bump == 'Bump': 148 | value = random.choice((1, 2)) 149 | else: 150 | value = random.choice((1, 2, 3, 4)) # 1-right, 2-left, others-forward 151 | 152 | if value == 1: 153 | return 'TurnRight' 154 | elif value == 2: 155 | return 'TurnLeft' 156 | else: 157 | return 'Forward' 158 | 159 | 160 | class XYReflexAgent(Agent): 161 | """The modified SimpleReflexAgent for the GUI environment.""" 162 | 163 | def __init__(self, program=None): 164 | super().__init__(program) 165 | self.location = (3, 3) 166 | self.direction = Direction("up") 167 | 168 | 169 | # TODO: 170 | # Check the coordinate system. 171 | # Give manual choice for agent's location. 172 | def main(): 173 | """The main function.""" 174 | root = Tk() 175 | root.title("Vacuum Environment") 176 | root.geometry("420x440") 177 | root.resizable(0, 0) 178 | frame = Frame(root, bg='black') 179 | reset_button = Button(frame, text='Reset', height=2, 180 | width=6, padx=2, pady=2) 181 | reset_button.pack(side='left') 182 | next_button = Button(frame, text='Next', height=2, 183 | width=6, padx=2, pady=2) 184 | next_button.pack(side='left') 185 | frame.pack(side='bottom') 186 | env = Gui(root) 187 | agt = XYReflexAgent(program=XYReflexAgentProgram) 188 | env.add_thing(agt, location=(3, 3)) 189 | next_button.config(command=env.update_env) 190 | reset_button.config(command=lambda: env.reset_env(agt)) 191 | root.mainloop() 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /tests/test_mdp.py: -------------------------------------------------------------------------------- 1 | from mdp import * 2 | 3 | sequential_decision_environment_1 = GridMDP([[-0.1, -0.1, -0.1, +1], 4 | [-0.1, None, -0.1, -1], 5 | [-0.1, -0.1, -0.1, -0.1]], 6 | terminals=[(3, 2), (3, 1)]) 7 | 8 | sequential_decision_environment_2 = GridMDP([[-2, -2, -2, +1], 9 | [-2, None, -2, -1], 10 | [-2, -2, -2, -2]], 11 | terminals=[(3, 2), (3, 1)]) 12 | 13 | sequential_decision_environment_3 = GridMDP([[-1.0, -0.1, -0.1, -0.1, -0.1, 0.5], 14 | [-0.1, None, None, -0.5, -0.1, -0.1], 15 | [-0.1, None, 1.0, 3.0, None, -0.1], 16 | [-0.1, -0.1, -0.1, None, None, -0.1], 17 | [0.5, -0.1, -0.1, -0.1, -0.1, -1.0]], 18 | terminals=[(2, 2), (3, 2), (0, 4), (5, 0)]) 19 | 20 | def test_value_iteration(): 21 | assert value_iteration(sequential_decision_environment, .01) == { 22 | (3, 2): 1.0, (3, 1): -1.0, 23 | (3, 0): 0.12958868267972745, (0, 1): 0.39810203830605462, 24 | (0, 2): 0.50928545646220924, (1, 0): 0.25348746162470537, 25 | (0, 0): 0.29543540628363629, (1, 2): 0.64958064617168676, 26 | (2, 0): 0.34461306281476806, (2, 1): 0.48643676237737926, 27 | (2, 2): 0.79536093684710951} 28 | 29 | assert value_iteration(sequential_decision_environment_1, .01) == { 30 | (3, 2): 1.0, (3, 1): -1.0, 31 | (3, 0): -0.0897388258468311, (0, 1): 0.146419707398967840, 32 | (0, 2): 0.30596200514385086, (1, 0): 0.010092796415625799, 33 | (0, 0): 0.00633408092008296, (1, 2): 0.507390193380827400, 34 | (2, 0): 0.15072242145212010, (2, 1): 0.358309043654212570, 35 | (2, 2): 0.71675493618997840} 36 | 37 | assert value_iteration(sequential_decision_environment_2, .01) == { 38 | (3, 2): 1.0, (3, 1): -1.0, 39 | (3, 0): -3.5141584808407855, (0, 1): -7.8000009574737180, 40 | (0, 2): -6.1064293596058830, (1, 0): -7.1012549580376760, 41 | (0, 0): -8.5872244532783200, (1, 2): -3.9653547121245810, 42 | (2, 0): -5.3099468802901630, (2, 1): -3.3543366255753995, 43 | (2, 2): -1.7383376462930498} 44 | 45 | assert value_iteration(sequential_decision_environment_3, .01) == { 46 | (0, 0): 4.350592130345558, (0, 1): 3.640700980321895, (0, 2): 3.0734806370346943, (0, 3): 2.5754335063434937, (0, 4): -1.0, 47 | (1, 0): 3.640700980321895, (1, 1): 3.129579352304856, (1, 4): 2.0787517066719916, 48 | (2, 0): 3.0259220379893352, (2, 1): 2.5926103577982897, (2, 2): 1.0, (2, 4): 2.507774181360808, 49 | (3, 0): 2.5336747364500076, (3, 2): 3.0, (3, 3): 2.292172805400873, (3, 4): 2.996383110867515, 50 | (4, 0): 2.1014575936349886, (4, 3): 3.1297590518608907, (4, 4): 3.6408806798779287, 51 | (5, 0): -1.0, (5, 1): 2.5756132058995282, (5, 2): 3.0736603365907276, (5, 3): 3.6408806798779287, (5, 4): 4.350771829901593} 52 | 53 | 54 | def test_policy_iteration(): 55 | assert policy_iteration(sequential_decision_environment) == { 56 | (0, 0): (0, 1), (0, 1): (0, 1), (0, 2): (1, 0), 57 | (1, 0): (1, 0), (1, 2): (1, 0), (2, 0): (0, 1), 58 | (2, 1): (0, 1), (2, 2): (1, 0), (3, 0): (-1, 0), 59 | (3, 1): None, (3, 2): None} 60 | 61 | assert policy_iteration(sequential_decision_environment_1) == { 62 | (0, 0): (0, 1), (0, 1): (0, 1), (0, 2): (1, 0), 63 | (1, 0): (1, 0), (1, 2): (1, 0), (2, 0): (0, 1), 64 | (2, 1): (0, 1), (2, 2): (1, 0), (3, 0): (-1, 0), 65 | (3, 1): None, (3, 2): None} 66 | 67 | assert policy_iteration(sequential_decision_environment_2) == { 68 | (0, 0): (1, 0), (0, 1): (0, 1), (0, 2): (1, 0), 69 | (1, 0): (1, 0), (1, 2): (1, 0), (2, 0): (1, 0), 70 | (2, 1): (1, 0), (2, 2): (1, 0), (3, 0): (0, 1), 71 | (3, 1): None, (3, 2): None} 72 | 73 | 74 | def test_best_policy(): 75 | pi = best_policy(sequential_decision_environment, 76 | value_iteration(sequential_decision_environment, .01)) 77 | assert sequential_decision_environment.to_arrows(pi) == [['>', '>', '>', '.'], 78 | ['^', None, '^', '.'], 79 | ['^', '>', '^', '<']] 80 | 81 | pi_1 = best_policy(sequential_decision_environment_1, 82 | value_iteration(sequential_decision_environment_1, .01)) 83 | assert sequential_decision_environment_1.to_arrows(pi_1) == [['>', '>', '>', '.'], 84 | ['^', None, '^', '.'], 85 | ['^', '>', '^', '<']] 86 | 87 | pi_2 = best_policy(sequential_decision_environment_2, 88 | value_iteration(sequential_decision_environment_2, .01)) 89 | assert sequential_decision_environment_2.to_arrows(pi_2) == [['>', '>', '>', '.'], 90 | ['^', None, '>', '.'], 91 | ['>', '>', '>', '^']] 92 | 93 | pi_3 = best_policy(sequential_decision_environment_3, 94 | value_iteration(sequential_decision_environment_3, .01)) 95 | assert sequential_decision_environment_3.to_arrows(pi_3) == [['.', '>', '>', '>', '>', '>'], 96 | ['v', None, None, '>', '>', '^'], 97 | ['v', None, '.', '.', None, '^'], 98 | ['v', '<', 'v', None, None, '^'], 99 | ['<', '<', '<', '<', '<', '.']] 100 | 101 | 102 | def test_transition_model(): 103 | transition_model = { 'a' : { 'plan1' : [(0.2, 'a'), (0.3, 'b'), (0.3, 'c'), (0.2, 'd')], 104 | 'plan2' : [(0.4, 'a'), (0.15, 'b'), (0.45, 'c')], 105 | 'plan3' : [(0.2, 'a'), (0.5, 'b'), (0.3, 'c')], 106 | }, 107 | 'b' : { 'plan1' : [(0.2, 'a'), (0.6, 'b'), (0.2, 'c'), (0.1, 'd')], 108 | 'plan2' : [(0.6, 'a'), (0.2, 'b'), (0.1, 'c'), (0.1, 'd')], 109 | 'plan3' : [(0.3, 'a'), (0.3, 'b'), (0.4, 'c')], 110 | }, 111 | 'c' : { 'plan1' : [(0.3, 'a'), (0.5, 'b'), (0.1, 'c'), (0.1, 'd')], 112 | 'plan2' : [(0.5, 'a'), (0.3, 'b'), (0.1, 'c'), (0.1, 'd')], 113 | 'plan3' : [(0.1, 'a'), (0.3, 'b'), (0.1, 'c'), (0.5, 'd')], 114 | }, 115 | } 116 | 117 | mdp = MDP(init="a", actlist={"plan1","plan2", "plan3"}, terminals={"d"}, states={"a","b","c", "d"}, transitions=transition_model) 118 | 119 | assert mdp.T("a","plan3") == [(0.2, 'a'), (0.5, 'b'), (0.3, 'c')] 120 | assert mdp.T("b","plan2") == [(0.6, 'a'), (0.2, 'b'), (0.1, 'c'), (0.1, 'd')] 121 | assert mdp.T("c","plan1") == [(0.3, 'a'), (0.5, 'b'), (0.1, 'c'), (0.1, 'd')] 122 | 123 | 124 | def test_pomdp_value_iteration(): 125 | t_prob = [[[0.65, 0.35], [0.65, 0.35]], [[0.65, 0.35], [0.65, 0.35]], [[1.0, 0.0], [0.0, 1.0]]] 126 | e_prob = [[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]], [[0.8, 0.2], [0.3, 0.7]]] 127 | rewards = [[5, -10], [-20, 5], [-1, -1]] 128 | 129 | gamma = 0.95 130 | actions = ('0', '1', '2') 131 | states = ('0', '1') 132 | 133 | pomdp = POMDP(actions, t_prob, e_prob, rewards, states, gamma) 134 | utility = pomdp_value_iteration(pomdp, epsilon=5) 135 | 136 | for _, v in utility.items(): 137 | sum_ = 0 138 | for element in v: 139 | sum_ += sum(element) 140 | 141 | assert -9.76 < sum_ < -9.70 or 246.5 < sum_ < 248.5 or 0 < sum_ < 1 142 | 143 | 144 | def test_pomdp_value_iteration2(): 145 | t_prob = [[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]], [[1.0, 0.0], [0.0, 1.0]]] 146 | e_prob = [[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]], [[0.85, 0.15], [0.15, 0.85]]] 147 | rewards = [[-100, 10], [10, -100], [-1, -1]] 148 | 149 | gamma = 0.95 150 | actions = ('0', '1', '2') 151 | states = ('0', '1') 152 | 153 | pomdp = POMDP(actions, t_prob, e_prob, rewards, states, gamma) 154 | utility = pomdp_value_iteration(pomdp, epsilon=100) 155 | 156 | for _, v in utility.items(): 157 | sum_ = 0 158 | for element in v: 159 | sum_ += sum(element) 160 | 161 | assert -77.31 < sum_ < -77.25 or 799 < sum_ < 800 162 | -------------------------------------------------------------------------------- /tests/test_learning.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | import random 4 | from utils import open_data 5 | from learning import * 6 | 7 | 8 | random.seed("aima-python") 9 | 10 | 11 | def test_euclidean(): 12 | distance = euclidean_distance([1, 2], [3, 4]) 13 | assert round(distance, 2) == 2.83 14 | 15 | distance = euclidean_distance([1, 2, 3], [4, 5, 6]) 16 | assert round(distance, 2) == 5.2 17 | 18 | distance = euclidean_distance([0, 0, 0], [0, 0, 0]) 19 | assert distance == 0 20 | 21 | def test_cross_entropy(): 22 | loss = cross_entropy_loss([1,0], [0.9, 0.3]) 23 | assert round(loss,2) == 0.23 24 | 25 | loss = cross_entropy_loss([1,0,0,1], [0.9,0.3,0.5,0.75]) 26 | assert round(loss,2) == 0.36 27 | 28 | loss = cross_entropy_loss([1,0,0,1,1,0,1,1], [0.9,0.3,0.5,0.75,0.85,0.14,0.93,0.79]) 29 | assert round(loss,2) == 0.26 30 | 31 | 32 | def test_rms_error(): 33 | assert rms_error([2, 2], [2, 2]) == 0 34 | assert rms_error((0, 0), (0, 1)) == math.sqrt(0.5) 35 | assert rms_error((1, 0), (0, 1)) == 1 36 | assert rms_error((0, 0), (0, -1)) == math.sqrt(0.5) 37 | assert rms_error((0, 0.5), (0, -0.5)) == math.sqrt(0.5) 38 | 39 | 40 | def test_manhattan_distance(): 41 | assert manhattan_distance([2, 2], [2, 2]) == 0 42 | assert manhattan_distance([0, 0], [0, 1]) == 1 43 | assert manhattan_distance([1, 0], [0, 1]) == 2 44 | assert manhattan_distance([0, 0], [0, -1]) == 1 45 | assert manhattan_distance([0, 0.5], [0, -0.5]) == 1 46 | 47 | 48 | def test_mean_boolean_error(): 49 | assert mean_boolean_error([1, 1], [0, 0]) == 1 50 | assert mean_boolean_error([0, 1], [1, 0]) == 1 51 | assert mean_boolean_error([1, 1], [0, 1]) == 0.5 52 | assert mean_boolean_error([0, 0], [0, 0]) == 0 53 | assert mean_boolean_error([1, 1], [1, 1]) == 0 54 | 55 | 56 | def test_mean_error(): 57 | assert mean_error([2, 2], [2, 2]) == 0 58 | assert mean_error([0, 0], [0, 1]) == 0.5 59 | assert mean_error([1, 0], [0, 1]) == 1 60 | assert mean_error([0, 0], [0, -1]) == 0.5 61 | assert mean_error([0, 0.5], [0, -0.5]) == 0.5 62 | 63 | 64 | def test_exclude(): 65 | iris = DataSet(name='iris', exclude=[3]) 66 | assert iris.inputs == [0, 1, 2] 67 | 68 | 69 | def test_parse_csv(): 70 | Iris = open_data('iris.csv').read() 71 | assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2, 'setosa'] 72 | 73 | 74 | def test_weighted_mode(): 75 | assert weighted_mode('abbaa', [1, 2, 3, 1, 2]) == 'b' 76 | 77 | 78 | def test_weighted_replicate(): 79 | assert weighted_replicate('ABC', [1, 2, 1], 4) == ['A', 'B', 'B', 'C'] 80 | 81 | 82 | def test_means_and_deviation(): 83 | iris = DataSet(name="iris") 84 | 85 | means, deviations = iris.find_means_and_deviations() 86 | 87 | assert round(means["setosa"][0], 3) == 5.006 88 | assert round(means["versicolor"][0], 3) == 5.936 89 | assert round(means["virginica"][0], 3) == 6.588 90 | 91 | assert round(deviations["setosa"][0], 3) == 0.352 92 | assert round(deviations["versicolor"][0], 3) == 0.516 93 | assert round(deviations["virginica"][0], 3) == 0.636 94 | 95 | 96 | def test_plurality_learner(): 97 | zoo = DataSet(name="zoo") 98 | 99 | pL = PluralityLearner(zoo) 100 | assert pL([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 4, 1, 0, 1]) == "mammal" 101 | 102 | 103 | def test_naive_bayes(): 104 | iris = DataSet(name="iris") 105 | 106 | # Discrete 107 | nBD = NaiveBayesLearner(iris, continuous=False) 108 | assert nBD([5, 3, 1, 0.1]) == "setosa" 109 | assert nBD([6, 3, 4, 1.1]) == "versicolor" 110 | assert nBD([7.7, 3, 6, 2]) == "virginica" 111 | 112 | # Continuous 113 | nBC = NaiveBayesLearner(iris, continuous=True) 114 | assert nBC([5, 3, 1, 0.1]) == "setosa" 115 | assert nBC([6, 5, 3, 1.5]) == "versicolor" 116 | assert nBC([7, 3, 6.5, 2]) == "virginica" 117 | 118 | # Simple 119 | data1 = 'a'*50 + 'b'*30 + 'c'*15 120 | dist1 = CountingProbDist(data1) 121 | data2 = 'a'*30 + 'b'*45 + 'c'*20 122 | dist2 = CountingProbDist(data2) 123 | data3 = 'a'*20 + 'b'*20 + 'c'*35 124 | dist3 = CountingProbDist(data3) 125 | 126 | dist = {('First', 0.5): dist1, ('Second', 0.3): dist2, ('Third', 0.2): dist3} 127 | nBS = NaiveBayesLearner(dist, simple=True) 128 | assert nBS('aab') == 'First' 129 | assert nBS(['b', 'b']) == 'Second' 130 | assert nBS('ccbcc') == 'Third' 131 | 132 | 133 | def test_k_nearest_neighbors(): 134 | iris = DataSet(name="iris") 135 | kNN = NearestNeighborLearner(iris, k=3) 136 | assert kNN([5, 3, 1, 0.1]) == "setosa" 137 | assert kNN([5, 3, 1, 0.1]) == "setosa" 138 | assert kNN([6, 5, 3, 1.5]) == "versicolor" 139 | assert kNN([7.5, 4, 6, 2]) == "virginica" 140 | 141 | 142 | def test_truncated_svd(): 143 | test_mat = [[17, 0], 144 | [0, 11]] 145 | _, _, eival = truncated_svd(test_mat) 146 | assert isclose(abs(eival[0]), 17) 147 | assert isclose(abs(eival[1]), 11) 148 | 149 | test_mat = [[17, 0], 150 | [0, -34]] 151 | _, _, eival = truncated_svd(test_mat) 152 | assert isclose(abs(eival[0]), 34) 153 | assert isclose(abs(eival[1]), 17) 154 | 155 | test_mat = [[1, 0, 0, 0, 2], 156 | [0, 0, 3, 0, 0], 157 | [0, 0, 0, 0, 0], 158 | [0, 2, 0, 0, 0]] 159 | _, _, eival = truncated_svd(test_mat) 160 | assert isclose(abs(eival[0]), 3) 161 | assert isclose(abs(eival[1]), 5**0.5) 162 | 163 | test_mat = [[3, 2, 2], 164 | [2, 3, -2]] 165 | _, _, eival = truncated_svd(test_mat) 166 | assert isclose(abs(eival[0]), 5) 167 | assert isclose(abs(eival[1]), 3) 168 | 169 | 170 | def test_decision_tree_learner(): 171 | iris = DataSet(name="iris") 172 | dTL = DecisionTreeLearner(iris) 173 | assert dTL([5, 3, 1, 0.1]) == "setosa" 174 | assert dTL([6, 5, 3, 1.5]) == "versicolor" 175 | assert dTL([7.5, 4, 6, 2]) == "virginica" 176 | 177 | 178 | def test_information_content(): 179 | assert information_content([]) == 0 180 | assert information_content([4]) == 0 181 | assert information_content([5, 4, 0, 2, 5, 0]) > 1.9 182 | assert information_content([5, 4, 0, 2, 5, 0]) < 2 183 | assert information_content([1.5, 2.5]) > 0.9 184 | assert information_content([1.5, 2.5]) < 1.0 185 | 186 | 187 | def test_random_forest(): 188 | iris = DataSet(name="iris") 189 | rF = RandomForest(iris) 190 | tests = [([5.0, 3.0, 1.0, 0.1], "setosa"), 191 | ([5.1, 3.3, 1.1, 0.1], "setosa"), 192 | ([6.0, 5.0, 3.0, 1.0], "versicolor"), 193 | ([6.1, 2.2, 3.5, 1.0], "versicolor"), 194 | ([7.5, 4.1, 6.2, 2.3], "virginica"), 195 | ([7.3, 3.7, 6.1, 2.5], "virginica")] 196 | assert grade_learner(rF, tests) >= 1/3 197 | 198 | 199 | def test_neural_network_learner(): 200 | iris = DataSet(name="iris") 201 | classes = ["setosa", "versicolor", "virginica"] 202 | iris.classes_to_numbers(classes) 203 | nNL = NeuralNetLearner(iris, [5], 0.15, 75) 204 | tests = [([5.0, 3.1, 0.9, 0.1], 0), 205 | ([5.1, 3.5, 1.0, 0.0], 0), 206 | ([4.9, 3.3, 1.1, 0.1], 0), 207 | ([6.0, 3.0, 4.0, 1.1], 1), 208 | ([6.1, 2.2, 3.5, 1.0], 1), 209 | ([5.9, 2.5, 3.3, 1.1], 1), 210 | ([7.5, 4.1, 6.2, 2.3], 2), 211 | ([7.3, 4.0, 6.1, 2.4], 2), 212 | ([7.0, 3.3, 6.1, 2.5], 2)] 213 | assert grade_learner(nNL, tests) >= 1/3 214 | assert err_ratio(nNL, iris) < 0.21 215 | 216 | 217 | def test_perceptron(): 218 | iris = DataSet(name="iris") 219 | iris.classes_to_numbers() 220 | classes_number = len(iris.values[iris.target]) 221 | perceptron = PerceptronLearner(iris) 222 | tests = [([5, 3, 1, 0.1], 0), 223 | ([5, 3.5, 1, 0], 0), 224 | ([6, 3, 4, 1.1], 1), 225 | ([6, 2, 3.5, 1], 1), 226 | ([7.5, 4, 6, 2], 2), 227 | ([7, 3, 6, 2.5], 2)] 228 | assert grade_learner(perceptron, tests) > 1/2 229 | assert err_ratio(perceptron, iris) < 0.4 230 | 231 | 232 | def test_random_weights(): 233 | min_value = -0.5 234 | max_value = 0.5 235 | num_weights = 10 236 | test_weights = random_weights(min_value, max_value, num_weights) 237 | assert len(test_weights) == num_weights 238 | for weight in test_weights: 239 | assert weight >= min_value and weight <= max_value 240 | 241 | 242 | def test_adaboost(): 243 | iris = DataSet(name="iris") 244 | iris.classes_to_numbers() 245 | WeightedPerceptron = WeightedLearner(PerceptronLearner) 246 | AdaboostLearner = AdaBoost(WeightedPerceptron, 5) 247 | adaboost = AdaboostLearner(iris) 248 | tests = [([5, 3, 1, 0.1], 0), 249 | ([5, 3.5, 1, 0], 0), 250 | ([6, 3, 4, 1.1], 1), 251 | ([6, 2, 3.5, 1], 1), 252 | ([7.5, 4, 6, 2], 2), 253 | ([7, 3, 6, 2.5], 2)] 254 | assert grade_learner(adaboost, tests) > 4/6 255 | assert err_ratio(adaboost, iris) < 0.25 256 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from utils import * 3 | import random 4 | 5 | 6 | def test_removeall_list(): 7 | assert removeall(4, []) == [] 8 | assert removeall(4, [1, 2, 3, 4]) == [1, 2, 3] 9 | assert removeall(4, [4, 1, 4, 2, 3, 4, 4]) == [1, 2, 3] 10 | 11 | 12 | def test_removeall_string(): 13 | assert removeall('s', '') == '' 14 | assert removeall('s', 'This is a test. Was a test.') == 'Thi i a tet. Wa a tet.' 15 | 16 | 17 | def test_unique(): 18 | assert unique([1, 2, 3, 2, 1]) == [1, 2, 3] 19 | assert unique([1, 5, 6, 7, 6, 5]) == [1, 5, 6, 7] 20 | 21 | 22 | def test_count(): 23 | assert count([1, 2, 3, 4, 2, 3, 4]) == 7 24 | assert count("aldpeofmhngvia") == 14 25 | assert count([True, False, True, True, False]) == 3 26 | assert count([5 > 1, len("abc") == 3, 3+1 == 5]) == 2 27 | 28 | 29 | def test_product(): 30 | assert product([1, 2, 3, 4]) == 24 31 | assert product(list(range(1, 11))) == 3628800 32 | 33 | 34 | def test_first(): 35 | assert first('word') == 'w' 36 | assert first('') is None 37 | assert first('', 'empty') == 'empty' 38 | assert first(range(10)) == 0 39 | assert first(x for x in range(10) if x > 3) == 4 40 | assert first(x for x in range(10) if x > 100) is None 41 | 42 | 43 | def test_is_in(): 44 | e = [] 45 | assert is_in(e, [1, e, 3]) is True 46 | assert is_in(e, [1, [], 3]) is False 47 | 48 | 49 | def test_mode(): 50 | assert mode([12, 32, 2, 1, 2, 3, 2, 3, 2, 3, 44, 3, 12, 4, 9, 0, 3, 45, 3]) == 3 51 | assert mode("absndkwoajfkalwpdlsdlfllalsflfdslgflal") == 'l' 52 | 53 | 54 | def test_powerset(): 55 | assert powerset([1, 2, 3]) == [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] 56 | 57 | 58 | def test_argminmax(): 59 | assert argmin([-2, 1], key=abs) == 1 60 | assert argmax([-2, 1], key=abs) == -2 61 | assert argmax(['one', 'to', 'three'], key=len) == 'three' 62 | 63 | 64 | def test_histogram(): 65 | assert histogram([1, 2, 4, 2, 4, 5, 7, 9, 2, 1]) == [(1, 2), (2, 3), 66 | (4, 2), (5, 1), 67 | (7, 1), (9, 1)] 68 | assert histogram([1, 2, 4, 2, 4, 5, 7, 9, 2, 1], 0, lambda x: x*x) == [(1, 2), (4, 3), 69 | (16, 2), (25, 1), 70 | (49, 1), (81, 1)] 71 | assert histogram([1, 2, 4, 2, 4, 5, 7, 9, 2, 1], 1) == [(2, 3), (4, 2), 72 | (1, 2), (9, 1), 73 | (7, 1), (5, 1)] 74 | 75 | 76 | def test_dotproduct(): 77 | assert dotproduct([1, 2, 3], [1000, 100, 10]) == 1230 78 | 79 | 80 | def test_element_wise_product(): 81 | assert element_wise_product([1, 2, 5], [7, 10, 0]) == [7, 20, 0] 82 | assert element_wise_product([1, 6, 3, 0], [9, 12, 0, 0]) == [9, 72, 0, 0] 83 | 84 | 85 | def test_matrix_multiplication(): 86 | assert matrix_multiplication([[1, 2, 3], 87 | [2, 3, 4]], 88 | [[3, 4], 89 | [1, 2], 90 | [1, 0]]) == [[8, 8], [13, 14]] 91 | 92 | assert matrix_multiplication([[1, 2, 3], 93 | [2, 3, 4]], 94 | [[3, 4, 8, 1], 95 | [1, 2, 5, 0], 96 | [1, 0, 0, 3]], 97 | [[1, 2], 98 | [3, 4], 99 | [5, 6], 100 | [1, 2]]) == [[132, 176], [224, 296]] 101 | 102 | 103 | def test_vector_to_diagonal(): 104 | assert vector_to_diagonal([1, 2, 3]) == [[1, 0, 0], [0, 2, 0], [0, 0, 3]] 105 | assert vector_to_diagonal([0, 3, 6]) == [[0, 0, 0], [0, 3, 0], [0, 0, 6]] 106 | 107 | 108 | def test_vector_add(): 109 | assert vector_add((0, 1), (8, 9)) == (8, 10) 110 | 111 | 112 | def test_scalar_vector_product(): 113 | assert scalar_vector_product(2, [1, 2, 3]) == [2, 4, 6] 114 | 115 | 116 | def test_scalar_matrix_product(): 117 | assert rounder(scalar_matrix_product(-5, [[1, 2], [3, 4], [0, 6]])) == [[-5, -10], [-15, -20], 118 | [0, -30]] 119 | assert rounder(scalar_matrix_product(0.2, [[1, 2], [2, 3]])) == [[0.2, 0.4], [0.4, 0.6]] 120 | 121 | 122 | def test_inverse_matrix(): 123 | assert rounder(inverse_matrix([[1, 0], [0, 1]])) == [[1, 0], [0, 1]] 124 | assert rounder(inverse_matrix([[2, 1], [4, 3]])) == [[1.5, -0.5], [-2.0, 1.0]] 125 | assert rounder(inverse_matrix([[4, 7], [2, 6]])) == [[0.6, -0.7], [-0.2, 0.4]] 126 | 127 | 128 | def test_rounder(): 129 | assert rounder(5.3330000300330) == 5.3330 130 | assert rounder(10.234566) == 10.2346 131 | assert rounder([1.234566, 0.555555, 6.010101]) == [1.2346, 0.5556, 6.0101] 132 | assert rounder([[1.234566, 0.555555, 6.010101], 133 | [10.505050, 12.121212, 6.030303]]) == [[1.2346, 0.5556, 6.0101], 134 | [10.5051, 12.1212, 6.0303]] 135 | 136 | 137 | def test_num_or_str(): 138 | assert num_or_str('42') == 42 139 | assert num_or_str(' 42x ') == '42x' 140 | 141 | 142 | def test_normalize(): 143 | assert normalize([1, 2, 1]) == [0.25, 0.5, 0.25] 144 | 145 | 146 | def test_norm(): 147 | assert isclose(norm([1, 2, 1], 1), 4) 148 | assert isclose(norm([3, 4], 2), 5) 149 | assert isclose(norm([-1, 1, 2], 4), 18**0.25) 150 | 151 | 152 | def test_clip(): 153 | assert [clip(x, 0, 1) for x in [-1, 0.5, 10]] == [0, 0.5, 1] 154 | 155 | 156 | def test_sigmoid(): 157 | assert isclose(0.5, sigmoid(0)) 158 | assert isclose(0.7310585786300049, sigmoid(1)) 159 | assert isclose(0.2689414213699951, sigmoid(-1)) 160 | 161 | 162 | def test_gaussian(): 163 | assert gaussian(1,0.5,0.7) == 0.6664492057835993 164 | assert gaussian(5,2,4.5) == 0.19333405840142462 165 | assert gaussian(3,1,3) == 0.3989422804014327 166 | 167 | 168 | def test_sigmoid_derivative(): 169 | value = 1 170 | assert sigmoid_derivative(value) == 0 171 | 172 | value = 3 173 | assert sigmoid_derivative(value) == -6 174 | 175 | 176 | def test_weighted_choice(): 177 | choices = [('a', 0.5), ('b', 0.3), ('c', 0.2)] 178 | choice = weighted_choice(choices) 179 | assert choice in choices 180 | 181 | 182 | def compare_list(x, y): 183 | return all([elm_x == y[i] for i, elm_x in enumerate(x)]) 184 | 185 | 186 | def test_distance(): 187 | assert distance((1, 2), (5, 5)) == 5.0 188 | 189 | 190 | def test_distance_squared(): 191 | assert distance_squared((1, 2), (5, 5)) == 25.0 192 | 193 | 194 | def test_vector_clip(): 195 | assert vector_clip((-1, 10), (0, 0), (9, 9)) == (0, 9) 196 | 197 | 198 | def test_turn_heading(): 199 | assert turn_heading((0, 1), 1) == (-1, 0) 200 | assert turn_heading((0, 1), -1) == (1, 0) 201 | assert turn_heading((1, 0), 1) == (0, 1) 202 | assert turn_heading((1, 0), -1) == (0, -1) 203 | assert turn_heading((0, -1), 1) == (1, 0) 204 | assert turn_heading((0, -1), -1) == (-1, 0) 205 | assert turn_heading((-1, 0), 1) == (0, -1) 206 | assert turn_heading((-1, 0), -1) == (0, 1) 207 | 208 | 209 | def test_turn_left(): 210 | assert turn_left((0, 1)) == (-1, 0) 211 | 212 | 213 | def test_turn_right(): 214 | assert turn_right((0, 1)) == (1, 0) 215 | 216 | 217 | def test_step(): 218 | assert step(1) == step(0.5) == 1 219 | assert step(0) == 1 220 | assert step(-1) == step(-0.5) == 0 221 | 222 | 223 | def test_Expr(): 224 | A, B, C = symbols('A, B, C') 225 | assert symbols('A, B, C') == (Symbol('A'), Symbol('B'), Symbol('C')) 226 | assert A.op == repr(A) == 'A' 227 | assert arity(A) == 0 and A.args == () 228 | 229 | b = Expr('+', A, 1) 230 | assert arity(b) == 2 and b.op == '+' and b.args == (A, 1) 231 | 232 | u = Expr('-', b) 233 | assert arity(u) == 1 and u.op == '-' and u.args == (b,) 234 | 235 | assert (b ** u) == (b ** u) 236 | assert (b ** u) != (u ** b) 237 | 238 | assert A + b * C ** 2 == A + (b * (C ** 2)) 239 | 240 | ex = C + 1 / (A % 1) 241 | assert list(subexpressions(ex)) == [(C + (1 / (A % 1))), C, (1 / (A % 1)), 1, (A % 1), A, 1] 242 | assert A in subexpressions(ex) 243 | assert B not in subexpressions(ex) 244 | 245 | 246 | def test_expr(): 247 | P, Q, x, y, z, GP = symbols('P, Q, x, y, z, GP') 248 | assert (expr(y + 2 * x) 249 | == expr('y + 2 * x') 250 | == Expr('+', y, Expr('*', 2, x))) 251 | assert expr('P & Q ==> P') == Expr('==>', P & Q, P) 252 | assert expr('P & Q <=> Q & P') == Expr('<=>', (P & Q), (Q & P)) 253 | assert expr('P(x) | P(y) & Q(z)') == (P(x) | (P(y) & Q(z))) 254 | # x is grandparent of z if x is parent of y and y is parent of z: 255 | assert (expr('GP(x, z) <== P(x, y) & P(y, z)') 256 | == Expr('<==', GP(x, z), P(x, y) & P(y, z))) 257 | 258 | 259 | if __name__ == '__main__': 260 | pytest.main() 261 | -------------------------------------------------------------------------------- /tests/test_agents.py: -------------------------------------------------------------------------------- 1 | import random 2 | from agents import Direction 3 | from agents import Agent 4 | from agents import ReflexVacuumAgent, ModelBasedVacuumAgent, TrivialVacuumEnvironment, compare_agents,\ 5 | RandomVacuumAgent, TableDrivenVacuumAgent, TableDrivenAgentProgram, RandomAgentProgram, \ 6 | SimpleReflexAgentProgram, ModelBasedReflexAgentProgram, rule_match 7 | 8 | 9 | random.seed("aima-python") 10 | 11 | 12 | def test_move_forward(): 13 | d = Direction("up") 14 | l1 = d.move_forward((0, 0)) 15 | assert l1 == (0, -1) 16 | 17 | d = Direction(Direction.R) 18 | l1 = d.move_forward((0, 0)) 19 | assert l1 == (1, 0) 20 | 21 | d = Direction(Direction.D) 22 | l1 = d.move_forward((0, 0)) 23 | assert l1 == (0, 1) 24 | 25 | d = Direction("left") 26 | l1 = d.move_forward((0, 0)) 27 | assert l1 == (-1, 0) 28 | 29 | l2 = d.move_forward((1, 0)) 30 | assert l2 == (0, 0) 31 | 32 | 33 | def test_add(): 34 | d = Direction(Direction.U) 35 | l1 = d + "right" 36 | l2 = d + "left" 37 | assert l1.direction == Direction.R 38 | assert l2.direction == Direction.L 39 | 40 | d = Direction("right") 41 | l1 = d.__add__(Direction.L) 42 | l2 = d.__add__(Direction.R) 43 | assert l1.direction == "up" 44 | assert l2.direction == "down" 45 | 46 | d = Direction("down") 47 | l1 = d.__add__("right") 48 | l2 = d.__add__("left") 49 | assert l1.direction == Direction.L 50 | assert l2.direction == Direction.R 51 | 52 | d = Direction(Direction.L) 53 | l1 = d + Direction.R 54 | l2 = d + Direction.L 55 | assert l1.direction == Direction.U 56 | assert l2.direction == Direction.D 57 | 58 | 59 | def test_RandomAgentProgram() : 60 | #create a list of all the actions a vacuum cleaner can perform 61 | list = ['Right', 'Left', 'Suck', 'NoOp'] 62 | # create a program and then an object of the RandomAgentProgram 63 | program = RandomAgentProgram(list) 64 | 65 | agent = Agent(program) 66 | # create an object of TrivialVacuumEnvironment 67 | environment = TrivialVacuumEnvironment() 68 | # add agent to the environment 69 | environment.add_thing(agent) 70 | # run the environment 71 | environment.run() 72 | # check final status of the environment 73 | assert environment.status == {(1, 0): 'Clean' , (0, 0): 'Clean'} 74 | 75 | 76 | def test_RandomVacuumAgent() : 77 | # create an object of the RandomVacuumAgent 78 | agent = RandomVacuumAgent() 79 | # create an object of TrivialVacuumEnvironment 80 | environment = TrivialVacuumEnvironment() 81 | # add agent to the environment 82 | environment.add_thing(agent) 83 | # run the environment 84 | environment.run() 85 | # check final status of the environment 86 | assert environment.status == {(1,0):'Clean' , (0,0) : 'Clean'} 87 | 88 | 89 | def test_TableDrivenAgent(): 90 | loc_A, loc_B = (0, 0), (1, 0) 91 | # table defining all the possible states of the agent 92 | table = {((loc_A, 'Clean'),): 'Right', 93 | ((loc_A, 'Dirty'),): 'Suck', 94 | ((loc_B, 'Clean'),): 'Left', 95 | ((loc_B, 'Dirty'),): 'Suck', 96 | ((loc_A, 'Dirty'), (loc_A, 'Clean')): 'Right', 97 | ((loc_A, 'Clean'), (loc_B, 'Dirty')): 'Suck', 98 | ((loc_B, 'Clean'), (loc_A, 'Dirty')): 'Suck', 99 | ((loc_B, 'Dirty'), (loc_B, 'Clean')): 'Left', 100 | ((loc_A, 'Dirty'), (loc_A, 'Clean'), (loc_B, 'Dirty')): 'Suck', 101 | ((loc_B, 'Dirty'), (loc_B, 'Clean'), (loc_A, 'Dirty')): 'Suck' 102 | } 103 | 104 | # create an program and then an object of the TableDrivenAgent 105 | program = TableDrivenAgentProgram(table) 106 | agent = Agent(program) 107 | # create an object of TrivialVacuumEnvironment 108 | environment = TrivialVacuumEnvironment() 109 | # initializing some environment status 110 | environment.status = {loc_A:'Dirty', loc_B:'Dirty'} 111 | # add agent to the environment 112 | environment.add_thing(agent) 113 | 114 | # run the environment by single step everytime to check how environment evolves using TableDrivenAgentProgram 115 | environment.run(steps = 1) 116 | assert environment.status == {(1,0): 'Clean', (0,0): 'Dirty'} 117 | 118 | environment.run(steps = 1) 119 | assert environment.status == {(1,0): 'Clean', (0,0): 'Dirty'} 120 | 121 | environment.run(steps = 1) 122 | assert environment.status == {(1,0): 'Clean', (0,0): 'Clean'} 123 | 124 | 125 | def test_ReflexVacuumAgent() : 126 | # create an object of the ReflexVacuumAgent 127 | agent = ReflexVacuumAgent() 128 | # create an object of TrivialVacuumEnvironment 129 | environment = TrivialVacuumEnvironment() 130 | # add agent to the environment 131 | environment.add_thing(agent) 132 | # run the environment 133 | environment.run() 134 | # check final status of the environment 135 | assert environment.status == {(1,0):'Clean' , (0,0) : 'Clean'} 136 | 137 | 138 | def test_SimpleReflexAgentProgram(): 139 | class Rule: 140 | 141 | def __init__(self, state, action): 142 | self.__state = state 143 | self.action = action 144 | 145 | def matches(self, state): 146 | return self.__state == state 147 | 148 | loc_A = (0, 0) 149 | loc_B = (1, 0) 150 | 151 | # create rules for a two state Vacuum Environment 152 | rules = [Rule((loc_A, "Dirty"), "Suck"), Rule((loc_A, "Clean"), "Right"), 153 | Rule((loc_B, "Dirty"), "Suck"), Rule((loc_B, "Clean"), "Left")] 154 | 155 | def interpret_input(state): 156 | return state 157 | 158 | # create a program and then an object of the SimpleReflexAgentProgram 159 | program = SimpleReflexAgentProgram(rules, interpret_input) 160 | agent = Agent(program) 161 | # create an object of TrivialVacuumEnvironment 162 | environment = TrivialVacuumEnvironment() 163 | # add agent to the environment 164 | environment.add_thing(agent) 165 | # run the environment 166 | environment.run() 167 | # check final status of the environment 168 | assert environment.status == {(1,0):'Clean' , (0,0) : 'Clean'} 169 | 170 | 171 | def test_ModelBasedReflexAgentProgram(): 172 | class Rule: 173 | 174 | def __init__(self, state, action): 175 | self.__state = state 176 | self.action = action 177 | 178 | def matches(self, state): 179 | return self.__state == state 180 | 181 | loc_A = (0, 0) 182 | loc_B = (1, 0) 183 | 184 | # create rules for a two-state vacuum environment 185 | rules = [Rule((loc_A, "Dirty"), "Suck"), Rule((loc_A, "Clean"), "Right"), 186 | Rule((loc_B, "Dirty"), "Suck"), Rule((loc_B, "Clean"), "Left")] 187 | 188 | def update_state(state, action, percept, model): 189 | return percept 190 | 191 | # create a program and then an object of the ModelBasedReflexAgentProgram class 192 | program = ModelBasedReflexAgentProgram(rules, update_state, None) 193 | agent = Agent(program) 194 | # create an object of TrivialVacuumEnvironment 195 | environment = TrivialVacuumEnvironment() 196 | # add agent to the environment 197 | environment.add_thing(agent) 198 | # run the environment 199 | environment.run() 200 | # check final status of the environment 201 | assert environment.status == {(1, 0): 'Clean', (0, 0): 'Clean'} 202 | 203 | 204 | def test_ModelBasedVacuumAgent() : 205 | # create an object of the ModelBasedVacuumAgent 206 | agent = ModelBasedVacuumAgent() 207 | # create an object of TrivialVacuumEnvironment 208 | environment = TrivialVacuumEnvironment() 209 | # add agent to the environment 210 | environment.add_thing(agent) 211 | # run the environment 212 | environment.run() 213 | # check final status of the environment 214 | assert environment.status == {(1,0):'Clean' , (0,0) : 'Clean'} 215 | 216 | 217 | def test_TableDrivenVacuumAgent() : 218 | # create an object of the TableDrivenVacuumAgent 219 | agent = TableDrivenVacuumAgent() 220 | # create an object of the TrivialVacuumEnvironment 221 | environment = TrivialVacuumEnvironment() 222 | # add agent to the environment 223 | environment.add_thing(agent) 224 | # run the environment 225 | environment.run() 226 | # check final status of the environment 227 | assert environment.status == {(1, 0):'Clean', (0, 0):'Clean'} 228 | 229 | 230 | def test_compare_agents() : 231 | environment = TrivialVacuumEnvironment 232 | agents = [ModelBasedVacuumAgent, ReflexVacuumAgent] 233 | 234 | result = compare_agents(environment, agents) 235 | performance_ModelBasedVacummAgent = result[0][1] 236 | performance_ReflexVacummAgent = result[1][1] 237 | 238 | # The performance of ModelBasedVacuumAgent will be at least as good as that of 239 | # ReflexVacuumAgent, since ModelBasedVacuumAgent can identify when it has 240 | # reached the terminal state (both locations being clean) and will perform 241 | # NoOp leading to 0 performance change, whereas ReflexVacuumAgent cannot 242 | # identify the terminal state and thus will keep moving, leading to worse 243 | # performance compared to ModelBasedVacuumAgent. 244 | assert performance_ReflexVacummAgent <= performance_ModelBasedVacummAgent 245 | 246 | 247 | def test_TableDrivenAgentProgram(): 248 | table = {(('foo', 1),): 'action1', 249 | (('foo', 2),): 'action2', 250 | (('bar', 1),): 'action3', 251 | (('bar', 2),): 'action1', 252 | (('foo', 1), ('foo', 1),): 'action2', 253 | (('foo', 1), ('foo', 2),): 'action3', 254 | } 255 | agent_program = TableDrivenAgentProgram(table) 256 | assert agent_program(('foo', 1)) == 'action1' 257 | assert agent_program(('foo', 2)) == 'action3' 258 | assert agent_program(('invalid percept',)) == None 259 | 260 | 261 | def test_Agent(): 262 | def constant_prog(percept): 263 | return percept 264 | agent = Agent(constant_prog) 265 | result = agent.program(5) 266 | assert result == 5 267 | -------------------------------------------------------------------------------- /tests/test_nlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import nlp 3 | 4 | from nlp import loadPageHTML, stripRawHTML, findOutlinks, onlyWikipediaURLS 5 | from nlp import expand_pages, relevant_pages, normalize, ConvergenceDetector, getInlinks 6 | from nlp import getOutlinks, Page, determineInlinks, HITS 7 | from nlp import Rules, Lexicon, Grammar, ProbRules, ProbLexicon, ProbGrammar 8 | from nlp import Chart, CYK_parse 9 | # Clumsy imports because we want to access certain nlp.py globals explicitly, because 10 | # they are accessed by functions within nlp.py 11 | 12 | from unittest.mock import patch 13 | from io import BytesIO 14 | 15 | 16 | def test_rules(): 17 | check = {'A': [['B', 'C'], ['D', 'E']], 'B': [['E'], ['a'], ['b', 'c']]} 18 | assert Rules(A="B C | D E", B="E | a | b c") == check 19 | 20 | 21 | def test_lexicon(): 22 | check = {'Article': ['the', 'a', 'an'], 'Pronoun': ['i', 'you', 'he']} 23 | lexicon = Lexicon(Article="the | a | an", Pronoun="i | you | he") 24 | assert lexicon == check 25 | 26 | 27 | def test_grammar(): 28 | rules = Rules(A="B C | D E", B="E | a | b c") 29 | lexicon = Lexicon(Article="the | a | an", Pronoun="i | you | he") 30 | grammar = Grammar("Simplegram", rules, lexicon) 31 | 32 | assert grammar.rewrites_for('A') == [['B', 'C'], ['D', 'E']] 33 | assert grammar.isa('the', 'Article') 34 | 35 | grammar = nlp.E_Chomsky 36 | for rule in grammar.cnf_rules(): 37 | assert len(rule) == 3 38 | 39 | 40 | def test_generation(): 41 | lexicon = Lexicon(Article="the | a | an", 42 | Pronoun="i | you | he") 43 | 44 | rules = Rules( 45 | S="Article | More | Pronoun", 46 | More="Article Pronoun | Pronoun Pronoun" 47 | ) 48 | 49 | grammar = Grammar("Simplegram", rules, lexicon) 50 | 51 | sentence = grammar.generate_random('S') 52 | for token in sentence.split(): 53 | found = False 54 | for non_terminal, terminals in grammar.lexicon.items(): 55 | if token in terminals: 56 | found = True 57 | assert found 58 | 59 | 60 | def test_prob_rules(): 61 | check = {'A': [(['B', 'C'], 0.3), (['D', 'E'], 0.7)], 62 | 'B': [(['E'], 0.1), (['a'], 0.2), (['b', 'c'], 0.7)]} 63 | rules = ProbRules(A="B C [0.3] | D E [0.7]", B="E [0.1] | a [0.2] | b c [0.7]") 64 | assert rules == check 65 | 66 | 67 | def test_prob_lexicon(): 68 | check = {'Article': [('the', 0.5), ('a', 0.25), ('an', 0.25)], 69 | 'Pronoun': [('i', 0.4), ('you', 0.3), ('he', 0.3)]} 70 | lexicon = ProbLexicon(Article="the [0.5] | a [0.25] | an [0.25]", 71 | Pronoun="i [0.4] | you [0.3] | he [0.3]") 72 | assert lexicon == check 73 | 74 | 75 | def test_prob_grammar(): 76 | rules = ProbRules(A="B C [0.3] | D E [0.7]", B="E [0.1] | a [0.2] | b c [0.7]") 77 | lexicon = ProbLexicon(Article="the [0.5] | a [0.25] | an [0.25]", 78 | Pronoun="i [0.4] | you [0.3] | he [0.3]") 79 | grammar = ProbGrammar("Simplegram", rules, lexicon) 80 | 81 | assert grammar.rewrites_for('A') == [(['B', 'C'], 0.3), (['D', 'E'], 0.7)] 82 | assert grammar.isa('the', 'Article') 83 | 84 | grammar = nlp.E_Prob_Chomsky 85 | for rule in grammar.cnf_rules(): 86 | assert len(rule) == 4 87 | 88 | 89 | def test_prob_generation(): 90 | lexicon = ProbLexicon(Verb="am [0.5] | are [0.25] | is [0.25]", 91 | Pronoun="i [0.4] | you [0.3] | he [0.3]") 92 | 93 | rules = ProbRules( 94 | S="Verb [0.5] | More [0.3] | Pronoun [0.1] | nobody is here [0.1]", 95 | More="Pronoun Verb [0.7] | Pronoun Pronoun [0.3]" 96 | ) 97 | 98 | grammar = ProbGrammar("Simplegram", rules, lexicon) 99 | 100 | sentence = grammar.generate_random('S') 101 | assert len(sentence) == 2 102 | 103 | 104 | def test_chart_parsing(): 105 | chart = Chart(nlp.E0) 106 | parses = chart.parses('the stench is in 2 2') 107 | assert len(parses) == 1 108 | 109 | 110 | def test_CYK_parse(): 111 | grammar = nlp.E_Prob_Chomsky 112 | words = ['the', 'robot', 'is', 'good'] 113 | P = CYK_parse(words, grammar) 114 | assert len(P) == 52 115 | 116 | grammar = nlp.E_Prob_Chomsky_ 117 | words = ['astronomers', 'saw', 'stars'] 118 | P = CYK_parse(words, grammar) 119 | assert len(P) == 32 120 | 121 | 122 | # ______________________________________________________________________________ 123 | # Data Setup 124 | 125 | testHTML = """Keyword String 1: A man is a male human. 126 | Keyword String 2: Like most other male mammals, a man inherits an 127 | X from his mom and a Y from his dad. 128 | Links: 129 | href="https://google.com.au" 130 | < href="/wiki/TestThing" > href="/wiki/TestBoy" 131 | href="/wiki/TestLiving" href="/wiki/TestMan" >""" 132 | testHTML2 = "a mom and a dad" 133 | testHTML3 = """ 134 | 135 | 136 | 137 | Page Title 138 | 139 | 140 | 141 |

AIMA book

142 | 143 | 144 | 145 | """ 146 | 147 | pA = Page("A", ["B", "C", "E"], ["D"], 1, 6) 148 | pB = Page("B", ["E"], ["A", "C", "D"], 2, 5) 149 | pC = Page("C", ["B", "E"], ["A", "D"], 3, 4) 150 | pD = Page("D", ["A", "B", "C", "E"], [], 4, 3) 151 | pE = Page("E", [], ["A", "B", "C", "D", "F"], 5, 2) 152 | pF = Page("F", ["E"], [], 6, 1) 153 | pageDict = {pA.address: pA, pB.address: pB, pC.address: pC, 154 | pD.address: pD, pE.address: pE, pF.address: pF} 155 | nlp.pagesIndex = pageDict 156 | nlp.pagesContent ={pA.address: testHTML, pB.address: testHTML2, 157 | pC.address: testHTML, pD.address: testHTML2, 158 | pE.address: testHTML, pF.address: testHTML2} 159 | 160 | # This test takes a long time (> 60 secs) 161 | # def test_loadPageHTML(): 162 | # # first format all the relative URLs with the base URL 163 | # addresses = [examplePagesSet[0] + x for x in examplePagesSet[1:]] 164 | # loadedPages = loadPageHTML(addresses) 165 | # relURLs = ['Ancient_Greek','Ethics','Plato','Theology'] 166 | # fullURLs = ["https://en.wikipedia.org/wiki/"+x for x in relURLs] 167 | # assert all(x in loadedPages for x in fullURLs) 168 | # assert all(loadedPages.get(key,"") != "" for key in addresses) 169 | 170 | 171 | @patch('urllib.request.urlopen', return_value=BytesIO(testHTML3.encode())) 172 | def test_stripRawHTML(html_mock): 173 | addr = "https://en.wikipedia.org/wiki/Ethics" 174 | aPage = loadPageHTML([addr]) 175 | someHTML = aPage[addr] 176 | strippedHTML = stripRawHTML(someHTML) 177 | assert "" not in strippedHTML and "" not in strippedHTML 178 | assert "AIMA book" in someHTML and "AIMA book" in strippedHTML 179 | 180 | 181 | def test_determineInlinks(): 182 | assert set(determineInlinks(pA)) == set(['B', 'C', 'E']) 183 | assert set(determineInlinks(pE)) == set([]) 184 | assert set(determineInlinks(pF)) == set(['E']) 185 | 186 | def test_findOutlinks_wiki(): 187 | testPage = pageDict[pA.address] 188 | outlinks = findOutlinks(testPage, handleURLs=onlyWikipediaURLS) 189 | assert "https://en.wikipedia.org/wiki/TestThing" in outlinks 190 | assert "https://en.wikipedia.org/wiki/TestThing" in outlinks 191 | assert "https://google.com.au" not in outlinks 192 | # ______________________________________________________________________________ 193 | # HITS Helper Functions 194 | 195 | 196 | def test_expand_pages(): 197 | pages = {k: pageDict[k] for k in ('F')} 198 | pagesTwo = {k: pageDict[k] for k in ('A', 'E')} 199 | expanded_pages = expand_pages(pages) 200 | assert all(x in expanded_pages for x in ['F', 'E']) 201 | assert all(x not in expanded_pages for x in ['A', 'B', 'C', 'D']) 202 | expanded_pages = expand_pages(pagesTwo) 203 | print(expanded_pages) 204 | assert all(x in expanded_pages for x in ['A', 'B', 'C', 'D', 'E', 'F']) 205 | 206 | 207 | def test_relevant_pages(): 208 | pages = relevant_pages("his dad") 209 | assert all((x in pages) for x in ['A', 'C', 'E']) 210 | assert all((x not in pages) for x in ['B', 'D', 'F']) 211 | pages = relevant_pages("mom and dad") 212 | assert all((x in pages) for x in ['A', 'B', 'C', 'D', 'E', 'F']) 213 | pages = relevant_pages("philosophy") 214 | assert all((x not in pages) for x in ['A', 'B', 'C', 'D', 'E', 'F']) 215 | 216 | 217 | def test_normalize(): 218 | normalize(pageDict) 219 | print(page.hub for addr, page in nlp.pagesIndex.items()) 220 | expected_hub = [1/91**0.5, 2/91**0.5, 3/91**0.5, 4/91**0.5, 5/91**0.5, 6/91**0.5] # Works only for sample data above 221 | expected_auth = list(reversed(expected_hub)) 222 | assert len(expected_hub) == len(expected_auth) == len(nlp.pagesIndex) 223 | assert expected_hub == [page.hub for addr, page in sorted(nlp.pagesIndex.items())] 224 | assert expected_auth == [page.authority for addr, page in sorted(nlp.pagesIndex.items())] 225 | 226 | 227 | def test_detectConvergence(): 228 | # run detectConvergence once to initialise history 229 | convergence = ConvergenceDetector() 230 | convergence() 231 | assert convergence() # values haven't changed so should return True 232 | # make tiny increase/decrease to all values 233 | for _, page in nlp.pagesIndex.items(): 234 | page.hub += 0.0003 235 | page.authority += 0.0004 236 | # retest function with values. Should still return True 237 | assert convergence() 238 | for _, page in nlp.pagesIndex.items(): 239 | page.hub += 3000000 240 | page.authority += 3000000 241 | # retest function with values. Should now return false 242 | assert not convergence() 243 | 244 | 245 | def test_getInlinks(): 246 | inlnks = getInlinks(pageDict['A']) 247 | assert sorted(inlnks) == pageDict['A'].inlinks 248 | 249 | 250 | def test_getOutlinks(): 251 | outlnks = getOutlinks(pageDict['A']) 252 | assert sorted(outlnks) == pageDict['A'].outlinks 253 | 254 | 255 | def test_HITS(): 256 | HITS('inherit') 257 | auth_list = [pA.authority, pB.authority, pC.authority, pD.authority, pE.authority, pF.authority] 258 | hub_list = [pA.hub, pB.hub, pC.hub, pD.hub, pE.hub, pF.hub] 259 | assert max(auth_list) == pD.authority 260 | assert max(hub_list) == pE.hub 261 | 262 | 263 | if __name__ == '__main__': 264 | pytest.main() 265 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | How to Contribute to aima-python 2 | ========================== 3 | 4 | Thanks for considering contributing to `aima-python`! Whether you are an aspiring [Google Summer of Code](https://summerofcode.withgoogle.com/organizations/5674023002832896/) student, or an independent contributor, here is a guide on how you can help. 5 | 6 | First of all, you can read these write-ups from past GSoC students to get an idea about what you can do for the project. [Chipe1](https://github.com/aimacode/aima-python/issues/641) - [MrDupin](https://github.com/aimacode/aima-python/issues/632) 7 | 8 | In general, the main ways you can contribute to the repository are the following: 9 | 10 | 1. Implement algorithms from the [list of algorithms](https://github.com/aimacode/aima-python/blob/master/README.md#index-of-algorithms). 11 | 1. Add tests for algorithms. 12 | 1. Take care of [issues](https://github.com/aimacode/aima-python/issues). 13 | 1. Write on the notebooks (`.ipynb` files). 14 | 1. Add and edit documentation (the docstrings in `.py` files). 15 | 16 | In more detail: 17 | 18 | ## Read the Code and Start on an Issue 19 | 20 | - First, read and understand the code to get a feel for the extent and the style. 21 | - Look at the [issues](https://github.com/aimacode/aima-python/issues) and pick one to work on. 22 | - One of the issues is that some algorithms are missing from the [list of algorithms](https://github.com/aimacode/aima-python/blob/master/README.md#index-of-algorithms) and that some don't have tests. 23 | 24 | ## Port to Python 3; Pythonic Idioms 25 | 26 | - Check for common problems in [porting to Python 3](http://python3porting.com/problems.html), such as: `print` is now a function; `range` and `map` and other functions no longer produce `list`s; objects of different types can no longer be compared with `<`; strings are now Unicode; it would be nice to move `%` string formatting to `.format`; there is a new `next` function for generators; integer division now returns a float; we can now use set literals. 27 | - Replace old Lisp-based idioms with proper Python idioms. For example, we have many functions that were taken directly from Common Lisp, such as the `every` function: `every(callable, items)` returns true if every element of `items` is callable. This is good Lisp style, but good Python style would be to use `all` and a generator expression: `all(callable(f) for f in items)`. Eventually, fix all calls to these legacy Lisp functions and then remove the functions. 28 | 29 | ## New and Improved Algorithms 30 | 31 | - Implement functions that were in the third edition of the book but were not yet implemented in the code. Check the [list of pseudocode algorithms (pdf)](https://github.com/aimacode/pseudocode/blob/master/aima3e-algorithms.pdf) to see what's missing. 32 | - As we finish chapters for the new fourth edition, we will share the new pseudocode in the [`aima-pseudocode`](https://github.com/aimacode/aima-pseudocode) repository, and describe what changes are necessary. 33 | We hope to have an `algorithm-name.md` file for each algorithm, eventually; it would be great if contributors could add some for the existing algorithms. 34 | 35 | ## Jupyter Notebooks 36 | 37 | In this project we use Jupyter/IPython Notebooks to showcase the algorithms in the book. They serve as short tutorials on what the algorithms do, how they are implemented and how one can use them. To install Jupyter, you can follow the instructions [here](https://jupyter.org/install.html). These are some ways you can contribute to the notebooks: 38 | 39 | - Proofread the notebooks for grammar mistakes, typos, or general errors. 40 | - Move visualization and unrelated to the algorithm code from notebooks to `notebook.py` (a file used to store code for the notebooks, like visualization and other miscellaneous stuff). Make sure the notebooks still work and have their outputs showing! 41 | - Replace the `%psource` magic notebook command with the function `psource` from `notebook.py` where needed. Examples where this is useful are a) when we want to show code for algorithm implementation and b) when we have consecutive cells with the magic keyword (in this case, if the code is large, it's best to leave the output hidden). 42 | - Add the function `pseudocode(algorithm_name)` in algorithm sections. The function prints the pseudocode of the algorithm. You can see some example usage in [`knowledge.ipynb`](https://github.com/aimacode/aima-python/blob/master/knowledge.ipynb). 43 | - Edit existing sections for algorithms to add more information and/or examples. 44 | - Add visualizations for algorithms. The visualization code should go in `notebook.py` to keep things clean. 45 | - Add new sections for algorithms not yet covered. The general format we use in the notebooks is the following: First start with an overview of the algorithm, printing the pseudocode and explaining how it works. Then, add some implementation details, including showing the code (using `psource`). Finally, add examples for the implementations, showing how the algorithms work. Don't fret with adding complex, real-world examples; the project is meant for educational purposes. You can of course choose another format if something better suits an algorithm. 46 | 47 | Apart from the notebooks explaining how the algorithms work, we also have notebooks showcasing some indicative applications of the algorithms. These notebooks are in the `*_apps.ipynb` format. We aim to have an `apps` notebook for each module, so if you don't see one for the module you would like to contribute to, feel free to create it from scratch! In these notebooks we are looking for applications showing what the algorithms can do. The general format of these sections is this: Add a description of the problem you are trying to solve, then explain how you are going to solve it and finally provide your solution with examples. Note that any code you write should not require any external libraries apart from the ones already provided (like `matplotlib`). 48 | 49 | # Style Guide 50 | 51 | There are a few style rules that are unique to this project: 52 | 53 | - The first rule is that the code should correspond directly to the pseudocode in the book. When possible this will be almost one-to-one, just allowing for the syntactic differences between Python and pseudocode, and for different library functions. 54 | - Don't make a function more complicated than the pseudocode in the book, even if the complication would add a nice feature, or give an efficiency gain. Instead, remain faithful to the pseudocode, and if you must, add a new function (not in the book) with the added feature. 55 | - I use functional programming (functions with no side effects) in many cases, but not exclusively (sometimes classes and/or functions with side effects are used). Let the book's pseudocode be the guide. 56 | 57 | Beyond the above rules, we use [Pep 8](https://www.python.org/dev/peps/pep-0008), with a few minor exceptions: 58 | 59 | - I have set `--max-line-length 100`, not 79. 60 | - You don't need two spaces after a sentence-ending period. 61 | - Strunk and White is [not a good guide for English](http://chronicle.com/article/50-Years-of-Stupid-Grammar/25497). 62 | - I prefer more concise docstrings; I don't follow [Pep 257](https://www.python.org/dev/peps/pep-0257/). In most cases, 63 | a one-line docstring suffices. It is rarely necessary to list what each argument does; the name of the argument usually is enough. 64 | - Not all constants have to be UPPERCASE. 65 | - At some point I may add [Pep 484](https://www.python.org/dev/peps/pep-0484/) type annotations, but I think I'll hold off for now; 66 | I want to get more experience with them, and some people may still be in Python 3.4. 67 | 68 | Reporting Issues 69 | ================ 70 | 71 | - Under which versions of Python does this happen? 72 | 73 | - Provide an example of the issue occurring. 74 | 75 | - Is anybody working on this? 76 | 77 | Patch Rules 78 | =========== 79 | 80 | - Ensure that the patch is Python 3.4 compliant. 81 | 82 | - Include tests if your patch is supposed to solve a bug, and explain 83 | clearly under which circumstances the bug happens. Make sure the test fails 84 | without your patch. 85 | 86 | - Follow the style guidelines described above. 87 | 88 | # Choice of Programming Languages 89 | 90 | Are we right to concentrate on Java and Python versions of the code? I think so; both languages are popular; Java is 91 | fast enough for our purposes, and has reasonable type declarations (but can be verbose); Python is popular and has a very direct mapping to the pseudocode in the book (but lacks type declarations and can be slow). The [TIOBE Index](http://www.tiobe.com/tiobe_index) says the top seven most popular languages, in order, are: 92 | 93 | Java, C, C++, C#, Python, PHP, Javascript 94 | 95 | So it might be reasonable to also support C++/C# at some point in the future. It might also be reasonable to support a language that combines the terse readability of Python with the type safety and speed of Java; perhaps Go or Julia. I see no reason to support PHP. Javascript is the language of the browser; it would be nice to have code that runs in the browser without need for any downloads; this would be in Javascript or a variant such as Typescript. 96 | 97 | There is also a `aima-lisp` project; in 1995 when we wrote the first edition of the book, Lisp was the right choice, but today it is less popular (currently #31 on the TIOBE index). 98 | 99 | What languages are instructors recommending for their AI class? To get an approximate idea, I gave the query [\[norvig russell "Modern Approach"\]](https://www.google.com/webhp#q=russell%20norvig%20%22modern%20approach%22%20java) along with the names of various languages and looked at the estimated counts of results on 100 | various dates. However, I don't have much confidence in these figures... 101 | 102 | |Language |2004 |2005 |2007 |2010 |2016 | 103 | |-------- |----: |----: |----: |----: |----: | 104 | |[none](http://www.google.com/search?q=norvig+russell+%22Modern+Approach%22)|8,080|20,100|75,200|150,000|132,000| 105 | |[java](http://www.google.com/search?q=java+norvig+russell+%22Modern+Approach%22)|1,990|4,930|44,200|37,000|50,000| 106 | |[c++](http://www.google.com/search?q=c%2B%2B+norvig+russell+%22Modern+Approach%22)|875|1,820|35,300|105,000|35,000| 107 | |[lisp](http://www.google.com/search?q=lisp+norvig+russell+%22Modern+Approach%22)|844|974|30,100|19,000|14,000| 108 | |[prolog](http://www.google.com/search?q=prolog+norvig+russell+%22Modern+Approach%22)|789|2,010|23,200|17,000|16,000| 109 | |[python](http://www.google.com/search?q=python+norvig+russell+%22Modern+Approach%22)|785|1,240|18,400|11,000|12,000| 110 | -------------------------------------------------------------------------------- /tests/test_text.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import random 4 | 5 | from text import * 6 | from utils import isclose, open_data 7 | 8 | 9 | 10 | def test_text_models(): 11 | flatland = open_data("EN-text/flatland.txt").read() 12 | wordseq = words(flatland) 13 | P1 = UnigramWordModel(wordseq) 14 | P2 = NgramWordModel(2, wordseq) 15 | P3 = NgramWordModel(3, wordseq) 16 | 17 | # Test top 18 | assert P1.top(5) == [(2081, 'the'), (1479, 'of'), 19 | (1021, 'and'), (1008, 'to'), 20 | (850, 'a')] 21 | 22 | assert P2.top(5) == [(368, ('of', 'the')), (152, ('to', 'the')), 23 | (152, ('in', 'the')), (86, ('of', 'a')), 24 | (80, ('it', 'is'))] 25 | 26 | assert P3.top(5) == [(30, ('a', 'straight', 'line')), 27 | (19, ('of', 'three', 'dimensions')), 28 | (16, ('the', 'sense', 'of')), 29 | (13, ('by', 'the', 'sense')), 30 | (13, ('as', 'well', 'as'))] 31 | 32 | # Test isclose 33 | assert isclose(P1['the'], 0.0611, rel_tol=0.001) 34 | assert isclose(P2['of', 'the'], 0.0108, rel_tol=0.01) 35 | assert isclose(P3['so', 'as', 'to'], 0.000323, rel_tol=0.001) 36 | 37 | # Test cond_prob.get 38 | assert P2.cond_prob.get(('went',)) is None 39 | assert P3.cond_prob['in', 'order'].dictionary == {'to': 6} 40 | 41 | # Test dictionary 42 | test_string = 'unigram' 43 | wordseq = words(test_string) 44 | P1 = UnigramWordModel(wordseq) 45 | assert P1.dictionary == {('unigram'): 1} 46 | 47 | test_string = 'bigram text' 48 | wordseq = words(test_string) 49 | P2 = NgramWordModel(2, wordseq) 50 | assert P2.dictionary == {('bigram', 'text'): 1} 51 | 52 | test_string = 'test trigram text here' 53 | wordseq = words(test_string) 54 | P3 = NgramWordModel(3, wordseq) 55 | assert ('test', 'trigram', 'text') in P3.dictionary 56 | assert ('trigram', 'text', 'here') in P3.dictionary 57 | 58 | 59 | def test_char_models(): 60 | test_string = 'test unigram' 61 | wordseq = words(test_string) 62 | P1 = UnigramCharModel(wordseq) 63 | 64 | expected_unigrams = {'n': 1, 's': 1, 'e': 1, 'i': 1, 'm': 1, 'g': 1, 'r': 1, 'a': 1, 't': 2, 'u': 1} 65 | assert len(P1.dictionary) == len(expected_unigrams) 66 | for char in test_string.replace(' ', ''): 67 | assert char in P1.dictionary 68 | 69 | test_string = 'alpha beta' 70 | wordseq = words(test_string) 71 | P1 = NgramCharModel(1, wordseq) 72 | 73 | assert len(P1.dictionary) == len(set(test_string)) 74 | for char in set(test_string): 75 | assert tuple(char) in P1.dictionary 76 | 77 | test_string = 'bigram' 78 | wordseq = words(test_string) 79 | P2 = NgramCharModel(2, wordseq) 80 | 81 | expected_bigrams = {(' ', 'b'): 1, ('b', 'i'): 1, ('i', 'g'): 1, ('g', 'r'): 1, ('r', 'a'): 1, ('a', 'm'): 1} 82 | 83 | assert len(P2.dictionary) == len(expected_bigrams) 84 | for bigram, count in expected_bigrams.items(): 85 | assert bigram in P2.dictionary 86 | assert P2.dictionary[bigram] == count 87 | 88 | test_string = 'bigram bigram' 89 | wordseq = words(test_string) 90 | P2 = NgramCharModel(2, wordseq) 91 | 92 | expected_bigrams = {(' ', 'b'): 2, ('b', 'i'): 2, ('i', 'g'): 2, ('g', 'r'): 2, ('r', 'a'): 2, ('a', 'm'): 2} 93 | 94 | assert len(P2.dictionary) == len(expected_bigrams) 95 | for bigram, count in expected_bigrams.items(): 96 | assert bigram in P2.dictionary 97 | assert P2.dictionary[bigram] == count 98 | 99 | test_string = 'trigram' 100 | wordseq = words(test_string) 101 | P3 = NgramCharModel(3, wordseq) 102 | expected_trigrams = {(' ', 't', 'r'): 1, ('t', 'r', 'i'): 1, 103 | ('r', 'i', 'g'): 1, ('i', 'g', 'r'): 1, 104 | ('g', 'r', 'a'): 1, ('r', 'a', 'm'): 1} 105 | 106 | assert len(P3.dictionary) == len(expected_trigrams) 107 | for bigram, count in expected_trigrams.items(): 108 | assert bigram in P3.dictionary 109 | assert P3.dictionary[bigram] == count 110 | 111 | test_string = 'trigram trigram trigram' 112 | wordseq = words(test_string) 113 | P3 = NgramCharModel(3, wordseq) 114 | expected_trigrams = {(' ', 't', 'r'): 3, ('t', 'r', 'i'): 3, 115 | ('r', 'i', 'g'): 3, ('i', 'g', 'r'): 3, 116 | ('g', 'r', 'a'): 3, ('r', 'a', 'm'): 3} 117 | 118 | assert len(P3.dictionary) == len(expected_trigrams) 119 | for bigram, count in expected_trigrams.items(): 120 | assert bigram in P3.dictionary 121 | assert P3.dictionary[bigram] == count 122 | 123 | 124 | def test_samples(): 125 | story = open_data("EN-text/flatland.txt").read() 126 | story += open_data("gutenberg.txt").read() 127 | wordseq = words(story) 128 | P1 = UnigramWordModel(wordseq) 129 | P2 = NgramWordModel(2, wordseq) 130 | P3 = NgramWordModel(3, wordseq) 131 | 132 | s1 = P1.samples(10) 133 | s2 = P3.samples(10) 134 | s3 = P3.samples(10) 135 | 136 | assert len(s1.split(' ')) == 10 137 | assert len(s2.split(' ')) == 10 138 | assert len(s3.split(' ')) == 10 139 | 140 | 141 | def test_viterbi_segmentation(): 142 | flatland = open_data("EN-text/flatland.txt").read() 143 | wordseq = words(flatland) 144 | P = UnigramWordModel(wordseq) 145 | text = "itiseasytoreadwordswithoutspaces" 146 | 147 | s, p = viterbi_segment(text, P) 148 | assert s == [ 149 | 'it', 'is', 'easy', 'to', 'read', 'words', 'without', 'spaces'] 150 | 151 | 152 | def test_shift_encoding(): 153 | code = shift_encode("This is a secret message.", 17) 154 | 155 | assert code == 'Kyzj zj r jvtivk dvjjrxv.' 156 | 157 | 158 | def test_shift_decoding(): 159 | flatland = open_data("EN-text/flatland.txt").read() 160 | ring = ShiftDecoder(flatland) 161 | msg = ring.decode('Kyzj zj r jvtivk dvjjrxv.') 162 | 163 | assert msg == 'This is a secret message.' 164 | 165 | 166 | def test_permutation_decoder(): 167 | gutenberg = open_data("gutenberg.txt").read() 168 | flatland = open_data("EN-text/flatland.txt").read() 169 | 170 | pd = PermutationDecoder(canonicalize(gutenberg)) 171 | assert pd.decode('aba') in ('ece', 'ete', 'tat', 'tit', 'txt') 172 | 173 | pd = PermutationDecoder(canonicalize(flatland)) 174 | assert pd.decode('aba') in ('ded', 'did', 'ece', 'ele', 'eme', 'ere', 'eve', 'eye', 'iti', 'mom', 'ses', 'tat', 'tit') 175 | 176 | 177 | def test_rot13_encoding(): 178 | code = rot13('Hello, world!') 179 | 180 | assert code == 'Uryyb, jbeyq!' 181 | 182 | 183 | def test_rot13_decoding(): 184 | flatland = open_data("EN-text/flatland.txt").read() 185 | ring = ShiftDecoder(flatland) 186 | msg = ring.decode(rot13('Hello, world!')) 187 | 188 | assert msg == 'Hello, world!' 189 | 190 | 191 | def test_counting_probability_distribution(): 192 | D = CountingProbDist() 193 | 194 | for i in range(10000): 195 | D.add(random.choice('123456')) 196 | 197 | ps = [D[n] for n in '123456'] 198 | 199 | assert 1 / 7 <= min(ps) <= max(ps) <= 1 / 5 200 | 201 | 202 | def test_ir_system(): 203 | from collections import namedtuple 204 | Results = namedtuple('IRResults', ['score', 'url']) 205 | 206 | uc = UnixConsultant() 207 | 208 | def verify_query(query, expected): 209 | assert len(expected) == len(query) 210 | 211 | for expected, (score, d) in zip(expected, query): 212 | doc = uc.documents[d] 213 | assert "{0:.2f}".format( 214 | expected.score) == "{0:.2f}".format(score * 100) 215 | assert os.path.basename(expected.url) == os.path.basename(doc.url) 216 | 217 | return True 218 | 219 | q1 = uc.query("how do I remove a file") 220 | assert verify_query(q1, [ 221 | Results(76.83, "aima-data/MAN/rm.txt"), 222 | Results(67.83, "aima-data/MAN/tar.txt"), 223 | Results(67.79, "aima-data/MAN/cp.txt"), 224 | Results(66.58, "aima-data/MAN/zip.txt"), 225 | Results(64.58, "aima-data/MAN/gzip.txt"), 226 | Results(63.74, "aima-data/MAN/pine.txt"), 227 | Results(62.95, "aima-data/MAN/shred.txt"), 228 | Results(57.46, "aima-data/MAN/pico.txt"), 229 | Results(43.38, "aima-data/MAN/login.txt"), 230 | Results(41.93, "aima-data/MAN/ln.txt"), 231 | ]) 232 | 233 | q2 = uc.query("how do I delete a file") 234 | assert verify_query(q2, [ 235 | Results(75.47, "aima-data/MAN/diff.txt"), 236 | Results(69.12, "aima-data/MAN/pine.txt"), 237 | Results(63.56, "aima-data/MAN/tar.txt"), 238 | Results(60.63, "aima-data/MAN/zip.txt"), 239 | Results(57.46, "aima-data/MAN/pico.txt"), 240 | Results(51.28, "aima-data/MAN/shred.txt"), 241 | Results(26.72, "aima-data/MAN/tr.txt"), 242 | ]) 243 | 244 | q3 = uc.query("email") 245 | assert verify_query(q3, [ 246 | Results(18.39, "aima-data/MAN/pine.txt"), 247 | Results(12.01, "aima-data/MAN/info.txt"), 248 | Results(9.89, "aima-data/MAN/pico.txt"), 249 | Results(8.73, "aima-data/MAN/grep.txt"), 250 | Results(8.07, "aima-data/MAN/zip.txt"), 251 | ]) 252 | 253 | q4 = uc.query("word count for files") 254 | assert verify_query(q4, [ 255 | Results(128.15, "aima-data/MAN/grep.txt"), 256 | Results(94.20, "aima-data/MAN/find.txt"), 257 | Results(81.71, "aima-data/MAN/du.txt"), 258 | Results(55.45, "aima-data/MAN/ps.txt"), 259 | Results(53.42, "aima-data/MAN/more.txt"), 260 | Results(42.00, "aima-data/MAN/dd.txt"), 261 | Results(12.85, "aima-data/MAN/who.txt"), 262 | ]) 263 | 264 | q5 = uc.query("learn: date") 265 | assert verify_query(q5, []) 266 | 267 | q6 = uc.query("2003") 268 | assert verify_query(q6, [ 269 | Results(14.58, "aima-data/MAN/pine.txt"), 270 | Results(11.62, "aima-data/MAN/jar.txt"), 271 | ]) 272 | 273 | 274 | def test_words(): 275 | assert words("``EGAD!'' Edgar cried.") == ['egad', 'edgar', 'cried'] 276 | 277 | 278 | def test_canonicalize(): 279 | assert canonicalize("``EGAD!'' Edgar cried.") == 'egad edgar cried' 280 | 281 | 282 | def test_translate(): 283 | text = 'orange apple lemon ' 284 | func = lambda x: ('s ' + x) if x ==' ' else x 285 | 286 | assert translate(text, func) == 'oranges apples lemons ' 287 | 288 | 289 | def test_bigrams(): 290 | assert bigrams('this') == ['th', 'hi', 'is'] 291 | assert bigrams(['this', 'is', 'a', 'test']) == [['this', 'is'], ['is', 'a'], ['a', 'test']] 292 | 293 | 294 | 295 | if __name__ == '__main__': 296 | pytest.main() 297 | -------------------------------------------------------------------------------- /rl.py: -------------------------------------------------------------------------------- 1 | """Reinforcement Learning (Chapter 21)""" 2 | 3 | from collections import defaultdict 4 | from utils import argmax 5 | from mdp import MDP, policy_evaluation 6 | 7 | import random 8 | 9 | 10 | class PassiveDUEAgent: 11 | 12 | """Passive (non-learning) agent that uses direct utility estimation 13 | on a given MDP and policy. 14 | 15 | import sys 16 | from mdp import sequential_decision_environment 17 | north = (0, 1) 18 | south = (0,-1) 19 | west = (-1, 0) 20 | east = (1, 0) 21 | policy = {(0, 2): east, (1, 2): east, (2, 2): east, (3, 2): None, (0, 1): north, (2, 1): north, (3, 1): None, (0, 0): north, (1, 0): west, (2, 0): west, (3, 0): west,} 22 | agent = PassiveDUEAgent(policy, sequential_decision_environment) 23 | for i in range(200): 24 | run_single_trial(agent,sequential_decision_environment) 25 | agent.estimate_U() 26 | agent.U[(0, 0)] > 0.2 27 | True 28 | 29 | """ 30 | def __init__(self, pi, mdp): 31 | self.pi = pi 32 | self.mdp = mdp 33 | self.U = {} 34 | self.s = None 35 | self.a = None 36 | self.s_history = [] 37 | self.r_history = [] 38 | self.init = mdp.init 39 | 40 | def __call__(self, percept): 41 | s1, r1 = percept 42 | self.s_history.append(s1) 43 | self.r_history.append(r1) 44 | ## 45 | ## 46 | if s1 in self.mdp.terminals: 47 | self.s = self.a = None 48 | else: 49 | self.s, self.a = s1, self.pi[s1] 50 | return self.a 51 | 52 | def estimate_U(self): 53 | # this function can be called only if the MDP has reached a terminal state 54 | # it will also reset the mdp history 55 | assert self.a is None, 'MDP is not in terminal state' 56 | assert len(self.s_history) == len(self.r_history) 57 | # calculating the utilities based on the current iteration 58 | U2 = {s : [] for s in set(self.s_history)} 59 | for i in range(len(self.s_history)): 60 | s = self.s_history[i] 61 | U2[s] += [sum(self.r_history[i:])] 62 | U2 = {k : sum(v)/max(len(v), 1) for k, v in U2.items()} 63 | # resetting history 64 | self.s_history, self.r_history = [], [] 65 | # setting the new utilities to the average of the previous 66 | # iteration and this one 67 | for k in U2.keys(): 68 | if k in self.U.keys(): 69 | self.U[k] = (self.U[k] + U2[k]) /2 70 | else: 71 | self.U[k] = U2[k] 72 | return self.U 73 | 74 | def update_state(self, percept): 75 | '''To be overridden in most cases. The default case 76 | assumes the percept to be of type (state, reward)''' 77 | return percept 78 | 79 | 80 | 81 | class PassiveADPAgent: 82 | 83 | """Passive (non-learning) agent that uses adaptive dynamic programming 84 | on a given MDP and policy. [Figure 21.2] 85 | 86 | import sys 87 | from mdp import sequential_decision_environment 88 | north = (0, 1) 89 | south = (0,-1) 90 | west = (-1, 0) 91 | east = (1, 0) 92 | policy = {(0, 2): east, (1, 2): east, (2, 2): east, (3, 2): None, (0, 1): north, (2, 1): north, (3, 1): None, (0, 0): north, (1, 0): west, (2, 0): west, (3, 0): west,} 93 | agent = PassiveADPAgent(policy, sequential_decision_environment) 94 | for i in range(100): 95 | run_single_trial(agent,sequential_decision_environment) 96 | 97 | agent.U[(0, 0)] > 0.2 98 | True 99 | agent.U[(0, 1)] > 0.2 100 | True 101 | """ 102 | 103 | class ModelMDP(MDP): 104 | """ Class for implementing modified Version of input MDP with 105 | an editable transition model P and a custom function T. """ 106 | def __init__(self, init, actlist, terminals, gamma, states): 107 | super().__init__(init, actlist, terminals, states=states, gamma=gamma) 108 | nested_dict = lambda: defaultdict(nested_dict) 109 | # StackOverflow:whats-the-best-way-to-initialize-a-dict-of-dicts-in-python 110 | self.P = nested_dict() 111 | 112 | def T(self, s, a): 113 | """Return a list of tuples with probabilities for states 114 | based on the learnt model P.""" 115 | return [(prob, res) for (res, prob) in self.P[(s, a)].items()] 116 | 117 | def __init__(self, pi, mdp): 118 | self.pi = pi 119 | self.mdp = PassiveADPAgent.ModelMDP(mdp.init, mdp.actlist, 120 | mdp.terminals, mdp.gamma, mdp.states) 121 | self.U = {} 122 | self.Nsa = defaultdict(int) 123 | self.Ns1_sa = defaultdict(int) 124 | self.s = None 125 | self.a = None 126 | self.visited = set() # keeping track of visited states 127 | 128 | def __call__(self, percept): 129 | s1, r1 = percept 130 | mdp = self.mdp 131 | R, P, terminals, pi = mdp.reward, mdp.P, mdp.terminals, self.pi 132 | s, a, Nsa, Ns1_sa, U = self.s, self.a, self.Nsa, self.Ns1_sa, self.U 133 | 134 | if s1 not in self.visited: # Reward is only known for visited state. 135 | U[s1] = R[s1] = r1 136 | self.visited.add(s1) 137 | if s is not None: 138 | Nsa[(s, a)] += 1 139 | Ns1_sa[(s1, s, a)] += 1 140 | # for each t such that Ns′|sa [t, s, a] is nonzero 141 | for t in [res for (res, state, act), freq in Ns1_sa.items() 142 | if (state, act) == (s, a) and freq != 0]: 143 | P[(s, a)][t] = Ns1_sa[(t, s, a)] / Nsa[(s, a)] 144 | 145 | self.U = policy_evaluation(pi, U, mdp) 146 | ## 147 | ## 148 | self.Nsa, self.Ns1_sa = Nsa, Ns1_sa 149 | if s1 in terminals: 150 | self.s = self.a = None 151 | else: 152 | self.s, self.a = s1, self.pi[s1] 153 | return self.a 154 | 155 | def update_state(self, percept): 156 | """To be overridden in most cases. The default case 157 | assumes the percept to be of type (state, reward).""" 158 | return percept 159 | 160 | 161 | class PassiveTDAgent: 162 | """The abstract class for a Passive (non-learning) agent that uses 163 | temporal differences to learn utility estimates. Override update_state 164 | method to convert percept to state and reward. The mdp being provided 165 | should be an instance of a subclass of the MDP Class. [Figure 21.4] 166 | 167 | import sys 168 | from mdp import sequential_decision_environment 169 | north = (0, 1) 170 | south = (0,-1) 171 | west = (-1, 0) 172 | east = (1, 0) 173 | policy = {(0, 2): east, (1, 2): east, (2, 2): east, (3, 2): None, (0, 1): north, (2, 1): north, (3, 1): None, (0, 0): north, (1, 0): west, (2, 0): west, (3, 0): west,} 174 | agent = PassiveTDAgent(policy, sequential_decision_environment, alpha=lambda n: 60./(59+n)) 175 | for i in range(200): 176 | run_single_trial(agent,sequential_decision_environment) 177 | 178 | agent.U[(0, 0)] > 0.2 179 | True 180 | agent.U[(0, 1)] > 0.2 181 | True 182 | """ 183 | 184 | def __init__(self, pi, mdp, alpha=None): 185 | 186 | self.pi = pi 187 | self.U = {s: 0. for s in mdp.states} 188 | self.Ns = {s: 0 for s in mdp.states} 189 | self.s = None 190 | self.a = None 191 | self.r = None 192 | self.gamma = mdp.gamma 193 | self.terminals = mdp.terminals 194 | 195 | if alpha: 196 | self.alpha = alpha 197 | else: 198 | self.alpha = lambda n: 1/(1+n) # udacity video 199 | 200 | def __call__(self, percept): 201 | s1, r1 = self.update_state(percept) 202 | pi, U, Ns, s, r = self.pi, self.U, self.Ns, self.s, self.r 203 | alpha, gamma, terminals = self.alpha, self.gamma, self.terminals 204 | if not Ns[s1]: 205 | U[s1] = r1 206 | if s is not None: 207 | Ns[s] += 1 208 | U[s] += alpha(Ns[s]) * (r + gamma * U[s1] - U[s]) 209 | if s1 in terminals: 210 | self.s = self.a = self.r = None 211 | else: 212 | self.s, self.a, self.r = s1, pi[s1], r1 213 | return self.a 214 | 215 | def update_state(self, percept): 216 | """To be overridden in most cases. The default case 217 | assumes the percept to be of type (state, reward).""" 218 | return percept 219 | 220 | 221 | class QLearningAgent: 222 | """ An exploratory Q-learning agent. It avoids having to learn the transition 223 | model because the Q-value of a state can be related directly to those of 224 | its neighbors. [Figure 21.8] 225 | 226 | import sys 227 | from mdp import sequential_decision_environment 228 | north = (0, 1) 229 | south = (0,-1) 230 | west = (-1, 0) 231 | east = (1, 0) 232 | policy = {(0, 2): east, (1, 2): east, (2, 2): east, (3, 2): None, (0, 1): north, (2, 1): north, (3, 1): None, (0, 0): north, (1, 0): west, (2, 0): west, (3, 0): west,} 233 | q_agent = QLearningAgent(sequential_decision_environment, Ne=5, Rplus=2, alpha=lambda n: 60./(59+n)) 234 | for i in range(200): 235 | run_single_trial(q_agent,sequential_decision_environment) 236 | 237 | q_agent.Q[((0, 1), (0, 1))] >= -0.5 238 | True 239 | q_agent.Q[((1, 0), (0, -1))] <= 0.5 240 | True 241 | """ 242 | def __init__(self, mdp, Ne, Rplus, alpha=None): 243 | 244 | self.gamma = mdp.gamma 245 | self.terminals = mdp.terminals 246 | self.all_act = mdp.actlist 247 | self.Ne = Ne # iteration limit in exploration function 248 | self.Rplus = Rplus # large value to assign before iteration limit 249 | self.Q = defaultdict(float) 250 | self.Nsa = defaultdict(float) 251 | self.s = None 252 | self.a = None 253 | self.r = None 254 | 255 | if alpha: 256 | self.alpha = alpha 257 | else: 258 | self.alpha = lambda n: 1./(1+n) # udacity video 259 | 260 | def f(self, u, n): 261 | """ Exploration function. Returns fixed Rplus until 262 | agent has visited state, action a Ne number of times. 263 | Same as ADP agent in book.""" 264 | if n < self.Ne: 265 | return self.Rplus 266 | else: 267 | return u 268 | 269 | def actions_in_state(self, state): 270 | """ Return actions possible in given state. 271 | Useful for max and argmax. """ 272 | if state in self.terminals: 273 | return [None] 274 | else: 275 | return self.all_act 276 | 277 | def __call__(self, percept): 278 | s1, r1 = self.update_state(percept) 279 | Q, Nsa, s, a, r = self.Q, self.Nsa, self.s, self.a, self.r 280 | alpha, gamma, terminals = self.alpha, self.gamma, self.terminals, 281 | actions_in_state = self.actions_in_state 282 | 283 | if s in terminals: 284 | Q[s, None] = r1 285 | if s is not None: 286 | Nsa[s, a] += 1 287 | Q[s, a] += alpha(Nsa[s, a]) * (r + gamma * max(Q[s1, a1] 288 | for a1 in actions_in_state(s1)) - Q[s, a]) 289 | if s in terminals: 290 | self.s = self.a = self.r = None 291 | else: 292 | self.s, self.r = s1, r1 293 | self.a = argmax(actions_in_state(s1), key=lambda a1: self.f(Q[s1, a1], Nsa[s1, a1])) 294 | return self.a 295 | 296 | def update_state(self, percept): 297 | """To be overridden in most cases. The default case 298 | assumes the percept to be of type (state, reward).""" 299 | return percept 300 | 301 | 302 | def run_single_trial(agent_program, mdp): 303 | """Execute trial for given agent_program 304 | and mdp. mdp should be an instance of subclass 305 | of mdp.MDP """ 306 | 307 | def take_single_action(mdp, s, a): 308 | """ 309 | Select outcome of taking action a 310 | in state s. Weighted Sampling. 311 | """ 312 | x = random.uniform(0, 1) 313 | cumulative_probability = 0.0 314 | for probability_state in mdp.T(s, a): 315 | probability, state = probability_state 316 | cumulative_probability += probability 317 | if x < cumulative_probability: 318 | break 319 | return state 320 | 321 | current_state = mdp.init 322 | while True: 323 | current_reward = mdp.R(current_state) 324 | percept = (current_state, current_reward) 325 | next_action = agent_program(percept) 326 | if next_action is None: 327 | break 328 | current_state = take_single_action(mdp, current_state, next_action) 329 | -------------------------------------------------------------------------------- /tests/test_knowledge.py: -------------------------------------------------------------------------------- 1 | from knowledge import * 2 | from utils import expr 3 | import random 4 | 5 | random.seed("aima-python") 6 | 7 | 8 | 9 | party = [ 10 | {'Pizza': 'Yes', 'Soda': 'No', 'GOAL': True}, 11 | {'Pizza': 'Yes', 'Soda': 'Yes', 'GOAL': True}, 12 | {'Pizza': 'No', 'Soda': 'No', 'GOAL': False} 13 | ] 14 | 15 | animals_umbrellas = [ 16 | {'Species': 'Cat', 'Rain': 'Yes', 'Coat': 'No', 'GOAL': True}, 17 | {'Species': 'Cat', 'Rain': 'Yes', 'Coat': 'Yes', 'GOAL': True}, 18 | {'Species': 'Dog', 'Rain': 'Yes', 'Coat': 'Yes', 'GOAL': True}, 19 | {'Species': 'Dog', 'Rain': 'Yes', 'Coat': 'No', 'GOAL': False}, 20 | {'Species': 'Dog', 'Rain': 'No', 'Coat': 'No', 'GOAL': False}, 21 | {'Species': 'Cat', 'Rain': 'No', 'Coat': 'No', 'GOAL': False}, 22 | {'Species': 'Cat', 'Rain': 'No', 'Coat': 'Yes', 'GOAL': True} 23 | ] 24 | 25 | conductance = [ 26 | {'Sample': 'S1', 'Mass': 12, 'Temp': 26, 'Material': 'Cu', 'Size': 3, 'GOAL': 0.59}, 27 | {'Sample': 'S1', 'Mass': 12, 'Temp': 100, 'Material': 'Cu', 'Size': 3, 'GOAL': 0.57}, 28 | {'Sample': 'S2', 'Mass': 24, 'Temp': 26, 'Material': 'Cu', 'Size': 6, 'GOAL': 0.59}, 29 | {'Sample': 'S3', 'Mass': 12, 'Temp': 26, 'Material': 'Pb', 'Size': 2, 'GOAL': 0.05}, 30 | {'Sample': 'S3', 'Mass': 12, 'Temp': 100, 'Material': 'Pb', 'Size': 2, 'GOAL': 0.04}, 31 | {'Sample': 'S4', 'Mass': 18, 'Temp': 100, 'Material': 'Pb', 'Size': 3, 'GOAL': 0.04}, 32 | {'Sample': 'S4', 'Mass': 18, 'Temp': 100, 'Material': 'Pb', 'Size': 3, 'GOAL': 0.04}, 33 | {'Sample': 'S5', 'Mass': 24, 'Temp': 100, 'Material': 'Pb', 'Size': 4, 'GOAL': 0.04}, 34 | {'Sample': 'S6', 'Mass': 36, 'Temp': 26, 'Material': 'Pb', 'Size': 6, 'GOAL': 0.05}, 35 | ] 36 | 37 | def r_example(Alt, Bar, Fri, Hun, Pat, Price, Rain, Res, Type, Est, GOAL): 38 | return {'Alt': Alt, 'Bar': Bar, 'Fri': Fri, 'Hun': Hun, 'Pat': Pat, 39 | 'Price': Price, 'Rain': Rain, 'Res': Res, 'Type': Type, 'Est': Est, 40 | 'GOAL': GOAL} 41 | 42 | restaurant = [ 43 | r_example('Yes', 'No', 'No', 'Yes', 'Some', '$$$', 'No', 'Yes', 'French', '0-10', True), 44 | r_example('Yes', 'No', 'No', 'Yes', 'Full', '$', 'No', 'No', 'Thai', '30-60', False), 45 | r_example('No', 'Yes', 'No', 'No', 'Some', '$', 'No', 'No', 'Burger', '0-10', True), 46 | r_example('Yes', 'No', 'Yes', 'Yes', 'Full', '$', 'Yes', 'No', 'Thai', '10-30', True), 47 | r_example('Yes', 'No', 'Yes', 'No', 'Full', '$$$', 'No', 'Yes', 'French', '>60', False), 48 | r_example('No', 'Yes', 'No', 'Yes', 'Some', '$$', 'Yes', 'Yes', 'Italian', '0-10', True), 49 | r_example('No', 'Yes', 'No', 'No', 'None', '$', 'Yes', 'No', 'Burger', '0-10', False), 50 | r_example('No', 'No', 'No', 'Yes', 'Some', '$$', 'Yes', 'Yes', 'Thai', '0-10', True), 51 | r_example('No', 'Yes', 'Yes', 'No', 'Full', '$', 'Yes', 'No', 'Burger', '>60', False), 52 | r_example('Yes', 'Yes', 'Yes', 'Yes', 'Full', '$$$', 'No', 'Yes', 'Italian', '10-30', False), 53 | r_example('No', 'No', 'No', 'No', 'None', '$', 'No', 'No', 'Thai', '0-10', False), 54 | r_example('Yes', 'Yes', 'Yes', 'Yes', 'Full', '$', 'No', 'No', 'Burger', '30-60', True) 55 | ] 56 | 57 | 58 | def test_current_best_learning(): 59 | examples = restaurant 60 | hypothesis = [{'Alt': 'Yes'}] 61 | h = current_best_learning(examples, hypothesis) 62 | values = [] 63 | for e in examples: 64 | values.append(guess_value(e, h)) 65 | 66 | assert values == [True, False, True, True, False, True, False, True, False, False, False, True] 67 | 68 | examples = animals_umbrellas 69 | initial_h = [{'Species': 'Cat'}] 70 | h = current_best_learning(examples, initial_h) 71 | values = [] 72 | for e in examples: 73 | values.append(guess_value(e, h)) 74 | 75 | assert values == [True, True, True, False, False, False, True] 76 | 77 | examples = party 78 | initial_h = [{'Pizza': 'Yes'}] 79 | h = current_best_learning(examples, initial_h) 80 | values = [] 81 | for e in examples: 82 | values.append(guess_value(e, h)) 83 | 84 | assert values == [True, True, False] 85 | 86 | 87 | def test_version_space_learning(): 88 | V = version_space_learning(party) 89 | results = [] 90 | for e in party: 91 | guess = False 92 | for h in V: 93 | if guess_value(e, h): 94 | guess = True 95 | break 96 | 97 | results.append(guess) 98 | 99 | assert results == [True, True, False] 100 | assert [{'Pizza': 'Yes'}] in V 101 | 102 | 103 | def test_minimal_consistent_det(): 104 | assert minimal_consistent_det(party, {'Pizza', 'Soda'}) == {'Pizza'} 105 | assert minimal_consistent_det(party[:2], {'Pizza', 'Soda'}) == set() 106 | assert minimal_consistent_det(animals_umbrellas, {'Species', 'Rain', 'Coat'}) == {'Species', 'Rain', 'Coat'} 107 | assert minimal_consistent_det(conductance, {'Mass', 'Temp', 'Material', 'Size'}) == {'Temp', 'Material'} 108 | assert minimal_consistent_det(conductance, {'Mass', 'Temp', 'Size'}) == {'Mass', 'Temp', 'Size'} 109 | 110 | 111 | A, B, C, D, E, F, G, H, I, x, y, z = map(expr, 'ABCDEFGHIxyz') 112 | 113 | # knowledge base containing family relations 114 | small_family = FOIL_container([expr("Mother(Anne, Peter)"), 115 | expr("Mother(Anne, Zara)"), 116 | expr("Mother(Sarah, Beatrice)"), 117 | expr("Mother(Sarah, Eugenie)"), 118 | expr("Father(Mark, Peter)"), 119 | expr("Father(Mark, Zara)"), 120 | expr("Father(Andrew, Beatrice)"), 121 | expr("Father(Andrew, Eugenie)"), 122 | expr("Father(Philip, Anne)"), 123 | expr("Father(Philip, Andrew)"), 124 | expr("Mother(Elizabeth, Anne)"), 125 | expr("Mother(Elizabeth, Andrew)"), 126 | expr("Male(Philip)"), 127 | expr("Male(Mark)"), 128 | expr("Male(Andrew)"), 129 | expr("Male(Peter)"), 130 | expr("Female(Elizabeth)"), 131 | expr("Female(Anne)"), 132 | expr("Female(Sarah)"), 133 | expr("Female(Zara)"), 134 | expr("Female(Beatrice)"), 135 | expr("Female(Eugenie)"), 136 | ]) 137 | 138 | smaller_family = FOIL_container([expr("Mother(Anne, Peter)"), 139 | expr("Father(Mark, Peter)"), 140 | expr("Father(Philip, Anne)"), 141 | expr("Mother(Elizabeth, Anne)"), 142 | expr("Male(Philip)"), 143 | expr("Male(Mark)"), 144 | expr("Male(Peter)"), 145 | expr("Female(Elizabeth)"), 146 | expr("Female(Anne)") 147 | ]) 148 | 149 | 150 | # target relation 151 | target = expr('Parent(x, y)') 152 | 153 | #positive examples of target 154 | examples_pos = [{x: expr('Elizabeth'), y: expr('Anne')}, 155 | {x: expr('Elizabeth'), y: expr('Andrew')}, 156 | {x: expr('Philip'), y: expr('Anne')}, 157 | {x: expr('Philip'), y: expr('Andrew')}, 158 | {x: expr('Anne'), y: expr('Peter')}, 159 | {x: expr('Anne'), y: expr('Zara')}, 160 | {x: expr('Mark'), y: expr('Peter')}, 161 | {x: expr('Mark'), y: expr('Zara')}, 162 | {x: expr('Andrew'), y: expr('Beatrice')}, 163 | {x: expr('Andrew'), y: expr('Eugenie')}, 164 | {x: expr('Sarah'), y: expr('Beatrice')}, 165 | {x: expr('Sarah'), y: expr('Eugenie')}] 166 | 167 | # negative examples of target 168 | examples_neg = [{x: expr('Anne'), y: expr('Eugenie')}, 169 | {x: expr('Beatrice'), y: expr('Eugenie')}, 170 | {x: expr('Mark'), y: expr('Elizabeth')}, 171 | {x: expr('Beatrice'), y: expr('Philip')}] 172 | 173 | 174 | 175 | def test_tell(): 176 | """ 177 | adds in the knowledge base a sentence 178 | """ 179 | smaller_family.tell(expr("Male(George)")) 180 | smaller_family.tell(expr("Female(Mum)")) 181 | assert smaller_family.ask(expr("Male(George)")) == {} 182 | assert smaller_family.ask(expr("Female(Mum)"))=={} 183 | assert not smaller_family.ask(expr("Female(George)")) 184 | assert not smaller_family.ask(expr("Male(Mum)")) 185 | 186 | def test_extend_example(): 187 | """ 188 | Create the extended examples of the given clause. 189 | (The extended examples are a set of examples created by extending example 190 | with each possible constant value for each new variable in literal.) 191 | """ 192 | assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Father(x, y)')))) == 2 193 | assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Mother(x, y)')))) == 0 194 | assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Female(y)')))) == 6 195 | 196 | 197 | def test_new_literals(): 198 | assert len(list(small_family.new_literals([expr('p'), []]))) == 8 199 | assert len(list(small_family.new_literals([expr('p & q'), []]))) == 20 200 | 201 | def test_new_clause(): 202 | """ 203 | Finds the best clause to add in the set of clauses. 204 | """ 205 | clause = small_family.new_clause([examples_pos, examples_neg], target)[0][1] 206 | assert len(clause) == 1 and ( clause[0].op in ['Male', 'Female', 'Father', 'Mother' ] ) 207 | 208 | 209 | def test_choose_literal(): 210 | """ 211 | Choose the best literal based on the information gain 212 | """ 213 | literals = [expr('Father(x, y)'), expr('Father(x, y)'), expr('Mother(x, y)'), expr('Mother(x, y)')] 214 | examples_pos = [{x: expr('Philip')}, {x: expr('Mark')}, {x: expr('Peter')}] 215 | examples_neg = [{x: expr('Elizabeth')}, {x: expr('Sarah')}] 216 | assert small_family.choose_literal(literals, [examples_pos, examples_neg]) == expr('Father(x, y)') 217 | literals = [expr('Father(x, y)'), expr('Father(y, x)'), expr('Male(x)')] 218 | examples_pos = [{x: expr('Philip')}, {x: expr('Mark')}, {x: expr('Andrew')}] 219 | examples_neg = [{x: expr('Elizabeth')}, {x: expr('Sarah')}] 220 | assert small_family.choose_literal(literals, [examples_pos, examples_neg]) == expr('Father(x,y)') 221 | 222 | 223 | def test_gain(): 224 | """ 225 | Calculates the utility of each literal, based on the information gained. 226 | """ 227 | gain_father = small_family.gain( expr('Father(x,y)'), [examples_pos, examples_neg] ) 228 | gain_male = small_family.gain(expr('Male(x)'), [examples_pos, examples_neg] ) 229 | assert round(gain_father, 2) == 2.49 230 | assert round(gain_male, 2) == 1.16 231 | 232 | def test_update_examples(): 233 | """Add to the kb those examples what are represented in extended_examples 234 | List of omitted examples is returned. 235 | """ 236 | extended_examples = [{x: expr("Mark") , y: expr("Peter")}, 237 | {x: expr("Philip"), y: expr("Anne")} ] 238 | 239 | uncovered = smaller_family.update_examples(target, examples_pos, extended_examples) 240 | assert {x: expr("Elizabeth"), y: expr("Anne") } in uncovered 241 | assert {x: expr("Anne"), y: expr("Peter")} in uncovered 242 | assert {x: expr("Philip"), y: expr("Anne") } not in uncovered 243 | assert {x: expr("Mark"), y: expr("Peter")} not in uncovered 244 | 245 | 246 | 247 | def test_foil(): 248 | """ 249 | Test the FOIL algorithm, when target is Parent(x,y) 250 | """ 251 | clauses = small_family.foil([examples_pos, examples_neg], target) 252 | assert len(clauses) == 2 and \ 253 | ((clauses[0][1][0] == expr('Father(x, y)') and clauses[1][1][0] == expr('Mother(x, y)')) or \ 254 | (clauses[1][1][0] == expr('Father(x, y)') and clauses[0][1][0] == expr('Mother(x, y)'))) 255 | 256 | target_g = expr('Grandparent(x, y)') 257 | examples_pos_g = [{x: expr('Elizabeth'), y: expr('Peter')}, 258 | {x: expr('Elizabeth'), y: expr('Zara')}, 259 | {x: expr('Elizabeth'), y: expr('Beatrice')}, 260 | {x: expr('Elizabeth'), y: expr('Eugenie')}, 261 | {x: expr('Philip'), y: expr('Peter')}, 262 | {x: expr('Philip'), y: expr('Zara')}, 263 | {x: expr('Philip'), y: expr('Beatrice')}, 264 | {x: expr('Philip'), y: expr('Eugenie')}] 265 | examples_neg_g = [{x: expr('Anne'), y: expr('Eugenie')}, 266 | {x: expr('Beatrice'), y: expr('Eugenie')}, 267 | {x: expr('Elizabeth'), y: expr('Andrew')}, 268 | {x: expr('Elizabeth'), y: expr('Anne')}, 269 | {x: expr('Elizabeth'), y: expr('Mark')}, 270 | {x: expr('Elizabeth'), y: expr('Sarah')}, 271 | {x: expr('Philip'), y: expr('Anne')}, 272 | {x: expr('Philip'), y: expr('Andrew')}, 273 | {x: expr('Anne'), y: expr('Peter')}, 274 | {x: expr('Anne'), y: expr('Zara')}, 275 | {x: expr('Mark'), y: expr('Peter')}, 276 | {x: expr('Mark'), y: expr('Zara')}, 277 | {x: expr('Andrew'), y: expr('Beatrice')}, 278 | {x: expr('Andrew'), y: expr('Eugenie')}, 279 | {x: expr('Sarah'), y: expr('Beatrice')}, 280 | {x: expr('Mark'), y: expr('Elizabeth')}, 281 | {x: expr('Beatrice'), y: expr('Philip')}, 282 | {x: expr('Peter'), y: expr('Andrew')}, 283 | {x: expr('Zara'), y: expr('Mark')}, 284 | {x: expr('Peter'), y: expr('Anne')}, 285 | {x: expr('Zara'), y: expr('Eugenie')}] 286 | 287 | clauses = small_family.foil([examples_pos_g, examples_neg_g], target_g) 288 | assert len(clauses[0]) == 2 289 | assert clauses[0][1][0].op == 'Parent' 290 | assert clauses[0][1][0].args[0] == x 291 | assert clauses[0][1][1].op == 'Parent' 292 | assert clauses[0][1][1].args[1] == y 293 | -------------------------------------------------------------------------------- /tests/test_logic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from logic import * 3 | from utils import expr_handle_infix_ops, count, Symbol 4 | 5 | definite_clauses_KB = PropDefiniteKB() 6 | for clause in ['(B & F)==>E', '(A & E & F)==>G', '(B & C)==>F', '(A & B)==>D', '(E & F)==>H', '(H & I)==>J', 'A', 'B', 'C']: 7 | definite_clauses_KB.tell(expr(clause)) 8 | 9 | 10 | def test_is_symbol(): 11 | assert is_symbol('x') 12 | assert is_symbol('X') 13 | assert is_symbol('N245') 14 | assert not is_symbol('') 15 | assert not is_symbol('1L') 16 | assert not is_symbol([1, 2, 3]) 17 | 18 | 19 | def test_is_var_symbol(): 20 | assert is_var_symbol('xt') 21 | assert not is_var_symbol('Txt') 22 | assert not is_var_symbol('') 23 | assert not is_var_symbol('52') 24 | 25 | 26 | def test_is_prop_symbol(): 27 | assert not is_prop_symbol('xt') 28 | assert is_prop_symbol('Txt') 29 | assert not is_prop_symbol('') 30 | assert not is_prop_symbol('52') 31 | 32 | 33 | def test_variables(): 34 | assert variables(expr('F(x, x) & G(x, y) & H(y, z) & R(A, z, 2)')) == {x, y, z} 35 | assert variables(expr('(x ==> y) & B(x, y) & A')) == {x, y} 36 | 37 | 38 | def test_expr(): 39 | assert repr(expr('P <=> Q(1)')) == '(P <=> Q(1))' 40 | assert repr(expr('P & Q | ~R(x, F(x))')) == '((P & Q) | ~R(x, F(x)))' 41 | assert (expr_handle_infix_ops('P & Q ==> R & ~S') 42 | == "P & Q |'==>'| R & ~S") 43 | 44 | 45 | def test_extend(): 46 | assert extend({x: 1}, y, 2) == {x: 1, y: 2} 47 | 48 | 49 | def test_subst(): 50 | assert subst({x: 42, y:0}, F(x) + y) == (F(42) + 0) 51 | 52 | 53 | def test_PropKB(): 54 | kb = PropKB() 55 | assert count(kb.ask(expr) for expr in [A, C, D, E, Q]) is 0 56 | kb.tell(A & E) 57 | assert kb.ask(A) == kb.ask(E) == {} 58 | kb.tell(E |'==>'| C) 59 | assert kb.ask(C) == {} 60 | kb.retract(E) 61 | assert kb.ask(E) is False 62 | assert kb.ask(C) is False 63 | 64 | 65 | def test_wumpus_kb(): 66 | # Statement: There is no pit in [1,1]. 67 | assert wumpus_kb.ask(~P11) == {} 68 | 69 | # Statement: There is no pit in [1,2]. 70 | assert wumpus_kb.ask(~P12) == {} 71 | 72 | # Statement: There is a pit in [2,2]. 73 | assert wumpus_kb.ask(P22) is False 74 | 75 | # Statement: There is a pit in [3,1]. 76 | assert wumpus_kb.ask(P31) is False 77 | 78 | # Statement: Neither [1,2] nor [2,1] contains a pit. 79 | assert wumpus_kb.ask(~P12 & ~P21) == {} 80 | 81 | # Statement: There is a pit in either [2,2] or [3,1]. 82 | assert wumpus_kb.ask(P22 | P31) == {} 83 | 84 | 85 | def test_is_definite_clause(): 86 | assert is_definite_clause(expr('A & B & C & D ==> E')) 87 | assert is_definite_clause(expr('Farmer(Mac)')) 88 | assert not is_definite_clause(expr('~Farmer(Mac)')) 89 | assert is_definite_clause(expr('(Farmer(f) & Rabbit(r)) ==> Hates(f, r)')) 90 | assert not is_definite_clause(expr('(Farmer(f) & ~Rabbit(r)) ==> Hates(f, r)')) 91 | assert not is_definite_clause(expr('(Farmer(f) | Rabbit(r)) ==> Hates(f, r)')) 92 | 93 | 94 | def test_parse_definite_clause(): 95 | assert parse_definite_clause(expr('A & B & C & D ==> E')) == ([A, B, C, D], E) 96 | assert parse_definite_clause(expr('Farmer(Mac)')) == ([], expr('Farmer(Mac)')) 97 | assert parse_definite_clause(expr('(Farmer(f) & Rabbit(r)) ==> Hates(f, r)')) == ([expr('Farmer(f)'), expr('Rabbit(r)')], expr('Hates(f, r)')) 98 | 99 | 100 | def test_pl_true(): 101 | assert pl_true(P, {}) is None 102 | assert pl_true(P, {P: False}) is False 103 | assert pl_true(P | Q, {P: True}) is True 104 | assert pl_true((A | B) & (C | D), {A: False, B: True, D: True}) is True 105 | assert pl_true((A & B) & (C | D), {A: False, B: True, D: True}) is False 106 | assert pl_true((A & B) | (A & C), {A: False, B: True, C: True}) is False 107 | assert pl_true((A | B) & (C | D), {A: True, D: False}) is None 108 | assert pl_true(P | P, {}) is None 109 | 110 | 111 | def test_tt_true(): 112 | assert tt_true(P | ~P) 113 | assert tt_true('~~P <=> P') 114 | assert not tt_true((P | ~Q) & (~P | Q)) 115 | assert not tt_true(P & ~P) 116 | assert not tt_true(P & Q) 117 | assert tt_true((P | ~Q) | (~P | Q)) 118 | assert tt_true('(A & B) ==> (A | B)') 119 | assert tt_true('((A & B) & C) <=> (A & (B & C))') 120 | assert tt_true('((A | B) | C) <=> (A | (B | C))') 121 | assert tt_true('(A ==> B) <=> (~B ==> ~A)') 122 | assert tt_true('(A ==> B) <=> (~A | B)') 123 | assert tt_true('(A <=> B) <=> ((A ==> B) & (B ==> A))') 124 | assert tt_true('~(A & B) <=> (~A | ~B)') 125 | assert tt_true('~(A | B) <=> (~A & ~B)') 126 | assert tt_true('(A & (B | C)) <=> ((A & B) | (A & C))') 127 | assert tt_true('(A | (B & C)) <=> ((A | B) & (A | C))') 128 | 129 | 130 | def test_dpll(): 131 | assert (dpll_satisfiable(A & ~B & C & (A | ~D) & (~E | ~D) & (C | ~D) & (~A | ~F) & (E | ~F) 132 | & (~D | ~F) & (B | ~C | D) & (A | ~E | F) & (~A | E | D)) 133 | == {B: False, C: True, A: True, F: False, D: True, E: False}) 134 | assert dpll_satisfiable(A & B & ~C & D) == {C: False, A: True, D: True, B: True} 135 | assert dpll_satisfiable((A | (B & C)) |'<=>'| ((A | B) & (A | C))) == {C: True, A: True} or {C: True, B: True} 136 | assert dpll_satisfiable(A |'<=>'| B) == {A: True, B: True} 137 | assert dpll_satisfiable(A & ~B) == {A: True, B: False} 138 | assert dpll_satisfiable(P & ~P) is False 139 | 140 | 141 | def test_find_pure_symbol(): 142 | assert find_pure_symbol([A, B, C], [A|~B,~B|~C,C|A]) == (A, True) 143 | assert find_pure_symbol([A, B, C], [~A|~B,~B|~C,C|A]) == (B, False) 144 | assert find_pure_symbol([A, B, C], [~A|B,~B|~C,C|A]) == (None, None) 145 | 146 | 147 | def test_unit_clause_assign(): 148 | assert unit_clause_assign(A|B|C, {A:True}) == (None, None) 149 | assert unit_clause_assign(B|C, {A:True}) == (None, None) 150 | assert unit_clause_assign(B|~A, {A:True}) == (B, True) 151 | 152 | 153 | def test_find_unit_clause(): 154 | assert find_unit_clause([A|B|C, B|~C, ~A|~B], {A:True}) == (B, False) 155 | 156 | 157 | def test_unify(): 158 | assert unify(x, x, {}) == {} 159 | assert unify(x, 3, {}) == {x: 3} 160 | assert unify(x & 4 & y, 6 & y & 4, {}) == {x: 6, y: 4} 161 | assert unify(expr('A(x)'), expr('A(B)')) == {x: B} 162 | assert unify(expr('American(x) & Weapon(B)'), expr('American(A) & Weapon(y)')) == {x: A, y: B} 163 | 164 | 165 | def test_pl_fc_entails(): 166 | assert pl_fc_entails(horn_clauses_KB, expr('Q')) 167 | assert pl_fc_entails(definite_clauses_KB, expr('G')) 168 | assert pl_fc_entails(definite_clauses_KB, expr('H')) 169 | assert not pl_fc_entails(definite_clauses_KB, expr('I')) 170 | assert not pl_fc_entails(definite_clauses_KB, expr('J')) 171 | assert not pl_fc_entails(horn_clauses_KB, expr('SomethingSilly')) 172 | 173 | 174 | def test_tt_entails(): 175 | assert tt_entails(P & Q, Q) 176 | assert not tt_entails(P | Q, Q) 177 | assert tt_entails(A & (B | C) & E & F & ~(P | Q), A & E & F & ~P & ~Q) 178 | assert not tt_entails(P |'<=>'| Q, Q) 179 | assert tt_entails((P |'==>'| Q) & P, Q) 180 | assert not tt_entails((P |'<=>'| Q) & ~P, Q) 181 | 182 | 183 | def test_prop_symbols(): 184 | assert prop_symbols(expr('x & y & z | A')) == {A} 185 | assert prop_symbols(expr('(x & B(z)) ==> Farmer(y) | A')) == {A, expr('Farmer(y)'), expr('B(z)')} 186 | 187 | 188 | def test_constant_symbols(): 189 | assert constant_symbols(expr('x & y & z | A')) == {A} 190 | assert constant_symbols(expr('(x & B(z)) & Father(John) ==> Farmer(y) | A')) == {A, expr('John')} 191 | 192 | 193 | def test_predicate_symbols(): 194 | assert predicate_symbols(expr('x & y & z | A')) == set() 195 | assert predicate_symbols(expr('(x & B(z)) & Father(John) ==> Farmer(y) | A')) == { 196 | ('B', 1), 197 | ('Father', 1), 198 | ('Farmer', 1)} 199 | assert predicate_symbols(expr('(x & B(x, y, z)) & F(G(x, y), x) ==> P(Q(R(x, y)), x, y, z)')) == { 200 | ('B', 3), 201 | ('F', 2), 202 | ('G', 2), 203 | ('P', 4), 204 | ('Q', 1), 205 | ('R', 2)} 206 | 207 | 208 | def test_eliminate_implications(): 209 | assert repr(eliminate_implications('A ==> (~B <== C)')) == '((~B | ~C) | ~A)' 210 | assert repr(eliminate_implications(A ^ B)) == '((A & ~B) | (~A & B))' 211 | assert repr(eliminate_implications(A & B | C & ~D)) == '((A & B) | (C & ~D))' 212 | 213 | 214 | def test_dissociate(): 215 | assert dissociate('&', [A & B]) == [A, B] 216 | assert dissociate('|', [A, B, C & D, P | Q]) == [A, B, C & D, P, Q] 217 | assert dissociate('&', [A, B, C & D, P | Q]) == [A, B, C, D, P | Q] 218 | 219 | 220 | def test_associate(): 221 | assert (repr(associate('&', [(A & B), (B | C), (B & C)])) 222 | == '(A & B & (B | C) & B & C)') 223 | assert (repr(associate('|', [A | (B | (C | (A & B)))])) 224 | == '(A | B | C | (A & B))') 225 | 226 | 227 | def test_move_not_inwards(): 228 | assert repr(move_not_inwards(~(A | B))) == '(~A & ~B)' 229 | assert repr(move_not_inwards(~(A & B))) == '(~A | ~B)' 230 | assert repr(move_not_inwards(~(~(A | ~B) | ~~C))) == '((A | ~B) & ~C)' 231 | 232 | 233 | def test_distribute_and_over_or(): 234 | def test_entailment(s, has_and = False): 235 | result = distribute_and_over_or(s) 236 | if has_and: 237 | assert result.op == '&' 238 | assert tt_entails(s, result) 239 | assert tt_entails(result, s) 240 | test_entailment((A & B) | C, True) 241 | test_entailment((A | B) & C, True) 242 | test_entailment((A | B) | C, False) 243 | test_entailment((A & B) | (C | D), True) 244 | 245 | 246 | def test_to_cnf(): 247 | assert (repr(to_cnf(wumpus_world_inference & ~expr('~P12'))) == 248 | "((~P12 | B11) & (~P21 | B11) & (P12 | P21 | ~B11) & ~B11 & P12)") 249 | assert repr(to_cnf((P & Q) | (~P & ~Q))) == '((~P | P) & (~Q | P) & (~P | Q) & (~Q | Q))' 250 | assert repr(to_cnf('A <=> B')) == '((A | ~B) & (B | ~A))' 251 | assert repr(to_cnf("B <=> (P1 | P2)")) == '((~P1 | B) & (~P2 | B) & (P1 | P2 | ~B))' 252 | assert repr(to_cnf('A <=> (B & C)')) == '((A | ~B | ~C) & (B | ~A) & (C | ~A))' 253 | assert repr(to_cnf("a | (b & c) | d")) == '((b | a | d) & (c | a | d))' 254 | assert repr(to_cnf("A & (B | (D & E))")) == '(A & (D | B) & (E | B))' 255 | assert repr(to_cnf("A | (B | (C | (D & E)))")) == '((D | A | B | C) & (E | A | B | C))' 256 | assert repr(to_cnf('(A <=> ~B) ==> (C | ~D)')) == '((B | ~A | C | ~D) & (A | ~A | C | ~D) & (B | ~B | C | ~D) & (A | ~B | C | ~D))' 257 | 258 | 259 | def test_pl_resolution(): 260 | assert pl_resolution(wumpus_kb, ~P11) 261 | assert pl_resolution(wumpus_kb, ~B11) 262 | assert not pl_resolution(wumpus_kb, P22) 263 | assert pl_resolution(horn_clauses_KB, A) 264 | assert pl_resolution(horn_clauses_KB, B) 265 | assert not pl_resolution(horn_clauses_KB, P) 266 | assert not pl_resolution(definite_clauses_KB, P) 267 | 268 | 269 | def test_standardize_variables(): 270 | e = expr('F(a, b, c) & G(c, A, 23)') 271 | assert len(variables(standardize_variables(e))) == 3 272 | # assert variables(e).intersection(variables(standardize_variables(e))) == {} 273 | assert is_variable(standardize_variables(expr('x'))) 274 | 275 | 276 | def test_fol_bc_ask(): 277 | def test_ask(query, kb=None): 278 | q = expr(query) 279 | test_variables = variables(q) 280 | answers = fol_bc_ask(kb or test_kb, q) 281 | return sorted( 282 | [dict((x, v) for x, v in list(a.items()) if x in test_variables) 283 | for a in answers], key=repr) 284 | assert repr(test_ask('Farmer(x)')) == '[{x: Mac}]' 285 | assert repr(test_ask('Human(x)')) == '[{x: Mac}, {x: MrsMac}]' 286 | assert repr(test_ask('Rabbit(x)')) == '[{x: MrsRabbit}, {x: Pete}]' 287 | assert repr(test_ask('Criminal(x)', crime_kb)) == '[{x: West}]' 288 | 289 | 290 | def test_fol_fc_ask(): 291 | def test_ask(query, kb=None): 292 | q = expr(query) 293 | test_variables = variables(q) 294 | answers = fol_fc_ask(kb or test_kb, q) 295 | return sorted( 296 | [dict((x, v) for x, v in list(a.items()) if x in test_variables) 297 | for a in answers], key=repr) 298 | assert repr(test_ask('Criminal(x)', crime_kb)) == '[{x: West}]' 299 | assert repr(test_ask('Enemy(x, America)', crime_kb)) == '[{x: Nono}]' 300 | assert repr(test_ask('Farmer(x)')) == '[{x: Mac}]' 301 | assert repr(test_ask('Human(x)')) == '[{x: Mac}, {x: MrsMac}]' 302 | assert repr(test_ask('Rabbit(x)')) == '[{x: MrsRabbit}, {x: Pete}]' 303 | 304 | 305 | def test_d(): 306 | assert d(x * x - x, x) == 2 * x - 1 307 | 308 | 309 | def test_WalkSAT(): 310 | def check_SAT(clauses, single_solution={}): 311 | # Make sure the solution is correct if it is returned by WalkSat 312 | # Sometimes WalkSat may run out of flips before finding a solution 313 | soln = WalkSAT(clauses) 314 | if soln: 315 | assert all(pl_true(x, soln) for x in clauses) 316 | if single_solution: # Cross check the solution if only one exists 317 | assert all(pl_true(x, single_solution) for x in clauses) 318 | assert soln == single_solution 319 | # Test WalkSat for problems with solution 320 | check_SAT([A & B, A & C]) 321 | check_SAT([A | B, P & Q, P & B]) 322 | check_SAT([A & B, C | D, ~(D | P)], {A: True, B: True, C: True, D: False, P: False}) 323 | check_SAT([A, B, ~C, D], {C: False, A: True, B: True, D: True}) 324 | # Test WalkSat for problems without solution 325 | assert WalkSAT([A & ~A], 0.5, 100) is None 326 | assert WalkSAT([A & B, C | D, ~(D | B)], 0.5, 100) is None 327 | assert WalkSAT([A | B, ~A, ~(B | C), C | D, P | Q], 0.5, 100) is None 328 | assert WalkSAT([A | B, B & C, C | D, D & A, P, ~P], 0.5, 100) is None 329 | 330 | 331 | def test_SAT_plan(): 332 | transition = {'A': {'Left': 'A', 'Right': 'B'}, 333 | 'B': {'Left': 'A', 'Right': 'C'}, 334 | 'C': {'Left': 'B', 'Right': 'C'}} 335 | assert SAT_plan('A', transition, 'C', 2) is None 336 | assert SAT_plan('A', transition, 'B', 3) == ['Right'] 337 | assert SAT_plan('C', transition, 'A', 3) == ['Left', 'Left'] 338 | 339 | transition = {(0, 0): {'Right': (0, 1), 'Down': (1, 0)}, 340 | (0, 1): {'Left': (1, 0), 'Down': (1, 1)}, 341 | (1, 0): {'Right': (1, 0), 'Up': (1, 0), 'Left': (1, 0), 'Down': (1, 0)}, 342 | (1, 1): {'Left': (1, 0), 'Up': (0, 1)}} 343 | assert SAT_plan((0, 0), transition, (1, 1), 4) == ['Right', 'Down'] 344 | 345 | 346 | if __name__ == '__main__': 347 | pytest.main() 348 | -------------------------------------------------------------------------------- /knowledge.py: -------------------------------------------------------------------------------- 1 | """Knowledge in learning, Chapter 19""" 2 | 3 | from random import shuffle 4 | from math import log 5 | from utils import powerset 6 | from collections import defaultdict 7 | from itertools import combinations, product 8 | from logic import (FolKB, constant_symbols, predicate_symbols, standardize_variables, 9 | variables, is_definite_clause, subst, expr, Expr) 10 | from functools import partial 11 | 12 | # ______________________________________________________________________________ 13 | 14 | 15 | def current_best_learning(examples, h, examples_so_far=None): 16 | """ [Figure 19.2] 17 | The hypothesis is a list of dictionaries, with each dictionary representing 18 | a disjunction.""" 19 | if not examples: 20 | return h 21 | 22 | examples_so_far = examples_so_far or [] 23 | e = examples[0] 24 | if is_consistent(e, h): 25 | return current_best_learning(examples[1:], h, examples_so_far + [e]) 26 | elif false_positive(e, h): 27 | for h2 in specializations(examples_so_far + [e], h): 28 | h3 = current_best_learning(examples[1:], h2, examples_so_far + [e]) 29 | if h3 != 'FAIL': 30 | return h3 31 | elif false_negative(e, h): 32 | for h2 in generalizations(examples_so_far + [e], h): 33 | h3 = current_best_learning(examples[1:], h2, examples_so_far + [e]) 34 | if h3 != 'FAIL': 35 | return h3 36 | 37 | return 'FAIL' 38 | 39 | 40 | def specializations(examples_so_far, h): 41 | """Specialize the hypothesis by adding AND operations to the disjunctions""" 42 | hypotheses = [] 43 | 44 | for i, disj in enumerate(h): 45 | for e in examples_so_far: 46 | for k, v in e.items(): 47 | if k in disj or k == 'GOAL': 48 | continue 49 | 50 | h2 = h[i].copy() 51 | h2[k] = '!' + v 52 | h3 = h.copy() 53 | h3[i] = h2 54 | if check_all_consistency(examples_so_far, h3): 55 | hypotheses.append(h3) 56 | 57 | shuffle(hypotheses) 58 | return hypotheses 59 | 60 | 61 | def generalizations(examples_so_far, h): 62 | """Generalize the hypothesis. First delete operations 63 | (including disjunctions) from the hypothesis. Then, add OR operations.""" 64 | hypotheses = [] 65 | 66 | # Delete disjunctions 67 | disj_powerset = powerset(range(len(h))) 68 | for disjs in disj_powerset: 69 | h2 = h.copy() 70 | for d in reversed(list(disjs)): 71 | del h2[d] 72 | 73 | if check_all_consistency(examples_so_far, h2): 74 | hypotheses += h2 75 | 76 | # Delete AND operations in disjunctions 77 | for i, disj in enumerate(h): 78 | a_powerset = powerset(disj.keys()) 79 | for attrs in a_powerset: 80 | h2 = h[i].copy() 81 | for a in attrs: 82 | del h2[a] 83 | 84 | if check_all_consistency(examples_so_far, [h2]): 85 | h3 = h.copy() 86 | h3[i] = h2.copy() 87 | hypotheses += h3 88 | 89 | # Add OR operations 90 | if hypotheses == [] or hypotheses == [{}]: 91 | hypotheses = add_or(examples_so_far, h) 92 | else: 93 | hypotheses.extend(add_or(examples_so_far, h)) 94 | 95 | shuffle(hypotheses) 96 | return hypotheses 97 | 98 | 99 | def add_or(examples_so_far, h): 100 | """Add an OR operation to the hypothesis. The AND operations in the disjunction 101 | are generated by the last example (which is the problematic one).""" 102 | ors = [] 103 | e = examples_so_far[-1] 104 | 105 | attrs = {k: v for k, v in e.items() if k != 'GOAL'} 106 | a_powerset = powerset(attrs.keys()) 107 | 108 | for c in a_powerset: 109 | h2 = {} 110 | for k in c: 111 | h2[k] = attrs[k] 112 | 113 | if check_negative_consistency(examples_so_far, h2): 114 | h3 = h.copy() 115 | h3.append(h2) 116 | ors.append(h3) 117 | 118 | return ors 119 | 120 | # ______________________________________________________________________________ 121 | 122 | 123 | def version_space_learning(examples): 124 | """ [Figure 19.3] 125 | The version space is a list of hypotheses, which in turn are a list 126 | of dictionaries/disjunctions.""" 127 | V = all_hypotheses(examples) 128 | for e in examples: 129 | if V: 130 | V = version_space_update(V, e) 131 | 132 | return V 133 | 134 | 135 | def version_space_update(V, e): 136 | return [h for h in V if is_consistent(e, h)] 137 | 138 | 139 | def all_hypotheses(examples): 140 | """Build a list of all the possible hypotheses""" 141 | values = values_table(examples) 142 | h_powerset = powerset(values.keys()) 143 | hypotheses = [] 144 | for s in h_powerset: 145 | hypotheses.extend(build_attr_combinations(s, values)) 146 | 147 | hypotheses.extend(build_h_combinations(hypotheses)) 148 | 149 | return hypotheses 150 | 151 | 152 | def values_table(examples): 153 | """Build a table with all the possible values for each attribute. 154 | Returns a dictionary with keys the attribute names and values a list 155 | with the possible values for the corresponding attribute.""" 156 | values = defaultdict(lambda: []) 157 | for e in examples: 158 | for k, v in e.items(): 159 | if k == 'GOAL': 160 | continue 161 | 162 | mod = '!' 163 | if e['GOAL']: 164 | mod = '' 165 | 166 | if mod + v not in values[k]: 167 | values[k].append(mod + v) 168 | 169 | values = dict(values) 170 | return values 171 | 172 | 173 | def build_attr_combinations(s, values): 174 | """Given a set of attributes, builds all the combinations of values. 175 | If the set holds more than one attribute, recursively builds the 176 | combinations.""" 177 | if len(s) == 1: 178 | # s holds just one attribute, return its list of values 179 | k = values[s[0]] 180 | h = [[{s[0]: v}] for v in values[s[0]]] 181 | return h 182 | 183 | h = [] 184 | for i, a in enumerate(s): 185 | rest = build_attr_combinations(s[i+1:], values) 186 | for v in values[a]: 187 | o = {a: v} 188 | for r in rest: 189 | t = o.copy() 190 | for d in r: 191 | t.update(d) 192 | h.append([t]) 193 | 194 | return h 195 | 196 | 197 | def build_h_combinations(hypotheses): 198 | """Given a set of hypotheses, builds and returns all the combinations of the 199 | hypotheses.""" 200 | h = [] 201 | h_powerset = powerset(range(len(hypotheses))) 202 | 203 | for s in h_powerset: 204 | t = [] 205 | for i in s: 206 | t.extend(hypotheses[i]) 207 | h.append(t) 208 | 209 | return h 210 | 211 | # ______________________________________________________________________________ 212 | 213 | 214 | def minimal_consistent_det(E, A): 215 | """Return a minimal set of attributes which give consistent determination""" 216 | n = len(A) 217 | 218 | for i in range(n + 1): 219 | for A_i in combinations(A, i): 220 | if consistent_det(A_i, E): 221 | return set(A_i) 222 | 223 | 224 | def consistent_det(A, E): 225 | """Check if the attributes(A) is consistent with the examples(E)""" 226 | H = {} 227 | 228 | for e in E: 229 | attr_values = tuple(e[attr] for attr in A) 230 | if attr_values in H and H[attr_values] != e['GOAL']: 231 | return False 232 | H[attr_values] = e['GOAL'] 233 | 234 | return True 235 | 236 | # ______________________________________________________________________________ 237 | 238 | 239 | class FOIL_container(FolKB): 240 | """Hold the kb and other necessary elements required by FOIL.""" 241 | 242 | def __init__(self, clauses=None): 243 | self.const_syms = set() 244 | self.pred_syms = set() 245 | FolKB.__init__(self, clauses) 246 | 247 | def tell(self, sentence): 248 | if is_definite_clause(sentence): 249 | self.clauses.append(sentence) 250 | self.const_syms.update(constant_symbols(sentence)) 251 | self.pred_syms.update(predicate_symbols(sentence)) 252 | else: 253 | raise Exception("Not a definite clause: {}".format(sentence)) 254 | 255 | def foil(self, examples, target): 256 | """Learn a list of first-order horn clauses 257 | 'examples' is a tuple: (positive_examples, negative_examples). 258 | positive_examples and negative_examples are both lists which contain substitutions.""" 259 | clauses = [] 260 | 261 | pos_examples = examples[0] 262 | neg_examples = examples[1] 263 | 264 | while pos_examples: 265 | clause, extended_pos_examples = self.new_clause((pos_examples, neg_examples), target) 266 | # remove positive examples covered by clause 267 | pos_examples = self.update_examples(target, pos_examples, extended_pos_examples) 268 | clauses.append(clause) 269 | 270 | return clauses 271 | 272 | def new_clause(self, examples, target): 273 | """Find a horn clause which satisfies part of the positive 274 | examples but none of the negative examples. 275 | The horn clause is specified as [consequent, list of antecedents] 276 | Return value is the tuple (horn_clause, extended_positive_examples).""" 277 | clause = [target, []] 278 | # [positive_examples, negative_examples] 279 | extended_examples = examples 280 | while extended_examples[1]: 281 | l = self.choose_literal(self.new_literals(clause), extended_examples) 282 | clause[1].append(l) 283 | extended_examples = [sum([list(self.extend_example(example, l)) for example in 284 | extended_examples[i]], []) for i in range(2)] 285 | 286 | return (clause, extended_examples[0]) 287 | 288 | def extend_example(self, example, literal): 289 | """Generate extended examples which satisfy the literal.""" 290 | # find all substitutions that satisfy literal 291 | for s in self.ask_generator(subst(example, literal)): 292 | s.update(example) 293 | yield s 294 | 295 | def new_literals(self, clause): 296 | """Generate new literals based on known predicate symbols. 297 | Generated literal must share atleast one variable with clause""" 298 | share_vars = variables(clause[0]) 299 | for l in clause[1]: 300 | share_vars.update(variables(l)) 301 | for pred, arity in self.pred_syms: 302 | new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)} 303 | for args in product(share_vars.union(new_vars), repeat=arity): 304 | if any(var in share_vars for var in args): 305 | # make sure we don't return an existing rule 306 | if not Expr(pred, args) in clause[1]: 307 | yield Expr(pred, *[var for var in args]) 308 | 309 | 310 | def choose_literal(self, literals, examples): 311 | """Choose the best literal based on the information gain.""" 312 | 313 | return max(literals, key = partial(self.gain , examples = examples)) 314 | 315 | 316 | def gain(self, l ,examples): 317 | """ 318 | Find the utility of each literal when added to the body of the clause. 319 | Utility function is: 320 | gain(R, l) = T * (log_2 (post_pos / (post_pos + post_neg)) - log_2 (pre_pos / (pre_pos + pre_neg))) 321 | 322 | where: 323 | 324 | pre_pos = number of possitive bindings of rule R (=current set of rules) 325 | pre_neg = number of negative bindings of rule R 326 | post_pos = number of possitive bindings of rule R' (= R U {l} ) 327 | post_neg = number of negative bindings of rule R' 328 | T = number of possitive bindings of rule R that are still covered 329 | after adding literal l 330 | 331 | """ 332 | pre_pos = len(examples[0]) 333 | pre_neg = len(examples[1]) 334 | post_pos = sum([list(self.extend_example(example, l)) for example in examples[0]], []) 335 | post_neg = sum([list(self.extend_example(example, l)) for example in examples[1]], []) 336 | if pre_pos + pre_neg ==0 or len(post_pos) + len(post_neg)==0: 337 | return -1 338 | # number of positive example that are represented in extended_examples 339 | T = 0 340 | for example in examples[0]: 341 | represents = lambda d: all(d[x] == example[x] for x in example) 342 | if any(represents(l_) for l_ in post_pos): 343 | T += 1 344 | value = T * (log(len(post_pos) / (len(post_pos) + len(post_neg)) + 1e-12,2) - log(pre_pos / (pre_pos + pre_neg),2)) 345 | return value 346 | 347 | 348 | def update_examples(self, target, examples, extended_examples): 349 | """Add to the kb those examples what are represented in extended_examples 350 | List of omitted examples is returned.""" 351 | uncovered = [] 352 | for example in examples: 353 | represents = lambda d: all(d[x] == example[x] for x in example) 354 | if any(represents(l) for l in extended_examples): 355 | self.tell(subst(example, target)) 356 | else: 357 | uncovered.append(example) 358 | 359 | return uncovered 360 | 361 | 362 | # ______________________________________________________________________________ 363 | 364 | 365 | def check_all_consistency(examples, h): 366 | """Check for the consistency of all examples under h.""" 367 | for e in examples: 368 | if not is_consistent(e, h): 369 | return False 370 | 371 | return True 372 | 373 | 374 | def check_negative_consistency(examples, h): 375 | """Check if the negative examples are consistent under h.""" 376 | for e in examples: 377 | if e['GOAL']: 378 | continue 379 | 380 | if not is_consistent(e, [h]): 381 | return False 382 | 383 | return True 384 | 385 | 386 | def disjunction_value(e, d): 387 | """The value of example e under disjunction d.""" 388 | for k, v in d.items(): 389 | if v[0] == '!': 390 | # v is a NOT expression 391 | # e[k], thus, should not be equal to v 392 | if e[k] == v[1:]: 393 | return False 394 | elif e[k] != v: 395 | return False 396 | 397 | return True 398 | 399 | 400 | def guess_value(e, h): 401 | """Guess value of example e under hypothesis h.""" 402 | for d in h: 403 | if disjunction_value(e, d): 404 | return True 405 | 406 | return False 407 | 408 | 409 | def is_consistent(e, h): 410 | return e["GOAL"] == guess_value(e, h) 411 | 412 | 413 | def false_positive(e, h): 414 | return guess_value(e, h) and not e["GOAL"] 415 | 416 | 417 | def false_negative(e, h): 418 | return e["GOAL"] and not guess_value(e, h) 419 | 420 | 421 | 422 | 423 | 424 | --------------------------------------------------------------------------------