├── README.md ├── char-rnn └── main.go ├── hmm.go ├── lstm.go ├── markov.go ├── model.go └── samples.go /README.md: -------------------------------------------------------------------------------- 1 | # char-rnn 2 | 3 | I have recreated [karpathy/char-rnn](https://github.com/karpathy/char-rnn) with my own RNN package. It works fairly well, and I have used it to generate [some cool results](#example). 4 | 5 | # Usage 6 | 7 | First, gather a folder with a bunch of text files in it (or with one big text file in it). Let's call this `path/to/text`. 8 | 9 | Now, install [Go](https://golang.org/doc/install) and setup a GOPATH. Once this is done, you are ready to install char-rnn itself: 10 | 11 | ``` 12 | $ go get -u github.com/unixpickle/char-rnn 13 | $ cd $GOPATH/src/github.com/unixpickle/char-rnn/char-rnn 14 | $ go build 15 | ``` 16 | 17 | This will generate an executable called `char-rnn` in your current directory. 18 | 19 | ## Installing with CUDA 20 | 21 | If you have CUDA bindings setup (instructions [here](https://godoc.org/github.com/unixpickle/cuda#hdr-Building)), you can use CUDA like so: 22 | 23 | ``` 24 | $ go get -u -tags cuda github.com/unixpickle/char-rnn/... 25 | $ cd $GOPATH/src/github.com/unixpickle/char-rnn/char-rnn 26 | $ go build -tags cuda 27 | ``` 28 | 29 | ## Training 30 | 31 | You can train char-rnn on some data as follows: 32 | 33 | ``` 34 | $ ./char-rnn train lstm /path/to/lstm /path/to/sample/directory 35 | 2016/06/22 17:52:53 Loaded model from file. 36 | 2016/06/22 17:52:53 Training LSTM on 22308 samples... 37 | 2016/06/22 17:59:29 Epoch 0: cost=1857807.936353 38 | ... 39 | ``` 40 | 41 | You can set the `GOMAXPROCS` environment variable to specify how many CPU threads to use for training. 42 | 43 | If the `/path/to/lstm` file already exists, it will be loaded as an LSTM and training will resume where it left off. Otherwise, a new LSTM will be created. 44 | 45 | It may take a while to train the LSTM reasonably well. On karpathy's [tinyshakespeare](https://github.com/karpathy/char-rnn/tree/6f9487a6fe5b420b7ca9afb0d7c078e37c1d1b4e/data/tinyshakespeare), it took my Intel NUC (quad-core i3 with 1.7GHz) roughly 18 hours to train reasonably well (although for much of that time I was only using a single CPU core). 46 | 47 | To pause or stop training, press Ctrl+C exactly once. This will finish the current mini-batch and then terminate the program. Once the program has terminated, a trained LSTM will be saved to `/path/to/lstm`. **Note:** if you hit Ctrl+C more than once, the program will terminate without saving. 48 | 49 | ## Generating text 50 | 51 | Once you have trained an LSTM, you can use it to generate a block of text. You must decide how much text to generate (e.g. 1000 characters, like below): 52 | 53 | ``` 54 | $ ./char-rnn gen /path/to/lstm 1000 55 | ``` 56 | 57 | # Example 58 | 59 | I ran a GRU on the output of `ls -l /usr/bin` and then generated some dir listings: 60 | 61 | ``` 62 | -rwxr-xr-x 35 root wheel 821 Aug 23 2015 iptab5.18 63 | -r-xr-xr-x 1 root wheel 3659 Sep 28 2015 instmodse 64 | -rwxr-xr-x 1 root wheel 75 Oct 25 2015 info3eal -> /System/Library/Frameworks/JavaVM.framework/Versions/Current/Commands/rmic 65 | lrwxr-xr-x 1 root wheel 84 Oct 25 2015 javmap -> /System/Library/Frameworks/JavaVM.framework/Versions/Current/Commands/kchase 66 | -rwxr-xr-x 1 root wheel 59576 Oct 17 2015 anplrac 67 | -rwxr-xr-x 1 root wheel 77 Oct 25 2015 edbsc -> cling 68 | -r-xr-xr-x 1 root wheel 18176 Oct 17 2015 nv5.16 69 | -rwxr-xr-x 1 root wheel 17204 Aug 22 2015 pod2readme5.16 70 | -rwxr-xr-x 35 root wheel 811 Aug 23 2015 lwp-download5.16 71 | -r-xr-xr-x 1 root wheel 3573 Aug 22 2015 dbiprof5.18 72 | -rwxr-xr-x 1 root wheel 23368 Oct 17 2015 enice 73 | -rwxr-xr-x 1 root wheel 43 Oct 25 2015 jstat -> /System/Library/Frameworks/JavaVM.framework/Versions/Current/Commands/intext 74 | -rwxr-xr-x 1 root wheel 77 Oct 25 2015 netalloc.5 -> ../../System/Library/Frameworks/Python.framework/Versions/-arwervim 75 | lrwxr-xr-x 1 root wheel 82 Oct 25 2015 j0 -> vmeadsrad 76 | -rwxr-xr-x 1 root wheel 18176 Oct 17 2015 gzeratex 77 | -rwxr-xr-x 1 root wheel 1947 Aug 22 2015 config_data5.16 78 | -rwxr-xr-x 1 root wheel 9151 Aug 23 2015 ifstroc5.16 79 | -rwxr-xr-x 1 root wheel 2 Oct 25 2015 viaevketat-cvisthar -> 2toc2.6 80 | ``` 81 | -------------------------------------------------------------------------------- /char-rnn/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "math/rand" 8 | "os" 9 | "time" 10 | 11 | _ "github.com/unixpickle/anyplugin" 12 | charrnn "github.com/unixpickle/char-rnn" 13 | "github.com/unixpickle/serializer" 14 | ) 15 | 16 | var Models = []charrnn.Model{&charrnn.LSTM{}, &charrnn.Markov{}, &charrnn.HMM{}} 17 | 18 | const OutputPermissions = 0755 19 | 20 | func main() { 21 | rand.Seed(time.Now().UnixNano()) 22 | if len(os.Args) < 2 { 23 | dieUsage() 24 | } 25 | subCmd := os.Args[1] 26 | switch subCmd { 27 | case "train": 28 | trainCommand() 29 | case "gen": 30 | genCommand() 31 | case "help": 32 | helpCommand() 33 | default: 34 | dieUsage() 35 | } 36 | } 37 | 38 | func trainCommand() { 39 | if len(os.Args) < 5 { 40 | dieUsage() 41 | } 42 | 43 | modelFile := os.Args[3] 44 | 45 | model := modelForName(os.Args[2]) 46 | samples := charrnn.ReadSampleList(os.Args[4]) 47 | modelData, err := ioutil.ReadFile(modelFile) 48 | 49 | if err == nil { 50 | x, desErr := serializer.DeserializeWithType(modelData) 51 | if desErr != nil { 52 | fmt.Fprintln(os.Stderr, "Failed to deserialize model:", desErr) 53 | os.Exit(1) 54 | } 55 | var ok bool 56 | model, ok = x.(charrnn.Model) 57 | if !ok { 58 | fmt.Fprintf(os.Stderr, "Loaded type was not a model but a %T\n", x) 59 | os.Exit(1) 60 | } 61 | log.Println("Loaded model from file.") 62 | } else { 63 | log.Println("Created new model.") 64 | } 65 | 66 | model.TrainingFlags().Parse(os.Args[5:]) 67 | model.Train(samples) 68 | 69 | encoded, err := serializer.SerializeWithType(model) 70 | if err != nil { 71 | fmt.Fprintln(os.Stderr, "Failed to serialize model:", err) 72 | os.Exit(1) 73 | } 74 | if err := ioutil.WriteFile(modelFile, encoded, OutputPermissions); err != nil { 75 | fmt.Fprintln(os.Stderr, "Failed to save:", err) 76 | os.Exit(1) 77 | } 78 | } 79 | 80 | func genCommand() { 81 | if len(os.Args) < 3 { 82 | dieUsage() 83 | } 84 | 85 | modelData, err := ioutil.ReadFile(os.Args[2]) 86 | if err != nil { 87 | fmt.Fprintln(os.Stderr, "Failed to read model:", err) 88 | os.Exit(1) 89 | } 90 | 91 | x, err := serializer.DeserializeWithType(modelData) 92 | if err != nil { 93 | fmt.Fprintln(os.Stderr, err) 94 | os.Exit(1) 95 | } 96 | 97 | model, ok := x.(charrnn.Model) 98 | if !ok { 99 | fmt.Fprintf(os.Stderr, "Loaded type was not a model but a %T\n", x) 100 | os.Exit(1) 101 | } 102 | 103 | model.GenerationFlags().Parse(os.Args[3:]) 104 | model.Generate() 105 | } 106 | 107 | func helpCommand() { 108 | if len(os.Args) != 3 { 109 | dieUsage() 110 | } 111 | m := modelForName(os.Args[2]) 112 | fmt.Fprintf(os.Stderr, "Usage for training:\n\n") 113 | m.TrainingFlags().PrintDefaults() 114 | fmt.Fprintf(os.Stderr, "\nUsage for generation:\n\n") 115 | m.GenerationFlags().PrintDefaults() 116 | } 117 | 118 | func dieUsage() { 119 | fmt.Fprintln(os.Stderr, "Usage: char-rnn train [args]\n"+ 120 | " char-rnn gen [args]\n"+ 121 | " char-rnn help \n\n"+ 122 | "Available models:") 123 | for _, m := range Models { 124 | fmt.Fprintln(os.Stderr, " "+m.Name()) 125 | } 126 | fmt.Fprintln(os.Stderr, "\nEnvironment variables:") 127 | fmt.Fprintf(os.Stderr, " TEXT_CHUNK_SIZE chars per sample (default %d)\n", 128 | charrnn.TextChunkSize) 129 | fmt.Fprintln(os.Stderr, " TEXT_CHUNK_HEAD_ONLY only use heads of samples") 130 | fmt.Fprintln(os.Stderr) 131 | os.Exit(1) 132 | } 133 | 134 | func modelForName(name string) charrnn.Model { 135 | for _, m := range Models { 136 | if m.Name() == name { 137 | return m 138 | } 139 | } 140 | fmt.Fprintln(os.Stderr, "no such model: "+os.Args[2]) 141 | dieUsage() 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /hmm.go: -------------------------------------------------------------------------------- 1 | package charrnn 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "flag" 7 | "fmt" 8 | "log" 9 | "runtime" 10 | "sync" 11 | 12 | "github.com/unixpickle/anynet/anysgd" 13 | "github.com/unixpickle/essentials" 14 | "github.com/unixpickle/hmm" 15 | "github.com/unixpickle/rip" 16 | "github.com/unixpickle/serializer" 17 | ) 18 | 19 | func init() { 20 | var h HMM 21 | serializer.RegisterTypedDeserializer(h.SerializerType(), DeserializeHMM) 22 | 23 | gob.Register(hmm.TabularEmitter{}) 24 | } 25 | 26 | // HMM is a Model for a character-level hidden Markov 27 | // model. 28 | type HMM struct { 29 | HMM *hmm.HMM 30 | NumStates int 31 | 32 | Validation float64 33 | } 34 | 35 | func DeserializeHMM(d []byte) (*HMM, error) { 36 | dec := gob.NewDecoder(bytes.NewReader(d)) 37 | var res *HMM 38 | if err := dec.Decode(&res); err != nil { 39 | return nil, essentials.AddCtx("deserialize HMM", err) 40 | } 41 | return res, nil 42 | } 43 | 44 | func (h *HMM) Name() string { 45 | return "hmm" 46 | } 47 | 48 | func (h *HMM) TrainingFlags() *flag.FlagSet { 49 | f := flag.NewFlagSet("hmm", flag.ExitOnError) 50 | f.IntVar(&h.NumStates, "states", 200, "number of hidden states") 51 | f.Float64Var(&h.Validation, "validation", 0.1, "validation fraction") 52 | return f 53 | } 54 | 55 | func (h *HMM) GenerationFlags() *flag.FlagSet { 56 | return flag.NewFlagSet("hmm", flag.ExitOnError) 57 | } 58 | 59 | func (h *HMM) Train(s SampleList) { 60 | validation, training := anysgd.HashSplit(s, h.Validation) 61 | log.Printf("Training: %d samples (%d bytes)", training.Len(), 62 | training.(SampleList).Bytes()) 63 | log.Printf("Validation: %d samples (%d bytes)", validation.Len(), 64 | validation.(SampleList).Bytes()) 65 | 66 | if h.HMM == nil { 67 | h.initModel() 68 | } 69 | 70 | log.Println("Computing initial loss...") 71 | log.Printf("initial: train_loss=%f val_loss=%f", h.meanLoss(training), 72 | h.meanLoss(validation)) 73 | 74 | log.Println("Training (press ctrl+c to terminate)...") 75 | r := rip.NewRIP() 76 | var iter int 77 | for !r.Done() { 78 | h.HMM = hmm.BaumWelch(h.HMM, h.samplesToChan(training), 0) 79 | log.Printf("iter %d: train_loss=%f val_loss=%f", iter, 80 | h.meanLoss(training), h.meanLoss(validation)) 81 | iter++ 82 | } 83 | } 84 | 85 | func (h *HMM) Generate() { 86 | _, seq := h.HMM.Sample(nil) 87 | for _, character := range seq { 88 | fmt.Print(string([]byte{character.(byte)})) 89 | } 90 | fmt.Println() 91 | } 92 | 93 | func (h *HMM) SerializerType() string { 94 | return "github.com/unixpickle/char-rnn.HMM" 95 | } 96 | 97 | func (h *HMM) Serialize() ([]byte, error) { 98 | var buf bytes.Buffer 99 | enc := gob.NewEncoder(&buf) 100 | if err := enc.Encode(h); err != nil { 101 | return nil, essentials.AddCtx("serialize HMM", err) 102 | } 103 | return buf.Bytes(), nil 104 | } 105 | 106 | func (h *HMM) initModel() { 107 | var states []hmm.State 108 | for i := 0; i < h.NumStates; i++ { 109 | states = append(states, i) 110 | } 111 | var obses []hmm.Obs 112 | for i := 0; i < 0x100; i++ { 113 | obses = append(obses, byte(i)) 114 | } 115 | h.HMM = hmm.RandomHMM(nil, states, 0, obses) 116 | } 117 | 118 | func (h *HMM) meanLoss(samples anysgd.SampleList) float64 { 119 | var total float64 120 | var divisor int 121 | 122 | var lock sync.Mutex 123 | var wg sync.WaitGroup 124 | 125 | ch := h.samplesToChan(samples) 126 | for i := 0; i < runtime.GOMAXPROCS(0); i++ { 127 | wg.Add(1) 128 | go func() { 129 | for sample := range ch { 130 | loss := hmm.LogLikelihood(h.HMM, sample) 131 | 132 | lock.Lock() 133 | // Add 1 for the terminal symbol. 134 | divisor += len(sample) + 1 135 | total += loss 136 | lock.Unlock() 137 | } 138 | wg.Done() 139 | }() 140 | } 141 | 142 | wg.Wait() 143 | 144 | return total / float64(divisor) 145 | } 146 | 147 | func (h *HMM) samplesToChan(samples anysgd.SampleList) <-chan []hmm.Obs { 148 | res := make(chan []hmm.Obs, 1) 149 | go func() { 150 | for _, seq := range samples.(SampleList) { 151 | var obses []hmm.Obs 152 | for _, b := range seq { 153 | obses = append(obses, b) 154 | } 155 | res <- obses 156 | } 157 | close(res) 158 | }() 159 | return res 160 | } 161 | -------------------------------------------------------------------------------- /lstm.go: -------------------------------------------------------------------------------- 1 | package charrnn 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "math" 8 | "math/rand" 9 | 10 | "github.com/unixpickle/anydiff/anyseq" 11 | "github.com/unixpickle/anynet" 12 | "github.com/unixpickle/anynet/anyrnn" 13 | "github.com/unixpickle/anynet/anys2s" 14 | "github.com/unixpickle/anynet/anysgd" 15 | "github.com/unixpickle/anyvec" 16 | "github.com/unixpickle/anyvec/anyvec32" 17 | "github.com/unixpickle/essentials" 18 | "github.com/unixpickle/lazyseq" 19 | "github.com/unixpickle/lazyseq/lazyrnn" 20 | "github.com/unixpickle/rip" 21 | "github.com/unixpickle/serializer" 22 | ) 23 | 24 | func init() { 25 | var l LSTM 26 | serializer.RegisterTypedDeserializer(l.SerializerType(), DeserializeLSTM) 27 | } 28 | 29 | // LSTM is a Model for long short-term memory RNNs. 30 | type LSTM struct { 31 | lstmTrainingFlags 32 | lstmGenerationFlags 33 | 34 | Block anyrnn.Block 35 | } 36 | 37 | func DeserializeLSTM(d []byte) (*LSTM, error) { 38 | var b anyrnn.Block 39 | if err := serializer.DeserializeAny(d, &b); err != nil { 40 | return nil, err 41 | } 42 | return &LSTM{Block: b}, nil 43 | } 44 | 45 | func (l *LSTM) Train(samples SampleList) { 46 | if l.Block == nil { 47 | l.createModel() 48 | } 49 | 50 | validation, training := anysgd.HashSplit(samples, l.Validation) 51 | 52 | t := &anys2s.Trainer{ 53 | Func: func(s anyseq.Seq) anyseq.Seq { 54 | if l.LowMem { 55 | inSeq := lazyseq.Lazify(s) 56 | ival := int(math.Sqrt(float64(len(s.Output())))) 57 | ival = essentials.MaxInt(ival, 1) 58 | out := lazyrnn.FixedHSM(ival, true, inSeq, l.Block) 59 | return lazyseq.Unlazify(out) 60 | } else { 61 | return anyrnn.Map(s, l.Block) 62 | } 63 | }, 64 | Cost: anynet.DotCost{}, 65 | Params: l.Block.(anynet.Parameterizer).Parameters(), 66 | Average: true, 67 | } 68 | 69 | log.Printf("Training: %d samples (%d bytes)", training.Len(), 70 | training.(SampleList).Bytes()) 71 | log.Printf("Validation: %d samples (%d bytes)", validation.Len(), 72 | validation.(SampleList).Bytes()) 73 | 74 | var iter int 75 | sgd := &anysgd.SGD{ 76 | Fetcher: t, 77 | Gradienter: t, 78 | Transformer: &anysgd.Adam{}, 79 | Samples: &anys2s.SortSampleList{ 80 | SortableSampleList: training.(SampleList), 81 | BatchSize: l.SortBatch, 82 | }, 83 | Rater: anysgd.ConstRater(l.StepSize), 84 | StatusFunc: func(b anysgd.Batch) { 85 | if validation.Len() == 0 { 86 | log.Printf("iter %d: cost=%v", iter, t.LastCost) 87 | return 88 | } 89 | 90 | vSize := l.BatchSize 91 | if vSize > validation.Len() { 92 | vSize = validation.Len() 93 | } 94 | anysgd.Shuffle(validation) 95 | validationBatch, _ := t.Fetch(validation.Slice(0, vSize)) 96 | v := anyvec.Sum(t.TotalCost(validationBatch.(*anys2s.Batch)).Output()) 97 | 98 | log.Printf("iter %d: cost=%v validation=%v", iter, t.LastCost, v) 99 | iter++ 100 | }, 101 | BatchSize: l.BatchSize, 102 | } 103 | 104 | log.Println("Training (ctrl+c to stop)...") 105 | l.setDropout(true) 106 | defer l.setDropout(false) 107 | sgd.Run(rip.NewRIP().Chan()) 108 | } 109 | 110 | func (l *LSTM) Generate() { 111 | state := l.Block.Start(1) 112 | 113 | last := oneHotAscii(0) 114 | seedBytes := []byte(l.Seed) 115 | for i := 0; i < l.Length; i++ { 116 | res := l.Block.Step(state, last) 117 | ch := sampleSoftmax(res.Output(), l.Temperature) 118 | if i < len(seedBytes) { 119 | ch = int(seedBytes[i]) 120 | } 121 | 122 | fmt.Print(string([]byte{byte(ch)})) 123 | 124 | v := make([]float32, CharCount) 125 | v[ch] = 1 126 | last = anyvec32.MakeVectorData(v) 127 | state = res.State() 128 | } 129 | 130 | fmt.Println() 131 | } 132 | 133 | func (l *LSTM) Name() string { 134 | return "lstm" 135 | } 136 | 137 | func (l *LSTM) SerializerType() string { 138 | return "github.com/unixpickle/char-rnn.LSTM" 139 | } 140 | 141 | func (l *LSTM) Serialize() ([]byte, error) { 142 | return serializer.SerializeAny(l.Block) 143 | } 144 | 145 | func (l *LSTM) createModel() { 146 | block := anyrnn.Stack{} 147 | inCount := CharCount 148 | scaler := anyvec32.MakeNumeric(16) 149 | for i := 0; i < l.Layers; i++ { 150 | lstm := anyrnn.NewLSTM(anyvec32.CurrentCreator(), inCount, l.Hidden) 151 | dropout := &anyrnn.LayerBlock{Layer: &anynet.Dropout{KeepProb: l.Dropout}} 152 | block = append(block, lstm.ScaleInWeights(scaler), dropout) 153 | inCount = l.Hidden 154 | } 155 | block = append(block, &anyrnn.LayerBlock{ 156 | Layer: anynet.Net{ 157 | anynet.NewFC(anyvec32.CurrentCreator(), inCount, CharCount), 158 | anynet.LogSoftmax, 159 | }, 160 | }) 161 | var size int 162 | for _, p := range block.Parameters() { 163 | size += p.Vector.Len() 164 | } 165 | l.Block = block 166 | } 167 | 168 | func (l *LSTM) setDropout(enabled bool) { 169 | for _, block := range l.Block.(anyrnn.Stack) { 170 | if block, ok := block.(*anyrnn.LayerBlock); ok { 171 | if do, ok := block.Layer.(*anynet.Dropout); ok { 172 | do.Enabled = enabled 173 | } 174 | } 175 | } 176 | } 177 | 178 | type lstmTrainingFlags struct { 179 | StepSize float64 180 | Validation float64 181 | Dropout float64 182 | Hidden int 183 | Layers int 184 | BatchSize int 185 | SortBatch int 186 | LowMem bool 187 | } 188 | 189 | func (l *lstmTrainingFlags) TrainingFlags() *flag.FlagSet { 190 | res := flag.NewFlagSet("lstm", flag.ExitOnError) 191 | res.IntVar(&l.Hidden, "hidden", 512, "hidden neuron count") 192 | res.IntVar(&l.Layers, "layers", 2, "LSTM layer count") 193 | res.Float64Var(&l.StepSize, "step", 0.001, "step size") 194 | res.Float64Var(&l.Validation, "validation", 0.1, "validation fraction") 195 | res.Float64Var(&l.Dropout, "dropout", 0.6, "dropout remain probability") 196 | res.IntVar(&l.BatchSize, "batch", 32, "SGD batch size") 197 | res.IntVar(&l.SortBatch, "sortbatch", 128, "sample sort batch size") 198 | res.BoolVar(&l.LowMem, "lowmem", false, "use asymptotic memory saving algorithms") 199 | return res 200 | } 201 | 202 | type lstmGenerationFlags struct { 203 | Length int 204 | Seed string 205 | Temperature float64 206 | } 207 | 208 | func (l *lstmGenerationFlags) GenerationFlags() *flag.FlagSet { 209 | res := flag.NewFlagSet("lstm", flag.ExitOnError) 210 | res.IntVar(&l.Length, "length", 100, "generated string length") 211 | res.StringVar(&l.Seed, "seed", "", "text to start with") 212 | res.Float64Var(&l.Temperature, "temperature", 1, "softmax temperature") 213 | return res 214 | } 215 | 216 | func sampleSoftmax(vec anyvec.Vector, temp float64) int { 217 | scaled := vec.Copy() 218 | scaled.Scale(vec.Creator().MakeNumeric(1 / temp)) 219 | anyvec.LogSoftmax(scaled, scaled.Len()) 220 | 221 | p := rand.Float64() 222 | for i, x := range scaled.Data().([]float32) { 223 | p -= math.Exp(float64(x)) 224 | if p < 0 { 225 | return i 226 | } 227 | } 228 | return CharCount - 1 229 | } 230 | -------------------------------------------------------------------------------- /markov.go: -------------------------------------------------------------------------------- 1 | package charrnn 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "math" 9 | "math/rand" 10 | 11 | "github.com/unixpickle/anynet/anysgd" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func init() { 16 | var m Markov 17 | serializer.RegisterTypedDeserializer(m.SerializerType(), DeserializeMarkov) 18 | } 19 | 20 | const entropySoftener = 1e-5 21 | 22 | // Markov is a Model for a character-level Markov chain. 23 | type Markov struct { 24 | Table map[string]map[byte]float64 25 | History int 26 | 27 | Validation float64 `json:"-"` 28 | } 29 | 30 | func DeserializeMarkov(d []byte) (*Markov, error) { 31 | var res Markov 32 | if err := json.Unmarshal(d, &res); err != nil { 33 | return nil, err 34 | } 35 | return &res, nil 36 | } 37 | 38 | func (m *Markov) Name() string { 39 | return "markov" 40 | } 41 | 42 | func (m *Markov) TrainingFlags() *flag.FlagSet { 43 | f := flag.NewFlagSet("markov", flag.ExitOnError) 44 | f.IntVar(&m.History, "history", 3, "character history size") 45 | f.Float64Var(&m.Validation, "validation", 0.1, "validation fraction") 46 | return f 47 | } 48 | 49 | func (m *Markov) GenerationFlags() *flag.FlagSet { 50 | return flag.NewFlagSet("markov", flag.ExitOnError) 51 | } 52 | 53 | func (m *Markov) Train(s SampleList) { 54 | validation, training := anysgd.HashSplit(s, m.Validation) 55 | log.Printf("Training: %d samples (%d bytes)", training.Len(), 56 | training.(SampleList).Bytes()) 57 | log.Printf("Validation: %d samples (%d bytes)", validation.Len(), 58 | validation.(SampleList).Bytes()) 59 | 60 | m.Table = map[string]map[byte]float64{} 61 | totals := map[string]float64{} 62 | 63 | log.Println("Producing chain...") 64 | for _, sample := range training.(SampleList) { 65 | stateBytes := []byte{} 66 | for _, ch := range append(append([]byte{}, sample...), 0) { 67 | stateStr := string(stateBytes) 68 | if m.Table[stateStr] == nil { 69 | m.Table[stateStr] = map[byte]float64{} 70 | } 71 | m.Table[stateStr][ch]++ 72 | totals[stateStr]++ 73 | 74 | stateBytes = m.appendState(stateBytes, ch) 75 | } 76 | } 77 | 78 | log.Println("Normalizing chain...") 79 | for state, total := range totals { 80 | for k, v := range m.Table[state] { 81 | m.Table[state][k] = v / total 82 | } 83 | } 84 | 85 | log.Println("Computing cross-entropy...") 86 | 87 | log.Println("Training entropy:", m.averageEntropy(training.(SampleList))) 88 | log.Println("Validation entropy:", m.averageEntropy(validation.(SampleList))) 89 | } 90 | 91 | func (m *Markov) Generate() { 92 | state := []byte{} 93 | for { 94 | next := m.selectRandom(state) 95 | if next == 0 { 96 | break 97 | } 98 | fmt.Print(string(next)) 99 | state = m.appendState(state, next) 100 | } 101 | fmt.Println() 102 | } 103 | 104 | func (m *Markov) SerializerType() string { 105 | return "github.com/unixpickle/char-rnn.Markov" 106 | } 107 | 108 | func (m *Markov) Serialize() ([]byte, error) { 109 | return json.Marshal(m) 110 | } 111 | 112 | func (m *Markov) averageEntropy(s SampleList) float64 { 113 | var totalEntropy float64 114 | var charCount float64 115 | for _, sample := range s { 116 | totalEntropy += m.sampleEntropy(sample) 117 | charCount += float64(len(sample)) 118 | } 119 | return totalEntropy / charCount 120 | } 121 | 122 | func (m *Markov) sampleEntropy(sample []byte) float64 { 123 | entropy := 0.0 124 | state := []byte{} 125 | for _, b := range sample { 126 | p := m.Table[string(state)][b] 127 | if p == 0 { 128 | p = entropySoftener 129 | } 130 | entropy += math.Log(p) 131 | state = m.appendState(state, b) 132 | } 133 | return -entropy 134 | } 135 | 136 | func (m *Markov) selectRandom(state []byte) byte { 137 | next := m.Table[string(state)] 138 | if len(next) == 0 { 139 | return 0 140 | } 141 | selection := rand.Float64() 142 | for b, prob := range next { 143 | selection -= prob 144 | if selection < 0 { 145 | return b 146 | } 147 | } 148 | return 0 149 | } 150 | 151 | func (m *Markov) appendState(state []byte, b byte) []byte { 152 | state = append(state, b) 153 | if len(state) > m.History { 154 | copy(state, state[1:]) 155 | state = state[:len(state)-1] 156 | } 157 | return state 158 | } 159 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package charrnn 2 | 3 | import ( 4 | "flag" 5 | 6 | "github.com/unixpickle/serializer" 7 | ) 8 | 9 | // A Model is a trainable language model for predicting 10 | // characters in a string. 11 | type Model interface { 12 | serializer.Serializer 13 | 14 | Name() string 15 | 16 | TrainingFlags() *flag.FlagSet 17 | GenerationFlags() *flag.FlagSet 18 | 19 | Train(samples SampleList) 20 | Generate() 21 | } 22 | -------------------------------------------------------------------------------- /samples.go: -------------------------------------------------------------------------------- 1 | package charrnn 2 | 3 | import ( 4 | "crypto/md5" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "path/filepath" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/unixpickle/anynet/anys2s" 13 | "github.com/unixpickle/anynet/anysgd" 14 | "github.com/unixpickle/anyvec" 15 | "github.com/unixpickle/anyvec/anyvec32" 16 | ) 17 | 18 | const ( 19 | TextChunkSize = 1 << 10 20 | CharCount = 256 21 | ) 22 | 23 | type SampleList [][]byte 24 | 25 | func ReadSampleList(dir string) SampleList { 26 | contents, err := ioutil.ReadDir(dir) 27 | if err != nil { 28 | fmt.Fprintln(os.Stderr, err) 29 | os.Exit(1) 30 | } 31 | 32 | var result SampleList 33 | 34 | chunkSize := TextChunkSize 35 | headOnly := false 36 | if csVar := os.Getenv("TEXT_CHUNK_SIZE"); csVar != "" { 37 | chunkSize, err = strconv.Atoi(csVar) 38 | if err != nil { 39 | fmt.Fprintln(os.Stderr, "Invalid TEXT_CHUNK_SIZE value:", csVar) 40 | os.Exit(1) 41 | } 42 | } 43 | if os.Getenv("TEXT_CHUNK_HEAD_ONLY") != "" { 44 | headOnly = true 45 | } 46 | 47 | for _, item := range contents { 48 | if strings.HasPrefix(item.Name(), ".") { 49 | continue 50 | } 51 | p := filepath.Join(dir, item.Name()) 52 | textContents, err := ioutil.ReadFile(p) 53 | if err != nil { 54 | fmt.Fprintln(os.Stderr, err) 55 | os.Exit(1) 56 | } 57 | for i := 0; i < len(textContents); i += chunkSize { 58 | bs := chunkSize 59 | if bs > len(textContents)-i { 60 | bs = len(textContents) - i 61 | } 62 | result = append(result, textContents[i:i+bs]) 63 | if headOnly { 64 | break 65 | } 66 | } 67 | } 68 | 69 | return result 70 | } 71 | 72 | func (s SampleList) Len() int { 73 | return len(s) 74 | } 75 | 76 | func (s SampleList) Swap(i, j int) { 77 | s[i], s[j] = s[j], s[i] 78 | } 79 | 80 | func (s SampleList) Slice(start, end int) anysgd.SampleList { 81 | return append(SampleList{}, s[start:end]...) 82 | } 83 | 84 | func (s SampleList) LenAt(idx int) int { 85 | return len(s[idx]) 86 | } 87 | 88 | func (s SampleList) GetSample(idx int) (*anys2s.Sample, error) { 89 | return seqForChunk(s[idx]), nil 90 | } 91 | 92 | func (s SampleList) Creator() anyvec.Creator { 93 | return anyvec32.CurrentCreator() 94 | } 95 | 96 | func (s SampleList) Hash(idx int) []byte { 97 | res := md5.Sum(s[idx]) 98 | return res[:] 99 | } 100 | 101 | // Bytes computes the total bytes across all samples. 102 | func (s SampleList) Bytes() int { 103 | sum := 0 104 | for _, x := range s { 105 | sum += len(x) 106 | } 107 | return sum 108 | } 109 | 110 | func seqForChunk(chunk []byte) *anys2s.Sample { 111 | var res anys2s.Sample 112 | for i, x := range chunk { 113 | res.Output = append(res.Output, oneHotAscii(x)) 114 | if i == 0 { 115 | res.Input = append(res.Input, oneHotAscii(0)) 116 | } else { 117 | res.Input = append(res.Input, oneHotAscii(chunk[i-1])) 118 | } 119 | } 120 | return &res 121 | } 122 | 123 | func oneHotAscii(b byte) anyvec.Vector { 124 | res := make([]float32, CharCount) 125 | res[int(b)] = 1 126 | return anyvec32.MakeVectorData(res) 127 | } 128 | --------------------------------------------------------------------------------