├── README.md ├── activation.go ├── block.go ├── block_test.go ├── experiments ├── audioset │ ├── README.md │ ├── cuda.go │ ├── eval_classes.txt │ ├── main.go │ ├── model.go │ └── train.go ├── omniglot │ ├── README.md │ ├── accuracy.go │ ├── analysis.go │ ├── cuda.go │ ├── debug.go │ ├── log │ │ ├── 15_classes │ │ │ ├── lstm_log.txt │ │ │ └── sgdstore_log.txt │ │ ├── 5_classes │ │ │ ├── lstm_log.txt │ │ │ ├── sgdstore_log.txt │ │ │ └── vanilla_log.txt │ │ ├── lr_search │ │ │ ├── sgdstore_0001_log.txt │ │ │ ├── sgdstore_0003_log.txt │ │ │ └── sgdstore_001_log.txt │ │ └── mem_search │ │ │ ├── sgdstore_4rh_log.txt │ │ │ └── sgdstore_deep_log.txt │ ├── main.go │ ├── model.go │ ├── plot │ │ ├── log_to_matlab.sh │ │ ├── plot.png │ │ ├── plot15.png │ │ ├── plot_15.sh │ │ ├── plot_5.sh │ │ ├── plot_lr.png │ │ ├── plot_lr.sh │ │ ├── plot_mem.png │ │ ├── plot_mem.sh │ │ └── smooth_data.m │ ├── samples.go │ └── train.go └── poly_approx │ ├── README.md │ ├── log │ ├── lstm.txt │ ├── sgdstore_1step.txt │ ├── sgdstore_2step.txt │ └── sgdstore_3step.txt │ ├── main.go │ ├── model.go │ ├── plot │ ├── log_to_matlab.sh │ ├── plot.png │ ├── plot_file.sh │ └── smooth_data.m │ └── samples.go ├── net.go └── net_test.go /README.md: -------------------------------------------------------------------------------- 1 | # sgdstore 2 | 3 | This is a memory-augmented neural network that uses a neural network as a storage device. Particularly, a *controller network* provides training examples for a *storage network* at every timestep. The *storage network* is trained on these samples with SGD at every timestep (in a differentiable manner). The controller can then query the storage network by feeding it inputs and seeing the corresponding outputs. The end result is that the storage network serves as a memory bank which is "written to" via SGD. 4 | 5 | # Hypotheses 6 | 7 | Neural networks seem to provide a lot of desirable properties as memory modules: 8 | 9 | * They can store a lot of information. 10 | * They can compress information. 11 | * They can generalize to new information. 12 | * They can interpolate between training samples. 13 | 14 | In a sense, a neural network can be seen as a key-value store which tries to generalize to new keys. This seems like the perfect memory structure for a memory-augmented neural network. 15 | 16 | # Results 17 | 18 | Preliminary results on polynomial approximation looked promising. See [experiments/poly_approx](experiments/poly_approx). After seeing those results, I decided to scale up to a harder meta-learning task. 19 | 20 | On the Omniglot handwriting dataset, the model (controlled by a vanilla RNN) outperformed an LSTM in training time (measured in epochs) and accuracy. See [experiments/omniglot](experiments/omniglot). 21 | -------------------------------------------------------------------------------- /activation.go: -------------------------------------------------------------------------------- 1 | package sgdstore 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // Activation is an activation function. 10 | type Activation int 11 | 12 | // Supported activation functions. 13 | const ( 14 | Tanh Activation = iota 15 | ReLU 16 | ) 17 | 18 | // Forward applies the activation function in the forward 19 | // direction. 20 | func (a Activation) Forward(in anydiff.Res) anydiff.Res { 21 | switch a { 22 | case Tanh: 23 | return anydiff.Tanh(in) 24 | case ReLU: 25 | return anydiff.ClipPos(in) 26 | } 27 | panic("unsupported activation") 28 | } 29 | 30 | // Backward applies backward propagation, given the output 31 | // from the forward pass and the upstream vector. 32 | func (a Activation) Backward(out, upstream anydiff.Res) anydiff.Res { 33 | switch a { 34 | case Tanh: 35 | return anydiff.Mul(anydiff.Complement(anydiff.Square(out)), upstream) 36 | case ReLU: 37 | mask := out.Output().Copy() 38 | anyvec.GreaterThan(mask, mask.Creator().MakeNumeric(0)) 39 | return anydiff.Mul(upstream, anydiff.NewConst(mask)) 40 | } 41 | panic("unsupported activation") 42 | } 43 | 44 | // Layer returns a compatible anynet.Layer. 45 | func (a Activation) Layer() anynet.Layer { 46 | switch a { 47 | case Tanh: 48 | return anynet.Tanh 49 | case ReLU: 50 | return anynet.ReLU 51 | } 52 | panic("unsupported activation") 53 | } 54 | -------------------------------------------------------------------------------- /block.go: -------------------------------------------------------------------------------- 1 | package sgdstore 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anyrnn" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvecsave" 12 | "github.com/unixpickle/essentials" 13 | "github.com/unixpickle/serializer" 14 | ) 15 | 16 | func init() { 17 | serializer.RegisterTypedDeserializer((&Block{}).SerializerType(), DeserializeBlock) 18 | } 19 | 20 | // Block is an RNN block that uses a Net as its memory. 21 | type Block struct { 22 | InitParams []*anydiff.Var 23 | Activation Activation 24 | 25 | // Gates which transform the input into various vectors 26 | // used to train and query the current Net. 27 | TrainInput anynet.Layer 28 | TrainTarget anynet.Layer 29 | StepSize anynet.Layer 30 | Query anynet.Layer 31 | 32 | // Steps is the number of SGD steps to take at each 33 | // timestep. 34 | Steps int 35 | } 36 | 37 | // LinearBlock creates a Block with linear gates. 38 | // 39 | // The blockIn argument specifies the input size for the 40 | // block. 41 | // 42 | // The trainBatch and queryBatch arguments specify the 43 | // batch sizes to use for training and querying, 44 | // respectively. 45 | // 46 | // The numSteps argument specifies the number of training 47 | // steps to take at each timestep. 48 | // The lrBias argument specifies the approximate initial 49 | // learning rate (i.e. step size). 50 | // 51 | // The layerSizes specify the sizes of the layers at every 52 | // point in the network. 53 | // The first layer corresponds to the network's input, and 54 | // the last to the network's output. 55 | // Thus, there must be at least two layer sizes. 56 | // 57 | // The block's output size can be computed as 58 | // 59 | // queryBatch * layerSizes[len(layerSizes)-1] 60 | // 61 | func LinearBlock(c anyvec.Creator, blockIn, trainBatch, queryBatch, numSteps int, 62 | lrBias float64, activation Activation, layerSizes ...int) *Block { 63 | if len(layerSizes) < 2 { 64 | panic("not enough layer sizes") 65 | } else if trainBatch < 1 || queryBatch < 1 { 66 | panic("invalid batch size") 67 | } 68 | 69 | res := &Block{ 70 | TrainInput: anynet.NewFC(c, blockIn, trainBatch*layerSizes[0]), 71 | TrainTarget: anynet.Net{ 72 | anynet.NewFC(c, blockIn, trainBatch*layerSizes[len(layerSizes)-1]), 73 | activation.Layer(), 74 | }, 75 | StepSize: anynet.Net{ 76 | anynet.NewFC(c, blockIn, 1).AddBias(c.MakeNumeric(math.Log(lrBias))), 77 | anynet.Exp, 78 | }, 79 | Query: anynet.NewFC(c, blockIn, queryBatch*layerSizes[0]), 80 | Steps: numSteps, 81 | Activation: activation, 82 | } 83 | 84 | layerSize := layerSizes[0] 85 | for _, size := range layerSizes[1:] { 86 | fc := anynet.NewFC(c, layerSize, size) 87 | layerSize = size 88 | res.InitParams = append(res.InitParams, fc.Parameters()...) 89 | } 90 | 91 | return res 92 | } 93 | 94 | // DeserializeBlock deserializes a Block. 95 | func DeserializeBlock(d []byte) (block *Block, err error) { 96 | defer essentials.AddCtxTo("deserialize sgdstore.Block", &err) 97 | var vecData []byte 98 | block = &Block{} 99 | err = serializer.DeserializeAny(d, &vecData, &block.TrainInput, &block.TrainTarget, 100 | &block.StepSize, &block.Query, &block.Steps) 101 | if err != nil { 102 | return nil, err 103 | } 104 | savedVecs, err := serializer.DeserializeSlice(vecData) 105 | if err != nil { 106 | return nil, err 107 | } 108 | for _, vecObj := range savedVecs { 109 | if vec, ok := vecObj.(*anyvecsave.S); ok { 110 | block.InitParams = append(block.InitParams, anydiff.NewVar(vec.Vector)) 111 | } else { 112 | return nil, fmt.Errorf("expected vector but got %T", vecObj) 113 | } 114 | } 115 | return 116 | } 117 | 118 | // Start produces a start state. 119 | func (b *Block) Start(n int) anyrnn.State { 120 | res := &State{Params: make([]*anyrnn.VecState, len(b.InitParams))} 121 | for i, p := range b.InitParams { 122 | res.Params[i] = anyrnn.NewVecState(p.Vector, n) 123 | } 124 | return res 125 | } 126 | 127 | // PropagateStart propagates through the start state. 128 | func (b *Block) PropagateStart(s anyrnn.StateGrad, g anydiff.Grad) { 129 | state := s.(*State) 130 | for i, paramVar := range b.InitParams { 131 | state.Params[i].PropagateStart(paramVar, g) 132 | } 133 | } 134 | 135 | // Step evaluates the block. 136 | func (b *Block) Step(s anyrnn.State, in anyvec.Vector) anyrnn.Res { 137 | state := s.(*State) 138 | inPool := anydiff.NewVar(in) 139 | netPool := state.pool() 140 | present := state.Present() 141 | n := present.NumPresent() 142 | gateOuts := b.applyGates(inPool, n) 143 | 144 | allRes := anydiff.PoolMulti(gateOuts, func(gateOuts []anydiff.Res) anydiff.MultiRes { 145 | poolReses := make([]anydiff.Res, len(netPool)) 146 | for i, x := range netPool { 147 | poolReses[i] = x 148 | } 149 | net := &Net{ 150 | Parameters: anydiff.Fuse(poolReses...), 151 | Num: n, 152 | Activation: b.Activation, 153 | } 154 | batchSize := gateOuts[0].Output().Len() / (net.InSize() * n) 155 | newNet := net.Train(gateOuts[0], gateOuts[1], gateOuts[2], batchSize, b.Steps) 156 | return anydiff.PoolMulti(newNet.Parameters, 157 | func(newParams []anydiff.Res) anydiff.MultiRes { 158 | net1 := *newNet 159 | net1.Parameters = anydiff.Fuse(newParams...) 160 | batchSize := gateOuts[3].Output().Len() / (net.InSize() * n) 161 | applied := newNet.Apply(gateOuts[3], batchSize) 162 | newReses := append([]anydiff.Res{applied}, newParams...) 163 | return anydiff.Fuse(newReses...) 164 | }) 165 | }) 166 | 167 | newState := &State{ 168 | Params: make([]*anyrnn.VecState, len(state.Params)), 169 | } 170 | for i, newVec := range allRes.Outputs()[1:] { 171 | newState.Params[i] = &anyrnn.VecState{ 172 | PresentMap: state.Params[i].PresentMap, 173 | Vector: newVec, 174 | } 175 | } 176 | v := anydiff.NewVarSet(b.Parameters()...) 177 | 178 | return &blockRes{ 179 | InPool: inPool, 180 | NetPools: netPool, 181 | OutVec: allRes.Outputs()[0], 182 | OutState: newState, 183 | AllRes: allRes, 184 | V: v, 185 | } 186 | } 187 | 188 | // Parameters returns the block's parameters, including 189 | // the parameters of the gates. 190 | func (b *Block) Parameters() []*anydiff.Var { 191 | gateParams := anynet.AllParameters(b.TrainInput, b.TrainTarget, b.StepSize, b.Query) 192 | return append(gateParams, b.InitParams...) 193 | } 194 | 195 | // SerializerType returns the unique ID used to serialize 196 | // a Block with the serializer package. 197 | func (b *Block) SerializerType() string { 198 | return "github.com/unixpickle/sgdstore.Block" 199 | } 200 | 201 | // Serialize serializes the block. 202 | func (b *Block) Serialize() ([]byte, error) { 203 | var savedVecs []serializer.Serializer 204 | for _, v := range b.InitParams { 205 | savedVecs = append(savedVecs, &anyvecsave.S{Vector: v.Vector}) 206 | } 207 | vecData, err := serializer.SerializeSlice(savedVecs) 208 | if err != nil { 209 | return nil, err 210 | } 211 | return serializer.SerializeAny( 212 | serializer.Bytes(vecData), 213 | b.TrainInput, 214 | b.TrainTarget, 215 | b.StepSize, 216 | b.Query, 217 | b.Steps, 218 | ) 219 | } 220 | 221 | // applyGates returns a vector of the form: 222 | // 223 | // [trainIn, trainTarget, step, query] 224 | // 225 | func (b *Block) applyGates(x anydiff.Res, n int) anydiff.MultiRes { 226 | gates := []anynet.Layer{b.TrainInput, b.TrainTarget, b.StepSize, b.Query} 227 | var outs []anydiff.Res 228 | for _, gate := range gates { 229 | outs = append(outs, gate.Apply(x, n)) 230 | } 231 | return anydiff.Fuse(outs...) 232 | } 233 | 234 | // State is the anyrnn.State and anyrnn.StateGrad type for 235 | // a Block. 236 | type State struct { 237 | Params []*anyrnn.VecState 238 | } 239 | 240 | // Present returns the present sequence map. 241 | func (s *State) Present() anyrnn.PresentMap { 242 | return s.Params[0].Present() 243 | } 244 | 245 | // Reduce removes states. 246 | func (s *State) Reduce(p anyrnn.PresentMap) anyrnn.State { 247 | res := &State{Params: make([]*anyrnn.VecState, len(s.Params))} 248 | for i, param := range s.Params { 249 | res.Params[i] = param.Reduce(p).(*anyrnn.VecState) 250 | } 251 | return res 252 | } 253 | 254 | // Expand inserts gradients. 255 | func (s *State) Expand(p anyrnn.PresentMap) anyrnn.StateGrad { 256 | res := &State{Params: make([]*anyrnn.VecState, len(s.Params))} 257 | for i, param := range s.Params { 258 | res.Params[i] = param.Expand(p).(*anyrnn.VecState) 259 | } 260 | return res 261 | } 262 | 263 | func (s *State) pool() []*anydiff.Var { 264 | res := make([]*anydiff.Var, len(s.Params)) 265 | for i, packed := range s.Params { 266 | res[i] = anydiff.NewVar(packed.Vector) 267 | } 268 | return res 269 | } 270 | 271 | type blockRes struct { 272 | InPool *anydiff.Var 273 | NetPools []*anydiff.Var 274 | OutVec anyvec.Vector 275 | OutState *State 276 | AllRes anydiff.MultiRes 277 | V anydiff.VarSet 278 | } 279 | 280 | func (b *blockRes) State() anyrnn.State { 281 | return b.OutState 282 | } 283 | 284 | func (b *blockRes) Output() anyvec.Vector { 285 | return b.OutVec 286 | } 287 | 288 | func (b *blockRes) Vars() anydiff.VarSet { 289 | return b.V 290 | } 291 | 292 | func (b *blockRes) Propagate(u anyvec.Vector, s anyrnn.StateGrad, 293 | g anydiff.Grad) (anyvec.Vector, anyrnn.StateGrad) { 294 | allUpstream := make([]anyvec.Vector, len(b.AllRes.Outputs())) 295 | allUpstream[0] = u 296 | if s == nil { 297 | for i, x := range allUpstream { 298 | if x == nil { 299 | size := b.AllRes.Outputs()[i].Len() 300 | allUpstream[i] = u.Creator().MakeVector(size) 301 | } 302 | } 303 | } else { 304 | sg := s.(*State) 305 | for i, vecs := range sg.Params { 306 | allUpstream[i+1] = vecs.Vector 307 | } 308 | } 309 | 310 | for _, p := range b.pools() { 311 | g[p] = p.Vector.Creator().MakeVector(p.Vector.Len()) 312 | defer func(g anydiff.Grad, p *anydiff.Var) { 313 | delete(g, p) 314 | }(g, p) 315 | } 316 | 317 | b.AllRes.Propagate(allUpstream, g) 318 | 319 | paramGrad := make([]*anyrnn.VecState, len(b.NetPools)) 320 | for i, netPool := range b.NetPools { 321 | paramGrad[i] = &anyrnn.VecState{ 322 | Vector: g[netPool], 323 | PresentMap: b.OutState.Present(), 324 | } 325 | } 326 | 327 | return g[b.InPool], &State{Params: paramGrad} 328 | } 329 | 330 | func (b *blockRes) pools() []*anydiff.Var { 331 | return append([]*anydiff.Var{b.InPool}, b.NetPools...) 332 | } 333 | -------------------------------------------------------------------------------- /block_test.go: -------------------------------------------------------------------------------- 1 | package sgdstore 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anynet/anyrnn" 11 | "github.com/unixpickle/anyvec" 12 | "github.com/unixpickle/anyvec/anyvec32" 13 | "github.com/unixpickle/anyvec/anyvec64" 14 | ) 15 | 16 | func TestBlockGradients(t *testing.T) { 17 | c := anyvec64.CurrentCreator() 18 | inSeq, inVars := randomTestSequence(3) 19 | block := &Block{ 20 | InitParams: []*anydiff.Var{ 21 | anydiff.NewVar(anyvec64.MakeVector(4 * 2)), 22 | anydiff.NewVar(anyvec64.MakeVector(2)), 23 | }, 24 | TrainInput: anynet.NewFC(c, 3, 4*2), 25 | TrainTarget: anynet.Net{ 26 | anynet.NewFC(c, 3, 2*2), 27 | anynet.Tanh, 28 | }, 29 | StepSize: anynet.Net{ 30 | anynet.NewFC(c, 3, 1), 31 | anynet.Exp, 32 | }, 33 | Query: anynet.NewFC(c, 3, 4*2), 34 | Steps: 1, 35 | } 36 | if len(block.Parameters()) != 10 { 37 | t.Errorf("expected 10 parameters, but got %d", len(block.Parameters())) 38 | } 39 | for _, param := range block.Parameters() { 40 | anyvec.Rand(param.Vector, anyvec.Normal, nil) 41 | // Prevent gradient explosion, which causes the tests to 42 | // fail because of bad approximations. 43 | param.Vector.Scale(c.MakeNumeric(0.5)) 44 | } 45 | checker := &anydifftest.SeqChecker{ 46 | F: func() anyseq.Seq { 47 | return anyrnn.Map(inSeq, block) 48 | }, 49 | V: append(inVars, block.Parameters()...), 50 | } 51 | checker.FullCheck(t) 52 | } 53 | 54 | func BenchmarkBlock(b *testing.B) { 55 | c := anyvec32.CurrentCreator() 56 | block := LinearBlock(c, 512, 4, 4, 1, 0.1, 128, 256, 128) 57 | startState := block.Start(8) 58 | inVec := c.MakeVector(startState.Present().NumPresent() * 512) 59 | anyvec.Rand(inVec, anyvec.Normal, nil) 60 | 61 | b.Run("Forward", func(b *testing.B) { 62 | for i := 0; i < b.N; i++ { 63 | block.Step(startState, inVec) 64 | } 65 | }) 66 | b.Run("Backward", func(b *testing.B) { 67 | upstream := inVec.Copy() 68 | grad := anydiff.NewGrad(block.Parameters()...) 69 | b.ResetTimer() 70 | for i := 0; i < b.N; i++ { 71 | out := block.Step(startState, inVec) 72 | out.Propagate(upstream, nil, grad) 73 | } 74 | }) 75 | } 76 | 77 | // randomTestSequence is borrowed from 78 | // https://github.com/unixpickle/anynet/blob/6a8bd570b702861f3c1260a6916723beea6bf296/anyrnn/layer_test.go#L34 79 | func randomTestSequence(inSize int) (anyseq.Seq, []*anydiff.Var) { 80 | inVars := []*anydiff.Var{} 81 | inBatches := []*anyseq.ResBatch{} 82 | 83 | presents := [][]bool{{true, true, true}, {true, false, true}} 84 | numPres := []int{3, 2} 85 | chunkLengths := []int{2, 3} 86 | 87 | for chunkIdx, pres := range presents { 88 | for i := 0; i < chunkLengths[chunkIdx]; i++ { 89 | vec := anyvec64.MakeVector(inSize * numPres[chunkIdx]) 90 | anyvec.Rand(vec, anyvec.Normal, nil) 91 | v := anydiff.NewVar(vec) 92 | batch := &anyseq.ResBatch{ 93 | Packed: v, 94 | Present: pres, 95 | } 96 | inVars = append(inVars, v) 97 | inBatches = append(inBatches, batch) 98 | } 99 | } 100 | return anyseq.ResSeq(anyvec64.CurrentCreator(), inBatches), inVars 101 | } 102 | -------------------------------------------------------------------------------- /experiments/audioset/README.md: -------------------------------------------------------------------------------- 1 | # AudioSet 2 | 3 | In this meta-learning experiment, I adapt the [AudioSet dataset](https://research.google.com/audioset/) to the domain of meta-learning. There are 527 sound classes in AudioSet; I split this into 477 training classes and [50 evaluation classes](eval_classes.txt). The task is very similar to the task from [Omniglot](../omniglot): the model is presented with sample after sample and has to predict the randomly assigned class of each sample. 4 | 5 | Unlike for Omniglot, there are two RNN components in the model for this experiment. There is the meta-learning, which is also present in Omniglot. There is also the feature RNN, which takes variable-length audio segments and converts them to fixed-length vectors. 6 | 7 | # Initial results 8 | 9 | Initial results do not look good. Random guessing would have a loss of `ln(1/5)=1.609`. Training loss gets down to about 1.54. Evaluation loss stays at about 1.615, which is worse than random. Thus, the model is overfitting, and even then it's barely fitting the training data. 10 | 11 | The above results were with roughly 18k samples, all taken from the official evaluation subset of AudioSet. I have just downloaded another 19k samples from the test set, so that extra data may prove very helpful. 12 | 13 | Besides using more data, I will also experiment with higher-capacity models. More classes may be necessary for the model not to overfit, in which case data augmentation may be necessary. For example, I might double the number of classes by reversing samples, speeding them up, or overlaying them. 14 | -------------------------------------------------------------------------------- /experiments/audioset/cuda.go: -------------------------------------------------------------------------------- 1 | // +build cuda 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/unixpickle/anyvec/anyvec32" 7 | "github.com/unixpickle/cudavec" 8 | ) 9 | 10 | func init() { 11 | handle, err := cudavec.NewHandleDefault() 12 | if err != nil { 13 | panic(err) 14 | } 15 | anyvec32.Use(&cudavec.Creator32{Handle: handle}) 16 | } 17 | -------------------------------------------------------------------------------- /experiments/audioset/eval_classes.txt: -------------------------------------------------------------------------------- 1 | /m/02x984l 2 | /m/07rgkc5 3 | /m/07rdhzs 4 | /m/07rv4dm 5 | /m/07plct2 6 | /m/02mfyn 7 | /m/07rcgpl 8 | /m/07pczhz 9 | /m/02k_mr 10 | /m/0l14t7 11 | /m/07ryjzk 12 | /m/03t3fj 13 | /m/01jg1z 14 | /m/01w250 15 | /m/027m70_ 16 | /m/07p_0gm 17 | /m/0239kh 18 | /m/01g90h 19 | /m/016622 20 | /m/06hps 21 | /m/07n_g 22 | /m/07rqsjt 23 | /m/07r4gkf 24 | /m/07qfgpx 25 | /m/01qbl 26 | /m/07pjjrj 27 | /m/0242l 28 | /t/dd00134 29 | /m/03wvsk 30 | /m/03qtq 31 | /m/0l156k 32 | /m/01jg02 33 | /m/07pn_8q 34 | /m/0j6m2 35 | /m/06rvn 36 | /m/02hnl 37 | /m/0d31p 38 | /m/0bm02 39 | /m/0dl83 40 | /m/05_wcq 41 | /m/03l9g 42 | /m/03w41f 43 | /m/0j2kx 44 | /m/07phxs1 45 | /m/07rbp7_ 46 | /m/02pjr4 47 | /m/06xkwv 48 | /m/01m4t 49 | /m/0mbct 50 | /m/0150b9 51 | -------------------------------------------------------------------------------- /experiments/audioset/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/unixpickle/essentials" 8 | ) 9 | 10 | func main() { 11 | if len(os.Args) < 2 { 12 | fmt.Fprintln(os.Stderr, os.Args[0], " [args | -help]") 13 | fmt.Fprintln(os.Stderr) 14 | fmt.Fprintln(os.Stderr, "Subcommands:") 15 | fmt.Fprintln(os.Stderr, " train train a new or existing model") 16 | fmt.Fprintln(os.Stderr) 17 | os.Exit(1) 18 | } 19 | switch os.Args[1] { 20 | case "train": 21 | Train(os.Args[2:]) 22 | default: 23 | essentials.Die("unknown sub-command:", os.Args[1]) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /experiments/audioset/model.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anynet/anyrnn" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | "github.com/unixpickle/essentials" 10 | "github.com/unixpickle/sgdstore" 11 | ) 12 | 13 | func learnerBlock(name string, sgdSteps, numFeatures, numOut int) anyrnn.Block { 14 | c := anyvec32.CurrentCreator() 15 | 16 | inLayer := normInputLayer(c, numFeatures, numOut) 17 | 18 | switch name { 19 | case "sgdstore": 20 | return anyrnn.Stack{ 21 | inLayer, 22 | anyrnn.NewVanilla(c, numFeatures+numOut, 384, anynet.Tanh), 23 | anyrnn.NewVanilla(c, 384, 384, anynet.Tanh), 24 | sgdstore.LinearBlock(c, 384, 16, 2, sgdSteps, 0.2, 25 | sgdstore.Tanh, 32, 256, 32), 26 | &anyrnn.LayerBlock{ 27 | Layer: anynet.Net{ 28 | anynet.NewFC(c, 64, 64), 29 | anynet.Tanh, 30 | anynet.NewFC(c, 64, numOut), 31 | anynet.LogSoftmax, 32 | }, 33 | }, 34 | } 35 | case "lstm": 36 | return anyrnn.Stack{ 37 | inLayer, 38 | anyrnn.NewLSTM(c, numFeatures+numOut, 384), 39 | anyrnn.NewLSTM(c, 384, 384), 40 | &anyrnn.LayerBlock{ 41 | Layer: anynet.Net{ 42 | anynet.NewFC(c, 384, numOut), 43 | anynet.LogSoftmax, 44 | }, 45 | }, 46 | } 47 | case "vanilla": 48 | return anyrnn.Stack{ 49 | inLayer, 50 | anyrnn.NewVanilla(c, numFeatures+numOut, 384, anynet.Tanh), 51 | anyrnn.NewVanilla(c, 384, 384, anynet.Tanh), 52 | &anyrnn.LayerBlock{ 53 | Layer: anynet.Net{ 54 | anynet.NewFC(c, 384, numOut), 55 | anynet.LogSoftmax, 56 | }, 57 | }, 58 | } 59 | default: 60 | essentials.Die("unknown model:", name) 61 | panic("unreachable") 62 | } 63 | } 64 | 65 | func featureBlock(numFeatures, blockSize int) anyrnn.Block { 66 | c := anyvec32.CurrentCreator() 67 | return anyrnn.NewLSTM(c, blockSize, numFeatures).ScaleInWeights(c.MakeNumeric(7)) 68 | } 69 | 70 | func normInputLayer(c anyvec.Creator, numFeatures, numOut int) anyrnn.Block { 71 | affine := &anynet.Affine{ 72 | Scalers: anydiff.NewVar(c.MakeVector(numFeatures + numOut)), 73 | Biases: anydiff.NewVar(c.MakeVector(numFeatures + numOut)), 74 | } 75 | 76 | // Scaling the one-hot vector for the last timestep 77 | // tends to improve performance. 78 | outScale := affine.Scalers.Vector.Slice(numFeatures, numFeatures+numOut) 79 | outScale.Scale(c.MakeNumeric(16)) 80 | 81 | return &anyrnn.LayerBlock{Layer: affine} 82 | } 83 | -------------------------------------------------------------------------------- /experiments/audioset/train.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "strings" 10 | 11 | "github.com/unixpickle/anydiff" 12 | "github.com/unixpickle/anydiff/anyseq" 13 | "github.com/unixpickle/anynet" 14 | "github.com/unixpickle/anynet/anyrnn" 15 | "github.com/unixpickle/anynet/anysgd" 16 | "github.com/unixpickle/anyvec" 17 | "github.com/unixpickle/anyvec/anyvec32" 18 | "github.com/unixpickle/audioset" 19 | "github.com/unixpickle/audioset/metaset" 20 | "github.com/unixpickle/essentials" 21 | "github.com/unixpickle/rip" 22 | "github.com/unixpickle/serializer" 23 | ) 24 | 25 | func Train(args []string) { 26 | var dataDir string 27 | var dataCSV string 28 | var evalClassPath string 29 | 30 | var learnerNetPath string 31 | var featureNetPath string 32 | var modelType string 33 | var stepSize float64 34 | var batchSize int 35 | var sgdSteps int 36 | var numClasses int 37 | var episodeLen int 38 | 39 | var audioFeatureSize int 40 | var pcmChunkSize int 41 | 42 | fs := flag.NewFlagSet("train", flag.ExitOnError) 43 | 44 | fs.StringVar(&dataDir, "datadir", "", "directory of AudioSet samples") 45 | fs.StringVar(&dataCSV, "datacsv", "", "path to AudioSet CSV file") 46 | fs.StringVar(&evalClassPath, "dataclasses", "eval_classes.txt", "path to eval classes") 47 | 48 | fs.StringVar(&learnerNetPath, "learner", "out_learner", "learner net output path") 49 | fs.StringVar(&featureNetPath, "features", "out_features", "feature net output path") 50 | fs.StringVar(&modelType, "model", "sgdstore", 51 | "model type (sgdstore, lstm, or vanilla)") 52 | fs.Float64Var(&stepSize, "step", 0.001, "SGD step size") 53 | fs.IntVar(&batchSize, "batch", 16, "SGD batch size") 54 | fs.IntVar(&sgdSteps, "steps", 1, "steps per sgdstore") 55 | fs.IntVar(&numClasses, "classes", 5, "classes per episode") 56 | fs.IntVar(&episodeLen, "eplen", 50, "episode length") 57 | 58 | fs.IntVar(&audioFeatureSize, "audiofeats", 128, "audio feature vector size") 59 | fs.IntVar(&pcmChunkSize, "chunksize", 512, "PCM sample chunk size") 60 | 61 | fs.Parse(args) 62 | 63 | if dataDir == "" || dataCSV == "" { 64 | essentials.Die("Required flags: -datadir and -datacsv. See -help.") 65 | } 66 | 67 | var learner, features anyrnn.Block 68 | if err := serializer.LoadAny(learnerNetPath, &learner); err != nil { 69 | log.Println("Creating new learner...") 70 | learner = learnerBlock(modelType, sgdSteps, audioFeatureSize, numClasses) 71 | } else { 72 | log.Println("Loaded learner.") 73 | } 74 | if err := serializer.LoadAny(featureNetPath, &features); err != nil { 75 | log.Println("Creating new feature net...") 76 | features = featureBlock(audioFeatureSize, pcmChunkSize) 77 | } else { 78 | log.Println("Loaded feature net.") 79 | } 80 | 81 | allSamples, err := audioset.ReadSet(dataDir, dataCSV) 82 | if err != nil { 83 | essentials.Die(err) 84 | } 85 | 86 | training, eval := metaset.Split(allSamples, readEvalClasses(evalClassPath)) 87 | 88 | log.Printf("Got %d samples: %d training, %d eval", len(allSamples), len(training), 89 | len(eval)) 90 | 91 | trainer := &metaset.Trainer{ 92 | Creator: anyvec32.CurrentCreator(), 93 | FeatureFunc: func(seq anyseq.Seq) anydiff.Res { 94 | return anyseq.Tail(anyrnn.Map(seq, features)) 95 | }, 96 | LearnerFunc: func(eps anyseq.Seq) anyseq.Seq { 97 | return anyrnn.Map(eps, learner) 98 | }, 99 | Params: anynet.AllParameters(learner, features), 100 | Set: training, 101 | NumClasses: numClasses, 102 | NumSteps: episodeLen, 103 | ChunkSize: pcmChunkSize, 104 | Average: true, 105 | } 106 | 107 | valBatches := fetchEvalBatches(*trainer, eval, batchSize) 108 | 109 | var iter int 110 | sgd := &anysgd.SGD{ 111 | Fetcher: trainer, 112 | Gradienter: trainer, 113 | Transformer: &anysgd.RMSProp{}, 114 | Samples: anysgd.LengthSampleList(batchSize), 115 | Rater: anysgd.ConstRater(stepSize), 116 | BatchSize: batchSize, 117 | StatusFunc: func(b anysgd.Batch) { 118 | if iter%4 == 0 { 119 | batch := <-valBatches 120 | valCost := anyvec.Sum(trainer.TotalCost(batch).Output()) 121 | log.Printf("iter %d: cost=%v validation=%v", iter, trainer.LastCost, 122 | valCost) 123 | } else { 124 | log.Printf("iter %d: cost=%v", iter, trainer.LastCost) 125 | } 126 | iter++ 127 | }, 128 | } 129 | 130 | if err := sgd.Run(rip.NewRIP().Chan()); err != nil { 131 | fmt.Fprintln(os.Stderr, err) 132 | } 133 | 134 | if err := serializer.SaveAny(learnerNetPath, learner); err != nil { 135 | essentials.Die(err) 136 | } 137 | if err := serializer.SaveAny(featureNetPath, features); err != nil { 138 | essentials.Die(err) 139 | } 140 | } 141 | 142 | func readEvalClasses(path string) []string { 143 | data, err := ioutil.ReadFile(path) 144 | if err != nil { 145 | essentials.Die(err) 146 | } 147 | return strings.Fields(string(data)) 148 | } 149 | 150 | func fetchEvalBatches(t metaset.Trainer, set audioset.Set, size int) <-chan anysgd.Batch { 151 | res := make(chan anysgd.Batch, 1) 152 | go func() { 153 | defer close(res) 154 | t.Set = set 155 | for { 156 | batch, err := t.Fetch(anysgd.LengthSampleList(size)) 157 | if err != nil { 158 | essentials.Die(err) 159 | } 160 | res <- batch 161 | } 162 | }() 163 | return res 164 | } 165 | -------------------------------------------------------------------------------- /experiments/omniglot/README.md: -------------------------------------------------------------------------------- 1 | # Omniglot 2 | 3 | In an episode of this experiment, the model is presented with a sequence of either 50 or 100 images. It knows beforehand that there are 5 (or 15) different kinds of characters with different labels. However, it does not know what the labels are, nor has it seen the exact characters before. How well can a model learn to do this task, if you train it with enough background knowledge? 4 | 5 | # 5-class task 6 | 7 | In this task, the sequences are of length 50 and there are 5 different classes. Here are the accuracy numbers for the sgdstore model. 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
Instance 1Instance 2Instance 3Instance 4Instance 11
35.76%89.38%93.72%95.25%96.74%
25 | 26 | Note that these results look worse than the results in [Santoro et al.](https://arxiv.org/abs/1605.06065). This is likely due to the fact that I use a different training/evaluation split. In the aforementioned paper, the network is meta-trained on more data and tested on less data. The models I trained did indeed overfit slightly, indicating that more training data would be helpful. I am using the original background/evaluation split from [Lake et al.](http://science.sciencemag.org/content/350/6266/1332). 27 | 28 | The following graph shows, for three different models, the validation error over time (measured in episodes) during meta-training. The sgdstore model clearly does the best, but the LSTM catches up after way more training. The vanilla RNN (which is used as a controller for sgdstore), does terribly on its own. I will update the graph after I have run the sgdstore model for longer. 29 | 30 | ![Training plot](plot/plot.png) 31 | 32 | Training in this experiment was done with batch sizes of 64 and a step size of 0.0003. It is likely that hyper-parameter tuning would result in much better results. I discovered in the 15-class task that batches of 16 work much better in terms of data efficiency. I have yet to test larger step sizes, but I bet that will help too. 33 | 34 | # 15-class task 35 | 36 | In this task, the sequences are of length 100 and there are 15 different classes. Here are the accuracy results for the sgdstore model: 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
Instance 1Instance 2Instance 3Instance 4Instance 10Instance 11
9.73%77.07%84.73%87.39%87.74%90.91%
56 | 57 | Here is a plot of training over time. In this case, I used a batch size of 16 and a learning rate of 0.0003. It is clear that the sgdstore model learns an order of magnitude faster than the LSTM, but once again the LSTM does eventually catch up: 58 | 59 | ![Training plot](plot/plot15.png) 60 | 61 | # Hyper-parameter exploration 62 | 63 | The above experiments were with a learning rate of 0.0003. From the following graph, it's clear that the model would learn faster (in the short term) with a learning rate of 0.001: 64 | 65 | ![Learning rate comparison plot](plot/plot_lr.png) 66 | 67 | Also, the memory modules in the above experiments were single-layer MLPs with 256 hidden units and two read heads. By "two read heads", I mean that the controller got to run two samples through the memory network. Here are two variations: one with "deep memory" (two layers of 256 units), another with four read heads. It is clear that deep memory helps in the long run, perhaps just due to the extra capacity: 68 | 69 | ![Memory structure comparison plot](plot/plot_mem.png) 70 | 71 | Here are the accuracy measurements for the "deep memory" model: 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 |
Set1st2nd3rd10th11th
Eval11.4%81.7%86.5%91.2%91.0%
Train13.4%88.6%92.7%95.6%96.8%
99 | 100 | # A note on LSTM results 101 | 102 | I have found that I can get much better LSTM results than the ones reported in *Santoro et al.*. They stop training after 100,000 episodes, which seems arbitrary (almost like it was chosen to make their model look good, since it learns faster). I don't want to confuse learning speed with model capacity, which *Santoro et al.* seems to do. 103 | 104 | I use two-layer LSTMs with 384 cells per layer. This is likely much more capacity than Santoro et al. allow for their LSTMs. I think it would be unfair not to give LSTMs the benefit of the doubt, even if their learning speeds are a lot worse than memory-augmented neural networks. 105 | -------------------------------------------------------------------------------- /experiments/omniglot/accuracy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet/anyrnn" 9 | "github.com/unixpickle/anynet/anys2s" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/essentials" 12 | "github.com/unixpickle/omniglot" 13 | "github.com/unixpickle/serializer" 14 | ) 15 | 16 | func Accuracy(args []string) { 17 | var modelPath string 18 | var dataPath string 19 | var batchSize int 20 | var numClasses int 21 | var episodeLen int 22 | 23 | fs := flag.NewFlagSet("accuracy", flag.ExitOnError) 24 | fs.StringVar(&modelPath, "model", "model_out", "model path") 25 | fs.StringVar(&dataPath, "data", "", "path to evaluation data") 26 | fs.IntVar(&batchSize, "batch", 16, "") 27 | fs.IntVar(&numClasses, "classes", 5, "classes per episode") 28 | fs.IntVar(&episodeLen, "eplen", 50, "episode length") 29 | fs.Parse(args) 30 | 31 | if dataPath == "" { 32 | essentials.Die("Required flag: -data. See -help.") 33 | } 34 | 35 | var model anyrnn.Block 36 | if err := serializer.LoadAny(modelPath, &model); err != nil { 37 | essentials.Die(err) 38 | } 39 | 40 | data, err := omniglot.ReadSet(dataPath) 41 | if err != nil { 42 | essentials.Die(err) 43 | } 44 | data = data.Augment() 45 | 46 | tr := &anys2s.Trainer{} 47 | samples := &Samples{ 48 | Length: batchSize, 49 | Sets: data.ByClass(), 50 | Augment: false, 51 | NumClasses: numClasses, 52 | NumTimesteps: episodeLen, 53 | } 54 | 55 | totalSeen := map[int]int{} 56 | totalCorrect := map[int]int{} 57 | 58 | for { 59 | batch, err := tr.Fetch(samples) 60 | if err != nil { 61 | essentials.Die(err) 62 | } 63 | b := batch.(*anys2s.Batch) 64 | actual := anyseq.SeparateSeqs(anyrnn.Map(b.Inputs, model).Output()) 65 | expected := anyseq.SeparateSeqs(b.Outputs.Output()) 66 | 67 | for i, actualSeq := range actual { 68 | expectedSeq := expected[i] 69 | seenCounts := map[int]int{} 70 | for j, a := range actualSeq { 71 | predicted := anyvec.MaxIndex(a) 72 | correct := anyvec.MaxIndex(expectedSeq[j]) 73 | seen := seenCounts[correct] 74 | if correct == predicted { 75 | totalCorrect[seen]++ 76 | } 77 | totalSeen[seen]++ 78 | seenCounts[correct]++ 79 | } 80 | } 81 | 82 | printAccuracies(totalSeen, totalCorrect) 83 | } 84 | } 85 | 86 | func printAccuracies(seen, correct map[int]int) { 87 | for i := 0; i <= 10; i++ { 88 | percent := 100 * float64(correct[i]) / float64(seen[i]) 89 | fmt.Printf("Instance %d: %.2f%%\n", i, percent) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /experiments/omniglot/analysis.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "math" 7 | 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anyrnn" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/essentials" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func Analysis(args []string) { 16 | var modelPath string 17 | fs := flag.NewFlagSet("analysis", flag.ExitOnError) 18 | fs.StringVar(&modelPath, "model", "", "path to model file") 19 | fs.Parse(args) 20 | 21 | if modelPath == "" { 22 | essentials.Die("Required flag: -model. See -help.") 23 | } 24 | 25 | var model anyrnn.Block 26 | if err := serializer.LoadAny(modelPath, &model); err != nil { 27 | essentials.Die(err) 28 | } 29 | 30 | for i, p := range model.(anynet.Parameterizer).Parameters() { 31 | v := p.Vector 32 | fmt.Printf("Param %d: mean=%f stddev=%f\n", i, computeMean(v), 33 | math.Sqrt(float64(computeVariance(v)))) 34 | } 35 | } 36 | 37 | func computeMean(vec anyvec.Vector) float32 { 38 | return anyvec.Sum(vec).(float32) / float32(vec.Len()) 39 | } 40 | 41 | func computeVariance(vec anyvec.Vector) float32 { 42 | mean := computeMean(vec) 43 | sq := vec.Copy() 44 | anyvec.Pow(sq, float32(2)) 45 | moment2 := computeMean(sq) 46 | return moment2 - mean*mean 47 | } 48 | -------------------------------------------------------------------------------- /experiments/omniglot/cuda.go: -------------------------------------------------------------------------------- 1 | // +build cuda 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/unixpickle/anyvec/anyvec32" 7 | "github.com/unixpickle/cudavec" 8 | ) 9 | 10 | func init() { 11 | handle, err := cudavec.NewHandleDefault() 12 | if err != nil { 13 | panic(err) 14 | } 15 | anyvec32.Use(&cudavec.Creator32{Handle: handle}) 16 | } 17 | -------------------------------------------------------------------------------- /experiments/omniglot/debug.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anyrnn" 10 | "github.com/unixpickle/essentials" 11 | "github.com/unixpickle/serializer" 12 | "github.com/unixpickle/sgdstore" 13 | ) 14 | 15 | func Debug(args []string) { 16 | var inPath, outPath string 17 | var logStepSize bool 18 | var logTrainIn bool 19 | var logTrainTarget bool 20 | fs := flag.NewFlagSet("debug", flag.ExitOnError) 21 | fs.StringVar(&inPath, "in", "model_out", "input model path") 22 | fs.StringVar(&outPath, "out", "", "output model path") 23 | fs.BoolVar(&logStepSize, "logstep", false, "log step size") 24 | fs.BoolVar(&logTrainIn, "logtrain", false, "log training inputs") 25 | fs.BoolVar(&logTrainTarget, "logtarget", false, "log training targets") 26 | fs.Parse(args) 27 | 28 | if inPath == "" || outPath == "" { 29 | essentials.Die("Required flag: -in and -out. See -help.") 30 | } 31 | 32 | if !logStepSize && !logTrainIn && !logTrainTarget { 33 | fmt.Fprintln(os.Stderr, "Warning: no new debug layers. See -help.") 34 | } 35 | 36 | var model anyrnn.Block 37 | if err := serializer.LoadAny(inPath, &model); err != nil { 38 | essentials.Die(err) 39 | } 40 | 41 | stack := model.(anyrnn.Stack) 42 | for _, layer := range stack { 43 | if par, ok := layer.(*anyrnn.Parallel); ok { 44 | layer = par.Block2 45 | } 46 | if block, ok := layer.(*sgdstore.Block); ok { 47 | if logStepSize { 48 | net := block.StepSize.(anynet.Net) 49 | block.StepSize = append(net, debugLayer("step")) 50 | } 51 | if logTrainIn { 52 | net := block.TrainInput 53 | block.TrainInput = anynet.Net{net, debugLayer("train")} 54 | } 55 | if logTrainTarget { 56 | net := block.TrainTarget.(anynet.Net) 57 | block.TrainTarget = append(net, debugLayer("target")) 58 | } 59 | } 60 | } 61 | 62 | if err := serializer.SaveAny(outPath, model); err != nil { 63 | essentials.Die(err) 64 | } 65 | } 66 | 67 | func debugLayer(id string) anynet.Layer { 68 | return &anynet.Debug{ID: id, PrintRaw: true} 69 | } 70 | -------------------------------------------------------------------------------- /experiments/omniglot/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/unixpickle/essentials" 8 | ) 9 | 10 | func main() { 11 | if len(os.Args) < 2 { 12 | fmt.Fprintln(os.Stderr, os.Args[0], " [args | -help]") 13 | fmt.Fprintln(os.Stderr) 14 | fmt.Fprintln(os.Stderr, "Subcommands:") 15 | fmt.Fprintln(os.Stderr, " train train a new or existing model") 16 | fmt.Fprintln(os.Stderr, " analysis dump weight statistics") 17 | fmt.Fprintln(os.Stderr, " accuracy evaluate model") 18 | fmt.Fprintln(os.Stderr, " debug add debug layers to model") 19 | fmt.Fprintln(os.Stderr) 20 | os.Exit(1) 21 | } 22 | switch os.Args[1] { 23 | case "train": 24 | Train(os.Args[2:]) 25 | case "analysis": 26 | Analysis(os.Args[2:]) 27 | case "accuracy": 28 | Accuracy(os.Args[2:]) 29 | case "debug": 30 | Debug(os.Args[2:]) 31 | default: 32 | essentials.Die("unknown sub-command:", os.Args[1]) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /experiments/omniglot/model.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anynet/anyrnn" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | "github.com/unixpickle/essentials" 10 | "github.com/unixpickle/sgdstore" 11 | ) 12 | 13 | func NewModel(name string, sgdSteps, outCount int) anyrnn.Block { 14 | c := anyvec32.CurrentCreator() 15 | numPixels := ImageSize * ImageSize 16 | 17 | switch name { 18 | case "sgdstore": 19 | return anyrnn.Stack{ 20 | normInputLayer(c, outCount, numPixels), 21 | anyrnn.NewVanilla(c, numPixels+outCount, 384, anynet.Tanh), 22 | anyrnn.NewVanilla(c, 384, 384, anynet.Tanh), 23 | sgdstore.LinearBlock(c, 384, 16, 2, sgdSteps, 0.2, sgdstore.Tanh, 24 | 32, 256, 32), 25 | &anyrnn.LayerBlock{ 26 | Layer: anynet.Net{ 27 | anynet.NewFC(c, 64, 64), 28 | anynet.Tanh, 29 | anynet.NewFC(c, 64, outCount), 30 | anynet.LogSoftmax, 31 | }, 32 | }, 33 | } 34 | case "parasgdstore": 35 | return anyrnn.Stack{ 36 | normInputLayer(c, outCount, numPixels), 37 | anyrnn.NewVanilla(c, numPixels+outCount, 384, anynet.Tanh), 38 | anyrnn.NewVanilla(c, 384, 384, anynet.Tanh), 39 | &anyrnn.Parallel{ 40 | Block1: &anyrnn.LayerBlock{Layer: anynet.Net{}}, 41 | Block2: sgdstore.LinearBlock(c, 384, 16, 2, sgdSteps, 0.2, 42 | sgdstore.Tanh, 32, 256, 32), 43 | Mixer: &anynet.AddMixer{ 44 | In1: anynet.NewFC(c, 384, 64), 45 | In2: anynet.NewFC(c, 64, 64), 46 | Out: anynet.Tanh, 47 | }, 48 | }, 49 | &anyrnn.LayerBlock{ 50 | Layer: anynet.Net{ 51 | anynet.NewFC(c, 64, outCount), 52 | anynet.LogSoftmax, 53 | }, 54 | }, 55 | } 56 | case "lstm": 57 | return anyrnn.Stack{ 58 | normInputLayer(c, outCount, numPixels), 59 | anyrnn.NewLSTM(c, numPixels+outCount, 384), 60 | anyrnn.NewLSTM(c, 384, 384).ScaleInWeights(c.MakeNumeric(2)), 61 | &anyrnn.LayerBlock{ 62 | Layer: anynet.Net{ 63 | anynet.NewFC(c, 384, outCount), 64 | anynet.LogSoftmax, 65 | }, 66 | }, 67 | } 68 | case "vanilla": 69 | return anyrnn.Stack{ 70 | normInputLayer(c, outCount, numPixels), 71 | anyrnn.NewVanilla(c, numPixels+outCount, 384, anynet.Tanh), 72 | anyrnn.NewVanilla(c, 384, 384, anynet.Tanh), 73 | &anyrnn.LayerBlock{ 74 | Layer: anynet.Net{ 75 | anynet.NewFC(c, 384, outCount), 76 | anynet.LogSoftmax, 77 | }, 78 | }, 79 | } 80 | default: 81 | essentials.Die("unknown model:", name) 82 | panic("unreachable") 83 | } 84 | } 85 | 86 | func CountParams(b anyrnn.Block) int { 87 | var res int 88 | for _, p := range anynet.AllParameters(b) { 89 | res += p.Vector.Len() 90 | } 91 | return res 92 | } 93 | 94 | func normInputLayer(c anyvec.Creator, numOut, numPixels int) anyrnn.Block { 95 | affine := &anynet.Affine{ 96 | Scalers: anydiff.NewVar(c.MakeVector(numPixels + numOut)), 97 | Biases: anydiff.NewVar(c.MakeVector(numPixels + numOut)), 98 | } 99 | affine.Scalers.Vector.AddScalar(c.MakeNumeric(4)) 100 | 101 | modified := affine.Scalers.Vector.Slice(numPixels, numPixels+numOut) 102 | modified.Scale(c.MakeNumeric(4)) 103 | 104 | modified = affine.Biases.Vector.Slice(0, numPixels) 105 | modified.AddScalar(c.MakeNumeric(-4 * 0.92)) 106 | 107 | return &anyrnn.LayerBlock{Layer: affine} 108 | } 109 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/log_to_matlab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 2 ]; then 4 | echo 'Usage: to_matlab.sh ' >&2 5 | exit 1 6 | fi 7 | 8 | filename=$(basename "$1") 9 | varname=${filename%.txt} 10 | 11 | yvalues=$(cat "$1" | grep 'validation=[0-9]' | 12 | sed -E 's/^.*validation=([0-9\.]*).*$/ \1/g') 13 | numlines=$(echo "$yvalues" | wc -l | sed -E 's/ //g') 14 | 15 | echo "${varname}_x = [1:$2:${numlines}*$2];" 16 | echo "${varname}_y = [" 17 | echo "$yvalues" 18 | echo "];" 19 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/sgdstore/82d72611cde5e1b77d7b289d6e2d8973d513ae45/experiments/omniglot/plot/plot.png -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/sgdstore/82d72611cde5e1b77d7b289d6e2d8973d513ae45/experiments/omniglot/plot/plot15.png -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_15.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Creating plot_file.m ...' 4 | 5 | TO_MATLAB="./log_to_matlab.sh" 6 | $TO_MATLAB ../log/15_classes/lstm_log.txt 64 >plot_file.m 7 | $TO_MATLAB ../log/15_classes/sgdstore_log.txt 64 >>plot_file.m 8 | echo -n 'plot(' >>plot_file.m 9 | first=1 10 | for name in lstm sgdstore 11 | do 12 | if [ $first -eq 1 ]; then 13 | first=0 14 | else 15 | echo -n ', ' >>plot_file.m 16 | fi 17 | echo -n "${name}_log_x, smooth_data(${name}_log_y), " >>plot_file.m 18 | echo -n "'linewidth', 2, ';${name};'" >>plot_file.m 19 | done 20 | echo ");" >>plot_file.m 21 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Creating plot_file.m ...' 4 | 5 | TO_MATLAB="./log_to_matlab.sh" 6 | $TO_MATLAB ../log/5_classes/lstm_log.txt 256 >plot_file.m 7 | $TO_MATLAB ../log/5_classes/sgdstore_log.txt 256 >>plot_file.m 8 | $TO_MATLAB ../log/5_classes/vanilla_log.txt 256 >>plot_file.m 9 | echo -n 'plot(' >>plot_file.m 10 | first=1 11 | for name in lstm sgdstore vanilla 12 | do 13 | if [ $first -eq 1 ]; then 14 | first=0 15 | else 16 | echo -n ', ' >>plot_file.m 17 | fi 18 | echo -n "${name}_log_x, smooth_data(${name}_log_y), " >>plot_file.m 19 | echo -n "'linewidth', 2, ';${name};'" >>plot_file.m 20 | done 21 | echo ");" >>plot_file.m 22 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/sgdstore/82d72611cde5e1b77d7b289d6e2d8973d513ae45/experiments/omniglot/plot/plot_lr.png -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_lr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Creating plot_file.m ...' 4 | 5 | TO_MATLAB="./log_to_matlab.sh" 6 | $TO_MATLAB ../log/lr_search/sgdstore_0001_log.txt 64 >plot_file.m 7 | $TO_MATLAB ../log/lr_search/sgdstore_001_log.txt 64 >>plot_file.m 8 | $TO_MATLAB ../log/lr_search/sgdstore_0003_log.txt 64 >>plot_file.m 9 | echo -n 'plot(' >>plot_file.m 10 | first=1 11 | for name in sgdstore_0001 sgdstore_0003 sgdstore_001 12 | do 13 | cleanName=$(echo $name | tr _ ' ') 14 | if [ $first -eq 1 ]; then 15 | first=0 16 | else 17 | echo -n ', ' >>plot_file.m 18 | fi 19 | echo -n "${name}_log_x, smooth_data(${name}_log_y), " >>plot_file.m 20 | echo -n "'linewidth', 2, ';${cleanName};'" >>plot_file.m 21 | done 22 | echo ");" >>plot_file.m 23 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/sgdstore/82d72611cde5e1b77d7b289d6e2d8973d513ae45/experiments/omniglot/plot/plot_mem.png -------------------------------------------------------------------------------- /experiments/omniglot/plot/plot_mem.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Creating plot_file.m ...' 4 | 5 | TO_MATLAB="./log_to_matlab.sh" 6 | $TO_MATLAB ../log/mem_search/sgdstore_4rh_log.txt 64 >plot_file.m 7 | $TO_MATLAB ../log/mem_search/sgdstore_deep_log.txt 64 >>plot_file.m 8 | $TO_MATLAB ../log/15_classes/sgdstore_log.txt 64 >>plot_file.m 9 | echo -n 'plot(' >>plot_file.m 10 | first=1 11 | for name in sgdstore_4rh sgdstore_deep sgdstore 12 | do 13 | cleanName=$(echo $name | tr _ ' ') 14 | if [ $first -eq 1 ]; then 15 | first=0 16 | else 17 | echo -n ', ' >>plot_file.m 18 | fi 19 | echo -n "${name}_log_x, smooth_data(${name}_log_y), " >>plot_file.m 20 | echo -n "'linewidth', 2, ';${cleanName};'" >>plot_file.m 21 | done 22 | echo ");" >>plot_file.m 23 | -------------------------------------------------------------------------------- /experiments/omniglot/plot/smooth_data.m: -------------------------------------------------------------------------------- 1 | function [smoothed] = smooth_data(data) 2 | rolling = mean(data(1:10)); 3 | smoothed = zeros(size(data)); 4 | for i=1:size(data) 5 | rolling = rolling + 0.05*(data(i)-rolling); 6 | smoothed(i) = rolling; 7 | end 8 | end 9 | -------------------------------------------------------------------------------- /experiments/omniglot/samples.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/unixpickle/anynet/anys2s" 7 | "github.com/unixpickle/anynet/anysgd" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | "github.com/unixpickle/omniglot" 11 | ) 12 | 13 | const ImageSize = 20 14 | 15 | // Samples is a dummy anys2s.SampleList which generates 16 | // random episodes. 17 | type Samples struct { 18 | Length int 19 | Sets []omniglot.Set 20 | Augment bool 21 | 22 | NumClasses int 23 | NumTimesteps int 24 | } 25 | 26 | func (s *Samples) Len() int { 27 | return s.Length 28 | } 29 | 30 | func (s *Samples) Swap(i, j int) { 31 | } 32 | 33 | func (s *Samples) Slice(i, j int) anysgd.SampleList { 34 | return &Samples{ 35 | Length: j - i, 36 | Sets: s.Sets, 37 | Augment: s.Augment, 38 | NumClasses: s.NumClasses, 39 | NumTimesteps: s.NumTimesteps, 40 | } 41 | } 42 | 43 | func (s *Samples) Creator() anyvec.Creator { 44 | return anyvec32.CurrentCreator() 45 | } 46 | 47 | func (s *Samples) GetSample(i int) (*anys2s.Sample, error) { 48 | c := s.Creator() 49 | samples, classes := s.episode() 50 | oneHot := make([]float64, s.NumClasses) 51 | var inputs, outputs []anyvec.Vector 52 | for i, sample := range samples { 53 | img, err := sample.Image(s.Augment, ImageSize) 54 | if err != nil { 55 | return nil, err 56 | } 57 | class := classes[i] 58 | inVec := append(omniglot.Tensor(img), oneHot...) 59 | oneHot = make([]float64, s.NumClasses) 60 | oneHot[class] = 1 61 | inputs = append(inputs, c.MakeVectorData(c.MakeNumericList(inVec))) 62 | outputs = append(outputs, c.MakeVectorData(c.MakeNumericList(oneHot))) 63 | } 64 | return &anys2s.Sample{Input: inputs, Output: outputs}, nil 65 | } 66 | 67 | func (s *Samples) episode() (samples []*omniglot.AugSample, classes []int) { 68 | for class, setIdx := range rand.Perm(len(s.Sets))[:s.NumClasses] { 69 | set := s.Sets[setIdx] 70 | for _, x := range set { 71 | samples = append(samples, x) 72 | classes = append(classes, class) 73 | } 74 | } 75 | for i := 0; i < len(samples); i++ { 76 | idx := rand.Intn(len(samples)-i) + i 77 | samples[i], samples[idx] = samples[idx], samples[i] 78 | classes[i], classes[idx] = classes[idx], classes[i] 79 | } 80 | if len(samples) > s.NumTimesteps { 81 | samples = samples[:s.NumTimesteps] 82 | classes = classes[:s.NumTimesteps] 83 | } 84 | return 85 | } 86 | -------------------------------------------------------------------------------- /experiments/omniglot/train.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "math/rand" 8 | "os" 9 | "time" 10 | 11 | "github.com/unixpickle/anydiff/anyseq" 12 | "github.com/unixpickle/anynet" 13 | "github.com/unixpickle/anynet/anyrnn" 14 | "github.com/unixpickle/anynet/anys2s" 15 | "github.com/unixpickle/anynet/anysgd" 16 | "github.com/unixpickle/anyvec" 17 | "github.com/unixpickle/essentials" 18 | "github.com/unixpickle/omniglot" 19 | "github.com/unixpickle/rip" 20 | "github.com/unixpickle/serializer" 21 | ) 22 | 23 | func Train(args []string) { 24 | var trainingPath string 25 | var testingPath string 26 | var modelPath string 27 | var modelType string 28 | var stepSize float64 29 | var sgdSteps int 30 | var batchSize int 31 | var numClasses int 32 | var episodeLen int 33 | 34 | fs := flag.NewFlagSet("train", flag.ExitOnError) 35 | fs.StringVar(&trainingPath, "training", "", "training data directory") 36 | fs.StringVar(&testingPath, "testing", "", "testing data directory") 37 | fs.StringVar(&modelPath, "out", "model_out", "model output path") 38 | fs.StringVar(&modelType, "model", "sgdstore", 39 | "model type (sgdstore, lstm, parasgdstore, or vanilla)") 40 | fs.Float64Var(&stepSize, "step", 0.001, "SGD step size") 41 | fs.IntVar(&sgdSteps, "steps", 1, "steps per sgdstore") 42 | fs.IntVar(&batchSize, "batch", 16, "SGD batch size") 43 | fs.IntVar(&numClasses, "classes", 5, "classes per episode") 44 | fs.IntVar(&episodeLen, "eplen", 50, "episode length") 45 | 46 | fs.Parse(args) 47 | 48 | rand.Seed(time.Now().UnixNano()) 49 | 50 | if trainingPath == "" || testingPath == "" { 51 | essentials.Die("Required flags: -training and -testing. See -help.") 52 | } 53 | 54 | var model anyrnn.Block 55 | if err := serializer.LoadAny(modelPath, &model); err != nil { 56 | log.Println("Creating new model.") 57 | model = NewModel(modelType, sgdSteps, numClasses) 58 | } else { 59 | log.Println("Loaded model.") 60 | } 61 | 62 | training, err := omniglot.ReadSet(trainingPath) 63 | if err != nil { 64 | essentials.Die(err) 65 | } 66 | training = training.Augment() 67 | 68 | testing, err := omniglot.ReadSet(testingPath) 69 | if err != nil { 70 | essentials.Die(err) 71 | } 72 | testing = testing.Augment() 73 | 74 | samples := &Samples{ 75 | Length: batchSize, 76 | Sets: training.ByClass(), 77 | Augment: true, 78 | NumClasses: numClasses, 79 | NumTimesteps: episodeLen, 80 | } 81 | testSamples := *samples 82 | testSamples.Sets = testing.ByClass() 83 | trainer := &anys2s.Trainer{ 84 | Func: func(s anyseq.Seq) anyseq.Seq { 85 | return anyrnn.Map(s, model) 86 | }, 87 | Params: model.(anynet.Parameterizer).Parameters(), 88 | Cost: anynet.DotCost{}, 89 | Average: true, 90 | } 91 | 92 | var iter int 93 | sgd := &anysgd.SGD{ 94 | Gradienter: trainer, 95 | Fetcher: trainer, 96 | Transformer: &anysgd.RMSProp{}, 97 | Rater: anysgd.ConstRater(stepSize), 98 | Samples: samples, 99 | BatchSize: batchSize, 100 | StatusFunc: func(b anysgd.Batch) { 101 | if iter%4 == 0 { 102 | batch, err := trainer.Fetch(&testSamples) 103 | if err != nil { 104 | essentials.Die(err) 105 | } 106 | cost := trainer.TotalCost(batch) 107 | log.Printf("iter %d: cost=%v validation=%f", iter, trainer.LastCost, 108 | anyvec.Sum(cost.Output())) 109 | } else { 110 | log.Printf("iter %d: cost=%v", iter, trainer.LastCost) 111 | } 112 | iter++ 113 | }, 114 | } 115 | 116 | if err := sgd.Run(rip.NewRIP().Chan()); err != nil { 117 | fmt.Fprintln(os.Stderr, err) 118 | } 119 | 120 | if err := serializer.SaveAny(modelPath, model); err != nil { 121 | essentials.Die(err) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /experiments/poly_approx/README.md: -------------------------------------------------------------------------------- 1 | # Polynomial Approximation 2 | 3 | In this experiment, an RNN is trained on a simple meta-learning task. During each episode of this task, the RNN has to learn to approximate a new, random polynomial. 4 | 5 | Before each episode, the trainer generates a random third-degree polynomial of three variables (e.g. `xyz + 0.5xy^2 + 1 + ...`). At each timestep of the episode, the RNN receives a tuple `(out_{t-1}, x, y, z)` and must predict the polynomial's value at `(x, y, z)`. At timestep `t`, `out_{t-1}` is the value which should have been predicted at timestep `t-1`. This way, the RNN can immediately see its mistakes and learn from them. Each episode is 64 timesteps in total. The RNN's goal is to minimize the mean-squared error of its predictions over the entire episode. 6 | 7 | # Results 8 | 9 | For these experiments, I tried an LSTM and three variants of the same sgdstore model. The LSTM had 51521 parameters, whereas the sgdstore models only had 6578 parameters. Nonetheless, the LSTM was much faster in terms of wall-clock time. 10 | 11 | I tried three different settings for the sgdstore models. The SS-1, SS-2, and SS-3 models run SGD at each timestep for 1, 2, and 3 iterations respectively. I was expecting SS-3 to win out, since the model gets to train its MLP for longer at each timestep. However, SS-3 did not learn as fast as SS-1 in the short-run. I stopped SS-2 early because I thought the results of SS-1 and SS-3 would be more meaningful. 12 | 13 | Here's a graph of the MSE (y-axis) versus the iteration (x-axis; batches of 16): 14 | 15 | ![Model comparison graph](plot/plot.png) 16 | 17 | The graphs make it look like the sgdstore model learns faster than the LSTM. However, keep in mind that the LSTM is several times faster in terms of wall-clock time. 18 | -------------------------------------------------------------------------------- /experiments/poly_approx/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anyrnn" 10 | "github.com/unixpickle/anynet/anys2s" 11 | "github.com/unixpickle/anynet/anysgd" 12 | "github.com/unixpickle/rip" 13 | ) 14 | 15 | const EpisodeLen = 64 16 | 17 | func main() { 18 | var batchSize int 19 | var stepSize float64 20 | var modelName string 21 | var sgdSteps int 22 | 23 | flag.IntVar(&batchSize, "batch", 16, "SGD batch size") 24 | flag.Float64Var(&stepSize, "step", 0.0003, "SGD step size") 25 | flag.StringVar(&modelName, "model", "sgdstore", "RNN type (sgdstore or lstm)") 26 | flag.IntVar(&sgdSteps, "sgdsteps", 2, "SGD steps for sgdstore") 27 | 28 | flag.Parse() 29 | 30 | model := NewModel(modelName, sgdSteps) 31 | 32 | log.Printf("Model has %d parameters.", CountParams(model)) 33 | 34 | samples := SampleList(batchSize) 35 | 36 | trainer := &anys2s.Trainer{ 37 | Func: func(s anyseq.Seq) anyseq.Seq { 38 | return anyrnn.Map(s, model) 39 | }, 40 | Cost: anynet.MSE{}, 41 | Params: model.(anynet.Parameterizer).Parameters(), 42 | Average: true, 43 | } 44 | var iter int 45 | sgd := &anysgd.SGD{ 46 | Fetcher: trainer, 47 | Gradienter: trainer, 48 | Transformer: &anysgd.Adam{}, 49 | Samples: samples, 50 | BatchSize: batchSize, 51 | Rater: anysgd.ConstRater(stepSize), 52 | StatusFunc: func(b anysgd.Batch) { 53 | log.Printf("iter %d: cost=%v", iter, trainer.LastCost) 54 | iter++ 55 | }, 56 | } 57 | sgd.Run(rip.NewRIP().Chan()) 58 | } 59 | -------------------------------------------------------------------------------- /experiments/poly_approx/model.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anynet/anyrnn" 7 | "github.com/unixpickle/anyvec/anyvec32" 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/sgdstore" 10 | ) 11 | 12 | func NewModel(name string, sgdSteps int) anyrnn.Block { 13 | c := anyvec32.CurrentCreator() 14 | switch name { 15 | case "sgdstore": 16 | return anyrnn.Stack{ 17 | &anyrnn.Feedback{ 18 | InitOut: anydiff.NewVar(c.MakeVector(32)), 19 | Mixer: anynet.ConcatMixer{}, 20 | Block: anyrnn.Stack{ 21 | &anyrnn.LayerBlock{ 22 | Layer: anynet.Net{ 23 | anynet.NewFC(c, 4+32, 32), 24 | anynet.Tanh, 25 | anynet.NewFC(c, 32, 32), 26 | anynet.Tanh, 27 | }, 28 | }, 29 | sgdstore.LinearBlock(c, 32, 2, 2, sgdSteps, 1, sgdstore.Tanh, 30 | 16, 32, 16), 31 | }, 32 | }, 33 | &anyrnn.LayerBlock{Layer: anynet.NewFC(c, 32, 1)}, 34 | } 35 | case "lstm": 36 | return anyrnn.Stack{ 37 | anyrnn.NewLSTM(c, 4, 64), 38 | anyrnn.NewLSTM(c, 64, 64), 39 | &anyrnn.LayerBlock{Layer: anynet.NewFC(c, 64, 1)}, 40 | } 41 | default: 42 | essentials.Die("unknown model:", name) 43 | panic("unreachable") 44 | } 45 | } 46 | 47 | func CountParams(b anyrnn.Block) int { 48 | var res int 49 | for _, p := range anynet.AllParameters(b) { 50 | res += p.Vector.Len() 51 | } 52 | return res 53 | } 54 | -------------------------------------------------------------------------------- /experiments/poly_approx/plot/log_to_matlab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo 'Usage: to_matlab.sh ' >&2 5 | exit 1 6 | fi 7 | 8 | filename=$(basename "$1") 9 | varname=${filename%.txt} 10 | 11 | yvalues=$(cat "$1" | grep 'cost=[0-9]' | 12 | sed -E 's/^.*cost=([0-9\.]*).*$/ \1/g') 13 | numlines=$(echo "$yvalues" | wc -l | sed -E 's/ //g') 14 | 15 | echo "${varname}_x = [1:$numlines];" 16 | echo "${varname}_y = [" 17 | echo "$yvalues" 18 | echo "];" 19 | -------------------------------------------------------------------------------- /experiments/poly_approx/plot/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/sgdstore/82d72611cde5e1b77d7b289d6e2d8973d513ae45/experiments/poly_approx/plot/plot.png -------------------------------------------------------------------------------- /experiments/poly_approx/plot/plot_file.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Creating plot_file.m ...' 4 | 5 | TO_MATLAB="./log_to_matlab.sh" 6 | $TO_MATLAB ../log/sgdstore_1step.txt >plot_file.m 7 | $TO_MATLAB ../log/sgdstore_2step.txt >>plot_file.m 8 | $TO_MATLAB ../log/sgdstore_3step.txt >>plot_file.m 9 | $TO_MATLAB ../log/lstm.txt >>plot_file.m 10 | echo -n 'plot(' >>plot_file.m 11 | for i in {1..3} 12 | do 13 | echo -n "sgdstore_${i}step_x, smooth_data(sgdstore_${i}step_y), " >>plot_file.m 14 | echo -n "'linewidth', 2, ';SS-${i};', " >>plot_file.m 15 | done 16 | echo "lstm_x(1:10500), smooth_data(lstm_y(1:10500)), 'linewidth', 2, ';LSTM;');" >>plot_file.m 17 | -------------------------------------------------------------------------------- /experiments/poly_approx/plot/smooth_data.m: -------------------------------------------------------------------------------- 1 | function [smoothed] = smooth_data(data) 2 | rolling = mean(data(1:10)); 3 | smoothed = zeros(size(data)); 4 | for i=1:size(data) 5 | rolling = rolling + 0.005*(data(i)-rolling); 6 | smoothed(i) = rolling; 7 | end 8 | end 9 | -------------------------------------------------------------------------------- /experiments/poly_approx/samples.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/unixpickle/anynet/anys2s" 8 | "github.com/unixpickle/anynet/anysgd" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec32" 11 | ) 12 | 13 | type SampleList int 14 | 15 | func (s SampleList) Len() int { 16 | return int(s) 17 | } 18 | 19 | func (s SampleList) Swap(i, j int) { 20 | } 21 | 22 | func (s SampleList) Slice(i, j int) anysgd.SampleList { 23 | return SampleList(j - i) 24 | } 25 | 26 | func (s SampleList) Creator() anyvec.Creator { 27 | return anyvec32.CurrentCreator() 28 | } 29 | 30 | func (s SampleList) GetSample(i int) (*anys2s.Sample, error) { 31 | poly := RandomPoly() 32 | 33 | var sample anys2s.Sample 34 | var lastValue float32 35 | 36 | for i := 0; i < EpisodeLen; i++ { 37 | inVec := make([]float32, 4) 38 | for j := range inVec[1:] { 39 | inVec[j+1] = float32(rand.NormFloat64()) 40 | } 41 | inVec[0] = lastValue 42 | lastValue = poly.Eval(inVec[1:]) 43 | outVec := []float32{lastValue} 44 | sample.Input = append(sample.Input, anyvec32.MakeVectorData(inVec)) 45 | sample.Output = append(sample.Output, anyvec32.MakeVectorData(outVec)) 46 | } 47 | 48 | return &sample, nil 49 | } 50 | 51 | type Term struct { 52 | X int 53 | Y int 54 | Z int 55 | 56 | Coeff float32 57 | } 58 | 59 | func (t Term) Eval(coord []float32) float32 { 60 | res := t.Coeff 61 | for i, pow := range []int{t.X, t.Y, t.Z} { 62 | res *= float32(math.Pow(float64(coord[i]), float64(pow))) 63 | } 64 | return res 65 | } 66 | 67 | type Poly []Term 68 | 69 | func RandomPoly() Poly { 70 | var poly Poly 71 | for x := 0; x <= 3; x++ { 72 | for y := 0; y <= 3-x; y++ { 73 | for z := 0; z <= 3-(x+y); z++ { 74 | poly = append(poly, Term{ 75 | X: x, 76 | Y: y, 77 | Z: z, 78 | Coeff: float32(rand.NormFloat64()) / 5, 79 | }) 80 | } 81 | } 82 | } 83 | return poly 84 | } 85 | 86 | func (p Poly) Eval(coord []float32) float32 { 87 | var sum float32 88 | for _, t := range p { 89 | sum += t.Eval(coord) 90 | } 91 | return sum 92 | } 93 | -------------------------------------------------------------------------------- /net.go: -------------------------------------------------------------------------------- 1 | package sgdstore 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/unixpickle/anydiff" 7 | ) 8 | 9 | // NetBatch is a batch of dynamic feed-forward multi-layer 10 | // perceptrons. 11 | // 12 | // Each layer is implicitly followed by an activation. 13 | type Net struct { 14 | // Parameters stores the weights and biases of the 15 | // network. 16 | // Each even index corresponds to a batch of weight 17 | // matrices. 18 | // Each odd index corresponds to a batch of bias vectors. 19 | // Matrices are row-major. 20 | // 21 | // This should not be empty. 22 | Parameters anydiff.MultiRes 23 | 24 | // Num is the number of networks in the batch. 25 | Num int 26 | 27 | // Activation is the activation function. 28 | Activation Activation 29 | } 30 | 31 | // Apply applies the networks to a batch of input batches, 32 | // producing a batch of output batches. 33 | func (n *Net) Apply(inBatch anydiff.Res, batchSize int) anydiff.Res { 34 | return anydiff.Unfuse(n.Parameters, func(params []anydiff.Res) anydiff.Res { 35 | if len(params)%2 != 0 { 36 | panic("mismatching bias and weight count") 37 | } 38 | for i := 0; i < len(params); i += 2 { 39 | inBatch = n.applyLayer(params[i], params[i+1], inBatch, batchSize, n.Num) 40 | } 41 | return inBatch 42 | }) 43 | } 44 | 45 | // InSize calculates the input size of the network using 46 | // the dimensions of the first layer. 47 | // 48 | // This is invariant to n.Num. 49 | func (n *Net) InSize() int { 50 | if len(n.Parameters.Outputs()) < 2 { 51 | panic("network cannot be empty") 52 | } 53 | return n.Parameters.Outputs()[0].Len() / n.Parameters.Outputs()[1].Len() 54 | } 55 | 56 | // Train performs SGD training on the batch. 57 | // 58 | // The input, target, and stepSize needn't be pooled by 59 | // the caller. 60 | func (n *Net) Train(inBatch, target, stepSize anydiff.Res, batchSize, 61 | numSteps int) *Net { 62 | if stepSize.Output().Len() != n.Num { 63 | panic("invalid stepSize length") 64 | } 65 | ins := anydiff.Fuse(inBatch, target, stepSize) 66 | newParams := anydiff.PoolMulti(ins, func(s []anydiff.Res) anydiff.MultiRes { 67 | inBatch, target, stepSize := s[0], s[1], s[2] 68 | net := n 69 | for i := 0; i < numSteps; i++ { 70 | net = net.step(inBatch, target, stepSize, batchSize) 71 | } 72 | return net.Parameters 73 | }) 74 | return &Net{Parameters: newParams, Num: n.Num} 75 | } 76 | 77 | // step performs a step of gradient descent and returns 78 | // the new network. 79 | // 80 | // The input, target, and stepSize should be pooled by the 81 | // caller. 82 | func (n *Net) step(inBatch, target, stepSize anydiff.Res, batchSize int) *Net { 83 | newParams := anydiff.PoolMulti(n.Parameters, func(params []anydiff.Res) anydiff.MultiRes { 84 | grad := n.applyBackprop(params, inBatch, target, batchSize, n.Num) 85 | return anydiff.PoolMulti(grad, func(grads []anydiff.Res) anydiff.MultiRes { 86 | var newParams []anydiff.Res 87 | for i, g := range grads[1:] { 88 | gMat := &anydiff.Matrix{ 89 | Data: g, 90 | Rows: n.Num, 91 | Cols: g.Output().Len() / n.Num, 92 | } 93 | p := anydiff.Add(params[i], anydiff.ScaleRows(gMat, stepSize).Data) 94 | newParams = append(newParams, p) 95 | } 96 | return anydiff.Fuse(newParams...) 97 | }) 98 | }) 99 | return &Net{Parameters: newParams, Num: n.Num} 100 | } 101 | 102 | // applyLayer applies a single layer. 103 | func (n *Net) applyLayer(weights, biases, inBatch anydiff.Res, batchSize, 104 | numNets int) anydiff.Res { 105 | inMat, weightMat := layerMats(weights, biases, inBatch, batchSize, numNets) 106 | inBatch = anydiff.BatchedMatMul(false, true, inMat, weightMat).Data 107 | return n.Activation.Forward(batchedAddRepeated(inBatch, biases, numNets)) 108 | } 109 | 110 | // applyBackprop applies the networks and performs 111 | // backward-propagation. 112 | // The result is [inGrad, param1Grad, param2Grad, ...]. 113 | // The caller should pool the input parameters. 114 | func (n *Net) applyBackprop(params []anydiff.Res, in, target anydiff.Res, 115 | batchSize, numNets int) anydiff.MultiRes { 116 | if len(params) == 0 { 117 | scaler := target.Output().Creator().MakeNumeric( 118 | 2 / float64(target.Output().Len()/numNets), 119 | ) 120 | if target.Output().Len() != in.Output().Len() { 121 | panic(fmt.Sprintf("target length %d (expected %d)", target.Output().Len(), 122 | in.Output().Len())) 123 | } 124 | return anydiff.Fuse(anydiff.Scale(anydiff.Sub(target, in), scaler)) 125 | } 126 | inMat, weightMat := layerMats(params[0], params[1], in, batchSize, numNets) 127 | matOut := anydiff.BatchedMatMul(false, true, inMat, weightMat).Data 128 | biasOut := batchedAddRepeated(matOut, params[1], numNets) 129 | actOut := n.Activation.Forward(biasOut) 130 | return anydiff.PoolFork(actOut, func(actOut anydiff.Res) anydiff.MultiRes { 131 | nextOut := n.applyBackprop(params[2:], actOut, target, batchSize, numNets) 132 | return anydiff.PoolMulti(nextOut, func(x []anydiff.Res) anydiff.MultiRes { 133 | outGrad := x[0] 134 | laterGrads := x[1:] 135 | pg := n.Activation.Backward(actOut, outGrad) 136 | return anydiff.PoolFork(pg, func(pg anydiff.Res) anydiff.MultiRes { 137 | productGrad := &anydiff.MatrixBatch{ 138 | Data: pg, 139 | Rows: batchSize, 140 | Cols: weightMat.Rows, 141 | Num: numNets, 142 | } 143 | weightGrad := anydiff.BatchedMatMul(true, false, productGrad, inMat).Data 144 | biasGrad := batchedSumRows(productGrad) 145 | inGrad := anydiff.BatchedMatMul(false, false, productGrad, weightMat).Data 146 | ourGrad := []anydiff.Res{inGrad, weightGrad, biasGrad} 147 | return anydiff.Fuse(append(ourGrad, laterGrads...)...) 148 | }) 149 | }) 150 | }) 151 | } 152 | 153 | func layerMats(weights, biases, inBatch anydiff.Res, batchSize, numNets int) (inMat, 154 | weightMat *anydiff.MatrixBatch) { 155 | outSize := biases.Output().Len() / numNets 156 | inSize := weights.Output().Len() / (outSize * numNets) 157 | if inSize*batchSize*numNets != inBatch.Output().Len() { 158 | panic(fmt.Sprintf("input size %d should be %d", 159 | inBatch.Output().Len()/(batchSize*numNets), inSize)) 160 | } 161 | inMat = &anydiff.MatrixBatch{ 162 | Data: inBatch, 163 | Rows: batchSize, 164 | Cols: inSize, 165 | Num: numNets, 166 | } 167 | weightMat = &anydiff.MatrixBatch{ 168 | Data: weights, 169 | Rows: outSize, 170 | Cols: inSize, 171 | Num: numNets, 172 | } 173 | return 174 | } 175 | 176 | func batchedAddRepeated(vec, biases anydiff.Res, n int) anydiff.Res { 177 | return anydiff.Pool(vec, func(vec anydiff.Res) anydiff.Res { 178 | return anydiff.Pool(biases, func(biases anydiff.Res) anydiff.Res { 179 | biasVecs := splitVec(biases, n) 180 | var res []anydiff.Res 181 | for i, v := range splitVec(vec, n) { 182 | b := biasVecs[i] 183 | res = append(res, anydiff.AddRepeated(v, b)) 184 | } 185 | return anydiff.Concat(res...) 186 | }) 187 | }) 188 | } 189 | 190 | func batchedSumRows(m *anydiff.MatrixBatch) anydiff.Res { 191 | return anydiff.Pool(m.Data, func(data anydiff.Res) anydiff.Res { 192 | var sums []anydiff.Res 193 | for _, matData := range splitVec(data, m.Num) { 194 | matrix := &anydiff.Matrix{Data: matData, Rows: m.Rows, Cols: m.Cols} 195 | sums = append(sums, anydiff.SumRows(matrix)) 196 | } 197 | return anydiff.Concat(sums...) 198 | }) 199 | } 200 | 201 | func splitVec(vec anydiff.Res, n int) []anydiff.Res { 202 | chunkSize := vec.Output().Len() / n 203 | var chunks []anydiff.Res 204 | for i := 0; i < n; i++ { 205 | chunks = append(chunks, anydiff.Slice(vec, i*chunkSize, (i+1)*chunkSize)) 206 | } 207 | return chunks 208 | } 209 | -------------------------------------------------------------------------------- /net_test.go: -------------------------------------------------------------------------------- 1 | package sgdstore 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | ) 12 | 13 | func TestNetApply(t *testing.T) { 14 | c := anyvec64.CurrentCreator() 15 | realNet, virtualNet := randomNetwork(c) 16 | 17 | inVec := c.MakeVector(12) 18 | anyvec.Rand(inVec, anyvec.Normal, nil) 19 | input := anydiff.NewVar(inVec) 20 | 21 | expectedOut := realNet.Apply(input, 4).Output() 22 | actualOut := virtualNet.Apply(input, 4).Output() 23 | 24 | diff := actualOut.Copy() 25 | diff.Sub(expectedOut) 26 | maxDiff := anyvec.AbsMax(diff).(float64) 27 | 28 | if maxDiff > 1e-4 { 29 | t.Errorf("expected %v but got %v", expectedOut.Data(), actualOut.Data()) 30 | } 31 | } 32 | 33 | func TestNetTrain(t *testing.T) { 34 | c := anyvec64.CurrentCreator() 35 | realNet, virtualNet := randomNetwork(c) 36 | 37 | inVec := c.MakeVector(12) 38 | anyvec.Rand(inVec, anyvec.Normal, nil) 39 | input := anydiff.NewVar(inVec) 40 | 41 | targetVec := c.MakeVector(8) 42 | anyvec.Rand(targetVec, anyvec.Normal, nil) 43 | target := anydiff.NewVar(targetVec) 44 | 45 | stepSize := c.MakeVector(1) 46 | stepSize.AddScalar(c.MakeNumeric(0.1)) 47 | 48 | trained := virtualNet.Train(input, target, anydiff.NewConst(stepSize), 4, 2) 49 | actual := trained.Parameters.Outputs() 50 | 51 | for i := 0; i < 2; i++ { 52 | out := realNet.Apply(input, 4) 53 | cost := anynet.MSE{}.Cost(target, out, 1) 54 | grad := anydiff.NewGrad(realNet.Parameters()...) 55 | cost.Propagate(stepSize.Copy(), grad) 56 | grad.Scale(c.MakeNumeric(-1)) 57 | grad.AddToVars() 58 | } 59 | for i, a := range actual { 60 | x := realNet.Parameters()[i].Vector 61 | diff := x.Copy() 62 | diff.Sub(a) 63 | maxDiff := anyvec.AbsMax(diff).(float64) 64 | if maxDiff > 1e-4 { 65 | t.Error("bad value for layer", i) 66 | } 67 | } 68 | } 69 | 70 | func TestNetBatched(t *testing.T) { 71 | c := anyvec64.CurrentCreator() 72 | _, net1 := randomNetwork(c) 73 | _, net2 := randomNetwork(c) 74 | joined := joinNets(net1, net2) 75 | 76 | inBatch := anydiff.NewVar(c.MakeVector(3 * 5 * 2)) 77 | target := anydiff.NewVar(c.MakeVector(2 * 5 * 2)) 78 | 79 | anyvec.Rand(inBatch.Vector, anyvec.Normal, nil) 80 | 81 | t.Run("Apply", func(t *testing.T) { 82 | out1 := net1.Apply(anydiff.Slice(inBatch, 0, 3*5), 5).Output() 83 | out2 := net2.Apply(anydiff.Slice(inBatch, 3*5, 3*5*2), 5).Output() 84 | actual := joined.Apply(inBatch, 5).Output() 85 | expected := c.Concat(out1, out2) 86 | if actual.Len() != expected.Len() { 87 | t.Error("length mismatch") 88 | return 89 | } 90 | diff := expected.Copy() 91 | diff.Sub(actual) 92 | maxDiff := anyvec.AbsMax(diff) 93 | if maxDiff.(float64) > 1e-4 { 94 | t.Errorf("bad output: expected %v but got %v", expected, actual) 95 | } 96 | }) 97 | 98 | t.Run("Train", func(t *testing.T) { 99 | stepSize := anydiff.NewVar(c.MakeVectorData([]float64{0.1, 0.2})) 100 | trained1 := net1.Train(anydiff.Slice(inBatch, 0, 3*5), 101 | anydiff.Slice(target, 0, 2*5), anydiff.Slice(stepSize, 0, 1), 5, 2) 102 | trained2 := net2.Train(anydiff.Slice(inBatch, 3*5, 3*5*2), 103 | anydiff.Slice(target, 2*5, 2*5*2), anydiff.Slice(stepSize, 1, 2), 5, 2) 104 | actual := joined.Train(inBatch, target, stepSize, 5, 2) 105 | expected := joinNets(trained1, trained2) 106 | for i, xParam := range expected.Parameters.Outputs() { 107 | aParam := actual.Parameters.Outputs()[i] 108 | diff := xParam.Copy() 109 | diff.Sub(aParam) 110 | maxDiff := anyvec.AbsMax(diff) 111 | if maxDiff.(float64) > 1e-4 { 112 | t.Errorf("bad training result: expected %v but got %v", xParam, aParam) 113 | } 114 | } 115 | }) 116 | } 117 | 118 | func BenchmarkNetwork(b *testing.B) { 119 | c := anyvec32.CurrentCreator() 120 | realNet := anynet.Net{ 121 | anynet.NewFC(c, 128, 256), 122 | anynet.NewFC(c, 256, 128), 123 | } 124 | var netParams []anydiff.Res 125 | for i, param := range realNet.Parameters() { 126 | if i%2 == 0 { 127 | anyvec.Rand(param.Vector, anyvec.Normal, nil) 128 | } 129 | netParams = append(netParams, param) 130 | } 131 | net := &Net{Parameters: anydiff.Fuse(netParams...), Num: 1} 132 | 133 | inBatch := anydiff.NewVar(c.MakeVector(512)) 134 | target := anydiff.NewVar(c.MakeVector(512)) 135 | anyvec.Rand(inBatch.Vector, anyvec.Normal, nil) 136 | anyvec.Rand(target.Vector, anyvec.Normal, nil) 137 | 138 | stepSize := anydiff.NewConst(c.MakeVector(1)) 139 | stepSize.Vector.AddScalar(float32(0.1)) 140 | 141 | b.Run("Forward", func(b *testing.B) { 142 | for i := 0; i < b.N; i++ { 143 | net.Train(inBatch, target, stepSize, 4, 1) 144 | } 145 | }) 146 | b.Run("Backward", func(b *testing.B) { 147 | grad := anydiff.NewGrad(append([]*anydiff.Var{inBatch, target}, 148 | realNet.Parameters()...)...) 149 | upstream := make([]anyvec.Vector, len(netParams)) 150 | for i, p := range netParams { 151 | upstream[i] = c.MakeVector(p.Output().Len()) 152 | anyvec.Rand(upstream[i], anyvec.Normal, nil) 153 | } 154 | b.ResetTimer() 155 | for i := 0; i < b.N; i++ { 156 | out := net.Train(inBatch, target, stepSize, 4, 1) 157 | out.Parameters.Propagate(upstream, grad) 158 | } 159 | }) 160 | } 161 | 162 | func randomNetwork(c anyvec.Creator) (anynet.Net, *Net) { 163 | realNet := anynet.Net{ 164 | anynet.NewFC(c, 3, 5), 165 | anynet.Tanh, 166 | anynet.NewFC(c, 5, 4), 167 | anynet.Tanh, 168 | anynet.NewFC(c, 4, 2), 169 | anynet.Tanh, 170 | } 171 | var netParams []anydiff.Res 172 | for i, param := range realNet.Parameters() { 173 | if i%2 == 0 { 174 | anyvec.Rand(param.Vector, anyvec.Normal, nil) 175 | } 176 | netParams = append(netParams, param) 177 | } 178 | return realNet, &Net{Parameters: anydiff.Fuse(netParams...), Num: 1} 179 | } 180 | 181 | func joinNets(n1, n2 *Net) *Net { 182 | return &Net{ 183 | Num: 2, 184 | Parameters: anydiff.PoolMulti(n1.Parameters, func(p1 []anydiff.Res) anydiff.MultiRes { 185 | return anydiff.PoolMulti(n2.Parameters, func(p2 []anydiff.Res) anydiff.MultiRes { 186 | var reses []anydiff.Res 187 | for i, x := range p1 { 188 | reses = append(reses, anydiff.Concat(x, p2[i])) 189 | } 190 | return anydiff.Fuse(reses...) 191 | }) 192 | }), 193 | } 194 | } 195 | --------------------------------------------------------------------------------