├── CONTRIBUTORS ├── LICENSE ├── README.md ├── Setup.hs ├── TODO.md ├── aima-haskell.cabal ├── data ├── problems_large.txt └── problems_small.txt ├── profiling ├── LikelihoodWeighting │ ├── cleanup.sh │ ├── likelihoodWeighting.hp │ ├── likelihoodWeighting.hs │ ├── likelihoodWeighting.prof │ ├── likelihoodWeighting.ps │ └── run.sh └── Restaurants │ ├── Restaurants.hs │ └── run.sh └── src └── AI ├── Core └── Agents.hs ├── Learning ├── Bootstrap.hs ├── Core.hs ├── CrossValidation.hs ├── DecisionTree.hs ├── Example │ ├── Restaurant.hs │ └── Students.hs ├── LinearRegression.hs ├── LogisticRegression.hs ├── NeuralNetwork.hs ├── Perceptron.hs └── RandomForest.hs ├── Logic ├── Core.hs ├── FOL.hs ├── Interactive.hs └── Propositional.hs ├── Probability ├── Bayes.hs ├── Example │ ├── Alarm.hs │ └── Grass.hs └── MDP.hs ├── Search ├── Adversarial.hs ├── CSP.hs ├── Core.hs ├── Example │ ├── Chess.hs │ ├── Connect4.hs │ ├── Fig52Game.hs │ ├── Graph.hs │ ├── MapColoring.hs │ ├── NQueens.hs │ ├── Sudoku.hs │ └── TicTacToe.hs ├── Informed.hs ├── Local.hs └── Uninformed.hs ├── Test ├── Learning │ └── LinearRegression.hs ├── Main.hs └── Util.hs └── Util ├── Array.hs ├── Graph.hs ├── Matrix.hs ├── ProbDist.hs ├── Queue.hs ├── Table.hs ├── Util.hs └── WeightedGraph.hs /CONTRIBUTORS: -------------------------------------------------------------------------------- 1 | Chris Taylor https://github.com/chris-taylor 2 | mhuesch https://github.com/mhuesch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 2 | Version 2, December 2004 3 | 4 | Copyright (C) 2004 Sam Hocevar 5 | 6 | Everyone is permitted to copy and distribute verbatim or modified 7 | copies of this license document, and changing it is allowed as long 8 | as the name is changed. 9 | 10 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 11 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aima-haskell 2 | 3 | Algorithms from *Artificial Intelligence: A Modern Approach* by Russell and Norvig. 4 | 5 | ## Part I. Artificial Intelligence 6 | 7 | ### 2. Intelligent Agents 8 | 9 | - Environment (Fig 2.1) 10 | - Agent (Fig 2.1) 11 | 12 | ## Part II. Problem Solving 13 | 14 | ### 3. Searching 15 | 16 | Completed: 17 | 18 | - Problem 19 | - Node 20 | - Tree Search (Fig 3.7) 21 | - Graph Search (Fig 3.7) 22 | - Breadth First Search (Fig 3.11) 23 | - Uniform Cost Search (Fig 3.14) 24 | - Depth First Search 25 | - Depth-Limited Search (Fig 3.17) 26 | - Iterative Deepening Search (Fig 3.18) 27 | - Greedy Best First Search 28 | - A* Search 29 | 30 | To do: 31 | 32 | - Recursive Best First Search (Fig 3.26) 33 | - Iterative-Deepening A* 34 | - Memory-Bounded A* (MA*) 35 | - Simplified MA* 36 | - Bidirectional Search 37 | - Eight Puzzle 38 | 39 | ### 4. Beyond Classical Search 40 | 41 | Completed: 42 | 43 | - Hill-Climbing (Fig 4.2) 44 | - Simulated Annealing (Fig 4.5) 45 | 46 | To do: 47 | 48 | - Genetic Algorithm (Fig 4.8) 49 | - And/Or Graph Search (Fig 4.11) 50 | - Online Depth First Search (Fig 4.21) 51 | - LRTA* (Fig 4.24) 52 | 53 | ### 5. Adversarial Search 54 | 55 | Completed: 56 | 57 | - Minimax Search (Fig 5.3) 58 | - Alpha-Beta Search (Fig 5.7) 59 | - Searching with cutoff 60 | 61 | To do: 62 | 63 | - Stochastic games 64 | 65 | ### 6. Constraint Satisfaction Problems 66 | 67 | Completed: 68 | 69 | - AC3 (Fig 6.3) 70 | - Backtracking Search (Fig 6.5) 71 | 72 | To do: 73 | 74 | - Min Conflicts (Fig 6.8) 75 | - Tree CSP Solver (Fig 6.11) 76 | 77 | ## Part III. Knowledge, Reasoning and Planning 78 | 79 | ### 7. Logical Agents 80 | 81 | Completed: 82 | 83 | - TT-Entails (Fig 7.10) 84 | - PL-Resolution (Fig 7.12) 85 | - PL-FC-Entails (Fig 7.15) 86 | 87 | To do: 88 | 89 | - DPLL-Satisfiable (Fig 7.17) 90 | - WalkSAT (Fig 7.18) 91 | - Wumpus World 92 | 93 | ### 8-9. First-Order Logic 94 | 95 | Completed: 96 | 97 | - Unify (Fig 9.1) 98 | - FOL-FC-Ask (Fig 9.3) 99 | 100 | To do: 101 | 102 | - FOL-BC-Ask (Fig 9.6) 103 | 104 | ### 10. Classical Planning 105 | 106 | ### 11. Planning and Acting in the Real World 107 | 108 | ### 12. Knowledge Representation 109 | 110 | ## Part IV. Uncertain Knowledge and Reasoning 111 | 112 | ### 14. Probabilistic Reasoning 113 | 114 | Completed: 115 | 116 | - Enumeration-Ask (Fig 14.9) 117 | - Elimination-Ask (Fig 14.11) 118 | - Prior-Sample (Fig 14.13) 119 | - Rejection-Sampling (Fig 14.14) 120 | - Likelihood-Weighting (Fig 14.15) 121 | 122 | To do: 123 | 124 | - Gibbs-Ask (Fig 14.16) 125 | - Fit Bayes Networks from data 126 | 127 | ### 15. Probabilistic Reasoning Over Time 128 | 129 | To do: 130 | 131 | - Kalman Filter 132 | - Particle Filter (Fig 15.17) 133 | 134 | ### 16/17. Making Complex Decisions 135 | 136 | Completed: 137 | 138 | - Value Iteration (Fig 17.4) 139 | - Policy Iteration (Fig 17.7) 140 | 141 | To do: 142 | 143 | - POMDP Value Iteration (Fig 17.9) 144 | 145 | ### 18. Learning from Examples 146 | 147 | Completed: 148 | 149 | - Decision Tree Learning (Fig 18.5) 150 | - Cross-Validation (Fig 18.8) 151 | - Linear regression 152 | - Logistic regression 153 | 154 | To do: 155 | 156 | - Decision List Learning (Fig 18.11) 157 | - Artificial Neural Networks 158 | - Back Prop Learning (Fig 18.24) 159 | - Nearest Neighbour 160 | - Nonparametric Regression 161 | - Regression Trees 162 | - Support Vector Machines 163 | - AdaBoost (Fig 18.34) 164 | 165 | ### 20. Statistical Learning 166 | 167 | To do: 168 | 169 | - Naive Bayes 170 | 171 | ### 21. Reinforcement Learning 172 | 173 | To do: 174 | 175 | - TD-Learning 176 | - Q-Learning 177 | - SARSA 178 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | --#! /usr/bin/env runhaskell 2 | import Distribution.Simple 3 | main = defaultMain -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # AIMA Haskell ToDo list 2 | 3 | ## General 4 | 5 | - Complete reorganization. Do we want to separate: 6 | - core modules 7 | - interactivity (e.g. playing games) 8 | - research (e.g. collecting statistics) 9 | - examples 10 | - Improve Haddock documentation, including section headings. 11 | - Use either System.TimeIt or Criterion for timing calculations. 12 | - Keep making improvements to ProbDist module (or use the PFP module?) 13 | 14 | ## Utils 15 | 16 | - Extensions to table-generating code. For example: 17 | - Write tables to arbitrary handles. 18 | - More customisable table layout. 19 | - Use an external library? 20 | - Improve queueing module to use a more efficient data structure for priority queues. 21 | - Rewrite the Util.Array module to actually use arrays! 22 | - Organize Utils module into subsections. 23 | 24 | ## Search 25 | 26 | - Informed/uninformed search 27 | - More statistics, e.g. time taken. 28 | - Is it possible to structure backtracking search as a monad? 29 | - Round out the search functions by including e.g. depth-limited graph search. 30 | - Fill in missing search functions from AIMA 31 | - Round effective branching factor to a sensible number of dps. 32 | - Local search 33 | - Make changes to n-queens problem so that it can be used with simulatedAnnealingsearch. 34 | - Include genetic algorithm 35 | - Adversarial Search 36 | - Alpha-beta search that orders nodes according to some heuristic before searching. 37 | - Stochastic games (using probability monad?) 38 | - Figure out why alpha/beta search sometimes makes stupid decisions while playingConnect 4. 39 | - Constraint Satisfaction 40 | - Wrapper for CSP class to allow statistics to be collected. 41 | - Examples - word puzzles, scheduling, n queens? 42 | - Examples 43 | - Use GraphViz to draw graph problems 44 | - Interactive stepping through a graph problem - highlight nodes as they are explored, keep track of running cost etc. 45 | - GUI for graph problems? 46 | - GUI for interactive playing of tic-tac-toe/connect 4? 47 | - Finish chess example 48 | - Other games, e.g checkers? 49 | 50 | ## Logic 51 | 52 | - Propositional logic: 53 | - Truth-table SAT 54 | - Local search for SAT 55 | - Backward chaining 56 | - First-order logic: 57 | - Backward chaining 58 | - Reduction to normal form 59 | 60 | ## Probability 61 | 62 | - Flesh out functionality for MDPs. 63 | - Partially observed MDPs 64 | - Bayes Net uses lists rather than arrays to store the conditional probability table. This is probably inefficient - profile it and check! 65 | - Markov chain (Gibbs sampling) routines for Bayes Net 66 | - Function to compute children of a node in a Bayes Net 67 | - Function to compute markov blanket of a node in a Bayes Net 68 | 69 | ## Learning 70 | 71 | - More functions to auto-fit decision trees 72 | - Decision tree demos are really slow - can they be optimized? In particular, computing the entropy for splits takes a long time. 73 | - Cross validation should return average validation set error rate 74 | - Handle continuous attributes 75 | - Compute precision, recall and f-statistic 76 | - Function to compare multiple classifiers 77 | - More examples - test random forest vs. pruned decision trees 78 | - Tests for linear/logistic regression and regularized regression 79 | - LASSO 80 | - Linear classifiers 81 | - Naive Bayes -------------------------------------------------------------------------------- /aima-haskell.cabal: -------------------------------------------------------------------------------- 1 | Name: aima-haskell 2 | Version: 0.1 3 | Synopsis: Artificial Intelligence: A Modern Approach 4 | Description: Implementation of algorithms in Russell and Norvig's AIMA 5 | License: OtherLicense 6 | License-File: LICENSE 7 | Author: Chris Taylor 8 | Maintainer: Chris Taylor 9 | Stability: Experimental 10 | Homepage: https://github.com/chris-taylor/aima-haskell 11 | Cabal-Version: >= 1.8 12 | Build-Type: Simple 13 | Category: Language 14 | 15 | Extra-Source-Files: README.md 16 | LICENSE 17 | 18 | Data-Files: data/problems_small.txt 19 | data/problems_large.txt 20 | 21 | Source-Repository head 22 | Type: git 23 | Location: git://github.com/chris-taylor/aima-haskell.git 24 | 25 | Library 26 | Extensions: MultiParamTypeClasses 27 | Hs-Source-Dirs: src 28 | GHC-Options: -O2 29 | GHC-Prof-Options: -O2 -auto-all 30 | Build-Depends: base >= 2, 31 | containers, 32 | random, 33 | stm, 34 | deepseq, 35 | mtl, 36 | text, 37 | parsec, 38 | MonadRandom, 39 | array, 40 | gnuplot, 41 | hmatrix, 42 | QuickCheck 43 | Exposed-Modules: AI.Core.Agents 44 | AI.Search.Core 45 | AI.Search.Uninformed 46 | AI.Search.Informed 47 | AI.Search.Local 48 | AI.Search.Adversarial 49 | AI.Search.CSP 50 | AI.Search.Example.Chess 51 | AI.Search.Example.Connect4 52 | AI.Search.Example.Fig52Game 53 | AI.Search.Example.Graph 54 | AI.Search.Example.MapColoring 55 | AI.Search.Example.NQueens 56 | AI.Search.Example.Sudoku 57 | AI.Search.Example.TicTacToe 58 | AI.Logic.Core 59 | AI.Logic.Interactive 60 | AI.Logic.Propositional 61 | AI.Logic.FOL 62 | AI.Probability.Bayes 63 | AI.Probability.MDP 64 | AI.Probability.Example.Alarm 65 | AI.Probability.Example.Grass 66 | AI.Learning.Bootstrap 67 | AI.Learning.Core 68 | AI.Learning.CrossValidation 69 | AI.Learning.DecisionTree 70 | AI.Learning.LinearRegression 71 | AI.Learning.LogisticRegression 72 | AI.Learning.NeuralNetwork 73 | AI.Learning.Perceptron 74 | AI.Learning.RandomForest 75 | AI.Learning.Example.Students 76 | AI.Learning.Example.Restaurant 77 | AI.Test.Main 78 | AI.Test.Learning.LinearRegression 79 | AI.Test.Util 80 | AI.Util.Array 81 | AI.Util.Graph 82 | AI.Util.Matrix 83 | AI.Util.ProbDist 84 | AI.Util.Queue 85 | AI.Util.Table 86 | AI.Util.Util 87 | AI.Util.WeightedGraph 88 | 89 | 90 | -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/cleanup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rm *.hi *.o *.aux 4 | rm likelihoodWeighting -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/likelihoodWeighting.hp: -------------------------------------------------------------------------------- 1 | JOB "likelihoodWeighting.exe +RTS -K100M -s -p -hy" 2 | DATE "Tue Jul 24 11:18 2012" 3 | SAMPLE_UNIT "seconds" 4 | VALUE_UNIT "bytes" 5 | BEGIN_SAMPLE 0.00 6 | END_SAMPLE 0.00 7 | BEGIN_SAMPLE 0.07 8 | ThreadId 8 9 | IO 20 10 | IO 20 11 | IO 12 12 | String 12 13 | Integer 16 14 | BA 8 15 | TextEncoding 16 16 | BayesNet 12 17 | BufferCodec 24 18 | BufferList 12 19 | MUT_VAR_DIRTY 8 20 | IO 8 21 | ->IO 8 22 | MVAR 16 23 | ->(#,#) 8 24 | ->(#,#) 8 25 | ->(#,#) 8 26 | ->(#,#) 8 27 | ->(#,#) 8 28 | (,) 24 29 | Handle 12 30 | Word32 8 31 | ->IO 24 32 | WEAK 48 33 | ->[] 8 34 | ->>IO 32 35 | STRef 32 36 | ->>IO 8 37 | ->>IO 16 38 | ForeignPtr 12 39 | Buffer 84 40 | BufferState 12 41 | ForeignPtrContents 36 42 | Maybe 16 43 | ->(#,#) 8 44 | Node 60 45 | [] 1128 46 | MUT_VAR_CLEAN 64 47 | Handle__ 68 48 | ->IO 8 49 | Int32 24 50 | [] 20 51 | StdGen 12 52 | ->* 20 53 | Integer 12 54 | BLACKHOLE 32 55 | ARR_WORDS 33052 56 | Map 138852 57 | Float 39680 58 | Map 476232 59 | END_SAMPLE 0.07 60 | BEGIN_SAMPLE 0.18 61 | ThreadId 8 62 | IO 20 63 | BufferCodec 24 64 | BufferList 12 65 | MUT_VAR_DIRTY 8 66 | IO 8 67 | ->IO 8 68 | MVAR 16 69 | ->(#,#) 8 70 | ->(#,#) 8 71 | ->(#,#) 8 72 | ->(#,#) 8 73 | ->(#,#) 8 74 | (,) 24 75 | Handle 12 76 | Word32 8 77 | ->IO 24 78 | WEAK 48 79 | ->[] 8 80 | ->>IO 32 81 | ->>IO 8 82 | ->>IO 16 83 | ForeignPtr 12 84 | Buffer 84 85 | BufferState 12 86 | ForeignPtrContents 36 87 | Maybe 16 88 | ->(#,#) 8 89 | Node 60 90 | MUT_VAR_CLEAN 64 91 | Handle__ 68 92 | ->IO 8 93 | IO 20 94 | IO 12 95 | String 12 96 | Integer 16 97 | BA 8 98 | [] 1128 99 | TextEncoding 16 100 | STRef 32 101 | BayesNet 12 102 | Int32 24 103 | [] 20 104 | StdGen 12 105 | ->* 20 106 | Integer 12 107 | BLACKHOLE 32 108 | ARR_WORDS 33052 109 | Map 321188 110 | Float 91776 111 | Map 1101384 112 | END_SAMPLE 0.18 113 | BEGIN_SAMPLE 0.28 114 | IO 20 115 | IO 12 116 | String 12 117 | BA 8 118 | TextEncoding 16 119 | Integer 16 120 | BayesNet 12 121 | IO 20 122 | ThreadId 8 123 | ->[] 8 124 | ->>IO 32 125 | MUT_VAR_DIRTY 8 126 | STRef 32 127 | ->>IO 8 128 | ->>IO 16 129 | ForeignPtr 12 130 | Handle__ 68 131 | ->IO 8 132 | IO 8 133 | ->IO 8 134 | MVAR 16 135 | ->(#,#) 8 136 | ->(#,#) 8 137 | ->(#,#) 8 138 | ->(#,#) 8 139 | ->(#,#) 8 140 | Handle 12 141 | Word32 8 142 | ->IO 24 143 | WEAK 48 144 | Node 60 145 | BufferState 12 146 | ForeignPtrContents 36 147 | BufferCodec 24 148 | BufferList 12 149 | (,) 24 150 | Buffer 84 151 | Maybe 16 152 | ->(#,#) 8 153 | [] 1128 154 | MUT_VAR_CLEAN 64 155 | [] 20 156 | ->* 20 157 | ->>>>(#,#) 28 158 | BLACKHOLE 24 159 | ARR_WORDS 33036 160 | Int32 24 161 | Float 143216 162 | StdGen 12 163 | Map 1718736 164 | Map 501228 165 | END_SAMPLE 0.28 166 | BEGIN_SAMPLE 0.39 167 | ThreadId 8 168 | IO 20 169 | IO 20 170 | IO 12 171 | String 12 172 | Integer 16 173 | BA 8 174 | TextEncoding 16 175 | BayesNet 12 176 | BufferCodec 24 177 | BufferList 12 178 | MUT_VAR_DIRTY 8 179 | IO 8 180 | ->IO 8 181 | MVAR 16 182 | ->(#,#) 8 183 | ->(#,#) 8 184 | ->(#,#) 8 185 | ->(#,#) 8 186 | ->(#,#) 8 187 | (,) 24 188 | Handle 12 189 | Word32 8 190 | ->IO 24 191 | WEAK 48 192 | ->[] 8 193 | ->>IO 32 194 | STRef 32 195 | ->>IO 8 196 | ->>IO 16 197 | ForeignPtr 12 198 | Buffer 84 199 | BufferState 12 200 | ForeignPtrContents 36 201 | Maybe 16 202 | ->(#,#) 8 203 | Node 60 204 | [] 1128 205 | MUT_VAR_CLEAN 64 206 | Handle__ 68 207 | ->IO 8 208 | Int32 24 209 | [] 20 210 | StdGen 12 211 | ->* 20 212 | Integer 12 213 | BLACKHOLE 32 214 | ARR_WORDS 33052 215 | Map 685860 216 | Float 195968 217 | Map 2351688 218 | END_SAMPLE 0.39 219 | BEGIN_SAMPLE 0.48 220 | IO 20 221 | ThreadId 8 222 | IO 20 223 | IO 12 224 | String 12 225 | BA 8 226 | TextEncoding 16 227 | Integer 16 228 | BayesNet 12 229 | ->[] 8 230 | ->>IO 32 231 | MUT_VAR_DIRTY 8 232 | STRef 32 233 | ->>IO 8 234 | ->>IO 16 235 | ForeignPtr 12 236 | Handle__ 68 237 | ->IO 8 238 | IO 8 239 | ->IO 8 240 | MVAR 16 241 | ->(#,#) 8 242 | ->(#,#) 8 243 | ->(#,#) 8 244 | ->(#,#) 8 245 | ->(#,#) 8 246 | Handle 12 247 | Word32 8 248 | ->IO 24 249 | WEAK 48 250 | Node 60 251 | BufferState 12 252 | ForeignPtrContents 36 253 | BufferCodec 24 254 | BufferList 12 255 | (,) 24 256 | Buffer 84 257 | Maybe 16 258 | ->(#,#) 8 259 | [] 1128 260 | MUT_VAR_CLEAN 64 261 | [] 20 262 | ->* 20 263 | ->>>>(#,#) 28 264 | BLACKHOLE 24 265 | ARR_WORDS 33036 266 | Int32 24 267 | Float 247408 268 | StdGen 12 269 | Map 2969040 270 | Map 865900 271 | END_SAMPLE 0.48 272 | BEGIN_SAMPLE 0.57 273 | ThreadId 8 274 | IO 20 275 | IO 20 276 | IO 12 277 | String 12 278 | BufferCodec 24 279 | BufferList 12 280 | MUT_VAR_DIRTY 8 281 | IO 8 282 | ->IO 8 283 | MVAR 16 284 | ->(#,#) 8 285 | ->(#,#) 8 286 | ->(#,#) 8 287 | ->(#,#) 8 288 | ->(#,#) 8 289 | (,) 24 290 | Handle 12 291 | Word32 8 292 | ->IO 24 293 | WEAK 48 294 | ->[] 8 295 | ->>IO 32 296 | ->>IO 8 297 | ->>IO 16 298 | ForeignPtr 12 299 | Buffer 84 300 | BufferState 12 301 | ForeignPtrContents 36 302 | Maybe 16 303 | ->(#,#) 8 304 | Node 60 305 | MUT_VAR_CLEAN 64 306 | Handle__ 68 307 | ->IO 8 308 | Integer 16 309 | BA 8 310 | [] 1128 311 | TextEncoding 16 312 | STRef 32 313 | BayesNet 12 314 | Int32 24 315 | [] 20 316 | StdGen 12 317 | ->* 20 318 | Integer 12 319 | BLACKHOLE 32 320 | ARR_WORDS 33052 321 | Map 1050532 322 | Float 300160 323 | Map 3601992 324 | END_SAMPLE 0.57 325 | BEGIN_SAMPLE 0.68 326 | IO 20 327 | ThreadId 8 328 | IO 20 329 | IO 12 330 | String 12 331 | BA 8 332 | TextEncoding 16 333 | Integer 16 334 | BayesNet 12 335 | STRef 32 336 | ->>IO 8 337 | ->>IO 16 338 | ForeignPtr 12 339 | BufferState 12 340 | Handle__ 68 341 | ->IO 8 342 | IO 8 343 | ->IO 8 344 | MVAR 16 345 | ->(#,#) 8 346 | ->(#,#) 8 347 | ->(#,#) 8 348 | ->(#,#) 8 349 | ->(#,#) 8 350 | Handle 12 351 | Word32 8 352 | ->IO 24 353 | WEAK 48 354 | Node 60 355 | ->[] 8 356 | ->>IO 32 357 | MUT_VAR_DIRTY 8 358 | ForeignPtrContents 36 359 | BufferCodec 24 360 | BufferList 12 361 | (,) 24 362 | Buffer 84 363 | Maybe 16 364 | ->(#,#) 8 365 | [] 1128 366 | MUT_VAR_CLEAN 64 367 | [] 20 368 | ->* 20 369 | ->>>>(#,#) 28 370 | Integer 12 371 | BLACKHOLE 40 372 | ARR_WORDS 33048 373 | Int32 24 374 | StdGen 12 375 | Map 4250616 376 | Map 1239700 377 | Float 354208 378 | END_SAMPLE 0.68 379 | BEGIN_SAMPLE 0.79 380 | IO 20 381 | IO 12 382 | String 12 383 | BA 8 384 | TextEncoding 16 385 | Integer 16 386 | BayesNet 12 387 | ThreadId 8 388 | IO 20 389 | ->[] 8 390 | ->>IO 32 391 | MUT_VAR_DIRTY 8 392 | STRef 32 393 | ->>IO 8 394 | ->>IO 16 395 | ForeignPtr 12 396 | Handle__ 68 397 | ->IO 8 398 | IO 8 399 | ->IO 8 400 | MVAR 16 401 | ->(#,#) 8 402 | ->(#,#) 8 403 | ->(#,#) 8 404 | ->(#,#) 8 405 | ->(#,#) 8 406 | Handle 12 407 | Word32 8 408 | ->IO 24 409 | WEAK 48 410 | Node 60 411 | BufferState 12 412 | ForeignPtrContents 36 413 | BufferCodec 24 414 | BufferList 12 415 | (,) 24 416 | Buffer 84 417 | Maybe 16 418 | ->(#,#) 8 419 | [] 1128 420 | MUT_VAR_CLEAN 64 421 | [] 20 422 | ->* 20 423 | ->>>>(#,#) 28 424 | BLACKHOLE 24 425 | ARR_WORDS 33036 426 | Int32 24 427 | Float 406952 428 | StdGen 12 429 | Map 4883568 430 | Map 1424304 431 | END_SAMPLE 0.79 432 | BEGIN_SAMPLE 0.90 433 | IO 20 434 | IO 12 435 | String 12 436 | IO 20 437 | ThreadId 8 438 | Integer 16 439 | BA 8 440 | TextEncoding 16 441 | BayesNet 12 442 | BufferCodec 24 443 | BufferList 12 444 | MUT_VAR_DIRTY 8 445 | IO 8 446 | ->IO 8 447 | MVAR 16 448 | ->(#,#) 8 449 | ->(#,#) 8 450 | ->(#,#) 8 451 | ->(#,#) 8 452 | ->(#,#) 8 453 | (,) 24 454 | Handle 12 455 | Word32 8 456 | ->IO 24 457 | WEAK 48 458 | ->[] 8 459 | ->>IO 32 460 | STRef 32 461 | ->>IO 8 462 | ->>IO 16 463 | ForeignPtr 12 464 | Buffer 84 465 | BufferState 12 466 | ForeignPtrContents 36 467 | Maybe 16 468 | ->(#,#) 8 469 | Node 60 470 | [] 1128 471 | MUT_VAR_CLEAN 64 472 | Handle__ 68 473 | ->IO 8 474 | [] 20 475 | Int32 24 476 | stg_sel_upd 12 477 | ->* 20 478 | BLACKHOLE 16 479 | StdGen 12 480 | ARR_WORDS 33032 481 | Map 1613500 482 | Float 461008 483 | Map 5532168 484 | END_SAMPLE 0.90 485 | BEGIN_SAMPLE 1.01 486 | IO 20 487 | ThreadId 8 488 | IO 20 489 | IO 12 490 | String 12 491 | Integer 16 492 | BA 8 493 | TextEncoding 16 494 | BayesNet 12 495 | BufferCodec 24 496 | BufferList 12 497 | MUT_VAR_DIRTY 8 498 | IO 8 499 | ->IO 8 500 | MVAR 16 501 | ->(#,#) 8 502 | ->(#,#) 8 503 | ->(#,#) 8 504 | ->(#,#) 8 505 | ->(#,#) 8 506 | (,) 24 507 | Handle 12 508 | Word32 8 509 | ->IO 24 510 | WEAK 48 511 | ->[] 8 512 | ->>IO 32 513 | STRef 32 514 | ->>IO 8 515 | ->>IO 16 516 | ForeignPtr 12 517 | Buffer 84 518 | BufferState 12 519 | ForeignPtrContents 36 520 | Maybe 16 521 | ->(#,#) 8 522 | Node 60 523 | [] 1128 524 | MUT_VAR_CLEAN 64 525 | Handle__ 68 526 | ->IO 8 527 | [] 20 528 | Int32 24 529 | stg_sel_upd 12 530 | ->* 20 531 | BLACKHOLE 16 532 | StdGen 12 533 | ARR_WORDS 33032 534 | Map 1795836 535 | Float 513104 536 | Map 6157320 537 | END_SAMPLE 1.01 538 | BEGIN_SAMPLE 1.10 539 | ThreadId 8 540 | IO 20 541 | IO 20 542 | IO 12 543 | String 12 544 | BA 8 545 | TextEncoding 16 546 | Integer 16 547 | BayesNet 12 548 | STRef 32 549 | ->>IO 8 550 | ->>IO 16 551 | ForeignPtr 12 552 | BufferState 12 553 | Handle__ 68 554 | ->IO 8 555 | IO 8 556 | ->IO 8 557 | MVAR 16 558 | ->(#,#) 8 559 | ->(#,#) 8 560 | ->(#,#) 8 561 | ->(#,#) 8 562 | ->(#,#) 8 563 | Handle 12 564 | Word32 8 565 | ->IO 24 566 | WEAK 48 567 | Node 60 568 | ->[] 8 569 | ->>IO 32 570 | MUT_VAR_DIRTY 8 571 | ForeignPtrContents 36 572 | BufferCodec 24 573 | BufferList 12 574 | (,) 24 575 | Buffer 84 576 | Maybe 16 577 | ->(#,#) 8 578 | [] 1128 579 | MUT_VAR_CLEAN 64 580 | [] 20 581 | ->* 20 582 | ->>>>(#,#) 28 583 | Integer 12 584 | BLACKHOLE 40 585 | ARR_WORDS 33048 586 | Int32 24 587 | StdGen 12 588 | Map 6790296 589 | Map 1980440 590 | Float 565848 591 | END_SAMPLE 1.10 592 | BEGIN_SAMPLE 1.20 593 | ThreadId 8 594 | IO 20 595 | IO 20 596 | IO 12 597 | String 12 598 | Integer 16 599 | BA 8 600 | TextEncoding 16 601 | BayesNet 12 602 | BufferCodec 24 603 | BufferList 12 604 | MUT_VAR_DIRTY 8 605 | IO 8 606 | ->IO 8 607 | MVAR 16 608 | ->(#,#) 8 609 | ->(#,#) 8 610 | ->(#,#) 8 611 | ->(#,#) 8 612 | ->(#,#) 8 613 | (,) 24 614 | Handle 12 615 | Word32 8 616 | ->IO 24 617 | WEAK 48 618 | ->[] 8 619 | ->>IO 32 620 | STRef 32 621 | ->>IO 8 622 | ->>IO 16 623 | ForeignPtr 12 624 | Buffer 84 625 | BufferState 12 626 | ForeignPtrContents 36 627 | Maybe 16 628 | ->(#,#) 8 629 | Node 60 630 | MUT_VAR_CLEAN 64 631 | Handle__ 68 632 | ->IO 8 633 | [] 1128 634 | [] 20 635 | Int32 24 636 | stg_sel_upd 12 637 | ->* 20 638 | BLACKHOLE 16 639 | StdGen 12 640 | ARR_WORDS 33032 641 | Map 2171904 642 | Float 620552 643 | Map 7446696 644 | END_SAMPLE 1.20 645 | BEGIN_SAMPLE 1.27 646 | ThreadId 8 647 | IO 20 648 | IO 20 649 | IO 12 650 | String 12 651 | BA 8 652 | TextEncoding 16 653 | Integer 16 654 | BayesNet 12 655 | ARR_WORDS 33008 656 | ->>IO 8 657 | ->>IO 16 658 | ForeignPtr 12 659 | BufferState 12 660 | ForeignPtrContents 36 661 | BufferCodec 24 662 | ->IO 8 663 | IO 8 664 | ->IO 8 665 | MVAR 16 666 | ->(#,#) 8 667 | ->(#,#) 8 668 | ->(#,#) 8 669 | ->(#,#) 8 670 | ->(#,#) 8 671 | Handle 12 672 | Word32 8 673 | ->IO 24 674 | WEAK 48 675 | Node 60 676 | ->[] 8 677 | ->>IO 32 678 | MUT_VAR_DIRTY 8 679 | STRef 32 680 | BufferList 12 681 | (,) 24 682 | Buffer 84 683 | Maybe 16 684 | ->(#,#) 8 685 | [] 1128 686 | MUT_VAR_CLEAN 64 687 | Handle__ 68 688 | Int32 24 689 | StdGen 12 690 | [] 20 691 | Float 672272 692 | ->* 20 693 | Map 2352924 694 | ->>>>(#,#) 28 695 | Map 8067384 696 | END_SAMPLE 1.27 697 | BEGIN_SAMPLE 1.37 698 | IO 20 699 | ThreadId 8 700 | IO 20 701 | IO 12 702 | String 12 703 | BA 8 704 | TextEncoding 16 705 | Integer 16 706 | BayesNet 12 707 | STRef 32 708 | ->>IO 8 709 | ->>IO 16 710 | ForeignPtr 12 711 | BufferState 12 712 | Handle__ 68 713 | ->IO 8 714 | IO 8 715 | ->IO 8 716 | MVAR 16 717 | ->(#,#) 8 718 | ->(#,#) 8 719 | ->(#,#) 8 720 | ->(#,#) 8 721 | ->(#,#) 8 722 | Handle 12 723 | Word32 8 724 | ->IO 24 725 | WEAK 48 726 | Node 60 727 | ->[] 8 728 | ->>IO 32 729 | MUT_VAR_DIRTY 8 730 | ForeignPtrContents 36 731 | BufferCodec 24 732 | BufferList 12 733 | (,) 24 734 | Buffer 84 735 | Maybe 16 736 | ->(#,#) 8 737 | [] 1128 738 | MUT_VAR_CLEAN 64 739 | [] 20 740 | ->* 20 741 | ->>>>(#,#) 28 742 | Integer 12 743 | BLACKHOLE 40 744 | ARR_WORDS 33048 745 | Int32 24 746 | StdGen 12 747 | Map 8669112 748 | Map 2528428 749 | Float 722416 750 | END_SAMPLE 1.37 751 | BEGIN_SAMPLE 1.48 752 | IO 20 753 | IO 20 754 | ThreadId 8 755 | IO 12 756 | String 12 757 | Integer 16 758 | BA 8 759 | TextEncoding 16 760 | BayesNet 12 761 | BufferCodec 24 762 | BufferList 12 763 | MUT_VAR_DIRTY 8 764 | IO 8 765 | ->IO 8 766 | MVAR 16 767 | ->(#,#) 8 768 | ->(#,#) 8 769 | ->(#,#) 8 770 | ->(#,#) 8 771 | ->(#,#) 8 772 | (,) 24 773 | Handle 12 774 | Word32 8 775 | ->IO 24 776 | WEAK 48 777 | ->[] 8 778 | ->>IO 32 779 | STRef 32 780 | ->>IO 8 781 | ->>IO 16 782 | ForeignPtr 12 783 | Buffer 84 784 | BufferState 12 785 | ForeignPtrContents 36 786 | Maybe 16 787 | ->(#,#) 8 788 | Node 60 789 | [] 1128 790 | MUT_VAR_CLEAN 64 791 | Handle__ 68 792 | ->IO 8 793 | Int32 24 794 | [] 20 795 | StdGen 12 796 | ->* 20 797 | Integer 12 798 | BLACKHOLE 32 799 | ARR_WORDS 33052 800 | Map 2715328 801 | Float 775816 802 | Map 9309864 803 | END_SAMPLE 1.48 804 | BEGIN_SAMPLE 1.59 805 | END_SAMPLE 1.59 806 | -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/likelihoodWeighting.hs: -------------------------------------------------------------------------------- 1 | import AI.Probability.Example.Alarm 2 | import AI.Util.ProbDist 3 | 4 | n :: Int 5 | n = 100000 6 | 7 | fixed :: [(String,Bool)] 8 | fixed = [("JohnCalls",True),("MaryCalls",True)] 9 | 10 | x :: String 11 | x = "Burglary" 12 | 13 | main :: IO () 14 | main = do 15 | putStrLn "----------" 16 | 17 | let d1 = enumerationAsk alarm fixed x 18 | putInfo d1 19 | 20 | d2 <- likelihoodWeighting n alarm fixed x 21 | putInfo d2 22 | 23 | putInfo :: Show a => Dist a -> IO () 24 | putInfo d = do 25 | putStrLn $ show d 26 | putStrLn $ "----------" 27 | -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/likelihoodWeighting.prof: -------------------------------------------------------------------------------- 1 | Tue Jul 24 11:18 2012 Time and Allocation Profiling Report (Final) 2 | 3 | likelihoodWeighting.exe +RTS -K100M -s -p -hy -RTS 4 | 5 | total time = 1.59 secs (1587 ticks @ 1000 us, 1 processor) 6 | total alloc = 392,068,176 bytes (excludes profiling overheads) 7 | 8 | COST CENTRE MODULE %time %alloc 9 | 10 | main Main 98.4 100.0 11 | main.d1 Main 1.6 0.0 12 | 13 | 14 | individual inherited 15 | COST CENTRE MODULE no. entries %time %alloc %time %alloc 16 | 17 | MAIN MAIN 57 0 0.0 0.0 100.0 100.0 18 | main Main 115 0 98.4 100.0 98.4 100.0 19 | n Main 122 1 0.0 0.0 0.0 0.0 20 | putInfo Main 117 1 0.0 0.0 0.0 0.0 21 | CAF GHC.Integer.Logarithms.Internals 113 0 0.0 0.0 0.0 0.0 22 | CAF GHC.IO.Encoding.CodePage 98 0 0.0 0.0 0.0 0.0 23 | CAF System.CPUTime 93 0 0.0 0.0 0.0 0.0 24 | CAF GHC.Show 92 0 0.0 0.0 0.0 0.0 25 | CAF GHC.IO.Encoding 89 0 0.0 0.0 0.0 0.0 26 | CAF Data.Fixed 79 0 0.0 0.0 0.0 0.0 27 | CAF GHC.IO.Handle.FD 75 0 0.0 0.0 0.0 0.0 28 | CAF Data.Time.Clock.POSIX 72 0 0.0 0.0 0.0 0.0 29 | CAF System.Random 71 0 0.0 0.0 0.0 0.0 30 | CAF AI.Probability.Example.Alarm 67 0 0.0 0.0 0.0 0.0 31 | CAF AI.Probability.Bayes 65 0 0.0 0.0 0.0 0.0 32 | CAF Main 64 0 0.0 0.0 1.6 0.0 33 | putInfo Main 121 0 0.0 0.0 0.0 0.0 34 | fixed Main 120 1 0.0 0.0 0.0 0.0 35 | x Main 119 1 0.0 0.0 0.0 0.0 36 | main Main 114 1 0.0 0.0 1.6 0.0 37 | main.d1 Main 118 1 1.6 0.0 1.6 0.0 38 | putInfo Main 116 1 0.0 0.0 0.0 0.0 39 | -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/likelihoodWeighting.ps: -------------------------------------------------------------------------------- 1 | %!PS-Adobe-2.0 2 | %%Title: likelihoodWeighting.exe +RTS -K100M -s -p -hy 3 | %%Creator: c:\Program Files (x86)\Haskell Platform\2012.2.0.0\bin\hp2ps.exe (version 0.25) 4 | %%CreationDate: Tue Jul 24 11:18 2012 5 | %%EndComments 6 | -90 rotate 7 | -756.000000 72.000000 translate 8 | /HE10 /Helvetica findfont 10 scalefont def 9 | /HE12 /Helvetica findfont 12 scalefont def 10 | newpath 11 | 0 0 moveto 12 | 0 432.000000 rlineto 13 | 648.000000 0 rlineto 14 | 0 -432.000000 rlineto 15 | closepath 16 | 0.500000 setlinewidth 17 | stroke 18 | newpath 19 | 5.000000 387.000000 moveto 20 | 0 40.000000 rlineto 21 | 638.000000 0 rlineto 22 | 0 -40.000000 rlineto 23 | closepath 24 | 0.500000 setlinewidth 25 | stroke 26 | 5.000000 407.000000 moveto 27 | 638.000000 0 rlineto 28 | stroke 29 | HE12 setfont 30 | 11.000000 413.000000 moveto 31 | (likelihoodWeighting.exe +RTS -K100M -s -p -hy) show 32 | HE12 setfont 33 | 11.000000 393.000000 moveto 34 | (10,076,868 bytes x seconds) 35 | show 36 | HE12 setfont 37 | (Tue Jul 24 11:18 2012) 38 | dup stringwidth pop 39 | 637.000000 40 | exch sub 41 | 393.000000 moveto 42 | show 43 | 45.000000 20.000000 moveto 44 | 540.188974 0 rlineto 45 | 0.500000 setlinewidth 46 | stroke 47 | HE10 setfont 48 | (seconds) 49 | dup stringwidth pop 50 | 585.188974 51 | exch sub 52 | 5.000000 moveto 53 | show 54 | 45.000000 20.000000 moveto 55 | 0 -4 rlineto 56 | stroke 57 | HE10 setfont 58 | (0.0) 59 | dup stringwidth pop 60 | 2 div 61 | 45.000000 exch sub 62 | 5.000000 moveto 63 | show 64 | 112.948299 20.000000 moveto 65 | 0 -4 rlineto 66 | stroke 67 | HE10 setfont 68 | (0.2) 69 | dup stringwidth pop 70 | 2 div 71 | 112.948299 exch sub 72 | 5.000000 moveto 73 | show 74 | 180.896597 20.000000 moveto 75 | 0 -4 rlineto 76 | stroke 77 | HE10 setfont 78 | (0.4) 79 | dup stringwidth pop 80 | 2 div 81 | 180.896597 exch sub 82 | 5.000000 moveto 83 | show 84 | 248.844896 20.000000 moveto 85 | 0 -4 rlineto 86 | stroke 87 | HE10 setfont 88 | (0.6) 89 | dup stringwidth pop 90 | 2 div 91 | 248.844896 exch sub 92 | 5.000000 moveto 93 | show 94 | 316.793194 20.000000 moveto 95 | 0 -4 rlineto 96 | stroke 97 | HE10 setfont 98 | (0.8) 99 | dup stringwidth pop 100 | 2 div 101 | 316.793194 exch sub 102 | 5.000000 moveto 103 | show 104 | 384.741493 20.000000 moveto 105 | 0 -4 rlineto 106 | stroke 107 | HE10 setfont 108 | (1.0) 109 | dup stringwidth pop 110 | 2 div 111 | 384.741493 exch sub 112 | 5.000000 moveto 113 | show 114 | 452.689791 20.000000 moveto 115 | 0 -4 rlineto 116 | stroke 117 | HE10 setfont 118 | (1.2) 119 | dup stringwidth pop 120 | 2 div 121 | 452.689791 exch sub 122 | 5.000000 moveto 123 | show 124 | 520.638090 20.000000 moveto 125 | 0 -4 rlineto 126 | stroke 127 | HE10 setfont 128 | (1.4) 129 | dup stringwidth pop 130 | 2 div 131 | 520.638090 exch sub 132 | 5.000000 moveto 133 | show 134 | 45.000000 20.000000 moveto 135 | 0 362.000000 rlineto 136 | 0.500000 setlinewidth 137 | stroke 138 | gsave 139 | HE10 setfont 140 | (bytes) 141 | dup stringwidth pop 142 | 382.000000 143 | exch sub 144 | 40.000000 exch 145 | translate 146 | 90 rotate 147 | 0 0 moveto 148 | show 149 | grestore 150 | 45.000000 20.000000 moveto 151 | -4 0 rlineto 152 | stroke 153 | HE10 setfont 154 | (0M) 155 | dup stringwidth 156 | 2 div 157 | 20.000000 exch sub 158 | exch 159 | 40.000000 exch sub 160 | exch 161 | moveto 162 | show 163 | 45.000000 76.558046 moveto 164 | -4 0 rlineto 165 | stroke 166 | HE10 setfont 167 | (2M) 168 | dup stringwidth 169 | 2 div 170 | 76.558046 exch sub 171 | exch 172 | 40.000000 exch sub 173 | exch 174 | moveto 175 | show 176 | 45.000000 133.116092 moveto 177 | -4 0 rlineto 178 | stroke 179 | HE10 setfont 180 | (4M) 181 | dup stringwidth 182 | 2 div 183 | 133.116092 exch sub 184 | exch 185 | 40.000000 exch sub 186 | exch 187 | moveto 188 | show 189 | 45.000000 189.674138 moveto 190 | -4 0 rlineto 191 | stroke 192 | HE10 setfont 193 | (6M) 194 | dup stringwidth 195 | 2 div 196 | 189.674138 exch sub 197 | exch 198 | 40.000000 exch sub 199 | exch 200 | moveto 201 | show 202 | 45.000000 246.232184 moveto 203 | -4 0 rlineto 204 | stroke 205 | HE10 setfont 206 | (8M) 207 | dup stringwidth 208 | 2 div 209 | 246.232184 exch sub 210 | exch 211 | 40.000000 exch sub 212 | exch 213 | moveto 214 | show 215 | 45.000000 302.790230 moveto 216 | -4 0 rlineto 217 | stroke 218 | HE10 setfont 219 | (10M) 220 | dup stringwidth 221 | 2 div 222 | 302.790230 exch sub 223 | exch 224 | 40.000000 exch sub 225 | exch 226 | moveto 227 | show 228 | 590.188974 133.666667 moveto 229 | 0 14 rlineto 230 | 14 0 rlineto 231 | 0 -14 rlineto 232 | closepath 233 | gsave 234 | 0.000000 0.000000 0.000000 setrgbcolor 235 | fill 236 | grestore 237 | stroke 238 | HE10 setfont 239 | 609.188974 135.666667 moveto 240 | (Float) show 241 | 590.188974 254.333333 moveto 242 | 0 14 rlineto 243 | 14 0 rlineto 244 | 0 -14 rlineto 245 | closepath 246 | gsave 247 | 0.000000 0.000000 1.000000 setrgbcolor 248 | fill 249 | grestore 250 | stroke 251 | HE10 setfont 252 | 609.188974 256.333333 moveto 253 | (Map) show 254 | 45.000000 20.000000 moveto 255 | 45.000000 20.000000 lineto 256 | 68.781904 20.000000 lineto 257 | 106.153469 20.000000 lineto 258 | 140.127618 20.000000 lineto 259 | 177.499182 20.000000 lineto 260 | 208.075917 20.000000 lineto 261 | 238.652651 20.000000 lineto 262 | 276.024215 20.000000 lineto 263 | 313.395779 20.000000 lineto 264 | 350.767344 20.000000 lineto 265 | 388.138908 20.000000 lineto 266 | 418.715642 20.000000 lineto 267 | 452.689791 20.000000 lineto 268 | 476.471696 20.000000 lineto 269 | 510.445845 20.000000 lineto 270 | 547.817409 20.000000 lineto 271 | 585.188974 20.000000 lineto 272 | 585.188974 20.000000 lineto 273 | 585.188974 20.000000 lineto 274 | 547.817409 41.939319 lineto 275 | 510.445845 40.429219 lineto 276 | 476.471696 39.011195 lineto 277 | 452.689791 37.548604 lineto 278 | 418.715642 36.001629 lineto 279 | 388.138908 34.510080 lineto 280 | 350.767344 33.036856 lineto 281 | 313.395779 31.508205 lineto 282 | 276.024215 30.016656 lineto 283 | 238.652651 28.488232 lineto 284 | 208.075917 26.996457 lineto 285 | 177.499182 25.541784 lineto 286 | 140.127618 24.050009 lineto 287 | 106.153469 22.595336 lineto 288 | 68.781904 21.122112 lineto 289 | 45.000000 20.000000 lineto 290 | closepath 291 | gsave 292 | 0.000000 0.000000 0.000000 setrgbcolor 293 | fill 294 | grestore 295 | stroke 296 | 45.000000 20.000000 moveto 297 | 45.000000 20.000000 lineto 298 | 68.781904 21.122112 lineto 299 | 106.153469 22.595336 lineto 300 | 140.127618 24.050009 lineto 301 | 177.499182 25.541784 lineto 302 | 208.075917 26.996457 lineto 303 | 238.652651 28.488232 lineto 304 | 276.024215 30.016656 lineto 305 | 313.395779 31.508205 lineto 306 | 350.767344 33.036856 lineto 307 | 388.138908 34.510080 lineto 308 | 418.715642 36.001629 lineto 309 | 452.689791 37.548604 lineto 310 | 476.471696 39.011195 lineto 311 | 510.445845 40.429219 lineto 312 | 547.817409 41.939319 lineto 313 | 585.188974 20.000000 lineto 314 | 585.188974 20.000000 lineto 315 | 585.188974 20.000000 lineto 316 | 547.817409 382.000000 lineto 317 | 510.445845 357.084710 lineto 318 | 476.471696 333.687325 lineto 319 | 452.689791 309.553215 lineto 320 | 418.715642 284.029474 lineto 321 | 388.138908 259.417561 lineto 322 | 350.767344 235.109366 lineto 323 | 313.395779 209.888663 lineto 324 | 276.024215 185.277429 lineto 325 | 238.652651 160.057065 lineto 326 | 208.075917 135.444813 lineto 327 | 177.499182 111.440673 lineto 328 | 140.127618 86.828422 lineto 329 | 106.153469 62.824282 lineto 330 | 68.781904 38.516086 lineto 331 | 45.000000 20.000000 lineto 332 | closepath 333 | gsave 334 | 0.000000 0.000000 1.000000 setrgbcolor 335 | fill 336 | grestore 337 | stroke 338 | showpage 339 | -------------------------------------------------------------------------------- /profiling/LikelihoodWeighting/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Reinstall library 4 | cd ../.. 5 | cabal install -p -O2 6 | cd "profiling/"$1 7 | 8 | # Compile test file 9 | ghc -rtsopts -prof -fprof-auto -fforce-recomp --make -O2 $1".hs" 10 | 11 | # Run with profiling enabled and view output 12 | ./$1 +RTS -K100M -s -p -hy 13 | cat $1".prof" 14 | hp2ps -c $1".hp" 15 | open $1".ps" 16 | 17 | # Clean up files from last compilation 18 | ./cleanup.sh -------------------------------------------------------------------------------- /profiling/Restaurants/Restaurants.hs: -------------------------------------------------------------------------------- 1 | import AI.Learning.Example.Restaurant 2 | 3 | main :: IO () 4 | main = demo1 5 | -------------------------------------------------------------------------------- /profiling/Restaurants/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Reinstall library 4 | cd ../.. 5 | cabal install -p 6 | cd "profiling/Restaurants/" 7 | 8 | # Compile test file 9 | ghc -rtsopts -prof -auto-all -fforce-recomp -O2 "Restaurants.hs" 10 | 11 | # Run with profiling enabled and view output 12 | ./Restaurants +RTS -K100M -s -p # -hy 13 | cat "Restaurants.prof" 14 | # hp2ps -c "Restaurants.hp" 15 | # open "Restaurants.ps" 16 | 17 | # Clean up files from last compilation 18 | rm *.hi *.o *.aux 19 | rm Restaurants -------------------------------------------------------------------------------- /src/AI/Core/Agents.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, FlexibleContexts #-} 2 | 3 | module AI.Core.Agents where 4 | 5 | import System.IO.Unsafe 6 | 7 | ------------- 8 | -- Objects -- 9 | ------------- 10 | 11 | -- |The Object class represents any physical object that can appear in an 12 | -- Environment. You should create instances of Object to get the objects you 13 | -- want. 14 | class Show o => Object o where 15 | -- |Objects that are alive should return 'True' 16 | isAlive :: o -> Bool 17 | 18 | ------------ 19 | -- Agents -- 20 | ------------ 21 | 22 | -- |An Agent is a type with one method, 'program', which returns a function 23 | -- of one argument, percept -> action. An agent program that needs a model of 24 | -- the world (and of the agent itself) will have to build and maintain its 25 | -- own model. 26 | class Object (agent p a) => Agent agent p a where 27 | -- |Given a percept, return an appropriate acion. 28 | program :: agent p a -> p -> a 29 | 30 | -- |A wrapper for an agent that will print its percepts and actions. This will 31 | -- let you see what the agent is doing in the environment. 32 | newtype TraceAgent agent p a = TraceAgent { getAgent :: agent p a } deriving (Show) 33 | 34 | -- |Make a TraceAgent into an instance of Object by wrapping 'isAlive'. 35 | instance Object (agent p a) => Object (TraceAgent agent p a) where 36 | isAlive (TraceAgent agent) = isAlive agent 37 | 38 | -- |Make a TraceAgent into an instance of Agent by wrapping 'program'. 39 | instance (Agent agent p a, Show p, Show a) => Agent (TraceAgent agent) p a where 40 | program (TraceAgent agent) p = unsafePerformIO $ do 41 | let a = program agent p 42 | putStrLn $ 43 | show agent ++ " perceives " ++ show p ++ " and does " ++ show a 44 | return a 45 | 46 | ------------------ 47 | -- Environments -- 48 | ------------------ 49 | 50 | -- |The Environment class contains is used to maintain a collection of objects 51 | -- and agents. 52 | class (Object o, Agent agent p a) => Environment e o agent p a where 53 | -- |Return a list of objects in the current environment. 54 | objects :: e o agent p a -> [o] 55 | 56 | -- |Return a list of agents in the current environment. 57 | agents :: e o agent p a -> [agent p a] 58 | 59 | -- |Return the percent that the agent sees at this point. 60 | percept :: e o agent p a -> agent p a -> p 61 | 62 | -- |Modify the environment to reflect an agent taking an action. 63 | execAction :: e o agent p a -> agent p a -> a -> e o agent p a 64 | 65 | -- |Is the execution over? By default, we are done when we can't find a 66 | -- live agent. 67 | isDone :: e o agent p a -> Bool 68 | isDone env = not $ any isAlive $ agents env 69 | 70 | -------------------------------------------------------------------------------- /src/AI/Learning/Bootstrap.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.Bootstrap where 2 | 3 | import Control.Monad.Random 4 | import Foreign.Storable (Storable) 5 | import Numeric.LinearAlgebra 6 | import qualified Data.List as L 7 | 8 | import AI.Util.Matrix 9 | import AI.Util.Util 10 | 11 | --------------- 12 | -- Bootstrap -- 13 | --------------- 14 | 15 | -- |Generate a bootstrap sample of size @sz@. 16 | genBootstrapSample :: RandomGen g => Int -> Rand g [Int] 17 | genBootstrapSample sz = go sz [] 18 | where go 0 accum = return accum 19 | go n accum = do 20 | i <- getRandomR (0,sz-1) 21 | go (n - 1) (i:accum) 22 | 23 | sampleVector :: (Storable a, RandomGen g) => Vector a -> Rand g (Vector a) 24 | sampleVector v = do 25 | idx <- genBootstrapSample (dim v) 26 | return (v `subRefVec` idx) 27 | 28 | sampleMatrixRows :: (Element a, RandomGen g) => Matrix a -> Rand g (Matrix a) 29 | sampleMatrixRows m = do 30 | idx <- genBootstrapSample (rows m) 31 | return $ m `subRefRows` idx 32 | 33 | sampleMatrixCols :: (Element a, RandomGen g) => Matrix a -> Rand g (Matrix a) 34 | sampleMatrixCols m = do 35 | idx <- genBootstrapSample (cols m) 36 | return $ m `subRefCols` idx 37 | 38 | -- |Generate a bootstrap sample of a statistic from a data set of type /a/. 39 | bootStrapResample :: RandomGen g => 40 | (a -> Rand g a) -- Sampling function 41 | -> (a -> b) -- Statistic to be resampled 42 | -> Int -- Number of resamples 43 | -> a -- Data 44 | -> Rand g [b] 45 | bootStrapResample sample func nSamples x = go [] nSamples 46 | where 47 | go samples 0 = return samples 48 | go samples n = do 49 | x' <- sample x 50 | go (func x' : samples) (n-1) 51 | -------------------------------------------------------------------------------- /src/AI/Learning/Core.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.Core where 2 | 3 | import Numeric.LinearAlgebra 4 | import Numeric.GSL.Minimization 5 | 6 | --------------- 7 | -- Utilities -- 8 | --------------- 9 | 10 | -- |Sigmoid function: 11 | -- 12 | -- > sigmoid x = 1 / (1 + exp (-x)) 13 | -- 14 | -- Used in the logistic regression and neural network modules. 15 | sigmoid :: Floating a => a -> a 16 | sigmoid x = 1 / (1 + exp (-x)) 17 | 18 | ------------------ 19 | -- Optimization -- 20 | ------------------ 21 | 22 | -- |Simplified minimization function. You supply functions that compute the 23 | -- quantity to be minimized and its gradient, and an initial guess, and the 24 | -- final solution is returned. 25 | minimizeS :: (Vector Double -> Double) -- f 26 | -> (Vector Double -> Vector Double) -- gradient 27 | -> Vector Double -- initial x 28 | -> Vector Double 29 | minimizeS f g x = fst $ minimizeVD VectorBFGS2 prec niter sz tol f g x 30 | where prec = 1e-9 31 | niter = 1000 32 | sz = 0.1 33 | tol = 0.1 34 | -------------------------------------------------------------------------------- /src/AI/Learning/CrossValidation.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.CrossValidation where 2 | 3 | import Control.Monad.Random 4 | import Foreign.Storable (Storable) 5 | import Numeric.LinearAlgebra 6 | import qualified Data.List as L 7 | 8 | import AI.Util.Matrix 9 | import AI.Util.Util 10 | 11 | ---------------------- 12 | -- Cross Validation -- 13 | ---------------------- 14 | 15 | class Indexable c where 16 | index :: c -> Index -> c 17 | nobs :: c -> Int 18 | 19 | instance Storable a => Indexable (Vector a) where 20 | index = subRefVec 21 | nobs = dim 22 | 23 | instance Element a => Indexable (Matrix a) where 24 | index = subRefRows 25 | nobs = rows 26 | 27 | instance Indexable [a] where 28 | index = map . (!!) 29 | nobs = length 30 | 31 | -- |Indexes are lists of 'Int'. Should refactor this to use something more 32 | -- efficient. 33 | type Index = [Int] 34 | 35 | -- |Type for cross-validation partition. 36 | data CVPartition = CVPartition [(Index, Index)] 37 | 38 | -- |Specify what type of cross-validation you want to do. 39 | data CVType = LeaveOneOut 40 | | KFold Int 41 | 42 | -- |Prediction function. A prediction function should take a training and a test 43 | -- set, and use the training set to build a model whose performance is 44 | -- evaluated on the test set, returning a final score as a 'Double'. 45 | type PredFun a b = a -- Training set predictors 46 | -> b -- Training set target 47 | -> a -- Test set predictors 48 | -> b -- Test set target 49 | -> Double -- Performance score 50 | 51 | -- |Create a partition into test and training sets. 52 | cvPartition :: RandomGen g => Int -> CVType -> Rand g CVPartition 53 | cvPartition sz cvtype = case cvtype of 54 | KFold i -> cvp sz i 55 | LeaveOneOut -> cvp sz sz 56 | 57 | -- |Helper function for 'cvPartition'. 58 | cvp :: RandomGen g => Int -> Int -> Rand g CVPartition 59 | cvp n k = do 60 | is <- go i (k - i) idx 61 | return . CVPartition $ map (\i -> (idx L.\\ i, i)) is 62 | where 63 | go 0 0 idx = return [] 64 | 65 | go 0 j idx = do 66 | (is, idx') <- selectMany' s idx 67 | iss <- go 0 (j-1) idx' 68 | return (is:iss) 69 | 70 | go i j idx = do 71 | (is, idx') <- selectMany' (s+1) idx 72 | iss <- go (i-1) j idx' 73 | return (is:iss) 74 | 75 | s = n `div` k 76 | i = n `mod` k 77 | idx = [0 .. n-1] 78 | 79 | -- |Perform k-fold cross-validation. Given a 'CVPartition' containing a list 80 | -- of training and test sets, we repeatedly fit a model on the training set 81 | -- and test its performance on the test set/ 82 | kFoldCV_ :: (Indexable a, Indexable b) => 83 | CVPartition 84 | -> PredFun a b 85 | -> a 86 | -> b 87 | -> [Double] 88 | kFoldCV_ (CVPartition partition) predfun x y = map go partition 89 | where 90 | go (trainIdx,testIdx) = predfun xTrain yTrain xTest yTest 91 | where 92 | xTrain = x `index` trainIdx 93 | yTrain = y `index` trainIdx 94 | xTest = x `index` testIdx 95 | yTest = y `index` testIdx 96 | 97 | -- |Perform k-fold cross-validation, randomly generating the training and 98 | -- test sets first. 99 | kFoldCV :: (RandomGen g, Indexable a, Indexable b) => 100 | CVType -- What type of cross-validation? 101 | -> PredFun a b -- Prediction function 102 | -> a -- Predictors 103 | -> b -- Targets 104 | -> Rand g [Double] -- List of scores 105 | kFoldCV cvtype predfun x y = if nobs x /= nobs y 106 | then error "Inconsistent dimensions -- KFOLDCV" 107 | else do 108 | cp <- cvPartition (nobs x) cvtype 109 | return (kFoldCV_ cp predfun x y) 110 | 111 | --------------- 112 | -- Old Stuff -- 113 | --------------- 114 | 115 | -- |Model builder. A model builder takes a training set of regressors and 116 | -- targets, and constructs a function that makes predictions from an out- 117 | -- of-sample set of regressors. 118 | type ModelBuilder = Matrix Double -- Training set regressors 119 | -> Vector Double -- Training set target 120 | -> Matrix Double -- Out-of-sample regressors 121 | -> Vector Double -- Predictions 122 | 123 | -- |Evaluation function. An evaluation function takes a vector of targets and 124 | -- a vector of predictions, and returns a score corresponding to how closely 125 | -- the predictions match the target. 126 | type EvalFun = Vector Double -- Target 127 | -> Vector Double -- Predictions 128 | -> Double -- Score (e.g. MSE, MCR, likelihood) -------------------------------------------------------------------------------- /src/AI/Learning/DecisionTree.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables, BangPatterns #-} 2 | 3 | module AI.Learning.DecisionTree where 4 | 5 | import Control.Monad.Random 6 | import Data.Map (Map, (!)) 7 | import Data.Ord (comparing) 8 | import qualified Data.List as L 9 | import qualified Data.Map as M 10 | 11 | import AI.Util.Util 12 | 13 | ---------------- 14 | -- Attributes -- 15 | ---------------- 16 | 17 | -- |An attribute is anything that can split the data into a number of classes. 18 | data Att a = Att { test :: a -> Int 19 | , vals :: [Int] 20 | , label :: String } 21 | 22 | instance Show (Att a) where 23 | show att = "Att(" ++ label att ++ ")" 24 | 25 | instance Eq (Att a) where 26 | Att _ _ lab1 == Att _ _ lab2 = lab1 == lab2 27 | 28 | ------------------- 29 | -- Decision Tree -- 30 | ------------------- 31 | 32 | -- |A decision tree which makes decisions based on attributes of type @a@ and 33 | -- returns results of type @b@. We store information of type @i@ at the nodes, 34 | -- which is useful for pruning the tree later. 35 | data DTree a i b = Result b 36 | | Decision (Att a) i (Map Int (DTree a i b)) 37 | 38 | instance Show b => Show (DTree a i b) where 39 | show (Result b) = show b 40 | show (Decision att _ ts) = 41 | "Decision " ++ show att ++ " " ++ show (M.elems ts) 42 | 43 | instance Functor (DTree a i) where 44 | fmap f (Result b) = Result (f b) 45 | fmap f (Decision att i branches) = Decision att i (fmap (fmap f) branches) 46 | 47 | instance Monad (DTree a i) where 48 | return b = Result b 49 | Result b >>= f = f b 50 | Decision att i ts >>= f = Decision att i (fmap (>>=f) ts) 51 | 52 | mapI :: (i -> j) -> DTree a i b -> DTree a j b 53 | mapI f (Result b) = Result b 54 | mapI f (Decision att i ts) = Decision att (f i) (fmap (mapI f) ts) 55 | 56 | dropInfo :: DTree a i b -> DTree a () b 57 | dropInfo = mapI $ const () 58 | 59 | -- |Create an attribution from a function and its name. 60 | att :: forall a b. (Enum b, Bounded b) => (a -> b) -> String -> Att a 61 | att f str = Att (fromEnum . f) (map fromEnum vs) str 62 | where 63 | vs = enum :: [b] 64 | 65 | -- |Create a simple decision tree from a function. 66 | attribute :: (Enum b,Bounded b) => (a -> b) -> String -> DTree a () b 67 | attribute f label = Decision (att f label) () tree 68 | where 69 | tree = M.fromList $ zip [0..] (map Result enum) 70 | 71 | -- |Run the decision tree on an example 72 | decide :: DTree a i b -> a -> b 73 | decide (Result b) _ = b 74 | decide (Decision att _ branches) a = decide (branches ! test att a) a 75 | 76 | -- |Fit a decision tree to data, using the ID-3 algorithm. 77 | fitTree :: (Ord a,Ord b) => (a -> b) -> [Att a] -> [a] -> DTree a () b 78 | fitTree target atts as = 79 | dropInfo $ fmap mode $ decisionTreeLearning target atts [] as 80 | 81 | ---------------- 82 | -- Data Split -- 83 | ---------------- 84 | 85 | -- |The decision-tree learning algorithm (Fig 18.5). This returns a list of 86 | -- elements at each leaf. You can 'fmap' the 'mode' function over the leaves 87 | -- to get the plurality value at that leaf, or the 'uniform' function to get 88 | -- a probability distribution. 89 | decisionTreeLearning :: Ord b => 90 | (a -> b) -- Target function 91 | -> [Att a] -- Attributes to split on 92 | -> [a] -- Examples from the parent node 93 | -> [a] -- Examples to be split at this node 94 | -> DTree a [b] [b] 95 | decisionTreeLearning target atts ps as 96 | | null as = Result ps' 97 | | null atts || allEqual as' = Result as' 98 | | otherwise = 99 | Decision att as' (fmap (decisionTreeLearning target atts' as) m) 100 | where 101 | (att,atts',m) = 102 | L.minimumBy (comparing (\(_,_,m) -> func (M.elems m))) choices 103 | 104 | choices = 105 | [ (att,atts',partition att as) | (att,atts') <- points atts ] 106 | 107 | func = sumEntropy target 108 | 109 | as' = map target as 110 | ps' = map target ps 111 | 112 | -- |Partition a list based on a function that maps elements of the list to 113 | -- integers. This assumes that 114 | partition :: Att a -> [a] -> Map Int [a] 115 | partition att as = L.foldl' fun initial as 116 | where 117 | fun m a = M.insertWith' (++) (test att a) [a] m 118 | initial = mkUniversalMap (vals att) [] 119 | 120 | entropy :: Ord a => [a] -> Float 121 | entropy as = entropy' probs 122 | where 123 | entropy' ps = negate . sum $ map (\p -> if p == 0 then 0 else p * log p) ps 124 | probs = map ((/len) . fromIntegral) $ M.elems $ L.foldl' go M.empty as 125 | go m a = M.insertWith' (const (+1)) a 1 m 126 | len = fromIntegral (length as) 127 | 128 | -- |When given a target function, this can be used as an input to the 'minSplit' 129 | -- routine. 130 | sumEntropy :: Ord b => (a -> b) -> [[a]] -> Float 131 | sumEntropy target as = sum $ map (entropy . map target) as 132 | 133 | ------------- 134 | -- Pruning -- 135 | ------------- 136 | 137 | -- |Prune a tree to have a maximum depth of decisions. 138 | maxDecisions :: Int -> DTree a b b -> DTree a b b 139 | maxDecisions i (Decision att as ts) = 140 | if i == 0 141 | then Result as 142 | else Decision att as $ fmap (maxDecisions (i-1)) ts 143 | maxDecisions _ r = r 144 | 145 | -- |Prune decisions using a predicate. 146 | prune :: (b -> Bool) -> DTree a b b -> DTree a b b 147 | prune _ (Result b) = Result b 148 | prune p (Decision att i ts) = 149 | if p i 150 | then Result i 151 | else Decision att i (fmap (prune p) ts) 152 | 153 | ------------- 154 | -- Testing -- 155 | ------------- 156 | 157 | -- |Compute the misclassification rate (MCR) of a particular decision tree 158 | -- on a data set. 159 | mcr :: Eq b => 160 | (a -> b) -- Classification algorithm 161 | -> [a] -- List of elements to be classified 162 | -> [b] -- List of correct classifications 163 | -> Double -- Misclassification rate 164 | mcr predfun as bs = 165 | let bsPred = map predfun as 166 | numCorrect = countIf id (zipWith (==) bs bsPred) 167 | numTotal = length as 168 | in fromIntegral (numTotal - numCorrect) / fromIntegral numTotal 169 | 170 | predfun xtrain ytrain xtest ytest = undefined -------------------------------------------------------------------------------- /src/AI/Learning/Example/Restaurant.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.Example.Restaurant where 2 | 3 | import Control.Monad 4 | import Control.Monad.Random 5 | import qualified Graphics.Gnuplot.Simple as G 6 | import System.IO.Unsafe 7 | 8 | import AI.Learning.Core 9 | import AI.Learning.DecisionTree 10 | import qualified AI.Learning.RandomForest as RF 11 | import AI.Util.Util 12 | 13 | data Patrons = Empty | Some | Full deriving (Show,Eq,Ord,Enum,Bounded) 14 | data Price = Cheap | Medium | Expensive deriving (Show,Eq,Ord,Enum,Bounded) 15 | data Type = French | Thai | Burger | Italian deriving (Show,Eq,Ord,Enum,Bounded) 16 | data Wait = None | Short | Med | Long deriving (Show,Eq,Ord,Enum,Bounded) 17 | 18 | data Restaurant = Restaurant { 19 | alt :: Bool, -- is there an alternative? 20 | bar :: Bool, -- is there a bar? 21 | fri :: Bool, -- is it a friday? 22 | hun :: Bool, -- are you hungry? 23 | pat :: Patrons, -- how many patrons are there? 24 | price :: Price, -- how cheap is it? 25 | rain :: Bool, -- is it raining? 26 | res :: Bool, -- do you have a reservation? 27 | food :: Type, -- what type of food is it? 28 | wait :: Wait, -- what is the wait? 29 | willWait :: Bool -- will you wait? 30 | } deriving (Show,Eq,Ord) 31 | 32 | atts :: [Att Restaurant] 33 | atts = [ att alt "Alternative" 34 | , att bar "Bar" 35 | , att fri "Friday" 36 | , att hun "Hungry" 37 | , att pat "Patrons" 38 | , att price "Price" 39 | , att rain "Raining" 40 | , att res "Reservation" 41 | , att food "Food" 42 | , att wait "Wait" ] 43 | 44 | randomRestaurantNoisy :: RandomGen g => Float -> Rand g Restaurant 45 | randomRestaurantNoisy noise = do 46 | alt <- getRandom 47 | bar <- getRandom 48 | fri <- getRandom 49 | hun <- getRandom 50 | pat <- getRandomEnum 3 51 | price <- getRandomEnum 3 52 | rain <- getRandom 53 | res <- getRandom 54 | food <- getRandomEnum 4 55 | wait <- getRandomEnum 4 56 | let mkR ww = Restaurant alt bar fri hun pat price rain res food wait ww 57 | willWait = decide actualTree (mkR False) 58 | p <- getRandomR (0,1) 59 | return $ if p > noise 60 | then mkR willWait 61 | else mkR (not willWait) 62 | 63 | randomRestaurant :: RandomGen g => Rand g Restaurant 64 | randomRestaurant = do 65 | alt <- getRandom 66 | bar <- getRandom 67 | fri <- getRandom 68 | hun <- getRandom 69 | pat <- getRandomEnum 3 70 | price <- getRandomEnum 3 71 | rain <- getRandom 72 | res <- getRandom 73 | food <- getRandomEnum 4 74 | wait <- getRandomEnum 4 75 | let mkR ww = Restaurant alt bar fri hun pat price rain res food wait ww 76 | willWait = decide actualTree (mkR False) 77 | return (mkR willWait) 78 | 79 | randomDataSetNoisy :: RandomGen g => Float -> Int -> Rand g [Restaurant] 80 | randomDataSetNoisy noise n = replicateM n (randomRestaurantNoisy noise) 81 | 82 | randomDataSet :: RandomGen g => Int -> Rand g [Restaurant] 83 | randomDataSet n = replicateM n randomRestaurant 84 | 85 | -------------------- 86 | -- Model builders -- 87 | -------------------- 88 | 89 | treeBuilder :: [Restaurant] -> [Bool] -> Restaurant -> Bool 90 | treeBuilder as _ a = 91 | let tree = fitTree willWait atts as 92 | in decide tree a 93 | 94 | forestBuilder :: Int -> Int -> [Restaurant] -> [Bool] -> Restaurant -> Bool 95 | forestBuilder nTree nAtt as bs a = do 96 | let forest = unsafePerformIO $ evalRandIO $ 97 | RF.randomForest nTree nAtt willWait atts as 98 | in RF.decide forest a 99 | 100 | --------------------------------------- 101 | -- Demo of the decision tree library -- 102 | --------------------------------------- 103 | 104 | --runWithNoise :: RandomGen g => 105 | -- Builder Restaurant Bool 106 | -- -> Int 107 | -- -> Int 108 | -- -> Float 109 | -- -> Rand g Float 110 | --runWithNoise builder nTrain nTest noise = do 111 | -- xTrain <- randomDataSetNoisy noise nTrain 112 | -- xTest <- randomDataSetNoisy noise nTest 113 | -- let yTrain = map willWait xTrain 114 | -- yTest = map willWait xTest 115 | -- return (crossValidate builder xTrain yTrain xTest yTest) 116 | 117 | --runNoNoise :: RandomGen g => 118 | -- Int 119 | -- -> Int 120 | -- -> Rand g Float 121 | --runNoNoise nTrain nTest = do 122 | -- xTrain <- randomDataSet nTrain 123 | -- xTest <- randomDataSet nTest 124 | -- let yTrain = map willWait xTrain 125 | -- yTest = map willWait xTest 126 | -- return (crossValidate treeBuilder xTrain yTrain xTest yTest) 127 | 128 | --demo2 :: Float -> IO () 129 | --demo2 noise = do 130 | -- vals <- evalRandIO $ do 131 | -- let ns = [1..100] 132 | -- mcrs <- forM ns $ \n -> do 133 | -- sampleMcrs <- replicateM 100 $ runWithNoise treeBuilder n 100 noise 134 | -- return (mean sampleMcrs) 135 | -- return (zip ns $ map (*100) mcrs) 136 | 137 | -- let xlabel = G.XLabel "Size of test set" 138 | -- ylabel = G.YLabel "Misclassification Rate (%)" 139 | -- title = G.Title "Decision Tree Demo (Restaurants)" 140 | 141 | -- G.plotList [xlabel,ylabel,title] vals 142 | 143 | --demo1 :: IO () 144 | --demo1 = do 145 | -- vals <- evalRandIO $ do 146 | -- let ns = [1..100] 147 | -- mcrs <- forM ns $ \n -> do 148 | -- sampleMcrs <- replicateM 100 $ runNoNoise n 100 149 | -- return (mean sampleMcrs) 150 | -- return (zip ns $ map (*100) mcrs) 151 | 152 | -- let xlabel = G.XLabel "Size of test set" 153 | -- ylabel = G.YLabel "Misclassification Rate (%)" 154 | -- title = G.Title "Decision Tree Demo (Restaurants)" 155 | 156 | -- G.plotList [xlabel,ylabel,title] vals 157 | 158 | -------------------------------------- 159 | -- The decision tree in Figure 18.2 -- 160 | -------------------------------------- 161 | 162 | actualTree :: DTree Restaurant () Bool 163 | actualTree = do 164 | patrons <- attribute pat "Patrons" 165 | case patrons of 166 | Empty -> return False 167 | Some -> return True 168 | Full -> do 169 | time <- attribute wait "WaitTime" 170 | case time of 171 | None -> return True 172 | Short -> do 173 | hungry <- attribute hun "Hungry" 174 | if not hungry 175 | then return True 176 | else do 177 | alternative <- attribute alt "Alternative" 178 | if not alternative 179 | then return True 180 | else do 181 | raining <- attribute rain "Rain" 182 | return (if raining then True else False) 183 | Med -> do 184 | alternative <- attribute alt "Alternative" 185 | if not alternative 186 | then do 187 | reservation <- attribute res "Reservation" 188 | if reservation 189 | then return True 190 | else do 191 | hasBar <- attribute bar "Bar" 192 | return (if hasBar then True else False) 193 | else do 194 | friday <- attribute fri "Fri/Sat" 195 | return (if friday then True else False) 196 | Long -> return False 197 | 198 | ------------------------------------------ 199 | -- This is the example in AIMA Fig 18.3 -- 200 | ------------------------------------------ 201 | 202 | restaurants = 203 | [ Restaurant True False False True Some Expensive False True French None True 204 | , Restaurant True False False True Full Cheap False False Thai Med False 205 | , Restaurant False True False False Some Cheap False False Burger None True 206 | , Restaurant True False True True Full Cheap True False Thai Short True 207 | , Restaurant True False True False Full Expensive False True French Long False 208 | , Restaurant False True False True Some Medium True True Italian None True 209 | , Restaurant False True False False Empty Cheap True False Burger None False 210 | , Restaurant False False False True Some Medium True True Thai None True 211 | , Restaurant False True True False Full Cheap True False Burger Long False 212 | , Restaurant True True True True Full Expensive False True Italian Short False 213 | , Restaurant False False False False Empty Cheap False False Thai None False 214 | , Restaurant True True True True Full Cheap False False Burger Med True ] 215 | 216 | fittedTree = fitTree willWait atts restaurants 217 | 218 | 219 | -------------------------------------------------------------------------------- /src/AI/Learning/Example/Students.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.Example.Students where 2 | 3 | import qualified Data.Map as M 4 | import AI.Learning.DecisionTree 5 | 6 | data Student = Student { 7 | firstLastYear :: Bool, 8 | male :: Bool, 9 | worksHard :: Bool, 10 | drinks :: Bool, 11 | firstThisYear :: Bool } deriving (Eq,Ord,Show) 12 | 13 | students = [richard,alan,alison,jeff,gail,simon] 14 | 15 | richard = Student True True False True True 16 | alan = Student True True True False True 17 | alison = Student False False True False True 18 | jeff = Student False True False True False 19 | gail = Student True False True True True 20 | simon = Student False True True True False 21 | 22 | matthew = Student False True False True True 23 | mary = Student False False True True False 24 | 25 | -- |Attributes. 26 | atts = [ att firstLastYear "firstLastYear" 27 | , att male "male" 28 | , att worksHard "worksHard" 29 | , att drinks "drinks" ] 30 | 31 | --tree = fitTree firstThisYear atts students 32 | -------------------------------------------------------------------------------- /src/AI/Learning/LinearRegression.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FunctionalDependencies, FlexibleInstances, FlexibleContexts, BangPatterns #-} 2 | 3 | module AI.Learning.LinearRegression where 4 | 5 | import Data.List (foldl') 6 | import Numeric.LinearAlgebra 7 | import Numeric.LinearAlgebra.Util (ones) 8 | 9 | import AI.Util.Matrix 10 | 11 | {- 12 | 13 | A linear model should contain the following information: 14 | 15 | * Regression coefficients 16 | * Covariance matrix of regression coefficients 17 | * Residuals 18 | * Confidence intervals for regression coefficients 19 | * Confidence intervals for residuals 20 | * T-Statistic for regression coefficients 21 | * P-Values for regression coefficients 22 | * F-Statistic for regression 23 | * P-Value for regression 24 | * MSE 25 | * SST 26 | * SSE 27 | * PRESS Statistic 28 | * R squared 29 | * adjusted r squared 30 | * 31 | 32 | -} 33 | 34 | ------------- 35 | -- Options -- 36 | ------------- 37 | 38 | -- |Type recording what kind of linear model we want to build. Choose from 39 | -- ordinary least squares (OLS), ridge regression with parameter lambda 40 | -- and LASSO with parameter lambda. 41 | data LMType = OLS 42 | | Ridge Double 43 | | LASSO Double 44 | deriving (Show) 45 | 46 | ---- |Data type for linear regression options. 47 | data LMOpts = LMOpts { fitIntercept :: Bool 48 | , standardizeRegressors :: Bool } 49 | deriving (Show) 50 | 51 | ---- |Standard options for linear regression (used by 'regress'). 52 | stdLMOpts = LMOpts { fitIntercept = True 53 | , standardizeRegressors = False } 54 | 55 | --------------------- 56 | ---- Linear Models -- 57 | --------------------- 58 | 59 | -- |Data type for a linear model. It consists of the coefficient vector, 60 | -- together with options that specify how to transform any data that is used 61 | -- to make new predictions. 62 | data LinearModel = LM { coefs :: Vector Double 63 | , lmIntercept :: Bool 64 | , lmStandardize :: Bool 65 | , lmMean :: Vector Double 66 | , lmStd :: Vector Double } 67 | deriving (Show) 68 | 69 | -- |Statistics structure for a linear regression. 70 | data LMStats = LMStats { covBeta :: Maybe (Matrix Double) 71 | , betaCI :: Maybe (Matrix Double) 72 | , sst :: Double 73 | , sse :: Double 74 | , mse :: Double 75 | , rSquare :: Double 76 | , tBeta :: Maybe (Vector Double) 77 | , pBeta :: Maybe (Vector Double) 78 | , fRegression :: Maybe Double 79 | , pRegression :: Maybe Double } 80 | deriving (Show) 81 | 82 | type Predictor = Matrix Double 83 | type Response = Vector Double 84 | 85 | -- |Fit an ordinary least squares linear model to data. 86 | lm :: Matrix Double -> Vector Double -> LinearModel 87 | lm = lmWith OLS stdLMOpts 88 | 89 | -- |Fit a ridge regression (least squares with quadratic penalty) to data. 90 | lmRidge :: Matrix Double -> Vector Double -> Double -> LinearModel 91 | lmRidge x y lambda = lmWith (Ridge lambda) stdLMOpts x y 92 | 93 | -- |Fit a linear model to data. The exact model fitted is specified by the 94 | -- 'LMType' argument, and the options are specified in 'LMOpts'. 95 | lmWith :: LMType -> LMOpts -> Matrix Double -> Vector Double -> LinearModel 96 | lmWith kind opts x y = LM { coefs = beta 97 | , lmIntercept = fitIntercept opts 98 | , lmStandardize = standardizeRegressors opts 99 | , lmMean = mu 100 | , lmStd = sigma } 101 | where 102 | (xx,mu,sigma) = lmPrepare opts x 103 | beta = case kind of 104 | OLS -> regress xx y 105 | Ridge a -> ridgeRegress xx y (fitIntercept opts) a 106 | LASSO a -> error "LASSO not implemented." 107 | 108 | -- |Prepare data according to the specified options structure. This may involve 109 | -- centering and standardizing the data, or adding a column of constants. 110 | lmPrepare :: LMOpts 111 | -> Matrix Double 112 | -> (Matrix Double, Vector Double, Vector Double) 113 | lmPrepare opts x = (x3,mu,sigma) 114 | where 115 | (x1,mu,sigma) = standardize x 116 | x2 = if standardizeRegressors opts then x1 else x 117 | x3 = if fitIntercept opts then addOnes x2 else x2 118 | 119 | --------------------------------- 120 | ---- Predict from linear model -- 121 | --------------------------------- 122 | 123 | -- |Make predictions from a linear model. 124 | lmPredict :: LinearModel -> Matrix Double -> Vector Double 125 | lmPredict model x = x2 <> beta 126 | where 127 | beta = coefs model 128 | xbar = lmMean model 129 | sigma = lmStd model 130 | x1 = if lmStandardize model 131 | then eachRow (\x -> (x - xbar) / sigma) x 132 | else x 133 | x2 = if lmIntercept model 134 | then addOnes x1 135 | else x1 136 | 137 | -- |Calculate statistics for a linear regression. 138 | lmStats :: LinearModel -> Matrix Double -> Vector Double -> LMStats 139 | lmStats model x y = 140 | let ybar = constant (mean y) (dim y) 141 | yhat = lmPredict model x 142 | residuals = y - yhat 143 | covBeta = Nothing 144 | betaCI = Nothing 145 | sst = sumVector $ (y - ybar) ^ 2 146 | sse = sumVector $ residuals ^ 2 147 | mse = sse / fromIntegral (rows x) 148 | rSq = 1 - sse / sst 149 | tBeta = Nothing 150 | pBeta = Nothing 151 | fReg = Nothing 152 | pReg = Nothing 153 | in LMStats covBeta betaCI sst sse mse rSq tBeta pBeta fReg pReg 154 | 155 | mseEvalFun :: Response -> Response -> Double 156 | mseEvalFun actual predicted = mean $ (actual - predicted) ^ 2 157 | 158 | lmPredFun :: Predictor -> Response -> Predictor -> Response -> Double 159 | lmPredFun xtrain ytrain xtest ytest = mseEvalFun ytest ypred 160 | where 161 | ypred = lmPredict model xtest 162 | model = lm xtrain ytrain 163 | 164 | -------------------------- 165 | ---- Perform Regression -- 166 | -------------------------- 167 | 168 | ---- |Regress a vector y against a matrix of predictors x, with the specified 169 | ---- options. 170 | regress :: Matrix Double -- X 171 | -> Vector Double -- y 172 | -> Vector Double -- beta 173 | regress x y 174 | | rows x /= dim y = error "Inconsistent dimensions -- REGRESS" 175 | | otherwise = let (_,n) = size x 176 | (_,r) = qr x 177 | rr = takeRows n r 178 | in (trans rr <> rr) <\> trans x <> y 179 | 180 | -- |Ridge regression. This adds a penalty term to OLS regression, which 181 | -- discourages large coefficient values to prevent overfitting. 182 | ridgeRegress :: Matrix Double -- X 183 | -> Vector Double -- y 184 | -> Bool -- useConst? 185 | -> Double -- lambda 186 | -> Vector Double -- beta 187 | ridgeRegress x y useConst lambda 188 | | rows x /= dim y = error "Inconsistent dimensions -- RIDGEREGRESS" 189 | | otherwise = let (_,n) = size x 190 | (_,r) = qr x 191 | rr = takeRows n r 192 | ww = if useConst 193 | then diag $ join [0, constant 1 (n-1)] 194 | else ident n 195 | in (trans rr <> rr + lambda `scale` ww) <\> trans x <> y 196 | 197 | ----------------- 198 | ---- Utilities -- 199 | ----------------- 200 | 201 | -- |De-mean a sample. 202 | demean :: Matrix Double -> Matrix Double 203 | demean x = eachRow (subtract $ mean x) x 204 | 205 | -- |Standardize a sample to have zero mean and unit variance. Returns a triple 206 | -- consisting of the standardized sample, the sample mean and the sample 207 | -- standard deviation. 208 | standardize :: Matrix Double -> (Matrix Double, Vector Double, Vector Double) 209 | standardize m = (eachRow (\x -> (x - mu) / sigma) m, mu, sigma) 210 | where mu = mean m 211 | sigma = std m 212 | 213 | -- |Standardize a sample to have zero mean and unit variance. This variant of 214 | -- the function discards the mean and standard deviation vectors, only 215 | -- returning the standardized sample. 216 | standardize_ :: Matrix Double -> Matrix Double 217 | standardize_ x = a where (a,_,_) = standardize x 218 | 219 | ------------------------- 220 | -- Mean, Variance etc. -- 221 | ------------------------- 222 | 223 | class Mean a b | a -> b where 224 | mean :: a -> b 225 | 226 | class Floating b => Variance a b | a -> b where 227 | var :: a -> b 228 | var x = y * y where y = std x 229 | 230 | std :: a -> b 231 | std x = sqrt $ var x 232 | 233 | instance Mean (Vector Double) Double where 234 | mean v = sumVector v / fromIntegral (dim v) 235 | 236 | instance Variance (Vector Double) Double where 237 | var v = mean $ (v - constant vbar (dim v)) ^ 2 238 | where vbar = mean v 239 | 240 | instance Mean (Matrix Double) (Vector Double) where 241 | mean m = fromList $ mapCols mean m 242 | 243 | instance Variance (Matrix Double) (Vector Double) where 244 | var m = fromList $ mapCols var m 245 | 246 | 247 | -------------------------------------------------------------------------------- /src/AI/Learning/LogisticRegression.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.LogisticRegression where 2 | 3 | import Numeric.LinearAlgebra 4 | import Numeric.LinearAlgebra.Util 5 | 6 | import AI.Learning.Core 7 | import AI.Util.Matrix 8 | 9 | -- |Multivariate logistic regression. Given a vector /y/ of target variables and 10 | -- a design matrix /x/, this function fits the parameters /theta/ of a 11 | -- logistic model, that is, 12 | -- 13 | -- y = sigmoid( x_1 * theta_1 + ... + x_n * theta_n ) 14 | -- 15 | -- Typically the values in the vector /y/ are either boolean (i.e. 0/1) or they 16 | -- represent frequency of observations, i.e. they are values between 0.0 and 17 | -- 1.0. 18 | -- 19 | -- The function fits /theta/ by numerically maximizing the likelihood function. 20 | -- It may be subject to overfit or non-convergence in the case where the number 21 | -- of observations is small or the predictors are highly correlated. 22 | lr :: Vector Double -- targets (y) 23 | -> Matrix Double -- design matrix (x) 24 | -> Vector Double -- coefficient vector (theta) 25 | lr y x = lrHelper (lrLogLikelihood y x) theta0 26 | where theta0 = constant 0 (cols x) 27 | 28 | -- |Regularized logistic regression with quadratic penalty on the coefficients. 29 | -- You should standardize the coefficients of the design matrix before using 30 | -- this function, as the regularization procedure is sensitive to the scale 31 | -- of the predictors. See also 'lr'. 32 | lrRegularized :: Vector Double -- targets (y) 33 | -> Matrix Double -- design matrix (x) 34 | -> Bool -- first column of design matrix is all ones? 35 | -> Double -- regularization constant (lambda) 36 | -> Vector Double -- coefficient vector (theta) 37 | lrRegularized y x useConst lambda = lrHelper costfun theta0 38 | where 39 | costfun = lrLogLikRegularized y x useConst lambda 40 | theta0 = constant 0 (cols x) 41 | 42 | -- |Helper function for logistic regressions. The first argument is a function 43 | -- that returns the cost and gradient for a given vector of parameters, and 44 | -- the second is the initial set of parameters to use. 45 | lrHelper :: (Vector Double -> (Double, Vector Double)) -> Vector Double -> Vector Double 46 | lrHelper fun theta0 = minimizeS cost grad theta0 47 | where cost = negate . fst . fun -- negate because we call minimize 48 | grad = negate . snd . fun 49 | 50 | -- |Cost function and derivative for logistic regression. This is maximized when 51 | -- fitting parameters for the regression. 52 | lrLogLikelihood :: Vector Double -- targets (y) 53 | -> Matrix Double -- design matrix (x) 54 | -> Vector Double -- coefficients (theta) 55 | -> (Double, Vector Double) -- (value, derivative) 56 | lrLogLikelihood y x theta = (cost, grad) 57 | where 58 | m = fromIntegral (rows x) -- For computing average 59 | h = sigmoid (x <> theta) -- Predictions for y 60 | cost = sumVector (y * log h + (1-y) * log (1-h)) / m 61 | grad = (1/m) `scale` (y - h) <> x 62 | 63 | -- |Cost function and derivative for regularized logistic regression. This is 64 | -- maximized when fitting parameters for the regression. 65 | lrLogLikRegularized :: Vector Double -- targets (y) 66 | -> Matrix Double -- design matrix (x) 67 | -> Bool -- is first col all ones? 68 | -> Double -- regularization const (lambda) 69 | -> Vector Double -- coefficients (theta) 70 | -> (Double, Vector Double) -- (value, derivative) 71 | lrLogLikRegularized y x useConst lambda theta = (cost, grad) 72 | where 73 | m = fromIntegral (rows x) 74 | (c,g) = lrLogLikelihood y x theta 75 | theta' = if useConst then join [0, dropVector 1 theta] else theta 76 | cost = c - (lambda / (2 * m)) * norm theta' ^ 2 77 | grad = g - (lambda / m) `scale` theta' 78 | 79 | ------------- 80 | -- Testing -- 81 | ------------- 82 | 83 | test n k lambda = do 84 | x <- randn n k -- design matrix 85 | e <- flatten `fmap` randn n 1 -- errors 86 | let theta = fromList $ 1 : replicate (k-1) 0 87 | h = sigmoid $ x <> theta + e 88 | y = (\i -> if i > 0.5 then 1 else 0) `mapVector` h 89 | theta_est1 = lr y x 90 | theta_est2 = lrRegularized y x False lambda 91 | --putStrLn $ "[y, h, x]" 92 | --disp 3 $ takeRows 10 $ fromColumns [y, h] ! x 93 | --putStrLn $ "[y, h, ypred]" 94 | --disp 3 $ takeRows 10 $ fromColumns [y, h, sigmoid $ x <> theta_est1] 95 | putStrLn $ " Number of observations: " ++ show n 96 | putStrLn $ " Number of regressors: " ++ show k 97 | putStrLn $ " Actual theta: " ++ show theta 98 | putStrLn $ "Estimated theta (unregularized): " ++ show theta_est1 99 | putStrLn $ " Estimated theta (regularized): " ++ show theta_est2 100 | 101 | -------------------------------------------------------------------------------- /src/AI/Learning/NeuralNetwork.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | 3 | module AI.Learning.NeuralNetwork 4 | ( 5 | -- * Representation 6 | NeuralNetwork 7 | -- * Prediction and Training 8 | , nnPredict 9 | , nnTrain 10 | , nnTrainIO 11 | ) where 12 | 13 | import Control.Monad.Random hiding (fromList) 14 | import Numeric.LinearAlgebra 15 | import Numeric.LinearAlgebra.Util 16 | import System.IO.Unsafe 17 | 18 | import AI.Learning.Core 19 | import AI.Util.Matrix 20 | 21 | ----------------------- 22 | -- NN Representation -- 23 | ---------------------- 24 | 25 | -- |Representation for a single-hidden-layer neural network. If the network 26 | -- has K input nodes, H hidden nodes and L output layers then the dimensions 27 | -- of the matrices theta0 and theta1 are 28 | -- 29 | -- * size Layer1 = (K+1) x H 30 | -- * size Layer2 = (H+1) x L 31 | -- 32 | -- the (+1)s account for the addition of bias nodes in the input layer and 33 | -- the hidden layer. Therefore the total number of parameters is 34 | -- H * (K + L) + H + L 35 | data NeuralNetwork = NN (Matrix Double) (Matrix Double) 36 | 37 | -- |Three-tuple describing the shape of a network: (input, hidden, output). 38 | type NNShape = (Int,Int,Int) 39 | 40 | instance Show NeuralNetwork where 41 | show (NN theta0 theta1) = "Neural Net:\n\n" ++ dispf 3 theta0 ++ 42 | "\n" ++ dispf 3 theta1 43 | 44 | fromVector :: NNShape -> Vector Double -> NeuralNetwork 45 | fromVector (k,h,l) vec = NN theta0 theta1 46 | where theta0 = reshape h $ takeVector ((k + 1) * h) vec 47 | theta1 = reshape l $ dropVector ((k + 1) * h) vec 48 | 49 | toVector :: Matrix Double -> Matrix Double -> Vector Double 50 | toVector theta0 theta1 = join [flatten theta0, flatten theta1] 51 | 52 | ---------------------- 53 | -- NN Train/Predict -- 54 | ---------------------- 55 | 56 | -- |Make predictions using a neural network. 57 | nnPredict :: NeuralNetwork -> Matrix Double -> Matrix Double 58 | nnPredict nn x = h where (_,_,h) = nnForwardProp nn x 59 | 60 | -- |Train a neural network from a training set. Note that you must supply 61 | -- an initial vector of weights. Supplying initial weights all equal to 62 | -- zero will generally give poor performance. 63 | nnTrain :: Int -- number of hidden neurons 64 | -> Matrix Double -- y 65 | -> Matrix Double -- x 66 | -> Double -- lambda 67 | -> Vector Double -- initial weights 68 | -> NeuralNetwork 69 | nnTrain h y x lambda initialVec = fromVector shape vec 70 | where shape = (cols x, h, cols y) 71 | vec = minimizeS cost grad initialVec 72 | cost = fst . nnCostGradient shape y x lambda 73 | grad = snd . nnCostGradient shape y x lambda 74 | 75 | -- |Train a neural network from a training set. 76 | -- 77 | -- Note that this is implemented as an IO action, because it randomly sets the 78 | -- weights in the network before performing numerical optimization (this is to 79 | -- break the symmetry that exists in the network when all weights are zero). 80 | -- As a result, it is possible to get different results when training the 81 | -- network multiple times with the same data. 82 | nnTrainIO :: Int -- number of hidden neurons 83 | -> Matrix Double -- y 84 | -> Matrix Double -- x 85 | -> Double -- lambda 86 | -> IO NeuralNetwork 87 | nnTrainIO h y x lambda = nnTrain h y x lambda `fmap` initialVec (cols x, h, cols y) 88 | 89 | -- |Choose initial random weights for a neural network. 90 | initialVec :: NNShape -> IO (Vector Double) 91 | initialVec (k,h,l) = do 92 | let len = h * (k + l) + h + l 93 | xs <- getRandomRs (0.0, 0.01) 94 | return . fromList $ take len xs 95 | 96 | ------------------------------ 97 | -- Forward/Back Propagation -- 98 | ------------------------------ 99 | 100 | -- |Perform forward propagation through a neural network, returning the matrices 101 | -- created in the process. 102 | nnForwardProp :: NeuralNetwork -- neural net 103 | -> Matrix Double -- design matrix (x) 104 | -> (Matrix Double, Matrix Double, Matrix Double) -- results of fwd prop 105 | nnForwardProp (NN theta0 theta1) x = (a0,a1,a2) 106 | where a0 = addOnes $ x 107 | a1 = addOnes $ sigmoid (a0 <> theta0) 108 | a2 = sigmoid (a1 <> theta1) 109 | 110 | -- |Perform backward propagiation through a neural network. You must supply the 111 | -- target values and the results of forward propagation for each layer, and 112 | -- the function returns the gradient matrices for the neural network. 113 | nnBackProp :: NeuralNetwork -- neural net 114 | -> Matrix Double -- target (y) 115 | -> (Matrix Double, Matrix Double, Matrix Double) -- results of fwd prop 116 | -> (Matrix Double, Matrix Double) -- gradient (delta0, delta1) 117 | nnBackProp (NN _ theta1) y (a0,a1,a2) = (dropColumns 1 delta0, delta1) 118 | where 119 | d2 = a2 - y 120 | d1 = (d2 <> trans theta1) * a1 * (1 - a1) 121 | delta0 = trans a0 <> d1 122 | delta1 = trans a1 <> d2 123 | 124 | -- |Perform back and forward propagation through a neural network, returning the 125 | -- final predictions (variable /a2/) and the gradient matrices (variables 126 | -- /delta0/ and /delta1/) produced. 127 | nnFwdBackProp :: NeuralNetwork -> Matrix Double -> Matrix Double -> (Matrix Double, Matrix Double, Matrix Double) 128 | nnFwdBackProp nn@(NN theta0 theta1) y x = (a2, delta0, delta1) 129 | where 130 | (a0,a1,a2) = nnForwardProp nn x 131 | (delta0,delta1) = nnBackProp nn y (a0,a1,a2) 132 | 133 | -- |Compute the penalty function and gradient vector for a neural network 134 | -- given a training set. 135 | nnCostGradient :: NNShape -- (K,H,L) 136 | -> Matrix Double -- targets (y) 137 | -> Matrix Double -- design matrix (x) 138 | -> Double -- regularization parameter (lambda) 139 | -> Vector Double -- neural network 140 | -> (Double, Vector Double) -- (cost, gradient) 141 | nnCostGradient shape y x lambda vec = (cost, grad) 142 | where 143 | m = fromIntegral (rows x) 144 | nn@(NN theta0 theta1) = fromVector shape vec 145 | (h, delta0, delta1) = nnFwdBackProp nn y x 146 | 147 | cost = (cost1 + cost2) / m 148 | cost1 = negate $ sumMatrix $ y * log h + (1-y) * log (1-h) 149 | cost2 = lambda/2 * (normMatrix theta0 + normMatrix theta1) 150 | 151 | grad = (1/m) `scale` (grad1 + grad2) 152 | grad1 = toVector delta0 delta1 153 | grad2 = lambda `scale` toVector (insertNils theta0) (insertNils theta1) 154 | 155 | normMatrix m = sumMatrix $ (dropRows 1 m) ^ 2 156 | insertNils m = vertcat [0, dropRows 1 m] 157 | 158 | nnCost shape y x lambda = fst . nnCostGradient shape y x lambda 159 | nnGrad shape y x lambda = snd . nnCostGradient shape y x lambda 160 | 161 | -- |Use central differencing to compute an approximation to the gradient 162 | -- vector for a neural network. This is mainly used for checking the 163 | -- implementation of backprop. 164 | nnGradApprox :: NNShape -> Matrix Double -> Matrix Double -> Double -> Vector Double -> Vector Double 165 | nnGradApprox shape y x lambda vec = fromList $ g `map` [0..n-1] 166 | where 167 | h = 1e-6 168 | n = dim vec 169 | f v = nnCost shape y x lambda v 170 | g i = (f (vec + e i) - f (vec - e i)) / (2*h) 171 | e i = fromList $ replicate i 0 ++ [h] ++ replicate (n-i-1) 0 172 | 173 | ------------- 174 | -- Testing -- 175 | ------------- 176 | 177 | testNN :: NeuralNetwork 178 | testNN = NN t0 t1 179 | where t0 = fromLists [[11.9934, -5.1396], [-7.7162, 10.1512], [-7.668, 10.1835]] 180 | t1 = fromLists [[-16.8806], [10.0445], [8.7476]] 181 | 182 | testFwdProp :: IO () 183 | testFwdProp = do 184 | putStrLn "***\nCompare forward propagation to the MATLAB implementation.\n" 185 | let theta0 = fromLists [[0], [-10]] 186 | theta1 = fromLists [[5],[-10]] 187 | nn = NN theta0 theta1 188 | x = fromLists [[-1],[0],[1]] 189 | -- sigmoid [10, 0, -10] = [1, 0.5, 0] 190 | -- sigmoid [-5, 0, 5] = [0.0, 0.5, 1] 191 | let y = nnPredict nn x 192 | putStrLn "Predictions (should be roughly 0.0, 0.5, 1.0)" 193 | disp 2 y 194 | 195 | testBackProp :: IO () 196 | testBackProp = do 197 | putStrLn "***\nCompare back propagation to the MATLAB implementation.\n" 198 | let nn = testNN 199 | x = fromLists [[0, 0], [0, 1], [1, 0], [1, 1]] 200 | y = fromLists [[0], [1], [1], [0]] 201 | (h, delta0, delta1) = nnFwdBackProp nn y x 202 | putStrLn "Predictions (should be 0.0011, 0.8487, 0.8476, 0.0004)" 203 | disp 4 h 204 | putStrLn "Delta 0 (should be -0.0401, -0.0171, -0.0205, -0.0088, -0.0195, -0.0084)" 205 | disp 4 delta0 206 | putStrLn "Delta 1 (should be -0.3022, -0.2985, -0.3013)" 207 | disp 4 delta1 208 | 209 | testCostGradient :: IO () 210 | testCostGradient = do 211 | putStrLn "***\nTest cost/gradient against the MATLAB implementation.\n" 212 | let nn@(NN theta0 theta1) = testNN 213 | x = fromLists [[0, 0], [0, 1], [1, 0], [1, 1]] 214 | y = fromLists [[0], [1], [1], [0]] 215 | lambda = 1e-4 216 | (cost, grad) = nnCostGradient (2,2,1) y x lambda (toVector theta0 theta1) 217 | putStrLn "Cost (should be around 0.0890)" 218 | print cost 219 | putStrLn "Gradient (should be -0.0100, -0.0043, -0.0053, -0.0019, -0.0051, -0.0019, -0.0755, -0.0744, -0.0751)" 220 | disp 4 (asColumn grad) 221 | 222 | test :: Int -> Double -> IO () 223 | test n lambda = do 224 | putStrLn "***\nLearning XOR function.\n" 225 | x <- rand n 2 226 | e <- fmap (0.01*) (rand n 1) 227 | let y = mapMatrix (\x -> if x > 0.5 then 1.0 else 0.0) (xor x) 228 | nn <- nnTrainIO 4 y x lambda 229 | let ypred = nnPredict nn x 230 | -- Show predictions 231 | putStrLn "Predictions:" 232 | disp 2 $ takeRows 10 $ horzcat [x, y, ypred] 233 | -- Show neural net 234 | print nn 235 | -- Final test; should approximately compute xor function 236 | let xx = fromLists [[0,0],[0,1],[1,0],[1,1]] 237 | yy = nnPredict nn xx 238 | putStrLn "Exclusive or:" 239 | disp 2 $ horzcat [xx,yy] 240 | 241 | xor :: Matrix Double -> Matrix Double 242 | xor x = let [u,v] = toColumns x in asColumn (u + v - 2 * u * v) 243 | 244 | -------------------------------------------------------------------------------- /src/AI/Learning/Perceptron.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.Perceptron where 2 | 3 | import Numeric.LinearAlgebra 4 | import Numeric.LinearAlgebra.Util 5 | 6 | perceptronPredict :: Vector Double -> Matrix Double -> Vector Double 7 | perceptronPredict weights x = step (x <> weights) 8 | 9 | perceptronCost :: Vector Double -> Vector Double -> Vector Double -> Double 10 | perceptronCost y yhat = undefined 11 | 12 | ---------------------- 13 | -- Gradient Descent -- 14 | ---------------------- 15 | 16 | gradientDescent :: (Vector Double -> Vector Double) -- f 17 | -> Vector Double -- x0 18 | -> Double -- alpha 19 | -> Double -- tol 20 | -> Vector Double 21 | gradientDescent g x0 alpha tol = go x0 (fun x0) 22 | where 23 | go x x' = if converged x x' 24 | then x' 25 | else go x' (fun x') 26 | converged a b = norm b / norm a - 1 < tol 27 | fun x = gradientDescentStep g alpha x 28 | 29 | gradientDescentStep :: (Vector Double -> Vector Double) -> Double -> Vector Double -> Vector Double 30 | gradientDescentStep g alpha x = x - alpha `scale` g x 31 | -------------------------------------------------------------------------------- /src/AI/Learning/RandomForest.hs: -------------------------------------------------------------------------------- 1 | module AI.Learning.RandomForest where 2 | 3 | import Control.Monad 4 | import Control.Monad.Random 5 | 6 | import AI.Learning.DecisionTree as D 7 | import AI.Util.Util 8 | 9 | -- |A forest is a list of decision trees. 10 | newtype Forest a b = Forest [D.DTree a () b] deriving (Show) 11 | 12 | -- |Create a new random forest. This function repeatedly selects a random 13 | -- subset of the attributes to split on, and fits an unpruned tree using 14 | -- a bootstrap sample of the observations. This is repeated many times, 15 | -- creating a /forest/ of decision trees. 16 | randomForest :: (Ord a, Ord b, RandomGen g) => 17 | Int -- Number of trees in the forest 18 | -> Int -- Number of attributes per tree 19 | -> (a -> b) -- Target attribute 20 | -> [Att a] -- Attributes to classify on 21 | -> [a] -- List of observations 22 | -> Rand g (Forest a b) 23 | randomForest nTree nAtt target atts as = fmap Forest (replicateM nTree go) 24 | where 25 | go = do atts' <- selectMany nAtt atts 26 | as' <- sampleWithReplacement (length as) as 27 | return $ D.fitTree target atts' as' 28 | 29 | -- |Use a forest to classify a new example. We run the example through each 30 | -- of the trees in the forest, and choose the most common classification. 31 | decide :: Ord b => Forest a b -> a -> b 32 | decide (Forest trees) a = mode $ map (`D.decide` a) trees -------------------------------------------------------------------------------- /src/AI/Logic/Core.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | 3 | -- |This module defines types and type classes that are used in the other Logic 4 | -- modules. 5 | module AI.Logic.Core ( 6 | -- * Expression 7 | Expr(..) 8 | -- * Knowledge Base 9 | , KB(..) 10 | -- * Error Handling 11 | , ThrowsError 12 | , LogicError(..) 13 | ) where 14 | 15 | import Control.Monad.Error 16 | import Control.Monad.State 17 | import Data.Map (Map) 18 | 19 | import AI.Util.Util 20 | 21 | ---------------- 22 | -- Data Types -- 23 | ---------------- 24 | 25 | type ThrowsError = Either LogicError 26 | 27 | data LogicError = ParseError 28 | | InvalidExpression 29 | | UnknownCommand 30 | | DefaultError deriving (Show) 31 | 32 | instance Error LogicError where noMsg = DefaultError 33 | 34 | ----------------- 35 | -- Expressions -- 36 | ----------------- 37 | 38 | -- |Class for logical expressions, supporting only a single method 'parseExpr' 39 | -- which parses a 'String' into an expression. 40 | class (Show p, Eq p) => Expr p where 41 | -- |Parse a 'String' as a logical expression. 42 | parseExpr :: String -> ThrowsError p 43 | 44 | -------------------- 45 | -- Knowledge Base -- 46 | -------------------- 47 | 48 | -- |A class for knowledge bases, supporting operations 'tell' (to tell a new 49 | -- fact), 'ask' (to query the knowledge base) and 'retract' (to un-tell a 50 | -- fact). Instances of 'KB' are used in routines in the @Logic.Interative@ 51 | -- package. 52 | class (Expr p, Show t) => KB k p t where 53 | -- |Returns an empty knowledge base. 54 | empty :: k p t 55 | 56 | -- |Store a new fact in the knowledge base. 57 | tell :: k p t -> p -> k p t 58 | 59 | -- |Ask whether a particular statement is entailed by the knowledge base. 60 | ask :: k p t -> p -> Bool 61 | 62 | -- |Given a statement containing variables, return an assignment of 63 | -- variables to terms that satisfies the statement. 64 | askVars :: k p t -> p -> [Map String t] 65 | 66 | -- |Retract a fact from the knowledge base. 67 | retract :: k p t -> p -> k p t 68 | 69 | -- |List the propositions stored in the knowledge base. 70 | axioms :: k p t -> [p] 71 | 72 | -- |Ask if the knowledge base contains a particular fact. 73 | contains :: k p t -> p -> Bool 74 | contains kb p = p `elem` axioms kb 75 | 76 | -------------------------------------------------------------------------------- /src/AI/Logic/Interactive.hs: -------------------------------------------------------------------------------- 1 | module AI.Logic.Interactive where 2 | 3 | import Control.Monad 4 | import Control.Monad.Error 5 | import Control.Monad.State 6 | 7 | import qualified Data.List as L 8 | import qualified Data.Map as M 9 | 10 | import AI.Logic.Core 11 | import AI.Util.Util 12 | 13 | import qualified AI.Logic.Propositional as P 14 | import qualified AI.Logic.FOL as F 15 | 16 | ----------- 17 | -- Types -- 18 | ----------- 19 | 20 | type IOThrowsError = ErrorT LogicError IO 21 | type Logic k = StateT k IOThrowsError 22 | 23 | -- |Run a computation of type `Logic k'. The computation represents a live 24 | -- interaction with a knowledge base. We don't care about the result - we 25 | -- just want to get the side effects from storing premises in the knowledge 26 | -- base and querying it for new information. 27 | runLogic :: Logic k a -> k -> IO () 28 | runLogic c s = ignoreResult $ runErrorT $ evalStateT c s 29 | 30 | ---------------------- 31 | -- Interactive Code -- 32 | ---------------------- 33 | 34 | -- |Start an interaction with a propositional logic theorem prover that uses 35 | -- a resolution algorithm to do inference. An example interaction might look 36 | -- like this: 37 | -- 38 | -- >>> runProp 39 | -- Propositional Logic Theorem Prover (Resolution) 40 | -- >>> tell p 41 | -- >>> tell p=>q 42 | -- >>> show 43 | -- 0. p 44 | -- 1. (~p | q) 45 | -- >>> ask q 46 | -- Entailed: q 47 | -- >>> quit 48 | runProp :: IO () 49 | runProp = do 50 | putStrLn "Propositional Logic Theorem Prover (Resolution)" 51 | runLogic loop (empty :: P.PropKB P.PLExpr Bool) 52 | 53 | -- |Start an interaction with a propositional logic theorem prover that uses 54 | -- truth tables to do inference. See the documentation for 'runProp' for more 55 | -- information. 56 | runTruthTable :: IO () 57 | runTruthTable = do 58 | putStrLn "Propositional Logic Theorem Prover (Truth Table)" 59 | runLogic loop (empty :: P.TruthTableKB P.PLExpr Bool) 60 | 61 | -- |Start an interaction with a theorem prover that uses Horn clause and forward 62 | -- chaining to do inference. See the documentation for 'runProp' for more 63 | -- information. 64 | runForwardChaining :: IO () 65 | runForwardChaining = do 66 | putStrLn "Proposition Logic Theorem Prover (Forward Chaining)" 67 | runLogic loop (empty :: P.DefClauseKB P.DefiniteClause Bool) 68 | 69 | -- |Start an interaction with a first-order logic theorem prover that uses 70 | -- forward chaining. A typical interaction might look like this: 71 | -- 72 | -- >>> runFOL 73 | -- First Order Logic Theorem Prover (Forward Chaining) 74 | -- >>> tell Man(x)=>Mortal(x) 75 | -- >>> tell Man(Socrates) 76 | -- >>> show 77 | -- 0. Man(x) => Mortal(x) 78 | -- 1. Man(Socrates) 79 | -- >>> ask Mortal(Socrates) 80 | -- Entailed: Mortal(Socrates) 81 | -- >>> sat Mortal(x) 82 | -- Valid assignments: 83 | -- x: Socrates 84 | -- >>> quit 85 | runFOL :: IO () 86 | runFOL = do 87 | putStrLn "First Order Logic Theorem Prover (Forward Chaining)" 88 | runLogic loop (empty :: F.FCKB F.DefiniteClause F.Term) 89 | 90 | -- |The input/output loop for a theorem prover. We repeatedly ask for input 91 | -- from the user, and then dispatch on the result, until the user enters the 92 | -- command @"quit"@. 93 | loop :: KB k p t => Logic (k p t) () 94 | loop = untilM (== "quit") (liftIO readPrompt) (trapError . dispatch) 95 | 96 | -- |Decide what to do with user input to a theorem prover. 97 | dispatch :: KB k p t => String -> Logic (k p t) () 98 | dispatch str = case cmd of 99 | "show" -> get >>= (\kb -> liftIO $ showPremises kb) 100 | "help" -> liftIO showHelp 101 | "tell" -> parse rest >>= tellKB 102 | "ask" -> parse rest >>= askKB 103 | "sat" -> parse rest >>= satisfyKB 104 | "retract" -> parse rest >>= retractKB 105 | "clear" -> clear 106 | "" -> return () 107 | _ -> liftIO unknown 108 | where 109 | (cmd,rest) = break (== ' ') str 110 | 111 | -- |Parse an expression entered by the user to be passed into either `tellKB', 112 | -- `askKB' or `retractKB'. 113 | parse :: KB k p t => String -> Logic (k p t) p 114 | parse str = case parseExpr (strip str) of 115 | Left _ -> liftIO (putStrLn "***parse error") >> throwError ParseError 116 | Right p -> return p 117 | 118 | -- |Store a new premise in the knowledge base. 119 | tellKB :: KB k p t => p -> Logic (k p t) () 120 | tellKB expr = modify (\kb -> tell kb expr) 121 | 122 | -- |Query the knowledge base. 123 | askKB :: KB k p t => p -> Logic (k p t) () 124 | askKB expr = do 125 | kb <- get 126 | --liftIO $ showPremises kb 127 | if ask kb expr 128 | then liftIO $ putStrLn $ "Entailed: " ++ show expr 129 | else liftIO $ putStrLn $ "Not entailed: " ++ show expr 130 | 131 | -- |Find all assignments that satisfy a given query. 132 | satisfyKB :: KB k p t => p -> Logic (k p t) () 133 | satisfyKB expr = do 134 | kb <- get 135 | --liftIO $ showPremises kb 136 | case askVars kb expr of 137 | [] -> liftIO $ putStrLn "No assignments satisfy expression" 138 | xs -> liftIO $ putStrLn "Valid assignments:" >> mapM_ display1 xs 139 | 140 | display1 :: Show t => M.Map String t -> IO () 141 | display1 xs = putStrLn $ (" "++) $ L.intercalate ", " $ map func $ M.toList xs 142 | where 143 | func (k,v) = k ++ ": " ++ show v 144 | 145 | -- |Remove a premise from the knowledge base. 146 | retractKB :: KB k p t => p -> Logic (k p t) () 147 | retractKB expr = do 148 | kb <- get 149 | if kb `contains` expr 150 | then modify $ \kb -> retract kb expr 151 | else liftIO $ putStrLn $ "***expression " ++ show expr ++ " not in KB" 152 | 153 | -- |Empty the knowledge base of all previously entered premises. 154 | clear :: KB k p t => Logic (k p t) () 155 | clear = put empty 156 | 157 | -- |Display a list of all premises currently stored in the knowledge base. 158 | showPremises :: KB k p t => k p t -> IO () 159 | showPremises kb = forM_ (enumerate $ axioms kb) $ 160 | \(n,p) -> putStrLn (" " ++ show n ++ ". " ++ show p) 161 | 162 | -- |IO routine to deal with unrecognised commands. 163 | unknown :: IO () 164 | unknown = putStrLn "***unknown command" 165 | 166 | -- |Show a useful help message for an interaction with a theorem prover. 167 | showHelp :: IO () 168 | showHelp = do 169 | putStrLn " tell

Store proposition

in the knowledge base" 170 | putStrLn " retract

Remove proposition

from the knowledge base" 171 | putStrLn " ask

Ask whether

is entailed by the knowledge base" 172 | putStrLn " sat

Find assignments that satisfy query

" 173 | putStrLn " clear Remove all propositions from the knowledge base" 174 | putStrLn " show Display the current state of the knowledge base" 175 | putStrLn " help Show this help file" 176 | putStrLn " quit Exit the PLTP" 177 | -------------------------------------------------------------------------------- /src/AI/Probability/Bayes.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | 3 | module AI.Probability.Bayes 4 | ( BayesNet 5 | , fromList 6 | , enumerationAsk 7 | , eliminationAsk 8 | , rejectionSample 9 | , likelihoodWeighting ) where 10 | 11 | import AI.Util.ProbDist 12 | import AI.Util.Array 13 | import AI.Util.Util 14 | 15 | import Control.DeepSeq 16 | import Control.Monad 17 | import Data.Map (Map, (!)) 18 | import Data.Ord (comparing) 19 | import qualified Control.Monad.Random as R 20 | import qualified Data.List as L 21 | import qualified Data.Map as M 22 | 23 | --------------- 24 | -- Bayes Net -- 25 | --------------- 26 | 27 | -- |A node in a Bayes Net. We keep things very lightweight, storing just a 28 | -- list of the node's parents and its conditional probability table as a list. 29 | data Node e = Node { nodeParents :: [e] 30 | , nodeCond :: [Prob] } deriving (Show) 31 | 32 | -- |A Bayes Net contains two fields - a list of variables ordered from parents 33 | -- to children, and a 'Map' from variable names to Nodes. 34 | data BayesNet e = BayesNet { bnVars :: [e] 35 | , bnMap :: Map e (Node e) } deriving (Show) 36 | 37 | -- |This function creates a Bayes Net from a list of elements of the form 38 | -- (variable, parents, conditional probs). The conditional probability table 39 | -- is specified with the first parent varying most slowly. For exampleif the 40 | -- parents are A and B, and the conditional probability table is 41 | -- 42 | -- > A | B | Prob 43 | -- > --+---+----- 44 | -- > T | T | 0.9 45 | -- > T | F | 0.8 46 | -- > F | T | 0.7 47 | -- > F | F | 0.1 48 | -- 49 | -- then the list of probabilities should be @[0.9,0.8,0.7,0.1]@. 50 | fromList :: Ord e => [ (e, [e], [Prob]) ] -> BayesNet e 51 | fromList xs = BayesNet vars net 52 | where 53 | net = foldr go M.empty xs 54 | 55 | go (ev,ps,cond) = if length cond /= 2 ^ length ps 56 | then error "Invalid length for probability table" 57 | else M.insert ev (Node ps cond) 58 | 59 | vars = L.sortBy (comparing rank) (M.keys net) 60 | 61 | rank e = if null ps then 0 else 1 + maximum (map rank ps) 62 | where ps = nodeParents (net ! e) 63 | 64 | --------------------- 65 | -- Enumeration Ask -- 66 | --------------------- 67 | 68 | -- |The Enumeration-Ask algorithm. This iterates over variables in the Bayes 69 | -- Net, from parents to children, summing over the possible values when a 70 | -- variable is not assigned. 71 | enumerationAsk :: Ord e => BayesNet e -> [(e,Bool)] -> e -> Dist Bool 72 | enumerationAsk bn fixed e = normalize $ D [(True, p True), (False, p False)] 73 | where 74 | p x = enumerateAll bn (M.insert e x a) (bnVars bn) 75 | a = M.fromList fixed 76 | 77 | -- |A helper function for 'enumerationAsk'. This performs the hard work of 78 | -- enumerating all unassigned values in the network and summing over their 79 | -- conditional probabilities. 80 | enumerateAll :: Ord e => BayesNet e -> Map e Bool -> [e] -> Prob 81 | enumerateAll bn a [] = 1.0 82 | enumerateAll bn a (v:vs) = if v `M.member` a 83 | then bnProb bn a (v, a!v) * enumerateAll bn a vs 84 | else p * go True + (1 - p) * go False 85 | where 86 | p = bnProb bn a (v,True) 87 | go x = enumerateAll bn (M.insert v x a) vs 88 | 89 | --------------------- 90 | -- Elimination Ask -- 91 | --------------------- 92 | 93 | -- |A factor in the variable elimination algorithm. A factor is a list of 94 | -- variables (unfixed by the problem) and a conditional probability table 95 | -- associated with them. 96 | data Factor e = Factor { fVars :: [e], fVals :: [Prob] } deriving (Show) 97 | 98 | -- |Exact inference using the elimination-ask algorithm. 99 | eliminationAsk :: Ord e => BayesNet e -> [(e,Bool)] -> e -> Dist Bool 100 | eliminationAsk bn fixed e = go [] (reverse $ bnVars bn) 101 | where 102 | go factors [] = let f = pointwiseProduct factors 103 | in normalize $ D $ zip [True,False] (fVals f) 104 | 105 | go factors (v:vs) = let factors' = (mkFactor bn fixed v) : factors 106 | in if v `elem` hidden 107 | then go [sumOut v factors'] vs 108 | else go factors' vs 109 | 110 | hidden = (e:map fst fixed) `deleteAll` bnVars bn 111 | 112 | -- |Given a Bayes Net, a list of fixed variables and a target variable, return 113 | -- a factor to be used in the 'eliminationAsk' algorithm. 114 | mkFactor :: Ord e => BayesNet e -> [(e,Bool)] -> e -> Factor e 115 | mkFactor bn fixed e = Factor fvar (subSlice cond is) 116 | where 117 | vars = e : bnParents bn e 118 | cond = bnCond bn e ++ map (1-) (bnCond bn e) 119 | fvar = map fst fixed `deleteAll` vars 120 | is = getIxVector vars fixed 121 | 122 | -- |Return the pointwise product of a list of factors. This is simply a strict 123 | -- fold over the factors using 'mulF'. 124 | pointwiseProduct :: Eq e => [Factor e] -> Factor e 125 | pointwiseProduct = L.foldl1' mulF 126 | 127 | -- |Sum out a factor with respect to one of its variables. 128 | sumOut :: Eq e => e -> [Factor e] -> Factor e 129 | sumOut e factors = with True `addF` with False 130 | where 131 | with x = pointwiseProduct $ map (set e x) factors 132 | 133 | -- |Return the pointwise product of two factors. This is ugly at the moment! 134 | -- It should be refactored to (a) be prettier and (b) use an 'Array' instead 135 | -- of a list to store the factor values, as the huge amount of list indexing 136 | -- going on will probably be really inefficient. 137 | mulF :: Eq e => Factor e -> Factor e -> Factor e 138 | mulF f1 f2 = Factor vars (map f vals) 139 | where 140 | vars = L.union (fVars f1) (fVars f2) 141 | vals = bools (length vars) 142 | 143 | f bs = valueAt bs (getIxs f1) (fVals f1) * valueAt bs (getIxs f2) (fVals f2) 144 | 145 | getIxs factor = map (vars `indexOf`) (fVars factor) 146 | valueAt bs ns vals = vals !! bnSubRef (bs `elemsAt` ns) 147 | 148 | -- |Return the pointwise sum of two factors. This performs a quick sanity check, 149 | -- requiring that the factors have the same variables in the same order. 150 | addF :: Eq e => Factor e -> Factor e -> Factor e 151 | addF (Factor vs1 ps1) (Factor vs2 ps2) = if vs1 /= vs2 152 | then error "Can't add factors with different variables" 153 | else Factor vs1 $ zipWith (+) ps1 ps2 154 | 155 | -- |Take a slice of a factor by setting one of its variables to a fixed value. 156 | -- This is a helper function for 'sumOut'. 157 | set :: Eq e => e -> Bool -> Factor e -> Factor e 158 | set e x (Factor vs ps) = if not (e `elem` vs) 159 | then Factor vs ps 160 | else Factor (L.delete e vs) (subSlice1 ps (i,x)) 161 | where 162 | i = vs `indexOf` e 163 | 164 | ------------------------ 165 | -- Rejection Sampling -- 166 | ------------------------ 167 | 168 | -- |Random sample from a Bayes Net, according to the prior distribution. 169 | bnSample :: Ord e => BayesNet e -> IO (Map e Bool) 170 | bnSample bn = go M.empty (bnVars bn) 171 | where 172 | go assignment [] = return assignment 173 | go assignment (v:vs) = do 174 | let !p = bnProb bn assignment (v,True) 175 | x <- probabilityIO p 176 | let !assignment' = M.insert v x assignment 177 | go assignment' vs 178 | 179 | -- |Rejection sampling algorithm. Repeatedly samples from a Bayes Net and 180 | -- discards samples that do not match the evidence, and builds a probability 181 | -- distribution from the result. 182 | rejectionSample :: Ord e => Int -> BayesNet e -> [(e,Bool)] -> e -> IO (Dist Bool) 183 | rejectionSample nIter bn fixed e = 184 | foldM func initial [1..nIter] >>= return . weighted . M.toList 185 | 186 | where 187 | func m _ = do 188 | a <- bnSample bn 189 | if isConsistent a 190 | then let x = a!e in return (M.insertWith' (+) x 1 m) 191 | else return m 192 | 193 | initial = M.fromList [(True,0),(False,0)] 194 | 195 | isConsistent a = map (a!) vars == vals 196 | 197 | (vars,vals) = unzip fixed 198 | 199 | -------------------------- 200 | -- Likelihood Weighting -- 201 | -------------------------- 202 | 203 | -- |Random sample from a Bayes Net, with an associated likelihood weight. The 204 | -- weight gives the likelihood of the fixed evidence, given the sample. 205 | weightedSample :: Ord e => BayesNet e -> [(e,Bool)] -> IO (Map e Bool, Prob) 206 | weightedSample bn fixed = go 1.0 (M.fromList fixed) (bnVars bn) 207 | where 208 | go w assignment [] = return (assignment, w) 209 | go w assignment (v:vs) = if v `elem` vars 210 | then 211 | let !w' = w * bnProb bn assignment (v, fixed %! v) 212 | in go w' assignment vs 213 | else do 214 | let !p = bnProb bn assignment (v,True) 215 | x <- probabilityIO p 216 | let !assignment' = M.insert v x assignment 217 | go w assignment' vs 218 | 219 | vars = map fst fixed 220 | 221 | -- |Repeatedly draw likelihood-weighted samples from a distribution to infer 222 | -- probabilities from a Bayes Net. 223 | likelihoodWeighting :: Ord e => Int -> BayesNet e -> [(e,Bool)] -> e -> IO (Dist Bool) 224 | likelihoodWeighting nIter bn fixed e = 225 | foldM func initial [1..nIter] >>= distribution 226 | 227 | where 228 | func m _ = do 229 | (a, w) <- weightedSample bn fixed 230 | let x = a ! e 231 | return (M.insertWith' (+) x w m) 232 | 233 | initial = M.fromList [(True,0),(False,0)] 234 | 235 | distribution = return . normalize . D . M.toList 236 | 237 | ------------------------- 238 | -- Bayes Net Utilities -- 239 | ------------------------- 240 | 241 | -- |Given a set of assignments and a variable, this function returns the values 242 | -- of the variable's parents in the assignment, in the order that they are 243 | -- specified in the Bayes Net. 244 | bnVals :: Ord e => BayesNet e -> Map e Bool -> e -> [Bool] 245 | bnVals bn a x = map (a!) (bnParents bn x) 246 | 247 | -- |Return the parents of a specified variable in a Bayes Net. 248 | bnParents :: Ord e => BayesNet e -> e -> [e] 249 | bnParents (BayesNet _ m) x = nodeParents (m ! x) 250 | 251 | -- |Return the conditional probability table of a variable in a Bayes Net. 252 | bnCond :: Ord e => BayesNet e -> e -> [Prob] 253 | bnCond (BayesNet _ m) x = nodeCond (m ! x) 254 | 255 | -- |Given a set of assignments and a (variable,value) pair, this function 256 | -- returns the probability that the variable has that value, given the 257 | -- assignments. Note that the variable's parents must be already assigned 258 | -- (this is why it is important to perform the enumeration of variables from 259 | -- parents to children). 260 | bnProb :: Ord e => BayesNet e -> Map e Bool -> (e, Bool) -> Prob 261 | bnProb bn a (v,b) = if b then p else 1 - p 262 | where p = bnCond bn v !! bnSubRef (bnVals bn a v) 263 | 264 | -- |A helper function for 'bnProb'. Given a list of parent values, this returns 265 | -- the correct index for a probability to be extracted from the conditional 266 | -- probability table associated with a variable. 267 | bnSubRef :: [Bool] -> Int 268 | bnSubRef = ndSubRef . map (\x -> if x then 0 else 1) 269 | -------------------------------------------------------------------------------- /src/AI/Probability/Example/Alarm.hs: -------------------------------------------------------------------------------- 1 | module AI.Probability.Example.Alarm 2 | ( alarm 3 | , AI.Probability.Bayes.enumerationAsk 4 | , AI.Probability.Bayes.eliminationAsk 5 | , AI.Probability.Bayes.rejectionSample 6 | , AI.Probability.Bayes.likelihoodWeighting 7 | ) where 8 | 9 | import AI.Probability.Bayes 10 | 11 | -- |The "alarm" example from AIMA. You can query the network using any of the 12 | -- ask functions. For example, to query the distribution of /Burglary/ given 13 | -- that /JohnCalls/ is true, you would do 14 | -- 15 | -- >>> enumerationAsk alarm [("JohnCalls",True)] "Burglary" 16 | -- True 1.6% 17 | -- False 98.4% 18 | -- 19 | -- At present, there seems to be a bug, in that the following happens: 20 | -- 21 | -- >>> enumerationAsk alarm [("JohnCalls",True)] "JohnCalls" 22 | -- True 5.2% 23 | -- False 94.8% 24 | -- 25 | -- whereas I think that a distribution representing certainty should be 26 | -- return instead. 27 | alarm :: BayesNet String 28 | alarm = fromList [ ("Burglary", [], [0.001]) 29 | , ("Earthquake", [], [0.002]) 30 | , ("Alarm", ["Burglary","Earthquake"], [0.95,0.94,0.29,0.001]) 31 | , ("JohnCalls", ["Alarm"], [0.9,0.05]) 32 | , ("MaryCalls", ["Alarm"], [0.7,0.01]) ] -------------------------------------------------------------------------------- /src/AI/Probability/Example/Grass.hs: -------------------------------------------------------------------------------- 1 | module AI.Probability.Example.Grass 2 | ( grass 3 | , AI.Probability.Bayes.enumerationAsk 4 | , AI.Probability.Bayes.eliminationAsk 5 | , AI.Probability.Bayes.rejectionSample 6 | , AI.Probability.Bayes.likelihoodWeighting 7 | ) where 8 | 9 | import AI.Probability.Bayes 10 | 11 | -- |The "grass" example from AIMA. You can query the network using any of the 12 | -- ask functions. For example, to query the distribution of /Rain/ given 13 | -- that /GrassWet/ is true, you would do 14 | -- 15 | -- >>> enumerationAsk alarm [("GrassWet",True)] "Rain" 16 | -- True 35.8% 17 | -- False 64.2% 18 | -- 19 | -- The same bug as in the alarm example also occurs here. 20 | grass :: BayesNet String 21 | grass = fromList [ ("Rain", [], [0.2]) 22 | , ("Sprinkler", ["Rain"], [0.01, 0.4]) 23 | , ("GrassWet", ["Sprinkler","Rain"], [0.99, 0.9, 0.8, 0]) ] -------------------------------------------------------------------------------- /src/AI/Probability/MDP.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses #-} 2 | 3 | module AI.Probability.MDP 4 | ( MDP 5 | , Utility 6 | , Policy 7 | , valueIteration 8 | , policyIteration ) where 9 | 10 | import Data.Map (Map, (!)) 11 | import GHC.Float 12 | 13 | import qualified Data.Map as Map 14 | 15 | import AI.Util.ProbDist 16 | import AI.Util.Util 17 | 18 | -- |Type for a utility function (a mapping from states to utilities) 19 | type Utility s = s -> Double 20 | 21 | -- |Type for a policy (a mapping from states to actions) 22 | type Policy s a = s -> a 23 | 24 | -- |Class for a Markov Decision Process. An MDP is defined by an initial state, 25 | -- a transition model and a reward function. We may also specify a discount 26 | -- factor (often called gamma). The transition model is represented somewhat 27 | -- differently to the text. Instead of T(s, a, s') being a probability for each 28 | -- state/action/state triplet, we instead have T(s,a) return a probability 29 | -- distribution over states. We also keep track of the possible states, 30 | -- terminal states and actions for each state. 31 | class Ord s => MDP m s a where 32 | -- |Return the initial state for the problem. 33 | initial :: m s a -> s 34 | 35 | -- |Return a list of the possible actions for this MDP. 36 | actionList :: m s a -> [a] 37 | 38 | -- |Return a list of all states for this MDP. 39 | stateList :: m s a -> [s] 40 | 41 | -- |Return a list of the terminal states for this MDP. 42 | terminalStates :: m s a -> [s] 43 | 44 | -- |Return the reward associated with a particular state. 45 | reward :: m s a -> s -> Double 46 | 47 | -- |Return a probability distribution over states for a given (state,action) 48 | -- pair. 49 | transition :: m s a -> s -> a -> Dist s 50 | 51 | -- |Return the discount factor 'gamma' for this MDP. 52 | discountFactor :: m s a -> Double 53 | discountFactor _ = 0.9 54 | 55 | -- |Return the list of actions that can be performed in this state. By 56 | -- default this is a fixed list of actions, except for at terminal states. 57 | -- You can override this method if you need to specialize by state. 58 | actions :: m s a -> s -> [a] 59 | actions m s = if s `elem` terminalStates m then [] else actionList m 60 | 61 | -------------------- 62 | -- MDP Algorithms -- 63 | -------------------- 64 | 65 | -- |Solve a Markov Decision Process using value iteration. 66 | valueIteration :: MDP m s a => 67 | m s a -- ^ Problem to be solved 68 | -> Double -- ^ Tolerance: determines when to terminate 69 | -> Policy s a -- ^ The final policy 70 | valueIteration mdp epsilon = bestPolicy mdp $ go (const 0.0) 71 | where 72 | go u = if delta < epsilon * (1 - gamma) / gamma then u1 else go u1 73 | where 74 | delta = maximum [ abs (u1 s - u s) | s <- states ] 75 | u1 = listToFunction $ [ (s,f s) | s <- states ] 76 | f s = reward mdp s + gamma * maximum (g s) 77 | g s = [ expectedUtility mdp u s a | a <- actions mdp s ] 78 | 79 | gamma = discountFactor mdp 80 | states = stateList mdp 81 | 82 | -- |Solve a Markov Decision Process using (modified) policy iteration. 83 | policyIteration :: (Eq a, MDP m s a) => 84 | m s a -- ^ Problem to be solved 85 | -> Policy s a -- ^ Final policy 86 | policyIteration mdp = go (\s -> head (actions mdp s)) (const 0) 87 | where 88 | go p u = if unchanged then p else go p1 u1 89 | where 90 | u1 = policyEvaluation mdp p u 20 91 | p1 = bestPolicy mdp u1 92 | unchanged = all (\s -> p s == p1 s) (stateList mdp) 93 | 94 | -- |Return an updated utility mapp from each state in the MDP to its utility, 95 | -- using an approximation (modified policy iteration). 96 | policyEvaluation :: MDP m s a => 97 | m s a -- ^ Markov Decision Process 98 | -> Policy s a -- ^ Policy to be evaluated 99 | -> Utility s -- ^ Initial utility function 100 | -> Int -- ^ Number of iterations 101 | -> Utility s -- ^ Final utility function 102 | policyEvaluation mdp p u k = go u k 103 | where 104 | go u k = if k == 0 then u else go u1 (k - 1) 105 | where 106 | u1 = listToFunction [ (s, f s) | s <- stateList mdp ] 107 | f s = reward mdp s + gamma * expectedUtility mdp u s (p s) 108 | 109 | gamma = discountFactor mdp 110 | 111 | --------------- 112 | -- Utilities -- 113 | --------------- 114 | 115 | -- |Given an MDP and a utility function, determine the best policy. 116 | bestPolicy :: MDP m s a => m s a -> Utility s -> Policy s a 117 | bestPolicy mdp u = listToFunction [ (s, bestAction s) | s <- stateList mdp ] 118 | where 119 | bestAction s = argMax (actions mdp s) (expectedUtility mdp u s) 120 | 121 | -- |The expected utility of taking action @a@ in state @s@, according to the 122 | -- decision process and a particular utility function. 123 | expectedUtility :: MDP m s a => m s a -> Utility s -> s -> a -> Double 124 | expectedUtility mdp u s a = 125 | float2Double $ expectation $ fmap u $ transition mdp s a 126 | -------------------------------------------------------------------------------- /src/AI/Search/Core.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-} 2 | 3 | module AI.Search.Core ( 4 | -- * Core classes and data structures 5 | Problem (..) 6 | , Node (..) 7 | , Cost 8 | , root 9 | , path 10 | , expand 11 | -- * Search algorithms 12 | , treeSearch 13 | , graphSearch 14 | -- * Algorithm comparison 15 | , compareSearchers 16 | , detailedCompareSearchers 17 | -- * Algorithm tracking 18 | , ProblemIO 19 | , mkProblemIO) where 20 | 21 | import Control.DeepSeq 22 | import Control.Monad 23 | import Data.IORef 24 | import Data.Maybe (fromJust,listToMaybe) 25 | import System.IO.Unsafe 26 | 27 | import qualified Data.Set as S 28 | 29 | import AI.Util.Queue 30 | import AI.Util.Table 31 | import AI.Util.Util 32 | 33 | --import qualified AI.Util.Graph as G 34 | 35 | -- |The type used to represent the cost associated with a particular path. 36 | type Cost = Double 37 | 38 | -- |Class for an abstract problem with state type s and action type a. A 39 | -- minimal implementation consists of 'initial' and 'successor', and one 40 | -- of 'goal' or 'goalTest'. 41 | class Eq s => Problem p s a where 42 | -- | The initial state of the problem. 43 | initial :: p s a -> s 44 | 45 | -- | Given a state, return a sequence of (action, state) pairs reachable 46 | -- from this state. Because of lazy evaluation we only ever compute as 47 | -- many elements of the list as the program needs. 48 | successor :: p s a -> s -> [(a, s)] 49 | 50 | -- | If the problem has a unique goal state, this method should return it. 51 | -- The default implementation of 'goalTest' compares for equality with 52 | -- this state. 53 | goal :: p s a -> s 54 | goal = undefined 55 | 56 | -- | Return true if the state is a goal. The default method compares the 57 | -- state to the state specified in the implementation of 'goal'. You can 58 | -- override this method if checking against a single goal is not enough. 59 | goalTest :: p s a -> s -> Bool 60 | goalTest p s = s == goal p 61 | 62 | -- | Return the cost of a solution path that arrives at the second state 63 | -- from the first state, via the specified action. If the problem is such 64 | -- that the path doesn't matter, the function will only look at the second 65 | -- state. The default implementation costs 1 for every step in the path. 66 | costP :: p s a -> Cost -> s -> a -> s -> Cost 67 | costP _ c _ _ _ = c + 1 68 | 69 | -- | You may want to specify a heuristic function for the problem. The 70 | -- default implementation always returns zero. 71 | heuristic :: p s a -> Node s a -> Cost 72 | heuristic _ = const 0 73 | 74 | -- | For optimization problems, each state has a value. Hill-climbing and 75 | -- related algorithms try to maximise this value. The default 76 | -- implementation always returns zero. 77 | valueP :: p s a -> s -> Double 78 | valueP _ = const 0 79 | 80 | -- |A node in a search tree. It contains a reference to its parent (the node 81 | -- that this is a successor of) and to the state for this node. Note that if 82 | -- a state can be arrived at by two paths, there will be two nodes with the 83 | -- same state. It may also include the action that got us to this state, and 84 | -- the total path cost. 85 | data Node s a = Node { state :: s 86 | , parent :: Maybe (Node s a) 87 | , action :: Maybe a 88 | , cost :: Cost 89 | , depth :: Int 90 | , value :: Double } 91 | 92 | instance (Show s, Show a) => Show (Node s a) where 93 | show (Node state _ action cost depth _) = 94 | "Node(state=" ++ show state ++ ",action=" ++ show action ++ 95 | ",cost=" ++ show cost ++ ",depth=" ++ show depth ++ ")" 96 | 97 | -- |A convenience constructor for root nodes (a node with no parent, no action 98 | -- that leads to it, and zero cost.) 99 | root :: (Problem p s a) => p s a -> Node s a 100 | root p = Node s Nothing Nothing 0 0 (valueP p s) where s = initial p 101 | 102 | -- |Create a list of paths from the root node to the node specified. 103 | path :: Node s a -> [Node s a] 104 | path n = case parent n of 105 | Nothing -> [n] 106 | Just n' -> n : path n' 107 | 108 | -- |Return a list of nodes reachable from this node in the context of the 109 | -- specified problem. 110 | expand :: (Problem p s a) => p s a -> Node s a -> [Node s a] 111 | expand p node = [ mkNode a s | (a,s) <- successor p (state node) ] 112 | where 113 | mkNode a s = Node s (Just node) (Just a) (c a s) (1 + depth node) v 114 | c a s = costP p (cost node) (state node) a s 115 | v = valueP p (state node) 116 | 117 | ---------------------------- 118 | -- Core Search Algorithms -- 119 | ---------------------------- 120 | 121 | -- |Search through the successors of a node to find a goal. The argument 122 | -- @fringe@ should be an empty queue. We don't worry about repeated paths 123 | -- to a state. 124 | treeSearch :: (Problem p s a, Queue q) => 125 | q (Node s a) -- ^ Empty queue 126 | -> p s a -- ^ Problem 127 | -> Maybe (Node s a) 128 | treeSearch q prob = listToMaybe $ genericSearch f q prob 129 | where 130 | f node closed = (expand prob node, closed) 131 | 132 | 133 | -- |Search through the successors of a node to find a goal. The argument 134 | -- @fringe@ should be an empty queue. If two paths reach the same state, use 135 | -- only the best one. 136 | graphSearch :: (Problem p s a, Queue q, Ord s) => 137 | q (Node s a) -- ^ Empty queue 138 | -> p s a -- ^ Problem 139 | -> Maybe (Node s a) 140 | graphSearch q prob = listToMaybe $ genericSearch f q prob 141 | where 142 | f node closed 143 | | state node `S.member` closed = (newQueue,closed) 144 | | otherwise = (expand prob node, closed') 145 | where 146 | closed' = state node `S.insert` closed 147 | 148 | 149 | genericSearch :: (Queue q, Problem p s a) => 150 | (Node s a -> S.Set a1 -> ([Node s a], S.Set a1)) 151 | -> q (Node s a) -> p s a -> [(Node s a)] 152 | genericSearch f q prob = findFinalState (genericSearchPath f (root prob `push` q)) 153 | where 154 | findFinalState = filter (goalTest prob.state) 155 | 156 | -- Return a (potentially infinite) list of nodes to search. 157 | -- Since the reult is lazy, you can break out early if you find a resut. 158 | genericSearchPath :: Queue q => (a -> S.Set a1 -> ([a], S.Set a1)) -> q a -> [a] 159 | genericSearchPath f q = go (q,S.empty) 160 | where 161 | go (fringe,closed) 162 | | empty fringe = [] 163 | | otherwise = go' (pop fringe) closed 164 | go' (node, rest) closed 165 | | (new,closed') <-(f node closed) = node : go (new `extend` rest, closed') 166 | 167 | 168 | ----------------------- 169 | -- Compare Searchers -- 170 | ----------------------- 171 | 172 | -- |Wrapper for a problem that keeps statistics on how many times nodes were 173 | -- expanded in the course of a search. We track the number of times 'goalCheck' 174 | -- was called, the number of times 'successor' was called, and the total number 175 | -- of states expanded. 176 | data ProblemIO p s a = PIO 177 | { problemIO :: p s a 178 | , numGoalChecks :: IORef Int 179 | , numSuccs :: IORef Int 180 | , numStates :: IORef Int } 181 | 182 | -- |Construct a new ProblemIO, with all counters initialized to zero. 183 | mkProblemIO :: p s a -> IO (ProblemIO p s a) 184 | mkProblemIO p = do 185 | i <- newIORef 0 186 | j <- newIORef 0 187 | k <- newIORef 0 188 | return (PIO p i j k) 189 | 190 | -- |Make ProblemIO into an instance of Problem. It uses the same implementation 191 | -- as the problem it wraps, except that whenever 'goalTest' or 's' 192 | instance (Problem p s a, Eq s, Show s) => Problem (ProblemIO p) s a where 193 | initial (PIO p _ _ _) = initial p 194 | 195 | goalTest (PIO p n _ _) s = unsafePerformIO $ do 196 | modifyIORef n (+1) 197 | return (goalTest p s) 198 | 199 | successor (PIO p _ n m) s = unsafePerformIO $ do 200 | let succs = successor p s 201 | modifyIORef n (+1) 202 | modifyIORef m (+length succs) 203 | return succs 204 | 205 | costP (PIO p _ _ _) = costP p 206 | 207 | heuristic (PIO p _ _ _) = heuristic p 208 | 209 | -- |Given a problem and a search algorithm, run the searcher on the problem 210 | -- and return the solution found, together with statistics about how many 211 | -- nodes were expanded in the course of finding the solution. 212 | testSearcher :: p s a -> (ProblemIO p s a -> t) -> IO (t,Int,Int,Int) 213 | testSearcher prob searcher = do 214 | p@(PIO _ numGoalChecks numSuccs numStates) <- mkProblemIO prob 215 | let result = searcher p in result `seq` do 216 | i <- readIORef numGoalChecks 217 | j <- readIORef numSuccs 218 | k <- readIORef numStates 219 | return (result, i, j, k) 220 | 221 | -- |NFData instance for search nodes. 222 | instance (NFData s, NFData a) => NFData (Node s a) where 223 | rnf (Node state parent action cost depth value) = 224 | state `seq` parent `seq` action `seq` 225 | cost `seq` depth `seq` value `seq` 226 | Node state parent action cost depth value `seq` () 227 | 228 | -- |Run a search algorithm over a problem, returning the time it took as well 229 | -- as other statistics. 230 | testSearcher' :: (NFData t) => p s a -> (ProblemIO p s a -> t) -> IO (t,Int,Int,Int,Int) 231 | testSearcher' prob searcher = do 232 | p@(PIO _ numGoalChecks numSuccs numStates) <- mkProblemIO prob 233 | (result, t) <- timed (searcher p) 234 | i <- readIORef numGoalChecks 235 | j <- readIORef numSuccs 236 | k <- readIORef numStates 237 | return (result, t, i, j, k) 238 | 239 | -- |Test multiple searchers on the same problem, and return a list of results 240 | -- and statistics. 241 | testSearchers :: [ProblemIO p s a -> t] -> p s a -> IO [(t,Int,Int,Int)] 242 | testSearchers searchers prob = testSearcher prob `mapM` searchers 243 | 244 | -- |Given a list of problems and a list of searchers, run every algorithm on 245 | -- every problem and print out a table showing the performance of each. 246 | compareSearchers :: (Show t) => 247 | [ProblemIO p s a -> t] -- ^ List of search algorithms 248 | -> [p s a] -- ^ List of problems 249 | -> [String] -- ^ Problem names 250 | -> [String] -- ^ Search algorithm names 251 | -> IO [[(t,Int,Int,Int)]] 252 | compareSearchers searchers probs header rownames = do 253 | results <- testSearchers searchers `mapM` probs 254 | printTable 20 (map (map f) (transpose results)) header rownames 255 | return results 256 | where 257 | f (x,i,j,k) = SB (i,j,k) 258 | 259 | -- |Given a problem and a list of searchers, run each search algorithm over the 260 | -- problem, and print out a table showing the performance of each searcher. 261 | -- The columns of the table indicate: [Algorithm name, Depth of solution, 262 | -- Cost of solution, Number of goal checks, Number of node expansions, 263 | -- Number of states expanded] . 264 | detailedCompareSearchers :: 265 | [ProblemIO p s a -> Maybe (Node s1 a1)] -- ^ List of searchers 266 | -> [String] -- ^ Names of searchers 267 | -> p s a -- ^ Problem 268 | -> IO () 269 | detailedCompareSearchers searchers names prob = do 270 | result <- testSearchers searchers prob 271 | table <- forM result $ \(n,numGoalChecks,numSuccs,numStates) -> do 272 | let d = depth $ fromJust n 273 | let c = round $ cost $ fromJust n 274 | let b = fromIntegral numStates ** (1/fromIntegral d) 275 | return [SB d,SB c,SB numGoalChecks,SB numSuccs,SB numStates,SB b] 276 | printTable 20 table header names 277 | where 278 | header = ["Searcher","Depth","Cost","Goal Checks","Successors", 279 | "States","Eff Branching Factor"] 280 | -------------------------------------------------------------------------------- /src/AI/Search/Example/Chess.hs: -------------------------------------------------------------------------------- 1 | -- |Adapted from the Haskell Live project: 2 | -- https://github.com/haskelllive/haskelllive 3 | module AI.Search.Example.Chess where 4 | 5 | type Board = [[Square]] 6 | 7 | initialBoardStr = unlines ["rnbqkbnr" 8 | ,"pppppppp" 9 | ," " 10 | ," " 11 | ," " 12 | ," " 13 | ,"PPPPPPPP" 14 | ,"RNBQKBNR"] 15 | 16 | readBoard :: String -> Board 17 | readBoard = map readRow . lines 18 | where readRow = map readSquare 19 | 20 | showBoard :: Board -> String 21 | showBoard = unlines . map showRow 22 | where showRow = map showSquare 23 | 24 | type Square = Maybe Piece 25 | 26 | -- | Show a square using FEN notation or ' ' for an empty square. 27 | showSquare :: Square -> Char 28 | showSquare = maybe ' ' showPiece 29 | 30 | -- | Read a square using FEN notation or ' ' for an empty square. 31 | readSquare :: Char -> Square 32 | readSquare ' ' = Nothing 33 | readSquare c = Just (readPiece c) 34 | 35 | data Piece = Piece PColor PType deriving (Show) 36 | data PColor = White | Black deriving (Show) 37 | data PType = Pawn | Knight | Bishop | Rook | Queen | King deriving (Show) 38 | 39 | -- | Shows a piece using FEN notation. 40 | -- 41 | -- * White pieces are "PNBRQG" 42 | -- * Black pieces are "pnbrqg" 43 | showPiece :: Piece -> Char 44 | showPiece (Piece White Pawn) = 'P' 45 | showPiece (Piece White Knight) = 'N' 46 | showPiece (Piece White Bishop) = 'B' 47 | showPiece (Piece White Rook) = 'R' 48 | showPiece (Piece White Queen) = 'Q' 49 | showPiece (Piece White King) = 'K' 50 | showPiece (Piece Black Pawn) = 'p' 51 | showPiece (Piece Black Knight) = 'n' 52 | showPiece (Piece Black Bishop) = 'b' 53 | showPiece (Piece Black Rook) = 'r' 54 | showPiece (Piece Black Queen) = 'q' 55 | showPiece (Piece Black King) = 'k' 56 | 57 | -- | Reads a piece using FEN notation. 58 | -- 59 | -- * White pieces are "PNBRQG" 60 | -- * Black pieces are "pnbrqg" 61 | readPiece :: Char -> Piece 62 | readPiece 'P' = Piece White Pawn 63 | readPiece 'N' = Piece White Knight 64 | readPiece 'B' = Piece White Bishop 65 | readPiece 'R' = Piece White Rook 66 | readPiece 'Q' = Piece White Queen 67 | readPiece 'K' = Piece White King 68 | readPiece 'p' = Piece Black Pawn 69 | readPiece 'n' = Piece Black Knight 70 | readPiece 'b' = Piece Black Bishop 71 | readPiece 'r' = Piece Black Rook 72 | readPiece 'q' = Piece Black Queen 73 | readPiece 'k' = Piece Black King 74 | 75 | -- Tests 76 | 77 | --tests = TestList $ map TestCase 78 | -- [assertEqual "add tests here" 1 1 79 | -- ] 80 | 81 | --prop_empty c1 = (c1::Int) == c1 82 | 83 | --runTests = do 84 | -- runTestTT tests 85 | -- quickCheck prop_empty 86 | 87 | ---- | For now, main will run our tests. 88 | --main :: IO () 89 | --main = runTests -------------------------------------------------------------------------------- /src/AI/Search/Example/Connect4.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} 2 | 3 | module AI.Search.Example.Connect4 where 4 | 5 | import Data.Ord (comparing) 6 | import qualified Data.Map as M 7 | import qualified Data.List as L 8 | 9 | import AI.Search.Adversarial 10 | import AI.Search.Example.TicTacToe 11 | import AI.Util.Util 12 | 13 | --------------- 14 | -- Connect 4 -- 15 | --------------- 16 | 17 | -- |The 'Connect4' data type is a wrapper around 'TicTacToe', which allows us 18 | -- to inherit most of its behaviour. 19 | data Connect4 s a = C (TicTacToe TTState TTMove) 20 | 21 | -- |Type for Connect 4 state. 22 | type C4State = TTState 23 | 24 | -- |Type for Connect 4 moves. 25 | type C4Move = Int 26 | 27 | -- |The standard Connect 4 board. 28 | connect4 :: Connect4 C4State C4Move 29 | connect4 = C (TTT 7 6 4) 30 | 31 | -- |A Connect 4 game is identical to tic tac toe in most respects. It differs in 32 | -- the set of legal moves, the fact that it sorts moves so that those closest 33 | -- to the center are considered first, and the heuristic function. 34 | instance Game Connect4 C4State C4Move where 35 | initial (C g) = initial g 36 | toMove (C g) s = toMove g s 37 | utility (C g) s p = utility g s p 38 | terminalTest (C g) s = terminalTest g s 39 | 40 | makeMove (C g) col s = let row = lowestUnoccupied (col-1) s 41 | in makeMove g (col-1, row) s 42 | 43 | sortMoves (C (TTT h _ _)) as = L.sortBy (comparing f) as 44 | where f x = abs (x - (h+1) `div` 2) 45 | 46 | legalMoves (C g) s@(TTS board _ _ _) = 47 | [ x+1 | (x,y) <- legalMoves g s, y == 0 || (x,y-1) `M.member` board ] 48 | 49 | heuristic g = heuristicTTT [0.1,-0.1,0.9,-0.9] 50 | 51 | -- |Return the lowest row in the specified column which is currently unoccupied. 52 | lowestUnoccupied :: Int -> TTState -> Int 53 | lowestUnoccupied col (TTS board _ _ (_,v,_)) = 54 | let coords = map (\row -> (col,row)) [0..v-1] 55 | counters = map (`M.lookup` board) coords 56 | in (countIf (/=Nothing) counters) 57 | 58 | -- |Play a game of Connect 4 against an opponent using alpha/beta search. 59 | demo :: IO () 60 | demo = playGameIO connect4 queryPlayer (iterativeAlphaBetaPlayer 5) >> return () -------------------------------------------------------------------------------- /src/AI/Search/Example/Fig52Game.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} 2 | 3 | module AI.Search.Example.Fig52Game where 4 | 5 | import AI.Search.Adversarial 6 | import AI.Util.Util 7 | 8 | ---------------------------- 9 | -- Example Game (Fig 5.2) -- 10 | ---------------------------- 11 | 12 | -- |Data type representing the example game. 13 | data ExampleGame s a = ExampleGame deriving (Show) 14 | 15 | -- |Instance of the example game. 16 | exampleGame :: ExampleGame String Int 17 | exampleGame = ExampleGame 18 | 19 | -- |Definition of the example game in Fig 5.2 (mainly useful as an example of 20 | -- how to create games). 21 | instance Game ExampleGame String Int where 22 | initial g = "A" 23 | 24 | toMove g "A" = Max 25 | toMove g _ = Min 26 | 27 | legalMoves _ s = case s `elem` ["A","B","C","D"] of 28 | True -> [1,2,3] 29 | False -> [] 30 | 31 | makeMove _ n "A" = ["B","C","D"] !! (n-1) 32 | makeMove _ n "B" = ["B1","B2","B3"] !! (n-1) 33 | makeMove _ n "C" = ["C1","C2","C3"] !! (n-1) 34 | makeMove _ n "D" = ["D1","D2","D3"] !! (n-1) 35 | 36 | utility _ s p = let u = util s in if p == Max then u else -u 37 | where 38 | util = listToFunction [ ("B1", 3), ("B2",12), ("B3", 8) 39 | , ("C1", 2), ("C2", 4), ("C3", 6) 40 | , ("D1",14), ("D2", 5), ("D3", 2) ] 41 | 42 | terminalTest t s = if s `elem` ["B1","B2","B3","C1","C2","C3","D1","D2","D3"] 43 | then True 44 | else False -------------------------------------------------------------------------------- /src/AI/Search/Example/Graph.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | 3 | module AI.Search.Example.Graph where 4 | 5 | import Control.DeepSeq 6 | import Control.Monad.State (StateT) 7 | import Control.Monad 8 | import Data.IORef 9 | import Data.Map (Map, (!)) 10 | import Data.Maybe (fromJust) 11 | import System.IO 12 | import System.IO.Unsafe 13 | 14 | import qualified Control.Monad.State as State 15 | import qualified Data.List as L 16 | import qualified Data.Map as M 17 | import qualified Data.Ord as O 18 | import qualified System.Random as R 19 | 20 | import AI.Search.Core 21 | import AI.Search.Uninformed 22 | import AI.Search.Informed 23 | import AI.Util.WeightedGraph (WeightedGraph) 24 | import AI.Util.Table 25 | import AI.Util.Util 26 | 27 | import qualified AI.Util.WeightedGraph as G 28 | 29 | ------------------------------- 30 | -- Graphs and Graph Problems -- 31 | ------------------------------- 32 | 33 | -- |Data structure to hold a graph (edge weights correspond to the distance 34 | -- between nodes) and a map of graph nodes to locations. 35 | data GraphMap a = G 36 | { getGraph :: WeightedGraph a Cost 37 | , getLocations :: Map a Location } deriving (Show,Read) 38 | 39 | -- |Type synonym for a pair of doubles, representing a location in cartesian 40 | -- coordinates. 41 | type Location = (Double,Double) 42 | 43 | -- |Creates a GraphMap from the graph's adjacency list representation and a list 44 | -- of (node, location) pairs. This function creates undirected graphs, so you 45 | -- don't need to include reverse links in the adjacency list (though you can 46 | -- if you like). 47 | mkGraphMap :: (Ord a) => [(a,[(a,Cost)])] -> [(a,Location)] -> GraphMap a 48 | mkGraphMap conn loc = G (G.toUndirectedGraph conn) (M.fromList loc) 49 | 50 | -- |Get the neighbours of a node from a GraphMap. 51 | getNeighbours :: Ord a => a -> GraphMap a -> [(a,Cost)] 52 | getNeighbours a (G g _) = G.getNeighbours a g 53 | 54 | -- |Get the location of a node from a GraphMap. 55 | getLocation :: Ord a => a -> GraphMap a -> Location 56 | getLocation a (G _ l) = case M.lookup a l of 57 | Nothing -> error "Vertex not found in graph -- GETLOCATION" 58 | Just pt -> pt 59 | 60 | -- | Add an edge between two nodes to a GraphMap. 61 | addEdge :: Ord a => a -> a -> Cost -> GraphMap a -> GraphMap a 62 | addEdge x y cost (G graph locs) = G (G.addUndirectedEdge x y cost graph) locs 63 | 64 | -- |The cost associated with moving between two nodes in a GraphMap. If the 65 | -- nodes are not connected by an edge, then the cost is returned as infinity. 66 | costFromTo :: Ord a => GraphMap a -> a -> a -> Cost 67 | costFromTo graph a b = case lookup b (getNeighbours a graph) of 68 | Nothing -> 1/0 69 | Just c -> c 70 | 71 | -- |Data structure to hold a graph problem (represented as a GraphMap together 72 | -- with an initial and final node). 73 | data GraphProblem s a = GP 74 | { graphGP :: GraphMap s 75 | , initGP :: s 76 | , goalGP :: s } deriving (Show,Read) 77 | 78 | -- |GraphProblems are an instance of Problem. The heuristic function measures 79 | -- the Euclidean (straight-line) distance between two nodes. It is assumed that 80 | -- this is less than or equal to the cost of moving along edges. 81 | instance Ord s => Problem GraphProblem s s where 82 | initial = initGP 83 | goal = goalGP 84 | successor (GP g _ _) s = [ (x,x) | (x,_) <- getNeighbours s g ] 85 | costP (GP g _ _) c s _ s' = c + costFromTo g s s' 86 | heuristic (GP g _ goal) n = euclideanDist x y 87 | where 88 | x = getLocation (state n) g 89 | y = getLocation goal g 90 | 91 | -- |Measures the Euclidean (straight-line) distance between two locations. 92 | euclideanDist :: Location -> Location -> Double 93 | euclideanDist (x,y) (x',y') = sqrt $ (x-x')^2 + (y-y')^2 94 | 95 | -- |Construct a random graph with the specified number of nodes, and random 96 | -- links. The nodes are laid out randomly on a @(width x height)@ rectangle. 97 | -- Then each node is connected to the @minLinks@ nearest neighbours. Because 98 | -- inverse links are added, some nodes will have more connections. The distance 99 | -- between nodes is the hypotenuse multiplied by @curvature@, where @curvature@ 100 | -- defaults to a random number between 1.1 and 1.5. 101 | randomGraphMap :: 102 | Int -- ^ Number of nodes 103 | -> Int -- ^ Minimum number of links 104 | -> Double -- ^ Width 105 | -> Double -- ^ Height 106 | -> IO (GraphMap Int) 107 | randomGraphMap n minLinks width height = State.execStateT go (mkGraphMap [] []) where 108 | go = do 109 | replicateM n mkLocation >>= State.put . mkGraphMap [] . zip nodes 110 | 111 | forM_ nodes $ \x -> do 112 | 113 | State.modify (addEmpty x) 114 | g @ (G _ loc) <- State.get 115 | 116 | let nbrs = map fst (getNeighbours x g) 117 | numNbrs = length nbrs 118 | 119 | unconnected = deleteAll (x:nbrs) nodes 120 | sorted = L.sortBy (O.comparing to_x) unconnected 121 | to_x y = euclideanDist (loc ! x) (loc ! y) 122 | toAdd = take (minLinks - numNbrs) sorted 123 | 124 | mapM_ (addLink x) toAdd 125 | 126 | where 127 | nodes = [1..n] 128 | 129 | addLink x y = do 130 | curv <- curvature 131 | dist <- distance x y 132 | State.modify $ addEdge x y (dist * curv) 133 | 134 | addEmpty x (G graph xs) = G (M.insert x M.empty graph) xs 135 | 136 | mkLocation = State.liftIO $ do 137 | x <- R.randomRIO (0,width) 138 | y <- R.randomRIO (0,height) 139 | return (x,y) 140 | 141 | curvature = State.liftIO $ R.randomRIO (1.1, 1.5) 142 | 143 | distance x y = do 144 | (G _ loc) <- State.get 145 | return $ euclideanDist (loc ! x) (loc ! y) 146 | 147 | -- |Return a random instance of a graph problem with the specified number of 148 | -- nodes and minimum number of links. 149 | randomGraphProblem :: Int -> Int -> IO (GraphProblem Int Int) 150 | randomGraphProblem numNodes minLinks = do 151 | g@(G _ loc) <- randomGraphMap numNodes minLinks 100 100 152 | let initial = fst $ L.minimumBy (O.comparing (fst.snd)) (M.toList loc) 153 | goal = fst $ L.maximumBy (O.comparing (fst.snd)) (M.toList loc) 154 | return (GP g initial goal) 155 | 156 | -- |Write a list of graph problems to a file. 157 | writeGraphProblems :: Show p => FilePath -> [p] -> IO () 158 | writeGraphProblems filename ps = do 159 | h <- openFile filename WriteMode 160 | forM_ ps (hPrint h) 161 | hClose h 162 | 163 | -- |Read a list of graph problems from a file. 164 | readGraphProblems :: FilePath -> IO [GraphProblem Int Int] 165 | readGraphProblems filepath = do 166 | contents <- readFile filepath 167 | return $ map read $ lines contents 168 | 169 | -- |Generate random graph problems and write them to a file. Each problem is 170 | -- checked for solvability by running the 'depthFirstGraphSearch' algorithm 171 | -- on it. This function finds poor solutions, but terminates quickly on this 172 | -- kind of problem. 173 | generateGraphProblems :: Int -> Int -> Int -> FilePath -> IO () 174 | generateGraphProblems numProbs numNodes minLinks filepath = do 175 | probs <- go numProbs 176 | writeGraphProblems filepath probs 177 | where 178 | go 0 = return [] 179 | go n = do 180 | p <- randomGraphProblem numNodes minLinks 181 | case depthFirstGraphSearch p of 182 | Nothing -> go n 183 | Just _ -> go (n-1) >>= \ps -> return (p:ps) 184 | 185 | ---------------------------------- 186 | -- Graphs used in AIMA examples -- 187 | ---------------------------------- 188 | 189 | -- |The Romania graph from AIMA. 190 | romania :: GraphMap String 191 | romania = mkGraphMap 192 | 193 | [ ("A", [("Z",75), ("S",140), ("T",118)]) 194 | , ("B", [("U",85), ("P",101), ("G",90), ("F",211)]) 195 | , ("C", [("D",120), ("R",146), ("P",138)]) 196 | , ("D", [("M",75)]) 197 | , ("E", [("H",86)]) 198 | , ("F", [("S",99)]) 199 | , ("H", [("U",98)]) 200 | , ("I", [("V",92), ("N",87)]) 201 | , ("L", [("T",111), ("M",70)]) 202 | , ("O", [("Z",71), ("S",151)]) 203 | , ("P", [("R",97)]) 204 | , ("R", [("S",80)]) 205 | , ("U", [("V",142)]) ] 206 | 207 | [ ("A",( 91,491)), ("B",(400,327)), ("C",(253,288)), ("D",(165,299)) 208 | , ("E",(562,293)), ("F",(305,449)), ("G",(375,270)), ("H",(534,350)) 209 | , ("I",(473,506)), ("L",(165,379)), ("M",(168,339)), ("N",(406,537)) 210 | , ("O",(131,571)), ("P",(320,368)), ("R",(233,410)), ("S",(207,457)) 211 | , ("T",( 94,410)), ("U",(456,350)), ("V",(509,444)), ("Z",(108,531)) ] 212 | 213 | -- |The Australia graph from AIMA. 214 | australia :: GraphMap String 215 | australia = mkGraphMap 216 | 217 | [ ("T", []) 218 | , ("SA", [("WA",1), ("NT",1), ("Q",1), ("NSW",1), ("V",1)]) 219 | , ("NT", [("WA",1), ("Q",1)]) 220 | , ("NSW", [("Q", 1), ("V",1)]) ] 221 | 222 | [ ("WA",(120,24)), ("NT" ,(135,20)), ("SA",(135,30)), 223 | ("Q" ,(145,20)), ("NSW",(145,32)), ("T" ,(145,42)), ("V",(145,37))] 224 | 225 | gp1, gp2, gp3 :: GraphProblem String String 226 | gp1 = GP { graphGP = australia, initGP = "Q", goalGP = "WA" } 227 | gp2 = GP { graphGP = romania, initGP = "A", goalGP = "B" } 228 | gp3 = GP { graphGP = romania, initGP = "O", goalGP = "N" } 229 | 230 | ----------------------------- 231 | -- Compare Graph Searchers -- 232 | ----------------------------- 233 | 234 | -- |Run all search algorithms over a particular problem and print out 235 | -- performance statistics. 236 | runDetailedCompare :: (Problem p s a, Ord s, Show s) => p s a -> IO () 237 | runDetailedCompare = detailedCompareSearchers allSearchers allSearcherNames 238 | 239 | -- |List of all search algorithms that can be applied to problems with a graph 240 | -- structure. I'd like to add an iterative deepening graph search to this list, 241 | -- as well as some of the more exotic search algorithsm described in the 242 | -- textbook. 243 | allSearchers :: (Problem p s a, Ord s) => [p s a -> Maybe (Node s a)] 244 | allSearchers = [ breadthFirstGraphSearch, depthFirstGraphSearch 245 | , greedyBestFirstSearch, uniformCostSearch, aStarSearch'] 246 | 247 | -- |Names for the search algorithms in this module. 248 | allSearcherNames :: [String] 249 | allSearcherNames = [ "Breadth First Graph Search" , "Depth First Graph Search" 250 | , "Greedy Best First Search", "Uniform Cost Search" 251 | , "A* Search"] 252 | 253 | ----------- 254 | -- Demos -- 255 | ----------- 256 | 257 | -- |Run all search algorithms over a few example problems. 258 | demo1 :: IO () 259 | demo1 = compareSearchers allSearchers probs header allSearcherNames >> return () 260 | where 261 | probs = [gp1, gp2, gp3] 262 | header = ["Searcher", "Australia", "Romania(A,B)","Romania(O,N)"] 263 | 264 | -- |Load ten example problems from a file and run all searchers over them. 265 | demo2 :: IO () 266 | demo2 = do 267 | ps <- readGraphProblems "data/problems_small.txt" 268 | mapM_ runDetailedCompare ps 269 | 270 | -- |Load 100 example problems from a file and run all searchers over them. 271 | demo3 :: IO () 272 | demo3 = do 273 | ps <- readGraphProblems "data/problems_large.txt" 274 | mapM_ runDetailedCompare ps 275 | -------------------------------------------------------------------------------- /src/AI/Search/Example/MapColoring.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} 2 | 3 | module AI.Search.Example.MapColoring where 4 | 5 | import Data.Map (Map, (!)) 6 | import qualified Data.Map as M 7 | 8 | import AI.Search.CSP 9 | import AI.Util.Graph (Graph) 10 | import AI.Util.Util 11 | 12 | import qualified AI.Util.Graph as G 13 | 14 | ---------------------- 15 | -- Map Coloring CSP -- 16 | ---------------------- 17 | 18 | data MapColoringCSP v a = MCP 19 | { neighboursMC :: Graph String 20 | , colorsMC :: [Char] } deriving (Show) 21 | 22 | instance CSP MapColoringCSP String Char where 23 | vars (MCP nbrs _) = M.keys nbrs 24 | 25 | domains csp = mkUniversalMap (vars csp) (colorsMC csp) 26 | 27 | neighbours (MCP nbrs _) = nbrs 28 | 29 | constraints csp x xv y yv = xv /=yv || not (y `elem` neighbours csp ! x) 30 | 31 | 32 | 33 | ----------------------------------- 34 | -- Map Coloring Problems in AIMA -- 35 | ----------------------------------- 36 | 37 | australia :: MapColoringCSP String Char 38 | australia = MCP territories "RGB" 39 | where 40 | territories = G.toGraph $ 41 | [ ("SA", ["WA","NT","Q","NSW","V"]) 42 | , ("NT", ["WA","Q","SA"]) 43 | , ("NSW", ["Q","V","SA"]) 44 | , ("T", []) 45 | , ("WA", ["SA","NT"]) 46 | , ("Q", ["SA","NT","NSW"]) 47 | , ("V", ["SA","NSW"]) ] 48 | 49 | usa :: MapColoringCSP String Char 50 | usa = MCP states "RGBY" 51 | where states = G.parseGraph 52 | "WA: OR ID; OR: ID NV CA; CA: NV AZ; NV: ID UT AZ; ID: MT WY UT;\ 53 | \UT: WY CO AZ; MT: ND SD WY; WY: SD NE CO; CO: NE KA OK NM; NM: OK TX;\ 54 | \ND: MN SD; SD: MN IA NE; NE: IA MO KA; KA: MO OK; OK: MO AR TX;\ 55 | \TX: AR LA; MN: WI IA; IA: WI IL MO; MO: IL KY TN AR; AR: MS TN LA;\ 56 | \LA: MS; WI: MI IL; IL: IN; IN: KY; MS: TN AL; AL: TN GA FL; MI: OH;\ 57 | \OH: PA WV KY; KY: WV VA TN; TN: VA NC GA; GA: NC SC FL;\ 58 | \PA: NY NJ DE MD WV; WV: MD VA; VA: MD DC NC; NC: SC; NY: VT MA CA NJ;\ 59 | \NJ: DE; DE: MD; MD: DC; VT: NH MA; MA: NH RI CT; CT: RI; ME: NH;\ 60 | \HI: ; AK: " 61 | 62 | ----------- 63 | -- Demos -- 64 | ----------- 65 | 66 | demo1 :: IO () 67 | demo1 = case backtrackingSearch australia fastOpts of 68 | Nothing -> putStrLn "No solution found." 69 | Just a -> putStrLn "Solution found:" >> print a 70 | 71 | demo2 :: IO () 72 | demo2 = case backtrackingSearch usa fastOpts of 73 | Nothing -> putStrLn "No solution found." 74 | Just a -> putStrLn "Solution found:" >> print a 75 | -------------------------------------------------------------------------------- /src/AI/Search/Example/NQueens.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | 3 | module AI.Search.Example.NQueens where 4 | 5 | import qualified Data.List as L 6 | 7 | import AI.Search.Core 8 | import AI.Util.Util 9 | 10 | ---------------------- 11 | -- N Queens Problem -- 12 | ---------------------- 13 | 14 | -- |Data structure to define an N-Queens problem (the problem is defined by 15 | -- the size of the board). 16 | data NQueens s a = NQ { sizeNQ :: Int } deriving (Show) 17 | 18 | -- |Update the state of the N-Queens board by playing a queen at (i,n). 19 | updateNQ :: (Int,Int) -> [Maybe Int] -> [Maybe Int] 20 | updateNQ (c,r) s = insert c (Just r) s 21 | 22 | -- |Would putting two queens in (r1,c1) and (r2,c2) conflict? 23 | conflict :: Int -> Int -> Int -> Int -> Bool 24 | conflict r1 c1 r2 c2 = 25 | r1 == r2 || c1 == c2 || r1-c1 == r2-c2 || r1+c1 == r2+c2 26 | 27 | -- |Would placing a queen at (row,col) conflict with anything? 28 | conflicted :: [Maybe Int] -> Int -> Int -> Bool 29 | conflicted state row col = or $ map f (enumerate state) 30 | where 31 | f (_, Nothing) = False 32 | f (c, Just r) = if c == col && r == row 33 | then False 34 | else conflict row col r c 35 | 36 | -- |N-Queens is an instance of Problem. 37 | instance Problem NQueens [Maybe Int] (Int,Int) where 38 | initial (NQ n) = replicate n Nothing 39 | 40 | -- @L.elemIndex Nothing s@ finds the index of the first column in s 41 | -- that doesn't yet have a queen. 42 | successor (NQ n) s = case L.elemIndex Nothing s of 43 | Nothing -> [] 44 | Just i -> zip actions (map (`updateNQ` s) actions) 45 | where 46 | actions = map ((,) i) [0..n-1] 47 | 48 | goalTest (NQ n) s = if last s == Nothing 49 | then False 50 | else not . or $ map f (enumerate s) 51 | where 52 | f (c,Nothing) = False 53 | f (c,Just r) = conflicted s r c 54 | 55 | -- |An example N-Queens problem on an 8x8 grid. 56 | nQueens :: NQueens [Maybe Int] (Int,Int) 57 | nQueens = NQ 8 -------------------------------------------------------------------------------- /src/AI/Search/Example/Sudoku.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} 2 | 3 | module AI.Search.Example.Sudoku where 4 | 5 | import Control.Monad 6 | import Data.Map (Map, (!)) 7 | 8 | import qualified Data.List as L 9 | import qualified Data.Map as M 10 | 11 | import AI.Search.CSP 12 | import AI.Util.Util 13 | 14 | ------------ 15 | -- Sudoku -- 16 | ------------ 17 | 18 | data Sudoku v a = Sudoku (Domain String Char) deriving Show 19 | 20 | instance CSP Sudoku String Char where 21 | vars s = squares 22 | domains (Sudoku dom) = dom 23 | neighbours s = M.fromList peers 24 | constraints s x xv y yv = (xv /= yv) || not (x `elem` neighbours s ! y) 25 | 26 | 27 | 28 | cross :: [a] -> [a] -> [[a]] 29 | cross xs ys = [ [x,y] | x <- xs, y <- ys ] 30 | 31 | digits = "123456789" 32 | rows = "abcdefghi" 33 | cols = digits 34 | squares = cross rows cols 35 | unitlist = [ cross rows c | c <- map return cols ] ++ 36 | [ cross r cols | r <- map return rows ] ++ 37 | [ cross rs cs | rs <- ["abc","def","ghi"], cs <- ["123","456","789"] ] 38 | units = [ (s, [ u | u <- unitlist, s `elem` u ]) | s <- squares ] 39 | peers = [ (s, L.delete s $ L.nub $ concat u) | (s,u) <- units ] 40 | 41 | parseGrid :: String -> Sudoku String Char 42 | parseGrid grid = 43 | Sudoku $ foldr update (mkUniversalMap squares digits) initial 44 | where 45 | update (x,y) = if y `elem` digits 46 | then M.insert x [y] 47 | else M.insert x digits 48 | initial = zip squares $ filter (`elem` ( "0." ++ digits)) grid 49 | 50 | ------------------------------------------ 51 | -- Example Sudokus (from Project Euler) -- 52 | ------------------------------------------ 53 | 54 | sudoku1 = parseGrid "003020600900305001001806400008102900700000008006708200002609500800203009005010300" 55 | sudoku2 = parseGrid "200080300060070084030500209000105408000000000402706000301007040720040060004010003" 56 | sudoku3 = parseGrid "000000907000420180000705026100904000050000040000507009920108000034059000507000000" 57 | 58 | ----------- 59 | -- Demos -- 60 | ----------- 61 | 62 | demo :: IO () 63 | demo = do 64 | let sudokus = [sudoku1,sudoku2,sudoku3] 65 | forM_ sudokus $ \s -> case backtrackingSearch s fastOpts of 66 | Nothing -> putStrLn "No solution found." 67 | Just sol -> do putStrLn "Solution found:" 68 | print sol 69 | -------------------------------------------------------------------------------- /src/AI/Search/Example/TicTacToe.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} 2 | 3 | module AI.Search.Example.TicTacToe where 4 | 5 | import Data.Map (Map) 6 | import Data.Maybe (catMaybes) 7 | import qualified Data.Map as M 8 | import qualified Data.List as L 9 | 10 | import AI.Search.Adversarial 11 | import AI.Util.Util 12 | 13 | ---------------------------------- 14 | -- Tic Tac Toe on a H x V board -- 15 | ---------------------------------- 16 | 17 | -- |Data type for K-in-a-row tic tac toe, on a H x V board. 18 | data TicTacToe s a = TTT { hT :: Int, vT :: Int, kT :: Int } deriving (Show) 19 | 20 | -- |A move in tic tac toe is a pair of integers indicating the row and column, 21 | -- indexed from zero. 22 | type TTMove = (Int,Int) 23 | 24 | -- |Each counter in tic-tac-toe is either an @O@ or an @X@. 25 | data TTCounter = O | X deriving (Eq,Show) 26 | 27 | -- |A tic tac toe board is a map from board positions to counters. Note that 28 | -- @M.lookup (x,y) board@ will return @Nothing@ if square @(x,y)@ is empty. 29 | type TTBoard = Map TTMove TTCounter 30 | 31 | -- |The state of a tic tac toe game is defined by the board. We also store the 32 | -- player whose move is next, the utility of this state (which is only nonzero 33 | -- if the state is terminal) and the size of the board, for convenience. 34 | data TTState = TTS 35 | { boardTT :: TTBoard 36 | , toMoveTT :: TTCounter 37 | , utilityTT :: Utility 38 | , limsTT :: (Int,Int,Int) } 39 | 40 | -- |This 'Game' instance defines the rules of tic tac toe. Note that whenever 41 | -- a move is made, we compute the utility of the newly created state on the 42 | -- fly. This avoids having to write an expensive function to decide if any 43 | -- player has won for a specific board state. The game is over when either 44 | -- a player has one, or there are no legal moves left to make. 45 | instance Game TicTacToe TTState TTMove where 46 | initial (TTT h v k) = TTS M.empty O 0 (h,v,k) 47 | 48 | toMove _ s = if toMoveTT s == O then Max else Min 49 | 50 | legalMoves (TTT h v _) (TTS board _ _ _) = 51 | [ (i,j) | i <- [0..h-1], j <- [0..v-1], M.notMember (i,j) board ] 52 | 53 | makeMove g move s@(TTS board p _ n) = 54 | let u = computeUtility s move 55 | in TTS (M.insert move p board) (other p) u n 56 | 57 | utility _ s p = let u = utilityTT s in if p == Max then u else -u 58 | 59 | terminalTest g s = utilityTT s /= 0 || null (legalMoves g s) 60 | 61 | heuristic _ = heuristicTTT [1,-1] 62 | 63 | -- |A 3x3 instance of tic tac toe. 64 | ticTacToe :: TicTacToe TTState TTMove 65 | ticTacToe = TTT 3 3 3 66 | 67 | -- |A useful function that interchanges @O@s and @X@s. 68 | other :: TTCounter -> TTCounter 69 | other O = X 70 | other X = O 71 | 72 | -- |In our game, @Max@ always plays the @O@ counter and @Min@ plays @X@. 73 | counter :: Player -> TTCounter 74 | counter Max = O 75 | counter Min = X 76 | 77 | -- |Helper function that computes the utility of a state after a particular 78 | -- move is played. 79 | computeUtility :: TTState -> TTMove -> Utility 80 | computeUtility s@(TTS _ player _ _) move = if kInARow s move player 81 | then if player == O then posInf else negInf 82 | else 0 83 | 84 | -- |Given the current state of the board, return @True@ if putting a counter 85 | -- into a specific square would win the game. 86 | kInARow :: TTState -> TTMove -> TTCounter -> Bool 87 | kInARow state move player = f (1,0) || f (0,1) || f (1,1) || f (1,-1) 88 | where 89 | f = kInARow' state move player 90 | 91 | -- |A helper function for 'kInARow'. Given a state of the board, does adding a 92 | -- counter to a specific square given k-in-a-row in the specified direction? 93 | kInARow' :: TTState -> TTMove -> TTCounter -> (Int,Int) -> Bool 94 | kInARow' (TTS board _ _ (_,_,k)) (x,y) p (dx,dy) = n1 + n2 - 1 >= k 95 | where 96 | board' = M.insert (x,y) p board 97 | fw = map (`M.lookup` board') ( zip [x,x+dx..] [y,y+dy..] ) 98 | bk = map (`M.lookup` board') ( zip [x,x-dx..] [y,y-dy..] ) 99 | n1 = length $ takeWhile (== Just p) fw 100 | n2 = length $ takeWhile (== Just p) bk 101 | 102 | -------------------------- 103 | -- Displaying the Board -- 104 | -------------------------- 105 | 106 | -- |The Show instance for 'TTState' creates a human-readable representation of 107 | -- the board. 108 | instance Show TTState where 109 | show s = concat $ concat $ L.intersperse [row] $ 110 | map ((++["\n"]) . L.intersperse "|") (toChars s) 111 | where 112 | (h,_,_) = limsTT s 113 | row = (concat $ replicate (h-1) "---+") ++ "---\n" 114 | 115 | -- |A helper function for @Show TTState@ that converts each position on the 116 | -- board to its @Char@ representation. 117 | toChars :: TTState -> [[String]] 118 | toChars (TTS board _ _ (h,v,_)) = reverse $ map (map f) board' 119 | where 120 | board' = [ [ M.lookup (i,j) board | i <- [0..h-1] ] | j <- [0..v-1] ] 121 | f (Just O) = " O " 122 | f (Just X) = " X " 123 | f Nothing = " " 124 | 125 | ------------------------ 126 | -- Compute Heuristics -- 127 | ------------------------ 128 | 129 | -- |A heuristic function for Tic Tac Toe. The value is a weighted 130 | -- combination of simpler heuristic functions. 131 | heuristicTTT :: [Double] -> TTState -> Player -> Utility 132 | heuristicTTT weights s p = sum $ zipWith (*) weights [n1,n2,n3,n4] 133 | where 134 | n1 = fromIntegral (numWinningLines p s) 135 | n2 = fromIntegral (numWinningLines (opponent p) s) 136 | n3 = fromIntegral (numThreats p s) 137 | n4 = fromIntegral (numThreats (opponent p) s) 138 | 139 | -- |Compute the number of winning lines heuristic for a particular player. A 140 | -- winning line is defined to be a line of k squares, which contains at least 141 | -- one of your counters, and none of the opponents counters. 142 | numWinningLines :: Player -> TTState -> Int 143 | numWinningLines p s = length $ filter (isWinningLine $ counter p) (allLines s) 144 | 145 | -- |Compute the number of threats heuristic for a particular player. A threat 146 | -- cell is a cell that would win the game if it was filled in, but is not 147 | -- blockable on the next move (i.e. it is not on the bottom row and does not 148 | -- have a counter directly beneath it). 149 | numThreats :: Player -> TTState -> Int 150 | numThreats p s@(TTS _ _ _ (h,v,_)) = length $ filter (isThreat s p) xs 151 | where 152 | xs = [ (i,j) | i <- [0..h-1], j <- [0..v-1] ] 153 | 154 | -- |Return @True@ if a cell is a threat cell. 155 | isThreat :: TTState -> Player -> (Int,Int) -> Bool 156 | isThreat s@(TTS board _ _ _) p (x,y) = 157 | y /= 0 && (x,y-1) `M.notMember` board && kInARow s (x,y) (counter p) 158 | 159 | -- |Return @True@ if a line of pieces is a winning line. 160 | isWinningLine :: TTCounter -> [Maybe TTCounter] -> Bool 161 | isWinningLine c xs = c `elem` ys && not (other c `elem` ys) 162 | where 163 | ys = catMaybes xs 164 | 165 | -- |Return a list of all of the lines on the board. 166 | allLines :: TTState -> [[Maybe TTCounter]] 167 | allLines s = concat [ linesInDir s (1,0), linesInDir s (0,1) 168 | , linesInDir s (1,1), linesInDir s (1,-1) ] 169 | 170 | -- |Return all of the lines on the board in the specified direction. 171 | linesInDir :: TTState -> (Int,Int) -> [[Maybe TTCounter]] 172 | linesInDir s@(TTS board _ _ (h,v,k)) dir = 173 | map (\p -> lineThrough s p dir) pts 174 | where 175 | pts = case dir of 176 | (1,0) -> [ (x,y) | x <- [0..h-k], y <- [0..v-1] ] 177 | (0,1) -> [ (x,y) | x <- [0..h-1], y <- [0..v-k] ] 178 | (1,1) -> [ (x,y) | x <- [0..h-k], y <- [0..v-k] ] 179 | (1,-1) -> [ (x,y) | x <- [0..h-k], y <- [k-1..v-1] ] 180 | 181 | -- |Return the line starting in cell (x,y) and continuing in direction (dx,dy) 182 | lineThrough :: TTState -> (Int,Int) -> (Int,Int) -> [Maybe TTCounter] 183 | lineThrough (TTS board _ _ (h,v,k)) (x,y) (dx,dy) = 184 | take k $ map (`M.lookup` board) ( zip [x,x+dx..] [y,y+dy..] ) 185 | 186 | ---------- 187 | -- Demo -- 188 | ---------- 189 | 190 | -- |Play a game of tic-tac-toe against a player using the minimax algorithm 191 | -- with full search. This player is impossible to beat - the best you can 192 | -- do is to draw. 193 | demo :: IO () 194 | demo = do playGameIO ticTacToe queryPlayer minimaxFullSearchPlayer 195 | return () -------------------------------------------------------------------------------- /src/AI/Search/Informed.hs: -------------------------------------------------------------------------------- 1 | module AI.Search.Informed where 2 | 3 | import AI.Search.Core 4 | import AI.Util.Queue 5 | 6 | --------------------------------- 7 | -- Informed (Heuristic) Search -- 8 | --------------------------------- 9 | 10 | -- |Type synonym for heuristic functions. In principle they can take any 11 | -- information at a search node into account, including cost already incurred 12 | -- at this node, depth of the node, or the state reached so far. 13 | type Heuristic s a = Node s a -> Double 14 | 15 | -- |Best-first tree search takes a function that scores each potential successor 16 | -- and prefers to explore nodes with the lowest score first. 17 | bestFirstTreeSearch :: (Problem p s a) => 18 | Heuristic s a -- ^ Function to score each node 19 | -> p s a -- ^ Problem 20 | -> Maybe (Node s a) 21 | bestFirstTreeSearch f = treeSearch (newPriorityQueue f) 22 | 23 | -- |Best-first graph search keeps track of states that have already been visited 24 | -- and won't visit the same state twice. 25 | bestFirstGraphSearch :: (Problem p s a, Ord s) => 26 | Heuristic s a -- ^ Function to score each node 27 | -> p s a -- ^ Problem 28 | -> Maybe (Node s a) 29 | bestFirstGraphSearch f = graphSearch (newPriorityQueue f) 30 | 31 | -- |Minimum cost search preferentially explores nodes with the lowest cost 32 | -- accrued, to guarantee that it finds the best path to the solution. 33 | uniformCostSearch :: (Problem p s a, Ord s) => p s a -> Maybe (Node s a) 34 | uniformCostSearch prob = bestFirstGraphSearch cost prob 35 | 36 | -- |Greedy best-first search preferentially explores nodes with the lowest 37 | -- cost remaining to the goal, ignoring cost already accrued. 38 | greedyBestFirstSearch :: (Problem p s a, Ord s) => p s a -> Maybe (Node s a) 39 | greedyBestFirstSearch prob = bestFirstGraphSearch (heuristic prob) prob 40 | 41 | -- |A* search takes a heuristic function that estimates how close each state is 42 | -- to the goal. It combines this with the path cost so far to get a total 43 | -- score, and preferentially explores nodes with a lower score. It is optimal 44 | -- whenever the heuristic function is 45 | aStarSearch :: (Problem p s a, Ord s) => 46 | Heuristic s a -- ^ Heuristic function 47 | -> p s a -- ^ Problem 48 | -> Maybe (Node s a) 49 | aStarSearch h = bestFirstGraphSearch (\n -> h n + cost n) 50 | 51 | -- |A variant on A* search that uses the heuristic function defined by the 52 | -- problem. 53 | aStarSearch' :: (Problem p s a, Ord s) => p s a -> Maybe (Node s a) 54 | aStarSearch' prob = aStarSearch (heuristic prob) prob -------------------------------------------------------------------------------- /src/AI/Search/Local.hs: -------------------------------------------------------------------------------- 1 | module AI.Search.Local 2 | ( hillClimbingSearch 3 | , Schedule(..) 4 | , expSchedule 5 | , simulatedAnnealing 6 | ) where 7 | 8 | import AI.Search.Core 9 | import AI.Util.Util 10 | 11 | ------------------- 12 | -- Hill Climbing -- 13 | ------------------- 14 | 15 | -- |From the initial node, keep choosing the neighbour with the highest value, 16 | -- stopping when no neighbour is better. 17 | hillClimbingSearch :: (Problem p s a) => p s a -> Node s a 18 | hillClimbingSearch prob = go (root prob) 19 | where 20 | go node = if value neighbour <= value node 21 | then node 22 | else go neighbour 23 | where 24 | neighbour = argMax (expand prob node) value 25 | 26 | ------------------------- 27 | -- Simulated Annealing -- 28 | ------------------------- 29 | 30 | -- |Data type for an annealing schedule. 31 | type Schedule = Int -> Double 32 | 33 | -- |One possible schedule function for simulated annealing. 34 | expSchedule :: Double -> Int -> Schedule 35 | expSchedule lambda limit k = if k < limit 36 | then exp (-lambda * fromIntegral k) 37 | else 0 38 | 39 | -- |Simulated annealing search. At each stage a random neighbour node is picked, 40 | -- and we move to that node if its value is higher than the current 41 | -- node. If its value is lower, then we move to it with some probability 42 | -- depending on the current 'temperature'. The temperature is gradually 43 | -- reduced according to an annealing schedule, making random jumps less likely 44 | -- as the algorithm progresses. 45 | simulatedAnnealing :: (Problem p s a) => Schedule -> p s a -> IO (Node s a) 46 | simulatedAnnealing schedule prob = go 0 (root prob) 47 | where 48 | go k current = let t = schedule k in 49 | if t == 0 50 | then return current 51 | else do 52 | next <- randomChoiceIO (expand prob current) 53 | let deltaE = value next - value current 54 | jump <- probabilityIO (exp $ deltaE / t) 55 | if deltaE > 0 || jump 56 | then go (k+1) next 57 | else go (k+1) current 58 | -------------------------------------------------------------------------------- /src/AI/Search/Uninformed.hs: -------------------------------------------------------------------------------- 1 | module AI.Search.Uninformed where 2 | 3 | import AI.Search.Core 4 | import AI.Util.Queue 5 | 6 | ---------------------------------- 7 | -- Uninformed Search Algorithms -- 8 | ---------------------------------- 9 | 10 | -- |Search the deepest nodes in the search tree first. 11 | depthFirstTreeSearch :: (Problem p s a) => p s a -> Maybe (Node s a) 12 | depthFirstTreeSearch = treeSearch [] 13 | 14 | -- |Search the shallowest nodes in the search tree first. 15 | breadthFirstTreeSearch :: (Problem p s a) => p s a -> Maybe (Node s a) 16 | breadthFirstTreeSearch = treeSearch (newQueue :: FifoQueue (Node s a)) 17 | 18 | -- |Search the deepest nodes in the graph first. 19 | depthFirstGraphSearch :: (Problem p s a, Ord s) => p s a -> Maybe (Node s a) 20 | depthFirstGraphSearch = graphSearch [] 21 | 22 | -- |Search the shallowest nodes in the graph first. 23 | breadthFirstGraphSearch :: (Problem p s a, Ord s) => p s a -> Maybe (Node s a) 24 | breadthFirstGraphSearch = graphSearch (newQueue :: FifoQueue (Node s a)) 25 | 26 | -- |Return type for depth-limited search. We need this as there are two types of 27 | -- failure - either we establish that the problem has no solutions ('Fail') or 28 | -- we can't find any solutions within the depth limit ('Cutoff'). 29 | data DepthLimited a = Fail | Cutoff | Ok a deriving (Show) 30 | 31 | -- |Depth-first search with a depth limit. If the depth limit is reached we 32 | -- return 'Cutoff', otherwise return 'Fail' (if no solution is found) or 'Ok' 33 | -- (if a solution is found) which take the place of Nothing and Just in the 34 | -- other search functions. 35 | depthLimitedSearch :: (Problem p s a) => 36 | Int -- ^ Depth limit 37 | -> p s a -- ^ Problem 38 | -> DepthLimited (Node s a) 39 | depthLimitedSearch lim prob = recursiveDLS (root prob) prob lim 40 | where 41 | recursiveDLS node p lim 42 | | goalTest p (state node) = Ok node 43 | | depth node == lim = Cutoff 44 | | otherwise = filt False $ map go (expand prob node) 45 | where 46 | go node = recursiveDLS node p lim 47 | 48 | filt cutoff [] = if cutoff then Cutoff else Fail 49 | filt cutoff (Ok node : _) = Ok node 50 | filt cutoff (Fail : rest) = filt cutoff rest 51 | filt cutoff (Cutoff : rest) = filt True rest 52 | 53 | -- |Repeatedly try depth-limited search with an increasing depth limit. 54 | iterativeDeepeningSearch :: (Problem p s a) => p s a -> Maybe (Node s a) 55 | iterativeDeepeningSearch prob = go 1 56 | where 57 | go lim = case depthLimitedSearch lim prob of 58 | Cutoff -> go (lim + 1) 59 | Fail -> Nothing 60 | Ok n -> Just n -------------------------------------------------------------------------------- /src/AI/Test/Learning/LinearRegression.hs: -------------------------------------------------------------------------------- 1 | module AI.Test.Learning.LinearRegression (runAllTests) where 2 | 3 | import Data.Packed.Matrix 4 | import Data.Packed.Vector 5 | import Numeric.Container 6 | import Test.QuickCheck 7 | 8 | import AI.Learning.LinearRegression 9 | import AI.Util.Matrix 10 | import AI.Test.Util 11 | 12 | -- |Regressing against a column of zeros should return a zero result vector. 13 | testRegressionAgainstZeros :: Gen Bool 14 | testRegressionAgainstZeros = do 15 | m <- choose (1,10) 16 | n <- choose (m,100) 17 | x <- arbitraryGaussianMatrix (n,m) :: Gen (Matrix Double) 18 | let y = constant 0 n 19 | b = constant 0 m 20 | bSample = regress x y 21 | return (bSample == b) 22 | 23 | 24 | allTests = 25 | [ testRegressionAgainstZeros ] 26 | 27 | runAllTests = mapM_ quickCheck allTests -------------------------------------------------------------------------------- /src/AI/Test/Main.hs: -------------------------------------------------------------------------------- 1 | module AI.Test.Main where 2 | 3 | import qualified AI.Test.Learning.LinearRegression as LR 4 | 5 | run = do 6 | LR.runAllTests -------------------------------------------------------------------------------- /src/AI/Test/Util.hs: -------------------------------------------------------------------------------- 1 | module AI.Test.Util where 2 | 3 | import Data.Packed.Vector 4 | import Data.Packed.Matrix 5 | import Foreign.Storable 6 | import Numeric.Container 7 | import Test.QuickCheck hiding ((><)) 8 | 9 | arbitraryGaussianVector :: Int -> Gen (Vector Double) 10 | arbitraryGaussianVector n = do 11 | seed <- arbitrary 12 | return (randomVector seed Gaussian n) 13 | 14 | arbitraryUniformVector :: Int -> Gen (Vector Double) 15 | arbitraryUniformVector n = do 16 | seed <- arbitrary 17 | return (randomVector seed Uniform n) 18 | 19 | arbitraryGaussianMatrix :: (Int,Int) -> Gen (Matrix Double) 20 | arbitraryGaussianMatrix (n,m) = do 21 | seed <- arbitrary 22 | let mu = constant 0 m 23 | cov = ident m 24 | return (gaussianSample seed n mu cov) 25 | 26 | arbitraryUniformMatrix :: (Int,Int) -> Gen (Matrix Double) 27 | arbitraryUniformMatrix (n,m) = do 28 | seed <- arbitrary 29 | return (uniformSample seed n (replicate m (0,1))) 30 | 31 | --instance (Storable a, Arbitrary a) => Arbitrary (Matrix a) where 32 | -- arbitrary = do 33 | -- n <- arbitrary `suchThat` \n -> n > 0 && n < 10 34 | -- m <- arbitrary `suchThat` \m -> m > 0 && m < 100 35 | -- s <- arbitrary 36 | 37 | -- return $ (n> Int 13 | ndSubRef = L.foldl' (\a d -> 2 * a + d) 0 14 | 15 | -- |Return the index of the first occurence of a particular element in a list. 16 | indexOf :: Eq a => [a] -> a -> Int 17 | indexOf xs x = case L.elemIndex x xs of 18 | Nothing -> error "Element not found -- INDEXOF" 19 | Just i -> i 20 | 21 | -- |Given a list of variables and a list of fixings, return a list (index,value) 22 | -- which can be used to 'subSlice' a conditional probability vector. 23 | getIxVector :: Eq e => [e] -> [(e,Bool)] -> [(Int,Bool)] 24 | getIxVector vars [] = [] 25 | getIxVector vars ((v,x):rest) = if v `elem` vars 26 | then (vars `indexOf` v, x) : getIxVector vars rest 27 | else getIxVector vars rest 28 | 29 | -- |This function returns the indexes to take an (n-1)-dimensional subslice of 30 | -- an n-dimensional array. The first argument gives n, the number of 31 | -- dimensions of the array. The second argument gives the index being fixed. 32 | -- The indexes returned can be used to index into an array of length 2^(n-1). 33 | subSliceIdx :: Int -> (Int, Bool) -> [Int] 34 | subSliceIdx ndim (i,x) = filter f [0 .. 2^ndim - 1] 35 | where 36 | f = if x then select else not . select 37 | select n = (n `div` 2 ^ (ndim - i - 1)) `mod` 2 == 0 38 | 39 | -- |This function returns the indexes to perform an arbitrary subslice of an 40 | -- n-dimensional array, by fixing a subset of its indexes. 41 | subSliceIdxs :: Int -> [(Int,Bool)] -> [Int] 42 | subSliceIdxs ndim [] = [0 .. 2^ndim - 1] 43 | subSliceIdxs ndim idxs = L.foldl1' L.intersect $ map (subSliceIdx ndim) idxs 44 | 45 | -- |Given a 1-dimensional array storing the values of an N-dimensional array, 46 | -- take a subslice by fixing one of the dimensions to either 0 or 1. 47 | subSlice1 :: [a] -> (Int,Bool) -> [a] 48 | subSlice1 xs (i,x) = xs `elemsAt` subSliceIdx (log2 $ length xs) (i,x) 49 | 50 | -- |Given a 1-dimensional array storing the values of an N-dimensional array, 51 | -- take a subslice by fixing a subset of the indexes to either 0 or 1. 52 | subSlice :: [a] -> [(Int,Bool)] -> [a] 53 | subSlice xs is = xs `elemsAt` subSliceIdxs (log2 $ length xs) is 54 | 55 | -- |Base 2 logarithm for 'Int's. 56 | log2 :: Int -> Int 57 | log2 n = go n 0 where go n x = if n == 1 then x else go (n `div` 2) (x+1) 58 | 59 | -------------------------------------------------------------------------------- /src/AI/Util/Graph.hs: -------------------------------------------------------------------------------- 1 | module AI.Util.Graph 2 | ( Graph(..) 3 | , toGraph 4 | , fromGraph 5 | , getNodes 6 | , getNeighbours 7 | , getEdge 8 | , addEdge 9 | , addUndirectedEdge 10 | , parseGraph) where 11 | 12 | import Data.Map (Map, (!)) 13 | import qualified Data.Map as M 14 | import qualified Data.List as L 15 | import qualified Data.Text as T 16 | 17 | ----------------------- 18 | -- Unweighted Graphs -- 19 | ----------------------- 20 | 21 | -- |Type for unweighted graphs. 22 | type Graph a = Map a [a] 23 | 24 | -- |Create a directed graph from an adjacency list. 25 | toGraph :: (Ord a) => [(a, [a])] -> Graph a 26 | toGraph = M.fromList 27 | 28 | -- |Create an undirected graph from an adjacency list. 29 | toUndirectedGraph :: Ord a => [(a, [a])] -> Graph a 30 | toUndirectedGraph xs = fromPairRep . symmetrize . 31 | toPairRep $ toGraph xs 32 | 33 | -- |Convert an unweighted graph to its adjacency list representation. 34 | fromGraph :: Graph a -> [(a, [a])] 35 | fromGraph = M.toList 36 | 37 | -- |Get a list of the nodes of the graph. 38 | getNodes :: Graph a -> [a] 39 | getNodes = M.keys 40 | 41 | -- |Get a list of the outbound links from node @a@. 42 | getNeighbours :: Ord a => a -> Graph a -> [a] 43 | getNeighbours a g = case M.lookup a g of 44 | Nothing -> error "Vertex not found in graph -- GETNEIGHBOURS" 45 | Just ls -> ls 46 | 47 | -- |Return 'True' if and only if an edge exists between @x@ and @y@. 48 | getEdge :: Ord a => a -> a -> Graph a -> Bool 49 | getEdge x y g = case M.lookup x g of 50 | Nothing -> error "Vertex not found in graph -- GETEDGE" 51 | Just ys -> y `elem` ys 52 | 53 | -- |Add an edge between two vertices to a 'Graph'. 54 | addEdge :: Ord a => a -> a -> Graph a -> Graph a 55 | addEdge x y graph = M.adjust (y:) x graph 56 | 57 | -- |Add an undirected edge between two vertices to a WeightedGraph. 58 | addUndirectedEdge :: Ord a => a -> a -> Graph a -> Graph a 59 | addUndirectedEdge x y graph = addEdge y x (addEdge x y graph) 60 | 61 | -- |Convert an unweighted graph to its ordered pair representation. 62 | toPairRep :: Graph a -> [(a,a)] 63 | toPairRep xs = [ (a,b) | (a,bs) <- fromGraph xs, b <- bs ] 64 | 65 | -- |Convert an unweighted graph from its ordered pair representation. 66 | fromPairRep :: Ord a => [(a,a)] -> Graph a 67 | fromPairRep xs = go xs M.empty 68 | where 69 | go [] m = m 70 | go ((a,b):xs) m = go xs (M.insert a newList m) 71 | where 72 | newList = b : case M.lookup a m of 73 | Nothing -> [] 74 | Just l -> l 75 | 76 | -- |Take an unweighted graph in ordered pair representation and add in all of 77 | -- the reverse links, so that the resulting graph is directed. 78 | symmetrize :: Eq a => [(a,a)] -> [(a,a)] 79 | symmetrize xs = L.nub $ concat [ [(a,b),(b,a)] | (a,b) <- xs ] 80 | 81 | -- |Parse an unweighted graph from a string. The string must be a semicolon- 82 | -- separated list of associations between nodes and neighours. Each association 83 | -- has the head node on the left, followed by a colon, followed by a list of 84 | -- neighbours, for example: 85 | -- 86 | -- > "A: B C; B: C D; C: D" 87 | -- 88 | -- It is not necessary to specify reverse links - they will be added 89 | -- automatically. 90 | parseGraph :: String -> Graph String 91 | parseGraph str = toUndirectedGraph $ textToStr $ splitNbrs $ 92 | parseNodes $ splitNodes $ T.pack str 93 | where 94 | splitNodes = map T.strip . T.split (== ';') 95 | parseNodes = map listToPair . map (T.split (== ':')) 96 | splitNbrs = map (\(x,y) -> (x, T.words y)) 97 | textToStr = map (\(x,y) -> (T.unpack x, map T.unpack y)) 98 | listToPair [x,y] = (x,y) -------------------------------------------------------------------------------- /src/AI/Util/Matrix.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | 3 | module AI.Util.Matrix where 4 | 5 | import Control.Monad.Random hiding (fromList) 6 | import Foreign.Storable (Storable) 7 | import Numeric.LinearAlgebra 8 | 9 | -- |Return the size of a matrix as a 2-tuple. 10 | size :: Matrix a -> (Int,Int) 11 | size x = (rows x, cols x) 12 | 13 | -- |Concatenate matrices horizontally. 14 | horzcat :: Element a => [Matrix a] -> Matrix a 15 | horzcat = fromBlocks . return 16 | 17 | -- |Concatenate matrices vertically. 18 | vertcat :: Element a => [Matrix a] -> Matrix a 19 | vertcat = fromBlocks . map return 20 | 21 | -- |Add a column of ones to a matrix. 22 | addOnes :: Matrix Double -> Matrix Double 23 | addOnes x = fromBlocks [[1, x]] 24 | 25 | -- |Create a row matrix. 26 | row :: [Double] -> Matrix Double 27 | row = asRow . fromList 28 | 29 | -- |Create a column matrix. 30 | column :: [Double] -> Matrix Double 31 | column = asColumn . fromList 32 | 33 | -------------------------- 34 | -- Functions on Vectors -- 35 | -------------------------- 36 | 37 | takeVector :: Storable a => Int -> Vector a -> Vector a 38 | takeVector n v = subVector 0 n v 39 | 40 | dropVector :: Storable a => Int -> Vector a -> Vector a 41 | dropVector n v = subVector n (dim v - n) v 42 | 43 | sumVector :: (Num a, Storable a) => Vector a -> a 44 | sumVector xs = foldVector (+) 0 xs 45 | 46 | prodVector :: (Num a, Storable a) => Vector a -> a 47 | prodVector xs = foldVector (*) 1 xs 48 | 49 | --------------------------- 50 | -- Functions on Matrices -- 51 | --------------------------- 52 | 53 | mapRows :: Element a => (Vector a -> b) -> Matrix a -> [b] 54 | mapRows f m = map f (toRows m) 55 | 56 | mapCols :: Element a => (Vector a -> b) -> Matrix a -> [b] 57 | mapCols f m = map f (toColumns m) 58 | 59 | eachRow :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 60 | eachRow f = fromRows . mapRows f 61 | 62 | eachCol :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 63 | eachCol f = fromColumns . mapCols f 64 | 65 | sumRows :: (Element a, Num (Vector a)) => Matrix a -> Vector a 66 | sumRows m = sum $ toRows m 67 | 68 | sumCols :: (Element a, Num (Vector a)) => Matrix a -> Vector a 69 | sumCols m = sum $ toColumns m 70 | 71 | sumMatrix :: Matrix Double -> Double 72 | sumMatrix = sumVector . sum . toRows 73 | 74 | ------------------------ 75 | -- Subset Referencing -- 76 | ------------------------ 77 | 78 | subRefVec :: Storable a => Vector a -> [Int] -> Vector a 79 | subRefVec v is = fromList $ map (v@>) is 80 | 81 | subRefRows :: Element a => Matrix a -> [Int] -> Matrix a 82 | subRefRows m is = fromRows $ map (r!!) is where r = toRows m 83 | 84 | subRefCols :: Element a => Matrix a -> [Int] -> Matrix a 85 | subRefCols m is = fromColumns $ map (c!!) is where c = toColumns m 86 | 87 | -------------------------------------------------------------------------------- /src/AI/Util/ProbDist.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | 3 | module AI.Util.ProbDist where 4 | 5 | import Control.Applicative 6 | import Control.Monad 7 | import Control.Monad.Random 8 | import Data.Map (Map) 9 | import GHC.Float 10 | 11 | import qualified Data.List as L 12 | import qualified Data.Map as M 13 | 14 | type Prob = Float 15 | 16 | data Dist a = D { unD :: [(a,Prob)] } 17 | 18 | instance Functor Dist where 19 | fmap f (D xs) = D [ (f x,p) | (x,p) <- xs ] 20 | 21 | instance Applicative Dist where 22 | pure x = D [(x,1)] 23 | (D fs) <*> (D xs) = D $ [ (f x,p*q) | (f,q) <- fs, (x,p) <- xs ] 24 | 25 | instance Monad Dist where 26 | return x = D [(x,1)] 27 | (D xs) >>= f = D [ (y,p*q) | (x,p) <- xs, (y,q) <- unD (f x) ] 28 | 29 | -- |Map over the values of a probability distribution. 30 | mapD :: (a -> b) -> Dist a -> Dist b 31 | mapD = fmap 32 | 33 | -- |Map over the probabilities of a probability distribution. 34 | mapP :: (Prob -> Prob) -> Dist a -> Dist a 35 | mapP f (D xs) = D (map (\(x,p) -> (x,f p)) xs) 36 | 37 | -- |Filter the values of a probability distribution according to some predicate. 38 | filterD :: (a -> Bool) -> Dist a -> Dist a 39 | filterD test (D xs) = D $ filter (test . fst) xs 40 | 41 | -- |Return the distribution that results from conditioning on some predicate. 42 | -- Note that if the condition is not satisfied by any member of the initial 43 | -- distribution, this procedure will give nonsensical results. 44 | (|||) :: Dist a -> (a -> Bool) -> Dist a 45 | p ||| condition = normalize (filterD condition p) 46 | 47 | -- |Return the probability that the predicate is satisfied by a distribution. 48 | (??) :: (a -> Bool) -> Dist a -> Prob 49 | test ?? dist = sum . probs $ filterD test dist 50 | 51 | -- |Return a list of all the values in a distribution. 52 | vals :: Dist a -> [a] 53 | vals (D xs) = map fst xs 54 | 55 | -- |Return a list of all the probabilities in a distribution. 56 | probs :: Dist a -> [Prob] 57 | probs (D xs) = map snd xs 58 | 59 | -- |Check if a given @Dist@ values satisfies the conditions required to be a 60 | -- probability distribution, i.e. the probability associated with each element 61 | -- is positive, and the probabilities sum to 1. 62 | isDist :: Dist a -> Bool 63 | isDist p@(D xs) = firstAxiomHolds && secondAxiomHolds 64 | where 65 | firstAxiomHolds = all (\x -> snd x >= 0) xs 66 | secondAxiomHolds = sum (probs p) =~ 1 67 | 68 | -- |Normalize a probability distribution. It may be necessary to call this after 69 | -- filtering some values from a distribution. 70 | normalize :: Dist a -> Dist a 71 | normalize p = if total =~ 1.0 72 | then p 73 | else mapP (/total) p where total = sum (probs p) 74 | 75 | -- |Collect equal values in a probability distribution. Since we cannot restrict 76 | -- the values of a probability distribution to those with an @Eq@ instance, it 77 | -- may sometimes be necessary to call this function to avoid explosive growth 78 | -- in the number of elements in a distribution. 79 | collect :: (Ord a) => Dist a -> Dist a 80 | collect (D xs) = D $ M.toList $ M.fromListWith (+) xs 81 | 82 | -- |Apply Bayes rule to a distribution. This is really just a convenience 83 | -- function. It normalizes the distribution, and collects results. Typically 84 | -- you would use it immediately after a @do@ block that uses the 'condition' 85 | -- function to filter out unnnecessary values. 86 | bayes :: Ord a => Dist a -> Dist a 87 | bayes = collect . normalize 88 | 89 | -- |Condition over an event in a do block. This can be used to filter out 90 | -- unwanted results from a distribution. Note that if any filtering is done, 91 | -- the distribution returned will be unnormalized. This is equivalent to 92 | -- 'guard' - in fact we could just make 'Dist' into an instance of 'MonadPlus'. 93 | condition :: Bool -> Dist () 94 | condition True = return () 95 | condition False = D [] 96 | 97 | ------------------- 98 | -- Show Instance -- 99 | ------------------- 100 | 101 | instance Show a => Show (Dist a) where 102 | show (D xs) = concat $ L.intersperse "\n" $ map disp xs 103 | where 104 | disp (x,p) = show x ++ replicate (pad x) ' ' ++ showProb p 105 | pad x = n - length (show x) + 2 106 | n = maximum $ map (length . show . fst) xs 107 | 108 | showProb :: Prob -> String 109 | showProb p = show intPart ++ "." ++ show fracPart ++ "%" 110 | where 111 | digits = round (1000 * p) 112 | intPart = digits `div` 10 113 | fracPart = digits `mod` 10 114 | 115 | ----------------------- 116 | -- Numeric Functions -- 117 | ----------------------- 118 | 119 | -- |Approximate floating point equality, with a tolerance of 1 part in 1000. 120 | (=~) :: (Ord a, Fractional a) => a -> a -> Bool 121 | x =~ y = abs (x/y - 1) < 0.001 122 | 123 | ---------------------- 124 | -- Convert to Float -- 125 | ---------------------- 126 | 127 | -- |Type class for data which can be converted to a floating point number. 128 | class ToFloat a where 129 | toFloat :: a -> Float 130 | 131 | instance ToFloat Float where 132 | toFloat = id 133 | 134 | instance ToFloat Double where 135 | toFloat = double2Float 136 | 137 | instance ToFloat Int where 138 | toFloat = fromIntegral 139 | 140 | instance ToFloat Integer where 141 | toFloat = fromIntegral 142 | 143 | -------------------------------- 144 | -- Functions on Distributions -- 145 | -------------------------------- 146 | 147 | -- |Compute the expectation of a numeric distribution. The expectation is 148 | -- defined to be 149 | -- 150 | -- > sum (x_i * p_i) for i = 1 .. end 151 | -- 152 | -- This is only defined for distributions over data that can be cast to Float. 153 | expectation :: ToFloat a => Dist a -> Prob 154 | expectation (D xs) = sum $ [ toFloat x * p | (x,p) <- xs ] 155 | 156 | -- |Compute the entropy of a distribution, returning the result in /nats/. 157 | -- Note that it is necessary to collect like results first, to ensure that 158 | -- the true entropy is calculated. 159 | entropy :: Ord a => Dist a -> Prob 160 | entropy (D xs) = negate $ sum [ if p /= 0 then p * log p else 0 | (_,p) <- xs ] 161 | 162 | -- |Compute the entropy of a distribution, returning the result in /bits/. 163 | -- Note that it is necessary to collect like results first, to ensure that 164 | -- the true entropy is calculated. 165 | entropyBits :: Ord a => Dist a -> Prob 166 | entropyBits d = entropy d / log 2 167 | 168 | ------------------------------- 169 | -- Probability Distributions -- 170 | ------------------------------- 171 | 172 | -- |A trivial probability distribution that always takes the same value. 173 | certainly :: a -> Dist a 174 | certainly = return 175 | 176 | -- |The Bernoulli distribution takes one of two possible values. 177 | bernoulli :: Prob -> a -> a -> Dist a 178 | bernoulli p a b = D [(a,p), (b,1-p)] 179 | 180 | -- |Bernoulli distribution over a boolean variable. 181 | boolD :: Prob -> Dist Bool 182 | boolD p = bernoulli p True False 183 | 184 | -- |A uniform distribution over a finite list assigns equal probability to each 185 | -- of the elements of the list. 186 | uniform :: [a] -> Dist a 187 | uniform xs = D $ zip xs (repeat p) where p = 1 / fromIntegral (length xs) 188 | 189 | -- |A weighted distribution over a finite list. The weights give the relative 190 | -- probabilities attached to each outcome. 191 | weighted :: [(a,Int)] -> Dist a 192 | weighted lst = D $ zip xs ps 193 | where 194 | (xs,ws) = unzip lst 195 | ps = map (\w -> fromIntegral w / fromIntegral (sum ws)) ws 196 | 197 | -- |Return the empirical distribution over a list, i.e. choose each element 198 | -- in proportion to how many times it appears in the list. 199 | empirical :: Ord a => [a] -> Dist a 200 | empirical xs = weighted $ M.toList $ M.fromListWith (+) $ zip xs [1..] 201 | 202 | -- |Select @n@ elements from a list without replacement. 203 | select :: Eq a => Int -> [a] -> Dist [a] 204 | select n = mapD (reverse . fst) . selectMany n 205 | 206 | -- |Select @n@ elements from a list uniformly at random without replacement, 207 | -- also returning the list of remaining elements. 208 | selectMany :: Eq a => Int -> [a] -> Dist ([a],[a]) 209 | selectMany 0 xs = return ([],xs) 210 | selectMany n xs = do 211 | (v, xs1) <- selectOne xs 212 | (vs,xs2) <- selectMany (n-1) xs1 213 | return (v:vs,xs2) 214 | 215 | -- |Select a single element from a list uniformly at random, also returning 216 | -- the list that remains. 217 | selectOne :: Eq a => [a] -> Dist (a,[a]) 218 | selectOne xs = uniform [(x, L.delete x xs) | x <- xs ] 219 | 220 | -------------------------------- 221 | -- Functions on Distributions -- 222 | -------------------------------- 223 | 224 | joinWith :: Ord c => (a -> b -> c) -> Dist a -> Dist b -> Dist c 225 | joinWith f (D xs) (D ys) = 226 | collect $ D [ (f x y, p*q) | (x,p) <- xs, (y,q) <- ys ] 227 | 228 | addD :: (Num a, Ord a) => Dist a -> Dist a -> Dist a 229 | addD = joinWith (+) 230 | 231 | subD :: (Num a, Ord a) => Dist a -> Dist a -> Dist a 232 | subD = joinWith (-) 233 | 234 | mulD :: (Num a, Ord a) => Dist a -> Dist a -> Dist a 235 | mulD = joinWith (*) 236 | 237 | sumD :: (Num a, Ord a) => [Dist a] -> Dist a 238 | sumD = L.foldl' addD (return 0) 239 | 240 | prodD :: (Num a, Ord a) => [Dist a] -> Dist a 241 | prodD = L.foldl' mulD (return 1) 242 | 243 | -------------- 244 | -- Sampling -- 245 | -------------- 246 | 247 | -- |Create a random sampler from a probability distribution. 248 | sample :: MonadRandom m => Dist a -> m a 249 | sample (D []) = error "AI.Util.ProbDist.sample called with empty distribution" 250 | sample (D xs) = do 251 | v <- getRandomR (0,1) 252 | return $ fst . head . dropWhile (\(x,p) -> p < v) $ cumulative 253 | where 254 | cumulative = scanl1 (\(x,p) (y,q) -> (y,p+q)) xs 255 | -------------------------------------------------------------------------------- /src/AI/Util/Queue.hs: -------------------------------------------------------------------------------- 1 | module AI.Util.Queue 2 | ( 3 | -- * Type class and queue functions 4 | Queue (..) 5 | , notEmpty 6 | -- * Queue instances 7 | , FifoQueue 8 | , PriorityQueue 9 | , newPriorityQueue 10 | ) where 11 | 12 | import qualified Data.Map as M 13 | 14 | -- |An abstract Queue class supporting a test for emptiness and push/pop 15 | -- functions. You can override the function 'extend' for performance reasons. 16 | class Queue q where 17 | -- |Return an empty queue. 18 | newQueue :: q a 19 | 20 | -- |Return 'True' if the queue is empty. 21 | empty :: q a -> Bool 22 | 23 | -- |Pop an element from the front of the queue, also returning 24 | -- the remaining queue. 25 | pop :: q a -> (a, q a) 26 | 27 | -- |Push a new element into the queue. 28 | push :: a -> q a -> q a 29 | 30 | -- |Push a list of elements into the queue one by one. 31 | extend :: [a] -> q a -> q a 32 | extend xs q = foldr push q xs 33 | 34 | -- |Return 'True' if a queue has any elements remaining. 35 | notEmpty :: Queue q => q a -> Bool 36 | notEmpty = not . empty 37 | 38 | -- |Lists can represent LIFO queues if 'push' conses new elements onto the 39 | -- front of the queue. 40 | instance Queue [] where 41 | newQueue = [] 42 | empty = null 43 | pop q = (head q, tail q) 44 | push = (:) 45 | extend = (++) 46 | 47 | -- |An amortized O(1) FIFO queue. We maintain a fast push operation by storing 48 | -- the front and back of the queue in separate lists. Whenever the front of the 49 | -- queue is empty, we reverse the back of the queue and put the reversed list 50 | -- at the front. Although it takes O(n) time to reverse the list, each element 51 | -- only needs to be moved once, and so the amortized time is O(1). 52 | -- 53 | -- Code adapted from Eric Kidd: 54 | -- 55 | instance Queue FifoQueue where 56 | newQueue = FifoQueue [] [] 57 | 58 | empty (FifoQueue [] []) = True 59 | empty _ = False 60 | 61 | pop (FifoQueue [] []) = error "Can't pop from an empty queue" 62 | pop (FifoQueue (x:xs) ys) = (x, FifoQueue xs ys) 63 | pop (FifoQueue [] ys) = pop (FifoQueue (reverse ys) []) 64 | 65 | push y (FifoQueue xs ys) = FifoQueue xs (y:ys) 66 | 67 | data FifoQueue a = FifoQueue [a] [a] 68 | 69 | -- |A priority queue implemented as a map. Both pop and push have O(log n) 70 | -- complexity. 71 | instance Ord k => Queue (PriorityQueue k) where 72 | newQueue = undefined 73 | empty (PQueue q _) = M.null q 74 | pop (PQueue q f) = (item, PQueue newPQ f) 75 | where ((key, (item:items)), q') = M.deleteFindMin q 76 | newPQ = if null items 77 | then q' 78 | else M.insert key items q' 79 | push x (PQueue q f) = PQueue newPQ f 80 | where key = (f x) 81 | newPQ = case M.lookup key q of 82 | Nothing -> M.insert key [x] q 83 | Just vals -> M.insert key (x:vals) q 84 | 85 | data PriorityQueue k a = PQueue { pqueue :: M.Map k [a], keyfun :: a -> k } 86 | 87 | newPriorityQueue :: (a -> k) -> PriorityQueue k a 88 | newPriorityQueue f = PQueue M.empty f 89 | -------------------------------------------------------------------------------- /src/AI/Util/Table.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ExistentialQuantification #-} 2 | 3 | -- |This module contains routines for displaying and printing tables of data. 4 | module AI.Util.Table 5 | ( Showable(..) 6 | , printTable 7 | , showTable 8 | ) where 9 | 10 | -- |A Showable is simply a box containing a value which is an instance of 'Show'. 11 | data Showable = forall a. Show a => SB a 12 | 13 | -- |To convert a 'Showable' to a 'String', just call 'show' on its contents. 14 | instance Show Showable where 15 | show (SB x) = show x 16 | 17 | -- |Print a table of data to stdout. You must supply the column width (number 18 | -- of chars) and a list of row and column names. 19 | printTable :: Int -- ^ Column width 20 | -> [[Showable]] -- ^ Data 21 | -> [String] -- ^ Column names (including the 0th column) 22 | -> [String] -- ^ Row names 23 | -> IO () 24 | printTable pad xs header rownames = 25 | mapM_ putStrLn (showTable pad xs header rownames) 26 | 27 | -- |Return a table as a list of strings, one row per line. This routine is 28 | -- called by 'printTable' 29 | showTable :: Int -- ^ Column width 30 | -> [[Showable]] -- ^ Data 31 | -> [String] -- ^ Column names 32 | -> [String] -- ^ Row names 33 | -> [String] 34 | showTable pad xs header rownames = 35 | let dashes = replicate (length header) (replicate pad '-') 36 | hzline = showRow pad "+" dashes 37 | hdline = showRow pad "|" header 38 | rows' = zipWith (:) rownames (map (map show) xs) 39 | rows = map (showRow pad "|") rows' 40 | in [hzline,hdline,hzline] ++ rows ++ [hzline] 41 | 42 | -- |Convert a single row of a table to a string, padding each cell so that 43 | -- it is of uniform width. 44 | showRow :: Int -> String -> [String] -> String 45 | showRow pad sep xs = sep ++ (concatMap showCell cells) 46 | where 47 | trim pad str = let n = length str 48 | m = max 0 (pad - n) 49 | in take pad str ++ replicate m ' ' 50 | showCell cel = cel ++ sep 51 | cells = map (trim pad) xs 52 | 53 | -------------------------------------------------------------------------------- /src/AI/Util/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | 3 | module AI.Util.Util where 4 | 5 | import qualified Data.List as L 6 | import qualified Data.Map as M 7 | import qualified Data.Ord as O 8 | import qualified System.Random as R 9 | 10 | import Control.Concurrent.STM 11 | import Control.DeepSeq 12 | import Control.Monad 13 | import Control.Monad.Error 14 | import Control.Monad.Random 15 | import Data.Map (Map, (!)) 16 | import System.CPUTime 17 | import System.Random 18 | import System.Timeout 19 | 20 | ----------------------- 21 | -- Numeric Functions -- 22 | ----------------------- 23 | 24 | -- |Positive infinity. 25 | posInf :: Fractional a => a 26 | posInf = 1/0 27 | 28 | -- |Negative infinity. 29 | negInf :: Fractional a => a 30 | negInf = -1/0 31 | 32 | -- |Return the mean of a list of numbers 33 | mean :: Fractional a => [a] -> a 34 | mean xs = total / fromInteger len 35 | where 36 | (total,len) = L.foldl' k (0,0) xs 37 | k (!s,!n) x = (s+x, n+1) 38 | 39 | --------------------- 40 | -- Maybe Functions -- 41 | --------------------- 42 | 43 | -- |Return 'True' if a 'Maybe' value is 'Nothing', else 'False'. 44 | no :: Maybe a -> Bool 45 | no Nothing = True 46 | no _ = False 47 | 48 | --------------------- 49 | -- Tuple Functions -- 50 | --------------------- 51 | 52 | -- |Return the first element of a 3-tuple. 53 | fst3 :: (a,b,c) -> a 54 | fst3 (a,_,_) = a 55 | 56 | -- |Return the second element of a 3-tuple. 57 | snd3:: (a,b,c) -> b 58 | snd3 (_,b,_) = b 59 | 60 | -- |Return the third element of a 3-tuple. 61 | thd3 :: (a,b,c) -> c 62 | thd3 (_,_,c) = c 63 | 64 | ------------------ 65 | -- Enumerations -- 66 | ------------------ 67 | 68 | enum :: (Enum b, Bounded b) => [b] 69 | enum = [minBound .. maxBound] 70 | 71 | -------------------- 72 | -- List Functions -- 73 | -------------------- 74 | 75 | -- |Return 'True' if the list is not null. 76 | notNull :: [a] -> Bool 77 | notNull = not . null 78 | 79 | -- |Return the elements of a list at the specified indexes. 80 | elemsAt :: [a] -> [Int] -> [a] 81 | elemsAt as is = map (as!!) is 82 | 83 | -- |Update the element at position i in a list. 84 | insert :: Int -> a -> [a] -> [a] 85 | insert 0 n (_:xs) = n : xs 86 | insert i n (x:xs) = x : insert (i-1) n xs 87 | 88 | -- |Delete every occurence of this element from the list 89 | deleteEvery :: Eq a => a -> [a] -> [a] 90 | deleteEvery x [] = [] 91 | deleteEvery x (y:ys) = if y == x then deleteEvery x ys else y : deleteEvery x ys 92 | 93 | -- |Delete all the elements of the first list from the second list 94 | deleteAll :: Eq a => [a] -> [a] -> [a] 95 | deleteAll xs [] = [] 96 | deleteAll xs (y:ys) = if y `elem` xs then deleteAll xs ys else y : deleteAll xs ys 97 | 98 | -- |Return a list of all (ordered) pairs of elements of a list. 99 | orderedPairs :: [a] -> [(a,a)] 100 | orderedPairs xs = [ (x,y) | x <- xs, y <- xs ] 101 | 102 | -- |Return a list of all (unordered) pairs of elements from a list. 103 | unorderedPairs :: [a] -> [(a,a)] 104 | unorderedPairs [] = [] 105 | unorderedPairs (x:xs) = [ (x,y) | y <- xs ] ++ unorderedPairs xs 106 | 107 | -- |Returns a list of pairs. Each pair consists of an element from the list, 108 | -- and the rest of the list with the element removed. This is useful for 109 | -- deleting elements from a list where no 'Eq' instance is defined on elements 110 | -- (eg function types). 111 | -- 112 | -- >>> points [1,2,3] 113 | -- [(1,[2,3]),(2,[1,3]),(3,[1,2])] 114 | points :: [a] -> [(a,[a])] 115 | points [] = [] 116 | points (a:as) = (a,as) : [ (b,a:bs) | (b,bs) <- points as ] 117 | 118 | -- |Return 'True' if all elements of the list are equal. 119 | allEqual :: Eq a => [a] -> Bool 120 | allEqual (a:as) = all (==a) as 121 | 122 | -- |Return the most common value in a list. 123 | mode :: Ord b => [b] -> b 124 | mode xs = fst $ L.maximumBy (O.comparing snd) $ 125 | map (\a -> (head a, length a)) $ 126 | L.group $ L.sort xs 127 | 128 | -- |Return 'True' if the first set is a subset of the second, i.e. if every 129 | -- element of the first set is also an element of the second set. 130 | isSubSet :: Eq a => [a] -> [a] -> Bool 131 | xs `isSubSet` ys = all (`elem` ys) xs 132 | 133 | -- |Given a list x :: [a], return a new list y :: [(Int,a)] which pairs every 134 | -- element of the list with its position. 135 | enumerate :: [a] -> [(Int,a)] 136 | enumerate = zip [0..] 137 | 138 | -- |Count the number of elements in a list that satisfy a predicate. 139 | countIf :: (a -> Bool) -> [a] -> Int 140 | countIf p xs = length (filter p xs) 141 | 142 | -- |Return the element of a list that minimises a function. In case of a tie, 143 | -- return the element closest to the front of the list. 144 | argMin :: Ord b => [a] -> (a -> b) -> a 145 | argMin xs f = L.minimumBy (O.comparing f) xs 146 | 147 | -- |Return a list of all elements that minimise a given function. 148 | argMinList :: Ord b => [a] -> (a -> b) -> [a] 149 | argMinList xs f = map (xs!!) indices 150 | where 151 | ys = map f xs 152 | minVal = minimum ys 153 | indices = L.findIndices (== minVal) ys 154 | 155 | -- |Return the element of a list that minimizes a function. In case of a tie, 156 | -- choose randomly with the given generator. 157 | argMinRandom :: (Ord b, RandomGen g) => g -> [a] -> (a -> b) -> (a, g) 158 | argMinRandom g xs f = randomChoice g (argMinList xs f) 159 | 160 | -- |Return the element of a list that minimizes a function. In case of a tie, 161 | -- choose randomly. 162 | argMinRandomIO :: Ord b => [a] -> (a -> b) -> IO a 163 | argMinRandomIO xs f = getStdGen >>= \g -> return $ fst $ argMinRandom g xs f 164 | 165 | -- |Return the element of the target list that maximises a function. 166 | argMax :: (Ord b, Num b) => [a] -> (a -> b) -> a 167 | argMax xs f = argMin xs (negate . f) 168 | 169 | -- |Return a list of all elements that maximise a given function. 170 | argMaxList :: (Ord b, Num b) => [a] -> (a -> b) -> [a] 171 | argMaxList xs f = argMinList xs (negate . f) 172 | 173 | -- |Return the element of a list that maximises a function. In case of a tie, 174 | -- choose randomly with the given generator. 175 | argMaxRandom :: (Ord b, Num b, RandomGen g) => g -> [a] -> (a -> b) -> (a, g) 176 | argMaxRandom g xs f = argMinRandom g xs (negate . f) 177 | 178 | -- |Return the element of a list that maximises a function. In case of a tie, 179 | -- choose randomly. 180 | argMaxRandomIO :: (Ord b, Num b) => [a] -> (a -> b) -> IO a 181 | argMaxRandomIO xs f = argMinRandomIO xs (negate . f) 182 | 183 | -- |Create a function from a list of (argument, value) pairs. 184 | listToFunction :: (Ord a) => [(a,b)] -> a -> b 185 | listToFunction xs = (M.fromList xs !) 186 | 187 | -- |Transpose a list of lists. 188 | transpose :: [[a]] -> [[a]] 189 | transpose xs = if or (map null xs) 190 | then [] 191 | else let heads = map head xs 192 | tails = map tail xs 193 | in heads : transpose tails 194 | 195 | -- |Unsafe look up of a variable in an association list. 196 | (%!) :: Eq a => [(a,b)] -> a -> b 197 | (%!) as a = case lookup a as of 198 | Nothing -> error "Variable not found in list -- AI.Util.Util.%!" 199 | Just b -> b 200 | 201 | -- |Return all lists of 'Bool' of length @n@. For example, 202 | -- 203 | -- >>> bools 2 204 | -- [[True,True],[True,False],[False,True],[False,False]] 205 | -- 206 | -- The returned list has length @2 ^ n@. 207 | bools :: Int -> [[Bool]] 208 | bools 0 = [[]] 209 | bools n = do 210 | x <- [True, False] 211 | xs <- bools (n-1) 212 | return (x:xs) 213 | 214 | -- |Return all subsets of a list. 215 | subsets :: [a] -> [[a]] 216 | subsets = filterM $ const [True,False] 217 | 218 | ------------------ 219 | -- String Utils -- 220 | ------------------ 221 | 222 | -- |Remove leading whitespace (spaces or tabs). 223 | lstrip :: String -> String 224 | lstrip = dropWhile (`elem` " \t") 225 | 226 | -- |Remove trailing whitespace (spaces or tabs). 227 | rstrip :: String -> String 228 | rstrip = reverse . lstrip . reverse 229 | 230 | -- |Remove both leading and trailing whitespace (spaces or tabs). 231 | strip :: String -> String 232 | strip = rstrip . lstrip 233 | 234 | -- |Join a list of strings, separating them with commas. 235 | commaSep :: [String] -> String 236 | commaSep xs = concat $ L.intersperse "," xs 237 | 238 | ------------------- 239 | -- Map Functions -- 240 | ------------------- 241 | 242 | -- |A universal map maps all keys to the same value. 243 | mkUniversalMap :: Ord k => [k] -> a -> Map k a 244 | mkUniversalMap ks a = M.fromList $ zip ks (repeat a) 245 | 246 | ------------------------- 247 | -- Monadic Combinators -- 248 | ------------------------- 249 | 250 | -- |Monadic 'when' statement. 251 | whenM :: Monad m => m Bool -> m () -> m () 252 | whenM test s = test >>= \p -> when p s 253 | 254 | -- |Monadic ternary 'if' statement. 255 | ifM :: Monad m => m Bool -> m a -> m a -> m a 256 | ifM test a b = test >>= \p -> if p then a else b 257 | 258 | -- |Run a REPL-style computation continuously. 259 | untilM :: Monad m => (t -> Bool) -> m t -> (t -> m ()) -> m () 260 | untilM predicate prompt action = do 261 | result <- prompt 262 | if predicate result 263 | then return () 264 | else action result >> untilM predicate prompt action 265 | 266 | -- |Run a computation, ignoring the result (i.e. run it only for its side 267 | -- effects). 268 | ignoreResult :: Monad m => m a -> m () 269 | ignoreResult c = c >> return () 270 | 271 | -- |Ensure that a monadic computation doesn't throw any errors. 272 | trapError :: MonadError e m => m () -> m () 273 | trapError c = c `catchError` \_ -> return () 274 | 275 | -------------------------- 276 | -- Random Numbers (New) -- 277 | -------------------------- 278 | 279 | -- |Chooses a single element from a list at random, returning the element 280 | -- chosen and the rest of the list. 281 | selectOne :: Eq a => RandomGen g => [a] -> Rand g (a, [a]) 282 | selectOne xs = do 283 | let n = length xs 284 | i <- getRandomR (0,n-1) 285 | let x = xs !! i 286 | return (x, L.delete x xs) 287 | 288 | -- |Select a number of elements from a list at random, returning the elements 289 | -- chosen the the rest of the list. 290 | selectMany' :: Eq a => RandomGen g => Int -> [a] -> Rand g ([a], [a]) 291 | selectMany' 0 xs = return ([], xs) 292 | selectMany' k xs = do 293 | (y, xs') <- selectOne xs 294 | (ys, xs'') <- selectMany' (k-1) xs' 295 | return (y:ys, xs'') 296 | 297 | -- |Select a number of elements from a list at random, returning the elements 298 | -- chosen. 299 | selectMany :: Eq a => RandomGen g => Int -> [a] -> Rand g [a] 300 | selectMany k = fmap fst . selectMany' k 301 | 302 | -- |Choose a random element from a list. 303 | sampleOne :: RandomGen g => [a] -> Rand g a 304 | sampleOne [] = error "Empty list -- SAMPLEONE" 305 | sampleOne xs = do 306 | n <- getRandomR (0, length xs - 1) 307 | return (xs !! n) 308 | 309 | -- |Choose @n@ elements with replacement from a list. 310 | sampleWithReplacement :: RandomGen g => Int -> [a] -> Rand g [a] 311 | sampleWithReplacement 0 xs = return [] 312 | sampleWithReplacement n xs = do 313 | y <- sampleOne xs 314 | ys <- sampleWithReplacement (n-1) xs 315 | return (y:ys) 316 | 317 | -- |Generate a random variable from the 'Enum' and 'Bounded' type class. The 318 | -- 'Int' input specifies how many values are in the enumeration. 319 | getRandomEnum :: (RandomGen g, Enum a, Bounded a) => Int -> Rand g a 320 | getRandomEnum i = getRandomR (0,i-1) >>= return . toEnum 321 | 322 | -------------------------- 323 | -- Random Numbers (Old) -- 324 | -------------------------- 325 | 326 | -- |Choose a random element from a list, given a generator. 327 | randomChoice :: RandomGen g => g -> [a] -> (a, g) 328 | randomChoice g [] = error "Empty list -- RANDOMCHOICE" 329 | randomChoice g xs = (xs !! n, next) 330 | where 331 | (n, next) = randomR (0, length xs - 1) g 332 | 333 | -- |Choose a random element from a list, in the IO monad. 334 | randomChoiceIO :: [a] -> IO a 335 | randomChoiceIO xs = getStdGen >>= \g -> return $ fst $ randomChoice g xs 336 | 337 | -- |Given a random number generator, return 'True' with probability p. 338 | probability :: (RandomGen g, Random a, Ord a, Num a) => g -> a -> (Bool, g) 339 | probability g p = if p' < p then (True, g') else (False, g') 340 | where 341 | (p', g') = R.randomR (0,1) g 342 | 343 | -- |Return @True@ with probability p. 344 | probabilityIO :: (R.Random a, Ord a, Num a) => a -> IO Bool 345 | probabilityIO p = randomIO >>= \q -> return $! if q < p then True else False 346 | 347 | -------------------- 348 | -- IO Combinators -- 349 | -------------------- 350 | 351 | -- |Read a line from stdin and return it. 352 | readPrompt :: IO String 353 | readPrompt = putStr "> " >> getLine 354 | 355 | -- |Compute a pure value and return it along with the number of microseconds 356 | -- taken for the computation. 357 | timed :: (NFData a) => a -> IO (a, Int) 358 | timed x = do 359 | t1 <- getCPUTime 360 | r <- return $!! x 361 | t2 <- getCPUTime 362 | let diff = fromIntegral (t2 - t1) `div` 1000000 363 | return (r, diff) 364 | 365 | -- |Given a time limit (in microseconds) and a list, compute as many elements 366 | -- of the list as possible within the time limit. 367 | timeLimited :: (NFData a) => Int -> [a] -> IO [a] 368 | timeLimited t xs = do 369 | v <- newTVarIO [] 370 | timeout t (forceIntoTVar v xs) 371 | readTVarIO v 372 | 373 | -- |Compute the elements of a list one by one, consing them onto the front 374 | -- of a @TVar@ as they are computed. Note that the result list will be 375 | -- in reverse order. 376 | forceIntoTVar :: (NFData a) => TVar [a] -> [a] -> IO () 377 | forceIntoTVar v xs = mapM_ (forceCons v) xs 378 | 379 | -- |Force a pure value, and cons it onto the front of a list stored in a @TVar@. 380 | forceCons :: (NFData a) => TVar [a] -> a -> IO () 381 | forceCons v x = x `deepseq` atomically $ modifyTVar2 v (x:) 382 | 383 | -- |Modify the value of a transactional variable 384 | modifyTVar2 :: TVar a -> (a -> a) -> STM () 385 | modifyTVar2 v f = readTVar v >>= writeTVar v . f 386 | -------------------------------------------------------------------------------- /src/AI/Util/WeightedGraph.hs: -------------------------------------------------------------------------------- 1 | module AI.Util.WeightedGraph 2 | ( WeightedGraph(..) 3 | , toGraph 4 | , toUndirectedGraph 5 | , getNodes 6 | , getNeighbours 7 | , getEdge 8 | , addEdge 9 | , addUndirectedEdge 10 | , writeGraphs 11 | , readGraphs ) where 12 | 13 | import Control.Monad (forM_) 14 | import Data.Map (Map, (!)) 15 | import qualified Data.Map as M 16 | import qualified Data.List as L 17 | import qualified Data.Text as T 18 | import System.IO 19 | 20 | --------------------- 21 | -- Weighted Graphs -- 22 | --------------------- 23 | 24 | -- |A weighted graph connects vertices (nodes) by edges (actions). Each edge has 25 | -- a weight associated with it. To build a graph, call one of the 26 | -- functions 'toGraph' (for a directed graph) and 'toUndirectedGraph' (for 27 | -- an undirected graph). 28 | type WeightedGraph a b = Map a (Map a b) 29 | 30 | -- |Create a directed graph from an adjacency list. 31 | toGraph :: Ord a => [(a, [(a,b)])] -> WeightedGraph a b 32 | toGraph xs = M.fromList (map f xs) 33 | where 34 | f (a,bs) = (a, M.fromList bs) 35 | 36 | -- |Create an undirected graph from an adjacency list. The inverse links will 37 | -- be added automatically. 38 | toUndirectedGraph :: (Ord a, Eq b) => [(a,[(a,b)])] -> WeightedGraph a b 39 | toUndirectedGraph conn = fromPairRep . symmetrize . toPairRep $ toGraph conn 40 | 41 | -- |Get a list of the nodes of the graph. 42 | getNodes :: WeightedGraph a b -> [a] 43 | getNodes = M.keys 44 | 45 | -- |Get a list of the outbound links from node @a@. 46 | getNeighbours :: Ord a => a -> WeightedGraph a b -> [(a,b)] 47 | getNeighbours a g = case M.lookup a g of 48 | Nothing -> error "Vertex not found in graph -- GETNEIGHBOURS" 49 | Just ls -> M.toList ls 50 | 51 | -- |Get the weight attached to the edge between @x@ and @y@. 52 | getEdge :: Ord a => a -> a -> WeightedGraph a b -> Maybe b 53 | getEdge x y g = case M.lookup x g of 54 | Nothing -> error "Vertex not found in graph -- GETEDGE" 55 | Just ys -> M.lookup y ys 56 | 57 | -- |Add an edge between two vertices to a WeightedGraph. 58 | addEdge :: Ord a => a -> a -> b -> WeightedGraph a b -> WeightedGraph a b 59 | addEdge x y e graph = M.adjust (M.insert y e) x graph 60 | 61 | -- |Add an undirected edge between two vertices to a WeightedGraph. 62 | addUndirectedEdge :: Ord a => a -> a -> b -> WeightedGraph a b -> WeightedGraph a b 63 | addUndirectedEdge x y e graph = addEdge y x e (addEdge x y e graph) 64 | 65 | ---------------- 66 | -- Read/Write -- 67 | ---------------- 68 | 69 | writeGraphs :: Show g => FilePath -> [g] -> IO () 70 | writeGraphs filename gs = do 71 | h <- openFile filename WriteMode 72 | forM_ gs $ \g -> hPrint h g 73 | hClose h 74 | 75 | readGraphs :: (Ord a,Read a,Read b) => FilePath -> IO [WeightedGraph a b] 76 | readGraphs filename = do 77 | contents <- readFile filename 78 | return $ map read $ lines contents 79 | 80 | 81 | ---------------------- 82 | -- Helper functions -- 83 | ---------------------- 84 | 85 | -- |Convert a graph to its adjacency list representation. 86 | toAdjacencyList :: WeightedGraph a b -> [(a, [(a,b)])] 87 | toAdjacencyList xs = map g (M.toList xs) 88 | where 89 | g (a,bs) = (a, M.toList bs) 90 | 91 | -- |Convert a graph to its ordered pair representation. 92 | toPairRep :: WeightedGraph a b -> [(a,a,b)] 93 | toPairRep xs = [ (a,b,c) | (a,bs) <- toAdjacencyList xs, (b,c) <- bs ] 94 | 95 | -- |Convert a graph from its ordered pair representation. 96 | fromPairRep :: (Ord a) => [(a,a,b)] -> WeightedGraph a b 97 | fromPairRep xs = go xs M.empty 98 | where 99 | go [] m = m 100 | go ((a,b,c):xs) m = go xs (M.insert a newMap m) 101 | where 102 | newMap = M.insert b c $ case M.lookup a m of 103 | Nothing -> M.empty 104 | Just m' -> m' 105 | 106 | -- |Take a directed graph in ordered pair representation and add in all of the 107 | -- reverse links, so that the resulting graph is undirected. 108 | symmetrize :: (Eq a, Eq b) => [(a,a,b)] -> [(a,a,b)] 109 | symmetrize xs = L.nub $ concat [ [(a,b,c),(b,a,c)] | (a,b,c) <- xs ] 110 | --------------------------------------------------------------------------------