├── .travis.yml ├── .gitignore ├── project ├── plugins.sbt ├── .zip └── .ensime ├── report3.pdf ├── src ├── main │ └── scala │ │ ├── learning │ │ ├── Learning.scala │ │ ├── DeepVLearning.scala │ │ ├── ExpReplay.scala │ │ ├── TDLambda.scala │ │ ├── NoveltyVLearning.scala │ │ ├── DeepQLearning.scala │ │ ├── QLearning.scala │ │ ├── OfflineRL.scala │ │ ├── NoveltyAEVLearning.scala │ │ ├── NoveltyExploration.scala │ │ └── DeepExploration.scala │ │ ├── policy │ │ ├── UCT.scala │ │ ├── TreeSearchP.scala │ │ ├── MCST.scala │ │ ├── MCTS.scala │ │ ├── Policy.scala │ │ ├── VPolicy.scala │ │ ├── TreeSearch.scala │ │ ├── QPolicy.scala │ │ └── VNovelty.scala │ │ ├── Rand.scala │ │ ├── backend │ │ ├── SeparableCompGraph.scala │ │ ├── Backends.scala │ │ ├── Backend.scala │ │ └── BuildNN.scala │ │ ├── Charts.scala │ │ ├── mdp │ │ ├── Game2048.scala │ │ ├── MDP.scala │ │ ├── GameDeepDQN.scala │ │ ├── Game6561.scala │ │ └── Grid6561.scala │ │ ├── SelfPlay.scala │ │ └── Conf.scala └── example │ └── scala │ ├── TestDeep.scala │ ├── RLApp.scala │ └── TestNovelty.scala ├── LICENSE └── README.md /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.11.7 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | maxconf* 2 | .ensime_cache/ 3 | target/ 4 | #* 5 | .#* -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 2 | -------------------------------------------------------------------------------- /project/.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rubenfiszel/scala-drl/HEAD/project/.zip -------------------------------------------------------------------------------- /report3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rubenfiszel/scala-drl/HEAD/report3.pdf -------------------------------------------------------------------------------- /src/main/scala/learning/Learning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl.mdp.MDP._ 4 | import drl.backend.Backend._ 5 | 6 | abstract class RL[S: Statable, B: NeuralN]() extends Iterator[PreBatch[S]] { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/policy/UCT.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | 5 | case class UCT(branches: Int = 100) extends PolicyV { 6 | 7 | def nextAction[S: Valuable](s: S) = { 8 | val mct = new MCTree(s) 9 | for (i <- 1 to branches) 10 | mct.iter() 11 | mct.bestMove 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala/Rand.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import util.Random 4 | 5 | object Rand { 6 | 7 | var r = new Random(1234) 8 | 9 | def setSeed(i: Int) = 10 | r.setSeed(i) 11 | 12 | def choose[A](s: Seq[A]):A = 13 | s(r.nextInt(s.length)) 14 | 15 | def nextInt(n: Int) = 16 | r.nextInt(n) 17 | 18 | def nextBool(f: Float) = 19 | r.nextFloat() < f 20 | 21 | def nextFloat = 22 | r.nextFloat 23 | 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/policy/TreeSearchP.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | 5 | case class TreeSearchP(depth: Int=1, expectimax: Boolean = false, heuristic: Boolean = true) extends PolicyV { 6 | 7 | def nextAction[S: Valuable](s: S) = { 8 | 9 | def eval(s: S): Value = 10 | if (heuristic) 11 | s.heuristic 12 | else 13 | s.value 14 | 15 | if (expectimax) 16 | TreeSearch.expectimax(s, depth, eval)._3 17 | else 18 | TreeSearch.maxmax(s, depth, eval)._3 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/policy/MCST.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | 5 | case class MCST(branches: Int = 100) extends PolicyQ { 6 | 7 | def mtc[S: Statable](s: S, a: A) = { 8 | var r = 0f 9 | for (i <- (1 to branches)) { 10 | var (cs, rw) = s.applyTransition(a) 11 | r += rw 12 | while (cs.canContinue) { 13 | val na = RandomQ.nextAction(cs) 14 | var (ns, rw) = cs.applyTransition(na) 15 | cs = ns 16 | r += rw 17 | } 18 | } 19 | r 20 | } 21 | 22 | def nextAction[S: Statable](s: S) = { 23 | val r = s.availableActions.map(a => (a, mtc(s, a))) 24 | println(r) 25 | r.maxBy(_._2)._1 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/learning/DeepVLearning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | class DeepVLearning[S: Valuable, B: NeuralN](vconf: VConf, nconf: Either[NConf, B], deconf: DeepExplorationConf, offrlconf: OfflineRLConf) extends DeepExploration[Valuable, S, B](nconf, deconf, offrlconf){ 9 | 10 | lazy val outputWidth = 1 11 | 12 | val combined = VPolicyCombined(model, offrlconf.disc) 13 | 14 | val pols = (0 until deconf.nbHead).map(x => EpsGreedyVPolicyHead(model, x, offrlconf.maxEpsilon, offrlconf.disc)) 15 | 16 | def targetHead(l: List[SARS[S]], k:Int):Array[Array[Float]] = { 17 | val b = (Rand.nextBool(0.5f)) || deconf.nbHead == 1 18 | TDLambda.tdErr(l, k, b, model, vconf.lambda, offrlconf.gamma, offrlconf.disc).map(Array(_)) 19 | } 20 | 21 | def getNext(nb:Int) = { 22 | val k = Rand.nextInt(deconf.nbHead) 23 | sample(pols(k)) 24 | } 25 | 26 | 27 | } 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Ruben Fiszel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/main/scala/learning/ExpReplay.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.mdp.MDP._ 5 | 6 | case class ExpReplayConf( 7 | batchSize: Int, 8 | minPoolFactor: Int, 9 | maxPoolFactor: Int 10 | ) 11 | class ExpReplay[S: Statable](expconf: ExpReplayConf) { 12 | 13 | var memory: List[SARS[S]] = List() 14 | var memSize = 0 15 | 16 | def reset() = { 17 | memSize = 0 18 | memory = List() 19 | } 20 | 21 | def add(lgame: List[SARS[S]]) = { 22 | memory :::= lgame 23 | memSize += lgame.length 24 | } 25 | 26 | def isEnough = 27 | memSize >= expconf.batchSize*expconf.minPoolFactor 28 | 29 | def isTooMuch = 30 | memSize > expconf.batchSize*expconf.maxPoolFactor 31 | 32 | def removeLast = { 33 | memSize -= 1 34 | memory = memory.dropRight(1) 35 | } 36 | 37 | def clean() = { 38 | while(isTooMuch) 39 | removeLast 40 | } 41 | 42 | def get(n: Int) = { 43 | 44 | var fetchSize = 0 45 | var fetched: List[SARS[S]] = List() 46 | 47 | while (fetchSize < n) { 48 | val ind = Rand.nextInt(memory.length) 49 | fetched ::= memory(ind) 50 | fetchSize += 1 51 | } 52 | fetched 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/main/scala/learning/TDLambda.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl.mdp.MDP._ 4 | import drl.backend.Backend._ 5 | 6 | object TDLambda { 7 | 8 | def tdErr[S: Valuable, B: NeuralN](l:List[SARS[S]], k:Int, b:Boolean, model:B, lambda:Float, gamma:Float, disc:Float):Array[Float] = { 9 | 10 | val last = l.last._4 11 | val ls: List[S] = 12 | l.map(_._1) :+ last 13 | 14 | val rewards = l.map(_._3) 15 | val preEv = 16 | model.output(ls).apply(k).map(_(0)) 17 | // if (!Conf.targetNPeriod.isDefined) 18 | // else 19 | // target.output(ls).apply(k).map(_(0)) 20 | 21 | val evals = preEv.toArray 22 | 23 | if (!last.canContinue) 24 | evals(evals.length-1) = 0f 25 | 26 | val dt = (0 until ls.length-1).map(i => rewards(i)/disc + gamma*evals(i+1) - evals(i)) 27 | 28 | var r = model.output(ls).apply(k).map(_(0)).init.toArray 29 | 30 | 31 | val br = r.toIndexedSeq 32 | 33 | if (b) { 34 | 35 | for (i <- (0 until r.length) ) { 36 | var ld = 1f 37 | var j = i 38 | while (j < r.length && ld > 0.005) { 39 | r(i) += ld*dt(j) 40 | ld *= lambda 41 | j += 1 42 | } 43 | } 44 | 45 | // println(evals.toList + " " + rewards + " " + dt + "\n \n") 46 | 47 | } 48 | 49 | r 50 | } 51 | 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/learning/NoveltyVLearning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | class NoveltyVLearning[S: Valuable, B: NeuralN](vconf: VConf, nconf: Either[NConf, B], novconf: NoveltyConf, offrlconf: OfflineRLConf, earlytermination: Option[EarlyTermination] = None) extends NoveltyExploration[Valuable, S, B](nconf, novconf, offrlconf, earlytermination){ 9 | 10 | 11 | lazy val outputWidth = 1 12 | 13 | val novelty:B = buildNoveltyModel(implicitly[Statable[S]].featureSize) 14 | 15 | val pol = new EpsGreedy(offrlconf.maxEpsilon, GreedyNoveltyVPolicy[B](model, novelty, offrlconf.disc, novconf.beta)) 16 | 17 | val combined = VPolicyHead(model, 0, offrlconf.disc) 18 | 19 | def fitNovelty(l:List[SARS[S]]) = { 20 | val targets = l.map(_._4.toInput).toArray 21 | val inputs = l.map(x => x._1.inputWithAction(x._2)) 22 | novelty.buildAndFit(inputs, Array(targets), (x:Array[Float]) => x) 23 | } 24 | 25 | def target(l: List[SARS[S]]):Array[Array[Array[Float]]] = { 26 | val r = Array(TDLambda.tdErr(l, 0, true, model, vconf.lambda, offrlconf.gamma, offrlconf.disc).map(Array(_))) 27 | fitNovelty(l) 28 | // TestNovelty.maxt(l) 29 | // TestNovelty.test(novelty, None) 30 | r 31 | } 32 | 33 | 34 | def getNext(nb:Int) = { 35 | sample(pol) 36 | } 37 | 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/learning/DeepQLearning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | 9 | class DeepQLearning[S: Statable, B: NeuralN](qconf: QConf, nconf:Either[NConf, B], deconf: DeepExplorationConf, offrlconf: OfflineRLConf, earlytermination: Option[EarlyTermination] = None) extends DeepExploration[Statable, S, B](nconf, deconf, offrlconf, earlytermination) { 10 | 11 | val combined = QPolicyCombined(model, deconf.nbHead) 12 | 13 | lazy val outputWidth = 14 | implicitly[Statable[S]].allActions.length 15 | 16 | val expRepQ = new ExpReplay[S](ExpReplayConf(offrlconf.batchSize, qconf.minPoolFactor, qconf.maxPoolFactor)) 17 | 18 | val pols = (0 until deconf.nbHead).map(x => EpsGreedyQPolicyHead(model, x, offrlconf.maxEpsilon)) 19 | 20 | def targetHead(l: List[SARS[S]], k:Int):Array[Array[Float]] = { 21 | l.map(x => { 22 | val b = (Rand.nextBool(0.5f)) || deconf.nbHead == 1 23 | val r = QLearning.ql(x, b, k, model, target, deconf.targetNPeriod, qconf.zeroImpossible, offrlconf.gamma, offrlconf.disc) 24 | r 25 | }).toArray 26 | } 27 | 28 | def getNext(nb:Int) = { 29 | var fetched = List[SARS[S]]() 30 | 31 | val k = Rand.nextInt(deconf.nbHead) 32 | val (episode, score) = sample(pols(k)) 33 | 34 | if (qconf.expRep) { 35 | expRepQ.add(episode) 36 | fetched = expRepQ.get(nb) 37 | expRepQ.clean() 38 | } 39 | else { 40 | fetched = episode 41 | } 42 | (fetched, score) 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/example/scala/TestDeep.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import drl.policy._ 4 | import drl.mdp._ 5 | import drl.learning._ 6 | import drl.mdp.MDP._ 7 | import drl.backend.Backend._ 8 | 9 | 10 | object TestDeep { 11 | 12 | def chart100L[S: Statable, B: NeuralN](qconf: QConf, nconf:NConf, deconf: DeepExplorationConf, offrlconf: OfflineRLConf) = { 13 | Charts.createChart("median", "median", 1) 14 | Charts.show() 15 | var scores = List[(Int, Int)]() 16 | for (i <- (5 to 100)) { 17 | GameDeepConf.gameL = i 18 | val median = medianOver3Seeds(qconf, nconf, deconf, offrlconf) 19 | scores ::= ((i, median)) 20 | Charts.addValue(i, median, "median") 21 | } 22 | println(scores) 23 | scores 24 | } 25 | 26 | def optimalIteration(l: List[(Int, Float)]) = { 27 | var max = 2000 28 | var current: Int = l.head._1 29 | var currentL = 0 30 | for (e <- l) { 31 | if (currentL == 100 && max == 2000) 32 | max = current 33 | if (e._2 >= 9.5f) 34 | currentL += 1 35 | else { 36 | current = e._1 37 | currentL = 0 38 | } 39 | } 40 | max 41 | } 42 | def medianOver3Seeds[S: Statable, B: NeuralN](qconf: QConf, nconf:NConf, deconf: DeepExplorationConf, offrlconf: OfflineRLConf) = { 43 | val l = (0 to 2) 44 | .map(x => offrlconf.copy(seed = Rand.nextInt(10000))) 45 | .map(noffc => SelfPlay.trainModelRLDeepQ(qconf, scala.Left(nconf), deconf, noffc, Some(RepeatedValue(9.5f, 101))).getScores) 46 | .map(optimalIteration) 47 | .sorted 48 | l(1) 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/policy/MCTS.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | import drl.Rand 5 | 6 | class MCTree[S: Valuable](s: S, val action: Action = null, var total:Float = 0, var simul:Int = 0, var children: List[MCTree[S]]= List()) { 7 | 8 | def bestMove = 9 | children.maxBy(_.score).action 10 | 11 | def score = 12 | simul 13 | 14 | def iter() = { 15 | val lt = select 16 | val r = lt.last.runSimulation 17 | lt.foreach(_.add(r)) 18 | } 19 | 20 | def isEmpty = 21 | simul == 0 22 | 23 | def allStats = 24 | children.forall(!_.isEmpty) 25 | 26 | def expand() = { 27 | val aA = s.availableActions 28 | val gs = aA.map(a => (s.applyTransition(a), a)).toList 29 | children = gs.map(x => new MCTree[S](x._1._1, x._2)) 30 | } 31 | 32 | def selectScore(psimul:Int) = 33 | total.toFloat/simul + Math.sqrt(2)*1000*Math.sqrt(Math.log(psimul)/simul) 34 | 35 | def choose(ts: List[MCTree[S]]) = 36 | ts.maxBy(_.selectScore(simul)) 37 | 38 | def add(score: Float) = { 39 | total += score 40 | simul += 1 41 | } 42 | 43 | def select: List[MCTree[S]] = 44 | if (isEmpty|| !s.canContinue) 45 | List(this) 46 | else { 47 | if (children.isEmpty) 48 | expand() 49 | if (allStats) 50 | this :: choose(children).select 51 | else 52 | this :: children(Rand.nextInt(children.length)).select 53 | } 54 | 55 | def runSimulation = { 56 | var cs = s 57 | while (cs.canContinue) { 58 | cs = RandomQ.applyNext(cs)._1 59 | } 60 | cs.value 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/scala/policy/Policy.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp._ 4 | import drl.Rand 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | 9 | 10 | trait Policy[G[_] <: Statable[_]]{ 11 | 12 | def nextAction[S: G](s: S): A 13 | 14 | def applyNext[S: G: Statable](s: S): (S, R) = 15 | s.applyTransition(nextAction[S](s)) 16 | } 17 | 18 | trait PolicyQ extends Policy[Statable] 19 | trait PolicyV extends Policy[Valuable] 20 | 21 | 22 | object RandomQ extends PolicyQ { 23 | 24 | def nextAction[S: Statable](s: S) = 25 | s.randomAction 26 | 27 | } 28 | 29 | object RandomV extends PolicyV { 30 | 31 | def nextAction[S: Valuable](s: S) = 32 | s.randomAction 33 | 34 | } 35 | 36 | 37 | 38 | case class MixedPolicy[A[_] <: Statable[_]](odd: Float, p1:Policy[A], p2:Policy[A]) extends Policy[A] { 39 | 40 | def nextAction[S: A](s: S) = { 41 | if (Rand.nextBool(odd)) 42 | p1.nextAction(s) 43 | else 44 | p2.nextAction(s) 45 | } 46 | } 47 | 48 | 49 | class EpsGreedy[S: Statable, A[S] <: Statable[S]](max:Float, pol:Policy[A]) extends Policy[A] { 50 | 51 | var nb = 0 52 | 53 | def nextAction[S: A](s: S) = { 54 | nb += 1 55 | val tresh = 1- (nb.toFloat/max).min(0.99f) 56 | val rand = Rand.nextBool(tresh) 57 | if (rand) { 58 | val a = s.randomAction 59 | // println(a + " " + tresh + " " + nb) 60 | a 61 | } 62 | else 63 | pol.nextAction(s) 64 | } 65 | } 66 | 67 | 68 | case class OppPolicy[A[_] <: Statable[_]](p1:Policy[A], p2:Policy[A]) extends Policy[A] { 69 | 70 | var i = -1 71 | 72 | def nextAction[S: A](s: S) = { 73 | i += 1 74 | if (i%2 == 0) 75 | p1.nextAction(s) 76 | else 77 | p2.nextAction(s) 78 | } 79 | 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/scala/backend/SeparableCompGraph.scala: -------------------------------------------------------------------------------- 1 | package drl.backend 2 | 3 | import org.deeplearning4j.nn.graph.ComputationGraph 4 | import org.nd4j.linalg.api.ndarray.INDArray 5 | import org.nd4j.linalg.dataset._ 6 | import java.io._ 7 | import org.deeplearning4j.util._ 8 | 9 | trait SeparableCompGraph { 10 | 11 | def output(ind: INDArray): Array[INDArray] 12 | 13 | def fit(targets: MultiDataSet): Unit 14 | 15 | def cloneM(): SeparableCompGraph 16 | 17 | def save(filename: String): Boolean 18 | 19 | def outputHead(ind: INDArray, head:Int) = 20 | output(ind).apply(head) 21 | 22 | } 23 | 24 | class SingleCompGraph(cg: ComputationGraph) extends SeparableCompGraph { 25 | 26 | def output(ind: INDArray) = 27 | cg.output(ind) 28 | 29 | def fit(targets: MultiDataSet): Unit = { 30 | cg.fit(targets) 31 | } 32 | 33 | def cloneM() = 34 | new SingleCompGraph(cg.clone()) 35 | 36 | def save(filename: String) = { 37 | ModelSerializer.writeModel(cg, new File(filename), true) 38 | true 39 | } 40 | 41 | 42 | } 43 | 44 | class SeparatedCompGraph(cgs: IndexedSeq[ComputationGraph]) extends SeparableCompGraph { 45 | 46 | override def outputHead(ind: INDArray, k:Int) = 47 | cgs(k).output(ind)(0) 48 | 49 | def output(ind: INDArray) = 50 | cgs.map(_.output(ind)(0)).toArray 51 | 52 | def fit(targets: MultiDataSet): Unit = { 53 | cgs.zipWithIndex.foreach(x => { 54 | val (cg, ind) = x 55 | // println(targets.getFeatures(0)) 56 | // println(targets.getLabels(0)) 57 | cg.fit(new DataSet(targets.getFeatures(0), targets.getLabels(ind))) 58 | }) 59 | } 60 | 61 | def cloneM() = 62 | new SeparatedCompGraph(cgs.map(_.clone())) 63 | 64 | def save(filename: String) = { 65 | false 66 | } 67 | 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/example/scala/RLApp.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import drl.policy._ 4 | import drl.mdp._ 5 | import drl.learning._ 6 | import drl.backend._ 7 | import drl.mdp.MDP._ 8 | import drl.backend.Backend._ 9 | import drl.backend.Backends._ 10 | 11 | import drl.mdp.Game2048._ 12 | 13 | object RLApp extends App { 14 | 15 | val seed = 1234567 16 | val numEx = 40000 17 | 18 | type Game = Game6561 19 | val fc = Game2048Conf.fullconf(numEx, seed)//GameDeepConf.fullconf(numEx, seed) 20 | 21 | // type Game = GameDeepDQN 22 | // val fc = GameDeepAEConf.fullconf(numEx, seed)//GameDeepConf.fullconf(numEx, seed) 23 | 24 | 25 | Rand.setSeed(seed) 26 | 27 | 28 | 29 | 30 | 31 | SelfPlay.trainModelRLDeepQ[Game, SeparableCompGraph](fc.qconf, scala.Left(fc.nconf), fc.deconf, fc.offrlconf) 32 | //SelfPlay.trainModelRLDeepV[Game, SeparableCompGraph](fc.vconf, scala.Left(fc.nconf), fc.deconf, fc.offrlconf) 33 | // SelfPlay.trainNoveltyAEV[Game, SeparableCompGraph](fc.vconf, scala.Left(fc.nconf), fc.novconf, fc.offrlconf) 34 | // SelfPlay.trainModelRLDeepQ[Game, SeparableCompGraph](fc.qconf, scala.Left(fc.nconf), fc.deconf, fc.offrlconf) 35 | // SelfPlay.trainNoveltyV[Game, SeparableCompGraph](fc.vconf, scala.Left(fc.nconf), fc.novconf, fc.offrlconf) 36 | 37 | } 38 | 39 | /* 40 | 41 | val evalP = TreeSearchP() 42 | val valP = TreeSearchP(1, false, false) 43 | val randomP = RandomQ 44 | val uct = UCT(1000) 45 | 46 | val pl = List( 47 | // MixedPolicy(1.0f, evalP, randomP), 48 | // MixedPolicy(0.95f, evalP, randomP), 49 | // MixedPolicy(0.99f, evalP, randomP) 50 | // MixedPolicy(0.0f, evalP, randomP), 51 | TreeSearchP(5, false, true) 52 | // MixedPolicy(1.0f, valP, randomP) 53 | // OppPolicy(valP, evalP), 54 | // OppPolicy(evalP, randomP) 55 | // uct 56 | ) 57 | 58 | for (p <- pl) { 59 | () 60 | 61 | } 62 | */ 63 | -------------------------------------------------------------------------------- /src/main/scala/Charts.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import drl.mdp._ 4 | 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | import org.jfree.chart.renderer.xy._ 9 | import scalax.chart.api._ 10 | import scalax.chart._ 11 | 12 | import java.io._ 13 | 14 | object Charts { 15 | var series: Map[String, XYSeries] = Map() 16 | var seriesToFile: Map[String, PrintWriter] = Map() 17 | var seriesToChart: Map[String, String] = Map() 18 | var avgs: Map[String, List[Float]] = Map() 19 | var nbValues: Map[String, Int] = Map() 20 | var charts: Map[String, XYChart] = Map() 21 | 22 | def createChart(serie: String, chart: String, nbValue: Int) = { 23 | val f = new File(serie) 24 | if (f.exists()) 25 | f.delete() 26 | f.createNewFile() 27 | seriesToFile += ((serie, new PrintWriter(new FileOutputStream(serie),true))) 28 | series += ((serie, new XYSeries(serie))) 29 | seriesToChart += ((serie, chart)) 30 | nbValues += ((serie, nbValue)) 31 | avgs += ((serie, List())) 32 | } 33 | 34 | def addValue(x:Float, y: Float, serie: String) = { 35 | 36 | val nbAvg = nbValues(serie) 37 | val ser = series(serie) 38 | avgs += ((serie, (y::avgs(serie)).take(nbAvg))) 39 | var ma = avgs(serie).sum/nbAvg 40 | swing.Swing onEDT { 41 | ser.add(x, ma) 42 | seriesToFile(serie).write(x.toString + " " + ma.toString() + "\n") 43 | seriesToFile(serie).flush() 44 | } 45 | } 46 | def show() = { 47 | val grp = seriesToChart.toList.groupBy(_._2) 48 | grp.foreach(x => 49 | if (!charts.contains(x._1)) { 50 | val chart = XYLineChart(x._2.map(y => series(y._1))) 51 | chart.plot.setRenderer(new XYLineAndShapeRenderer(false, true)) 52 | charts += ((x._1, chart)) 53 | chart.show() 54 | } 55 | ) 56 | } 57 | 58 | def save() = { 59 | charts.foreach(x => 60 | x._2.saveAsPDF(x._1) 61 | ) 62 | } 63 | 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/policy/VPolicy.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | import drl.backend.Backend._ 5 | 6 | case class EpsGreedyVPolicyHead[B: NeuralN, S: Valuable](model: B, k: Int, maxEpsilon: Int, disc: Float) extends EpsGreedy[S, Valuable](maxEpsilon, VPolicyHead(model, k, disc)) 7 | 8 | case class VPolicyCombined[B: NeuralN](model: B, disc:Float) extends PolicyV { 9 | 10 | 11 | def potentials[S: Valuable](l: Seq[((A, (S, Reward, Odd)), Array[Float])]) = { 12 | def potential(e: ((A, (S, Reward, Odd)), Array[Float])) = { 13 | e._1._2._3*(e._2.sum/e._2.length + e._1._2._2/disc) 14 | } 15 | l.map(potential).sum 16 | } 17 | 18 | def nextAction[S: Valuable](s: S) = { 19 | val aA = s.availableActions 20 | val ps = aA.map(a => s.potentialStates(a).map(x => (a, x))).flatten 21 | val states = ps.map(_._2._1) 22 | val evals = model.output(states.toList).map(_.map(_(0)).toArray) 23 | val shifted = (0 until evals(0).length).map(y => (0 until evals.length).map(z => evals(z)(y)).toArray).toArray 24 | ps.zip(shifted).groupBy(_._1._1).mapValues(potentials[S]).maxBy(_._2)._1 25 | } 26 | 27 | } 28 | 29 | 30 | case class VPolicyHead[B: NeuralN](model: B, k:Int, disc:Float) extends PolicyV { 31 | 32 | 33 | def potentials[S: Valuable](l: Seq[((A, (S, Reward, Odd)), Float)]) = { 34 | def potential(e: ((A, (S, Reward, Odd)), Float)) = { 35 | e._1._2._3*(e._2 + e._1._2._2/disc) 36 | } 37 | l.map(potential).sum 38 | } 39 | 40 | 41 | def nextAction[S: Valuable](s: S) = { 42 | 43 | val aA = s.availableActions 44 | val ps = aA.map(a => s.potentialStates(a).map(x => (a, x))).flatten 45 | val states = ps.map(_._2._1) 46 | val evals = model.outputHead(states.toList, k).map(_(0)).toArray 47 | // println(s) 48 | // println(evals.toList) 49 | val r = ps.zip(evals).groupBy(_._1._1).mapValues(potentials[S]).maxBy(_._2) 50 | r._1 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/learning/QLearning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl.mdp.MDP._ 4 | import drl.backend.Backend._ 5 | import drl.policy._ 6 | 7 | 8 | object QLearning { 9 | 10 | 11 | def ql[S: Statable, B: NeuralN](entry:SARS[S], b:Boolean, head: Int, model: B, target: B, targetNPeriod: Option[Int], zeroImpossible:Boolean, gamma:Float, disc:Float):Array[Float] = { 12 | 13 | def getScore(s:S, head:Int = 0, cg:B = model):Array[Float] = { 14 | cg.outputS(s).apply(head) 15 | } 16 | 17 | 18 | def getHeadScore(s:S) = 19 | getScore(s, head, model) 20 | 21 | val (s1, a, r, s2) = entry 22 | 23 | val ow = implicitly[Statable[S]].outputWidth 24 | 25 | var ar = Array.fill(ow)(0f) 26 | if (zeroImpossible) { 27 | val qs = DiscRewardQ.filterPossible(s1, getHeadScore(s1)) 28 | qs.foreach(x => 29 | ar(x._2) = x._1 30 | ) 31 | } 32 | else 33 | ar = getHeadScore(s1) 34 | 35 | if (b) { 36 | 37 | if (!s1.canContinue) { 38 | ar = Array.fill(ow)(0f) 39 | } 40 | 41 | else { 42 | val nqs:Float = 43 | if (s2.canContinue) { 44 | if (targetNPeriod.isDefined) { 45 | val maxA = DiscRewardQ.filterPossible(s2, getHeadScore(s2)).maxBy(_._1)._2 46 | DiscRewardQ.filterPossible(s2, getScore(s2, head, target)).find(_._2 == maxA).get._1 47 | } 48 | else { 49 | DiscRewardQ.filterPossible(s2, getHeadScore(s2)).maxBy(_._1)._1 50 | } 51 | } 52 | else { 53 | 0f 54 | } 55 | 56 | val ind = s1.F.actionToIndex(a) 57 | ar(ind) = (r/disc + gamma*nqs) 58 | if (!s2.canContinue) 59 | println("NQS: " + nqs) 60 | } 61 | } 62 | 63 | if (false && !s2.canContinue) { 64 | val hs = getHeadScore(s1) 65 | println(entry + " " + head) 66 | println(ar.toList) 67 | println(hs.toList) 68 | 69 | } 70 | ar 71 | } 72 | 73 | 74 | } 75 | -------------------------------------------------------------------------------- /src/example/scala/TestNovelty.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import drl.policy._ 4 | import drl.mdp._ 5 | import drl.learning._ 6 | import drl.mdp.MDP._ 7 | import drl.backend.Backend._ 8 | import java.text.DecimalFormat 9 | 10 | 11 | object TestNovelty { 12 | 13 | val formatter = new DecimalFormat("#.##") 14 | 15 | val L = 20 16 | 17 | def g(t:Int) = 18 | GameDeepDQN(t.max(0).min(L), t.max(0), L) 19 | 20 | val games = (-1 to L+1).map(g).toList 21 | /* 22 | def sas(t:Int) = 23 | (g(t), RightM, g(t+1)) 24 | 25 | def lsas(t:Int) = 26 | (g(t), RightM, g(t+1).copy(t=(t-2).max(0))) 27 | 28 | 29 | val games = (1 to L+5).map(turn => 30 | (0 until turn).map(t => 31 | sas(t) 32 | ).toList:::List(lsas(turn))) 33 | */ 34 | 35 | def test[B: NeuralN](novelty: B, autoEncode: Option[B]) { 36 | /* 37 | games.foreach(l => { 38 | val targets = NoveltyPolicy.autoEncode(l.map(_._3), autoEncode) 39 | val inputs = NoveltyPolicy.input(l.map(x => (x._1, x._2)), autoEncode) 40 | val dis2 = NoveltyPolicy.noveltyDis(l.map(x => (x._1, x._2, x._3)), novelty, 1f, Some(autoEncode)).toList 41 | println("DIS: " + dis2) 42 | } 43 | // novelty.outputHead(games) 44 | ) 45 | */ 46 | 47 | //val aeg = NoveltyPolicy.autoEncode(games, autoEncode) 48 | val right = (0 until games.length-2).map(i => (games(1+i), RightM, games(i+2))).toList 49 | val left = (0 until games.length-2).map(i => (games(i+1), LeftM, games(i))).toList 50 | val nr = NoveltyPolicy.noveltyDis(right, novelty, 1f, autoEncode).map(formatter.format(_)) 51 | val nl = NoveltyPolicy.noveltyDis(left, novelty, 1f, autoEncode).map(formatter.format(_)) 52 | val r = nl.zip(nr).zip(counts) 53 | println(r.mkString(" | ")) 54 | 55 | } 56 | 57 | val counts = Array.fill(L+1)(0) 58 | 59 | Charts.createChart("maxt", "maxt", 1) 60 | var i = 0 61 | def maxt[S: Statable](l:List[SARS[S]]) = { 62 | 63 | i += 1 64 | val gds = l.map(_._4.asInstanceOf[GameDeepDQN].t) 65 | val m = gds.max 66 | 67 | gds.foreach(x => 68 | counts(x) += 1 69 | ) 70 | 71 | println("MMMMM: " + m) 72 | Charts.addValue(i, m, "maxt") 73 | } 74 | Charts.show() 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/policy/TreeSearch.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | sealed trait Opponent 4 | case object PerfectCollab extends Opponent 5 | case object RandomO extends Opponent 6 | 7 | import drl.mdp.MDP._ 8 | 9 | object TreeSearch { 10 | 11 | type Search[S] = (S, Reward, A) 12 | 13 | def searchScore[S: Statable](s: S, odepth: Int, eval: S => Value, opponent: Opponent, opponentTurn: Boolean = false): Search[S] = { 14 | 15 | def rec(s: S, depth: Int): Search[S] = { 16 | 17 | def isOpponentTurn = 18 | odepth-depth%2 == 0 19 | 20 | def getMaxScore = { 21 | 22 | var maxMove: A = null 23 | var max = Float.MinValue 24 | var leaf: S = s.F.zero 25 | 26 | for (m <- s.availableActions) { 27 | val ng = s.applyTransition(m)._1 28 | val search = rec(ng, depth-1) 29 | 30 | val game = search._1 31 | val scr = search._2 32 | 33 | if (scr > max) { 34 | maxMove = m 35 | max = scr 36 | leaf = game 37 | } 38 | } 39 | (leaf, max, maxMove) 40 | } 41 | 42 | def getExpectiMax = { 43 | 44 | var sum = 0f 45 | var leaf: S = s.F.zero 46 | var maxMove: A = null 47 | 48 | var nb = 0f 49 | 50 | for (m <- s.availableActions) { 51 | val ng = s.applyTransition(m)._1 52 | val search = rec(ng, depth-1) 53 | 54 | maxMove = m 55 | sum += search._2 56 | nb += 1 57 | leaf = search._1 58 | 59 | } 60 | 61 | (leaf, sum/nb, maxMove) 62 | 63 | } 64 | 65 | if (depth == 0 || !s.canContinue) 66 | (s, eval(s), null) 67 | else 68 | opponent match { 69 | case RandomO if opponentTurn => { 70 | getExpectiMax 71 | } 72 | case _ => 73 | getMaxScore 74 | } 75 | } 76 | 77 | rec(s, odepth) 78 | 79 | } 80 | 81 | 82 | def maxmax[S: Statable](s: S, depth: Int, eval: S => Value): Search[S] = 83 | searchScore(s, depth, eval, PerfectCollab) 84 | 85 | def expectimax[S: Statable](s: S, depth: Int, eval: S => Value): Search[S] = 86 | searchScore(s, depth, eval, RandomO) 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/policy/QPolicy.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.Rand 4 | import drl.mdp.MDP._ 5 | import drl.backend.Backend._ 6 | 7 | case class EpsGreedyQPolicyHead[B: NeuralN, S: Statable](model: B, k: Int, maxEpsilon: Int) extends EpsGreedy[S, Statable](maxEpsilon, QPolicyHead(model, k)) 8 | 9 | case class QPolicyHead[B: NeuralN](model: B, k:Int) extends PolicyQ { 10 | 11 | def max[S: Statable](l:Array[Float], s:S) = { 12 | val (r, n) = DiscRewardQ.findMaxInd(DiscRewardQ.filterPossible(s, l)) 13 | // println(r + " " + n)// + " " + DiscRewardQ.filterPossible(g, l.toIndexedSeq).mkString(", ")) 14 | n 15 | } 16 | 17 | 18 | def nextAction[S: Statable](s: S) = { 19 | val o = model.outputHeadS(s, k) 20 | val r = s.F.allActions( 21 | max(o,s) 22 | ) 23 | // println(r + " " + s + " " + s.toInput.toList + " " + o.toList + " " + k) 24 | r 25 | } 26 | } 27 | 28 | case class QPolicyCombined[B: NeuralN](model: B, nbHead: Int) extends PolicyQ { 29 | 30 | 31 | def max[S: Statable](l:Array[Array[Float]], s: S) = { 32 | val maxWithInd = (0 until nbHead).map(y => DiscRewardQ.findMaxInd( 33 | DiscRewardQ.filterPossible(s, (0 until l.head.length).map(x => { 34 | l(y)(x) 35 | }).toArray) 36 | )).groupBy(_._2).map(x => (x._1, x._2.map(_._1))) 37 | 38 | var maxInd, nb = -1 39 | var maxSum = Float.MinValue 40 | maxWithInd.foreach(x => 41 | if (x._2.length > nb || (x._2.length == nb && x._2.sum > maxSum)) { 42 | maxInd = x._1 43 | maxSum = x._2.sum 44 | nb = x._2.length 45 | } 46 | ) 47 | /* 48 | println(" maxInd: " + maxInd ) 49 | println("MAXXX:" + maxWithInd) 50 | println(l.map(_.toList.map(_.toString.take(5)).mkString(" ")).toList.mkString("\n")) 51 | println() 52 | */ 53 | maxInd 54 | } 55 | 56 | def nextAction[S: Statable](s: S) = { 57 | s.F.allActions( 58 | max( 59 | model.outputS(s), 60 | s) 61 | ) 62 | } 63 | 64 | } 65 | 66 | object DiscRewardQ { 67 | 68 | def findMaxInd(l:Array[(Float, Int)]) = 69 | l.maxBy(_._1) 70 | 71 | 72 | def filterPossible[S: Statable](s: S, mvs: Array[Float]) = { 73 | val indexs = s.availableActions.map(x => s.F.actionToIndex(x)) 74 | mvs.zipWithIndex.filter(x => indexs.contains(x._2)) 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/learning/OfflineRL.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import scala.{Right => ERight, Left => ELeft} 6 | import drl.mdp.MDP._ 7 | import drl.backend.Backend._ 8 | 9 | case class OfflineRLConf( 10 | numEx: Int, 11 | seed: Int, 12 | maxEpsilon: Int = 200, 13 | disc: Float = 1f, 14 | batchSize: Int = 32, 15 | gamma: Float = 0.99f 16 | // val maskRand = 0.5f//0.5f 17 | ) 18 | 19 | case class VConf( 20 | lambda: Float = 0.8f 21 | ) 22 | 23 | case class QConf( 24 | expRep: Boolean = true, 25 | zeroImpossible: Boolean = false, 26 | minPoolFactor: Int = 30, 27 | maxPoolFactor: Int = 35 28 | ) 29 | 30 | abstract class OfflineRL[S: Statable, B: NeuralN](nconf:Either[NConf, B], offRLconf: OfflineRLConf) extends RL[S, B]() { 31 | 32 | type Score = Float 33 | 34 | var nbFetch = 0 35 | 36 | val model = nconf match { 37 | case ERight(mdl) => mdl 38 | case ELeft(conf) => buildModel(conf) 39 | } 40 | 41 | def getModel() = model 42 | def buildModel(conf: NConf): B 43 | def getNext(nb:Int): (List[SARS[S]], Score) 44 | def target(l: List[SARS[S]]):Array[Array[Array[Float]]] 45 | 46 | def sample[Q[_] <: Statable[_], S: Q: Statable](pol:Policy[Q]) = { 47 | val F = implicitly[Statable[S]] 48 | val episode = OfflineRL.episode(F.zero)(pol) 49 | val lv = episode.map(_._3).sum 50 | println("Ep Score: "+ lv) 51 | (episode, lv) 52 | } 53 | 54 | 55 | def output[S: Statable](states: IndexedSeq[S], scores:IndexedSeq[Array[Array[Float]]]) = { 56 | (states, scores.map(_.toIndexedSeq)) 57 | } 58 | 59 | def hasNext() = 60 | offRLconf.numEx > nbFetch 61 | 62 | def next(n: Int): PreBatch[S] 63 | 64 | def next() = { 65 | 66 | nbFetch += 1 67 | 68 | next(offRLconf.batchSize) 69 | 70 | } 71 | 72 | } 73 | 74 | 75 | 76 | object OfflineRL { 77 | 78 | def episode[Q[_] <: Statable[_], S: Q: Statable](s: S)(pol:Policy[Q]) = { 79 | var exps: List[SARS[S]] = List() 80 | var cs = s.F.zero 81 | while (cs.canContinue) { 82 | val action = pol.nextAction(cs) 83 | val (ns, reward) = cs.applyTransition(action) 84 | exps ::= ((cs, action, reward, ns)) 85 | cs = ns 86 | } 87 | // games ::= (cg, NoMove, 0f, cg) //to train to 0f the end 88 | exps.reverse 89 | 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /src/main/scala/mdp/Game2048.scala: -------------------------------------------------------------------------------- 1 | package drl.mdp 2 | 3 | import drl._ 4 | import drl.mdp.MDP._ 5 | 6 | case class Game2048(grid: GridIS6561, turn: Int) { 7 | 8 | lazy val toInput = { 9 | grid.toInput2048 10 | } 11 | 12 | def canTurn(dir: Direction) = { 13 | fullMove(Turn(dir))._1.grid != grid 14 | } 15 | 16 | def move(mv: Move6561) = { 17 | 18 | val ng = mv match { 19 | case Turn(dir) => 20 | val (ng, r) = grid.move(dir) 21 | (ng, r) 22 | } 23 | 24 | (Game2048(ng._1, turn+1), ng._2) 25 | 26 | } 27 | 28 | def randomPlace() = { 29 | val es = grid.emptySpots 30 | if (!es.isEmpty) { 31 | val spot = Rand.choose(grid.emptySpots) 32 | copy(grid = grid.place(spot._1, spot._2, new Piece(1, Red, 2)).get) 33 | } else 34 | this 35 | } 36 | 37 | def fullMove(m: Move6561) = { 38 | val (ng, r) = move(m) 39 | (ng.randomPlace(), r) 40 | } 41 | 42 | def eval = 43 | grid.eval 44 | 45 | def value = 46 | grid.value 47 | 48 | lazy val availableMoveNext = { 49 | if (!grid.emptySpots.isEmpty) 50 | Game6561.ALL_TURNS filter (x => canTurn(x.dir)) toList 51 | else 52 | List() 53 | } 54 | // List(availableMoveGen(turn+1).map(x => (x, move(x).get)).maxBy(_._2.grid.eval)._1) 55 | 56 | } 57 | 58 | object Game2048 { 59 | 60 | implicit object Game6561V extends Randomizable[Game6561] { 61 | 62 | type CAction = Move6561 63 | 64 | val allActions = Game6561.moves 65 | 66 | val zero = Game6561(Grid6561(3), 0, 0) 67 | 68 | def realizeTransition(g: Game6561, m: CAction) = { 69 | val ng = g.move(m).get 70 | (ng, ng.value - g.value) 71 | } 72 | 73 | def potentialStates(g: Game6561, a: A): IndexedSeq[(Game6561, Reward, Odd)] = { 74 | val (ng, rw) = realizeTransition(g, cAction(a)) 75 | IndexedSeq((ng, rw, 1f)) 76 | } 77 | 78 | def availableActions(g: Game6561) = 79 | g.availableMoveNext 80 | 81 | def value(g: Game6561) = 82 | g.value 83 | 84 | def heuristic(g: Game6561) = 85 | g.eval.toFloat 86 | 87 | def toInput(g: Game6561) = 88 | g.toInput 89 | 90 | def toString(g: Game6561) = 91 | g.toString 92 | 93 | def genRandom() = 94 | Game6561(Grid6561.random(Game6561Conf.gameL, 3), Rand.nextInt(Game6561Conf.gameL), 0) 95 | 96 | 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/main/scala/SelfPlay.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import org.slf4j.LoggerFactory 4 | 5 | import drl.policy._ 6 | import drl.learning._ 7 | import drl.mdp.MDP._ 8 | import drl.backend.Backend._ 9 | 10 | object SelfPlay { 11 | 12 | lazy val log = LoggerFactory.getLogger(getClass()) 13 | 14 | 15 | def trainModelRLDeepQ[S: Statable, B: NeuralN](qconf: QConf, nconf:Either[NConf, B], deconf: DeepExplorationConf, offrlconf: OfflineRLConf, earlyTerm: Option[EarlyTermination] = None) = { 16 | 17 | val qlearning = 18 | new DeepQLearning[S, B](qconf, nconf, deconf, offrlconf, earlyTerm) 19 | 20 | val model = qlearning.getModel 21 | model.fit(qlearning) 22 | qlearning 23 | } 24 | 25 | def trainModelRLDeepV[S: Valuable, B: NeuralN](vconf: VConf, nconf:Either[NConf, B], deconf: DeepExplorationConf, offrlconf: OfflineRLConf) = { 26 | 27 | val vlearning = 28 | new DeepVLearning[S, B](vconf, nconf, deconf, offrlconf) 29 | 30 | val model = vlearning.getModel 31 | 32 | model.fit(vlearning) 33 | model 34 | } 35 | 36 | def trainNoveltyV[S: Valuable, B: NeuralN](vconf: VConf, nconf:Either[NConf, B], novconf: NoveltyConf, offrlconf: OfflineRLConf) = { 37 | 38 | val vlearning = 39 | new NoveltyVLearning[S, B](vconf, nconf, novconf, offrlconf) 40 | 41 | val model = vlearning.getModel 42 | 43 | model.fit(vlearning) 44 | model 45 | 46 | 47 | } 48 | 49 | 50 | def trainNoveltyAEV[S: Randomizable, B: NeuralN](vconf: VConf, nconf:Either[NConf, B], novconf: NoveltyConf, offrlconf: OfflineRLConf) = { 51 | 52 | val vlearning = 53 | new NoveltyAEVLearning[S, B](vconf, nconf, novconf, offrlconf) 54 | 55 | val model = vlearning.getModel 56 | 57 | model.fit(vlearning) 58 | model 59 | 60 | 61 | } 62 | 63 | 64 | 65 | 66 | 67 | def test[Q[_] <: Statable[_], S: Q: Statable](p:Policy[Q], track: Boolean=false) = { 68 | 69 | val F = implicitly[Statable[S]] 70 | var s = F.zero 71 | var r = 0f 72 | var i = 0 73 | while (s.canContinue) { 74 | if (track) { 75 | log.info(p.getClass.toString + " turn: " + i + " score: " + r) 76 | println(s) 77 | } 78 | val a = p.nextAction(s) 79 | // println(a) 80 | val (ns, rw) = s.applyTransition(a) 81 | s = ns 82 | r += rw 83 | i += 1 84 | } 85 | println("Score: " + r + " Turn: " + i) 86 | r 87 | } 88 | 89 | 90 | def average[Q[_] <: Statable[_], S: Q: Statable](p: Policy[Q], nb: Int = 100) = { 91 | val values = (1 to nb).map(x => test(p)) 92 | (values.sum/values.length.toFloat, values) 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/policy/VNovelty.scala: -------------------------------------------------------------------------------- 1 | package drl.policy 2 | 3 | import drl.mdp.MDP._ 4 | import drl.backend.Backend._ 5 | 6 | 7 | case class GreedyNoveltyVPolicy[B: NeuralN](model: B, novelty:B, disc:Float, beta:Float, ae:Option[B] = None) extends PolicyV { 8 | 9 | def nextAction[S: Valuable](s: S) = { 10 | NoveltyPolicy.nextActions(s, model, novelty, disc, beta, ae).maxBy(_._2)._1 11 | } 12 | 13 | } 14 | 15 | 16 | object NoveltyPolicy { 17 | 18 | def eucDistance(x1: Array[Float], x2:Array[Float]) = { 19 | // println("DIS: " +x1.mkString(" ") + "\n" + x2.mkString(" ")) 20 | Math.sqrt(x1.zip(x2).map(x => x._1 - x._2).map(x => x*x).sum).toFloat 21 | } 22 | 23 | def potentials[S: Valuable](l: Seq[((A, (S, Reward, Odd)), Float)], disc:Float) = { 24 | def potential(e: ((A, (S, Reward, Odd)), Float)) = { 25 | e._1._2._3*(e._2 + e._1._2._2/disc) 26 | } 27 | l.map(potential).sum 28 | } 29 | 30 | def noveltyDis[S: Valuable, B: NeuralN](ls: List[(S, A, S)], novelty:B, beta:Float, ae: Option[B]) = { 31 | // val sts = ae.map(_.output()) 32 | val ars = 33 | if (ae.isEmpty) 34 | ls.map(x => x._1.inputWithAction(x._2)) 35 | else 36 | input(ls.map(x => (x._1, x._2)), ae.get) 37 | 38 | val outputAE:Array[Array[Float]] = 39 | if (ae.isEmpty) 40 | ls.map(x => x._3.toInput).toArray 41 | else 42 | autoEncode(ls.map(x => x._3), ae.get) 43 | 44 | novelty.output(ars.toList, (x: Array[Float]) => x)(0).zip(outputAE).map(x => beta*eucDistance(x._1, x._2)) 45 | } 46 | 47 | def autoEncode[S: Valuable, B: NeuralN](ls: List[S], ae: B) = 48 | // ls.map(_.toInput.take(10)) 49 | ae.outputHead(ls, 1) 50 | 51 | def input[S: Valuable, B: NeuralN](ls: List[(S, A)], ae: B) = 52 | autoEncode(ls.map(_._1), ae).zip(ls).map(x => x._1 ++ x._2._1.actionEncoding(x._2._2)).toList 53 | 54 | def nextActions[S: Valuable, B: NeuralN](s: S, model: B, novelty: B, disc: Float, beta:Float, ae: Option[B] = None) = { 55 | // val F = implicitly[Statable[S]] 56 | val aA = s.availableActions 57 | val ps = aA.map(a => s.potentialStates(a).map(x => (a, x))).flatten 58 | val states = ps.map(_._2._1) 59 | // val noveltys = ps.map(x => x._2._1.inputWithAction(x._1)) 60 | val noveltys:Array[Float] = noveltyDis(ps.map(x => (s, x._1, x._2._1)).toList, novelty, beta, ae) 61 | val evals = model.outputHead(states.toList, 0).map(_(0)).toArray.zip(noveltys).map(x => x._1 + x._2) 62 | 63 | // println("STATE: " +s) 64 | println("NOV: " + noveltys.mkString(" ")) 65 | println("EVAL: " + evals.toList) 66 | 67 | ps.zip(evals).groupBy(_._1._1).mapValues(x => potentials[S](x, disc)) 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /src/main/scala/learning/NoveltyAEVLearning.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | 9 | class NoveltyAEVLearning[S: Randomizable, B: NeuralN](vconf: VConf, nconf: Either[NConf, B], novconf: NoveltyConf, offrlconf: OfflineRLConf, earlytermination: Option[EarlyTermination] = None) extends NoveltyExploration[Valuable, S, B](nconf, novconf, offrlconf, earlytermination){ 10 | 11 | 12 | lazy val outputWidth = 1 13 | 14 | val load = false 15 | 16 | 17 | val novelty:B = buildNoveltyModel(implicitly[Statable[S]].featureSize/2) 18 | val autoEncode:B = 19 | if (load) 20 | implicitly[NeuralN[B]].load("autoencode") 21 | else 22 | implicitly[NeuralN[B]].buildAE() 23 | 24 | val pol = GreedyNoveltyVPolicy[B](model, novelty, offrlconf.disc, novconf.beta, Some(autoEncode)) 25 | 26 | val combined = VPolicyHead(model, 0, offrlconf.disc) 27 | 28 | if (!load) { 29 | trainAutoEncoder() 30 | autoEncode.save("autoencode") 31 | } 32 | 33 | trainPredict() 34 | 35 | def trainAutoEncoder() = { 36 | 37 | for (i <- 1 to 20000) { 38 | println("autoEncoder fit: "+i) 39 | val inputs = List.fill(1000)(implicitly[Randomizable[S]].genRandom) 40 | val out2 = autoEncode.outputHead(inputs, 1) 41 | autoEncode.buildAndFit(inputs, Array(inputs.map(_.toInput).toArray, out2), (x:S) => x.toInput) 42 | } 43 | } 44 | 45 | def trainPredict() = { 46 | 47 | def genRandom = { 48 | val st = implicitly[Randomizable[S]].genRandom 49 | val a = st.randomAction 50 | (st, a) 51 | } 52 | for (i <- 1 to 0) { 53 | println("predict fit: "+i) 54 | val gen = List.fill(1000)(genRandom) 55 | val r = gen.map(x => x._1.applyTransition(x._2)) 56 | fitNovelty(gen.zip(r).map(x => (x._1._1, x._1._2, x._2._2, x._2._1))) 57 | } 58 | } 59 | 60 | def fitNovelty(l:List[SARS[S]]) = { 61 | val targets = NoveltyPolicy.autoEncode(l.map(_._4), autoEncode) 62 | val inputs = NoveltyPolicy.input(l.map(x => (x._1, x._2)), autoEncode) 63 | val dis2 = NoveltyPolicy.noveltyDis(l.map(x => (x._1, x._2, x._4)), novelty, 1f, Some(autoEncode)).sum 64 | println("DIS: " + dis2) 65 | novelty.buildAndFit(inputs, Array(targets), (x:Array[Float]) => x) 66 | } 67 | 68 | def target(l: List[SARS[S]]):Array[Array[Array[Float]]] = { 69 | val r = Array(TDLambda.tdErr(l, 0, true, model, vconf.lambda, offrlconf.gamma, offrlconf.disc).map(Array(_))) 70 | // trainPredict() 71 | // TestNovelty.maxt(l) 72 | // TestNovelty.test(novelty, Some(autoEncode)) 73 | fitNovelty(l) 74 | r 75 | } 76 | 77 | 78 | def getNext(nb:Int) = { 79 | sample(pol) 80 | } 81 | 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/scala/mdp/MDP.scala: -------------------------------------------------------------------------------- 1 | package drl.mdp 2 | 3 | import drl.Rand 4 | 5 | object MDP { 6 | 7 | trait Action 8 | 9 | type Value = Float 10 | type Reward = Float 11 | 12 | 13 | type A = Action 14 | type R = Reward 15 | type V = Value 16 | type SARS[S] = (S, A, R, S) 17 | 18 | type Odd = Float 19 | 20 | trait Statable[State] { //extends Monoid[State] 21 | 22 | type CAction <: Action 23 | 24 | def realizeTransition(state: State, action: CAction): (State, Reward) 25 | 26 | def applyTransition(state: State, action: A): (State, Reward) = 27 | realizeTransition(state, cAction(action)) 28 | 29 | def availableActions(state: State): Seq[A] 30 | 31 | // def zeroMove: A 32 | 33 | def zero: State 34 | 35 | def allActions: IndexedSeq[A] 36 | 37 | lazy val outputWidth = 38 | allActions.length 39 | 40 | def featureSize = 41 | toInput(zero).length 42 | 43 | lazy val actionToIndex = 44 | allActions.zipWithIndex toMap 45 | 46 | def actionEncoding(action: A) = { 47 | val ind = actionToIndex(action) 48 | val ar = Array.fill(outputWidth)(0f) 49 | ar(ind) = 1f 50 | ar 51 | } 52 | 53 | 54 | def toString(state: State): String 55 | 56 | def toInput(state: State): Array[Float] 57 | 58 | def cAction(a: Action) = a.asInstanceOf[CAction] 59 | 60 | } 61 | 62 | trait Valuable[S] extends Statable[S]{ 63 | 64 | def value(state: S): Value 65 | 66 | def heuristic(state: S): Float 67 | 68 | def potentialStates(state: S, action: A): IndexedSeq[(S, Reward, Odd)] 69 | 70 | // def potStates(state: S, action: A): IndexedSeq[(S, Reward, Odd)] = 71 | // potentialStates(S, cAction(action)) 72 | 73 | } 74 | 75 | trait Randomizable[S] extends Valuable[S]{ 76 | 77 | def genRandom(): S 78 | 79 | } 80 | 81 | implicit class StatableOps[S: Statable](state: S) { 82 | 83 | val F = implicitly[Statable[S]] 84 | 85 | override def toString = 86 | F.toString(state) 87 | 88 | def toInput = 89 | F.toInput(state) 90 | 91 | def applyTransition(action: A) = 92 | F.applyTransition(state, action) 93 | 94 | def availableActions: Seq[A] = 95 | F.availableActions(state) 96 | 97 | def randomAction: A = 98 | Rand.choose(availableActions) 99 | 100 | def availableActionsIndexs = 101 | availableActions.map(F.actionToIndex) 102 | 103 | def canContinue = 104 | !(F availableActions(state) isEmpty) 105 | 106 | def actionEncoding(a: Action) = 107 | F.actionEncoding(a) 108 | 109 | def inputWithAction(a: Action) = { 110 | val inp = toInput 111 | val aE = actionEncoding(a) 112 | // println(inp.length + " " + aE.length) 113 | inp ++ aE 114 | } 115 | 116 | } 117 | 118 | 119 | implicit class ValuableOps[V: Valuable](v: V) { 120 | 121 | val V = implicitly[Valuable[V]] 122 | 123 | def value = 124 | V.value(v) 125 | 126 | def heuristic = 127 | V.heuristic(v) 128 | 129 | def potentialStates(a: A): IndexedSeq[(V, Reward, Odd)] = 130 | V.potentialStates(v, a) 131 | 132 | } 133 | 134 | 135 | 136 | } 137 | -------------------------------------------------------------------------------- /src/main/scala/learning/NoveltyExploration.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.Charts 6 | import drl.mdp.MDP._ 7 | import drl.backend.Backend._ 8 | 9 | case class NoveltyConf( 10 | combinedTestFrequency:Int = 5, 11 | nbAverages:Int = 1, 12 | beta: Float = 0.01f 13 | ) 14 | 15 | abstract class NoveltyExploration[Q[_] <: Statable[_], S: Q: Statable, B: NeuralN](nconf: Either[NConf, B], noveltyconf: NoveltyConf, offrlconf: OfflineRLConf, earlytermination: Option[EarlyTermination] = None) extends OfflineRL[S, B](nconf, offrlconf){ 16 | 17 | createCharts() 18 | 19 | def outputWidth: Int 20 | def combined: Policy[Q] 21 | 22 | var scores = List[(Int, Float)]() 23 | 24 | def buildNoveltyModel(inputSize: Int) = { 25 | val backend = implicitly[NeuralN[B]] 26 | val F = implicitly[Statable[S]] 27 | val fS = inputSize + F.outputWidth 28 | println("Novelty") 29 | val confnn = ConfNN( 30 | offrlconf.seed, 31 | inputSize, 32 | 0.01f, 33 | None, 34 | Some(0.05f), 35 | 1, 36 | 0, 37 | 3, 38 | 0, 39 | 128, 40 | 0.9f, 41 | RMSProp, 42 | ReLu 43 | ) 44 | val noveltyModel = backend.build[S](confnn, Some(fS)) 45 | noveltyModel 46 | } 47 | 48 | def buildModel(onconf: NConf) = { 49 | println("Model") 50 | val backend = implicitly[NeuralN[B]] 51 | val confnn = ConfNN( 52 | offrlconf.seed, 53 | outputWidth, 54 | onconf.learningRate, 55 | onconf.l1, 56 | onconf.l2, 57 | 1, 58 | onconf.commonHeight, 59 | onconf.headHeight, 60 | onconf.commonWidth, 61 | onconf.headWidth, 62 | onconf.momentum, 63 | onconf.updater, 64 | onconf.activation 65 | ) 66 | backend.build[S](confnn) 67 | } 68 | 69 | 70 | var repeated = 0 71 | 72 | def next(nb:Int):PreBatch[S] = { 73 | 74 | if (nbFetch%noveltyconf.combinedTestFrequency == 0) { 75 | val avgScore = SelfPlay.average(combined, noveltyconf.nbAverages)._1 76 | chartCombined(avgScore) 77 | if (earlytermination.exists(x => x match { 78 | case RepeatedValue(value, _) if (avgScore >= value) => true 79 | case _ => false 80 | })) 81 | repeated += 1 82 | else 83 | repeated = 0 84 | 85 | scores ::= ((nbFetch, avgScore)) 86 | } 87 | 88 | val (episode, score) = getNext(nb) 89 | chartHead(score) 90 | val stateList = episode.map(_._1) 91 | (stateList, target(episode)) 92 | 93 | } 94 | 95 | override def hasNext() = { 96 | super.hasNext() && earlytermination.forall(x => x match { 97 | case RepeatedValue(_, rep) if (repeated >= rep) => 98 | false 99 | case _ => 100 | println(repeated) 101 | true 102 | }) 103 | } 104 | 105 | 106 | def createCharts() = { 107 | Charts.createChart("head-1", "score2", 1) 108 | Charts.createChart("head-ma10", "score", 100) 109 | Charts.createChart("combined-1", "score3", 100) 110 | Charts.show() 111 | } 112 | 113 | def chartHead(score:Float) { 114 | Charts.addValue(nbFetch, score, "head-1") 115 | Charts.addValue(nbFetch, score, "head-ma10") 116 | } 117 | 118 | def chartCombined(score: Float) { 119 | Charts.addValue(nbFetch, score, "combined-1") 120 | } 121 | 122 | 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/main/scala/learning/DeepExploration.scala: -------------------------------------------------------------------------------- 1 | package drl.learning 2 | 3 | import drl._ 4 | import drl.policy._ 5 | import drl.mdp.MDP._ 6 | import drl.backend.Backend._ 7 | 8 | trait EarlyTermination 9 | case class RepeatedValue(value:Float, repearted: Int) extends EarlyTermination 10 | 11 | case class DeepExplorationConf( 12 | nbHead: Int = 10, 13 | targetNPeriod: Option[Int] = None, 14 | combinedTestFrequency:Int = 5, 15 | nbAverages:Int = 1 16 | ) 17 | 18 | abstract class DeepExploration[Q[_] <: Statable[_], S: Q: Statable, B: NeuralN](nconf: Either[NConf, B], deconf: DeepExplorationConf, offrlconf: OfflineRLConf, earlytermination: Option[EarlyTermination] = None) extends OfflineRL[S, B](nconf, offrlconf){ 19 | 20 | createCharts() 21 | var target = model.cloneM() 22 | val outputWidth: Int 23 | var scores = List[(Int, Float)]() 24 | 25 | def buildModel(onconf: NConf) = 26 | buildModel(onconf, outputWidth) 27 | 28 | def buildModel(onconf: NConf, ow: Int) = { 29 | val backend = implicitly[NeuralN[B]] 30 | val confnn = ConfNN( 31 | offrlconf.seed, 32 | ow, 33 | onconf.learningRate, 34 | onconf.l1, 35 | onconf.l2, 36 | deconf.nbHead, 37 | onconf.commonHeight, 38 | onconf.headHeight, 39 | onconf.commonWidth, 40 | onconf.headWidth, 41 | onconf.momentum, 42 | onconf.updater, 43 | onconf.activation 44 | ) 45 | backend.build[S](confnn) 46 | } 47 | 48 | 49 | def getScores = 50 | scores 51 | 52 | def combined: Policy[Q] 53 | 54 | def targetHead(l: List[SARS[S]], k:Int): Array[Array[Float]] 55 | 56 | def target(l: List[SARS[S]]):Array[Array[Array[Float]]] = 57 | (0 until deconf.nbHead).map(k => { 58 | targetHead(l, k) 59 | }).toArray 60 | 61 | var repeated = 0 62 | def next(nb:Int):PreBatch[S] = { 63 | 64 | if (deconf.targetNPeriod.exists(p => nbFetch%p == 0)) 65 | cloneTarget() 66 | 67 | if (nbFetch%deconf.combinedTestFrequency == 0) { 68 | val avgScore = SelfPlay.average(combined, deconf.nbAverages)._1 69 | chartCombined(avgScore) 70 | if (earlytermination.exists(x => x match { 71 | case RepeatedValue(value, _) if (avgScore >= value) => true 72 | case _ => false 73 | })) 74 | repeated += 1 75 | else 76 | repeated = 0 77 | 78 | scores ::= ((nbFetch, avgScore)) 79 | } 80 | 81 | val (episode, score) = getNext(nb) 82 | chartHead(score) 83 | val stateList = episode.map(_._1) 84 | (stateList, target(episode)) 85 | 86 | } 87 | 88 | override def hasNext() = { 89 | super.hasNext() && earlytermination.forall(x => x match { 90 | case RepeatedValue(_, rep) if (repeated >= rep) => 91 | false 92 | case _ => 93 | println(repeated) 94 | true 95 | }) 96 | } 97 | 98 | def cloneTarget() = { 99 | println("CLONED") 100 | target = model.cloneM 101 | } 102 | 103 | 104 | def createCharts() = { 105 | Charts.createChart("head-1", "score2", 1) 106 | Charts.createChart("head-ma10", "score", 100) 107 | Charts.createChart("combined-1", "score3", 100) 108 | Charts.show() 109 | } 110 | 111 | def chartHead(score:Float) { 112 | Charts.addValue(nbFetch, score, "head-1") 113 | Charts.addValue(nbFetch, score, "head-ma10") 114 | } 115 | 116 | def chartCombined(score: Float) { 117 | Charts.addValue(nbFetch, score, "combined-1") 118 | } 119 | 120 | 121 | } 122 | -------------------------------------------------------------------------------- /src/main/scala/backend/Backends.scala: -------------------------------------------------------------------------------- 1 | package drl.backend 2 | 3 | 4 | import org.deeplearning4j.nn.graph.ComputationGraph 5 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 6 | import org.nd4j.linalg.api.ndarray.INDArray 7 | import org.deeplearning4j.util._ 8 | import org.nd4j.linalg.factory.Nd4j 9 | import org.nd4j.linalg.dataset._ 10 | import java.io._ 11 | 12 | import drl.mdp.MDP._ 13 | import drl.backend.Backend._ 14 | 15 | object Backends { 16 | 17 | implicit object dl4j extends NeuralN[SeparableCompGraph] { 18 | 19 | // Nd4j.ENFORCE_NUMERICAL_STABILITY = true 20 | 21 | type NN = SeparableCompGraph 22 | 23 | type Batch = MultiDataSet 24 | type Input = INDArray 25 | 26 | 27 | def build[S: Statable](conf: ConfNN, inputSize:Option[Int]):NN = { 28 | if (conf.commonHeight > 0) { 29 | val cgConf = BuildNN.buildCG[S](conf, inputSize) 30 | val model = new ComputationGraph(cgConf) 31 | model.init() 32 | model.setListeners(new ScoreIterationListener(1)) 33 | new SingleCompGraph(model) 34 | } else { 35 | println(conf.nbHead + " ....") 36 | val confs = (1 to conf.nbHead).map(i => conf.copy(seed = conf.seed + i, nbHead = 1)) 37 | val cgConfs = confs.map(c => BuildNN.buildCG[S](c, inputSize)) 38 | val models = cgConfs.map(c => new ComputationGraph(c)) 39 | models.foreach(_.init) 40 | models.foreach(_.setListeners(new ScoreIterationListener(1))) 41 | val r = new SeparatedCompGraph(models) 42 | println(conf.nbHead + " nbHead") 43 | r 44 | } 45 | } 46 | 47 | def buildAE[S: Statable]() = 48 | BuildNN.buildAE[S]() 49 | 50 | def load(filename: String) = { 51 | val cg = ModelSerializer.restoreComputationGraph(new File(filename)) 52 | new SingleCompGraph(cg) 53 | } 54 | 55 | 56 | def save(nn: NN, filename: String) = 57 | nn.save(filename) 58 | 59 | def genInput(ar: Array[Array[Float]]) = { 60 | // println(ar.length + " " + ar(0).length) 61 | Nd4j.create(ar) 62 | } 63 | 64 | def output(nn: NN, inp: Input) = { 65 | val o = nn.output(inp) 66 | o.map(h => (0 until h.rows()).map(i => h.getRow(i).dup().data().asFloat()).toArray) 67 | 68 | } 69 | 70 | 71 | def buildBatch[A](seqs: List[A], target: Array[Array[Array[Float]]], f: A => Array[Float]): Batch = { 72 | val in = genInput(seqs.map(f).toArray) 73 | val out = target.map(ht => Nd4j.create(ht)) 74 | new MultiDataSet(Array(in), out) 75 | } 76 | 77 | def cloneM(nn: NN) = 78 | nn.cloneM() 79 | 80 | def fit(nn: NN, b: Batch) = { 81 | // println(b.getFeatures(0).shape.toList + " " + b.getLabels.map(_.shape.toList).toList) 82 | nn.fit(b) 83 | // println("D") 84 | } 85 | 86 | 87 | /* def output[S: Statable](nn: NN, ls: List[S]) = { 88 | val o = nn.output(genInput(ls)) 89 | o.map(h => (0 until ls.length).map(i => h.getRow(i).dup().data().asFloat()).toArray) 90 | } 91 | / 92 | def outputWithAction[S: Statable](nn: NN, ls: List[(S, A)]) = { 93 | val o = nn.output(genInputWithAction(ls))(0) 94 | (0 until ls.length).map(i => o.getRow(i).dup().data().asFloat()).toArray 95 | } 96 | 97 | 98 | override def outputHead[S: Statable](nn: NN, ls: List[S], head:Int) = { 99 | val o = nn.outputHead(genInput(ls), head) 100 | // println(ls) 101 | (0 until ls.length).map(i => o.getRow(i).dup().data().asFloat()).toArray 102 | } 103 | */ 104 | 105 | } 106 | 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/mdp/GameDeepDQN.scala: -------------------------------------------------------------------------------- 1 | package drl.mdp 2 | 3 | import drl._ 4 | import drl.Rand 5 | import drl.mdp.MDP._ 6 | 7 | sealed trait MoveDeep extends Action 8 | case object RightM extends MoveDeep 9 | case object LeftM extends MoveDeep 10 | 11 | case class GameDeepDQN(t: Int, turn: Int, N:Int) { 12 | 13 | def canContinue = 14 | turn < N + 9 15 | 16 | def move(m: MoveDeep) = { 17 | val ns = m match { 18 | case RightM => copy(t = (t+1).min(N), turn = turn + 1) 19 | case LeftM => copy(t = (t-1).max(0), turn = turn + 1) 20 | } 21 | val reward = 22 | if (ns.t == 0) 23 | 1f/1000 24 | else if (ns.t == N) 25 | 1f 26 | else 27 | 0f 28 | (ns, reward) 29 | } 30 | 31 | } 32 | 33 | case class GameDeepStochDQN(t: Int, turn: Int) { 34 | 35 | val N = 6 36 | 37 | def canContinue = 38 | turn < N + 9 39 | 40 | def move(m: MoveDeep) = { 41 | val r = Rand.nextFloat 42 | val ns = m match { 43 | case RightM if r < 0.5f=> copy(t = (t+1).min(N), turn = turn + 1) 44 | case _ => copy(t = (t-1).max(0), turn = turn + 1) 45 | } 46 | val reward = 47 | if (ns.t == 0) 48 | 1f/100 49 | else if (ns.t == N) 50 | 1f 51 | else 52 | 0f 53 | (ns, reward) 54 | 55 | } 56 | 57 | 58 | } 59 | 60 | object GameDeepStochDQN { 61 | 62 | implicit object GameDeepStochDQNS extends Valuable[GameDeepStochDQN] { 63 | 64 | type CAction = MoveDeep 65 | 66 | val allActions = IndexedSeq(RightM, LeftM) 67 | 68 | val zero = GameDeepStochDQN(0, 0) 69 | 70 | def realizeTransition(g: GameDeepStochDQN, m: CAction) = { 71 | g.move(m) 72 | } 73 | 74 | def potentialStates(g: GameDeepStochDQN, a: A): IndexedSeq[(GameDeepStochDQN, Reward, Odd)] = { 75 | val (ng, rw) = realizeTransition(g, cAction(a)) 76 | IndexedSeq((ng, rw, 1f)) 77 | } 78 | 79 | 80 | def value(g: GameDeepStochDQN) = 81 | 0f //FIx 82 | 83 | def heuristic(g: GameDeepStochDQN) = 84 | 0f //FIX 85 | 86 | def availableActions(g: GameDeepStochDQN) = 87 | if (g.canContinue) 88 | allActions 89 | else 90 | IndexedSeq() 91 | 92 | def toInput(g: GameDeepStochDQN) = 93 | Array.fill(g.t+1)(1f) ++ Array.fill(g.N - g.t)(0f) 94 | 95 | 96 | def toString(g: GameDeepStochDQN) = 97 | g.toString 98 | 99 | } 100 | 101 | } 102 | 103 | object GameDeepDQN { 104 | 105 | 106 | implicit object GameDeepDQNS extends Randomizable[GameDeepDQN] { 107 | 108 | type CAction = MoveDeep 109 | 110 | val allActions = IndexedSeq(RightM, LeftM) 111 | 112 | def zero = GameDeepDQN(0, 0, GameDeepConf.gameL) 113 | 114 | def realizeTransition(g: GameDeepDQN, m: CAction) = { 115 | g.move(m) 116 | } 117 | 118 | def potentialStates(g: GameDeepDQN, a: A): IndexedSeq[(GameDeepDQN, Reward, Odd)] = { 119 | val (ng, rw) = realizeTransition(g, cAction(a)) 120 | IndexedSeq((ng, rw, 1f)) 121 | } 122 | 123 | def genRandom() = 124 | zero.copy(t = Rand.nextInt(GameDeepConf.gameL), turn = Rand.nextInt(GameDeepConf.gameL+9)) 125 | 126 | def value(g: GameDeepDQN) = 127 | 0f //FIx 128 | 129 | def heuristic(g: GameDeepDQN) = 130 | 0f //FIX 131 | 132 | def availableActions(g: GameDeepDQN) = 133 | if (g.canContinue) 134 | allActions 135 | else 136 | IndexedSeq() 137 | 138 | def toInput(g: GameDeepDQN) = 139 | // Array.fill(g.t+1)(1f) ++ Array.fill(g.N - g.t)(0f) ++ Array.fill(1)(g.turn/GameDeepConf.gameL.toFloat) 140 | Array.fill(g.t.min(g.N-1))(0f) ++ Array.fill(1)(1f) ++ Array.fill((g.N - g.t-1).max(0))(0f) //++ Array(g.turn/GameDeepConf.gameL.toFloat) 141 | // Array(g.t/g.N.toFloat) 142 | 143 | 144 | def toString(g: GameDeepDQN) = 145 | g.toString 146 | 147 | } 148 | 149 | 150 | } 151 | -------------------------------------------------------------------------------- /src/main/scala/backend/Backend.scala: -------------------------------------------------------------------------------- 1 | package drl.backend 2 | 3 | import drl.mdp.MDP._ 4 | 5 | object Backend { 6 | 7 | type PreBatch[S] = (List[S], Array[Array[Array[Float]]]) 8 | 9 | sealed trait Updater 10 | object RMSProp extends Updater 11 | object Nesterovs extends Updater 12 | 13 | sealed trait Activation 14 | object ReLu extends Activation 15 | 16 | 17 | case class ConfNN( 18 | seed: Int, 19 | outputWidth: Int, 20 | learningRate: Float = 0.005f, 21 | l1: Option[Float] = None, 22 | l2: Option[Float] = Some(0.005f), 23 | nbHead:Int = 10, 24 | commonHeight:Int = 3, 25 | headHeight:Int = 1, 26 | commonWidth:Int = 64, 27 | headWidth:Int = 128, 28 | momentum: Float = 0.9f, 29 | updater: Updater = Nesterovs, 30 | activation: Activation = ReLu 31 | ) 32 | 33 | trait NeuralN[NN] { 34 | 35 | type Batch 36 | type Input 37 | 38 | def build[S: Statable](conf: ConfNN, inputsize:Option[Int]=None): NN 39 | 40 | def buildAE[S: Statable](): NN 41 | 42 | def load(filename: String): NN 43 | def save(nn: NN, filename: String): Boolean 44 | def cloneM(nn: NN): NN 45 | def buildBatch[A](seqs: List[A], target: Array[Array[Array[Float]]], f: A => Array[Float]): Batch 46 | 47 | def buildBatch[S: Statable](seqs: List[S], target: Array[Array[Array[Float]]]): Batch = 48 | buildBatch(seqs, target, (x:S) => x.toInput) 49 | 50 | def fit(nn: NN, b: Batch): Unit 51 | def genInput(ar: Array[Array[Float]]): Input 52 | 53 | def output(nn: NN, inp: Input): Array[Array[Array[Float]]] 54 | 55 | def output[A](nn: NN, la: List[A], f: A => Array[Float]):Array[Array[Array[Float]]] = 56 | output(nn, genInput(la.map(f).toArray)) 57 | 58 | def output[S: Statable](nn: NN, ls: List[S]): Array[Array[Array[Float]]] = 59 | output(nn, ls, (x:S) => x.toInput) 60 | 61 | def outputWithAction[S: Statable](nn: NN, ls: List[(S, A)]): Array[Array[Float]] = 62 | output(nn, ls, (x: (S,A)) => x._1.inputWithAction(x._2))(0) 63 | 64 | def outputHead[S: Statable](nn: NN, ls: List[S], head:Int) = 65 | output(nn, ls).apply(head) 66 | 67 | } 68 | 69 | implicit class NeuralNOps[NN: NeuralN](n: NN) { 70 | 71 | val F = implicitly[NeuralN[NN]] 72 | 73 | type FBatch = F.Batch//NeuralN[NN]#Batch 74 | 75 | def buildBatch[A](seqs: List[A], target: Array[Array[Array[Float]]], f: A => Array[Float]): FBatch = 76 | F.buildBatch(seqs, target, f) 77 | 78 | def save(filename: String): Boolean = 79 | F.save(n, filename) 80 | 81 | def fitB(b: FBatch):Unit = 82 | F.fit(n, b) 83 | 84 | def buildAndFit[A](seqs: List[A], target: Array[Array[Array[Float]]], f: A => Array[Float]) = { 85 | val batch = buildBatch(seqs, target, f) 86 | fitB(batch) 87 | } 88 | 89 | def fit(bi: Iterator[FBatch]):Unit = 90 | bi.foreach(fitB) 91 | 92 | def fit[S: Statable](si: Iterator[PreBatch[S]]):Unit = 93 | for (s <- si) { 94 | val b = F.buildBatch(s._1, s._2) 95 | fitB(b) 96 | } 97 | 98 | def output[A](la: List[A], f: A => Array[Float]):Array[Array[Array[Float]]] = 99 | F.output(n, la, f) 100 | 101 | def output[S: Statable](ls: List[S]): Array[Array[Array[Float]]] = 102 | F.output(n, ls) 103 | 104 | def outputHead[S: Statable](ls: List[S], head:Int): Array[Array[Float]] = 105 | F.outputHead(n, ls, head) 106 | 107 | def outputHeadS[S: Statable](s: S, head:Int): Array[Float] = { 108 | outputHead(List(s), head).apply(0) 109 | } 110 | 111 | def outputWithAction[S: Statable](ls: List[(S, A)]): Array[Array[Float]] = 112 | F.outputWithAction(n, ls) 113 | 114 | 115 | def outputS[S: Statable](s: S): Array[Array[Float]] = 116 | output(List(s)).map(_(0)) 117 | 118 | def cloneM():NN = 119 | F.cloneM(n) 120 | 121 | 122 | } 123 | 124 | 125 | } 126 | -------------------------------------------------------------------------------- /src/main/scala/mdp/Game6561.scala: -------------------------------------------------------------------------------- 1 | package drl.mdp 2 | 3 | import drl._ 4 | import drl.mdp.MDP._ 5 | 6 | sealed trait Move6561 extends MDP.Action 7 | case object Start extends Move6561 8 | case class Turn(dir: Direction) extends Move6561 9 | case class Place(x: Int, y: Int, c: Color) extends Move6561 10 | 11 | 12 | case class Game6561(grid: Grid6561, turn: Int, maxScore: Float) { 13 | 14 | lazy val canContinue = 15 | turn < Game6561Conf.gameL && !availableMoveNext.isEmpty 16 | 17 | 18 | lazy val toInput = 19 | grid.toInput6561 :+ (1f-turn.toFloat/Game6561Conf.gameL) :+ (turn%5.toFloat)/5 20 | 21 | def canTurn(dir: Direction) = 22 | move(Turn(dir)).exists(_.grid != grid) 23 | 24 | 25 | 26 | def move(mv: Move6561) = { 27 | 28 | val ng = mv match { 29 | case Turn(dir) if Game6561.placeColor(turn).isEmpty => 30 | Some(grid.move(dir)._1) 31 | case Place(x, y, c) if Game6561.placeColor(turn).exists(_ == c) => 32 | grid.place(x, y, new Piece(1, c, 3)) 33 | } 34 | 35 | ng.map(g => Game6561(g, turn+1, maxScore.max(grid.value))) 36 | 37 | } 38 | 39 | 40 | def eval = 41 | grid.eval 42 | 43 | def value = 44 | grid.value 45 | 46 | def randomMove = { 47 | val aM = availableMoveNext 48 | aM(Rand.nextInt(aM.length)) 49 | } 50 | 51 | def applyRandomMove = 52 | if (canContinue) 53 | move(randomMove).get 54 | else 55 | this 56 | 57 | def applyEvalMove = { 58 | val aM = availableMoveNext 59 | aM.map(m => move(m).get).maxBy(_.grid.eval) 60 | } 61 | 62 | 63 | 64 | lazy val availableMoveNext = { 65 | val aM = 66 | if (turn == Game6561Conf.gameL) 67 | List() 68 | else 69 | (turn % 5) match { 70 | case 3 | 4 => Game6561.ALL_TURNS filter (x => canTurn(x.dir)) toList 71 | case t => 72 | val color = Game6561.placeColor(t).get 73 | grid.emptySpots.map(p => Place(p._1, p._2, color)).toList 74 | } 75 | aM 76 | } 77 | // List(availableMoveGen(turn+1).map(x => (x, move(x).get)).maxBy(_._2.grid.eval)._1) 78 | 79 | } 80 | 81 | 82 | object Game6561 { 83 | 84 | val ALL_TURNS = IndexedSeq(Turn(Up), Turn(Down), Turn(Right), Turn(Left)) 85 | 86 | val colors = List(Blue, Red, Gray) 87 | 88 | val moves = (0 to 51).map( i => 89 | if (i < 48) 90 | Place((i%16)%4, (i%16)/4, colors(i/16)) 91 | else 92 | ALL_TURNS(i-48) 93 | ) 94 | 95 | 96 | 97 | def player(turn: Int) = 98 | (turn+1)%2 99 | 100 | def placeColor(turn: Int): Option[Color] = 101 | (turn % 5) match { 102 | case 3 | 4 => None 103 | case 0 => Some(Blue) 104 | case 1 => Some(Red) 105 | case 2 => Some(Gray) 106 | } 107 | 108 | 109 | def newGame(mult:Int) = 110 | Game6561(Grid6561(mult), 0, 0) 111 | 112 | 113 | implicit object Game2048V extends Valuable[Game2048] { 114 | 115 | type CAction = Move6561 116 | 117 | val allActions = Game6561.ALL_TURNS 118 | 119 | val zero = Game2048(Grid6561(2), 0) 120 | 121 | def realizeTransition(g: Game2048, m: CAction) = { 122 | val ng = g.fullMove(m) 123 | // (ng, ng.value - g.value) 124 | ng 125 | } 126 | 127 | def potentialStates(g: Game2048, a: A): IndexedSeq[(Game2048, Reward, Odd)] = { 128 | val (ng, rw) = g.move(cAction(a)) 129 | val ngs = ng.grid.emptySpots.map(spot => ng.copy(grid = ng.grid.place(spot._1, spot._2, new Piece(1, Red, 2)).get)) 130 | ngs.map(x => (x, rw, 1f/ngs.length)).toIndexedSeq 131 | } 132 | 133 | def availableActions(g: Game2048) = { 134 | g.availableMoveNext 135 | } 136 | 137 | def value(g: Game2048) = 138 | g.value 139 | 140 | def heuristic(g: Game2048) = 141 | g.eval.toFloat 142 | 143 | def toInput(g: Game2048) = 144 | g.toInput 145 | 146 | def toString(g: Game2048) = 147 | g.toString 148 | 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/main/scala/Conf.scala: -------------------------------------------------------------------------------- 1 | package drl 2 | 3 | import drl.backend.Backend._ 4 | import drl.learning._ 5 | 6 | case class NConf( 7 | learningRate: Float = 0.005f, 8 | l1: Option[Float] = None, 9 | l2: Option[Float] = Some(0.005f), 10 | commonHeight:Int = 0, 11 | headHeight:Int = 3, 12 | commonWidth:Int = 128, 13 | headWidth:Int = 128, 14 | momentum: Float = 0.95f, 15 | updater: Updater = RMSProp, 16 | activation: Activation = ReLu 17 | ) 18 | 19 | 20 | object Game6561Conf { 21 | val maxTileValue = 7 22 | val gameL = 100 23 | } 24 | 25 | case class FullConf(nconf: NConf, deconf: DeepExplorationConf, offrlconf: OfflineRLConf, qconf: QConf, vconf: VConf, novconf: NoveltyConf = NoveltyConf()) 26 | 27 | object Game2048Conf { 28 | val maxTileValue = 29 | 12 30 | 31 | val qconf = QConf( 32 | expRep = true, 33 | zeroImpossible = false, 34 | minPoolFactor = 30, 35 | maxPoolFactor = 35 36 | ) 37 | 38 | val vconf = VConf( 39 | lambda = 0.8f 40 | ) 41 | 42 | val nconf = NConf( 43 | learningRate = 0.005f, 44 | l1 = None, 45 | l2 = Some(0.005f), 46 | commonHeight = 2, 47 | headHeight = 1, 48 | commonWidth = 128, 49 | headWidth = 128, 50 | momentum = 0.95f, 51 | updater = RMSProp, 52 | activation = ReLu 53 | ) 54 | 55 | val deconf = DeepExplorationConf( 56 | nbHead = 10, 57 | targetNPeriod = None, 58 | combinedTestFrequency = 5, 59 | nbAverages = 1 60 | ) 61 | 62 | def offconf(numEX: Int, seed: Int) = 63 | OfflineRLConf(numEX, seed, 1000, 300f) //3200 64 | 65 | def fullconf(numEX: Int, seed: Int) = 66 | FullConf(nconf, deconf, offconf(numEX, seed), qconf, vconf) 67 | 68 | } 69 | 70 | object GameDeepStochConf { 71 | val qconf = QConf(true) 72 | val nconf = NConf() 73 | val vconf = VConf( 74 | lambda = 0f 75 | ) 76 | 77 | 78 | def offconf(numEX: Int, seed: Int) = OfflineRLConf(numEX, seed, 100, 1f) 79 | val deconf = DeepExplorationConf(1) 80 | def fullconf(numEX: Int, seed: Int) = FullConf(nconf, deconf, offconf(numEX, seed), qconf, vconf) 81 | } 82 | 83 | object GameDeepConf { 84 | 85 | var gameL = 86 | 20 87 | 88 | val qconf = QConf( 89 | expRep = true, 90 | zeroImpossible = false, 91 | minPoolFactor = 30, 92 | maxPoolFactor = 35 93 | ) 94 | 95 | val vconf = VConf( 96 | lambda = 0f 97 | ) 98 | 99 | 100 | val nconf = NConf( 101 | learningRate = 0.005f, 102 | l1 = None, 103 | l2 = Some(0.005f), 104 | commonHeight = 0, 105 | headHeight = 3, 106 | commonWidth = 128, 107 | headWidth = 128, 108 | momentum = 0.95f, 109 | updater = RMSProp, 110 | activation = ReLu 111 | ) 112 | 113 | val deconf = DeepExplorationConf( 114 | nbHead = 1, 115 | targetNPeriod = None, 116 | combinedTestFrequency = 5, 117 | nbAverages = 1 118 | ) 119 | 120 | def offconf(numEX: Int, seed: Int) = 121 | OfflineRLConf(numEX, seed, 100, 10f) 122 | 123 | def fullconf(numEX: Int, seed: Int) = 124 | FullConf(nconf, deconf, offconf(numEX, seed), qconf, vconf) 125 | } 126 | 127 | 128 | 129 | object GameDeepAEConf { 130 | 131 | var gameL = 132 | 20 133 | 134 | val qconf = QConf( 135 | expRep = true, 136 | zeroImpossible = false, 137 | minPoolFactor = 30, 138 | maxPoolFactor = 35 139 | ) 140 | 141 | val vconf = VConf( 142 | lambda = 0.8f 143 | ) 144 | 145 | 146 | val nconf = NConf( 147 | learningRate = 0.005f, 148 | l1 = None, 149 | l2 = Some(0.005f), 150 | commonHeight = 0, 151 | headHeight = 3, 152 | commonWidth = 128, 153 | headWidth = 128, 154 | momentum = 0.95f, 155 | updater = RMSProp, 156 | activation = ReLu 157 | ) 158 | 159 | val deconf = DeepExplorationConf( 160 | nbHead = 1, 161 | targetNPeriod = None, 162 | combinedTestFrequency = 5, 163 | nbAverages = 1 164 | ) 165 | 166 | def offconf(numEX: Int, seed: Int) = 167 | OfflineRLConf(numEX, seed, 100, 1000f) 168 | 169 | def fullconf(numEX: Int, seed: Int) = 170 | FullConf(nconf, deconf, offconf(numEX, seed), qconf, vconf) 171 | } 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/rubenfiszel/scala-drl.svg?branch=master)](https://travis-ci.org/rubenfiszel/scala-drl) 2 | 3 | # Scala Deep Reinforcement Learning 4 | 5 | Source code for my [semester project](https://github.com/rubenfiszel/scala-drl/raw/master/report3.pdf) at the LAI (Laboratory of Artificial Intelligence) of EPFL 6 | 7 | General Markov Decision Process (MDP) deep reinforcement library on top of deeplearning4j. 8 | 9 | 10 | ## Quickstart 11 | 12 | ### Implement one of the MDP typeclass 13 | 14 | The library use the typeclass pattern: You use your own implementation of your MDP and write a typeclass implementation in the scope to make it usable by the library: 15 | 16 | From the most constrained to the less contrained: 17 | 18 | Randomizable is a subset of Valuable which is a subset of Statable 19 | 20 | Thus, the most easy to implement is Statable. 21 | 22 | * Statable enables Q-learning 23 | * Valuable enables TD-Lambda 24 | * Randomizable enables to use the Autoencoder for advanced features 25 | 26 | I assume that you have your own implementation of a MDP. Let's take for example 2048. Here is a typeclass implementation: 27 | 28 | ```scala 29 | implicit object Game6561V extends Randomizable[Game6561] { 30 | 31 | type CAction = Move6561 32 | 33 | val allActions = Game6561.moves 34 | 35 | val zero = Game6561(Grid6561(3), 0, 0) 36 | 37 | def realizeTransition(g: Game6561, m: CAction) = { 38 | val ng = g.move(m).get 39 | (ng, ng.value - g.value) 40 | } 41 | 42 | def potentialStates(g: Game6561, a: A): IndexedSeq[(Game6561, Reward, Odd)] = { 43 | val (ng, rw) = realizeTransition(g, cAction(a)) 44 | IndexedSeq((ng, rw, 1f)) 45 | } 46 | 47 | def availableActions(g: Game6561) = 48 | g.availableMoveNext 49 | 50 | def value(g: Game6561) = 51 | g.value 52 | 53 | def heuristic(g: Game6561) = 54 | g.eval.toFloat 55 | 56 | def toInput(g: Game6561) = 57 | g.toInput 58 | 59 | def toString(g: Game6561) = 60 | g.toString 61 | 62 | def genRandom() = 63 | Game6561(Grid6561.random(Game6561Conf.gameL, 3), Rand.nextInt(Game6561Conf.gameL), 0) 64 | } 65 | ``` 66 | 67 | Valuable only requires: 68 | 69 | ```scala 70 | def value(state: S): Value 71 | 72 | def heuristic(state: S): Float 73 | 74 | def potentialStates(state: S, action: A): IndexedSeq[(S, Reward, Odd)] 75 | ``` 76 | 77 | then to apply Q-learning: 78 | 79 | ```scala 80 | import drl.Rand 81 | import drl.backend._ 82 | import drl.mdp.Game2048._ 83 | 84 | Rand.setSeed(seed) 85 | 86 | val nconf:NConf = ... 87 | val deconf: DeepExplorationConf = ... 88 | val offrlconf: OfflineRLConf = ... 89 | val qconf: QConf = ... 90 | 91 | SelfPlay.trainModelRLDeepQ[Game2048, SeparableCompGraph](qconf, scala.Left(nconf), deconf, offrlconf) 92 | ``` 93 | 94 | ## MDP included 95 | 96 | * 2048 97 | * 6561 98 | * Chain MDP 99 | 100 | Easy to add any MDP through the typeclass pattern. 101 | 102 | 103 | ## Current features 104 | 105 | * DQN (Q-learning) 106 | * TD-Lambda 107 | * Deep exploration through novelty incentivising 108 | * Deep exploration through bootstrapping DQN 109 | * Monte-Carlo Search Tree 110 | * Minimax and expectimax 111 | * Deeplearning4j backend 112 | 113 | 114 | 115 | ## MIT License 116 | 117 | Copyright (c) 2016 Ruben Fiszel 118 | 119 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 120 | 121 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 122 | 123 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 124 | -------------------------------------------------------------------------------- /src/main/scala/backend/BuildNN.scala: -------------------------------------------------------------------------------- 1 | package drl.backend 2 | 3 | 4 | import java.nio.file.{Files, Paths} 5 | import org.deeplearning4j.nn.conf.inputs.InputType 6 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 7 | import org.deeplearning4j.nn.graph.ComputationGraph 8 | import org.deeplearning4j.nn.conf.ComputationGraphConfiguration 9 | import org.apache.commons.io.FileUtils 10 | import org.deeplearning4j.datasets.iterator.DataSetIterator 11 | import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator 12 | import org.deeplearning4j.eval.Evaluation 13 | import org.deeplearning4j.nn.api.{Layer, OptimizationAlgorithm} 14 | import org.deeplearning4j.nn.conf.layers._ 15 | import org.deeplearning4j.nn.conf.{MultiLayerConfiguration, NeuralNetConfiguration, Updater} 16 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 17 | import org.deeplearning4j.nn.params.DefaultParamInitializer 18 | import org.deeplearning4j.nn.weights.WeightInit 19 | import org.nd4j.linalg.api.ndarray.INDArray 20 | import org.nd4j.linalg.dataset.{DataSet, SplitTestAndTrain} 21 | import org.nd4j.linalg.factory.Nd4j 22 | import org.nd4j.linalg.lossfunctions.LossFunctions 23 | 24 | //import org.deeplearning4j.ui._ 25 | //import org.deeplearning4j.ui.weights._ 26 | import org.deeplearning4j.nn.api.Model; 27 | 28 | import drl.mdp.MDP._ 29 | import drl.backend.Backend._ 30 | 31 | object BuildNN { 32 | 33 | 34 | def buildOL(out: Int) = { 35 | new OutputLayer.Builder(LossFunctions.LossFunction.MSE) 36 | .nOut(out) 37 | .weightInit(WeightInit.XAVIER) 38 | .activation("identity") 39 | .build() 40 | } 41 | 42 | def buildOLID(out: Int) = { 43 | new OutputLayer.Builder(LossFunctions.LossFunction.MSE) 44 | // .nIn(out) 45 | .nOut(out) 46 | .weightInit(WeightInit.XAVIER) 47 | .learningRate(0f) 48 | .activation("identity") 49 | .build() 50 | } 51 | 52 | 53 | def buildDL(in:Int, out:Int, learningRate: Float):DenseLayer = { 54 | new DenseLayer.Builder() 55 | .nIn(in) 56 | .nOut(out) 57 | .weightInit(WeightInit.RELU) 58 | .learningRate(learningRate) 59 | .biasLearningRate(learningRate) 60 | .activation("relu") 61 | .build() 62 | } 63 | 64 | 65 | def buildAE[S: Statable]() = { 66 | 67 | val fs = implicitly[Statable[S]].featureSize 68 | val nccb = new NeuralNetConfiguration.Builder() 69 | val gb = nccb 70 | .learningRate(0.01) 71 | .iterations(1) //makebetter 72 | .regularization(true) 73 | .l2(0.05) 74 | .updater(Updater.NESTEROVS) 75 | .momentum(0.9f) 76 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 77 | .seed(1234) 78 | .graphBuilder() 79 | .setInputTypes(InputType.feedForward(fs)) 80 | .addInputs("input") 81 | 82 | val layers = List(fs/2, fs/4, fs/2) 83 | 84 | gb.addLayer("L0", buildDL(fs, layers(0), 0.01f), "input") 85 | 86 | (1 until layers.length).foreach(x => 87 | gb.addLayer("L"+x, buildDL(layers(x-1), layers(x), 0.01f), "L"+(x-1)) 88 | ) 89 | 90 | val h = layers.length 91 | 92 | gb.addLayer("L"+h, buildDL(layers(h-1), fs, 0.01f), "L"+(h-1)) 93 | 94 | gb.addLayer("out2", buildOLID(layers(2)), "L2") 95 | val conf = gb.addLayer("out", buildOL(fs), "L"+h) 96 | .setOutputs("out", "out2") 97 | .build() 98 | 99 | println("AE") 100 | val cg = new ComputationGraph(conf) 101 | cg.init() 102 | cg.setListeners(new ScoreIterationListener(1)) 103 | new SingleCompGraph(cg) 104 | } 105 | 106 | 107 | def buildCG[S: Statable](conf: ConfNN, inputsize: Option[Int] = None) = { 108 | 109 | println("CG " + conf.nbHead) 110 | // val nccb = splitAndAdd(new NeuralNetConfiguration.Builder()) 111 | val lr = 112 | if (conf.nbHead > 1) 113 | 2*conf.learningRate 114 | else 115 | conf.learningRate 116 | 117 | val nccb = new NeuralNetConfiguration.Builder() 118 | val nconf = nccb 119 | // .learningRate(learningRate) 120 | .learningRate(lr) 121 | .iterations(1) //makebetter 122 | 123 | conf.l1.foreach(x => 124 | nconf 125 | .regularization(true) 126 | .l1(x) 127 | ) 128 | 129 | conf.l2.foreach(x => 130 | nconf 131 | .regularization(true) 132 | .l2(x) 133 | ) 134 | 135 | conf.updater match { 136 | case Nesterovs => nconf.updater(Updater.NESTEROVS) 137 | case RMSProp => nconf.updater(Updater.RMSPROP) 138 | } 139 | 140 | // .dropOut(0.2) 141 | // .updater(Updater.ADAGRAD) 142 | 143 | val gb = nconf 144 | .momentum(conf.momentum) 145 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 146 | .seed(conf.seed) 147 | .graphBuilder() 148 | 149 | 150 | val fs = inputsize.getOrElse(implicitly[Statable[S]].featureSize) 151 | 152 | val os = 153 | conf.outputWidth 154 | // implicitly[Statable[S]].allActions.length 155 | 156 | println("FS: " + fs + " OS:" + os) 157 | gb 158 | .setInputTypes(InputType.feedForward(fs)) 159 | 160 | .addInputs("input") 161 | 162 | if (conf.commonHeight > 0) 163 | gb. 164 | addLayer("L1", buildDL(fs, conf.commonWidth, lr/conf.nbHead), "input") 165 | 166 | for (i <- 2 to conf.commonHeight) 167 | gb.addLayer("L"+i, buildDL(conf.commonWidth, conf.commonWidth, lr/conf.nbHead), "L"+(i-1)) 168 | 169 | if (conf.headHeight > 1) { 170 | 171 | for (i <- (1 to conf.nbHead)) 172 | if (conf.commonHeight > 0) 173 | gb.addLayer("LH"+i+"-1", buildDL(conf.commonWidth, conf.headWidth, lr), "L"+conf.commonHeight) 174 | else 175 | gb.addLayer("LH"+i+"-1", buildDL(fs, conf.headWidth, lr), "input") 176 | 177 | 178 | for { 179 | i <- 1 to conf.nbHead 180 | j <- 2 to conf.headHeight 181 | } 182 | gb.addLayer("LH"+i+"-"+j, buildDL(conf.headWidth, conf.headWidth, lr), "LH"+i+"-"+(j-1)) 183 | 184 | 185 | for (i <- (1 to conf.nbHead)) 186 | gb.addLayer("out"+i, buildOL(os), "LH"+i+"-"+conf.headHeight) 187 | 188 | } else { 189 | 190 | for (i <- (1 to conf.nbHead)) 191 | gb.addLayer("out"+i, buildOL(os), "L"+conf.commonHeight) 192 | } 193 | 194 | val outs = 195 | (1 to conf.nbHead).map(i => "out"+i) 196 | 197 | gb 198 | .setOutputs(outs:_*) 199 | .build() 200 | 201 | 202 | } 203 | 204 | 205 | } 206 | -------------------------------------------------------------------------------- /project/.ensime: -------------------------------------------------------------------------------- 1 | ( 2 | :root-dir "/home/atoll/codecup-ml/project" 3 | :cache-dir "/home/atoll/codecup-ml/project/.ensime_cache" 4 | :scala-compiler-jars ("/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-compiler.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-library.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/scala-reflect-2.10.6.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-reflect.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scalap/jars/scalap-2.10.6.jar") 5 | :name "codecup-ml-project" 6 | :java-home "/usr/lib/jvm/java-8-openjdk" 7 | :java-flags ("-Xss2m" "-Xms1024m" "-Xmx1024m" "-XX:ReservedCodeCacheSize=128m" "-XX:MaxMetaspaceSize=256m") 8 | :reference-source-roots ("/usr/lib/jvm/java-8-openjdk/src.zip") 9 | :scala-version "2.10.6" 10 | :compiler-args ("-feature" "-deprecation" "-Xlint" "-Yinline-warnings" "-Yno-adapted-args" "-Ywarn-dead-code" "-Ywarn-numeric-widen" "-Xfuture") 11 | :formatting-prefs nil 12 | :disable-source-monitoring nil 13 | :disable-class-monitoring nil 14 | :subprojects (( 15 | :name "codecup-ml-project" 16 | :source-roots ("/home/atoll/codecup-ml/project") 17 | :targets nil 18 | :test-targets nil 19 | :depends-on-modules nil 20 | :compile-deps ("/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/actions-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/api-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/apply-macro-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/cache-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/classfile-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/classpath-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/collections-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/command-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/compile-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/compiler-integration-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/compiler-interface-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/compiler-ivy-integration-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/completion-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/control-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/cross-0.13.11.jar" "/home/atoll/.ivy2/cache/scala_2.10/sbt_0.13/org.ensime/ensime-sbt/jars/ensime-sbt-0.4.0.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/incremental-compiler-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/xsbti/interface-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/io-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/ivy-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/ivy-2.3.0-sbt-2cc8d2761242b072cedb0a04cb39435c4fa24f9a.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/jansi-1.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/jansi.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/jawn-parser_2.10-0.6.0.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/jline-2.13.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/jline.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/jsch-0.1.46.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/json4s-ast_2.10-3.2.10.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/json4s-core_2.10-3.2.10.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/json4s-support_2.10-0.6.0.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/launcher-interface-1.0.0-M1.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/logging-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/logic-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/main-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/main-settings-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/paranamer-2.6.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/persist-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/process-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/quasiquotes_2.10-2.0.1.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/relation-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/run-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/sbinary_2.10-0.4.2.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/sbt-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-compiler.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-library.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/scala-pickling_2.10-0.10.1.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/scala-reflect-2.10.6.jar" "/home/atoll/.sbt/boot/scala-2.10.6/lib/scala-reflect.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scalap/jars/scalap-2.10.6.jar" "/home/atoll/.ivy2/cache/org.scalariform/scalariform_2.10/jars/scalariform_2.10-0.1.4.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/serialization_2.10-0.1.2.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/task-system-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/tasks-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/test-agent-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/test-interface-1.0.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/testing-0.13.11.jar" "/home/atoll/.sbt/boot/scala-2.10.6/org.scala-sbt/sbt/0.13.11/tracking-0.13.11.jar") 21 | :runtime-deps nil 22 | :test-deps nil 23 | :doc-jars nil 24 | :reference-source-roots ("/home/atoll/.ivy2/cache/org.scala-sbt/actions/srcs/actions-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/api/srcs/api-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/apply-macro/srcs/apply-macro-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/cache/srcs/cache-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/classfile/srcs/classfile-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/classpath/srcs/classpath-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/collections/srcs/collections-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/command/srcs/command-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/compile/srcs/compile-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/compiler-integration/srcs/compiler-integration-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/compiler-interface/srcs/compiler-interface-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/compiler-ivy-integration/srcs/compiler-ivy-integration-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/completion/srcs/completion-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/control/srcs/control-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/cross/srcs/cross-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/scala_2.10/sbt_0.13/org.ensime/ensime-sbt/srcs/ensime-sbt-0.4.0-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/incremental-compiler/srcs/incremental-compiler-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/interface/srcs/interface-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/io/srcs/io-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/ivy/srcs/ivy-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt.ivy/ivy/srcs/ivy-2.3.0-sbt-2cc8d2761242b072cedb0a04cb39435c4fa24f9a-sources.jar" "/home/atoll/.ivy2/cache/org.fusesource.jansi/jansi/srcs/jansi-1.11-sources.jar" "/home/atoll/.ivy2/cache/org.spire-math/jawn-parser_2.10/srcs/jawn-parser_2.10-0.6.0-sources.jar" "/home/atoll/.ivy2/cache/jline/jline/srcs/jline-2.13-sources.jar" "/home/atoll/.ivy2/cache/com.jcraft/jsch/srcs/jsch-0.1.46-sources.jar" "/home/atoll/.ivy2/cache/org.json4s/json4s-ast_2.10/srcs/json4s-ast_2.10-3.2.10-sources.jar" "/home/atoll/.ivy2/cache/org.json4s/json4s-core_2.10/srcs/json4s-core_2.10-3.2.10-sources.jar" "/home/atoll/.ivy2/cache/org.spire-math/json4s-support_2.10/srcs/json4s-support_2.10-0.6.0-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/launcher-interface/srcs/launcher-interface-1.0.0-M1-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/logging/srcs/logging-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/logic/srcs/logic-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/main/srcs/main-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/main-settings/srcs/main-settings-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/com.thoughtworks.paranamer/paranamer/srcs/paranamer-2.6-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/persist/srcs/persist-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/process/srcs/process-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scalamacros/quasiquotes_2.10/srcs/quasiquotes_2.10-2.0.1-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/relation/srcs/relation-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/run/srcs/run-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-tools.sbinary/sbinary_2.10/srcs/sbinary_2.10-0.4.2-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/sbt/srcs/sbt-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scala-compiler/srcs/scala-compiler-2.10.6-sources.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scala-library/srcs/scala-library-2.10.6-sources.jar" "/home/atoll/.ivy2/cache/org.scala-lang.modules/scala-pickling_2.10/srcs/scala-pickling_2.10-0.10.1-sources.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scala-reflect/srcs/scala-reflect-2.10.6-sources.jar" "/home/atoll/.ivy2/cache/org.scala-lang/scalap/srcs/scalap-2.10.6-sources.jar" "/home/atoll/.ivy2/cache/org.scalariform/scalariform_2.10/srcs/scalariform_2.10-0.1.4-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/serialization_2.10/srcs/serialization_2.10-0.1.2-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/task-system/srcs/task-system-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/tasks/srcs/tasks-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/test-agent/srcs/test-agent-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/test-interface/srcs/test-interface-1.0-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/testing/srcs/testing-0.13.11-sources.jar" "/home/atoll/.ivy2/cache/org.scala-sbt/tracking/srcs/tracking-0.13.11-sources.jar"))) 25 | ) 26 | -------------------------------------------------------------------------------- /src/main/scala/mdp/Grid6561.scala: -------------------------------------------------------------------------------- 1 | package drl.mdp 2 | 3 | import drl._ 4 | 5 | sealed trait Color 6 | case object Blue extends Color 7 | case object Red extends Color 8 | case object Gray extends Color 9 | 10 | 11 | case class Piece(value: Int, color: Color, mult: Int) { 12 | 13 | def combine = 14 | copy(value = value*mult) 15 | 16 | def log = 17 | Math.round(Math.log(value.toDouble)/Math.log(mult) + 1).toInt 18 | 19 | override def toString() = { 20 | val c = color match { 21 | case Blue => "B" 22 | case Red => "R" 23 | case Gray => "G" 24 | } 25 | c + log.toString 26 | } 27 | } 28 | 29 | object Grid6561 { 30 | 31 | def apply(mult: Int): GridIS6561 = 32 | GridIS6561(IndexedSeq.fill(4,4)(None), mult) 33 | 34 | def exponential(max:Int) = 35 | max-(Math.sqrt(Rand.nextFloat*(max+1)*(max+1))).floor 36 | 37 | def randomPiece6561(nbP:Int, dominant:Color, max:Int, mult: Int) = 38 | if (Rand.nextInt(16) < nbP) { 39 | val v = Math.pow(3, exponential(max)).toInt 40 | val c = 41 | if (Rand.nextFloat < 0.3) 42 | dominant 43 | else 44 | Rand.choose(Seq(Blue, Red, Gray)) 45 | 46 | Some(Piece(v, c, mult)) 47 | } 48 | else { 49 | None 50 | } 51 | 52 | def random(turn:Int, mult:Int): GridIS6561 = { 53 | val nbP = (Rand.nextInt(17)+Rand.nextInt(17))/2 54 | val c = Rand.choose(Seq(Blue, Red, Gray)) 55 | val max = (Math.log(turn/5)/Math.log(2)).toInt + 1 56 | val is = 57 | GridIS6561(IndexedSeq.fill(4,4)(randomPiece6561(nbP, c, max, mult)), mult) 58 | if (is.value > turn*4) 59 | random(turn, mult) 60 | else 61 | is 62 | } 63 | 64 | } 65 | 66 | /* 67 | grid coordinates expressed as yx: 68 | 69 | 11 12 13 13 70 | 21 22 23 24 71 | 31 32 33 34 72 | 41 42 43 44 73 | 74 | */ 75 | 76 | sealed trait Direction 77 | case object Up extends Direction 78 | case object Down extends Direction 79 | case object Right extends Direction 80 | case object Left extends Direction 81 | 82 | trait Grid6561 { 83 | 84 | def reward: Int 85 | def evalOpp: Double 86 | def eval: Double 87 | def value: Float 88 | def toInput2048: Array[Float] 89 | def toInput6561: Array[Float] 90 | def allSum: Array[Float] 91 | def place(x:Int, y: Int, piece: Piece): Option[Grid6561] 92 | def move(dir: Direction): (Grid6561, Float) 93 | def emptySpots: Seq[(Int, Int)] 94 | def get(x:Int, y:Int):Option[Piece] 95 | 96 | } 97 | 98 | 99 | 100 | case class GridIS6561(grid: IndexedSeq[IndexedSeq[Option[Piece]]], mult: Int) extends Grid6561 { 101 | 102 | def get(x: Int, y:Int) = 103 | grid(y)(x) 104 | 105 | def row(y:Int) = 106 | grid(y) 107 | 108 | def col(x:Int) = 109 | (0 to 3).map(y => grid(y)(x)).toIndexedSeq 110 | 111 | def all = 112 | (0 to 15).map(i => grid(i/4)(i%4)).filter(_.isDefined).map(_.get) 113 | 114 | def place(x:Int, y:Int, piece: Piece):Option[GridIS6561] = 115 | if (get(x, y) == None) 116 | Some(copy(grid = grid.updated(y, grid(y).updated(x, Some(piece))))) 117 | else 118 | None 119 | 120 | def move(dir: Direction = Up) = { 121 | 122 | val ar = Array.fill[Option[Piece]](4, 4)(None) 123 | 124 | val ns = 125 | (dir == Up || dir == Down) //NORTHSOUTH 126 | val dr = 127 | (dir == Down || dir == Right) //DOWNRIGHT 128 | val range = 129 | if (dr) 130 | (3 to 0 by -1) 131 | else 132 | (0 to 3) 133 | 134 | var merge = 0f 135 | 136 | for (i <- 0 to 3) { 137 | 138 | var bef:Option[Piece] = None 139 | var inc = 140 | if (dr) 141 | 4 142 | else 143 | -1 144 | 145 | for (j <- range) { 146 | 147 | val c = 148 | if (ns) 149 | grid(j)(i) 150 | else 151 | grid(i)(j) 152 | 153 | 154 | c match { 155 | case Some(x) if (bef.exists(_ == x)) => 156 | 157 | val comb = Some(x.combine) 158 | 159 | merge += comb.get.value 160 | 161 | if (ns) 162 | ar(inc)(i) = comb 163 | else 164 | ar(i)(inc) = comb 165 | 166 | bef = None 167 | 168 | 169 | case Some(x) if (bef.exists(_.value == x.value)) => 170 | if (ns) 171 | ar(inc)(i) = None 172 | else 173 | ar(i)(inc) = None 174 | 175 | if (dr) 176 | inc += 1 177 | else 178 | inc -= 1 179 | 180 | bef = None 181 | 182 | 183 | case sx@Some(x) => 184 | if (dr) 185 | inc -= 1 186 | else 187 | inc += 1 188 | 189 | if (ns) 190 | ar(inc)(i) = sx 191 | else 192 | ar(i)(inc) = sx 193 | 194 | bef = sx 195 | 196 | case None => () 197 | } 198 | } 199 | } 200 | (copy(grid = ar.map(_.toIndexedSeq).toIndexedSeq), merge) 201 | } 202 | 203 | 204 | def emptySpots = 205 | for { 206 | i <- 0 to 3 207 | j <- 0 to 3 if grid(j)(i) == None 208 | } yield (i, j) 209 | 210 | def value = 211 | grid.map(_.map(_.map(_.value).getOrElse(0)).sum).sum 212 | 213 | def edge = 214 | List(0, 3).map(row(_).map(_.map(_.value).getOrElse(0)).sum).sum + 215 | List(0, 3).map(col(_).map(_.map(_.value).getOrElse(0)).sum).sum 216 | 217 | 218 | def empty_squares = { 219 | grid.map(_.map(_.map(_.value.toFloat).getOrElse(0.5f)).sum).sum 220 | } 221 | 222 | def bestSum = { 223 | allSum.max 224 | } 225 | 226 | lazy val groupBycolor = 227 | (Piece(0, Red, mult)::Piece(0, Blue, mult)::Piece(0, Gray, mult)::all.toList).groupBy(_.color).map(x => (x._1, x._2.map(_.value).sum)) 228 | 229 | def allSum = { 230 | val a = groupBycolor.map(_._2) 231 | a.map(_.toFloat).toArray 232 | // a.map(x => if(x==a.max) x else 0f) 233 | // (l(0), l(1), l(2)) 234 | } 235 | 236 | def countMerge(xs:List[Piece], c: Color):Int = { 237 | lazy val (a, b) = (xs.head, xs.tail.head) 238 | if (xs.length < 2) 0 239 | else if (a == b && a.color == c) 2 + countMerge(xs.tail.tail, c) 240 | else countMerge(xs.tail, c) 241 | } 242 | 243 | def countBadMerge(xs:List[Piece], c: Color):Int = { 244 | lazy val (a, b) = (xs.head, xs.tail.head) 245 | if (xs.length < 2) 0 246 | else if (a.value == b.value && a.color == c && a.color == b.color ) 2 + countBadMerge(xs.tail.tail, c) 247 | else countBadMerge(xs.tail, c) 248 | } 249 | 250 | 251 | def countDestroy(xs:List[Piece], c: Color):Int = { 252 | lazy val (a, b) = (xs.head, xs.tail.head) 253 | if (xs.length < 2) 0 254 | else if (a.value == b.value && (a.color == c || b.color == c) && a.color != b.color ) 2 + countDestroy(xs.tail.tail, c) 255 | else countDestroy(xs.tail, c) 256 | } 257 | 258 | def countBadDestroy(xs:List[Piece], c: Color):Int = { 259 | lazy val (a, b) = (xs.head, xs.tail.head) 260 | if (xs.length < 2) 0 261 | else if (a.value == b.value && (a.color != c && b.color != c) && a.color != b.color ) 2 + countBadDestroy(xs.tail.tail, c) 262 | else countBadDestroy(xs.tail, c) 263 | } 264 | 265 | val p = Array(0, 1, 8, 27, 64, 125, 217, 343) 266 | 267 | def countMonoton(xs:List[Option[Piece]], c: Color):Int = { 268 | lazy val (a, b) = (xs.head.getOrElse(Piece(0, c, mult)), xs.tail.head) 269 | if (xs.length < 2) 0 270 | else if (b.isDefined && b.get.color == c && a.color == c && b.get.value > a.value) { 271 | p(b.get.log) - p(a.log) + countMonoton(xs.tail, c) 272 | } 273 | else countMonoton(xs.tail, c) 274 | } 275 | 276 | def merges(c: Color) = { 277 | val l = ((0 to 3).map(row).toList ::: (0 to 3).map(col).toList).map(_.foldLeft(List[Piece]())((acc, pos) => pos.map(_::acc).getOrElse(acc))).map(x => ((countMerge(x, c), countBadMerge(x, c)), (countDestroy(x, c), countBadDestroy(x, c)))).unzip 278 | 279 | val m = ((0 to 3).map(row).toList ::: (0 to 3).map(col).toList).map(_.toList).map(x => countMonoton(x, c).min(countMonoton(x.reverse, c))) 280 | 281 | 282 | val l1 = l._1.unzip 283 | val l2 = l._2.unzip 284 | (l1._1.sum, l1._2.sum, l2._1.sum, l2._2.sum, m.sum) 285 | } 286 | def empties = { 287 | emptySpots.size 288 | } 289 | 290 | def reward = 291 | empties + bestSum.toInt 292 | 293 | def heuristic = { 294 | 295 | 296 | val EVAL_EMPTY_WEIGHT = 1.0; 297 | val EVAL_MERGE_WEIGHT = 0.5; 298 | val EVAL_SUM_WEIGHT = 2.0; 299 | val EVAL_DESTROYS_WEIGHT = 0.75; 300 | val EVAL_MONOTONOCITY_WEIGHT = 0.001; 301 | val EVAL_BAD_MERGE_WEIGHT = 0.5; 302 | val EVAL_BAD_SUM_WEIGHT = 0.1; 303 | val EVAL_BAD_DESTROYS_WEIGHT = 0.2; 304 | 305 | 306 | val bestSum = allSum.max 307 | val bestColor = groupBycolor.filter(_._2 == bestSum).head._1 308 | val badSum = allSum.sum - bestSum 309 | val merge = merges(bestColor) 310 | 311 | EVAL_EMPTY_WEIGHT * empties + 312 | EVAL_MERGE_WEIGHT * merge._1 + 313 | EVAL_SUM_WEIGHT * bestSum + 314 | EVAL_BAD_DESTROYS_WEIGHT * merge._4 - 315 | EVAL_DESTROYS_WEIGHT * merge._3 - 316 | EVAL_BAD_SUM_WEIGHT * badSum - 317 | EVAL_BAD_MERGE_WEIGHT * merge._2 318 | // EVAL_MONOTONOCITY_WEIGHT * merge._5 319 | 320 | } 321 | 322 | def parameters = { 323 | val bestSum = allSum.max 324 | val bestColor = groupBycolor.filter(_._2 == bestSum).head._1 325 | val badSum = allSum.sum - bestSum 326 | val merge = merges(bestColor) 327 | 328 | Array(empties, bestSum, badSum, merge._1, merge._2, merge._3, merge._4, merge._5).map(_.toFloat) 329 | } 330 | 331 | def eval = 332 | heuristic 333 | 334 | def evalOpp = 335 | bestSum 336 | 337 | 338 | def cellToInputColor(opt: Option[Piece], color:Color):Float = 339 | opt.map(x => { 340 | if (color == x.color) 341 | x.value.toFloat 342 | else 343 | 0f 344 | }).getOrElse(0f) 345 | 346 | def cellToInput(opt: Option[Piece]):List[Float] = 347 | opt.map(x => { 348 | val value = x.value.toFloat 349 | x.color match { 350 | case Blue => List(value, 0f, 0f) 351 | case Red => List(0f, value, 0f) 352 | case Gray => List(0f, 0f, value) 353 | } 354 | }).getOrElse(List(0f, 0f, 0f)) 355 | 356 | 357 | def pow(i:Int, j:Int) = BigInt(i).pow(j).intValue 358 | 359 | lazy val toInput6561: Array[Float] = { 360 | val clrs = for { 361 | c <- List(Blue, Red, Gray) 362 | i <- 0 to 3 363 | j <- 0 to 3 364 | } yield { 365 | if (grid(i)(j).exists(_.color == c)) 366 | 1f 367 | else 368 | 0f 369 | } 370 | val inp = for { 371 | c <- List(Blue, Red, Gray) 372 | v <- 0 to Game6561Conf.maxTileValue 373 | i <- 0 to 3 374 | j <- 0 to 3 375 | } yield { 376 | if (grid(i)(j).equals(Some(Piece(pow(3, v), c, mult)))) 377 | 1f 378 | else { 379 | 0f 380 | } 381 | } 382 | (clrs.toList:::inp.toList).toArray 383 | } 384 | 385 | 386 | lazy val toInput2048: Array[Float] = { 387 | val inp = for { 388 | c <- List(Red) 389 | v <- 0 to Game2048Conf.maxTileValue 390 | i <- 0 to 3 391 | j <- 0 to 3 392 | } yield { 393 | if (grid(i)(j).equals(Some(Piece(pow(2, v), c, mult)))) 394 | 1f 395 | else { 396 | 0f 397 | } 398 | 399 | } 400 | (inp.toArray) 401 | } 402 | 403 | 404 | override def toString() = { 405 | val ESP = 3 406 | var str = "" 407 | str += "value: " + value + "\n" 408 | for (y <- 0 to 3) { 409 | str += "| " 410 | for (x <- 0 to 3) { 411 | val s = get(x, y).map(_.toString).getOrElse("_") 412 | val l = s.length/2.0 413 | str += (" "*ESP).drop(l.floor.toInt) + s + (" "*ESP).drop(l.ceil.toInt) 414 | } 415 | str += " |\n" 416 | } 417 | str 418 | } 419 | } 420 | --------------------------------------------------------------------------------