├── README.md ├── cube.go ├── network.go ├── solve └── main.go ├── stats └── main.go └── train └── main.go /README.md: -------------------------------------------------------------------------------- 1 | # Abstract 2 | 3 | I am going to try to come up with God's algorithm for the Rubik's cube. 4 | 5 | # Results 6 | 7 | After ~10 hours of training on a K80, the model gets down to a loss of `1.356`. The stats program outputs stats like the following: 8 | 9 | ``` 10 | 1 moves: 100 % 11 | 2 moves: 100 % 12 | 3 moves: 100 % 13 | 4 moves: 100 % 14 | 5 moves: 100 % 15 | 6 moves: 97.45454545454545 % 16 | 7 moves: 85.89090909090909 % 17 | 8 moves: 67.34545454545454 % 18 | 9 moves: 41.45454545454545 % 19 | 10 moves: 23.854545454545452 % 20 | 11 moves: 10.181818181818182 % 21 | 12 moves: 3.854545454545454 % 22 | 13 moves: 2.4727272727272727 % 23 | 14 moves: 1.090909090909091 % 24 | 15 moves: 0.2181818181818182 % 25 | 16 moves: 0.14545454545454545 % 26 | ``` 27 | 28 | Results after more training pending. 29 | -------------------------------------------------------------------------------- /cube.go: -------------------------------------------------------------------------------- 1 | package godsalg 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/unixpickle/gocube" 8 | ) 9 | 10 | // RandomScramble generates a move-based scramble 11 | // of a certain length. 12 | // It returns the inverse of the last move, which doubles 13 | // as the first move of a valid solution. 14 | func RandomScramble(length int) (*gocube.CubieCube, gocube.Move) { 15 | moves := allMoves() 16 | res := gocube.SolvedCubieCube() 17 | axis := -1 18 | lastMove := gocube.Move(0) 19 | for i := 0; i < length; i++ { 20 | move := moves[rand.Intn(len(moves))] 21 | if moveAxis(move) != axis { 22 | moves = allMoves() 23 | axis = moveAxis(move) 24 | } 25 | for i := 0; i < len(moves); i++ { 26 | if moves[i].Face() == move.Face() { 27 | moves[i] = moves[len(moves)-1] 28 | moves = moves[:len(moves)-1] 29 | i-- 30 | } 31 | } 32 | res.Move(move) 33 | lastMove = move 34 | } 35 | return &res, lastMove.Inverse() 36 | } 37 | 38 | // CubeVector returns a vectorized representation of 39 | // the stickers of a cube. 40 | func CubeVector(c *gocube.CubieCube) []float64 { 41 | stickerCube := c.StickerCube() 42 | res := make([]float64, 8*6*6) 43 | 44 | mean := 1.0 / 6 45 | stddev := math.Sqrt(0.13937) 46 | 47 | var stickerIdx int 48 | for i, sticker := range stickerCube[:] { 49 | if i%9 == 4 { 50 | continue 51 | } 52 | for j := 0; j < 6; j++ { 53 | if j == sticker-1 { 54 | res[j+stickerIdx] = (1 - mean) / stddev 55 | } else { 56 | res[j+stickerIdx] = (0 - mean) / stddev 57 | } 58 | } 59 | stickerIdx += 6 60 | } 61 | 62 | return res 63 | } 64 | 65 | func allMoves() []gocube.Move { 66 | res := make([]gocube.Move, NumMoves) 67 | for i := range res { 68 | res[i] = gocube.Move(i) 69 | } 70 | return res 71 | } 72 | 73 | func moveAxis(m gocube.Move) int { 74 | return (m.Face() - 1) / 2 75 | } 76 | -------------------------------------------------------------------------------- /network.go: -------------------------------------------------------------------------------- 1 | package godsalg 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/unixpickle/anynet" 7 | "github.com/unixpickle/anynet/anymisc" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | const ( 13 | NumMoves = 18 14 | ) 15 | 16 | func CreateNetwork(c anyvec.Creator, path string) anynet.Net { 17 | var net anynet.Net 18 | if err := serializer.LoadAny(path, &net); err == nil { 19 | log.Println("Using existing network.") 20 | return net 21 | } 22 | 23 | log.Println("Creating new network...") 24 | res := anynet.Net{ 25 | anynet.NewFC(c, 6*6*8, 1024), 26 | &anymisc.SELU{}, 27 | } 28 | for i := 0; i < 30; i++ { 29 | res = append(res, anynet.NewFC(c, 1024, 1024), &anymisc.SELU{}) 30 | } 31 | res = append(res, anynet.NewFC(c, 1024, NumMoves), anynet.LogSoftmax) 32 | return res 33 | } 34 | -------------------------------------------------------------------------------- /solve/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anynet" 10 | _ "github.com/unixpickle/anyplugin" 11 | "github.com/unixpickle/anyvec" 12 | "github.com/unixpickle/anyvec/anyvec32" 13 | "github.com/unixpickle/gocube" 14 | "github.com/unixpickle/godsalg" 15 | "github.com/unixpickle/serializer" 16 | ) 17 | 18 | const ( 19 | BatchSize = 1000 20 | ) 21 | 22 | func main() { 23 | if len(os.Args) != 2 { 24 | fmt.Fprintln(os.Stderr, "Usage: solve ") 25 | os.Exit(1) 26 | } 27 | var net anynet.Net 28 | if err := serializer.LoadAny(os.Args[1], &net); err != nil { 29 | fmt.Fprintln(os.Stderr, "Failed to load:", err) 30 | os.Exit(1) 31 | } 32 | 33 | cube, err := gocube.InputStickerCube() 34 | if err != nil { 35 | fmt.Fprintln(os.Stderr, "Bad input:", err) 36 | os.Exit(1) 37 | } 38 | state, err := cube.CubieCube() 39 | if err != nil { 40 | fmt.Fprintln(os.Stderr, "Bad state:", err) 41 | os.Exit(1) 42 | } 43 | 44 | for i := 0; true; i++ { 45 | solution := sampleSolution(*state, net) 46 | if solution != nil { 47 | fmt.Println("Solution:", solution) 48 | break 49 | } else { 50 | fmt.Println("Attempt", i, "failed") 51 | } 52 | } 53 | } 54 | 55 | func sampleSolution(start gocube.CubieCube, net anynet.Net) []gocube.Move { 56 | solutions := make([][]gocube.Move, BatchSize) 57 | states := make([]*gocube.CubieCube, BatchSize) 58 | for i := range states { 59 | c := start 60 | states[i] = &c 61 | } 62 | for i := 0; i < 21; i++ { 63 | var inVec []float64 64 | for j, x := range states { 65 | if x.Solved() { 66 | return solutions[j] 67 | } 68 | inVec = append(inVec, godsalg.CubeVector(x)...) 69 | } 70 | inRes := anydiff.NewConst( 71 | anyvec32.MakeVectorData(anyvec32.MakeNumericList(inVec)), 72 | ) 73 | output := net.Apply(inRes, BatchSize).Output() 74 | anyvec.Exp(output) 75 | slice := output.Data().([]float32) 76 | for j := 0; j < BatchSize; j++ { 77 | part := slice[j*godsalg.NumMoves : (j+1)*godsalg.NumMoves] 78 | move := selectMoveVector(part) 79 | solutions[j] = append(solutions[j], move) 80 | states[j].Move(move) 81 | } 82 | } 83 | return nil 84 | } 85 | 86 | func selectMoveVector(vec []float32) gocube.Move { 87 | p := rand.Float32() 88 | for i, x := range vec { 89 | p -= x 90 | if p < 0 { 91 | return gocube.Move(i) 92 | } 93 | } 94 | return 0 95 | } 96 | -------------------------------------------------------------------------------- /stats/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anynet" 9 | _ "github.com/unixpickle/anyplugin" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/gocube" 13 | "github.com/unixpickle/godsalg" 14 | "github.com/unixpickle/serializer" 15 | ) 16 | 17 | const ( 18 | MaxMoves = 16 19 | ) 20 | 21 | func main() { 22 | if len(os.Args) != 2 { 23 | die("Usage: stats ") 24 | } 25 | var net anynet.Net 26 | if err := serializer.LoadAny(os.Args[1], &net); err != nil { 27 | die("Failed to load network:", err) 28 | } 29 | 30 | histogram := make([]float64, MaxMoves) 31 | total := make([]float64, MaxMoves) 32 | for i := 0; true; i++ { 33 | solves := roundOfSolves(net) 34 | for i, x := range solves { 35 | if x { 36 | histogram[i]++ 37 | } 38 | total[i]++ 39 | } 40 | for i := 1; i <= MaxMoves; i++ { 41 | pct := histogram[i-1] / total[i-1] 42 | fmt.Println(i, "moves:", pct*100, "%") 43 | } 44 | } 45 | } 46 | 47 | func roundOfSolves(net anynet.Net) []bool { 48 | cubes := make([]*gocube.CubieCube, MaxMoves) 49 | for i := range cubes { 50 | scramble, _ := godsalg.RandomScramble(i + 1) 51 | cubes[i] = scramble 52 | } 53 | res := make([]bool, MaxMoves) 54 | for i := 0; i < 21; i++ { 55 | var in []float64 56 | for j, c := range cubes { 57 | res[j] = res[j] || c.Solved() 58 | in = append(in, godsalg.CubeVector(c)...) 59 | } 60 | c := anyvec32.CurrentCreator() 61 | inRes := anydiff.NewConst(c.MakeVectorData(c.MakeNumericList(in))) 62 | out := net.Apply(inRes, MaxMoves) 63 | for j := 0; j < MaxMoves; j++ { 64 | subVec := out.Output().Slice(godsalg.NumMoves*j, godsalg.NumMoves*(j+1)) 65 | max := anyvec.MaxIndex(subVec) 66 | cubes[j].Move(gocube.Move(max)) 67 | } 68 | } 69 | return res 70 | } 71 | 72 | func die(args ...interface{}) { 73 | fmt.Fprintln(os.Stderr, args...) 74 | os.Exit(1) 75 | } 76 | -------------------------------------------------------------------------------- /train/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "math/rand" 8 | "os" 9 | "time" 10 | 11 | "github.com/unixpickle/anydiff" 12 | "github.com/unixpickle/anynet" 13 | "github.com/unixpickle/anynet/anyff" 14 | "github.com/unixpickle/anynet/anysgd" 15 | _ "github.com/unixpickle/anyplugin" 16 | "github.com/unixpickle/anyvec" 17 | "github.com/unixpickle/anyvec/anyvec32" 18 | "github.com/unixpickle/godsalg" 19 | "github.com/unixpickle/rip" 20 | "github.com/unixpickle/serializer" 21 | ) 22 | 23 | const ( 24 | MoveCount = 18 25 | ) 26 | 27 | func main() { 28 | rand.Seed(time.Now().UnixNano()) 29 | 30 | var outFile string 31 | var batchSize int 32 | var stepSize float64 33 | var minMoves int 34 | var maxMoves int 35 | 36 | flag.StringVar(&outFile, "out", "out_net", "output network file") 37 | flag.IntVar(&batchSize, "batch", 1000, "SGD batch size") 38 | flag.Float64Var(&stepSize, "step", 1e-4, "SGD step size") 39 | flag.IntVar(&minMoves, "minmoves", 1, "minimum scramble moves") 40 | flag.IntVar(&maxMoves, "maxmoves", 16, "maximum scramble moves") 41 | flag.Parse() 42 | 43 | c := anyvec32.CurrentCreator() 44 | net := godsalg.CreateNetwork(c, outFile) 45 | 46 | log.Println("Training...") 47 | t := &anyff.Trainer{ 48 | Net: net, 49 | Cost: anynet.DotCost{}, 50 | Params: net.Parameters(), 51 | Average: true, 52 | } 53 | 54 | var iterNum int 55 | s := &anysgd.SGD{ 56 | Fetcher: &Fetcher{ 57 | Creator: c, 58 | MinMoves: minMoves, 59 | MaxMoves: maxMoves, 60 | }, 61 | Gradienter: t, 62 | Transformer: &anysgd.Adam{}, 63 | Samples: anysgd.LengthSampleList(batchSize), 64 | Rater: anysgd.ConstRater(stepSize), 65 | StatusFunc: func(b anysgd.Batch) { 66 | log.Printf("iter %d: cost=%v", iterNum, t.LastCost) 67 | iterNum++ 68 | }, 69 | BatchSize: batchSize, 70 | } 71 | 72 | log.Println("Press ctrl+c once to stop...") 73 | s.Run(rip.NewRIP().Chan()) 74 | 75 | if err := serializer.SaveAny(outFile, net); err != nil { 76 | fmt.Fprintln(os.Stderr, "Save error:", err) 77 | os.Exit(1) 78 | } 79 | } 80 | 81 | type Fetcher struct { 82 | Creator anyvec.Creator 83 | MinMoves int 84 | MaxMoves int 85 | } 86 | 87 | func (f *Fetcher) Fetch(s anysgd.SampleList) (anysgd.Batch, error) { 88 | var inVec []float64 89 | var outVec []float64 90 | for i := 0; i < s.Len(); i++ { 91 | moves := rand.Intn(f.MaxMoves-f.MinMoves+1) + f.MinMoves 92 | cube, first := godsalg.RandomScramble(moves) 93 | inVec = append(inVec, godsalg.CubeVector(cube)...) 94 | 95 | oneHot := make([]float64, MoveCount) 96 | oneHot[first] = 1 97 | outVec = append(outVec, oneHot...) 98 | } 99 | 100 | return &anyff.Batch{ 101 | Inputs: anydiff.NewConst( 102 | f.Creator.MakeVectorData(f.Creator.MakeNumericList(inVec)), 103 | ), 104 | Outputs: anydiff.NewConst( 105 | f.Creator.MakeVectorData(f.Creator.MakeNumericList(outVec)), 106 | ), 107 | Num: s.Len(), 108 | }, nil 109 | } 110 | --------------------------------------------------------------------------------