├── README.md ├── activation.go ├── affine.go ├── anyconv ├── batch_norm.go ├── batch_norm_test.go ├── conv.go ├── conv_test.go ├── conver.go ├── conver_test.go ├── im2row.go ├── image.go ├── image_test.go ├── markup.go ├── max_pool.go ├── max_pool_test.go ├── mean_pool.go ├── mean_pool_test.go ├── padding.go ├── padding_test.go ├── post_train.go ├── post_train_test.go ├── residual.go ├── residual_test.go ├── resize.go ├── resize_test.go └── util.go ├── anyctc ├── conversion.go ├── cost.go ├── cost_test.go ├── decode.go ├── decode_test.go ├── doc.go ├── log_likelihood.go ├── log_likelihood_test.go ├── samples.go └── trainer.go ├── anyff ├── doc.go ├── samples.go └── trainer.go ├── anymisc ├── doc.go ├── gumbel_softmax.go ├── gumbel_softmax_test.go ├── relu_rnn.go ├── relu_rnn_test.go ├── selu.go └── selu_test.go ├── anynet.go ├── anyrnn ├── anyrnn.go ├── bidir.go ├── feedback.go ├── feedback_test.go ├── func.go ├── func_test.go ├── layer.go ├── layer_test.go ├── lstm.go ├── lstm_test.go ├── map.go ├── markov.go ├── markov_test.go ├── markup.go ├── parallel.go ├── parallel_test.go ├── serializer_test.go ├── stack.go ├── stack_test.go ├── vanilla.go ├── vanilla_test.go └── vec_state.go ├── anys2s ├── doc.go ├── samples.go └── trainer.go ├── anys2v ├── doc.go ├── samples.go └── trainer.go ├── anysgd ├── adam.go ├── adam_test.go ├── anysgd.go ├── anysgd_test.go ├── hash_split.go ├── interfaces.go ├── marshal.go ├── marshal_test.go ├── momentum.go ├── rmsprop.go └── util.go ├── cost.go ├── cost_test.go ├── debug.go ├── demo └── mnist │ └── main.go ├── dropout.go ├── fc.go ├── mixer.go ├── param_hider.go └── serializer_test.go /README.md: -------------------------------------------------------------------------------- 1 | # anynet [![GoDoc](https://godoc.org/github.com/unixpickle/anynet?status.svg)](https://godoc.org/github.com/unixpickle/anynet) 2 | 3 | **anynet** is a [neural network](https://en.wikipedia.org/wiki/Artificial_neural_network) framework based on [anydiff](https://github.com/unixpickle/anydiff) and [anyvec](https://github.com/unixpickle/anyvec). 4 | 5 | # Supported features 6 | 7 | *anynet* ships with a ton of built-in features: 8 | 9 | * Feed-forward neural networks 10 | * Fully-connected layers 11 | * Convolution 12 | * Dropout 13 | * Max/Mean pooling 14 | * Batch normalization 15 | * Residual connections 16 | * Image scaling 17 | * Image padding 18 | * Recurrent neural networks 19 | * LSTM 20 | * Bidirectional RNNs 21 | * npRNN and IRNN (vanilla RNNs with ReLU activations) 22 | * Training setups 23 | * Vector-to-vector (standard feed-forward) 24 | * Sequence-to-sequence (standard RNN) 25 | * Sequence-to-vector 26 | * Connectionist Temporal Classification 27 | * Miscellaneous 28 | * Gumbel Softmax 29 | 30 | Plenty of stuff is missing from the above list. Luckily, it's easy to write new APIs on top of *anynet*. Here is a non-exhaustive list of packages that work with *anynet*: 31 | 32 | * [unixpickle/anyrl](https://github.com/unixpickle/anyrl) - deep reinforcement learning 33 | * [unixpickle/lazyseq](https://github.com/unixpickle/lazyseq) - memory-efficient RNNs 34 | * [unixpickle/attention](https://github.com/unixpickle/attention) - attention mechanisms 35 | * [unixpickle/rwa](https://github.com/unixpickle/rwa) - a new attention-based RNN architecture 36 | 37 | # TODO 38 | 39 | Here are some minor things I'd like to get done at some point. None of these are very urgent, as *anynet* is already complete for the most part. 40 | 41 | * anyrnn 42 | * Tests comparing LSTM outputs to another implementation 43 | * GRU (gated recurrent units) 44 | * anysgd 45 | * Gradient clipping 46 | * Marshalling for RMSProp 47 | * Marshalling for Momentum 48 | -------------------------------------------------------------------------------- /activation.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/serializer" 8 | ) 9 | 10 | func init() { 11 | var a Activation 12 | serializer.RegisterTypedDeserializer(a.SerializerType(), DeserializeActivation) 13 | } 14 | 15 | // An Activation is a standard activation function. 16 | type Activation int 17 | 18 | // These are standard activation function. 19 | const ( 20 | Tanh Activation = iota 21 | LogSoftmax 22 | Sigmoid 23 | ReLU 24 | Sin 25 | Exp 26 | ) 27 | 28 | // DeserializeActivation deserializes an Activation. 29 | func DeserializeActivation(d []byte) (Activation, error) { 30 | if len(d) != 1 { 31 | return 0, fmt.Errorf("deserialize Activation: data length (%d) should be 1", len(d)) 32 | } 33 | a := Activation(d[0]) 34 | if a > Exp { 35 | return 0, fmt.Errorf("deserialize Activation: unknown activation ID: %d", a) 36 | } 37 | return a, nil 38 | } 39 | 40 | // Apply applies the activation function. 41 | func (a Activation) Apply(in anydiff.Res, n int) anydiff.Res { 42 | switch a { 43 | case Tanh: 44 | return anydiff.Tanh(in) 45 | case LogSoftmax: 46 | inLen := in.Output().Len() 47 | if inLen%n != 0 { 48 | panic("batch size must divide input length") 49 | } 50 | return anydiff.LogSoftmax(in, inLen/n) 51 | case Sigmoid: 52 | return anydiff.Sigmoid(in) 53 | case ReLU: 54 | return anydiff.ClipPos(in) 55 | case Sin: 56 | return anydiff.Sin(in) 57 | case Exp: 58 | return anydiff.Exp(in) 59 | default: 60 | panic(fmt.Sprintf("unknown activation: %d", a)) 61 | } 62 | } 63 | 64 | // SerializerType returns the unique ID used to serialize 65 | // an Activation. 66 | func (a Activation) SerializerType() string { 67 | return "github.com/unixpickle/anynet.Activation" 68 | } 69 | 70 | // Serialize serializes the activation. 71 | func (a Activation) Serialize() ([]byte, error) { 72 | return []byte{byte(a)}, nil 73 | } 74 | -------------------------------------------------------------------------------- /affine.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | "github.com/unixpickle/anyvec/anyvecsave" 7 | "github.com/unixpickle/essentials" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | func init() { 12 | var a Affine 13 | serializer.RegisterTypedDeserializer(a.SerializerType(), DeserializeAffine) 14 | var c ConstAffine 15 | serializer.RegisterTypedDeserializer(c.SerializerType(), DeserializeConstAffine) 16 | } 17 | 18 | // Affine is a layer which performs component-wise affine 19 | // transformations. 20 | // 21 | // In other words, for every component x[i], it computes 22 | // 23 | // a[i%len(a)]*x[i] + b[i%len(b)] 24 | // 25 | // The scaler vector a and bias vector b are learnable 26 | // variables. 27 | type Affine struct { 28 | Scalers *anydiff.Var 29 | Biases *anydiff.Var 30 | } 31 | 32 | // NewAffine creates an Affine layer with one scaler and 33 | // bias. 34 | func NewAffine(c anyvec.Creator, scaler, bias float64) *Affine { 35 | scalerVec := c.MakeVector(1) 36 | biasVec := c.MakeVector(1) 37 | scalerVec.AddScalar(c.MakeNumeric(scaler)) 38 | biasVec.AddScalar(c.MakeNumeric(bias)) 39 | return &Affine{ 40 | Scalers: anydiff.NewVar(scalerVec), 41 | Biases: anydiff.NewVar(biasVec), 42 | } 43 | } 44 | 45 | // DeserializeAffine deserializes an Affine layer. 46 | func DeserializeAffine(d []byte) (*Affine, error) { 47 | var s, b *anyvecsave.S 48 | if err := serializer.DeserializeAny(d, &s, &b); err != nil { 49 | return nil, essentials.AddCtx("deserialize Affine", err) 50 | } 51 | return &Affine{ 52 | Scalers: anydiff.NewVar(s.Vector), 53 | Biases: anydiff.NewVar(b.Vector), 54 | }, nil 55 | } 56 | 57 | // Apply applies the layer to a batch of inputs. 58 | // The size of each vector in the batch must be divisible 59 | // by the number of scalers and biases. 60 | func (a *Affine) Apply(in anydiff.Res, n int) anydiff.Res { 61 | if in.Output().Len()%n != 0 { 62 | panic("input size not divisible by batch size") 63 | } 64 | inLen := in.Output().Len() / n 65 | if inLen%a.Scalers.Vector.Len() != 0 || inLen%a.Biases.Vector.Len() != 0 { 66 | panic("scaler and bias count must divide input count") 67 | } 68 | return anydiff.ScaleAddRepeated(in, a.Scalers, a.Biases) 69 | } 70 | 71 | // Parameters returns a slice containing the scalers 72 | // followed by the biases. 73 | func (a *Affine) Parameters() []*anydiff.Var { 74 | return []*anydiff.Var{a.Scalers, a.Biases} 75 | } 76 | 77 | // SerializerType returns the unique ID used to serialize 78 | // an Affine with the serializer package. 79 | func (a *Affine) SerializerType() string { 80 | return "github.com/unixpickle/anynet.Affine" 81 | } 82 | 83 | // Serialize serializes the layer. 84 | func (a *Affine) Serialize() ([]byte, error) { 85 | return serializer.SerializeAny( 86 | &anyvecsave.S{Vector: a.Scalers.Vector}, 87 | &anyvecsave.S{Vector: a.Biases.Vector}, 88 | ) 89 | } 90 | 91 | // ConstAffine is a layer which performs component-wise 92 | // affine transformations with a constant bias and scaler. 93 | // 94 | // In other words, each component x is transformed via: 95 | // 96 | // a*x + b 97 | // 98 | type ConstAffine struct { 99 | Scale float64 100 | Bias float64 101 | } 102 | 103 | // DeserializeConstAffine deserializes a ConstAffine. 104 | func DeserializeConstAffine(d []byte) (*ConstAffine, error) { 105 | var res ConstAffine 106 | if err := serializer.DeserializeAny(d, &res.Scale, &res.Bias); err != nil { 107 | return nil, essentials.AddCtx("deserialize ConstAffine", err) 108 | } 109 | return &res, nil 110 | } 111 | 112 | // Apply applies the affine transformation. 113 | func (c *ConstAffine) Apply(in anydiff.Res, n int) anydiff.Res { 114 | cr := in.Output().Creator() 115 | return anydiff.AddScalar( 116 | anydiff.Scale(in, cr.MakeNumeric(c.Scale)), 117 | cr.MakeNumeric(c.Bias), 118 | ) 119 | } 120 | 121 | // SerializerType returns the unique ID used to serialize 122 | // a ConstAffine with the serializer package. 123 | func (c *ConstAffine) SerializerType() string { 124 | return "github.com/unixpickle/anynet.ConstAffine" 125 | } 126 | 127 | // Serialize serializes a ConstAffine. 128 | func (c *ConstAffine) Serialize() ([]byte, error) { 129 | return serializer.SerializeAny(c.Scale, c.Bias) 130 | } 131 | -------------------------------------------------------------------------------- /anyconv/batch_norm.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | "github.com/unixpickle/anyvec/anyvecsave" 7 | "github.com/unixpickle/essentials" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | const defaultBNStabilizer = 1e-3 12 | 13 | func init() { 14 | var b BatchNorm 15 | serializer.RegisterTypedDeserializer(b.SerializerType(), DeserializeBatchNorm) 16 | } 17 | 18 | // BatchNorm is a batch normalization layer. 19 | // 20 | // After a network has finished training, BatchNorm layers 21 | // will typically be replaced with anynet.Affine layers so 22 | // that the normalizations are less noisy. 23 | type BatchNorm struct { 24 | // InputCount indicates how many components to noramlize. 25 | // 26 | // For use after a fully-connected layer, this should be 27 | // the total number of output neurons. 28 | // For use after a convolutional layer, this sholud be 29 | // the number of filters. 30 | InputCount int 31 | 32 | // Post-normalization affine transform. 33 | Scalers *anydiff.Var 34 | Biases *anydiff.Var 35 | 36 | // Stabilizer prevents numerical instability by adding a 37 | // small constant to variances to keep them from being 0. 38 | // 39 | // If it is 0, a default is used. 40 | Stabilizer float64 41 | } 42 | 43 | // DeserializeBatchNorm deserializes a BatchNorm. 44 | func DeserializeBatchNorm(d []byte) (*BatchNorm, error) { 45 | var s, b *anyvecsave.S 46 | var stab serializer.Float64 47 | if err := serializer.DeserializeAny(d, &s, &b, &stab); err != nil { 48 | return nil, essentials.AddCtx("deserialize BatchNorm", err) 49 | } 50 | return &BatchNorm{ 51 | InputCount: s.Vector.Len(), 52 | Scalers: anydiff.NewVar(s.Vector), 53 | Biases: anydiff.NewVar(b.Vector), 54 | Stabilizer: float64(stab), 55 | }, nil 56 | } 57 | 58 | // NewBatchNorm creates a BatchNorm with an input size. 59 | func NewBatchNorm(c anyvec.Creator, inCount int) *BatchNorm { 60 | oneScaler := c.MakeVector(inCount) 61 | oneScaler.AddScalar(c.MakeNumeric(1)) 62 | return &BatchNorm{ 63 | InputCount: inCount, 64 | Scalers: anydiff.NewVar(oneScaler), 65 | Biases: anydiff.NewVar(c.MakeVector(inCount)), 66 | } 67 | } 68 | 69 | // Apply applies the layer to some inputs. 70 | func (b *BatchNorm) Apply(in anydiff.Res, batch int) anydiff.Res { 71 | if in.Output().Len()%b.InputCount != 0 { 72 | panic("invalid input size") 73 | } 74 | return anydiff.Pool(in, func(in anydiff.Res) anydiff.Res { 75 | c := in.Output().Creator() 76 | 77 | negMean := negMeanRows(in, b.InputCount) 78 | secondMoment := meanSquare(in, b.InputCount) 79 | variance := anydiff.Sub(secondMoment, anydiff.Square(negMean)) 80 | 81 | variance = anydiff.AddScalar(variance, c.MakeNumeric(b.stabilizer())) 82 | normalizer := anydiff.Pow(variance, c.MakeNumeric(-0.5)) 83 | 84 | totalScaler := anydiff.Mul(b.Scalers, normalizer) 85 | return anydiff.Pool(totalScaler, func(totalScaler anydiff.Res) anydiff.Res { 86 | return anydiff.ScaleAddRepeated( 87 | in, 88 | totalScaler, 89 | anydiff.Add(b.Biases, anydiff.Mul(negMean, totalScaler)), 90 | ) 91 | }) 92 | }) 93 | } 94 | 95 | // Parameters returns a slice containing the scales and 96 | // biases, in that order. 97 | func (b *BatchNorm) Parameters() []*anydiff.Var { 98 | return []*anydiff.Var{b.Scalers, b.Biases} 99 | } 100 | 101 | // SerializerType returns the unique ID used to serialize 102 | // a BatchNorm with the serializer package. 103 | func (b *BatchNorm) SerializerType() string { 104 | return "github.com/unixpickle/anynet/anyconv.BatchNorm" 105 | } 106 | 107 | // Serialize serializes the layer. 108 | func (b *BatchNorm) Serialize() ([]byte, error) { 109 | return serializer.SerializeAny( 110 | &anyvecsave.S{Vector: b.Scalers.Vector}, 111 | &anyvecsave.S{Vector: b.Biases.Vector}, 112 | serializer.Float64(b.Stabilizer), 113 | ) 114 | } 115 | 116 | func (b *BatchNorm) stabilizer() float64 { 117 | if b.Stabilizer == 0 { 118 | return defaultBNStabilizer 119 | } else { 120 | return b.Stabilizer 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /anyconv/batch_norm_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func TestBatchNormSerialize(t *testing.T) { 16 | layer := randomizedBatchNorm(4) 17 | data, err := serializer.SerializeAny(layer) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | var newLayer *BatchNorm 22 | if err := serializer.DeserializeAny(data, &newLayer); err != nil { 23 | t.Fatal(err) 24 | } 25 | if !reflect.DeepEqual(layer, newLayer) { 26 | t.Error("layers differ") 27 | } 28 | } 29 | 30 | func TestBatchNormOutput(t *testing.T) { 31 | layer := &BatchNorm{ 32 | InputCount: 2, 33 | Scalers: anydiff.NewVar(anyvec32.MakeVectorData([]float32{2, -3})), 34 | Biases: anydiff.NewVar(anyvec32.MakeVectorData([]float32{-1.5, 2})), 35 | } 36 | vec := anyvec32.MakeVectorData([]float32{ 37 | -0.636299517987754, 1.381820934572628, 1.117062796520384, 38 | -1.032042307499387, -0.603144099627179, 0.937477768422949, 39 | }) 40 | actual := layer.Apply(anydiff.NewConst(vec), 1).Output().Data().([]float32) 41 | expected := []float32{ 42 | -2.953427612694010, -0.723517206628873, 1.325934113323319, 43 | 6.176822169129221, -2.872506500629310, 0.546695037499651, 44 | } 45 | for i, x := range expected { 46 | a := actual[i] 47 | if math.IsNaN(float64(a)) || math.Abs(float64(a-x)) > 1e-3 { 48 | t.Fatalf("expected %v but got %v", expected, actual) 49 | } 50 | } 51 | } 52 | 53 | func TestBatchNormProp(t *testing.T) { 54 | layer := NewBatchNorm(anyvec32.CurrentCreator(), 2) 55 | input := anyvec32.MakeVector(24) 56 | anyvec.Rand(input, anyvec.Normal, nil) 57 | inVar := anydiff.NewVar(input) 58 | 59 | checker := anydifftest.ResChecker{ 60 | F: func() anydiff.Res { 61 | return layer.Apply(inVar, 12) 62 | }, 63 | V: []*anydiff.Var{inVar}, 64 | } 65 | checker.FullCheck(t) 66 | } 67 | -------------------------------------------------------------------------------- /anyconv/conv.go: -------------------------------------------------------------------------------- 1 | // Package anyconv provides various types of layers for 2 | // convolutional neural networks. 3 | package anyconv 4 | 5 | import ( 6 | "errors" 7 | "math" 8 | 9 | "github.com/unixpickle/anydiff" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvecsave" 12 | "github.com/unixpickle/essentials" 13 | "github.com/unixpickle/serializer" 14 | ) 15 | 16 | func init() { 17 | var c Conv 18 | serializer.RegisterTypedDeserializer(c.SerializerType(), DeserializeConv) 19 | } 20 | 21 | // Conv is a convolutional layer. 22 | // 23 | // All input and output tensors are row-major depth-minor. 24 | type Conv struct { 25 | FilterCount int 26 | FilterWidth int 27 | FilterHeight int 28 | 29 | StrideX int 30 | StrideY int 31 | 32 | InputWidth int 33 | InputHeight int 34 | InputDepth int 35 | 36 | Filters *anydiff.Var 37 | Biases *anydiff.Var 38 | 39 | Conver Conver 40 | } 41 | 42 | // DeserializeConv deserialize a Conv. 43 | // 44 | // The Conver is automatically set. 45 | func DeserializeConv(d []byte) (*Conv, error) { 46 | var inW, inH, inD, fW, fH, sX, sY serializer.Int 47 | var f, b *anyvecsave.S 48 | err := serializer.DeserializeAny(d, &inW, &inH, &inD, &fW, &fH, &sX, &sY, &f, &b) 49 | if err != nil { 50 | return nil, essentials.AddCtx("deserialize Conv", err) 51 | } 52 | res := Conv{ 53 | FilterCount: f.Vector.Len() / int(fW*fH*inD), 54 | FilterWidth: int(fW), 55 | FilterHeight: int(fH), 56 | StrideX: int(sX), 57 | StrideY: int(sY), 58 | 59 | InputWidth: int(inW), 60 | InputHeight: int(inH), 61 | InputDepth: int(inD), 62 | 63 | Filters: anydiff.NewVar(f.Vector), 64 | Biases: anydiff.NewVar(b.Vector), 65 | } 66 | res.Conver = CurrentConverMaker()(res) 67 | return &res, nil 68 | } 69 | 70 | // InitRand the biases an filters in a randomized fashion 71 | // and sets the Conver. 72 | func (c *Conv) InitRand(cr anyvec.Creator) { 73 | c.InitZero(cr) 74 | 75 | normalizer := 1 / math.Sqrt(float64(c.FilterWidth*c.FilterHeight*c.InputDepth)) 76 | anyvec.Rand(c.Filters.Vector, anyvec.Normal, nil) 77 | c.Filters.Vector.Scale(cr.MakeNumeric(normalizer)) 78 | } 79 | 80 | // InitZero initializes the layer to zero and sets the 81 | // Conver. 82 | func (c *Conv) InitZero(cr anyvec.Creator) { 83 | filterSize := c.FilterWidth * c.FilterHeight * c.InputDepth 84 | c.Filters = anydiff.NewVar(cr.MakeVector(filterSize * c.FilterCount)) 85 | c.Biases = anydiff.NewVar(cr.MakeVector(c.FilterCount)) 86 | c.Conver = CurrentConverMaker()(*c) 87 | } 88 | 89 | // OutputWidth returns the width of the output tensor. 90 | func (c *Conv) OutputWidth() int { 91 | w := 1 + (c.InputWidth-c.FilterWidth)/c.StrideX 92 | if w < 0 { 93 | return 0 94 | } else { 95 | return w 96 | } 97 | } 98 | 99 | // OutputHeight returns the height of the output tensor. 100 | func (c *Conv) OutputHeight() int { 101 | h := 1 + (c.InputHeight-c.FilterHeight)/c.StrideY 102 | if h < 0 { 103 | return 0 104 | } else { 105 | return h 106 | } 107 | } 108 | 109 | // OutputDepth returns the depth of the output tensor. 110 | func (c *Conv) OutputDepth() int { 111 | return c.FilterCount 112 | } 113 | 114 | // Apply applies the layer to an input tensor using the 115 | // Conver. 116 | // 117 | // The layer must have been initialized. 118 | func (c *Conv) Apply(in anydiff.Res, batchSize int) anydiff.Res { 119 | return c.Conver.Apply(in, batchSize) 120 | } 121 | 122 | // Parameters returns the layer's parameters. 123 | // The filters come before the biases in the resulting 124 | // slice. 125 | // 126 | // If the layer is uninitialized, the result is nil. 127 | func (c *Conv) Parameters() []*anydiff.Var { 128 | if c.Filters == nil || c.Biases == nil { 129 | return nil 130 | } 131 | return []*anydiff.Var{c.Filters, c.Biases} 132 | } 133 | 134 | // SerializerType returns the unique ID used to serialize 135 | // a Conv with the serializer package. 136 | func (c *Conv) SerializerType() string { 137 | return "github.com/unixpickle/anynet/anyconv.Conv" 138 | } 139 | 140 | // Serialize serializes the layer. 141 | // 142 | // If the layer was not yet initialized, this fails. 143 | func (c *Conv) Serialize() ([]byte, error) { 144 | if c.Filters == nil || c.Biases == nil { 145 | return nil, errors.New("cannot serialize uninitialized Conv") 146 | } 147 | return serializer.SerializeAny( 148 | serializer.Int(c.InputWidth), 149 | serializer.Int(c.InputHeight), 150 | serializer.Int(c.InputDepth), 151 | serializer.Int(c.FilterWidth), 152 | serializer.Int(c.FilterHeight), 153 | serializer.Int(c.StrideX), 154 | serializer.Int(c.StrideY), 155 | &anyvecsave.S{Vector: c.Filters.Vector}, 156 | &anyvecsave.S{Vector: c.Biases.Vector}, 157 | ) 158 | } 159 | -------------------------------------------------------------------------------- /anyconv/conver_test.go: -------------------------------------------------------------------------------- 1 | // Package anyconv provides various types of layers for 2 | // convolutional neural networks. 3 | package anyconv 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | ) 12 | 13 | func TestParallelConver(t *testing.T) { 14 | c := anyvec64.CurrentCreator() 15 | conv := Conv{ 16 | FilterCount: 13, 17 | FilterWidth: 4, 18 | FilterHeight: 3, 19 | 20 | StrideX: 2, 21 | StrideY: 3, 22 | 23 | InputWidth: 30, 24 | InputHeight: 20, 25 | InputDepth: 7, 26 | } 27 | conv.InitRand(c) 28 | testConverEquiv(t, MakeDefaultConver(conv).(*conver), 29 | MakeParallelConver(conv).(*conver)) 30 | } 31 | 32 | func testConverEquiv(t *testing.T, c1, c2 *conver) { 33 | c := c1.conv.Filters.Vector.Creator() 34 | inSize := c1.im2row.InputWidth * c1.im2row.InputHeight * c1.im2row.InputDepth 35 | outSize := c1.conv.OutputWidth() * c1.conv.OutputHeight() * c1.conv.OutputDepth() 36 | 37 | batchSize := 32 38 | inBatch := c.MakeVector(inSize * batchSize) 39 | anyvec.Rand(inBatch, anyvec.Normal, nil) 40 | inVar := anydiff.NewVar(inBatch) 41 | 42 | out1 := c1.Apply(inVar, batchSize) 43 | out2 := c2.Apply(inVar, batchSize) 44 | if !vecsClose(out1.Output(), out2.Output()) { 45 | t.Error("mismatching output values") 46 | } 47 | 48 | upstream := c.MakeVector(outSize * batchSize) 49 | anyvec.Rand(upstream, anyvec.Normal, nil) 50 | 51 | vars := []*anydiff.Var{inVar, c1.conv.Filters, c1.conv.Biases} 52 | grad1 := anydiff.NewGrad(vars...) 53 | out1.Propagate(upstream.Copy(), grad1) 54 | grad2 := anydiff.NewGrad(vars...) 55 | out2.Propagate(upstream.Copy(), grad2) 56 | 57 | for i, variable := range vars { 58 | g1 := grad1[variable] 59 | g2 := grad2[variable] 60 | if !vecsClose(g1, g2) { 61 | t.Errorf("gradient for variable %d differs", i) 62 | } 63 | } 64 | } 65 | 66 | func vecsClose(v1, v2 anyvec.Vector) bool { 67 | c := v1.Creator() 68 | diff := v1.Copy() 69 | diff.Sub(v2) 70 | maxDiff := anyvec.AbsMax(diff) 71 | thresh := c.MakeNumeric(1e-3) 72 | return c.NumOps().Less(maxDiff, thresh) 73 | } 74 | -------------------------------------------------------------------------------- /anyconv/image.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "image" 5 | "image/color" 6 | 7 | "github.com/unixpickle/anyvec" 8 | ) 9 | 10 | // ImageToTensor converts an image to a tensor of RGB 11 | // values. 12 | // Values in the tensor range between 0 and 1. 13 | func ImageToTensor(c anyvec.Creator, img image.Image) anyvec.Vector { 14 | w := img.Bounds().Dx() 15 | h := img.Bounds().Dy() 16 | minX := img.Bounds().Min.X 17 | minY := img.Bounds().Min.Y 18 | 19 | res := make([]float64, w*h*3) 20 | idx := 0 21 | for y := 0; y < h; y++ { 22 | for x := 0; x < w; x++ { 23 | r, g, b, _ := img.At(minX+x, minY+y).RGBA() 24 | for _, comp := range []uint32{r, g, b} { 25 | res[idx] = float64(comp) / 0xffff 26 | idx++ 27 | } 28 | } 29 | } 30 | 31 | return c.MakeVectorData(c.MakeNumericList(res)) 32 | } 33 | 34 | // TensorToImage converts a tensor of RGB values into an 35 | // image. 36 | // Values in the tensor are clipped between 0 and 1. 37 | // 38 | // The anyvec.NumericList type must be []float32 or 39 | // []float64. 40 | func TensorToImage(width, height int, v anyvec.Vector) image.Image { 41 | var rawData []float64 42 | switch data := v.Data().(type) { 43 | case []float64: 44 | rawData = data 45 | case []float32: 46 | rawData = make([]float64, len(data)) 47 | for i, x := range data { 48 | rawData[i] = float64(x) 49 | } 50 | } 51 | 52 | if len(rawData) != width*height*3 { 53 | panic("incorrect tensor size") 54 | } 55 | for i, x := range rawData { 56 | if x < 0 { 57 | rawData[i] = 0 58 | } else if x > 1 { 59 | rawData[i] = 1 60 | } 61 | } 62 | 63 | res := image.NewRGBA(image.Rect(0, 0, width, height)) 64 | idx := 0 65 | for y := 0; y < height; y++ { 66 | for x := 0; x < width; x++ { 67 | var vals [3]uint8 68 | for z := 0; z < 3; z++ { 69 | vals[z] = uint8(rawData[idx]*0xff + 0.5) 70 | idx++ 71 | } 72 | res.SetRGBA(x, y, color.RGBA{R: vals[0], G: vals[1], B: vals[2], A: 0xff}) 73 | } 74 | } 75 | 76 | return res 77 | } 78 | -------------------------------------------------------------------------------- /anyconv/image_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "image" 5 | "image/color" 6 | "testing" 7 | 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | ) 10 | 11 | func TestImageConversion(t *testing.T) { 12 | inImg := image.NewRGBA(image.Rect(1, 1, 3, 3)) 13 | inImg.SetRGBA(1, 1, color.RGBA{R: 0x13, G: 0x37, B: 0x66, A: 0xff}) 14 | inImg.SetRGBA(1, 2, color.RGBA{R: 0x37, G: 0x37, B: 0x66, A: 0xff}) 15 | inImg.SetRGBA(2, 1, color.RGBA{R: 0x10, G: 0x37, B: 0x66, A: 0xff}) 16 | inImg.SetRGBA(2, 2, color.RGBA{R: 0x5, G: 0x37, B: 0x66, A: 0xff}) 17 | 18 | tensor := ImageToTensor(anyvec32.CurrentCreator(), inImg) 19 | outImg := TensorToImage(2, 2, tensor) 20 | 21 | for x := 0; x < 2; x++ { 22 | for y := 0; y < 2; y++ { 23 | oldR, oldG, oldB, _ := inImg.At(x+1, y+1).RGBA() 24 | newR, newG, newB, _ := outImg.At(x, y).RGBA() 25 | olds := []uint32{oldR, oldG, oldB} 26 | news := []uint32{newR, newG, newB} 27 | for z, expected := range olds { 28 | a := news[z] 29 | if expected/0x100 != a/0x100 { 30 | t.Errorf("value %d,%d,%d should be %d but got %d", x, y, z, 31 | expected/0x100, a/0x100) 32 | } 33 | } 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /anyconv/max_pool_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func TestMaxPoolSerialize(t *testing.T) { 16 | mp := &MaxPool{ 17 | SpanX: 3, 18 | SpanY: 2, 19 | InputWidth: 15, 20 | InputHeight: 13, 21 | InputDepth: 4, 22 | } 23 | data, err := serializer.SerializeAny(mp) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | var newLayer *MaxPool 28 | if err := serializer.DeserializeAny(data, &newLayer); err != nil { 29 | t.Fatal(err) 30 | } 31 | if !reflect.DeepEqual(newLayer, mp) { 32 | t.Fatal("layers differ") 33 | } 34 | } 35 | 36 | func TestMaxPoolOutput(t *testing.T) { 37 | mp := &MaxPool{ 38 | SpanX: 3, 39 | SpanY: 2, 40 | StrideX: 2, 41 | StrideY: 1, 42 | InputWidth: 15, 43 | InputHeight: 13, 44 | InputDepth: 4, 45 | } 46 | input := anyvec32.MakeVector(15 * 13 * 4 * 2) 47 | anyvec.Rand(input, anyvec.Normal, nil) 48 | 49 | expected := naiveMaxPool(mp, input.Data().([]float32)[:15*13*4]) 50 | expected = append(expected, naiveMaxPool(mp, input.Data().([]float32)[15*13*4:])...) 51 | actual := mp.Apply(anydiff.NewConst(input), 2).Output().Data().([]float32) 52 | 53 | if len(actual) != len(expected) { 54 | t.Fatalf("expected length %d but got %d", len(expected), len(actual)) 55 | } 56 | 57 | for i, x := range expected { 58 | a := actual[i] 59 | if math.Abs(float64(x-a)) > 1e-3 { 60 | t.Errorf("output %d: should be %f but got %f", i, x, a) 61 | break 62 | } 63 | } 64 | } 65 | 66 | func TestMaxPoolProp(t *testing.T) { 67 | layer := &MaxPool{ 68 | SpanX: 3, 69 | SpanY: 2, 70 | StrideX: 2, 71 | StrideY: 1, 72 | InputWidth: 15, 73 | InputHeight: 13, 74 | InputDepth: 4, 75 | } 76 | img := anyvec32.MakeVector(15 * 13 * 4 * 2) 77 | anyvec.Rand(img, anyvec.Uniform, nil) 78 | inVar := anydiff.NewVar(img) 79 | 80 | checker := anydifftest.ResChecker{ 81 | F: func() anydiff.Res { 82 | return layer.Apply(inVar, 2) 83 | }, 84 | V: []*anydiff.Var{inVar}, 85 | Delta: 1e-5, 86 | Prec: 1e-2, 87 | } 88 | checker.FullCheck(t) 89 | } 90 | 91 | func naiveMaxPool(m *MaxPool, img []float32) []float32 { 92 | var res []float32 93 | for y := 0; y+m.SpanY <= m.InputHeight; y += m.StrideY { 94 | for x := 0; x+m.SpanX <= m.InputWidth; x += m.StrideX { 95 | for z := 0; z < m.InputDepth; z++ { 96 | res = append(res, maxInRegion(m, x, y, z, img)) 97 | } 98 | } 99 | } 100 | return res 101 | } 102 | 103 | func maxInRegion(m *MaxPool, x, y, z int, img []float32) float32 { 104 | value := float32(math.Inf(-1)) 105 | for subY := 0; subY < m.SpanY; subY++ { 106 | for subX := 0; subX < m.SpanX; subX++ { 107 | idx := ((subY+y)*m.InputWidth+subX+x)*m.InputDepth + z 108 | if img[idx] > value { 109 | value = img[idx] 110 | } 111 | } 112 | } 113 | return value 114 | } 115 | -------------------------------------------------------------------------------- /anyconv/mean_pool.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func init() { 13 | var m MeanPool 14 | serializer.RegisterTypedDeserializer(m.SerializerType(), DeserializeMeanPool) 15 | } 16 | 17 | // MeanPool is a mean-pooling layer. 18 | // 19 | // It behaves very similarly to MaxPool. 20 | type MeanPool struct { 21 | SpanX int 22 | SpanY int 23 | 24 | StrideX int 25 | StrideY int 26 | 27 | InputWidth int 28 | InputHeight int 29 | InputDepth int 30 | 31 | im2colLock sync.Mutex 32 | im2col anyvec.Mapper 33 | } 34 | 35 | // DeserializeMeanPool deserializes a MeanPool. 36 | func DeserializeMeanPool(d []byte) (*MeanPool, error) { 37 | mp, err := DeserializeMaxPool(d) 38 | if err != nil { 39 | return nil, essentials.AddCtx("deserialize MeanPool", err) 40 | } 41 | return &MeanPool{ 42 | SpanX: mp.SpanX, 43 | SpanY: mp.SpanY, 44 | StrideX: mp.StrideX, 45 | StrideY: mp.StrideY, 46 | InputWidth: mp.InputWidth, 47 | InputHeight: mp.InputHeight, 48 | InputDepth: mp.InputDepth, 49 | }, nil 50 | } 51 | 52 | // OutputWidth returns the output tensor width. 53 | func (m *MeanPool) OutputWidth() int { 54 | return m.maxPool().OutputWidth() 55 | } 56 | 57 | // OutputHeight returns the output tensor height. 58 | func (m *MeanPool) OutputHeight() int { 59 | return m.maxPool().OutputHeight() 60 | } 61 | 62 | // OutputDepth returns the depth of the output tensor. 63 | func (m *MeanPool) OutputDepth() int { 64 | return m.InputDepth 65 | } 66 | 67 | // Apply applies the pooling layer. 68 | func (m *MeanPool) Apply(in anydiff.Res, batchSize int) anydiff.Res { 69 | m.im2colLock.Lock() 70 | if m.im2col == nil { 71 | m.initIm2Col(in.Output().Creator()) 72 | } 73 | m.im2colLock.Unlock() 74 | 75 | imgSize := m.im2col.InSize() 76 | if in.Output().Len() != batchSize*imgSize { 77 | panic("incorrect input size") 78 | } 79 | 80 | outCount := m.OutputHeight() * m.OutputWidth() * m.OutputDepth() 81 | im2ColTemp := in.Output().Creator().MakeVector(m.im2col.OutSize()) 82 | 83 | var sumResults []anyvec.Vector 84 | for i := 0; i < batchSize; i++ { 85 | subIn := in.Output().Slice(imgSize*i, imgSize*(i+1)) 86 | m.im2col.Map(subIn, im2ColTemp) 87 | sum := anyvec.SumCols(im2ColTemp, outCount) 88 | sumResults = append(sumResults, sum) 89 | } 90 | 91 | out := in.Output().Creator().Concat(sumResults...) 92 | scaler := out.Creator().MakeNumeric(1 / float64(m.SpanX*m.SpanY)) 93 | out.Scale(scaler) 94 | 95 | return &meanPoolRes{ 96 | Layer: m, 97 | N: batchSize, 98 | In: in, 99 | Scaler: scaler, 100 | OutVec: out, 101 | } 102 | } 103 | 104 | // SerializerType returns the unique ID used to serialize 105 | // a MeanPool with the serializer package. 106 | func (m *MeanPool) SerializerType() string { 107 | return "github.com/unixpickle/anynet/anyconv.MeanPool" 108 | } 109 | 110 | // Serialize serializes the layer. 111 | func (m *MeanPool) Serialize() ([]byte, error) { 112 | return m.maxPool().Serialize() 113 | } 114 | 115 | func (m *MeanPool) maxPool() *MaxPool { 116 | return &MaxPool{ 117 | SpanX: m.SpanX, 118 | SpanY: m.SpanY, 119 | StrideX: m.StrideX, 120 | StrideY: m.StrideY, 121 | InputWidth: m.InputWidth, 122 | InputHeight: m.InputHeight, 123 | InputDepth: m.InputDepth, 124 | } 125 | } 126 | 127 | func (m *MeanPool) initIm2Col(c anyvec.Creator) { 128 | mp := m.maxPool() 129 | mp.initIm2Col(c) 130 | m.im2col = mp.im2col 131 | } 132 | 133 | type meanPoolRes struct { 134 | Layer *MeanPool 135 | N int 136 | In anydiff.Res 137 | Scaler anyvec.Numeric 138 | OutVec anyvec.Vector 139 | } 140 | 141 | func (m *meanPoolRes) Output() anyvec.Vector { 142 | return m.OutVec 143 | } 144 | 145 | func (m *meanPoolRes) Vars() anydiff.VarSet { 146 | return m.In.Vars() 147 | } 148 | 149 | func (m *meanPoolRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 150 | // Scaling u first is more efficient. 151 | u.Scale(m.Scaler) 152 | 153 | mappedU := u.Creator().MakeVector(m.Layer.im2col.OutSize() * m.N) 154 | anyvec.AddChunks(mappedU, u) 155 | 156 | m.In.Propagate(batchMapTranspose(m.Layer.im2col, mappedU), g) 157 | } 158 | -------------------------------------------------------------------------------- /anyconv/mean_pool_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func TestMeanPoolSerialize(t *testing.T) { 16 | mp := &MeanPool{ 17 | SpanX: 3, 18 | SpanY: 2, 19 | StrideX: 4, 20 | StrideY: 5, 21 | InputWidth: 15, 22 | InputHeight: 13, 23 | InputDepth: 4, 24 | } 25 | data, err := serializer.SerializeAny(mp) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | var newLayer *MeanPool 30 | if err := serializer.DeserializeAny(data, &newLayer); err != nil { 31 | t.Fatal(err) 32 | } 33 | if !reflect.DeepEqual(newLayer, mp) { 34 | t.Fatal("layers differ") 35 | } 36 | } 37 | 38 | func TestMeanPoolOutput(t *testing.T) { 39 | mp := &MeanPool{ 40 | SpanX: 3, 41 | SpanY: 2, 42 | StrideX: 2, 43 | StrideY: 1, 44 | InputWidth: 15, 45 | InputHeight: 13, 46 | InputDepth: 4, 47 | } 48 | input := anyvec32.MakeVector(15 * 13 * 4 * 2) 49 | anyvec.Rand(input, anyvec.Normal, nil) 50 | 51 | expected := naiveMeanPool(mp, input.Data().([]float32)[:15*13*4]) 52 | expected = append(expected, naiveMeanPool(mp, input.Data().([]float32)[15*13*4:])...) 53 | actual := mp.Apply(anydiff.NewConst(input), 2).Output().Data().([]float32) 54 | 55 | if len(actual) != len(expected) { 56 | t.Fatalf("expected length %d but got %d", len(expected), len(actual)) 57 | } 58 | 59 | for i, x := range expected { 60 | a := actual[i] 61 | if math.Abs(float64(x-a)) > 1e-3 { 62 | t.Errorf("output %d: should be %f but got %f", i, x, a) 63 | break 64 | } 65 | } 66 | } 67 | 68 | func TestMeanPoolProp(t *testing.T) { 69 | layer := &MeanPool{ 70 | SpanX: 3, 71 | SpanY: 2, 72 | StrideX: 2, 73 | StrideY: 2, 74 | InputWidth: 15, 75 | InputHeight: 13, 76 | InputDepth: 4, 77 | } 78 | img := anyvec32.MakeVector(15 * 13 * 4 * 2) 79 | anyvec.Rand(img, anyvec.Normal, nil) 80 | inVar := anydiff.NewVar(img) 81 | 82 | checker := anydifftest.ResChecker{ 83 | F: func() anydiff.Res { 84 | return layer.Apply(inVar, 2) 85 | }, 86 | V: []*anydiff.Var{inVar}, 87 | } 88 | checker.FullCheck(t) 89 | } 90 | 91 | func naiveMeanPool(m *MeanPool, img []float32) []float32 { 92 | var res []float32 93 | for y := 0; y+m.SpanY <= m.InputHeight; y += m.StrideY { 94 | for x := 0; x+m.SpanX <= m.InputWidth; x += m.StrideX { 95 | for z := 0; z < m.InputDepth; z++ { 96 | res = append(res, meanInRegion(m, x, y, z, img)) 97 | } 98 | } 99 | } 100 | return res 101 | } 102 | 103 | func meanInRegion(m *MeanPool, x, y, z int, img []float32) float32 { 104 | var sum float32 105 | for subY := 0; subY < m.SpanY && subY+y < m.InputHeight; subY++ { 106 | for subX := 0; subX < m.SpanX && subX+x < m.InputWidth; subX++ { 107 | idx := ((subY+y)*m.InputWidth+subX+x)*m.InputDepth + z 108 | sum += img[idx] 109 | } 110 | } 111 | return sum / float32(m.SpanX*m.SpanY) 112 | } 113 | -------------------------------------------------------------------------------- /anyconv/padding.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func init() { 13 | var p Padding 14 | serializer.RegisterTypedDeserializer(p.SerializerType(), DeserializePadding) 15 | } 16 | 17 | // A Padding layer adds zeros to the border of input 18 | // tensors. 19 | type Padding struct { 20 | InputWidth int 21 | InputHeight int 22 | InputDepth int 23 | 24 | PaddingTop int 25 | PaddingRight int 26 | PaddingBottom int 27 | PaddingLeft int 28 | 29 | mapperLock sync.Mutex 30 | mapper anyvec.Mapper 31 | } 32 | 33 | // DeserializePadding deserializes a Padding. 34 | func DeserializePadding(d []byte) (*Padding, error) { 35 | var inW, inH, inD, pT, pR, pB, pL serializer.Int 36 | err := serializer.DeserializeAny(d, &inW, &inH, &inD, &pT, &pR, &pB, &pL) 37 | if err != nil { 38 | return nil, essentials.AddCtx("deserialize Padding", err) 39 | } 40 | return &Padding{ 41 | InputWidth: int(inW), 42 | InputHeight: int(inH), 43 | InputDepth: int(inD), 44 | 45 | PaddingTop: int(pT), 46 | PaddingRight: int(pR), 47 | PaddingBottom: int(pB), 48 | PaddingLeft: int(pL), 49 | }, nil 50 | } 51 | 52 | // Apply applies the layer. 53 | func (p *Padding) Apply(in anydiff.Res, batch int) anydiff.Res { 54 | p.mapperLock.Lock() 55 | if p.mapper == nil { 56 | p.initMapper(in.Output().Creator()) 57 | } 58 | p.mapperLock.Unlock() 59 | 60 | if in.Output().Len() != batch*p.mapper.OutSize() { 61 | panic("incorrect input size") 62 | } 63 | return &paddingRes{ 64 | In: in, 65 | Mapper: p.mapper, 66 | OutVec: batchMapTranspose(p.mapper, in.Output()), 67 | } 68 | } 69 | 70 | // SerializerType returns the unique ID used to serialize 71 | // a Padding with the serializer package. 72 | func (p *Padding) SerializerType() string { 73 | return "github.com/unixpickle/anynet/anyconv.Padding" 74 | } 75 | 76 | // Serialize serializes a Padding. 77 | func (p *Padding) Serialize() ([]byte, error) { 78 | return serializer.SerializeAny( 79 | serializer.Int(p.InputWidth), 80 | serializer.Int(p.InputHeight), 81 | serializer.Int(p.InputDepth), 82 | serializer.Int(p.PaddingTop), 83 | serializer.Int(p.PaddingRight), 84 | serializer.Int(p.PaddingBottom), 85 | serializer.Int(p.PaddingLeft), 86 | ) 87 | } 88 | 89 | func (p *Padding) initMapper(c anyvec.Creator) { 90 | newWidth := p.InputWidth + p.PaddingLeft + p.PaddingRight 91 | outSize := newWidth * (p.InputHeight + p.PaddingTop + p.PaddingBottom) * p.InputDepth 92 | table := make([]int, 0, p.InputWidth*p.InputHeight*p.InputDepth) 93 | 94 | for y := 0; y < p.InputHeight; y++ { 95 | yOffset := (y + p.PaddingTop) * newWidth * p.InputDepth 96 | for x := 0; x < p.InputWidth; x++ { 97 | xOffset := yOffset + (x+p.PaddingLeft)*p.InputDepth 98 | for z := 0; z < p.InputDepth; z++ { 99 | table = append(table, xOffset+z) 100 | } 101 | } 102 | } 103 | 104 | p.mapper = c.MakeMapper(outSize, table) 105 | } 106 | 107 | type paddingRes struct { 108 | In anydiff.Res 109 | Mapper anyvec.Mapper 110 | OutVec anyvec.Vector 111 | } 112 | 113 | func (p *paddingRes) Output() anyvec.Vector { 114 | return p.OutVec 115 | } 116 | 117 | func (p *paddingRes) Vars() anydiff.VarSet { 118 | return p.In.Vars() 119 | } 120 | 121 | func (p *paddingRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 122 | p.In.Propagate(batchMap(p.Mapper, u), g) 123 | } 124 | -------------------------------------------------------------------------------- /anyconv/padding_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func TestPaddingSerialize(t *testing.T) { 16 | pl := &Padding{ 17 | InputWidth: 1, 18 | InputHeight: 2, 19 | InputDepth: 3, 20 | 21 | PaddingTop: 4, 22 | PaddingBottom: 5, 23 | PaddingLeft: 6, 24 | PaddingRight: 7, 25 | } 26 | data, err := serializer.SerializeAny(pl) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | var newLayer *Padding 31 | if err := serializer.DeserializeAny(data, &newLayer); err != nil { 32 | t.Fatal(err) 33 | } 34 | if !reflect.DeepEqual(newLayer, pl) { 35 | t.Fatal("layers differ") 36 | } 37 | } 38 | 39 | func TestPaddingOutput(t *testing.T) { 40 | pl := &Padding{ 41 | InputWidth: 3, 42 | InputHeight: 4, 43 | InputDepth: 2, 44 | 45 | PaddingTop: 1, 46 | PaddingBottom: 2, 47 | PaddingLeft: 3, 48 | PaddingRight: 1, 49 | } 50 | 51 | inTensor := anyvec32.MakeVectorData([]float32{ 52 | 3.868200, 1.104760, 0.360270, 0.046398, 0.800748, -0.579334, 53 | -0.540134, -0.095748, -0.240087, 0.298587, 0.018990, 0.481808, 54 | -0.656787, -0.061479, 1.997873, 0.108665, 1.788285, 0.222048, 55 | 1.153895, 0.780207, -0.655182, 0.495345, -0.244460, -0.841344, 56 | }) 57 | 58 | expected := []float32{ 59 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60 | 0, 0, 0, 0, 0, 0, 3.868200, 1.104760, 0.360270, 0.046398, 0.800748, -0.579334, 0, 0, 61 | 0, 0, 0, 0, 0, 0, -0.540134, -0.095748, -0.240087, 0.298587, 0.018990, 0.481808, 0, 0, 62 | 0, 0, 0, 0, 0, 0, -0.656787, -0.061479, 1.997873, 0.108665, 1.788285, 0.222048, 0, 0, 63 | 0, 0, 0, 0, 0, 0, 1.153895, 0.780207, -0.655182, 0.495345, -0.244460, -0.841344, 0, 0, 64 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 66 | } 67 | actual := pl.Apply(anydiff.NewConst(inTensor), 1).Output().Data().([]float32) 68 | 69 | if len(actual) != len(expected) { 70 | t.Fatalf("len should be %d but got %d", len(expected), len(actual)) 71 | } 72 | 73 | for i, x := range expected { 74 | a := actual[i] 75 | if math.IsNaN(float64(a)) || math.Abs(float64(a-x)) > 1e-3 { 76 | t.Errorf("value %d: should be %f but got %f", i, x, a) 77 | } 78 | } 79 | } 80 | 81 | func TestPaddingProp(t *testing.T) { 82 | layer := &Padding{ 83 | InputWidth: 3, 84 | InputHeight: 4, 85 | InputDepth: 2, 86 | 87 | PaddingTop: 1, 88 | PaddingBottom: 2, 89 | PaddingLeft: 3, 90 | PaddingRight: 1, 91 | } 92 | img := anyvec32.MakeVector(3 * 4 * 2 * 2) 93 | anyvec.Rand(img, anyvec.Uniform, nil) 94 | inVar := anydiff.NewVar(img) 95 | 96 | checker := anydifftest.ResChecker{ 97 | F: func() anydiff.Res { 98 | return layer.Apply(inVar, 2) 99 | }, 100 | V: []*anydiff.Var{inVar}, 101 | } 102 | checker.FullCheck(t) 103 | } 104 | -------------------------------------------------------------------------------- /anyconv/post_train.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anynet/anyff" 9 | "github.com/unixpickle/anynet/anysgd" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/essentials" 12 | ) 13 | 14 | // A PostTrainer uses a list of samples to replace 15 | // BatchNorm layers with hard-wired affine transforms. 16 | // It automatically looks for BatchNorm layers inside 17 | // Residual blocks. 18 | type PostTrainer struct { 19 | Samples anysgd.SampleList 20 | 21 | // Fetcher must return *anyff.Batch instances. 22 | Fetcher anysgd.Fetcher 23 | 24 | // BatchSize specifies how many samples to feed to the 25 | // network at once. 26 | BatchSize int 27 | 28 | Net anynet.Net 29 | 30 | // StatusFunc, if non-nil, is called for every BatchNorm 31 | // layer right before the layer is replaced. 32 | StatusFunc func(bn *BatchNorm) 33 | } 34 | 35 | // Run performs layer replacement. 36 | // 37 | // It returns with an error if the Fetcher failed. 38 | // However, even in the event of an error, some layers may 39 | // already be replaced with affine transforms. 40 | func (p *PostTrainer) Run() error { 41 | return affinizeNet(p.Net, func(bn *BatchNorm, subNet anynet.Net) (*anynet.Affine, error) { 42 | mean, stddev, err := momentsFromOutputs(bn, p.evaluateBatch(subNet)) 43 | if err != nil { 44 | return nil, essentials.AddCtx("post train", err) 45 | } 46 | scaler := bn.Scalers.Vector.Copy() 47 | scaler.Div(stddev) 48 | bias := bn.Biases.Vector.Copy() 49 | mean.Mul(scaler) 50 | bias.Sub(mean) 51 | if p.StatusFunc != nil { 52 | p.StatusFunc(bn) 53 | } 54 | return &anynet.Affine{ 55 | Scalers: anydiff.NewVar(scaler), 56 | Biases: anydiff.NewVar(bias), 57 | }, nil 58 | }) 59 | } 60 | 61 | func (p *PostTrainer) evaluateBatch(subNet anynet.Net) <-chan *postTrainerOutput { 62 | resChan := make(chan *postTrainerOutput, 1) 63 | batchChan := make(chan anysgd.Batch, 1) 64 | batchSizesChan := make(chan int, 1) 65 | batchErrChan := make(chan error, 1) 66 | go func() { 67 | defer close(batchChan) 68 | defer close(batchSizesChan) 69 | defer close(batchErrChan) 70 | for i := 0; i < p.Samples.Len(); i += p.BatchSize { 71 | bs := p.Samples.Len() - i 72 | if bs > p.BatchSize { 73 | bs = p.BatchSize 74 | } 75 | slice := p.Samples.Slice(i, i+bs) 76 | if batch, err := p.Fetcher.Fetch(slice); err == nil { 77 | batchChan <- batch 78 | batchSizesChan <- bs 79 | } else { 80 | batchErrChan <- err 81 | return 82 | } 83 | } 84 | }() 85 | go func() { 86 | defer close(resChan) 87 | for batch := range batchChan { 88 | bs := <-batchSizesChan 89 | inVec := batch.(*anyff.Batch).Inputs 90 | outVec := subNet.Apply(inVec, bs).Output().Copy() 91 | resChan <- &postTrainerOutput{Vec: outVec} 92 | } 93 | if err := <-batchErrChan; err != nil { 94 | resChan <- &postTrainerOutput{Err: err} 95 | } 96 | }() 97 | return resChan 98 | } 99 | 100 | func affinizeNet(n anynet.Net, f func(*BatchNorm, anynet.Net) (*anynet.Affine, error)) error { 101 | for i, layer := range n { 102 | if bn, ok := layer.(*BatchNorm); ok { 103 | if a, err := f(bn, n[:i]); err != nil { 104 | return err 105 | } else { 106 | n[i] = a 107 | } 108 | continue 109 | } 110 | 111 | r, ok := layer.(*Residual) 112 | if !ok { 113 | continue 114 | } 115 | 116 | for _, part := range []*anynet.Layer{&r.Layer, &r.Projection} { 117 | if *part == nil { 118 | continue 119 | } 120 | switch layer := (*part).(type) { 121 | case anynet.Net: 122 | subF := func(bn *BatchNorm, subNet anynet.Net) (*anynet.Affine, error) { 123 | return f(bn, append(append(anynet.Net{}, n[:i]...), subNet...)) 124 | } 125 | err := affinizeNet(layer, subF) 126 | if err != nil { 127 | return err 128 | } 129 | case *BatchNorm: 130 | if a, err := f(layer, n[:i]); err != nil { 131 | return err 132 | } else { 133 | *part = a 134 | } 135 | } 136 | } 137 | } 138 | return nil 139 | } 140 | 141 | type postTrainerOutput struct { 142 | Err error 143 | Vec anyvec.Vector 144 | } 145 | 146 | func momentsFromOutputs(b *BatchNorm, c <-chan *postTrainerOutput) (mean, 147 | stddev anyvec.Vector, err error) { 148 | var sum, sqSum anyvec.Vector 149 | var count int 150 | for item := range c { 151 | if item.Err != nil { 152 | return nil, nil, item.Err 153 | } 154 | 155 | count += item.Vec.Len() / b.InputCount 156 | thisSum := anyvec.SumRows(item.Vec, b.InputCount) 157 | item.Vec.Mul(item.Vec.Copy()) 158 | thisSqSum := anyvec.SumRows(item.Vec, b.InputCount) 159 | 160 | if sum == nil { 161 | sum = thisSum 162 | sqSum = thisSqSum 163 | } else { 164 | sum.Add(thisSum) 165 | sqSum.Add(thisSqSum) 166 | } 167 | } 168 | if sum == nil { 169 | return nil, nil, errors.New("no samples to average") 170 | } 171 | normalizer := sum.Creator().MakeNumeric(1 / float64(count)) 172 | sum.Scale(normalizer) 173 | sqSum.Scale(normalizer) 174 | 175 | sumSq := sum.Copy() 176 | sumSq.Mul(sum) 177 | 178 | sqSum.Sub(sumSq) 179 | sqSum.AddScalar(sqSum.Creator().MakeNumeric(b.stabilizer())) 180 | anyvec.Pow(sqSum, sqSum.Creator().MakeNumeric(0.5)) 181 | 182 | return sum, sqSum, nil 183 | } 184 | -------------------------------------------------------------------------------- /anyconv/post_train_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anynet/anyff" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | ) 12 | 13 | func TestPostTrainer(t *testing.T) { 14 | net := anynet.Net{ 15 | anynet.NewFC(anyvec64.CurrentCreator(), 3, 4), 16 | randomizedBatchNorm(2), 17 | &Residual{ 18 | Projection: anynet.Net{ 19 | anynet.NewFC(anyvec64.CurrentCreator(), 4, 5), 20 | randomizedBatchNorm(1), 21 | }, 22 | Layer: anynet.Net{ 23 | anynet.NewFC(anyvec64.CurrentCreator(), 4, 5), 24 | randomizedBatchNorm(5), 25 | }, 26 | }, 27 | &Residual{Layer: randomizedBatchNorm(5)}, 28 | } 29 | 30 | var samples anyff.SliceSampleList 31 | for i := 0; i < 8; i++ { 32 | inVec := anyvec64.MakeVector(3) 33 | anyvec.Rand(inVec, anyvec.Normal, nil) 34 | samples = append(samples, &anyff.Sample{ 35 | Input: inVec, 36 | Output: anyvec64.MakeVector(5), 37 | }) 38 | } 39 | 40 | fetcher := &anyff.Trainer{} 41 | fullBatch, _ := fetcher.Fetch(samples) 42 | 43 | expected := net.Apply(fullBatch.(*anyff.Batch).Inputs, samples.Len()).Output() 44 | 45 | pt := &PostTrainer{ 46 | Samples: samples, 47 | Fetcher: fetcher, 48 | BatchSize: 3, 49 | Net: net, 50 | } 51 | if err := pt.Run(); err != nil { 52 | t.Fatal(err) 53 | } 54 | 55 | if _, ok := net[1].(*BatchNorm); ok { 56 | t.Error("first BatchNorm stayed") 57 | } 58 | resid := net[2].(*Residual) 59 | for i, subLayer := range []anynet.Layer{resid.Layer, resid.Projection} { 60 | for j, layer := range subLayer.(anynet.Net) { 61 | if _, ok := layer.(*BatchNorm); ok { 62 | t.Errorf("residual part %d: layer %d is BatchNorm", i, j) 63 | } 64 | } 65 | } 66 | if _, ok := net[3].(*Residual).Layer.(*BatchNorm); ok { 67 | t.Error("second residual's BatchNorm stayed") 68 | } 69 | 70 | actual := net.Apply(fullBatch.(*anyff.Batch).Inputs, samples.Len()).Output() 71 | 72 | for i, x := range expected.Data().([]float64) { 73 | a := actual.Data().([]float64)[i] 74 | if math.Abs(x-a) > 1e-5 || math.IsNaN(a) || math.IsNaN(x) { 75 | t.Errorf("output %d should be %f but got %f", i, x, a) 76 | } 77 | } 78 | } 79 | 80 | func randomizedBatchNorm(inCount int) *BatchNorm { 81 | res := NewBatchNorm(anyvec64.CurrentCreator(), inCount) 82 | anyvec.Rand(res.Scalers.Vector, anyvec.Normal, nil) 83 | anyvec.Rand(res.Biases.Vector, anyvec.Normal, nil) 84 | return res 85 | } 86 | -------------------------------------------------------------------------------- /anyconv/residual.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/essentials" 7 | "github.com/unixpickle/serializer" 8 | ) 9 | 10 | func init() { 11 | var r Residual 12 | serializer.RegisterTypedDeserializer(r.SerializerType(), DeserializeResidual) 13 | } 14 | 15 | // Residual implements residual shortcut connections. 16 | type Residual struct { 17 | // Layer is the residual mapping. 18 | Layer anynet.Layer 19 | 20 | // If non-nil, Projection is applied to the original 21 | // input before it is added back to the output of Layer. 22 | // 23 | // This can be used to deal with residual mappings that 24 | // change the tensor dimensions. 25 | Projection anynet.Layer 26 | } 27 | 28 | // DeserializeResidual deserializes a Residual. 29 | func DeserializeResidual(d []byte) (*Residual, error) { 30 | var layer anynet.Layer 31 | var proj anynet.Net 32 | if err := serializer.DeserializeAny(d, &layer, &proj); err != nil { 33 | return nil, essentials.AddCtx("deserialize Residual", err) 34 | } 35 | res := &Residual{Layer: layer} 36 | if len(proj) == 1 { 37 | res.Projection = proj[0] 38 | } 39 | return res, nil 40 | } 41 | 42 | // Apply applies the layer. 43 | func (r *Residual) Apply(in anydiff.Res, batch int) anydiff.Res { 44 | return anydiff.Pool(in, func(in anydiff.Res) anydiff.Res { 45 | mainOut := r.Layer.Apply(in, batch) 46 | orig := in 47 | if r.Projection != nil { 48 | orig = r.Projection.Apply(in, batch) 49 | } 50 | return anydiff.Add(orig, mainOut) 51 | }) 52 | } 53 | 54 | // Parameters returns the joined parameters of the Layer 55 | // and (if applicable) the Projection. 56 | func (r *Residual) Parameters() []*anydiff.Var { 57 | n := anynet.Net{r.Layer} 58 | if r.Projection != nil { 59 | n = append(n, r.Projection) 60 | } 61 | return n.Parameters() 62 | } 63 | 64 | // SerializerType returns the unique ID used to serialize 65 | // a Residual with the serializer package. 66 | func (r *Residual) SerializerType() string { 67 | return "github.com/unixpickle/anynet/anyconv.Residual" 68 | } 69 | 70 | // Serialize serializes the Residual. 71 | func (r *Residual) Serialize() ([]byte, error) { 72 | var projLayer anynet.Net 73 | if r.Projection != nil { 74 | projLayer = anynet.Net{r.Projection} 75 | } 76 | return serializer.SerializeAny( 77 | r.Layer, 78 | projLayer, 79 | ) 80 | } 81 | -------------------------------------------------------------------------------- /anyconv/residual_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func TestResidualSerialize(t *testing.T) { 13 | c := anyvec32.CurrentCreator() 14 | r := &Residual{ 15 | Layer: anynet.NewFC(c, 3, 2), 16 | Projection: anynet.NewFC(c, 3, 2), 17 | } 18 | data, err := serializer.SerializeAny(r) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | var newLayer *Residual 23 | if err = serializer.DeserializeAny(data, &newLayer); err != nil { 24 | t.Fatal(err) 25 | } 26 | if !reflect.DeepEqual(newLayer, r) { 27 | t.Fatal("layers differ") 28 | } 29 | 30 | r = &Residual{ 31 | Layer: anynet.NewFC(c, 3, 2), 32 | } 33 | data, err = serializer.SerializeAny(r) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | if err = serializer.DeserializeAny(data, &newLayer); err != nil { 38 | t.Fatal(err) 39 | } 40 | if !reflect.DeepEqual(newLayer, r) { 41 | t.Fatal("layers differ") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /anyconv/resize.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func init() { 13 | var r Resize 14 | serializer.RegisterTypedDeserializer(r.SerializerType(), DeserializeResize) 15 | } 16 | 17 | // A Resize layer resizes tensors using bilinear 18 | // interpolation. 19 | // 20 | // The output dimensions must be greater than 1. 21 | type Resize struct { 22 | Depth int 23 | 24 | InputWidth int 25 | InputHeight int 26 | OutputWidth int 27 | OutputHeight int 28 | 29 | mappingLock sync.Mutex 30 | neighborMap anyvec.Mapper 31 | neighborWeights anyvec.Vector 32 | } 33 | 34 | // DeserializeResize deserializes a Resize. 35 | func DeserializeResize(d []byte) (*Resize, error) { 36 | var depth, inW, inH, outW, outH serializer.Int 37 | err := serializer.DeserializeAny(d, &depth, &inW, &inH, &outW, &outH) 38 | if err != nil { 39 | return nil, essentials.AddCtx("deserialize Resize", err) 40 | } 41 | return &Resize{ 42 | Depth: int(depth), 43 | InputWidth: int(inW), 44 | InputHeight: int(inH), 45 | OutputWidth: int(outW), 46 | OutputHeight: int(outH), 47 | }, nil 48 | } 49 | 50 | // Apply applies the layer to an input tensor. 51 | func (r *Resize) Apply(in anydiff.Res, batchSize int) anydiff.Res { 52 | if batchSize == 0 { 53 | return anydiff.NewConst(in.Output().Creator().MakeVector(0)) 54 | } 55 | if r.InputWidth == 0 || r.InputHeight == 0 || r.OutputWidth <= 1 || 56 | r.OutputHeight <= 1 || r.Depth == 0 { 57 | panic("tensor dimension out of range") 58 | } 59 | if r.InputWidth*r.InputHeight*r.Depth*batchSize != in.Output().Len() { 60 | panic("incorrect input size") 61 | } 62 | 63 | r.mappingLock.Lock() 64 | if r.neighborMap == nil { 65 | r.initializeMapping(in.Output().Creator()) 66 | } 67 | r.mappingLock.Unlock() 68 | 69 | mapped := batchMap(r.neighborMap, in.Output()) 70 | anyvec.ScaleRepeated(mapped, r.neighborWeights) 71 | out := anyvec.SumCols(mapped, mapped.Len()/4) 72 | 73 | return &resizeRes{ 74 | Layer: r, 75 | In: in, 76 | Out: out, 77 | Batch: batchSize, 78 | } 79 | } 80 | 81 | // SerializerType returns the unique ID used to serialize 82 | // a Resize with the serializer package. 83 | func (r *Resize) SerializerType() string { 84 | return "github.com/unixpickle/anynet/anyconv.Resize" 85 | } 86 | 87 | // Serialize serializes the Resize. 88 | func (r *Resize) Serialize() ([]byte, error) { 89 | return serializer.SerializeAny( 90 | serializer.Int(r.Depth), 91 | serializer.Int(r.InputWidth), 92 | serializer.Int(r.InputHeight), 93 | serializer.Int(r.OutputWidth), 94 | serializer.Int(r.OutputHeight), 95 | ) 96 | } 97 | 98 | func (r *Resize) initializeMapping(c anyvec.Creator) { 99 | var sources []int 100 | var amounts []float64 101 | xScale := float64(r.InputWidth-1) / float64(r.OutputWidth-1) 102 | yScale := float64(r.InputHeight-1) / float64(r.OutputHeight-1) 103 | for y := 0; y < r.OutputHeight; y++ { 104 | sourceY := yScale * float64(y) 105 | for x := 0; x < r.OutputWidth; x++ { 106 | sourceX := xScale * float64(x) 107 | neighbors, a := r.neighbors(sourceX, sourceY) 108 | for z := 0; z < r.Depth; z++ { 109 | for _, idx := range neighbors[:] { 110 | sources = append(sources, idx+z) 111 | } 112 | amounts = append(amounts, a[:]...) 113 | } 114 | } 115 | } 116 | r.neighborMap = c.MakeMapper(r.InputWidth*r.InputHeight*r.Depth, sources) 117 | r.neighborWeights = c.MakeVectorData(c.MakeNumericList(amounts)) 118 | } 119 | 120 | func (r *Resize) neighbors(sx, sy float64) ([4]int, [4]float64) { 121 | if sx > float64(r.InputWidth-1) { 122 | sx = float64(r.InputWidth - 1) 123 | } 124 | if sy > float64(r.InputHeight-1) { 125 | sy = float64(r.InputHeight - 1) 126 | } 127 | x1, x2 := int(sx), int(sx+1) 128 | y1, y2 := int(sy), int(sy+1) 129 | if x1 < 0 || y1 < 0 { 130 | x1 = 0 131 | y1 = 0 132 | } 133 | if x2 >= r.InputWidth || y2 >= r.InputHeight { 134 | x2 = r.InputWidth - 1 135 | y2 = r.InputHeight - 1 136 | } 137 | 138 | x1A := 1 - (sx - float64(x1)) 139 | y1A := 1 - (sy - float64(y1)) 140 | 141 | return [4]int{ 142 | r.sourceIndex(x1, y1), 143 | r.sourceIndex(x2, y1), 144 | r.sourceIndex(x1, y2), 145 | r.sourceIndex(x2, y2), 146 | }, [4]float64{ 147 | x1A * y1A, 148 | (1 - x1A) * y1A, 149 | x1A * (1 - y1A), 150 | (1 - x1A) * (1 - y1A), 151 | } 152 | } 153 | 154 | func (r *Resize) sourceIndex(x, y int) int { 155 | return r.Depth * (x + r.InputWidth*y) 156 | } 157 | 158 | type resizeRes struct { 159 | Layer *Resize 160 | In anydiff.Res 161 | Out anyvec.Vector 162 | Batch int 163 | } 164 | 165 | func (r *resizeRes) Output() anyvec.Vector { 166 | return r.Out 167 | } 168 | 169 | func (r *resizeRes) Vars() anydiff.VarSet { 170 | return r.In.Vars() 171 | } 172 | 173 | func (r *resizeRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 174 | repSize := r.Layer.neighborWeights.Len() * r.Batch 175 | mappedDown := r.Layer.neighborWeights.Creator().MakeVector(repSize) 176 | anyvec.AddRepeated(mappedDown, r.Layer.neighborWeights) 177 | anyvec.ScaleChunks(mappedDown, u) 178 | down := batchMapTranspose(r.Layer.neighborMap, mappedDown) 179 | r.In.Propagate(down, g) 180 | } 181 | -------------------------------------------------------------------------------- /anyconv/resize_test.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func TestResizeSerialize(t *testing.T) { 16 | r := &Resize{ 17 | Depth: 1, 18 | InputWidth: 2, 19 | InputHeight: 3, 20 | OutputWidth: 4, 21 | OutputHeight: 5, 22 | } 23 | data, err := serializer.SerializeAny(r) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | var newR *Resize 28 | if err := serializer.DeserializeAny(data, &newR); err != nil { 29 | t.Fatal(err) 30 | } 31 | if !reflect.DeepEqual(newR, r) { 32 | t.Error("incorrect result") 33 | } 34 | } 35 | 36 | func TestResizeOut(t *testing.T) { 37 | r := &Resize{ 38 | Depth: 2, 39 | InputWidth: 2, 40 | InputHeight: 2, 41 | OutputWidth: 3, 42 | OutputHeight: 4, 43 | } 44 | img := anyvec32.MakeVectorData([]float32{ 45 | 0.5, 0.4, 0.3, 0.2, 46 | 0.1, 0.9, 0.8, 0.7, 47 | }) 48 | expected := []float32{ 49 | 0.5, 0.4, 0.4, 0.3, 0.3, 0.2, 50 | 0.3666667, 0.5666667, 0.4166667, 0.4666667, 0.4666667, 0.3666667, 51 | 0.2333333, 0.7333333, 0.4333333, 0.6333333, 0.6333333, 0.5333333, 52 | 0.1, 0.9, 0.45, 0.8, 0.8, 0.7, 53 | } 54 | actual := r.Apply(anydiff.NewConst(img), 1).Output().Data().([]float32) 55 | if len(actual) != len(expected) { 56 | t.Fatalf("length should be %d but got %d", len(expected), len(actual)) 57 | } 58 | for i, x := range expected { 59 | a := actual[i] 60 | if math.Abs(float64(x-a)) > 1e-4 { 61 | t.Errorf("value %d: should be %f but got %f", i, x, a) 62 | } 63 | } 64 | } 65 | 66 | func TestResizeProp(t *testing.T) { 67 | r := &Resize{ 68 | Depth: 3, 69 | InputWidth: 4, 70 | InputHeight: 7, 71 | OutputWidth: 6, 72 | OutputHeight: 6, 73 | } 74 | img := anyvec32.MakeVector(4 * 7 * 3 * 2) 75 | anyvec.Rand(img, anyvec.Normal, nil) 76 | inVar := anydiff.NewVar(img) 77 | 78 | checker := anydifftest.ResChecker{ 79 | F: func() anydiff.Res { 80 | return r.Apply(inVar, 2) 81 | }, 82 | V: []*anydiff.Var{inVar}, 83 | } 84 | checker.FullCheck(t) 85 | } 86 | -------------------------------------------------------------------------------- /anyconv/util.go: -------------------------------------------------------------------------------- 1 | package anyconv 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // Weights returns all of the filters and weight matrices 10 | // in the layer without returning biases, scalers, etc. 11 | // It can extract anynet.Net and Residual layers. 12 | func Weights(l anynet.Layer) []*anydiff.Var { 13 | var res []*anydiff.Var 14 | switch l := l.(type) { 15 | case anynet.Net: 16 | for _, sub := range l { 17 | res = append(res, Weights(sub)...) 18 | } 19 | case *anynet.FC: 20 | res = append(res, l.Weights) 21 | case *Conv: 22 | res = append(res, l.Filters) 23 | case *Residual: 24 | res = append(res, Weights(l.Layer)...) 25 | if l.Projection != nil { 26 | res = append(res, Weights(l.Projection)...) 27 | } 28 | } 29 | return res 30 | } 31 | 32 | type meanRowsRes struct { 33 | In anydiff.Res 34 | Scaler anyvec.Numeric 35 | Out anyvec.Vector 36 | } 37 | 38 | // negMeanRows computes the negative of the mean of the 39 | // rows in a row-major matrix. 40 | func negMeanRows(in anydiff.Res, cols int) anydiff.Res { 41 | if in.Output().Len()%cols != 0 { 42 | panic("column count must divide input size") 43 | } 44 | rows := in.Output().Len() / cols 45 | scaler := in.Output().Creator().MakeNumeric(-1 / float64(rows)) 46 | out := anyvec.SumRows(in.Output().Copy(), cols) 47 | out.Scale(scaler) 48 | return &meanRowsRes{ 49 | In: in, 50 | Scaler: scaler, 51 | Out: out, 52 | } 53 | } 54 | 55 | func (m *meanRowsRes) Output() anyvec.Vector { 56 | return m.Out 57 | } 58 | 59 | func (m *meanRowsRes) Vars() anydiff.VarSet { 60 | return m.In.Vars() 61 | } 62 | 63 | func (m *meanRowsRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 64 | u.Scale(m.Scaler) 65 | if v, ok := m.In.(*anydiff.Var); ok { 66 | downstream, ok := g[v] 67 | if !ok { 68 | return 69 | } 70 | anyvec.AddRepeated(downstream, u) 71 | } else { 72 | downstream := m.Out.Creator().MakeVector(m.In.Output().Len()) 73 | anyvec.AddRepeated(downstream, u) 74 | m.In.Propagate(downstream, g) 75 | } 76 | 77 | } 78 | 79 | type meanSquareRes struct { 80 | In anydiff.Res 81 | Scaler anyvec.Numeric 82 | Out anyvec.Vector 83 | } 84 | 85 | // meanSquare is like meanRows, but is squares the rows 86 | // before taking their mean. 87 | func meanSquare(in anydiff.Res, cols int) anydiff.Res { 88 | if in.Output().Len()%cols != 0 { 89 | panic("column count must divide input size") 90 | } 91 | rows := in.Output().Len() / cols 92 | scaler := in.Output().Creator().MakeNumeric(1 / float64(rows)) 93 | squareIn := in.Output().Copy() 94 | squareIn.Mul(in.Output()) 95 | out := anyvec.SumRows(squareIn, cols) 96 | out.Scale(scaler) 97 | return &meanSquareRes{ 98 | In: in, 99 | Scaler: in.Output().Creator().MakeNumeric(2 / float64(rows)), 100 | Out: out, 101 | } 102 | } 103 | 104 | func (m *meanSquareRes) Output() anyvec.Vector { 105 | return m.Out 106 | } 107 | 108 | func (m *meanSquareRes) Vars() anydiff.VarSet { 109 | return m.In.Vars() 110 | } 111 | 112 | func (m *meanSquareRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 113 | u.Scale(m.Scaler) 114 | downstream := m.Out.Creator().MakeVector(m.In.Output().Len()) 115 | anyvec.AddRepeated(downstream, u) 116 | downstream.Mul(m.In.Output()) 117 | m.In.Propagate(downstream, g) 118 | } 119 | 120 | func batchMap(m anyvec.Mapper, in anyvec.Vector) anyvec.Vector { 121 | var mapped []anyvec.Vector 122 | n := in.Len() / m.InSize() 123 | for i := 0; i < n; i++ { 124 | sub := in.Slice(m.InSize()*i, m.InSize()*(i+1)) 125 | newSub := in.Creator().MakeVector(m.OutSize()) 126 | m.Map(sub, newSub) 127 | mapped = append(mapped, newSub) 128 | } 129 | return in.Creator().Concat(mapped...) 130 | } 131 | 132 | func batchMapTranspose(m anyvec.Mapper, in anyvec.Vector) anyvec.Vector { 133 | var mapped []anyvec.Vector 134 | n := in.Len() / m.OutSize() 135 | for i := 0; i < n; i++ { 136 | sub := in.Slice(m.OutSize()*i, m.OutSize()*(i+1)) 137 | newSub := in.Creator().MakeVector(m.InSize()) 138 | m.MapTranspose(sub, newSub) 139 | mapped = append(mapped, newSub) 140 | } 141 | return in.Creator().Concat(mapped...) 142 | } 143 | -------------------------------------------------------------------------------- /anyctc/conversion.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/unixpickle/anydiff/anyseq" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/anyvec/anyvec64" 9 | ) 10 | 11 | var internalCreator = anyvec64.DefaultCreator{} 12 | 13 | // vectorTo64 creates a vector with []float64 numeric list 14 | // types. 15 | func vectorTo64(v anyvec.Vector) anyvec.Vector { 16 | switch d := v.Data().(type) { 17 | case []float64: 18 | return internalCreator.MakeVectorData(d) 19 | case []float32: 20 | s := make([]float64, len(d)) 21 | for i, x := range d { 22 | s[i] = float64(x) 23 | } 24 | return internalCreator.MakeVectorData(s) 25 | default: 26 | panic(fmt.Sprintf("unsupported numeric type: %T", d)) 27 | } 28 | } 29 | 30 | func batchesTo64(v []*anyseq.Batch) []*anyseq.Batch { 31 | res := make([]*anyseq.Batch, len(v)) 32 | for i, x := range v { 33 | res[i] = &anyseq.Batch{ 34 | Packed: vectorTo64(x.Packed), 35 | Present: x.Present, 36 | } 37 | } 38 | return res 39 | } 40 | 41 | func batchesFrom64(c anyvec.Creator, v []*anyseq.Batch) []*anyseq.Batch { 42 | res := make([]*anyseq.Batch, len(v)) 43 | for i, x := range v { 44 | slice := x.Packed.Data().([]float64) 45 | res[i] = &anyseq.Batch{ 46 | Packed: c.MakeVectorData(c.MakeNumericList(slice)), 47 | Present: x.Present, 48 | } 49 | } 50 | return res 51 | } 52 | -------------------------------------------------------------------------------- /anyctc/cost.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anydiff/anyseq" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // Cost computes the negative log likelihood of each label 10 | // based on the outputs from an RNN. 11 | // The result is a packed vector with one entry per 12 | // sequence in the batch. 13 | // 14 | // For a sequence, suppose that all of the labels are 15 | // bounded between 0 and N. 16 | // Then there should be N+1 outputs at each timestep. 17 | // The first N outputs correspond to the N labels. 18 | // The last output is the special "blank" symbol. 19 | // The outputs are all in the log domain. 20 | // 21 | // The anyvec.Creator must use an anyvec.NumericList type 22 | // []float32 or []float64. 23 | // No other numeric types are supported. 24 | func Cost(seqs anyseq.Seq, labels [][]int) anydiff.Res { 25 | if len(seqs.Output()) == 0 { 26 | return anydiff.NewConst(seqs.Creator().MakeVector(0)) 27 | } 28 | return anydiff.Scale(pool(seqs, func(in [][]anydiff.Res) anydiff.Res { 29 | var res []anydiff.Res 30 | for i, x := range in { 31 | res = append(res, logLikelihood(internalCreator, x, labels[i])) 32 | } 33 | return anydiff.Concat(res...) 34 | }), seqs.Creator().MakeNumeric(-1)) 35 | } 36 | 37 | type poolRes struct { 38 | In anyseq.Seq 39 | Pools []*anydiff.Var 40 | Lengths []int 41 | Res anydiff.Res 42 | OutVec anyvec.Vector 43 | } 44 | 45 | func pool(seqs anyseq.Seq, f func(in [][]anydiff.Res) anydiff.Res) anydiff.Res { 46 | rawData := anyseq.SeparateSeqs(batchesTo64(seqs.Output())) 47 | pools := make([]*anydiff.Var, len(rawData)) 48 | splitPools := make([][]anydiff.Res, len(rawData)) 49 | lengths := make([]int, len(rawData)) 50 | for i, raw := range rawData { 51 | pools[i] = anydiff.NewVar(internalCreator.Concat(raw...)) 52 | splitPools[i] = splitRes(pools[i], len(raw)) 53 | lengths[i] = len(raw) 54 | } 55 | outRes := f(splitPools) 56 | convertedOut := seqs.Creator().MakeVectorData( 57 | seqs.Creator().MakeNumericList(outRes.Output().Data().([]float64)), 58 | ) 59 | return &poolRes{ 60 | In: seqs, 61 | Pools: pools, 62 | Lengths: lengths, 63 | Res: outRes, 64 | OutVec: convertedOut, 65 | } 66 | } 67 | 68 | func (p *poolRes) Output() anyvec.Vector { 69 | return p.OutVec 70 | } 71 | 72 | func (p *poolRes) Vars() anydiff.VarSet { 73 | return p.In.Vars() 74 | } 75 | 76 | func (p *poolRes) Propagate(u anyvec.Vector, g anydiff.Grad) { 77 | u = vectorTo64(u) 78 | 79 | tempGrad := anydiff.Grad{} 80 | for _, pvar := range p.Pools { 81 | tempGrad[pvar] = pvar.Vector.Creator().MakeVector(pvar.Vector.Len()) 82 | } 83 | p.Res.Propagate(u, tempGrad) 84 | downstream := make([][]anyvec.Vector, len(p.Pools)) 85 | for i, pvar := range p.Pools { 86 | downstream[i] = splitVec(tempGrad[pvar], p.Lengths[i]) 87 | delete(tempGrad, p.Pools[i]) 88 | } 89 | joinedU := anyseq.ConstSeqList(u.Creator(), downstream).Output() 90 | p.In.Propagate(batchesFrom64(p.In.Creator(), joinedU), g) 91 | } 92 | 93 | func splitVec(vec anyvec.Vector, parts int) []anyvec.Vector { 94 | res := make([]anyvec.Vector, parts) 95 | chunkSize := vec.Len() / parts 96 | for i := range res { 97 | res[i] = vec.Slice(i*chunkSize, (i+1)*chunkSize) 98 | } 99 | return res 100 | } 101 | 102 | func splitRes(res anydiff.Res, parts int) []anydiff.Res { 103 | if parts == 0 { 104 | return nil 105 | } 106 | reses := make([]anydiff.Res, parts) 107 | chunkSize := res.Output().Len() / parts 108 | for i := range reses { 109 | reses[i] = anydiff.Slice(res, i*chunkSize, (i+1)*chunkSize) 110 | } 111 | return reses 112 | } 113 | -------------------------------------------------------------------------------- /anyctc/cost_test.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anydiff/anydifftest" 9 | "github.com/unixpickle/anydiff/anyseq" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | ) 13 | 14 | func TestCostOutputs(t *testing.T) { 15 | probSeqs := [][][]float64{ 16 | {}, 17 | {{0.3, 0.2, 0.5}, {0.1, 0.5, 0.4}}, 18 | {}, 19 | } 20 | labels := [][]int{{}, {0, 1}, {1}} 21 | expectedProbs := []float64{1, 0.3 * 0.5, 0} 22 | inSeqs := logProbSeqs(anyvec32.CurrentCreator(), probSeqs) 23 | negOut := anydiff.Scale(Cost(inSeqs, labels), float32(-1)) 24 | actualProbs := anydiff.Exp(negOut).Output().Data().([]float32) 25 | for i, x := range expectedProbs { 26 | a := float64(actualProbs[i]) 27 | if math.Abs(x-a)/x > testPrecision { 28 | t.Errorf("output %d: expected %f but got %f", i, x, a) 29 | } 30 | } 31 | } 32 | 33 | func TestCostGrad(t *testing.T) { 34 | c := anyvec32.CurrentCreator() 35 | var vars []*anydiff.Var 36 | seqs := anyseq.ResSeq(c, []*anyseq.ResBatch{ 37 | {Packed: randomVar(c, 9, &vars), Present: []bool{true, true, true}}, 38 | {Packed: randomVar(c, 6, &vars), Present: []bool{true, false, true}}, 39 | {Packed: randomVar(c, 3, &vars), Present: []bool{false, false, true}}, 40 | {Packed: randomVar(c, 3, &vars), Present: []bool{false, false, true}}, 41 | }) 42 | labels := [][]int{{1, 0}, {0}, {0, 1, 2}} 43 | ch := anydifftest.ResChecker{ 44 | F: func() anydiff.Res { 45 | return Cost(seqs, labels) 46 | }, 47 | V: vars, 48 | Prec: testPrecision * 3, 49 | Delta: testPrecision, 50 | } 51 | ch.FullCheck(t) 52 | } 53 | 54 | func logProbSeqs(c anyvec.Creator, values [][][]float64) anyseq.Seq { 55 | vecLists := make([][]anyvec.Vector, len(values)) 56 | for i, seq := range values { 57 | vecLists[i] = make([]anyvec.Vector, len(seq)) 58 | for j, x := range seq { 59 | vecLists[i][j] = c.MakeVectorData(c.MakeNumericList(x)) 60 | anyvec.Log(vecLists[i][j]) 61 | } 62 | } 63 | return anyseq.ConstSeqList(c, vecLists) 64 | } 65 | 66 | func randomVar(c anyvec.Creator, n int, vs *[]*anydiff.Var) *anydiff.Var { 67 | v := c.MakeVector(n) 68 | anyvec.Rand(v, anyvec.Normal, nil) 69 | anyvec.LogSoftmax(v, 3) 70 | res := anydiff.NewVar(v) 71 | *vs = append(*vs, res) 72 | return res 73 | } 74 | -------------------------------------------------------------------------------- /anyctc/decode.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | 7 | "github.com/unixpickle/anydiff/anyseq" 8 | ) 9 | 10 | // BestLabels produces the most likely labelings for the 11 | // output sequences. 12 | // 13 | // The blankThresh argument specifies how greedy the 14 | // search should be with respect to blank symbols. 15 | // Typically, a value close to -1e-3 is sufficient. 16 | // As an example, a blankThresh of -0.0001 means that any 17 | // blank with probability greater than e^-0.0001 is 18 | // treated as if it had a 100% probability. 19 | // 20 | // A blankThresh of zero is not recommended unless the 21 | // input sequences are fairly short. 22 | func BestLabels(seqs anyseq.Seq, blankThresh float64) [][]int { 23 | var res [][]int 24 | for _, seq := range anyseq.SeparateSeqs(batchesTo64(seqs.Output())) { 25 | floatSeq := make([][]float64, len(seq)) 26 | for i, x := range seq { 27 | floatSeq[i] = x.Data().([]float64) 28 | } 29 | res = append(res, prefixSearch(floatSeq, blankThresh)) 30 | } 31 | return res 32 | } 33 | 34 | func prefixSearch(seq [][]float64, blankThresh float64) []int { 35 | var subSeqs [][][]float64 36 | var subSeq [][]float64 37 | for _, x := range seq { 38 | if x[len(x)-1] > blankThresh { 39 | if len(subSeq) > 0 { 40 | subSeqs = append(subSeqs, subSeq) 41 | subSeq = nil 42 | } 43 | } else { 44 | subSeq = append(subSeq, x) 45 | } 46 | } 47 | if len(subSeq) > 0 { 48 | subSeqs = append(subSeqs, subSeq) 49 | } 50 | 51 | var res []int 52 | for _, sub := range subSeqs { 53 | startProb := &labelProb{NoBlank: math.Inf(-1)} 54 | subRes, _ := subPrefixSearch(sub, nil, startProb) 55 | res = append(res, subRes...) 56 | } 57 | return res 58 | } 59 | 60 | func subPrefixSearch(seq [][]float64, prefix []int, prob *labelProb) ([]int, *labelProb) { 61 | if len(seq) == 0 { 62 | return prefix, prob 63 | } 64 | 65 | extensions := allExtensions(seq[0], prefix, prob) 66 | sort.Sort(extensionSorter(extensions)) 67 | 68 | bestProb := zeroLabelProb() 69 | bestSeq := []int{} 70 | for _, ext := range extensions { 71 | if ext.Prob.Total() > bestProb.Total() { 72 | extended := append(append([]int{}, prefix...), ext.Addition...) 73 | res, finalProb := subPrefixSearch(seq[1:], extended, ext.Prob) 74 | if finalProb.Total() > bestProb.Total() { 75 | bestProb = finalProb 76 | bestSeq = res 77 | } 78 | } 79 | } 80 | 81 | return bestSeq, bestProb 82 | } 83 | 84 | // labelProb represents the probability of a labeling, 85 | // split up into the probability of the labeling without a 86 | // trailing blank and with a trailing blank. 87 | type labelProb struct { 88 | Blank float64 89 | NoBlank float64 90 | } 91 | 92 | func zeroLabelProb() *labelProb { 93 | return &labelProb{Blank: math.Inf(-1), NoBlank: math.Inf(-1)} 94 | } 95 | 96 | func (l *labelProb) Total() float64 { 97 | return addLogs(l.Blank, l.NoBlank) 98 | } 99 | 100 | // possibleExtension represents a possible way to extend a 101 | // labeling (during prefix search). 102 | type possibleExtension struct { 103 | // Tokens added to the labeling by this extension. 104 | Addition []int 105 | 106 | // Probability of the extended labeling. 107 | Prob *labelProb 108 | } 109 | 110 | func allExtensions(next []float64, label []int, prob *labelProb) []*possibleExtension { 111 | var res []*possibleExtension 112 | for i, compProb := range next[:len(next)-1] { 113 | p := zeroLabelProb() 114 | if len(label) > 0 && i == label[len(label)-1] { 115 | p.NoBlank = compProb + prob.Blank 116 | } else { 117 | p.NoBlank = compProb + prob.Total() 118 | } 119 | res = append(res, &possibleExtension{Addition: []int{i}, Prob: p}) 120 | } 121 | noChangeProb := &labelProb{ 122 | Blank: prob.Total() + next[len(next)-1], 123 | NoBlank: math.Inf(-1), 124 | } 125 | if len(label) > 0 { 126 | last := label[len(label)-1] 127 | noChangeProb.NoBlank = prob.NoBlank + next[last] 128 | } 129 | return append(res, &possibleExtension{Prob: noChangeProb}) 130 | } 131 | 132 | // An extensionSorter sorts possible labeling extensions 133 | // from most to least probable. 134 | type extensionSorter []*possibleExtension 135 | 136 | func (e extensionSorter) Len() int { 137 | return len(e) 138 | } 139 | 140 | func (e extensionSorter) Swap(i, j int) { 141 | e[i], e[j] = e[j], e[i] 142 | } 143 | 144 | func (e extensionSorter) Less(i, j int) bool { 145 | return e[i].Prob.Total() > e[j].Prob.Total() 146 | } 147 | -------------------------------------------------------------------------------- /anyctc/decode_test.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | ) 11 | 12 | func TestBestLabels(t *testing.T) { 13 | var inputs = [][][]float32{ 14 | { 15 | {-9.21034037197618, -0.000100005000333347}, 16 | {-0.105360515657826, -2.302585092994046}, 17 | {-9.21034037197618, -0.000100005000333347}, 18 | {-0.105360515657826, -2.302585092994046}, 19 | {-9.21034037197618, -0.000100005000333347}, 20 | {-9.21034037197618, -0.000100005000333347}, 21 | }, 22 | { 23 | {-1.38155105579643e+01, -1.38155105579643e+01, -2.00000199994916e-06}, 24 | // The first label is not more likely, but 25 | // after both timesteps it has a 0.64% chance 26 | // of being seen in at least one of the two 27 | // timesteps. 28 | {-0.916290731874155, -13.815510557964274, -0.510827290434046}, 29 | {-0.916290731874155, -13.815510557964274, -0.510827290434046}, 30 | {-1.38155105579643e+01, -1.38155105579643e+01, -2.00000199994916e-06}, 31 | {-1.609437912434100, -0.693147180559945, -1.203972804325936}, 32 | }, 33 | { 34 | {-1.38155105579643e+01, -1.38155105579643e+01, -2.00000199994916e-06}, 35 | {-0.916290731874155, -13.815510557964274, -0.510827290434046}, 36 | {-1.38155105579643e+01, -1.38155105579643e+01, -2.00000199994916e-06}, 37 | {-1.609437912434100, -0.693147180559945, -1.203972804325936}, 38 | }, 39 | { 40 | {-0.916290731874155, -13.815510557964274, -0.510827290434046}, 41 | {-1.38155105579643e+01, -1.38155105579643e+01, -2.00000199994916e-06}, 42 | {-1.609437912434100, -0.693147180559945, -1.203972804325936}, 43 | }, 44 | } 45 | 46 | inLists := make([][]anyvec.Vector, len(inputs)) 47 | for i, x := range inputs { 48 | inLists[i] = make([]anyvec.Vector, len(x)) 49 | for j, y := range x { 50 | inLists[i][j] = anyvec32.MakeVectorData(y) 51 | } 52 | } 53 | in := anyseq.ConstSeqList(anyvec32.CurrentCreator(), inLists[1:]) 54 | smallIn := anyseq.ConstSeqList(anyvec32.CurrentCreator(), inLists[:1]) 55 | 56 | var expected = [][]int{ 57 | {0, 0}, 58 | {0, 1}, 59 | {1}, 60 | {1}, 61 | } 62 | 63 | for _, thresh := range []float64{-1e-2, -1e-3, -1e-6, -1e-10} { 64 | actual := append(BestLabels(smallIn, thresh), BestLabels(in, thresh)...) 65 | if !reflect.DeepEqual(actual, expected) { 66 | t.Errorf("thresh %e: expected %v but got %v", thresh, expected, actual) 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /anyctc/doc.go: -------------------------------------------------------------------------------- 1 | // Package anyctc implements Connectionist Temporal 2 | // Classification (CTC). 3 | // For more information on CTC, see this paper: 4 | // http://www.cs.toronto.edu/~graves/icml_2006.pdf. 5 | // 6 | // Much of the code in this package was inspired by my old 7 | // CTC package, the source code for which can be found at 8 | // https://github.com/unixpickle/speechrecog/blob/f6a0b091e6b69d9b4541ca2859f5de23fc2d26de/ctc/ctc.go. 9 | package anyctc 10 | -------------------------------------------------------------------------------- /anyctc/log_likelihood_test.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anydiff/anydifftest" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec64" 12 | ) 13 | 14 | const ( 15 | testSymbolCount = 5 16 | testPrecision = 1e-3 17 | ) 18 | 19 | func TestLogLikelihoodOutputs(t *testing.T) { 20 | c := anyvec64.CurrentCreator() 21 | for i := 0; i < 11; i++ { 22 | labelLen := 5 + rand.Intn(5) 23 | if i == 10 { 24 | labelLen = 0 25 | } 26 | seqLen := labelLen + rand.Intn(5) 27 | label := make([]int, labelLen) 28 | for i := range label { 29 | label[i] = rand.Intn(testSymbolCount) 30 | } 31 | seq, res := createTestSequence(c, seqLen, testSymbolCount) 32 | expected := exactLikelihood(seq, label, -1) 33 | outSlice := logLikelihood(c, res, label).Output().Data().([]float64) 34 | actual := math.Exp(outSlice[0]) 35 | if math.Abs(actual-expected)/math.Abs(expected) > testPrecision { 36 | t.Errorf("LogLikelihood gave log(%e) but expected log(%e)", 37 | actual, expected) 38 | } 39 | } 40 | } 41 | 42 | func TestLogLikelihoodGrad(t *testing.T) { 43 | c := anyvec64.CurrentCreator() 44 | label := make([]int, 5) 45 | for i := range label { 46 | label[i] = rand.Intn(testSymbolCount) 47 | } 48 | _, resSeq := createTestSequence(c, len(label)+5, testSymbolCount) 49 | var vars []*anydiff.Var 50 | for _, x := range resSeq { 51 | for v := range x.Vars() { 52 | vars = append(vars, v) 53 | } 54 | } 55 | ch := anydifftest.ResChecker{ 56 | F: func() anydiff.Res { 57 | return logLikelihood(c, resSeq, label) 58 | }, 59 | V: vars, 60 | Prec: testPrecision * 3, 61 | Delta: testPrecision, 62 | } 63 | ch.FullCheck(t) 64 | } 65 | 66 | // createTestSequence creates a test sequence. 67 | // 68 | // The sequence is produced in two forms. 69 | // First, a native sequence of []float64 containing actual 70 | // probabilities is produced. 71 | // Second, a sequence of anydiff.Res is produced with the 72 | // logs of the probabilities. 73 | func createTestSequence(c anyvec.Creator, seqLen, symCount int) ([][]float64, []anydiff.Res) { 74 | res := make([]anydiff.Res, seqLen) 75 | seq := make([][]float64, seqLen) 76 | for i := range seq { 77 | seq[i] = make([]float64, symCount+1) 78 | var probSum float64 79 | for j := range seq[i] { 80 | seq[i][j] = math.Abs(rand.NormFloat64()) 81 | probSum += seq[i][j] 82 | } 83 | for j := range seq[i] { 84 | seq[i][j] /= probSum 85 | } 86 | logVec := make([]float64, len(seq[i])) 87 | for j := range logVec { 88 | logVec[j] = math.Log(seq[i][j]) 89 | } 90 | res[i] = anydiff.NewVar(c.MakeVectorData(c.MakeNumericList(logVec))) 91 | } 92 | return seq, res 93 | } 94 | 95 | // exactLikelihood computes the log likelihood of a label 96 | // naively from a sequence of raw, unlogged probabilities. 97 | func exactLikelihood(seq [][]float64, label []int, lastSymbol int) float64 { 98 | if len(seq) == 0 { 99 | if len(label) == 0 { 100 | return 1 101 | } else { 102 | return 0 103 | } 104 | } 105 | 106 | next := seq[0] 107 | blank := len(next) - 1 108 | 109 | var res float64 110 | res += next[blank] * exactLikelihood(seq[1:], label, -1) 111 | if lastSymbol >= 0 { 112 | res += next[lastSymbol] * exactLikelihood(seq[1:], label, lastSymbol) 113 | } 114 | if len(label) > 0 && label[0] != lastSymbol { 115 | res += next[label[0]] * exactLikelihood(seq[1:], label[1:], label[0]) 116 | } 117 | return res 118 | } 119 | -------------------------------------------------------------------------------- /anyctc/samples.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "github.com/unixpickle/anynet/anysgd" 5 | "github.com/unixpickle/anyvec" 6 | ) 7 | 8 | // A Sample is a training sequence paired with its 9 | // corresponding label. 10 | type Sample struct { 11 | Input []anyvec.Vector 12 | Label []int 13 | } 14 | 15 | // A SampleList is an anysgd.SampleList that produces 16 | // CTC samples. 17 | type SampleList interface { 18 | anysgd.SampleList 19 | 20 | GetSample(idx int) (*Sample, error) 21 | Creator() anyvec.Creator 22 | } 23 | -------------------------------------------------------------------------------- /anyctc/trainer.go: -------------------------------------------------------------------------------- 1 | package anyctc 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet/anysgd" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/essentials" 11 | ) 12 | 13 | // A Batch stores a batch of input sequences and the 14 | // corresponding labels for each. 15 | type Batch struct { 16 | Inputs anyseq.Seq 17 | Labels [][]int 18 | } 19 | 20 | // A Trainer creates batches, computes gradients, and adds 21 | // up costs for CTC. 22 | type Trainer struct { 23 | Func func(anyseq.Seq) anyseq.Seq 24 | Params []*anydiff.Var 25 | 26 | // Average indicates whether or not the total cost should 27 | // be averaged before computing gradients. 28 | // This affects gradients, LastCost, and the output of 29 | // TotalCost(). 30 | Average bool 31 | 32 | // After every gradient computation, LastCost is set to 33 | // the cost from the batch. 34 | LastCost anyvec.Numeric 35 | } 36 | 37 | // Fetch produces a *Batch for the subset of samples. 38 | // The s argument must implement SampleList. 39 | // The batch may not be empty. 40 | func (t *Trainer) Fetch(s anysgd.SampleList) (anysgd.Batch, error) { 41 | if s.Len() == 0 { 42 | return nil, errors.New("fetch batch: empty batch") 43 | } 44 | l := s.(SampleList) 45 | ins := make([][]anyvec.Vector, l.Len()) 46 | outs := make([][]int, l.Len()) 47 | for i := 0; i < l.Len(); i++ { 48 | sample, err := l.GetSample(i) 49 | if err != nil { 50 | return nil, essentials.AddCtx("fetch batch", err) 51 | } 52 | ins[i] = sample.Input 53 | outs[i] = sample.Label 54 | } 55 | return &Batch{ 56 | Inputs: anyseq.ConstSeqList(l.Creator(), ins), 57 | Labels: outs, 58 | }, nil 59 | } 60 | 61 | // TotalCost computes the total cost for the *Batch. 62 | // 63 | // For more information on how this works, see Cost(). 64 | func (t *Trainer) TotalCost(batch anysgd.Batch) anydiff.Res { 65 | b := batch.(*Batch) 66 | actual := t.Func(b.Inputs) 67 | costs := Cost(actual, b.Labels) 68 | sum := anydiff.Sum(costs) 69 | if t.Average { 70 | scaler := sum.Output().Creator().MakeNumeric(1 / float64(costs.Output().Len())) 71 | return anydiff.Scale(sum, scaler) 72 | } else { 73 | return sum 74 | } 75 | } 76 | 77 | // Gradient computes the gradient for the batch's cost. 78 | // It also sets t.LastCost to the numerical value of the 79 | // total cost. 80 | // 81 | // The b argument must be a *Batch. 82 | func (t *Trainer) Gradient(b anysgd.Batch) anydiff.Grad { 83 | grad, lc := anysgd.CosterGrad(t, b, t.Params) 84 | t.LastCost = lc 85 | return grad 86 | } 87 | -------------------------------------------------------------------------------- /anyff/doc.go: -------------------------------------------------------------------------------- 1 | // Package anyff provides various APIs for training 2 | // feed-forward neural networks with SGD. 3 | package anyff 4 | -------------------------------------------------------------------------------- /anyff/samples.go: -------------------------------------------------------------------------------- 1 | package anyff 2 | 3 | import ( 4 | "github.com/unixpickle/anynet/anysgd" 5 | "github.com/unixpickle/anyvec" 6 | ) 7 | 8 | // A Sample is a training sample for a feed-forward neural 9 | // network. 10 | // It indicates the network's input and the target output. 11 | type Sample struct { 12 | Input anyvec.Vector 13 | Output anyvec.Vector 14 | } 15 | 16 | // A SampleList is an anysgd.SampleList that produces 17 | // feed-forward samples. 18 | type SampleList interface { 19 | anysgd.SampleList 20 | 21 | GetSample(idx int) (*Sample, error) 22 | } 23 | 24 | // A SliceSampleList is a concrete SampleList with 25 | // predetermined samples. 26 | type SliceSampleList []*Sample 27 | 28 | // Len returns the number of samples. 29 | func (s SliceSampleList) Len() int { 30 | return len(s) 31 | } 32 | 33 | // Swap swaps two samples. 34 | func (s SliceSampleList) Swap(i, j int) { 35 | s[i], s[j] = s[j], s[i] 36 | } 37 | 38 | // Slice copies a sub-slice of the list. 39 | func (s SliceSampleList) Slice(i, j int) anysgd.SampleList { 40 | return append(SliceSampleList{}, s[i:j]...) 41 | } 42 | 43 | // GetSample returns the sample at the index. 44 | func (s SliceSampleList) GetSample(idx int) (*Sample, error) { 45 | return s[idx], nil 46 | } 47 | -------------------------------------------------------------------------------- /anyff/trainer.go: -------------------------------------------------------------------------------- 1 | package anyff 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | "sync" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anynet/anysgd" 11 | "github.com/unixpickle/anyvec" 12 | "github.com/unixpickle/essentials" 13 | ) 14 | 15 | // A Batch stores an input and output batch in a packed 16 | // format. 17 | type Batch struct { 18 | Inputs *anydiff.Const 19 | Outputs *anydiff.Const 20 | Num int 21 | } 22 | 23 | // A Trainer can construct batches, compute gradients, and 24 | // tally up costs for feed-forward neural networks. 25 | type Trainer struct { 26 | Net anynet.Layer 27 | Cost anynet.Cost 28 | Params []*anydiff.Var 29 | 30 | // Average indicates whether or not the total cost should 31 | // be averaged before computing gradients. 32 | // This affects gradients, LastCost, and the output of 33 | // TotalCost(). 34 | Average bool 35 | 36 | // After every gradient computation, LastCost is set to 37 | // the cost from the batch. 38 | LastCost anyvec.Numeric 39 | 40 | // MaxGos specifies the maximum goroutines to use 41 | // simultaneously for fetching samples. 42 | // If it is 0, GOMAXPROCS is used. 43 | MaxGos int 44 | } 45 | 46 | // Fetch produces a *Batch for the subset of samples. 47 | // The s argument must implement SampleList. 48 | // The batch may not be empty. 49 | func (t *Trainer) Fetch(s anysgd.SampleList) (anysgd.Batch, error) { 50 | if s.Len() == 0 { 51 | return nil, errors.New("fetch batch: empty batch") 52 | } 53 | 54 | l := s.(SampleList) 55 | ins := make([]anyvec.Vector, l.Len()) 56 | outs := make([]anyvec.Vector, l.Len()) 57 | 58 | idxChan := make(chan int, l.Len()) 59 | for i := 0; i < l.Len(); i++ { 60 | idxChan <- i 61 | } 62 | close(idxChan) 63 | 64 | maxGos := t.MaxGos 65 | if maxGos == 0 { 66 | maxGos = runtime.GOMAXPROCS(0) 67 | } 68 | 69 | wg := sync.WaitGroup{} 70 | errChan := make(chan error, maxGos) 71 | for i := 0; i < maxGos; i++ { 72 | wg.Add(1) 73 | go func() { 74 | defer wg.Done() 75 | for i := range idxChan { 76 | sample, err := l.GetSample(i) 77 | if err != nil { 78 | errChan <- essentials.AddCtx("fetch batch", err) 79 | return 80 | } 81 | ins[i] = sample.Input 82 | outs[i] = sample.Output 83 | } 84 | }() 85 | } 86 | 87 | wg.Wait() 88 | close(errChan) 89 | 90 | if err := <-errChan; err != nil { 91 | return nil, err 92 | } 93 | 94 | joinedIns := ins[0].Creator().Concat(ins...) 95 | joinedOuts := outs[0].Creator().Concat(outs...) 96 | 97 | return &Batch{ 98 | Inputs: anydiff.NewConst(joinedIns), 99 | Outputs: anydiff.NewConst(joinedOuts), 100 | Num: l.Len(), 101 | }, nil 102 | } 103 | 104 | // TotalCost computes the total cost for the *Batch. 105 | func (t *Trainer) TotalCost(batch anysgd.Batch) anydiff.Res { 106 | b := batch.(*Batch) 107 | outRes := t.Net.Apply(b.Inputs, b.Num) 108 | cost := t.Cost.Cost(b.Outputs, outRes, b.Num) 109 | total := anydiff.Sum(cost) 110 | if t.Average { 111 | divisor := 1 / float64(cost.Output().Len()) 112 | return anydiff.Scale(total, total.Output().Creator().MakeNumeric(divisor)) 113 | } else { 114 | return total 115 | } 116 | } 117 | 118 | // Gradient computes the gradient for the batch's cost. 119 | // It also sets t.LastCost to the numerical value of the 120 | // total cost. 121 | // 122 | // The b argument must be a *Batch. 123 | func (t *Trainer) Gradient(b anysgd.Batch) anydiff.Grad { 124 | grad, lc := anysgd.CosterGrad(t, b, t.Params) 125 | t.LastCost = lc 126 | return grad 127 | } 128 | -------------------------------------------------------------------------------- /anymisc/doc.go: -------------------------------------------------------------------------------- 1 | // Package anymisc implements neural network components 2 | // that are not mainstream enough to be part of anyrnn, 3 | // anynet, or anydiff. 4 | package anymisc 5 | -------------------------------------------------------------------------------- /anymisc/gumbel_softmax.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | "github.com/unixpickle/essentials" 7 | "github.com/unixpickle/serializer" 8 | ) 9 | 10 | const gumbelEpsilon = 1e-10 11 | 12 | func init() { 13 | serializer.RegisterTypedDeserializer((&GumbelSoftmax{}).SerializerType(), 14 | DeserializeGumbelSoftmax) 15 | } 16 | 17 | // GumbelSoftmax is an anynet.Layer that implements the 18 | // Gumbel-Softmax distribution. 19 | // 20 | // For more, see https://arxiv.org/abs/1611.01144. 21 | type GumbelSoftmax struct { 22 | Temperature float64 23 | } 24 | 25 | // DeserializeGumbelSoftmax deserializes a GumbelSoftmax. 26 | func DeserializeGumbelSoftmax(d []byte) (*GumbelSoftmax, error) { 27 | var res GumbelSoftmax 28 | if err := serializer.DeserializeAny(d, &res.Temperature); err != nil { 29 | return nil, essentials.AddCtx("deserialize GumbelSoftmax", err) 30 | } 31 | return &res, nil 32 | } 33 | 34 | // Apply applies the Gumbel-Softmax to each vector in the 35 | // batch. 36 | func (g *GumbelSoftmax) Apply(in anydiff.Res, n int) anydiff.Res { 37 | c := in.Output().Creator() 38 | gumbel := c.MakeVector(in.Output().Len()) 39 | anyvec.Rand(gumbel, anyvec.Uniform, nil) 40 | for i := 0; i < 2; i++ { 41 | gumbel.AddScalar(c.MakeNumeric(gumbelEpsilon)) 42 | anyvec.Log(gumbel) 43 | gumbel.Scale(c.MakeNumeric(-1)) 44 | } 45 | smIn := anydiff.Scale(anydiff.Add(anydiff.NewConst(gumbel), in), 46 | c.MakeNumeric(1/g.Temperature)) 47 | return anydiff.Exp(anydiff.LogSoftmax(smIn, in.Output().Len()/n)) 48 | } 49 | 50 | // SerializerType returns the unique ID used to serialize 51 | // a GumbelSoftmax with the serializer package. 52 | func (g *GumbelSoftmax) SerializerType() string { 53 | return "github.com/unixpickle/anynet/anymisc.GumbelSoftmax" 54 | } 55 | 56 | // Serialize serializes the layer. 57 | func (g *GumbelSoftmax) Serialize() ([]byte, error) { 58 | return serializer.SerializeAny(g.Temperature) 59 | } 60 | -------------------------------------------------------------------------------- /anymisc/gumbel_softmax_test.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | ) 11 | 12 | func TestGumbelSoftmax(t *testing.T) { 13 | c := anyvec32.CurrentCreator() 14 | vec := c.MakeVectorData(c.MakeNumericList([]float64{ 15 | 1.306852819440055, 16 | 0.796027195674064, 17 | 0.390562087565900, 18 | })) 19 | maxHistogram := map[int]int{} 20 | layer := &GumbelSoftmax{Temperature: 1} 21 | numIters := 10000 22 | for i := 0; i < numIters; i++ { 23 | out := layer.Apply(anydiff.NewConst(vec), 1).Output() 24 | maxHistogram[anyvec.MaxIndex(out)]++ 25 | } 26 | expectedFracs := []float64{0.5, 0.3, 0.2} 27 | for idx, count := range maxHistogram { 28 | frac := float64(count) / float64(numIters) 29 | expected := expectedFracs[idx] 30 | if math.Abs(frac-expected) > 1e-2 { 31 | t.Error("bad histogram:", maxHistogram) 32 | break 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /anymisc/relu_rnn.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "github.com/unixpickle/anynet" 5 | "github.com/unixpickle/anynet/anyrnn" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // NewIRNN creates an RNN with ReLU activations and a 10 | // scaled identity state transition matrix. 11 | // 12 | // This is based on https://arxiv.org/abs/1504.00941. 13 | func NewIRNN(c anyvec.Creator, in, out int, scale float64) *anyrnn.Vanilla { 14 | res := anyrnn.NewVanilla(c, in, out, anynet.ReLU) 15 | res.StateWeights.Vector.Set(identityMatrix(c, out, scale).Data) 16 | return res 17 | } 18 | 19 | // NewNPRNN creates an RNN with ReLU activations and a 20 | // positive-definite transition matrix with bounded 21 | // eigenvalues. 22 | // 23 | // This is based on https://arxiv.org/abs/1511.03771. 24 | func NewNPRNN(c anyvec.Creator, in, out int) *anyrnn.Vanilla { 25 | res := anyrnn.NewVanilla(c, in, out, anynet.ReLU) 26 | 27 | factor := &anyvec.Matrix{ 28 | Data: c.MakeVector(out * out), 29 | Rows: out, 30 | Cols: out, 31 | } 32 | anyvec.Rand(factor.Data, anyvec.Normal, nil) 33 | posDef := identityMatrix(c, out, 1) 34 | posDef.Product(true, false, c.MakeNumeric(1/float64(out)), factor, factor, 35 | c.MakeNumeric(1)) 36 | 37 | res.StateWeights.Vector.Set(posDef.Data) 38 | res.StateWeights.Vector.Scale(inverseLargestEig(posDef)) 39 | 40 | return res 41 | } 42 | 43 | func identityMatrix(c anyvec.Creator, size int, scale float64) *anyvec.Matrix { 44 | hiddenIndices := make([]int, size) 45 | for i := range hiddenIndices { 46 | hiddenIndices[i] = i * (size + 1) 47 | } 48 | mapper := c.MakeMapper(size*size, hiddenIndices) 49 | diagonal := c.MakeVector(size) 50 | diagonal.AddScalar(c.MakeNumeric(scale)) 51 | 52 | res := c.MakeVector(size * size) 53 | mapper.MapTranspose(diagonal, res) 54 | return &anyvec.Matrix{Data: res, Rows: size, Cols: size} 55 | } 56 | 57 | func inverseLargestEig(mat *anyvec.Matrix) anyvec.Numeric { 58 | const numIters = 100 59 | 60 | c := mat.Data.Creator() 61 | ops := c.NumOps() 62 | inVec := c.MakeVector(mat.Cols) 63 | outVec := c.MakeVector(mat.Cols) 64 | 65 | anyvec.Rand(inVec, anyvec.Normal, nil) 66 | 67 | // Power iteration method: it's slow, but it works. 68 | for i := 0; i < numIters; i++ { 69 | mag := anyvec.Norm(inVec) 70 | inVec.Scale(ops.Div(c.MakeNumeric(1), mag)) 71 | anyvec.Gemv(false, mat.Rows, mat.Cols, c.MakeNumeric(1), mat.Data, 72 | mat.Cols, inVec, 1, c.MakeNumeric(0), outVec, 1) 73 | inVec, outVec = outVec, inVec 74 | } 75 | 76 | return ops.Div(c.MakeNumeric(1), anyvec.Norm(inVec)) 77 | } 78 | -------------------------------------------------------------------------------- /anymisc/relu_rnn_test.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/anyvec/anyvec64" 9 | ) 10 | 11 | func TestIRNN(t *testing.T) { 12 | c := anyvec64.DefaultCreator{} 13 | rnn := NewIRNN(c, 5, 3, 2) 14 | expectedDiag := c.MakeVectorData([]float64{ 15 | 2, 0, 0, 16 | 0, 2, 0, 17 | 0, 0, 2, 18 | }) 19 | if !reflect.DeepEqual(rnn.StateWeights.Vector, expectedDiag) { 20 | t.Errorf("expected %v but got %v", expectedDiag, rnn.StateWeights.Vector) 21 | } 22 | } 23 | 24 | func TestNPRNN(t *testing.T) { 25 | c := anyvec64.DefaultCreator{} 26 | rnn := NewNPRNN(c, 5, 4) 27 | matrix := &anyvec.Matrix{ 28 | Data: rnn.StateWeights.Vector, 29 | Rows: 4, 30 | Cols: 4, 31 | } 32 | for i := 0; i < 30; i++ { 33 | inVec := &anyvec.Matrix{ 34 | Data: c.MakeVector(matrix.Cols), 35 | Rows: matrix.Cols, 36 | Cols: 1, 37 | } 38 | product := &anyvec.Matrix{ 39 | Data: c.MakeVector(matrix.Rows), 40 | Rows: matrix.Rows, 41 | Cols: 1, 42 | } 43 | anyvec.Rand(inVec.Data, anyvec.Normal, nil) 44 | inMag := anyvec.Norm(inVec.Data) 45 | product.Product(false, false, c.MakeNumeric(1), matrix, inVec, c.MakeNumeric(0)) 46 | outMag := anyvec.Norm(product.Data) 47 | if inMag.(float64) < outMag.(float64) { 48 | t.Error("an eigenvalue was too big") 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /anymisc/selu.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/essentials" 6 | "github.com/unixpickle/serializer" 7 | ) 8 | 9 | const ( 10 | seluDefaultAlpha = 1.6732632423543772848170429916717 11 | seluDefaultLambda = 1.0507009873554804934193349852946 12 | ) 13 | 14 | func init() { 15 | serializer.RegisterTypedDeserializer((&SELU{}).SerializerType(), DeserializeSELU) 16 | } 17 | 18 | // SELU implements the scaled exponential linear unit 19 | // activation function. 20 | // 21 | // If a field is 0, a default is used. 22 | // The default values are meant to induce a fixed point of 23 | // mean 0 and variance 1. 24 | // 25 | // For more on SELU, see https://arxiv.org/abs/1706.02515. 26 | type SELU struct { 27 | Alpha float64 28 | Lambda float64 29 | } 30 | 31 | // DeserializeSELU deserializes a SELU instance. 32 | func DeserializeSELU(d []byte) (*SELU, error) { 33 | var res SELU 34 | if err := serializer.DeserializeAny(d, &res.Alpha, &res.Lambda); err != nil { 35 | return nil, essentials.AddCtx("deserialize SELU", err) 36 | } 37 | return &res, nil 38 | } 39 | 40 | // Apply applies the activation function. 41 | func (s *SELU) Apply(in anydiff.Res, n int) anydiff.Res { 42 | alpha := s.Alpha 43 | lambda := s.Lambda 44 | if alpha == 0 { 45 | alpha = seluDefaultAlpha 46 | } 47 | if lambda == 0 { 48 | lambda = seluDefaultLambda 49 | } 50 | 51 | c := in.Output().Creator() 52 | return anydiff.Pool(in, func(in anydiff.Res) anydiff.Res { 53 | posPart := anydiff.ClipPos(in) 54 | negPart := clipNeg(in) 55 | return anydiff.Scale( 56 | anydiff.AddScalar( 57 | anydiff.Add( 58 | posPart, 59 | anydiff.Scale(anydiff.Exp(negPart), c.MakeNumeric(alpha)), 60 | ), 61 | c.MakeNumeric(-alpha), 62 | ), 63 | c.MakeNumeric(lambda), 64 | ) 65 | }) 66 | } 67 | 68 | // SerializerType returns the unique ID used to serialize 69 | // a SELU with the serializer package. 70 | func (s *SELU) SerializerType() string { 71 | return "github.com/unixpickle/anynet/anymisc.SELU" 72 | } 73 | 74 | // Serialize serializes the SELU. 75 | func (s *SELU) Serialize() ([]byte, error) { 76 | return serializer.SerializeAny(s.Alpha, s.Lambda) 77 | } 78 | 79 | func clipNeg(vec anydiff.Res) anydiff.Res { 80 | c := vec.Output().Creator() 81 | return anydiff.Scale( 82 | anydiff.ClipPos(anydiff.Scale(vec, c.MakeNumeric(-1))), 83 | c.MakeNumeric(-1), 84 | ) 85 | } 86 | -------------------------------------------------------------------------------- /anymisc/selu_test.go: -------------------------------------------------------------------------------- 1 | package anymisc 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | "github.com/unixpickle/serializer" 12 | ) 13 | 14 | func TestSELU(t *testing.T) { 15 | c := anyvec64.DefaultCreator{} 16 | vec := c.MakeVector(512) 17 | anyvec.Rand(vec, anyvec.Normal, nil) 18 | 19 | expectedData := make([]float64, vec.Len()) 20 | for i, in := range vec.Data().([]float64) { 21 | expectedData[i] = scalarSELU(in) 22 | } 23 | expected := c.MakeVectorData(expectedData) 24 | 25 | actual := (&SELU{}).Apply(anydiff.NewVar(vec), 4).Output() 26 | diff := actual.Copy() 27 | diff.Sub(expected) 28 | if anyvec.AbsMax(diff).(float64) > 1e-4 { 29 | t.Error("bad output vector") 30 | } 31 | } 32 | 33 | func TestSELUSerialize(t *testing.T) { 34 | s := &SELU{Alpha: 3, Lambda: 5} 35 | data, err := serializer.SerializeAny(s) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | var s1 *SELU 41 | if err := serializer.DeserializeAny(data, &s1); err != nil { 42 | t.Fatal(err) 43 | } 44 | 45 | if !reflect.DeepEqual(s, s1) { 46 | t.Fatal("bad value") 47 | } 48 | } 49 | 50 | func scalarSELU(x float64) float64 { 51 | if x > 0 { 52 | return seluDefaultLambda * x 53 | } else { 54 | return seluDefaultLambda * (seluDefaultAlpha*math.Exp(x) - seluDefaultAlpha) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /anynet.go: -------------------------------------------------------------------------------- 1 | // Package anynet provides APIs for running and training 2 | // artificial neural networks. 3 | // It includes sub-packages for common neural network 4 | // variations. 5 | package anynet 6 | 7 | import ( 8 | "fmt" 9 | 10 | "github.com/unixpickle/anydiff" 11 | "github.com/unixpickle/essentials" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func init() { 16 | var n Net 17 | serializer.RegisterTypedDeserializer(n.SerializerType(), DeserializeNet) 18 | } 19 | 20 | // A Parameterizer is anything with learnable variables. 21 | // 22 | // The parameters of a Parameterizer must be in the same 23 | // order every time Parameters() is called. 24 | type Parameterizer interface { 25 | Parameters() []*anydiff.Var 26 | } 27 | 28 | // AllParameters gathers the parameters from every 29 | // argument that implements Parameterizer, ignoring the 30 | // arguments that don't. 31 | // Parameters are fetched in order from the first to the 32 | // last argument. 33 | func AllParameters(args ...interface{}) []*anydiff.Var { 34 | var res []*anydiff.Var 35 | for _, x := range args { 36 | if p, ok := x.(Parameterizer); ok { 37 | res = append(res, p.Parameters()...) 38 | } 39 | } 40 | return res 41 | } 42 | 43 | // A Layer is a composable computation unit for use in a 44 | // neural network. 45 | // In a feed-forward network, each layer's output is fed 46 | // into the next layers input. 47 | // 48 | // A Layer's Apply method is inherently batched. 49 | // The input's length must be divisible by the batch size, 50 | // since the batch size indicates how many equally-long 51 | // vectors are packed into the input vector. 52 | type Layer interface { 53 | Apply(in anydiff.Res, batchSize int) anydiff.Res 54 | } 55 | 56 | // A Net evaluates a list of layers, one after another. 57 | type Net []Layer 58 | 59 | // DeserializeNet attempts to deserialize the network. 60 | func DeserializeNet(d []byte) (Net, error) { 61 | slice, err := serializer.DeserializeSlice(d) 62 | if err != nil { 63 | return nil, essentials.AddCtx("deserialize Net", err) 64 | } 65 | res := make(Net, len(slice)) 66 | for i, x := range slice { 67 | if layer, ok := x.(Layer); ok { 68 | res[i] = layer 69 | } else { 70 | return nil, fmt.Errorf("deserialize Net: not a Layer: %T", x) 71 | } 72 | } 73 | return res, nil 74 | } 75 | 76 | // Apply applies the network to a batch. 77 | // If the network contains no layers, the input is 78 | // returned as output. 79 | func (n Net) Apply(in anydiff.Res, batchSize int) anydiff.Res { 80 | for _, l := range n { 81 | in = l.Apply(in, batchSize) 82 | } 83 | return in 84 | } 85 | 86 | // Parameters returns the parameters of the network. 87 | // 88 | // This is equivalent to calling AllParameters on a slice 89 | // containing the layers of n (in order). 90 | func (n Net) Parameters() []*anydiff.Var { 91 | interfaces := make([]interface{}, len(n)) 92 | for i, x := range n { 93 | interfaces[i] = x 94 | } 95 | return AllParameters(interfaces...) 96 | } 97 | 98 | // SerializerType returns the unique ID used to serialize 99 | // a Net with the serializer package. 100 | func (n Net) SerializerType() string { 101 | return "github.com/unixpickle/anynet.Net" 102 | } 103 | 104 | // Serialize attempts to serialize the network. 105 | // If any Layer is not a serializer.Serializer, 106 | // this fails. 107 | func (n Net) Serialize() ([]byte, error) { 108 | var slice []serializer.Serializer 109 | for _, x := range n { 110 | if s, ok := x.(serializer.Serializer); ok { 111 | slice = append(slice, s) 112 | } else { 113 | return nil, fmt.Errorf("not a Serializer: %T", x) 114 | } 115 | } 116 | return serializer.SerializeSlice(slice) 117 | } 118 | -------------------------------------------------------------------------------- /anyrnn/anyrnn.go: -------------------------------------------------------------------------------- 1 | // Package anyrnn implements recurrent neural networks. 2 | package anyrnn 3 | 4 | import ( 5 | "github.com/unixpickle/anydiff" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // A PresentMap is used to indicate which sequences are 10 | // present in a State and which ones are not. 11 | // A true value indicates present. 12 | // 13 | // See State for more details on how PresentMap is used. 14 | type PresentMap []bool 15 | 16 | // NumPresent counts the present sequences. 17 | func (p PresentMap) NumPresent() int { 18 | var i int 19 | for _, x := range p { 20 | if x { 21 | i++ 22 | } 23 | } 24 | return i 25 | } 26 | 27 | // A State stores a batch of internal Block states. 28 | // 29 | // Since RNNs are typically used to evaluate sequences, a 30 | // state has the idea of present and absent sequences. 31 | // A present sequence is one which has not yet terminated, 32 | // thus we need the state of the RNN for that sequence. 33 | // An absent sequence has already finished, so we don't 34 | // need to track its state. 35 | // 36 | // If an RNN is being evaluated on sequences of differing 37 | // lengths, then the idea of present/active is essential. 38 | type State interface { 39 | // Present provides information about which sequences 40 | // have states in the batch. 41 | Present() PresentMap 42 | 43 | // Reduce creates a copy of the State with a new 44 | // PresentMap. 45 | // It is intended to be used to remove states from the 46 | // batch when some sequences end during RNN evaluation. 47 | // 48 | // The PresentMap must be a subset of Present(), meaning 49 | // that every true value in it must be true in Present(). 50 | // The method is called "Reduce" because it can only be 51 | // used to remove some states from the batch; it cannot 52 | // add states to the batch. 53 | Reduce(PresentMap) State 54 | } 55 | 56 | // A StateGrad is an upstream gradient for a State. 57 | // It is used while back-propagating through a Block. 58 | type StateGrad interface { 59 | // Present provides information about which sequences 60 | // have upstream gradients in the batch. 61 | Present() PresentMap 62 | 63 | // Expand inserts zero gradients as necessary to expand 64 | // the present map. 65 | // The resulting StateGrad will include all of the 66 | // sequences from Present() and all of the sequences from 67 | // the passed PresentMap. 68 | // 69 | // Expand is the inverse of State.Reduce(). 70 | Expand(PresentMap) StateGrad 71 | } 72 | 73 | // A Block is a differentiable unit in an RNN. 74 | // It receives an input/state batch and produces a batch 75 | // of outputs and new states. 76 | type Block interface { 77 | // Start produces the start state with a batch size of n. 78 | Start(n int) State 79 | 80 | // PropagateStart back-propagates through the start 81 | // state. 82 | // After this is called, s should not be used again. 83 | PropagateStart(s StateGrad, g anydiff.Grad) 84 | 85 | // Step applies the block for a single timestep. 86 | Step(s State, in anyvec.Vector) Res 87 | } 88 | 89 | // A Res represents the output of a Block and is used to 90 | // back-propagate through a Block. 91 | type Res interface { 92 | // State returns the output state batch. 93 | State() State 94 | 95 | // Output returns the Block outputs. 96 | Output() anyvec.Vector 97 | 98 | // Vars returns the variables upon which the output 99 | // depends, including variables from previous states. 100 | Vars() anydiff.VarSet 101 | 102 | // Propagate propagates the gradient for one timestep. 103 | // It takes an upstream vector u for the output, an 104 | // upstream StateGrad s for the output state, and the 105 | // output gradient to which partials should be added. 106 | // 107 | // It returns a downstream vector for the input and a 108 | // StateGrad for the previous timestep. 109 | // 110 | // The upstream state s may be nil, indicating a zero 111 | // upstream. 112 | // This is useful for the final timestep, whose state is 113 | // never used for anything. 114 | // 115 | // All upstream objects may be modified. 116 | // A call to Propagate may change u and s, meaning that s 117 | // in particular should not be used again. 118 | // 119 | // The downstream input vector may be modified by the 120 | // caller (e.g. as scratch space). 121 | // Modifying said vector should not affect the returned 122 | // downstream StateGrad. 123 | Propagate(u anyvec.Vector, s StateGrad, g anydiff.Grad) (anyvec.Vector, StateGrad) 124 | } 125 | -------------------------------------------------------------------------------- /anyrnn/bidir.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anydiff/anyseq" 6 | "github.com/unixpickle/anynet" 7 | "github.com/unixpickle/essentials" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | func init() { 12 | var b Bidir 13 | serializer.RegisterTypedDeserializer(b.SerializerType(), DeserializeBidir) 14 | } 15 | 16 | // Bidir implements a bi-directional RNN. 17 | // 18 | // In a bi-directional RNN, a forward block is evaluated 19 | // on the input sequence, while a backward block is mapped 20 | // over the reversed input sequence. 21 | // Then, outputs from the forward and backward block for 22 | // corresponding timesteps in the original sequence are 23 | // combined using the mixer. 24 | // 25 | // The first input to the Mixer is from the forward block; 26 | // the second is from the backward block. 27 | type Bidir struct { 28 | Forward Block 29 | Backward Block 30 | Mixer anynet.Mixer 31 | } 32 | 33 | // DeserializeBidir deserializes a Bidir. 34 | func DeserializeBidir(d []byte) (*Bidir, error) { 35 | var res Bidir 36 | err := serializer.DeserializeAny(d, &res.Forward, &res.Backward, &res.Mixer) 37 | if err != nil { 38 | return nil, essentials.AddCtx("deserialize Bidir", err) 39 | } 40 | return &res, nil 41 | } 42 | 43 | // Apply applies the bidirectional RNN. 44 | func (b *Bidir) Apply(in anyseq.Seq) anyseq.Seq { 45 | return anyseq.Pool(in, func(in anyseq.Seq) anyseq.Seq { 46 | forwOut := Map(in, b.Forward) 47 | backOut := anyseq.Reverse(Map(anyseq.Reverse(in), b.Backward)) 48 | return anyseq.MapN(func(n int, v ...anydiff.Res) anydiff.Res { 49 | return b.Mixer.Mix(v[0], v[1], n) 50 | }, forwOut, backOut) 51 | }) 52 | } 53 | 54 | // Parameters returns the parameters of the blocks and 55 | // Mixer if they implement anynet.Parameterizer. 56 | func (b *Bidir) Parameters() []*anydiff.Var { 57 | return anynet.AllParameters(b.Forward, b.Backward, b.Mixer) 58 | } 59 | 60 | // SerializerType returns the unique ID used to serialize 61 | // a Bidir with the serializer package. 62 | func (b *Bidir) SerializerType() string { 63 | return "github.com/unixpickle/anynet/anyrnn.Bidir" 64 | } 65 | 66 | // Serialize serializes the Bidir. 67 | func (b *Bidir) Serialize() ([]byte, error) { 68 | return serializer.SerializeAny(b.Forward, b.Backward, b.Mixer) 69 | } 70 | -------------------------------------------------------------------------------- /anyrnn/feedback_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec64" 12 | ) 13 | 14 | func TestFeedbackProp(t *testing.T) { 15 | c := anyvec64.CurrentCreator() 16 | vec := c.MakeVector(2) 17 | anyvec.Rand(vec, anyvec.Normal, nil) 18 | block := &Feedback{ 19 | Mixer: &anynet.AddMixer{ 20 | In1: anynet.NewFC(c, 3, 3), 21 | In2: anynet.NewFC(c, 2, 3), 22 | Out: anynet.Tanh, 23 | }, 24 | Block: NewLSTM(c, 3, 2), 25 | InitOut: anydiff.NewVar(vec), 26 | } 27 | inSeq, inVars := randomTestSequence(c, 3) 28 | if len(block.Parameters()) != 23 { 29 | t.Errorf("expected 23 parameters, but got %d", len(block.Parameters())) 30 | } 31 | checker := &anydifftest.SeqChecker{ 32 | F: func() anyseq.Seq { 33 | return Map(inSeq, block) 34 | }, 35 | V: append(inVars, block.Parameters()...), 36 | } 37 | checker.FullCheck(t) 38 | } 39 | -------------------------------------------------------------------------------- /anyrnn/func.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | ) 7 | 8 | // A FuncBlock is a Block which applies a function to 9 | // transform a state-input pair into a state-output pair. 10 | type FuncBlock struct { 11 | // Func applies the block. 12 | // 13 | // If out is nil, newState is used as the output of the 14 | // block as well as the state. 15 | Func func(in, state anydiff.Res, batch int) (out, newState anydiff.Res) 16 | 17 | // MakeStart produces the initial state vector. 18 | MakeStart func(n int) anydiff.Res 19 | } 20 | 21 | // Start generates an initial *FuncBlockState. 22 | func (f *FuncBlock) Start(n int) State { 23 | r := f.MakeStart(n) 24 | pm := make(PresentMap, n) 25 | for i := range pm { 26 | pm[i] = true 27 | } 28 | return &FuncBlockState{ 29 | VecState: &VecState{ 30 | Vector: r.Output(), 31 | PresentMap: pm, 32 | }, 33 | StartRes: r, 34 | V: r.Vars(), 35 | } 36 | } 37 | 38 | // PropagateStart back-propagates through the start state. 39 | func (f *FuncBlock) PropagateStart(s StateGrad, g anydiff.Grad) { 40 | fs := s.(*FuncBlockState) 41 | fs.StartRes.Propagate(fs.VecState.Vector, g) 42 | } 43 | 44 | // Step applies the block for a timestep. 45 | func (f *FuncBlock) Step(s State, in anyvec.Vector) Res { 46 | fs := s.(*FuncBlockState) 47 | inPool := anydiff.NewVar(in) 48 | statePool := anydiff.NewVar(fs.Vector) 49 | out, state := f.Func(inPool, statePool, s.Present().NumPresent()) 50 | stateVars := anydiff.MergeVarSets(fs.V, state.Vars()) 51 | allVars := stateVars 52 | if out != nil { 53 | allVars = anydiff.MergeVarSets(allVars, out.Vars()) 54 | } 55 | for _, x := range []anydiff.VarSet{stateVars, allVars} { 56 | x.Del(inPool) 57 | x.Del(statePool) 58 | } 59 | newState := &FuncBlockState{ 60 | VecState: &VecState{ 61 | PresentMap: fs.PresentMap, 62 | Vector: state.Output(), 63 | }, 64 | V: stateVars, 65 | StartRes: fs.StartRes, 66 | } 67 | return &funcBlockRes{ 68 | InPool: inPool, 69 | StatePool: statePool, 70 | OutRes: out, 71 | StateRes: state, 72 | OutState: newState, 73 | V: allVars, 74 | } 75 | } 76 | 77 | // FuncBlockState is the State and StateGrad type used by 78 | // FuncBlock. 79 | type FuncBlockState struct { 80 | *VecState 81 | V anydiff.VarSet 82 | StartRes anydiff.Res 83 | } 84 | 85 | // Reduce reduces the state to the given sequences. 86 | func (f *FuncBlockState) Reduce(p PresentMap) State { 87 | return &FuncBlockState{ 88 | VecState: f.VecState.Reduce(p).(*VecState), 89 | V: f.V, 90 | StartRes: f.StartRes, 91 | } 92 | } 93 | 94 | // Expand expands the state. 95 | func (f *FuncBlockState) Expand(p PresentMap) StateGrad { 96 | return &FuncBlockState{ 97 | VecState: f.VecState.Expand(p).(*VecState), 98 | V: f.V, 99 | StartRes: f.StartRes, 100 | } 101 | } 102 | 103 | type funcBlockRes struct { 104 | InPool *anydiff.Var 105 | StatePool *anydiff.Var 106 | OutRes anydiff.Res 107 | StateRes anydiff.Res 108 | OutState *FuncBlockState 109 | V anydiff.VarSet 110 | } 111 | 112 | func (f *funcBlockRes) State() State { 113 | return f.OutState 114 | } 115 | 116 | func (f *funcBlockRes) Output() anyvec.Vector { 117 | if f.OutRes != nil { 118 | return f.OutRes.Output() 119 | } else { 120 | return f.StateRes.Output() 121 | } 122 | } 123 | 124 | func (f *funcBlockRes) Vars() anydiff.VarSet { 125 | return f.V 126 | } 127 | 128 | func (f *funcBlockRes) Propagate(u anyvec.Vector, s StateGrad, 129 | g anydiff.Grad) (anyvec.Vector, StateGrad) { 130 | c := f.InPool.Vector.Creator() 131 | g[f.InPool] = c.MakeVector(f.InPool.Output().Len()) 132 | g[f.StatePool] = c.MakeVector(f.StatePool.Output().Len()) 133 | if f.OutRes != nil { 134 | f.OutRes.Propagate(u, g) 135 | if s != nil { 136 | v := s.(*FuncBlockState).Vector 137 | f.StateRes.Propagate(v, g) 138 | } 139 | } else { 140 | if s != nil { 141 | u.Add(s.(*FuncBlockState).Vector) 142 | } 143 | f.StateRes.Propagate(u, g) 144 | } 145 | inGrad := g[f.InPool] 146 | stateGrad := g[f.StatePool] 147 | delete(g, f.InPool) 148 | delete(g, f.StatePool) 149 | return inGrad, &FuncBlockState{ 150 | VecState: &VecState{ 151 | Vector: stateGrad, 152 | PresentMap: f.OutState.PresentMap, 153 | }, 154 | StartRes: f.OutState.StartRes, 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /anyrnn/func_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/anyvec/anyvec64" 13 | ) 14 | 15 | func TestFuncBlock(t *testing.T) { 16 | c := anyvec64.CurrentCreator() 17 | inSeq, inVars := randomTestSequence(c, 3) 18 | stateFC := anynet.NewFC(c, 2, 2) 19 | inputFC := anynet.NewFC(c, 3, 2) 20 | outFC := anynet.NewFC(c, 2, 3) 21 | startState := anydiff.NewVar(c.MakeVector(2)) 22 | anyvec.Rand(startState.Vector, anyvec.Normal, nil) 23 | 24 | params := append(append(append(append(inVars, stateFC.Parameters()...), 25 | inputFC.Parameters()...), outFC.Parameters()...), startState) 26 | 27 | block := &FuncBlock{ 28 | Func: func(in, state anydiff.Res, n int) (out, newState anydiff.Res) { 29 | st := stateFC.Apply(state, n) 30 | it := inputFC.Apply(in, n) 31 | newState = anydiff.Add(st, it) 32 | out = outFC.Apply(newState, n) 33 | return 34 | }, 35 | MakeStart: func(n int) anydiff.Res { 36 | zeroVec := anydiff.NewConst(c.MakeVector(n * startState.Vector.Len())) 37 | return anydiff.AddRepeated(zeroVec, startState) 38 | }, 39 | } 40 | 41 | checker := &anydifftest.SeqChecker{ 42 | F: func() anyseq.Seq { 43 | return Map(inSeq, block) 44 | }, 45 | V: params, 46 | } 47 | checker.FullCheck(t) 48 | } 49 | 50 | func TestFuncBlockNilOut(t *testing.T) { 51 | c := anyvec32.CurrentCreator() 52 | inSeq, inVars := randomTestSequence(c, 3) 53 | fcBlock := anynet.NewFC(c, 3, 3) 54 | startState := anydiff.NewVar(c.MakeVector(3)) 55 | anyvec.Rand(startState.Vector, anyvec.Normal, nil) 56 | 57 | params := append(append(inVars, fcBlock.Parameters()...), startState) 58 | 59 | block := &FuncBlock{ 60 | Func: func(in, state anydiff.Res, n int) (out, newState anydiff.Res) { 61 | return nil, anynet.Tanh.Apply(anydiff.Add( 62 | fcBlock.Apply(in, n), 63 | fcBlock.Apply(state, n), 64 | ), n) 65 | }, 66 | MakeStart: func(n int) anydiff.Res { 67 | zeroVec := anydiff.NewConst(c.MakeVector(n * startState.Vector.Len())) 68 | return anydiff.AddRepeated(zeroVec, startState) 69 | }, 70 | } 71 | 72 | checker := &anydifftest.SeqChecker{ 73 | F: func() anyseq.Seq { 74 | return Map(inSeq, block) 75 | }, 76 | V: params, 77 | } 78 | checker.FullCheck(t) 79 | } 80 | -------------------------------------------------------------------------------- /anyrnn/layer.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/essentials" 10 | "github.com/unixpickle/serializer" 11 | ) 12 | 13 | func init() { 14 | var l LayerBlock 15 | serializer.RegisterTypedDeserializer(l.SerializerType(), DeserializeLayerBlock) 16 | } 17 | 18 | // A LayerBlock is a stateless Block that applies a 19 | // feed-forward neural network (or a layer thereof) to its 20 | // inputs. 21 | type LayerBlock struct { 22 | Layer anynet.Layer 23 | } 24 | 25 | // DeserializeLayerBlock deserializes a LayerBlock. 26 | func DeserializeLayerBlock(d []byte) (*LayerBlock, error) { 27 | n, err := anynet.DeserializeNet(d) 28 | if err != nil { 29 | return nil, essentials.AddCtx("deserialize LayerBlock", err) 30 | } 31 | if len(n) != 1 { 32 | return nil, errors.New("deserialize LayerBlock: multiple Layers") 33 | } 34 | return &LayerBlock{Layer: n[0]}, nil 35 | } 36 | 37 | // Start creates an empty start state. 38 | func (l *LayerBlock) Start(n int) State { 39 | pres := make([]bool, n) 40 | for i := range pres { 41 | pres[i] = true 42 | } 43 | return &emptyState{P: pres} 44 | } 45 | 46 | // PropagateStart does nothing, since the block is 47 | // state-less. 48 | func (l *LayerBlock) PropagateStart(s StateGrad, g anydiff.Grad) { 49 | } 50 | 51 | // Step applies the block for a single timestep. 52 | func (l *LayerBlock) Step(s State, in anyvec.Vector) Res { 53 | p := anydiff.NewVar(in) 54 | out := l.Layer.Apply(p, s.Present().NumPresent()) 55 | v := anydiff.MergeVarSets(out.Vars()) 56 | v.Del(p) 57 | return &layerBlockRes{ 58 | S: s.(*emptyState), 59 | Pool: p, 60 | Res: out, 61 | V: v, 62 | } 63 | } 64 | 65 | // Parameters returns the parameters of the layer if it is 66 | // an anynet.Parameterizer. 67 | func (l *LayerBlock) Parameters() []*anydiff.Var { 68 | return anynet.AllParameters(l.Layer) 69 | } 70 | 71 | // SerializerType returns the unique ID used to serialize 72 | // a LayerBlock with the serializer package. 73 | func (l *LayerBlock) SerializerType() string { 74 | return "github.com/unixpickle/anynet/anyrnn.LayerBlock" 75 | } 76 | 77 | // Serialize serializes the block if the Layer can be 78 | // serialized. 79 | func (l *LayerBlock) Serialize() ([]byte, error) { 80 | return anynet.Net{l.Layer}.Serialize() 81 | } 82 | 83 | type layerBlockRes struct { 84 | S *emptyState 85 | Pool *anydiff.Var 86 | Res anydiff.Res 87 | V anydiff.VarSet 88 | } 89 | 90 | func (l *layerBlockRes) State() State { 91 | return l.S 92 | } 93 | 94 | func (l *layerBlockRes) Output() anyvec.Vector { 95 | return l.Res.Output() 96 | } 97 | 98 | func (l *layerBlockRes) Vars() anydiff.VarSet { 99 | return l.V 100 | } 101 | 102 | func (l *layerBlockRes) Propagate(u anyvec.Vector, s StateGrad, g anydiff.Grad) (anyvec.Vector, 103 | StateGrad) { 104 | inDown := l.Pool.Vector.Creator().MakeVector(l.Pool.Vector.Len()) 105 | g[l.Pool] = inDown 106 | l.Res.Propagate(u, g) 107 | delete(g, l.Pool) 108 | if s == nil { 109 | s = l.S 110 | } 111 | return inDown, s 112 | } 113 | 114 | type emptyState struct { 115 | P PresentMap 116 | } 117 | 118 | func (e *emptyState) Present() PresentMap { 119 | return e.P 120 | } 121 | 122 | func (e *emptyState) Reduce(p PresentMap) State { 123 | return &emptyState{P: p} 124 | } 125 | 126 | func (e *emptyState) Expand(p PresentMap) StateGrad { 127 | return &emptyState{P: p} 128 | } 129 | -------------------------------------------------------------------------------- /anyrnn/layer_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | ) 13 | 14 | func TestLayerBlock(t *testing.T) { 15 | inSeq, inVars := randomTestSequence(anyvec32.CurrentCreator(), 3) 16 | block := &LayerBlock{ 17 | Layer: anynet.Net{ 18 | anynet.NewFC(anyvec32.CurrentCreator(), 3, 2), 19 | anynet.Tanh, 20 | }, 21 | } 22 | if len(block.Parameters()) != 2 { 23 | t.Errorf("expected 2 parameters, but got %d", len(block.Parameters())) 24 | } 25 | checker := &anydifftest.SeqChecker{ 26 | F: func() anyseq.Seq { 27 | return Map(inSeq, block) 28 | }, 29 | V: append(inVars, block.Parameters()...), 30 | } 31 | checker.FullCheck(t) 32 | } 33 | 34 | func randomTestSequence(c anyvec.Creator, inSize int) (anyseq.Seq, []*anydiff.Var) { 35 | inVars := []*anydiff.Var{} 36 | inBatches := []*anyseq.ResBatch{} 37 | 38 | presents := [][]bool{{true, true, true}, {true, false, true}} 39 | numPres := []int{3, 2} 40 | chunkLengths := []int{2, 3} 41 | 42 | for chunkIdx, pres := range presents { 43 | for i := 0; i < chunkLengths[chunkIdx]; i++ { 44 | vec := c.MakeVector(inSize * numPres[chunkIdx]) 45 | anyvec.Rand(vec, anyvec.Normal, nil) 46 | v := anydiff.NewVar(vec) 47 | batch := &anyseq.ResBatch{ 48 | Packed: v, 49 | Present: pres, 50 | } 51 | inVars = append(inVars, v) 52 | inBatches = append(inBatches, batch) 53 | } 54 | } 55 | return anyseq.ResSeq(c, inBatches), inVars 56 | } 57 | -------------------------------------------------------------------------------- /anyrnn/lstm_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff/anydifftest" 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | ) 10 | 11 | func TestLSTMProp(t *testing.T) { 12 | inSeq, inVars := randomTestSequence(anyvec32.CurrentCreator(), 3) 13 | block := NewLSTM(anyvec32.CurrentCreator(), 3, 2) 14 | if len(block.Parameters()) != 18 { 15 | t.Errorf("expected 18 parameters, but got %d", len(block.Parameters())) 16 | } 17 | checker := &anydifftest.SeqChecker{ 18 | F: func() anyseq.Seq { 19 | return Map(inSeq, block) 20 | }, 21 | V: append(inVars, block.Parameters()...), 22 | } 23 | checker.FullCheck(t) 24 | } 25 | -------------------------------------------------------------------------------- /anyrnn/map.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anydiff/anyseq" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | type mapRes struct { 10 | F func(s StateGrad, g anydiff.Grad) 11 | InitPres PresentMap 12 | In anyseq.Seq 13 | Out []*anyseq.Batch 14 | BlockRes []Res 15 | Block Block 16 | V anydiff.VarSet 17 | } 18 | 19 | // Map maps a Block over an input sequence batch, giving 20 | // an output sequence batch. 21 | func Map(s anyseq.Seq, b Block) anyseq.Seq { 22 | inSteps := s.Output() 23 | if len(inSteps) == 0 { 24 | return &mapRes{In: s} 25 | } 26 | 27 | state := b.Start(len(inSteps[0].Present)) 28 | return MapWithStart(s, b, state, func(sg StateGrad, g anydiff.Grad) { 29 | b.PropagateStart(sg, g) 30 | }) 31 | } 32 | 33 | // MapWithStart is like Map, but it takes a customized 34 | // start state rather than using the block's default start 35 | // state. 36 | // 37 | // During back-propagation, f is called with the upstream 38 | // state gradient for the start state. 39 | func MapWithStart(s anyseq.Seq, b Block, state State, f func(StateGrad, anydiff.Grad)) anyseq.Seq { 40 | inSteps := s.Output() 41 | if len(inSteps) == 0 { 42 | return &mapRes{In: s} 43 | } 44 | 45 | initPres := state.Present() 46 | if inSteps[0].NumPresent() != len(inSteps[0].Present) { 47 | state = state.Reduce(inSteps[0].Present) 48 | } 49 | res := &mapRes{F: f, InitPres: initPres, In: s, Block: b, V: s.Vars()} 50 | 51 | for _, x := range inSteps { 52 | if x.NumPresent() != state.Present().NumPresent() { 53 | state = state.Reduce(x.Present) 54 | } 55 | step := b.Step(state, x.Packed) 56 | res.BlockRes = append(res.BlockRes, step) 57 | res.V = anydiff.MergeVarSets(res.V, step.Vars()) 58 | res.Out = append(res.Out, &anyseq.Batch{ 59 | Packed: step.Output(), 60 | Present: x.Present, 61 | }) 62 | state = step.State() 63 | } 64 | 65 | return res 66 | } 67 | 68 | func (m *mapRes) Creator() anyvec.Creator { 69 | return m.In.Creator() 70 | } 71 | 72 | func (m *mapRes) Output() []*anyseq.Batch { 73 | return m.Out 74 | } 75 | 76 | func (m *mapRes) Vars() anydiff.VarSet { 77 | return m.V 78 | } 79 | 80 | func (m *mapRes) Propagate(u []*anyseq.Batch, g anydiff.Grad) { 81 | if len(u) == 0 { 82 | return 83 | } 84 | 85 | var downstream []*anyseq.Batch 86 | if g.Intersects(m.In.Vars()) { 87 | downstream = make([]*anyseq.Batch, len(u)) 88 | } 89 | 90 | var upState StateGrad 91 | for i := len(m.BlockRes) - 1; i >= 0; i-- { 92 | blockRes := m.BlockRes[i] 93 | if upState != nil { 94 | newPres := blockRes.State().Present() 95 | if newPres.NumPresent() != upState.Present().NumPresent() { 96 | upState = upState.Expand(newPres) 97 | } 98 | } 99 | down, downState := blockRes.Propagate(u[i].Packed, upState, g) 100 | if downstream != nil { 101 | downstream[i] = &anyseq.Batch{Packed: down, Present: u[i].Present} 102 | } 103 | upState = downState 104 | } 105 | 106 | if upState != nil { 107 | if m.InitPres.NumPresent() != upState.Present().NumPresent() { 108 | upState = upState.Expand(m.InitPres) 109 | } 110 | m.F(upState, g) 111 | } 112 | 113 | if downstream != nil { 114 | m.In.Propagate(downstream, g) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /anyrnn/markov.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | "github.com/unixpickle/anyvec/anyvecsave" 7 | "github.com/unixpickle/essentials" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | func init() { 12 | var m Markov 13 | serializer.RegisterTypedDeserializer(m.SerializerType(), DeserializeMarkov) 14 | } 15 | 16 | // Markov is an RNN block that tracks a history of inputs. 17 | type Markov struct { 18 | StartState *anydiff.Var 19 | 20 | // HistorySize is the number of inputs to store. 21 | // Since this does not include the input, the output 22 | // will contain HistorySize+1 packed inputs. 23 | HistorySize int 24 | 25 | // DepthWise controls how vectors from the history 26 | // are concatenated. 27 | // 28 | // If false, then vectors a and b are joined to 29 | // 30 | // 31 | // 32 | // If true, then vectors are joined depth-wise 33 | // 34 | // 35 | // 36 | // In either case, the first component of more 37 | // recent timesteps are packed before those of less 38 | // recent timesteps. 39 | DepthWise bool 40 | } 41 | 42 | // NewMarkov creates a Markov with the history size and 43 | // pre-known input vector size. 44 | func NewMarkov(c anyvec.Creator, history int, inSize int, depthWise bool) *Markov { 45 | if inSize == 0 || history == 0 { 46 | panic("input and history sizes must be non-zero") 47 | } 48 | return &Markov{ 49 | StartState: anydiff.NewVar(c.MakeVector(history * inSize)), 50 | HistorySize: history, 51 | DepthWise: depthWise, 52 | } 53 | } 54 | 55 | // DeserializeMarkov deserializes a Markov. 56 | func DeserializeMarkov(d []byte) (*Markov, error) { 57 | var res Markov 58 | var vec *anyvecsave.S 59 | err := serializer.DeserializeAny(d, &vec, &res.HistorySize, &res.DepthWise) 60 | if err != nil { 61 | return nil, essentials.AddCtx("deserialize Markov", err) 62 | } 63 | res.StartState = anydiff.NewVar(vec.Vector) 64 | return &res, nil 65 | } 66 | 67 | // Start returns a start state. 68 | func (m *Markov) Start(n int) State { 69 | return m.funcBlock().Start(n) 70 | } 71 | 72 | // PropagateStart propagates through the start state. 73 | func (m *Markov) PropagateStart(sg StateGrad, g anydiff.Grad) { 74 | m.funcBlock().PropagateStart(sg, g) 75 | } 76 | 77 | // Step applies the block, returning a stacked tensor and 78 | // updating the frame history in the state. 79 | func (m *Markov) Step(state State, in anyvec.Vector) Res { 80 | return m.funcBlock().Step(state, in) 81 | } 82 | 83 | // Parameters returns the Markov's parameters. 84 | func (m *Markov) Parameters() []*anydiff.Var { 85 | return []*anydiff.Var{m.StartState} 86 | } 87 | 88 | // SerializerType returns the unique ID used to serialize 89 | // a Markov with the serializer package. 90 | func (m *Markov) SerializerType() string { 91 | return "github.com/unixpickle/anynet/anyrnn.Markov" 92 | } 93 | 94 | // Serialize serializes the Markov. 95 | func (m *Markov) Serialize() ([]byte, error) { 96 | return serializer.SerializeAny( 97 | &anyvecsave.S{Vector: m.StartState.Vector}, 98 | m.HistorySize, 99 | m.DepthWise, 100 | ) 101 | } 102 | 103 | func (m *Markov) funcBlock() *FuncBlock { 104 | return &FuncBlock{ 105 | Func: func(in, state anydiff.Res, n int) (out, newState anydiff.Res) { 106 | subInSize := in.Output().Len() / n 107 | subStateSize := state.Output().Len() / n 108 | var outs, states []anydiff.Res 109 | for i := 0; i < n; i++ { 110 | subIn := anydiff.Slice(in, i*subInSize, (i+1)*subInSize) 111 | subState := anydiff.Slice(state, i*subStateSize, 112 | (i+1)*subStateSize) 113 | o, s := m.forward(subIn, subState) 114 | outs = append(outs, o) 115 | states = append(states, s) 116 | } 117 | return anydiff.Concat(outs...), anydiff.Concat(states...) 118 | }, 119 | MakeStart: func(n int) anydiff.Res { 120 | var rep []anydiff.Res 121 | for i := 0; i < n; i++ { 122 | rep = append(rep, m.StartState) 123 | } 124 | return anydiff.Concat(rep...) 125 | }, 126 | } 127 | } 128 | 129 | func (m *Markov) forward(in, state anydiff.Res) (out, newState anydiff.Res) { 130 | oldRows := &anydiff.Matrix{ 131 | Data: state, 132 | Rows: m.HistorySize, 133 | Cols: state.Output().Len() / m.HistorySize, 134 | } 135 | 136 | outRows := *oldRows 137 | outRows.Data = anydiff.Concat(in, oldRows.Data) 138 | outRows.Rows++ 139 | 140 | stateRows := outRows 141 | stateRows.Rows-- 142 | stateRows.Data = anydiff.Slice(outRows.Data, 0, 143 | stateRows.Cols*stateRows.Rows) 144 | 145 | if m.DepthWise { 146 | return anydiff.Transpose(&outRows).Data, stateRows.Data 147 | } else { 148 | return outRows.Data, stateRows.Data 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /anyrnn/markov_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | "github.com/unixpickle/serializer" 12 | ) 13 | 14 | func TestMarkovOutput(t *testing.T) { 15 | for _, mode := range []string{"Concat", "DepthWise"} { 16 | t.Run(mode, func(t *testing.T) { 17 | c := anyvec64.DefaultCreator{} 18 | markov := NewMarkov(c, 2, 3, mode == "DepthWise") 19 | markov.StartState.Vector.SetData( 20 | c.MakeNumericList([]float64{-1, -2, -3, -4, -5, -6}), 21 | ) 22 | 23 | inSeq := anyseq.ConstSeqList(c, [][]anyvec.Vector{ 24 | { 25 | c.MakeVectorData(c.MakeNumericList([]float64{1, 2, 3})), 26 | c.MakeVectorData(c.MakeNumericList([]float64{4, 5, 6})), 27 | c.MakeVectorData(c.MakeNumericList([]float64{7, 8, 9})), 28 | }, 29 | }) 30 | 31 | actual := anyseq.SeparateSeqs(Map(inSeq, markov).Output()) 32 | var expected [][]anyvec.Vector 33 | if mode == "Concat" { 34 | expected = [][]anyvec.Vector{ 35 | { 36 | c.MakeVectorData(c.MakeNumericList([]float64{1, 2, 3, -1, -2, -3, -4, -5, -6})), 37 | c.MakeVectorData(c.MakeNumericList([]float64{4, 5, 6, 1, 2, 3, -1, -2, -3})), 38 | c.MakeVectorData(c.MakeNumericList([]float64{7, 8, 9, 4, 5, 6, 1, 2, 3})), 39 | }, 40 | } 41 | } else { 42 | expected = [][]anyvec.Vector{ 43 | { 44 | c.MakeVectorData(c.MakeNumericList([]float64{1, -1, -4, 2, -2, -5, 3, -3, -6})), 45 | c.MakeVectorData(c.MakeNumericList([]float64{4, 1, -1, 5, 2, -2, 6, 3, -3})), 46 | c.MakeVectorData(c.MakeNumericList([]float64{7, 4, 1, 8, 5, 2, 9, 6, 3})), 47 | }, 48 | } 49 | } 50 | if !reflect.DeepEqual(actual, expected) { 51 | t.Errorf("expected %#v but got %#v", expected, actual) 52 | } 53 | }) 54 | } 55 | } 56 | 57 | func TestMarkovGradients(t *testing.T) { 58 | c := anyvec64.DefaultCreator{} 59 | markov := NewMarkov(c, 2, 3, true) 60 | markov.StartState.Vector.SetData( 61 | c.MakeNumericList([]float64{-1, -2, -3, -4, -5, -6}), 62 | ) 63 | inSeq, inVars := randomTestSequence(c, 3) 64 | checker := &anydifftest.SeqChecker{ 65 | F: func() anyseq.Seq { 66 | return Map(inSeq, markov) 67 | }, 68 | V: append(inVars, markov.Parameters()...), 69 | } 70 | checker.FullCheck(t) 71 | } 72 | 73 | func TestMarkovBatching(t *testing.T) { 74 | for _, mode := range []string{"Concat", "DepthWise"} { 75 | t.Run(mode, func(t *testing.T) { 76 | c := anyvec64.DefaultCreator{} 77 | markov := NewMarkov(c, 2, 3, mode == "DepthWise") 78 | markov.StartState.Vector.SetData( 79 | c.MakeNumericList([]float64{-1, -2, -3, -4, -5, -6}), 80 | ) 81 | seqs, _ := randomTestSequence(c, 3) 82 | batchOuts := anyseq.SeparateSeqs(Map(seqs, markov).Output()) 83 | for i, batchOut := range batchOuts { 84 | subSeq := anyseq.ConstSeqList( 85 | c, 86 | anyseq.SeparateSeqs(seqs.Output())[i:i+1], 87 | ) 88 | singleOut := anyseq.SeparateSeqs(Map(subSeq, markov).Output())[0] 89 | if !reflect.DeepEqual(batchOut, singleOut) { 90 | t.Errorf("got batch out %v but single out %v", batchOut, singleOut) 91 | } 92 | } 93 | }) 94 | } 95 | } 96 | 97 | func TestMarkovSerialize(t *testing.T) { 98 | c := anyvec64.DefaultCreator{} 99 | markov := NewMarkov(c, 2, 3, true) 100 | markov.StartState.Vector.SetData( 101 | c.MakeNumericList([]float64{-1, -2, -3, -4, -5, -6}), 102 | ) 103 | data, err := serializer.SerializeAny(markov) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | var markov1 *Markov 108 | if err := serializer.DeserializeAny(data, &markov1); err != nil { 109 | t.Fatal(err) 110 | } 111 | if !reflect.DeepEqual(markov, markov1) { 112 | t.Error("bad deserialized value") 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /anyrnn/parallel.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anynet" 6 | "github.com/unixpickle/anyvec" 7 | "github.com/unixpickle/essentials" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | func init() { 12 | serializer.RegisterTypedDeserializer((&Parallel{}).SerializerType(), DeserializeParallel) 13 | } 14 | 15 | // A Parallel block feeds its input to two blocks, then 16 | // merges the blocks' outputs. 17 | type Parallel struct { 18 | Block1 Block 19 | Block2 Block 20 | Mixer anynet.Mixer 21 | } 22 | 23 | // DeserializeParallel deserializes a Parallel. 24 | func DeserializeParallel(d []byte) (*Parallel, error) { 25 | var res Parallel 26 | err := serializer.DeserializeAny(d, &res.Block1, &res.Block2, &res.Mixer) 27 | if err != nil { 28 | return nil, essentials.AddCtx("deserialize Parallel", err) 29 | } 30 | return &res, nil 31 | } 32 | 33 | // Start produces a start state. 34 | func (p *Parallel) Start(n int) State { 35 | return &ParallelState{ 36 | State1: p.Block1.Start(n), 37 | State2: p.Block2.Start(n), 38 | } 39 | } 40 | 41 | // PropagateStart back-propagates through the start state. 42 | func (p *Parallel) PropagateStart(s StateGrad, g anydiff.Grad) { 43 | sg := s.(*ParallelGrad) 44 | p.Block1.PropagateStart(sg.Grad1, g) 45 | p.Block2.PropagateStart(sg.Grad2, g) 46 | } 47 | 48 | // Step takes a timestep. 49 | func (p *Parallel) Step(s State, in anyvec.Vector) Res { 50 | state := s.(*ParallelState) 51 | res := ¶llelRes{ 52 | Res1: p.Block1.Step(state.State1, in), 53 | Res2: p.Block2.Step(state.State2, in), 54 | } 55 | res.V = anydiff.MergeVarSets(res.Res1.Vars(), res.Res2.Vars()) 56 | res.Pool1 = anydiff.NewVar(res.Res1.Output()) 57 | res.Pool2 = anydiff.NewVar(res.Res2.Output()) 58 | res.OutRes = p.Mixer.Mix(res.Pool1, res.Pool2, s.Present().NumPresent()) 59 | res.OutState = &ParallelState{State1: res.Res1.State(), State2: res.Res2.State()} 60 | return res 61 | } 62 | 63 | // Parameters returns the parameters of the block, which 64 | // are taken from block 1, block 2, and the mixer in that 65 | // order. 66 | func (p *Parallel) Parameters() []*anydiff.Var { 67 | return anynet.AllParameters(p.Block1, p.Block2, p.Mixer) 68 | } 69 | 70 | // SerializerType returns the unique ID used to serialize 71 | // a Parallel with the serializer package. 72 | func (p *Parallel) SerializerType() string { 73 | return "github.com/unixpickle/anynet/anyrnn.Parallel" 74 | } 75 | 76 | // Serialize serializes the block. 77 | func (p *Parallel) Serialize() ([]byte, error) { 78 | return serializer.SerializeAny(p.Block1, p.Block2, p.Mixer) 79 | } 80 | 81 | // ParallelState stores the state of a Parallel block. 82 | type ParallelState struct { 83 | State1 State 84 | State2 State 85 | } 86 | 87 | // Present returns the present map of one of the internal 88 | // states. 89 | func (p *ParallelState) Present() PresentMap { 90 | return p.State1.Present() 91 | } 92 | 93 | // Reduce reduces the internal states. 94 | func (p *ParallelState) Reduce(pres PresentMap) State { 95 | return &ParallelState{ 96 | State1: p.State1.Reduce(pres), 97 | State2: p.State2.Reduce(pres), 98 | } 99 | } 100 | 101 | // ParallelGrad stores the state gradient of a Parallel 102 | // block. 103 | type ParallelGrad struct { 104 | Grad1 StateGrad 105 | Grad2 StateGrad 106 | } 107 | 108 | // Present returns the present map of one of the internal 109 | // state grads. 110 | func (p *ParallelGrad) Present() PresentMap { 111 | return p.Grad1.Present() 112 | } 113 | 114 | // Expand expands all the internal state grads. 115 | func (p *ParallelGrad) Expand(pres PresentMap) StateGrad { 116 | return &ParallelGrad{ 117 | Grad1: p.Grad1.Expand(pres), 118 | Grad2: p.Grad2.Expand(pres), 119 | } 120 | } 121 | 122 | type parallelRes struct { 123 | Res1 Res 124 | Res2 Res 125 | OutRes anydiff.Res 126 | Pool1 *anydiff.Var 127 | Pool2 *anydiff.Var 128 | OutState *ParallelState 129 | V anydiff.VarSet 130 | } 131 | 132 | func (p *parallelRes) State() State { 133 | return p.OutState 134 | } 135 | 136 | func (p *parallelRes) Output() anyvec.Vector { 137 | return p.OutRes.Output() 138 | } 139 | 140 | func (p *parallelRes) Vars() anydiff.VarSet { 141 | return p.V 142 | } 143 | 144 | func (p *parallelRes) Propagate(u anyvec.Vector, s StateGrad, 145 | g anydiff.Grad) (anyvec.Vector, StateGrad) { 146 | for _, p := range []*anydiff.Var{p.Pool1, p.Pool2} { 147 | g[p] = p.Vector.Creator().MakeVector(p.Vector.Len()) 148 | defer func(p *anydiff.Var) { 149 | delete(g, p) 150 | }(p) 151 | } 152 | p.OutRes.Propagate(u, g) 153 | var sg1, sg2 StateGrad 154 | if s != nil { 155 | sg := s.(*ParallelGrad) 156 | sg1, sg2 = sg.Grad1, sg.Grad2 157 | } 158 | inGrad1, downGrad1 := p.Res1.Propagate(g[p.Pool1], sg1, g) 159 | inGrad2, downGrad2 := p.Res2.Propagate(g[p.Pool2], sg2, g) 160 | inGrad1.Add(inGrad2) 161 | return inGrad1, &ParallelGrad{Grad1: downGrad1, Grad2: downGrad2} 162 | } 163 | -------------------------------------------------------------------------------- /anyrnn/parallel_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff/anydifftest" 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | ) 11 | 12 | func TestParallelProp(t *testing.T) { 13 | inSeq, inVars := randomTestSequence(anyvec32.CurrentCreator(), 3) 14 | block1 := NewLSTM(anyvec32.CurrentCreator(), 3, 1) 15 | block2 := NewLSTM(anyvec32.CurrentCreator(), 3, 2) 16 | block := &Parallel{ 17 | Block1: block1, 18 | Block2: block2, 19 | Mixer: anynet.ConcatMixer{}, 20 | } 21 | if len(block.Parameters()) != 36 { 22 | t.Errorf("expected 36 parameters, but got %d", len(block.Parameters())) 23 | } 24 | checker := &anydifftest.SeqChecker{ 25 | F: func() anyseq.Seq { 26 | return Map(inSeq, block) 27 | }, 28 | V: append(inVars, block.Parameters()...), 29 | } 30 | checker.FullCheck(t) 31 | } 32 | -------------------------------------------------------------------------------- /anyrnn/serializer_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec32" 11 | "github.com/unixpickle/serializer" 12 | ) 13 | 14 | func TestLayerSerialize(t *testing.T) { 15 | testSerialize(t, &LayerBlock{ 16 | Layer: anynet.Tanh, 17 | }) 18 | } 19 | 20 | func TestStackSerialize(t *testing.T) { 21 | testSerialize(t, Stack{ 22 | &LayerBlock{Layer: anynet.Tanh}, 23 | &LayerBlock{Layer: anynet.LogSoftmax}, 24 | }) 25 | } 26 | 27 | func TestVanillaSerialize(t *testing.T) { 28 | v := NewVanilla(anyvec32.CurrentCreator(), 3, 2, anynet.Tanh) 29 | 30 | // Make sure the biases are different than the init state. 31 | v.Biases.Vector.AddScalar(float32(1)) 32 | 33 | testSerialize(t, v) 34 | } 35 | 36 | func TestLSTMGateSerialize(t *testing.T) { 37 | g := NewLSTMGate(anyvec32.CurrentCreator(), 3, 2, anynet.Sigmoid) 38 | 39 | // Make sure the biases are different than the init state. 40 | g.Biases.Vector.AddScalar(float32(1)) 41 | 42 | testSerialize(t, g) 43 | } 44 | 45 | func TestLSTMSerialize(t *testing.T) { 46 | testSerialize(t, NewLSTM(anyvec32.CurrentCreator(), 3, 2)) 47 | } 48 | 49 | func TestBidirSerialize(t *testing.T) { 50 | c := anyvec32.CurrentCreator() 51 | b := &Bidir{ 52 | Forward: NewVanilla(c, 5, 3, anynet.Tanh), 53 | Backward: NewVanilla(c, 5, 2, anynet.Tanh), 54 | Mixer: &anynet.AddMixer{ 55 | In1: anynet.NewFC(c, 3, 2), 56 | In2: anynet.NewFC(c, 2, 2), 57 | Out: anynet.Tanh, 58 | }, 59 | } 60 | testSerialize(t, b) 61 | } 62 | 63 | func TestFeedbackSerialize(t *testing.T) { 64 | c := anyvec32.CurrentCreator() 65 | vec := c.MakeVector(2) 66 | anyvec.Rand(vec, anyvec.Normal, nil) 67 | testSerialize(t, &Feedback{ 68 | Mixer: anynet.ConcatMixer{}, 69 | Block: NewLSTM(c, 3, 2), 70 | InitOut: anydiff.NewVar(vec), 71 | }) 72 | } 73 | 74 | func TestParallelSerialize(t *testing.T) { 75 | c := anyvec32.CurrentCreator() 76 | b := &Parallel{ 77 | Block1: NewVanilla(c, 5, 3, anynet.Tanh), 78 | Block2: NewVanilla(c, 5, 2, anynet.Tanh), 79 | Mixer: &anynet.AddMixer{ 80 | In1: anynet.NewFC(c, 3, 2), 81 | In2: anynet.NewFC(c, 2, 2), 82 | Out: anynet.Tanh, 83 | }, 84 | } 85 | testSerialize(t, b) 86 | } 87 | 88 | func testSerialize(t *testing.T, obj serializer.Serializer) { 89 | data, err := serializer.SerializeWithType(obj) 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | newObj, err := serializer.DeserializeWithType(data) 94 | if err != nil { 95 | t.Fatal(err) 96 | } 97 | if !reflect.DeepEqual(obj, newObj) { 98 | t.Errorf("expected %v but got %v", obj, newObj) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /anyrnn/stack.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/essentials" 10 | "github.com/unixpickle/serializer" 11 | ) 12 | 13 | func init() { 14 | var s Stack 15 | serializer.RegisterTypedDeserializer(s.SerializerType(), DeserializeStack) 16 | } 17 | 18 | // A Stack is a meta-Block for composing Blocks. 19 | // In a Stack, the first Block's output is fed as input to 20 | // the next Block, etc. 21 | // 22 | // An empty Stack is invalid. 23 | type Stack []Block 24 | 25 | // DeserializeStack deserializes a Stack. 26 | func DeserializeStack(d []byte) (Stack, error) { 27 | blockSlice, err := serializer.DeserializeSlice(d) 28 | if err != nil { 29 | return nil, essentials.AddCtx("deserialize Stack", err) 30 | } 31 | res := make(Stack, len(blockSlice)) 32 | for i, x := range blockSlice { 33 | if b, ok := x.(Block); ok { 34 | res[i] = b 35 | } else { 36 | return nil, fmt.Errorf("deserialize Stack: type is not a Block: %T", x) 37 | } 38 | } 39 | return res, nil 40 | } 41 | 42 | // Start produces a start StackState. 43 | func (s Stack) Start(n int) State { 44 | s.assertNonEmpty() 45 | res := make(StackState, len(s)) 46 | for i, x := range s { 47 | res[i] = x.Start(n) 48 | } 49 | return res 50 | } 51 | 52 | // PropagateStart back-propagates through the start state. 53 | func (s Stack) PropagateStart(sg StateGrad, g anydiff.Grad) { 54 | for i, x := range s { 55 | x.PropagateStart(sg.(StackGrad)[i], g) 56 | } 57 | } 58 | 59 | // Step applies the block for a single timestep. 60 | func (s Stack) Step(st State, in anyvec.Vector) Res { 61 | res := &stackRes{V: anydiff.VarSet{}} 62 | inVec := in 63 | for i, x := range s { 64 | inState := st.(StackState)[i] 65 | blockRes := x.Step(inState, inVec) 66 | inVec = blockRes.Output() 67 | res.Reses = append(res.Reses, blockRes) 68 | res.OutState = append(res.OutState, blockRes.State()) 69 | res.V = anydiff.MergeVarSets(res.V, blockRes.Vars()) 70 | } 71 | return res 72 | } 73 | 74 | // Parameters gathers the parameters of all the sub-blocks 75 | // that implement anynet.Parameterizer. 76 | func (s Stack) Parameters() []*anydiff.Var { 77 | var res []*anydiff.Var 78 | for _, x := range s { 79 | if p, ok := x.(anynet.Parameterizer); ok { 80 | res = append(res, p.Parameters()...) 81 | } 82 | } 83 | return res 84 | } 85 | 86 | // SerializerType returns the unique ID used to serialize 87 | // a Stack with the serializer package. 88 | func (s Stack) SerializerType() string { 89 | return "github.com/unixpickle/anynet/anyrnn.Stack" 90 | } 91 | 92 | // Serialize serializes the Stack. 93 | // It only works if every child is a Serializer. 94 | func (s Stack) Serialize() ([]byte, error) { 95 | var res []serializer.Serializer 96 | for _, x := range s { 97 | if ser, ok := x.(serializer.Serializer); ok { 98 | res = append(res, ser) 99 | } else { 100 | return nil, fmt.Errorf("not a serializer: %T", x) 101 | } 102 | } 103 | return serializer.SerializeSlice(res) 104 | } 105 | 106 | func (s Stack) assertNonEmpty() { 107 | if len(s) == 0 { 108 | panic("empty Stack is invalid") 109 | } 110 | } 111 | 112 | type stackRes struct { 113 | Reses []Res 114 | OutState StackState 115 | V anydiff.VarSet 116 | } 117 | 118 | func (s *stackRes) State() State { 119 | return s.OutState 120 | } 121 | 122 | func (s *stackRes) Output() anyvec.Vector { 123 | return s.Reses[len(s.Reses)-1].Output() 124 | } 125 | 126 | func (s *stackRes) Vars() anydiff.VarSet { 127 | return s.V 128 | } 129 | 130 | func (s *stackRes) Propagate(u anyvec.Vector, sg StateGrad, g anydiff.Grad) (anyvec.Vector, 131 | StateGrad) { 132 | downVec := u 133 | downStates := make(StackGrad, len(s.Reses)) 134 | for i := len(s.Reses) - 1; i >= 0; i-- { 135 | var stateUpstream StateGrad 136 | if sg != nil { 137 | stateUpstream = sg.(StackGrad)[i] 138 | } 139 | down, downState := s.Reses[i].Propagate(downVec, stateUpstream, g) 140 | downVec = down 141 | downStates[i] = downState 142 | } 143 | return downVec, downStates 144 | } 145 | 146 | // StackState is the State type for a Stack. 147 | // 148 | // Each State in the slice corresponds to a Block in the 149 | // Stack. 150 | type StackState []State 151 | 152 | // Present returns the present map of one of the internal 153 | // states. 154 | func (s StackState) Present() PresentMap { 155 | return s[0].Present() 156 | } 157 | 158 | // Reduce reduces all the internal states. 159 | func (s StackState) Reduce(p PresentMap) State { 160 | res := make(StackState, len(s)) 161 | for i, x := range s { 162 | res[i] = x.Reduce(p) 163 | } 164 | return res 165 | } 166 | 167 | // StackGrad is the StateGrad type for a Stack. 168 | // 169 | // It is pretty much analogous to StackState. 170 | type StackGrad []StateGrad 171 | 172 | // Present returns the present map of one of the internal 173 | // state grads. 174 | func (s StackGrad) Present() PresentMap { 175 | return s[0].Present() 176 | } 177 | 178 | // Expand expands all the internal state grads. 179 | func (s StackGrad) Expand(p PresentMap) StateGrad { 180 | res := make(StackGrad, len(s)) 181 | for i, x := range s { 182 | res[i] = x.Expand(p) 183 | } 184 | return res 185 | } 186 | -------------------------------------------------------------------------------- /anyrnn/stack_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anydifftest" 8 | "github.com/unixpickle/anydiff/anyseq" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | ) 13 | 14 | func TestStackOutput(t *testing.T) { 15 | layer1 := anynet.NewFC(anyvec32.CurrentCreator(), 3, 2) 16 | layer2 := anynet.Tanh 17 | 18 | input := anyvec32.MakeVectorData([]float32{ 19 | 2.098950, -0.645579, 2.106542, 20 | 0.085620, 0.762207, -0.279375, 21 | 0.993967, 2.453542, 1.729150, 22 | -0.971805, -0.315578, -0.306942, 23 | }) 24 | inRes := anydiff.NewConst(input) 25 | expected := anynet.Net{layer1, layer2}.Apply(inRes, 4).Output() 26 | 27 | stacked := Stack{&LayerBlock{Layer: layer1}, &LayerBlock{Layer: layer2}} 28 | state := stacked.Start(4) 29 | actual := stacked.Step(state, input).Output() 30 | 31 | diff := actual.Copy() 32 | diff.Sub(expected) 33 | max := anyvec.AbsMax(diff).(float32) 34 | if max > 1e-3 { 35 | t.Errorf("expected %v but got %v", expected.Data(), actual.Data()) 36 | } 37 | } 38 | 39 | func TestStackProp(t *testing.T) { 40 | inSeq, inVars := randomTestSequence(anyvec32.CurrentCreator(), 3) 41 | block := Stack{ 42 | &LayerBlock{ 43 | Layer: anynet.NewFC(anyvec32.CurrentCreator(), 3, 2), 44 | }, 45 | &LayerBlock{ 46 | Layer: anynet.Tanh, 47 | }, 48 | } 49 | if len(block.Parameters()) != 2 { 50 | t.Errorf("expected 2 parameters, but got %d", len(block.Parameters())) 51 | } 52 | checker := &anydifftest.SeqChecker{ 53 | F: func() anyseq.Seq { 54 | return Map(inSeq, block) 55 | }, 56 | V: append(inVars, block.Parameters()...), 57 | } 58 | checker.FullCheck(t) 59 | } 60 | -------------------------------------------------------------------------------- /anyrnn/vanilla_test.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/unixpickle/anydiff" 10 | "github.com/unixpickle/anydiff/anydifftest" 11 | "github.com/unixpickle/anydiff/anyseq" 12 | "github.com/unixpickle/anynet" 13 | "github.com/unixpickle/anyvec" 14 | "github.com/unixpickle/anyvec/anyvec32" 15 | ) 16 | 17 | func TestVanillaOutput(t *testing.T) { 18 | v := &Vanilla{ 19 | InCount: 2, 20 | OutCount: 3, 21 | StateWeights: anydiff.NewVar(anyvec32.MakeVectorData([]float32{ 22 | 1.013949080929492, 0.651993107300643, 1.962063017373509, 23 | -0.305518636912932, -1.907428571394675, 1.047494354506540, 24 | 0.424928126971939, 1.152028175999884, -0.159508838475856, 25 | })), 26 | InputWeights: anydiff.NewVar(anyvec32.MakeVectorData([]float32{ 27 | 0.578969569547752, -1.738131402776219, 28 | 1.834645967361668, 0.216295204977240, 29 | -0.174414967388466, 0.495420882674173, 30 | })), 31 | Biases: anydiff.NewVar(anyvec32.MakeVectorData([]float32{ 32 | -1.787774931407171, 0.295051579270469, 0.926511486922573, 33 | })), 34 | StartState: anydiff.NewVar(anyvec32.MakeVectorData([]float32{ 35 | 1.137478454905680, 0.318992795539938, -0.263672034086737, 36 | })), 37 | Activation: anynet.Tanh, 38 | } 39 | c := anyvec32.CurrentCreator() 40 | seq := anyseq.ConstSeq(c, []*anyseq.Batch{ 41 | { 42 | Packed: anyvec32.MakeVectorData([]float32{1, 2, -1, -3}), 43 | Present: []bool{true, true, false}, 44 | }, 45 | { 46 | Packed: anyvec32.MakeVectorData([]float32{2, -1}), 47 | Present: []bool{false, true, false}, 48 | }, 49 | { 50 | Packed: anyvec32.MakeVectorData([]float32{-1.5, 2}), 51 | Present: []bool{false, true, false}, 52 | }, 53 | }) 54 | actual := Map(seq, v).Output() 55 | expected := []*anyseq.Batch{ 56 | { 57 | Packed: anyvec32.MakeVectorData([]float32{ 58 | -0.999078474068952, 0.869277718016662, 0.989782342296173, 59 | 0.998757641871079, -0.997864863237772, 0.468039628893154, 60 | }), 61 | Present: []bool{true, true, false}, 62 | }, 63 | { 64 | Packed: anyvec32.MakeVectorData([]float32{ 65 | 0.983305063303649, 0.999982959700321, -0.615398128318720, 66 | }), 67 | Present: []bool{false, true, false}, 68 | }, 69 | { 70 | Packed: anyvec32.MakeVectorData([]float32{ 71 | -0.999977199810133, -0.999883828747684, 0.999089273224046, 72 | }), 73 | Present: []bool{false, true, false}, 74 | }, 75 | } 76 | if !seqsEquivalent(actual, expected) { 77 | t.Errorf("expected %s but got %s", seqString(expected), seqString(actual)) 78 | } 79 | } 80 | 81 | func TestVanillaProp(t *testing.T) { 82 | block := NewVanilla(anyvec32.CurrentCreator(), 3, 2, anynet.Tanh) 83 | inSeq, inVars := randomTestSequence(anyvec32.CurrentCreator(), 3) 84 | if len(block.Parameters()) != 4 { 85 | t.Errorf("expected 4 parameters, but got %d", len(block.Parameters())) 86 | } 87 | checker := &anydifftest.SeqChecker{ 88 | F: func() anyseq.Seq { 89 | return Map(inSeq, block) 90 | }, 91 | V: append(inVars, block.Parameters()...), 92 | } 93 | checker.FullCheck(t) 94 | } 95 | 96 | func seqString(s []*anyseq.Batch) string { 97 | var parts []string 98 | for _, x := range s { 99 | parts = append(parts, fmt.Sprintf("{Packed: %v, Present: %v}", x.Packed.Data(), 100 | x.Present)) 101 | } 102 | return "[" + strings.Join(parts, " ") + "]" 103 | } 104 | 105 | func seqsEquivalent(s1, s2 []*anyseq.Batch) bool { 106 | if len(s1) != len(s2) { 107 | return false 108 | } 109 | for i, b1 := range s1 { 110 | b2 := s2[i] 111 | if !reflect.DeepEqual(b1.Present, b2.Present) { 112 | return false 113 | } 114 | diff := b1.Packed.Copy() 115 | diff.Sub(b2.Packed) 116 | max := anyvec.AbsMax(diff) 117 | switch max := max.(type) { 118 | case float32: 119 | if max > 1e-3 { 120 | return false 121 | } 122 | case float64: 123 | if max > 1e-5 { 124 | return false 125 | } 126 | default: 127 | panic(fmt.Sprintf("unsupported numeric type: %T", max)) 128 | } 129 | } 130 | return true 131 | } 132 | -------------------------------------------------------------------------------- /anyrnn/vec_state.go: -------------------------------------------------------------------------------- 1 | package anyrnn 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anydiff/anyseq" 6 | "github.com/unixpickle/anyvec" 7 | ) 8 | 9 | // A VecState is a State and/or StateGrad that can be 10 | // expressed as a vector. 11 | type VecState struct { 12 | Vector anyvec.Vector 13 | PresentMap PresentMap 14 | } 15 | 16 | // NewVecState generates a VecState with the vector 17 | // repeated n times. 18 | func NewVecState(v anyvec.Vector, n int) *VecState { 19 | rep := v.Creator().MakeVector(v.Len() * n) 20 | anyvec.AddRepeated(rep, v) 21 | p := make([]bool, n) 22 | for i := range p { 23 | p[i] = true 24 | } 25 | return &VecState{ 26 | Vector: rep, 27 | PresentMap: p, 28 | } 29 | } 30 | 31 | // Present returns the PresentMap. 32 | func (v *VecState) Present() PresentMap { 33 | return v.PresentMap 34 | } 35 | 36 | // Reduce generates a new *VecState with a subset of the 37 | // chunks in v. 38 | func (v *VecState) Reduce(p PresentMap) State { 39 | b := &anyseq.Batch{Packed: v.Vector, Present: v.PresentMap} 40 | res := b.Reduce(p) 41 | return &VecState{Vector: res.Packed, PresentMap: p} 42 | } 43 | 44 | // Expand expands the *VecState by inserting zero chunks 45 | // where necessary, producing a new *VecState. 46 | func (v *VecState) Expand(p PresentMap) StateGrad { 47 | b := &anyseq.Batch{Packed: v.Vector, Present: v.PresentMap} 48 | res := b.Expand(p) 49 | return &VecState{Vector: res.Packed, PresentMap: p} 50 | } 51 | 52 | // PropagateStart propagates the contents of the vector, 53 | // treated as a batched upstream gradient, through the 54 | // variable. 55 | // 56 | // All sequences must be present. 57 | func (v *VecState) PropagateStart(va *anydiff.Var, g anydiff.Grad) { 58 | for _, x := range v.PresentMap { 59 | if !x { 60 | panic("all sequences must be present") 61 | } 62 | } 63 | if dest, ok := g[va]; ok { 64 | dest.Add(anyvec.SumRows(v.Vector, v.Vector.Len()/len(v.PresentMap))) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /anys2s/doc.go: -------------------------------------------------------------------------------- 1 | // Package anys2s is for sequence-to-sequence learning for 2 | // recurrent neural networks. 3 | package anys2s 4 | -------------------------------------------------------------------------------- /anys2s/samples.go: -------------------------------------------------------------------------------- 1 | package anys2s 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/unixpickle/anynet/anysgd" 7 | "github.com/unixpickle/anyvec" 8 | ) 9 | 10 | // A Sample is a training sequence with a corresponding 11 | // desired output sequence. 12 | type Sample struct { 13 | Input []anyvec.Vector 14 | Output []anyvec.Vector 15 | } 16 | 17 | // A SampleList is an anysgd.SampleList that produces 18 | // sequence-to-sequence samples. 19 | type SampleList interface { 20 | anysgd.SampleList 21 | 22 | GetSample(idx int) (*Sample, error) 23 | Creator() anyvec.Creator 24 | } 25 | 26 | // A SortableSampleList is a SampleList with an extra 27 | // LenAt method for efficiently getting the length of an 28 | // input sequence. 29 | type SortableSampleList interface { 30 | SampleList 31 | 32 | LenAt(idx int) int 33 | } 34 | 35 | // A SortSampleList wraps a SampleList and ensures that 36 | // samples will be sorted within reasonably small chunks. 37 | // This is often beneficial for RNNs on a GPU, since it 38 | // helps to keep batch sizes stable across timesteps. 39 | type SortSampleList struct { 40 | SortableSampleList 41 | 42 | // BatchSize is the size of the chunks that should be 43 | // sorted. 44 | BatchSize int 45 | } 46 | 47 | // Slice produces a subset of the SortSampleList. 48 | func (s *SortSampleList) Slice(i, j int) anysgd.SampleList { 49 | sliced := s.SortableSampleList.Slice(i, j) 50 | return &SortSampleList{ 51 | SortableSampleList: sliced.(SortableSampleList), 52 | BatchSize: s.BatchSize, 53 | } 54 | } 55 | 56 | // PostShuffle sorts batches of sequences. 57 | func (s *SortSampleList) PostShuffle() { 58 | for i := 0; i < s.Len(); i += s.BatchSize { 59 | bs := s.BatchSize 60 | if bs > s.Len()-i { 61 | bs = s.Len() - i 62 | } 63 | s := &sorter{S: s.SortableSampleList, Start: i, End: i + bs} 64 | sort.Sort(s) 65 | } 66 | } 67 | 68 | type sorter struct { 69 | S SortableSampleList 70 | Start int 71 | End int 72 | } 73 | 74 | func (s *sorter) Len() int { 75 | return s.End - s.Start 76 | } 77 | 78 | func (s *sorter) Swap(i, j int) { 79 | s.S.Swap(i+s.Start, j+s.Start) 80 | } 81 | 82 | func (s *sorter) Less(i, j int) bool { 83 | return s.S.LenAt(i+s.Start) < s.S.LenAt(j+s.Start) 84 | } 85 | -------------------------------------------------------------------------------- /anys2s/trainer.go: -------------------------------------------------------------------------------- 1 | package anys2s 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anysgd" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/essentials" 12 | ) 13 | 14 | // A Batch stores an input and output batch in a packed 15 | // format. 16 | type Batch struct { 17 | Inputs anyseq.Seq 18 | Outputs anyseq.Seq 19 | } 20 | 21 | // A Trainer creates batches, computes gradients, and adds 22 | // up costs for a sequence-to-sequence mapping. 23 | type Trainer struct { 24 | Func func(anyseq.Seq) anyseq.Seq 25 | Cost anynet.Cost 26 | Params []*anydiff.Var 27 | 28 | // Average indicates whether or not the total cost should 29 | // be averaged before computing gradients. 30 | // This affects gradients, LastCost, and the output of 31 | // TotalCost(). 32 | Average bool 33 | 34 | // After every gradient computation, LastCost is set to 35 | // the cost from the batch. 36 | LastCost anyvec.Numeric 37 | } 38 | 39 | // Fetch produces a *Batch for the subset of samples. 40 | // The s argument must implement SampleList. 41 | // The batch may not be empty. 42 | func (t *Trainer) Fetch(s anysgd.SampleList) (anysgd.Batch, error) { 43 | if s.Len() == 0 { 44 | return nil, errors.New("fetch batch: empty batch") 45 | } 46 | l := s.(SampleList) 47 | ins := make([][]anyvec.Vector, l.Len()) 48 | outs := make([][]anyvec.Vector, l.Len()) 49 | for i := 0; i < l.Len(); i++ { 50 | sample, err := l.GetSample(i) 51 | if err != nil { 52 | return nil, essentials.AddCtx("fetch batch", err) 53 | } 54 | ins[i] = sample.Input 55 | outs[i] = sample.Output 56 | } 57 | return &Batch{ 58 | Inputs: anyseq.ConstSeqList(l.Creator(), ins), 59 | Outputs: anyseq.ConstSeqList(l.Creator(), outs), 60 | }, nil 61 | } 62 | 63 | // TotalCost computes the total cost for the *Batch. 64 | func (t *Trainer) TotalCost(batch anysgd.Batch) anydiff.Res { 65 | b := batch.(*Batch) 66 | actual := t.Func(b.Inputs) 67 | 68 | if len(actual.Output()) != len(b.Outputs.Output()) { 69 | panic("mismatching actual and desired sequence shapes") 70 | } 71 | 72 | var idx int 73 | var costCount int 74 | allCosts := anyseq.Map(actual, func(a anydiff.Res, n int) anydiff.Res { 75 | batch := b.Outputs.Output()[idx] 76 | if batch.NumPresent() != n { 77 | panic("mismatching actual and desired sequence shapes") 78 | } 79 | costCount += n 80 | idx++ 81 | c := t.Cost.Cost(anydiff.NewConst(batch.Packed), a, n) 82 | return c 83 | }) 84 | 85 | sum := anydiff.Sum(anyseq.Sum(allCosts)) 86 | if t.Average { 87 | scaler := sum.Output().Creator().MakeNumeric(1 / float64(costCount)) 88 | return anydiff.Scale(sum, scaler) 89 | } else { 90 | return sum 91 | } 92 | } 93 | 94 | // Gradient computes the gradient for the batch's cost. 95 | // It also sets t.LastCost to the numerical value of the 96 | // total cost. 97 | // 98 | // The b argument must be a *Batch. 99 | func (t *Trainer) Gradient(b anysgd.Batch) anydiff.Grad { 100 | grad, lc := anysgd.CosterGrad(t, b, t.Params) 101 | t.LastCost = lc 102 | return grad 103 | } 104 | -------------------------------------------------------------------------------- /anys2v/doc.go: -------------------------------------------------------------------------------- 1 | // Package anys2v is for sequence-to-vector learning. 2 | // It can be applied to RNN architectures that encode 3 | // sequences down to fixed-length vectors. 4 | package anys2v 5 | -------------------------------------------------------------------------------- /anys2v/samples.go: -------------------------------------------------------------------------------- 1 | package anys2v 2 | 3 | import ( 4 | "github.com/unixpickle/anynet/anysgd" 5 | "github.com/unixpickle/anyvec" 6 | ) 7 | 8 | // A Sample is a training sequence with a corresponding 9 | // desired output vector. 10 | // 11 | // It is invalid for the input sequence to be empty. 12 | // All input sequences must be non-empty. 13 | type Sample struct { 14 | Input []anyvec.Vector 15 | Output anyvec.Vector 16 | } 17 | 18 | // A SampleList is an anysgd.SampleList that produces 19 | // sequence-to-vector samples. 20 | type SampleList interface { 21 | anysgd.SampleList 22 | 23 | GetSample(idx int) (*Sample, error) 24 | } 25 | -------------------------------------------------------------------------------- /anys2v/trainer.go: -------------------------------------------------------------------------------- 1 | package anys2v 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anydiff/anyseq" 8 | "github.com/unixpickle/anynet" 9 | "github.com/unixpickle/anynet/anysgd" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/essentials" 12 | ) 13 | 14 | // A Batch stores an input and output batch in a packed 15 | // format. 16 | // No input sequences may be empty. 17 | type Batch struct { 18 | Inputs anyseq.Seq 19 | Outputs *anydiff.Const 20 | } 21 | 22 | // A Trainer creates batches, computes gradients, and adds 23 | // up costs for a sequence-to-sequence mapping. 24 | type Trainer struct { 25 | Func func(anyseq.Seq) anydiff.Res 26 | Cost anynet.Cost 27 | Params []*anydiff.Var 28 | 29 | // Average indicates whether or not the total cost should 30 | // be averaged before computing gradients. 31 | // This affects gradients, LastCost, and the output of 32 | // TotalCost(). 33 | Average bool 34 | 35 | // After every gradient computation, LastCost is set to 36 | // the cost from the batch. 37 | LastCost anyvec.Numeric 38 | } 39 | 40 | // Fetch produces a *Batch for the subset of samples. 41 | // The s argument must implement SampleList. 42 | // The batch may not be empty. 43 | func (t *Trainer) Fetch(s anysgd.SampleList) (anysgd.Batch, error) { 44 | if s.Len() == 0 { 45 | return nil, errors.New("fetch batch: empty batch") 46 | } 47 | l := s.(SampleList) 48 | ins := make([][]anyvec.Vector, l.Len()) 49 | outs := make([]anyvec.Vector, l.Len()) 50 | for i := 0; i < l.Len(); i++ { 51 | sample, err := l.GetSample(i) 52 | if err != nil { 53 | return nil, essentials.AddCtx("fetch batch", err) 54 | } 55 | if len(sample.Input) == 0 { 56 | return nil, errors.New("fetch batch: empty sequence") 57 | } 58 | ins[i] = sample.Input 59 | outs[i] = sample.Output 60 | } 61 | cr := outs[0].Creator() 62 | return &Batch{ 63 | Inputs: anyseq.ConstSeqList(cr, ins), 64 | Outputs: anydiff.NewConst(cr.Concat(outs...)), 65 | }, nil 66 | } 67 | 68 | // TotalCost computes the total cost for the *Batch. 69 | func (t *Trainer) TotalCost(batch anysgd.Batch) anydiff.Res { 70 | b := batch.(*Batch) 71 | n := 0 72 | if len(b.Inputs.Output()) > 0 { 73 | n = b.Inputs.Output()[0].NumPresent() 74 | } 75 | outRes := t.Func(b.Inputs) 76 | cost := t.Cost.Cost(b.Outputs, outRes, n) 77 | total := anydiff.Sum(cost) 78 | if t.Average { 79 | divisor := 1 / float64(cost.Output().Len()) 80 | return anydiff.Scale(total, total.Output().Creator().MakeNumeric(divisor)) 81 | } else { 82 | return total 83 | } 84 | } 85 | 86 | // Gradient computes the gradient for the batch's cost. 87 | // It also sets t.LastCost to the numerical value of the 88 | // total cost. 89 | // 90 | // The b argument must be a *Batch. 91 | func (t *Trainer) Gradient(b anysgd.Batch) anydiff.Grad { 92 | grad, lc := anysgd.CosterGrad(t, b, t.Params) 93 | t.LastCost = lc 94 | return grad 95 | } 96 | -------------------------------------------------------------------------------- /anysgd/adam.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anyvec" 8 | "github.com/unixpickle/essentials" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | const ( 13 | adamDefaultDecayRate1 = 0.9 14 | adamDefaultDecayRate2 = 0.999 15 | adamDefaultDamping = 1e-8 16 | ) 17 | 18 | // Adam implements the adaptive moments SGD technique 19 | // described in https://arxiv.org/pdf/1412.6980.pdf. 20 | // 21 | // Most of this code is taken from 22 | // https://github.com/unixpickle/sgd/blob/0e3d4c9d317b1095d02febdaedf802f6d1dbd5b1/adam.go. 23 | type Adam struct { 24 | // These are decay rates for the first and second 25 | // moments of the gradient. 26 | // If these are 0, defaults as suggested in the 27 | // original Adam paper are used. 28 | DecayRate1, DecayRate2 float64 29 | 30 | // Damping is used to prevent divisions by zero. 31 | // This should be very small. 32 | // If it is 0, a default is used. 33 | Damping float64 34 | 35 | // Vars is used by the marshalling routines to 36 | // assign an ordering to the variables. 37 | // It is only used by the MarshalBinary and 38 | // UnmarshalBinary methods. 39 | Vars []*anydiff.Var 40 | 41 | firstMoment anydiff.Grad 42 | secondMoment anydiff.Grad 43 | iteration float64 44 | } 45 | 46 | // Transform transforms the gradient using Adam. 47 | // 48 | // This is not thread-safe. 49 | func (a *Adam) Transform(realGrad anydiff.Grad) anydiff.Grad { 50 | a.updateMoments(realGrad) 51 | 52 | a.iteration++ 53 | scalingFactor := math.Sqrt(1-math.Pow(a.decayRate(2), a.iteration)) / 54 | (1 - math.Pow(a.decayRate(1), a.iteration)) 55 | damping := a.damping() 56 | for variable, vec := range realGrad { 57 | firstVec := a.firstMoment[variable] 58 | secondVec := a.secondMoment[variable] 59 | 60 | vec.Set(firstVec) 61 | vec.Scale(vec.Creator().MakeNumeric(scalingFactor)) 62 | 63 | divisor := secondVec.Copy() 64 | divisor.AddScalar(divisor.Creator().MakeNumeric(damping)) 65 | anyvec.Pow(divisor, divisor.Creator().MakeNumeric(0.5)) 66 | vec.Div(divisor) 67 | } 68 | 69 | return realGrad 70 | } 71 | 72 | // MarshalBinary marshals the hyperparameters and current 73 | // state into a binary format. 74 | // 75 | // This requires that a.Vars contains all and only the 76 | // variables contained in gradients passed to Transform. 77 | // If Transform has never been called, then MarshalBinary 78 | // will always succeed. 79 | func (a *Adam) MarshalBinary() (data []byte, err error) { 80 | defer essentials.AddCtxTo("marshal Adam", &err) 81 | 82 | moment1Data, err := marshalGradient(a.Vars, a.firstMoment) 83 | if err != nil { 84 | return nil, err 85 | } 86 | moment2Data, err := marshalGradient(a.Vars, a.secondMoment) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | return serializer.SerializeAny( 92 | a.DecayRate1, 93 | a.DecayRate2, 94 | a.Damping, 95 | moment1Data, 96 | moment2Data, 97 | a.iteration, 98 | ) 99 | } 100 | 101 | // UnmarshalBinary performs the inverse of MarshalBinary. 102 | // 103 | // Like MarshalBinary, this requires a.Vars to be set. 104 | // 105 | // If this fails, the old contents of the instance may 106 | // have been partially overwritten. 107 | func (a *Adam) UnmarshalBinary(data []byte) (err error) { 108 | defer essentials.AddCtxTo("unmarshal Adam", &err) 109 | 110 | var moment1Data, moment2Data []byte 111 | err = serializer.DeserializeAny(data, &a.DecayRate1, &a.DecayRate2, 112 | &a.Damping, &moment1Data, &moment2Data, &a.iteration) 113 | if err != nil { 114 | return 115 | } 116 | 117 | a.firstMoment, err = unmarshalGradient(a.Vars, moment1Data) 118 | if err != nil { 119 | return 120 | } 121 | a.secondMoment, err = unmarshalGradient(a.Vars, moment2Data) 122 | return 123 | } 124 | 125 | func (a *Adam) updateMoments(grad anydiff.Grad) { 126 | if a.firstMoment == nil { 127 | a.firstMoment = copyGrad(grad) 128 | scaleGrad(a.firstMoment, 1-a.decayRate(1)) 129 | } else { 130 | decayRate := a.decayRate(1) 131 | for variable, vec := range grad { 132 | momentVec := a.firstMoment[variable] 133 | rollingAverage(momentVec, vec.Copy(), decayRate) 134 | } 135 | } 136 | 137 | if a.secondMoment == nil { 138 | a.secondMoment = copyGrad(grad) 139 | for _, v := range a.secondMoment { 140 | anyvec.Pow(v, v.Creator().MakeNumeric(2)) 141 | } 142 | scaleGrad(a.secondMoment, 1-a.decayRate(2)) 143 | } else { 144 | decayRate := a.decayRate(2) 145 | for variable, vec := range grad { 146 | momentVec := a.secondMoment[variable] 147 | anyvec.Pow(vec, vec.Creator().MakeNumeric(2)) 148 | rollingAverage(momentVec, vec, decayRate) 149 | } 150 | } 151 | } 152 | 153 | func (a *Adam) decayRate(moment int) float64 { 154 | if moment == 1 { 155 | return valueOrDefault(a.DecayRate1, adamDefaultDecayRate1) 156 | } else if moment == 2 { 157 | return valueOrDefault(a.DecayRate2, adamDefaultDecayRate2) 158 | } else { 159 | panic("invalid moment.") 160 | } 161 | } 162 | 163 | func (a *Adam) damping() float64 { 164 | if a.Damping != 0 { 165 | return a.Damping 166 | } else { 167 | return adamDefaultDamping 168 | } 169 | } 170 | 171 | // rollingAverage computes 172 | // 173 | // oldVec := oldVec + (1-decayRate)*(newVec - oldVec) 174 | // 175 | // It overwrites newVec in the process. 176 | func rollingAverage(oldVec, newVec anyvec.Vector, decayRate float64) { 177 | newVec.Sub(oldVec) 178 | newVec.Scale(newVec.Creator().MakeNumeric(1 - decayRate)) 179 | oldVec.Add(newVec) 180 | } 181 | -------------------------------------------------------------------------------- /anysgd/adam_test.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | "github.com/unixpickle/anyvec/anyvec64" 10 | ) 11 | 12 | func TestAdamValues(t *testing.T) { 13 | v := anydiff.NewVar(anyvec32.MakeVector(2)) 14 | g := anydiff.Grad{v: anyvec32.MakeVector(2)} 15 | 16 | a := Adam{DecayRate1: 0.9, DecayRate2: 0.99, Damping: 1e-8} 17 | 18 | scaling1 := math.Sqrt(1-0.99) / (1 - 0.9) 19 | scaling2 := math.Sqrt(1-math.Pow(0.99, 2)) / (1 - math.Pow(0.9, 2)) 20 | 21 | inputGrads := [][]float32{{1, -2}, {2, 1}} 22 | expectedOuts := [][]float64{ 23 | { 24 | 0.1 * scaling1 / math.Sqrt(0.01+1e-8), 25 | 0.1 * -2 * scaling1 / math.Sqrt(0.04+1e-8), 26 | }, 27 | { 28 | scaling2 * (0.9*0.1 + 0.1*2) / math.Sqrt((0.99*0.01+0.01*4)+1e-8), 29 | scaling2 * (-0.9*0.2 + 0.1) / math.Sqrt((0.99*0.04+0.01)+1e-8), 30 | }, 31 | } 32 | 33 | for i, input := range inputGrads { 34 | g[v].SetData(input) 35 | actual := a.Transform(g)[v].Data().([]float32) 36 | expected := expectedOuts[i] 37 | for j, x := range expected { 38 | act := actual[j] 39 | if math.IsNaN(float64(act)) || math.IsNaN(x) || 40 | math.Abs(float64(act)-x) > 1e-3 { 41 | t.Errorf("time %d out %d: expected %f, got %f", i, j, x, act) 42 | } 43 | } 44 | } 45 | } 46 | 47 | func TestAdamTraining(t *testing.T) { 48 | if testing.Short() { 49 | t.Skip("skipping in short mode") 50 | } 51 | stop := newTestStopper(100000) 52 | g := newTestGradienter() 53 | s := &SGD{ 54 | Fetcher: testFetcher{}, 55 | Gradienter: g, 56 | Transformer: &Adam{}, 57 | Samples: newTestSampleList(), 58 | Rater: ConstRater(0.001), 59 | StatusFunc: stop.StatusFunc, 60 | BatchSize: 1, 61 | } 62 | 63 | s.Run(stop.Chan()) 64 | 65 | if g.errorMargin() > 1e-2 { 66 | x, y := g.current() 67 | t.Errorf("bad solution: %f, %f", x, y) 68 | } 69 | } 70 | 71 | func TestAdamMarshal(t *testing.T) { 72 | c := anyvec64.DefaultCreator{} 73 | a := &Adam{ 74 | DecayRate1: 0.3, 75 | DecayRate2: 0.4, 76 | Damping: 0.2, 77 | Vars: randomVars(c), 78 | } 79 | testMarshal(t, a, a.Vars) 80 | } 81 | -------------------------------------------------------------------------------- /anysgd/anysgd.go: -------------------------------------------------------------------------------- 1 | // Package anysgd provides tools for Stochastic Gradient 2 | // Descent. 3 | // It is intended to be used for Machine Learning, but it 4 | // can be applied to other areas as well. 5 | package anysgd 6 | 7 | import "github.com/unixpickle/anydiff" 8 | 9 | // SGD performs stochastic gradient descent. 10 | type SGD struct { 11 | // Fetcher is used to obtain Batches for mini-batch 12 | // slices of the sample list. 13 | Fetcher Fetcher 14 | 15 | // Gradienter is used to compute initial, untransformed 16 | // gradients for each mini-batch. 17 | Gradienter Gradienter 18 | 19 | // Transformer, if non-nil, is used to transform each 20 | // gradient before the step. 21 | Transformer Transformer 22 | 23 | // Samples is the list of training samples to use for 24 | // training. 25 | // It will be shuffled and re-shuffled as needed. 26 | // 27 | // The list may not be empty. 28 | Samples SampleList 29 | 30 | // Rater determines the learning rate for each step. 31 | Rater Rater 32 | 33 | // BatchSize is the mini-batch size. 34 | // If it is 0, then the entire sample list is used at 35 | // every iteration. 36 | BatchSize int 37 | 38 | // StatusFunc, if non-nil, is called before every 39 | // iteration with the next mini-batch. 40 | StatusFunc func(batch Batch) 41 | 42 | // NumProcessed keeps track of the number of samples that 43 | // have been passed to Gradienter so far. 44 | // It is used to compute the epoch for Rater. 45 | // Most of the time, this should be initialized to 0. 46 | NumProcessed int 47 | } 48 | 49 | // Run runs SGD until doneChan is closed or the fetcher 50 | // returns an error. 51 | // 52 | // Run is not thread-safe, and you should never modify the 53 | // struct's fields while Run is active. 54 | // However, you may safely read from s.NumProcessed during 55 | // calls to s.StatusFunc. 56 | func (s *SGD) Run(doneChan <-chan struct{}) error { 57 | return s.streamGradients(doneChan, func(g anydiff.Grad) { 58 | if s.Transformer != nil { 59 | g = s.Transformer.Transform(g) 60 | } 61 | scaleGrad(g, -s.Rater.Rate(s.epoch())) 62 | g.AddToVars() 63 | }) 64 | } 65 | 66 | // RunAvg is like Run, but n mini-batch gradients are 67 | // averaged together before a step is taken. 68 | // 69 | // The StatusFunc is called for every mini-batch gradient, 70 | // not just for each gradient step. 71 | func (s *SGD) RunAvg(n int, doneChan <-chan struct{}) error { 72 | if n < 1 { 73 | panic("n out of bounds") 74 | } else if n == 1 { 75 | return s.Run(doneChan) 76 | } 77 | 78 | var sum anydiff.Grad 79 | var count int 80 | return s.streamGradients(doneChan, func(g anydiff.Grad) { 81 | if sum == nil { 82 | var vars []*anydiff.Var 83 | for v := range g { 84 | vars = append(vars, v) 85 | } 86 | sum = anydiff.NewGrad(vars...) 87 | } 88 | 89 | for v, x := range sum { 90 | x.Add(g[v]) 91 | } 92 | 93 | count++ 94 | if count == n { 95 | scaleGrad(sum, 1/float64(n)) 96 | 97 | addMe := sum 98 | if s.Transformer != nil { 99 | addMe = s.Transformer.Transform(addMe) 100 | } 101 | scaleGrad(addMe, -s.Rater.Rate(s.epoch())) 102 | addMe.AddToVars() 103 | 104 | sum.Clear() 105 | count = 0 106 | } 107 | }) 108 | } 109 | 110 | func (s *SGD) batchSize(remaining int) int { 111 | if s.BatchSize == 0 || s.BatchSize > remaining { 112 | return remaining 113 | } else { 114 | return s.BatchSize 115 | } 116 | } 117 | 118 | func (s *SGD) streamGradients(doneChan <-chan struct{}, f func(anydiff.Grad)) error { 119 | if s.Samples.Len() == 0 { 120 | panic("cannot run SGD with empty sample list") 121 | } 122 | 123 | errChan := make(chan error, 1) 124 | batchChan := make(chan *batchInfo) 125 | 126 | go func() { 127 | idx := s.Samples.Len() 128 | for { 129 | select { 130 | case <-doneChan: 131 | return 132 | default: 133 | } 134 | remaining := s.Samples.Len() - idx 135 | if remaining == 0 { 136 | Shuffle(s.Samples) 137 | idx = 0 138 | remaining = s.Samples.Len() 139 | } 140 | batchSize := s.batchSize(remaining) 141 | batchSlice := s.Samples.Slice(idx, idx+batchSize) 142 | idx += batchSize 143 | batch, err := s.Fetcher.Fetch(batchSlice) 144 | if err != nil { 145 | errChan <- err 146 | return 147 | } 148 | select { 149 | case batchChan <- &batchInfo{batch, batchSize}: 150 | case <-doneChan: 151 | return 152 | } 153 | } 154 | }() 155 | 156 | for { 157 | select { 158 | case <-doneChan: 159 | return nil 160 | default: 161 | } 162 | 163 | var info *batchInfo 164 | select { 165 | case info = <-batchChan: 166 | case err := <-errChan: 167 | return err 168 | case <-doneChan: 169 | return nil 170 | } 171 | 172 | if s.StatusFunc != nil { 173 | s.StatusFunc(info.Batch) 174 | select { 175 | case <-doneChan: 176 | return nil 177 | default: 178 | } 179 | } 180 | 181 | s.NumProcessed += info.Size 182 | 183 | grad := s.Gradienter.Gradient(info.Batch) 184 | f(grad) 185 | } 186 | } 187 | 188 | func (s *SGD) epoch() float64 { 189 | return float64(s.NumProcessed) / float64(s.Samples.Len()) 190 | } 191 | 192 | type batchInfo struct { 193 | Batch Batch 194 | Size int 195 | } 196 | -------------------------------------------------------------------------------- /anysgd/anysgd_test.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | ) 10 | 11 | type testSample struct { 12 | X2 float64 13 | Y2 float64 14 | XY float64 15 | X float64 16 | Y float64 17 | } 18 | 19 | func (t *testSample) Apply(x, y anydiff.Res) anydiff.Res { 20 | mk := x.Output().Creator().MakeNumeric 21 | a := anydiff.Scale(anydiff.Mul(x, x), mk(t.X2)) 22 | b := anydiff.Scale(anydiff.Mul(y, y), mk(t.Y2)) 23 | c := anydiff.Scale(anydiff.Mul(x, y), mk(t.XY)) 24 | d := anydiff.Scale(x, mk(t.X)) 25 | e := anydiff.Scale(y, mk(t.Y)) 26 | return anydiff.Add( 27 | anydiff.Add(a, b), 28 | anydiff.Add(anydiff.Add(c, d), e), 29 | ) 30 | } 31 | 32 | type testSampleList []*testSample 33 | 34 | func newTestSampleList() testSampleList { 35 | // Together, these polynomials add up to 3x^2+3xy-2x+y^2. 36 | // The global minimum is (x = 4/3, y = -2). 37 | return testSampleList{ 38 | {X2: 2, X: -1, XY: 0, Y2: 0.5}, 39 | {X2: -1, X: 0, XY: 2, Y2: 0.5}, 40 | {X2: 2, X: -1, XY: 1, Y2: 0}, 41 | } 42 | } 43 | 44 | func (t testSampleList) Len() int { 45 | return len(t) 46 | } 47 | 48 | func (t testSampleList) Swap(i, j int) { 49 | t[i], t[j] = t[j], t[i] 50 | } 51 | 52 | func (t testSampleList) Slice(i, j int) SampleList { 53 | return append(testSampleList{}, t[i:j]...) 54 | } 55 | 56 | type testFetcher struct{} 57 | 58 | func (t testFetcher) Fetch(s SampleList) (Batch, error) { 59 | return s, nil 60 | } 61 | 62 | type testStopper struct { 63 | callsRemaining int 64 | channel chan struct{} 65 | } 66 | 67 | func newTestStopper(calls int) *testStopper { 68 | return &testStopper{callsRemaining: calls, channel: make(chan struct{})} 69 | } 70 | 71 | func (t *testStopper) StatusFunc(b Batch) { 72 | t.callsRemaining-- 73 | if t.callsRemaining == 0 { 74 | close(t.channel) 75 | } 76 | } 77 | 78 | func (t *testStopper) Chan() <-chan struct{} { 79 | return t.channel 80 | } 81 | 82 | type testGradienter struct { 83 | X *anydiff.Var 84 | Y *anydiff.Var 85 | } 86 | 87 | func newTestGradienter() *testGradienter { 88 | c := anyvec32.DefaultCreator{} 89 | return &testGradienter{ 90 | X: anydiff.NewVar(c.MakeVector(1)), 91 | Y: anydiff.NewVar(c.MakeVector(1)), 92 | } 93 | } 94 | 95 | func (t *testGradienter) Gradient(batch Batch) anydiff.Grad { 96 | s := batch.(testSampleList) 97 | var cost anydiff.Res 98 | for _, x := range s { 99 | res := x.Apply(t.X, t.Y) 100 | if cost == nil { 101 | cost = res 102 | } else { 103 | cost = anydiff.Add(cost, res) 104 | } 105 | } 106 | grad := anydiff.Grad{ 107 | t.X: t.X.Vector.Creator().MakeVector(1), 108 | t.Y: t.Y.Vector.Creator().MakeVector(1), 109 | } 110 | oneVec := t.X.Vector.Creator().MakeVectorData( 111 | t.X.Vector.Creator().MakeNumericList([]float64{1}), 112 | ) 113 | cost.Propagate(oneVec, grad) 114 | return grad 115 | } 116 | 117 | func (t *testGradienter) current() (x, y float64) { 118 | x32 := t.X.Vector.Data().([]float32)[0] 119 | y32 := t.Y.Vector.Data().([]float32)[0] 120 | return float64(x32), float64(y32) 121 | } 122 | 123 | func (t *testGradienter) errorMargin() float64 { 124 | x, y := t.current() 125 | return math.Max( 126 | math.Abs(float64(x)-4.0/3), 127 | math.Abs(float64(y)+2), 128 | ) 129 | } 130 | 131 | func TestSGD(t *testing.T) { 132 | if testing.Short() { 133 | t.Skip("skipping in short mode") 134 | } 135 | 136 | stop := newTestStopper(400000) 137 | g := newTestGradienter() 138 | s := &SGD{ 139 | Fetcher: testFetcher{}, 140 | Gradienter: g, 141 | Samples: newTestSampleList(), 142 | Rater: ConstRater(0.0002), 143 | StatusFunc: stop.StatusFunc, 144 | BatchSize: 1, 145 | } 146 | 147 | s.Run(stop.Chan()) 148 | 149 | if g.errorMargin() > 1e-2 { 150 | x, y := g.current() 151 | t.Errorf("bad solution: %f, %f", x, y) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /anysgd/hash_split.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "math" 7 | "sort" 8 | ) 9 | 10 | // A Hasher is a SampleList with the added capability to 11 | // produce a hash for a given sample. 12 | type Hasher interface { 13 | SampleList 14 | Hash(i int) []byte 15 | } 16 | 17 | // HashSplit partitions a Hasher. 18 | // It can be used to deterministically split data up into 19 | // separate validation and training samples. 20 | // 21 | // The Hasher h will be re-ordered as needed for internal 22 | // computations. 23 | // 24 | // The leftRatio argument specifies the expected fraction 25 | // of samples that should end up on the left partition. 26 | func HashSplit(h Hasher, leftRatio float64) (left, right SampleList) { 27 | if leftRatio == 0 { 28 | return h.Slice(0, 0), h 29 | } else if leftRatio == 1 { 30 | return h, h.Slice(0, 0) 31 | } 32 | cutoff := hashCutoff(leftRatio) 33 | insertIdx := 0 34 | for i := 0; i < h.Len(); i++ { 35 | hash := h.Hash(i) 36 | if compareHashes(hash, cutoff) < 0 { 37 | h.Swap(insertIdx, i) 38 | insertIdx++ 39 | } 40 | } 41 | splitIdx := sort.Search(h.Len(), func(i int) bool { 42 | return compareHashes(h.Hash(i), cutoff) >= 0 43 | }) 44 | return h.Slice(0, splitIdx), h.Slice(splitIdx, h.Len()) 45 | } 46 | 47 | // Most of the following code was taken from my old sgd package. 48 | // See https://github.com/unixpickle/sgd/blob/0e3d4c9d317b1095d02febdaedf802f6d1dbd5b1/hash_split.go. 49 | 50 | func hashCutoff(ratio float64) []byte { 51 | res := make([]byte, 8) 52 | for i := range res { 53 | ratio *= 256 54 | value := int(ratio) 55 | ratio -= float64(value) 56 | if value == 256 { 57 | value = 255 58 | } 59 | res[i] = byte(value) 60 | } 61 | return res 62 | } 63 | 64 | func compareHashes(h1, h2 []byte) int { 65 | max := len(h1) 66 | if len(h2) > max { 67 | max = len(h2) 68 | } 69 | for i := 0; i < max; i++ { 70 | var h1Val, h2Val byte 71 | if i < len(h1) { 72 | h1Val = h1[i] 73 | } 74 | if i < len(h2) { 75 | h2Val = h2[i] 76 | } 77 | if h1Val < h2Val { 78 | return -1 79 | } else if h1Val > h2Val { 80 | return 1 81 | } 82 | } 83 | return 0 84 | } 85 | 86 | func writeFloatBits(w io.Writer, temp []byte, val float64) { 87 | binary.BigEndian.PutUint64(temp, math.Float64bits(val)) 88 | w.Write(temp) 89 | } 90 | -------------------------------------------------------------------------------- /anysgd/interfaces.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "encoding" 5 | 6 | "github.com/unixpickle/anydiff" 7 | ) 8 | 9 | // A Transformer transforms gradients. 10 | // For example, pre-conditioning could be implemented as a 11 | // transformer. 12 | // 13 | // After its first call, a Transformer expects to see 14 | // gradients of the same form (i.e. containing the same 15 | // variables). 16 | // 17 | // A Transformer may modify its own input and return the 18 | // same gradient as an output. 19 | // However, a Transformer should not modify its input 20 | // after Transform returns. 21 | // In other words, the input still belongs to the caller, 22 | // and the transformer should not retain a reference to 23 | // the input. 24 | // If a Transformer needs to cache things relating to its 25 | // inputs, it must allocate a separate gradient. 26 | // 27 | // A Transformer's output is only guaranteed to be valid 28 | // until the next time Transform is called. 29 | type Transformer interface { 30 | Transform(g anydiff.Grad) anydiff.Grad 31 | } 32 | 33 | // TransformMarshaler is a Transformer with support for 34 | // binary marshalling and unmarshalling. 35 | type TransformMarshaler interface { 36 | Transformer 37 | encoding.BinaryMarshaler 38 | encoding.BinaryUnmarshaler 39 | } 40 | 41 | // A Batch is an immutable list of samples. 42 | // 43 | // In contrast to a SampleList, a Batch is not assumed to 44 | // use lazy evaluation. 45 | // This means that Batches should only be created when 46 | // they are about to be used. 47 | // 48 | // Batches are obtained using a Fetcher and then used as 49 | // arguments to a Gradienter. 50 | type Batch interface{} 51 | 52 | // A Fetcher is responsible for fetching Batches for 53 | // SampleLists. 54 | // 55 | // Typically, a Fetcher will be used concurrently with 56 | // SGD, making it possible to have a new Batch available 57 | // exactly when the previous one is done being used. 58 | type Fetcher interface { 59 | Fetch(s SampleList) (Batch, error) 60 | } 61 | 62 | // A Gradienter computes a gradient for a Batch. 63 | // 64 | // The same gradient instance may be re-used by successive 65 | // calls to Gradient. 66 | type Gradienter interface { 67 | Gradient(b Batch) anydiff.Grad 68 | } 69 | 70 | // A Rater determines the learning rate given the epoch 71 | // number. 72 | // An "epoch" is a full pass over the training set, so 73 | // fractional epochs are possible. 74 | type Rater interface { 75 | Rate(epoch float64) float64 76 | } 77 | 78 | // A SampleList represents a list of training samples. 79 | type SampleList interface { 80 | // Len returns the number of samples. 81 | Len() int 82 | 83 | // Swap swaps two samples. 84 | Swap(i, j int) 85 | 86 | // Slice generates a shallow copy of a subset of the 87 | // list. 88 | Slice(i, j int) SampleList 89 | } 90 | 91 | // PostShuffler is used to notify a SampleList that it has 92 | // been shuffled, allowing it to perform any sample 93 | // re-ordering it likes. 94 | // 95 | // For example, you might use a PostShuffler to make sure 96 | // that "compatible" samples are close to each other so 97 | // they end up in the same mini-batch. 98 | type PostShuffler interface { 99 | PostShuffle() 100 | } 101 | 102 | // A Coster computes differentiable costs for a Batch. 103 | // The resulting cost vectors should have one component. 104 | type Coster interface { 105 | TotalCost(b Batch) anydiff.Res 106 | } 107 | -------------------------------------------------------------------------------- /anysgd/marshal.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anyvec/anyvecsave" 8 | "github.com/unixpickle/serializer" 9 | ) 10 | 11 | var errVarsGradMismatch = errors.New("variable list does not match gradients") 12 | 13 | func marshalGradient(vars []*anydiff.Var, grad anydiff.Grad) ([]byte, error) { 14 | if grad == nil { 15 | return []byte{}, nil 16 | } 17 | if len(vars) != len(grad) { 18 | return nil, errVarsGradMismatch 19 | } 20 | 21 | var vecObjs []interface{} 22 | for _, v := range vars { 23 | vec, ok := grad[v] 24 | if !ok { 25 | return nil, errVarsGradMismatch 26 | } 27 | vecObjs = append(vecObjs, &anyvecsave.S{Vector: vec}) 28 | } 29 | 30 | return serializer.SerializeAny(vecObjs...) 31 | } 32 | 33 | func unmarshalGradient(vars []*anydiff.Var, data []byte) (anydiff.Grad, error) { 34 | if len(data) == 0 { 35 | return nil, nil 36 | } 37 | 38 | var dests []interface{} 39 | for _ = range vars { 40 | dests = append(dests, new(*anyvecsave.S)) 41 | } 42 | if err := serializer.DeserializeAny(data, dests...); err != nil { 43 | return nil, err 44 | } 45 | 46 | res := anydiff.Grad{} 47 | for i, v := range vars { 48 | vec := (*dests[i].(**anyvecsave.S)).Vector 49 | if vec.Len() != v.Vector.Len() { 50 | return nil, errors.New("bad vector length") 51 | } else if vec.Creator() != v.Vector.Creator() { 52 | return nil, errors.New("bad vector creator") 53 | } 54 | res[v] = vec 55 | } 56 | 57 | return res, nil 58 | } 59 | -------------------------------------------------------------------------------- /anysgd/marshal_test.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "math/rand" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvec64" 11 | ) 12 | 13 | func TestGradientMarshal(t *testing.T) { 14 | c := anyvec64.DefaultCreator{} 15 | vars := randomVars(c) 16 | grad := randomGrad(vars) 17 | 18 | data, err := marshalGradient(vars, grad) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | newGrad, err := unmarshalGradient(vars, data) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | if !reflect.DeepEqual(grad, newGrad) { 29 | t.Error("gradient mismatch") 30 | } 31 | } 32 | 33 | func testMarshal(t *testing.T, inst TransformMarshaler, v []*anydiff.Var) { 34 | var inGrads []anydiff.Grad 35 | var outGrads []anydiff.Grad 36 | var checkpoints [][]byte 37 | 38 | for i := 0; i < 5; i++ { 39 | inGrad := randomGrad(v) 40 | data, err := inst.MarshalBinary() 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | outGrad := copyGrad(inst.Transform(copyGrad(inGrad))) 45 | inGrads = append(inGrads, inGrad) 46 | checkpoints = append(checkpoints, data) 47 | outGrads = append(outGrads, outGrad) 48 | } 49 | 50 | for _, i := range []int{2, 0, 3, 4, 1} { 51 | err := inst.UnmarshalBinary(checkpoints[i]) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | in := inGrads[i] 56 | out := inst.Transform(in) 57 | if !reflect.DeepEqual(out, outGrads[i]) { 58 | t.Errorf("gradient %d came out wrong", i) 59 | } 60 | } 61 | } 62 | 63 | func randomVars(c anyvec.Creator) []*anydiff.Var { 64 | vars := []*anydiff.Var{} 65 | for i := 0; i < 20; i++ { 66 | size := i * rand.Intn(3) 67 | vec := c.MakeVector(size) 68 | anyvec.Rand(vec, anyvec.Normal, nil) 69 | v := anydiff.NewVar(vec) 70 | vars = append(vars, v) 71 | } 72 | return vars 73 | } 74 | 75 | func randomGrad(vars []*anydiff.Var) anydiff.Grad { 76 | resGrad := anydiff.Grad{} 77 | for _, v := range vars { 78 | gradVec := v.Vector.Creator().MakeVector(v.Vector.Len()) 79 | anyvec.Rand(gradVec, anyvec.Normal, nil) 80 | resGrad[v] = gradVec 81 | } 82 | return resGrad 83 | } 84 | -------------------------------------------------------------------------------- /anysgd/momentum.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import "github.com/unixpickle/anydiff" 4 | 5 | // Momentum implements SGD with momentum. 6 | // 7 | // The transformed gradient v is computed as 8 | // 9 | // v := momentum * v + grad 10 | type Momentum struct { 11 | Momentum float64 12 | rolling anydiff.Grad 13 | } 14 | 15 | // Transform transforms the gradient using momentum. 16 | // 17 | // This is not thread-safe. 18 | func (m *Momentum) Transform(g anydiff.Grad) anydiff.Grad { 19 | if m.rolling == nil { 20 | m.rolling = copyGrad(g) 21 | return g 22 | } 23 | for v, x := range m.rolling { 24 | x.Scale(x.Creator().MakeNumeric(m.Momentum)) 25 | x.Add(g[v]) 26 | g[v].Set(x) 27 | } 28 | return g 29 | } 30 | -------------------------------------------------------------------------------- /anysgd/rmsprop.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | ) 7 | 8 | const ( 9 | rmspropDefaultDecayRate = 0.9 10 | rmspropDefaultDamping = 1e-8 11 | ) 12 | 13 | // RMSProp implements the RMSProp regularizer; see: 14 | // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf. 15 | type RMSProp struct { 16 | // The decay rate for the running average. 17 | // If it is 0, a default of 0.9 is used. 18 | DecayRate float64 19 | 20 | // Damping is used to prevent divisions by zero. 21 | // This should be very small. 22 | // If it is 0, a default is used. 23 | Damping float64 24 | 25 | moment anydiff.Grad 26 | } 27 | 28 | // Transform transforms the gradient using RMSProp. 29 | // 30 | // This is not thread-safe. 31 | func (r *RMSProp) Transform(realGrad anydiff.Grad) anydiff.Grad { 32 | if r.moment == nil { 33 | r.moment = anydiff.Grad{} 34 | for v, grad := range realGrad { 35 | sq := grad.Copy() 36 | anyvec.Pow(sq, sq.Creator().MakeNumeric(2)) 37 | r.moment[v] = sq 38 | } 39 | } else { 40 | for v, grad := range realGrad { 41 | sq := grad.Copy() 42 | anyvec.Pow(sq, sq.Creator().MakeNumeric(2)) 43 | sq.Sub(r.moment[v]) 44 | sq.Scale(sq.Creator().MakeNumeric(1 - r.decayRate())) 45 | r.moment[v].Add(sq) 46 | } 47 | } 48 | for v, grad := range realGrad { 49 | div := r.moment[v].Copy() 50 | div.AddScalar(div.Creator().MakeNumeric(r.damping())) 51 | anyvec.Pow(div, div.Creator().MakeNumeric(-0.5)) 52 | grad.Mul(div) 53 | } 54 | return realGrad 55 | } 56 | 57 | func (r *RMSProp) decayRate() float64 { 58 | if r.DecayRate == 0 { 59 | return rmspropDefaultDecayRate 60 | } else { 61 | return r.DecayRate 62 | } 63 | } 64 | 65 | func (r *RMSProp) damping() float64 { 66 | if r.Damping == 0 { 67 | return rmspropDefaultDamping 68 | } else { 69 | return r.Damping 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /anysgd/util.go: -------------------------------------------------------------------------------- 1 | package anysgd 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec" 9 | ) 10 | 11 | // Shuffle shuffles a list of samples. 12 | // If the list implements PostShuffler, then PostShuffle 13 | // is called after the shuffle completes. 14 | func Shuffle(s SampleList) { 15 | for i := 0; i < s.Len(); i++ { 16 | j := i + rand.Intn(s.Len()-i) 17 | s.Swap(i, j) 18 | } 19 | if p, ok := s.(PostShuffler); ok { 20 | p.PostShuffle() 21 | } 22 | } 23 | 24 | // LengthSampleList is a SampleList that carries no 25 | // information other than its length. 26 | // 27 | // This may be used in conjunction with a Fetcher that 28 | // generates data dynamically. 29 | type LengthSampleList int 30 | 31 | // Len returns int(l). 32 | func (l LengthSampleList) Len() int { 33 | return int(l) 34 | } 35 | 36 | // Swap does nothing. 37 | func (l LengthSampleList) Swap(i, j int) { 38 | } 39 | 40 | // Slice creates a LengthSampleList with the new size. 41 | func (l LengthSampleList) Slice(i, j int) SampleList { 42 | return LengthSampleList(j - i) 43 | } 44 | 45 | // A ConstRater is a Rater which always returns the same 46 | // constant learning rate. 47 | type ConstRater float64 48 | 49 | // Rate returns float64(c). 50 | func (c ConstRater) Rate(epoch float64) float64 { 51 | return float64(c) 52 | } 53 | 54 | // An ExpRater is a Rater which returns 55 | // 56 | // Bias + Coeff*Decay^t 57 | // 58 | // This is a standard kind of exponential decay schedule 59 | // used for SGD. 60 | type ExpRater struct { 61 | Bias float64 62 | Coeff float64 63 | Decay float64 64 | } 65 | 66 | // Rate computes the rate for time t. 67 | func (e *ExpRater) Rate(t float64) float64 { 68 | return e.Bias + e.Coeff*math.Pow(e.Decay, t) 69 | } 70 | 71 | // CosterGrad computes a gradient and a cost for the 72 | // batch. 73 | func CosterGrad(c Coster, b Batch, params []*anydiff.Var) (anydiff.Grad, 74 | anyvec.Numeric) { 75 | grad := anydiff.NewGrad(params...) 76 | cost := c.TotalCost(b) 77 | cr := cost.Output().Creator() 78 | data := cr.MakeNumericList([]float64{1}) 79 | upstream := cr.MakeVectorData(data) 80 | cost.Propagate(upstream, grad) 81 | return grad, anyvec.Sum(cost.Output()) 82 | } 83 | 84 | func copyGrad(g anydiff.Grad) anydiff.Grad { 85 | res := anydiff.Grad{} 86 | for va, vec := range g { 87 | res[va] = vec.Copy() 88 | } 89 | return res 90 | } 91 | 92 | func scaleGrad(g anydiff.Grad, s float64) { 93 | for _, v := range g { 94 | g.Scale(v.Creator().MakeNumeric(s)) 95 | return 96 | } 97 | } 98 | 99 | func valueOrDefault(val, def float64) float64 { 100 | if val != 0 { 101 | return val 102 | } else { 103 | return def 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /cost_test.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anydiff/anydifftest" 9 | "github.com/unixpickle/anyvec/anyvec32" 10 | ) 11 | 12 | func TestDotCost(t *testing.T) { 13 | testCost(t, DotCost{}, []float32{ 14 | 1, 0.5, 2, 15 | 3, -1, 2, 16 | }, []float32{ 17 | -1, -2, -3, 18 | -2, -3, -1, 19 | }, []float32{8, 5}, 2) 20 | } 21 | 22 | func TestMSE(t *testing.T) { 23 | testCost(t, MSE{}, []float32{ 24 | 1, 0.5, 2, 25 | 3, -1, 2, 26 | }, []float32{ 27 | -1, -2, -3, 28 | -2, -3, -1, 29 | }, []float32{11 + 3.0/4, 12 + 2.0/3}, 2) 30 | } 31 | 32 | func TestSigmoidCE(t *testing.T) { 33 | t.Run("Unaveraged", func(t *testing.T) { 34 | testCost(t, SigmoidCE{}, []float32{ 35 | 1, 0.6, 36 | 0.2, 0, 37 | }, []float32{ 38 | 1, 0, 39 | 2, -1, 40 | }, []float32{ 41 | 0.3132616875 + 0.6931471806, 42 | 0.02538560221 + 1.7015424088 + 0.3132616875, 43 | }, 2) 44 | }) 45 | t.Run("Averaged", func(t *testing.T) { 46 | testCost(t, SigmoidCE{Average: true}, []float32{ 47 | 1, 0.6, 0, 48 | 0.2, 0, 0, 49 | }, []float32{ 50 | 1, 0, -50, 51 | 2, -1, -50, 52 | }, []float32{ 53 | (1.0 / 3) * (0.3132616875 + 0.6931471806), 54 | (1.0 / 3) * (0.02538560221 + 1.7015424088 + 0.3132616875), 55 | }, 2) 56 | }) 57 | } 58 | 59 | func TestHinge(t *testing.T) { 60 | testCost(t, Hinge{}, []float32{ 61 | 1, -1, -1, 1, 1, 1, -1, -1, 62 | }, []float32{ 63 | 0.5, 1, -2, 0.9, -2, 2, -1.5, -0.9, 64 | }, []float32{ 65 | 2.5, 0.1, 3, 0.1, 66 | }, 4) 67 | } 68 | 69 | func TestMultiHinge(t *testing.T) { 70 | t.Run("WW", func(t *testing.T) { 71 | testCost(t, WestonWatkins, []float32{ 72 | 0, 1, 0, 73 | 0, 0, 1, 74 | 1, 0, 0, 75 | }, []float32{ 76 | 1, 2.5, 2, 77 | -2, -5, -3, 78 | -5, -2, -3, 79 | }, []float32{ 80 | 0.5, 2, 7, 81 | }, 3) 82 | }) 83 | t.Run("CS", func(t *testing.T) { 84 | testCost(t, CrammerSinger, []float32{ 85 | 0, 1, 0, 86 | 0, 0, 1, 87 | }, []float32{ 88 | 1, 2.5, 2, 89 | -2, -5, -3, 90 | }, []float32{ 91 | 0.5, 2, 92 | }, 2) 93 | }) 94 | } 95 | 96 | func TestMultiHingeProp(t *testing.T) { 97 | v1 := anydiff.NewVar(anyvec32.MakeVectorData([]float32{1, 2, 2.5, 3, 3.5, 4, 4.5, 5})) 98 | v2 := anydiff.NewVar(anyvec32.MakeVectorData([]float32{0, 0, 1, 0, 1, 0, 0, 0})) 99 | 100 | t.Run("WW", func(t *testing.T) { 101 | checker := &anydifftest.ResChecker{ 102 | F: func() anydiff.Res { 103 | return WestonWatkins.Cost(v2, v1, 2) 104 | }, 105 | V: []*anydiff.Var{v1}, 106 | } 107 | checker.FullCheck(t) 108 | }) 109 | 110 | t.Run("CS", func(t *testing.T) { 111 | checker := &anydifftest.ResChecker{ 112 | F: func() anydiff.Res { 113 | return CrammerSinger.Cost(v2, v1, 2) 114 | }, 115 | V: []*anydiff.Var{v1}, 116 | } 117 | checker.FullCheck(t) 118 | }) 119 | } 120 | 121 | func testCost(t *testing.T, c Cost, desired, output, expected []float32, n int) { 122 | desiredRes := anydiff.NewConst(anyvec32.MakeVectorData(desired)) 123 | outputRes := anydiff.NewConst(anyvec32.MakeVectorData(output)) 124 | 125 | actual := c.Cost(desiredRes, outputRes, n).Output().Data().([]float32) 126 | 127 | for i, x := range expected { 128 | a := actual[i] 129 | if math.IsNaN(float64(a)) || math.Abs(float64(x-a)) > 1e-3 { 130 | t.Errorf("component %d: expected %f but got %f", i, x, a) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func init() { 13 | serializer.RegisterTypedDeserializer((&Debug{}).SerializerType(), DeserializeDebug) 14 | } 15 | 16 | // Debug is a layer which logs statistics about its 17 | // inputs. 18 | // Besides logging, the Debug layer does nothing to 19 | // interfere with the flow of values in a network. 20 | type Debug struct { 21 | // Writer to which stats are printed. 22 | // If nil, os.Stdout is used. 23 | Writer io.Writer 24 | 25 | ID string 26 | PrintRaw bool 27 | PrintMean bool 28 | PrintVariance bool 29 | } 30 | 31 | // DeserializeDebug deserializes a Debug layer. 32 | // The Writer will be nil. 33 | func DeserializeDebug(d []byte) (*Debug, error) { 34 | var res Debug 35 | err := serializer.DeserializeAny(d, &res.ID, &res.PrintRaw, &res.PrintMean, 36 | &res.PrintVariance) 37 | if err != nil { 38 | return nil, err 39 | } 40 | return &res, nil 41 | } 42 | 43 | // Apply logs information about its input. 44 | // The input is returned, untouched. 45 | func (d *Debug) Apply(in anydiff.Res, n int) anydiff.Res { 46 | if d.PrintRaw { 47 | d.println("batch of", n, "values:", in.Output().Data()) 48 | } 49 | cols := in.Output().Len() / n 50 | if d.PrintMean || d.PrintVariance { 51 | mean := anyvec.SumRows(in.Output(), cols) 52 | normalizer := mean.Creator().MakeNumeric(1 / float64(n)) 53 | mean.Scale(normalizer) 54 | if d.PrintMean { 55 | d.println("mean:", mean.Data()) 56 | } 57 | if d.PrintVariance { 58 | two := mean.Creator().MakeNumeric(2) 59 | squared := in.Output().Copy() 60 | anyvec.Pow(squared, two) 61 | variance := anyvec.SumRows(squared, cols) 62 | variance.Scale(normalizer) 63 | anyvec.Pow(mean, two) 64 | variance.Sub(mean) 65 | d.println("variance:", variance.Data()) 66 | } 67 | } 68 | return in 69 | } 70 | 71 | // SerializerType returns the unique ID used to serialize 72 | // a Debug layer with the serializer package. 73 | func (d *Debug) SerializerType() string { 74 | return "github.com/unixpickle/anynet.Debug" 75 | } 76 | 77 | // Serialize serializes the layer. 78 | func (d *Debug) Serialize() ([]byte, error) { 79 | return serializer.SerializeAny(d.ID, d.PrintRaw, d.PrintMean, d.PrintVariance) 80 | } 81 | 82 | func (d *Debug) println(args ...interface{}) { 83 | newArgs := append([]interface{}{"Debug (" + d.ID + "):"}, args...) 84 | if d.Writer == nil { 85 | fmt.Println(newArgs...) 86 | } else { 87 | fmt.Fprintln(d.Writer, newArgs...) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /demo/mnist/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/unixpickle/anydiff" 7 | "github.com/unixpickle/anynet" 8 | "github.com/unixpickle/anynet/anyff" 9 | "github.com/unixpickle/anynet/anysgd" 10 | "github.com/unixpickle/anyvec" 11 | "github.com/unixpickle/anyvec/anyvec32" 12 | "github.com/unixpickle/mnist" 13 | "github.com/unixpickle/rip" 14 | ) 15 | 16 | var Creator anyvec.Creator 17 | 18 | func main() { 19 | log.Println("Setting up...") 20 | 21 | Creator = anyvec32.CurrentCreator() 22 | 23 | network := anynet.Net{ 24 | anynet.NewFC(Creator, 28*28, 300), 25 | anynet.Tanh, 26 | anynet.NewFC(Creator, 300, 10), 27 | anynet.LogSoftmax, 28 | } 29 | 30 | t := &anyff.Trainer{ 31 | Net: network, 32 | Cost: anynet.DotCost{}, 33 | Params: network.Parameters(), 34 | Average: true, 35 | } 36 | 37 | var iterNum int 38 | s := &anysgd.SGD{ 39 | Fetcher: t, 40 | Gradienter: t, 41 | Transformer: &anysgd.Adam{}, 42 | Samples: mnist.LoadTrainingDataSet().AnyNetSamples(Creator), 43 | Rater: anysgd.ConstRater(0.001), 44 | StatusFunc: func(b anysgd.Batch) { 45 | log.Printf("iter %d: cost=%v", iterNum, t.LastCost) 46 | iterNum++ 47 | }, 48 | BatchSize: 100, 49 | } 50 | 51 | log.Println("Press ctrl+c once to stop...") 52 | s.Run(rip.NewRIP().Chan()) 53 | 54 | log.Println("Computing statistics...") 55 | printStats(network) 56 | } 57 | 58 | func printStats(net anynet.Net) { 59 | ts := mnist.LoadTestingDataSet() 60 | cf := func(in []float64) int { 61 | vec := Creator.MakeVectorData(Creator.MakeNumericList(in)) 62 | inRes := anydiff.NewConst(vec) 63 | res := net.Apply(inRes, 1).Output() 64 | return anyvec.MaxIndex(res) 65 | } 66 | log.Println("Validation:", ts.NumCorrect(cf)) 67 | log.Println("Histogram:", ts.CorrectnessHistogram(cf)) 68 | } 69 | -------------------------------------------------------------------------------- /dropout.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anyvec" 6 | "github.com/unixpickle/essentials" 7 | "github.com/unixpickle/serializer" 8 | ) 9 | 10 | func init() { 11 | var d Dropout 12 | serializer.RegisterTypedDeserializer(d.SerializerType(), DeserializeDropout) 13 | } 14 | 15 | // A Dropout layer applies dropout regularization. 16 | // When disabled, a dropout layer scales its input to 17 | // compute the "expected output". 18 | type Dropout struct { 19 | Enabled bool 20 | 21 | // The probability of keeping any given input. 22 | KeepProb float64 23 | } 24 | 25 | // DeserializeDropout deserializes a Dropout. 26 | func DeserializeDropout(d []byte) (*Dropout, error) { 27 | var enabled serializer.Int 28 | var keepProb serializer.Float64 29 | if err := serializer.DeserializeAny(d, &enabled, &keepProb); err != nil { 30 | return nil, essentials.AddCtx("deserialize Dropout", err) 31 | } 32 | return &Dropout{ 33 | Enabled: enabled == 1, 34 | KeepProb: float64(keepProb), 35 | }, nil 36 | } 37 | 38 | // Apply applies the layer. 39 | func (d *Dropout) Apply(in anydiff.Res, n int) anydiff.Res { 40 | c := in.Output().Creator() 41 | if !d.Enabled { 42 | return anydiff.Scale(in, c.MakeNumeric(d.KeepProb)) 43 | } 44 | mask := c.MakeVector(in.Output().Len()) 45 | anyvec.Rand(mask, anyvec.Uniform, nil) 46 | anyvec.LessThan(mask, c.MakeNumeric(d.KeepProb)) 47 | return anydiff.Mul(in, anydiff.NewConst(mask)) 48 | } 49 | 50 | // SerializerType returns the unique ID used to serialize 51 | // a Dropout with the serializer package. 52 | func (d *Dropout) SerializerType() string { 53 | return "github.com/unixpickle/anynet.Dropout" 54 | } 55 | 56 | // Serialize serializes the Dropout. 57 | func (d *Dropout) Serialize() ([]byte, error) { 58 | enabledFlag := serializer.Int(0) 59 | if d.Enabled { 60 | enabledFlag = 1 61 | } 62 | return serializer.SerializeAny(enabledFlag, serializer.Float64(d.KeepProb)) 63 | } 64 | -------------------------------------------------------------------------------- /fc.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anyvec" 10 | "github.com/unixpickle/anyvec/anyvecsave" 11 | "github.com/unixpickle/essentials" 12 | "github.com/unixpickle/serializer" 13 | ) 14 | 15 | func init() { 16 | var f FC 17 | serializer.RegisterTypedDeserializer(f.SerializerType(), DeserializeFC) 18 | } 19 | 20 | // FC is a fully-connected layer. 21 | type FC struct { 22 | InCount int 23 | OutCount int 24 | Weights *anydiff.Var 25 | Biases *anydiff.Var 26 | } 27 | 28 | // DeserializeFC attempts to deserialize an FC. 29 | func DeserializeFC(d []byte) (*FC, error) { 30 | var weights, biases *anyvecsave.S 31 | if err := serializer.DeserializeAny(d, &weights, &biases); err != nil { 32 | return nil, essentials.AddCtx("deserialize FC", err) 33 | } 34 | inCount := weights.Vector.Len() / biases.Vector.Len() 35 | outCount := biases.Vector.Len() 36 | if inCount*outCount != weights.Vector.Len() { 37 | return nil, errors.New("deserialize FC: invalid matrix dimensions") 38 | } 39 | return &FC{ 40 | InCount: inCount, 41 | OutCount: outCount, 42 | Weights: anydiff.NewVar(weights.Vector), 43 | Biases: anydiff.NewVar(biases.Vector), 44 | }, nil 45 | } 46 | 47 | // NewFC creates a new, randomized FC. 48 | // The randomization scheme targets an output variance of 49 | // 1, given that the input variance is 1. 50 | func NewFC(c anyvec.Creator, in, out int) *FC { 51 | res := NewFCZero(c, in, out) 52 | anyvec.Rand(res.Weights.Vector, anyvec.Normal, nil) 53 | res.Weights.Vector.Scale(c.MakeNumeric(1 / math.Sqrt(float64(in)))) 54 | return res 55 | } 56 | 57 | // NewFCZero creates a new, zero'd out FC. 58 | func NewFCZero(c anyvec.Creator, in, out int) *FC { 59 | return &FC{ 60 | InCount: in, 61 | OutCount: out, 62 | Weights: anydiff.NewVar(c.MakeVector(in * out)), 63 | Biases: anydiff.NewVar(c.MakeVector(out)), 64 | } 65 | } 66 | 67 | // Apply applies the fully-connected layer to a batch of 68 | // inputs. 69 | func (f *FC) Apply(in anydiff.Res, batch int) anydiff.Res { 70 | if batch*f.InCount != in.Output().Len() { 71 | panic(fmt.Sprintf("input length should be %d, but got %d", 72 | batch*f.InCount, in.Output().Len())) 73 | } 74 | weightMat := &anydiff.Matrix{ 75 | Data: f.Weights, 76 | Rows: f.OutCount, 77 | Cols: f.InCount, 78 | } 79 | inMat := &anydiff.Matrix{ 80 | Data: in, 81 | Rows: batch, 82 | Cols: f.InCount, 83 | } 84 | weighted := anydiff.MatMul(false, true, inMat, weightMat) 85 | return anydiff.AddRepeated(weighted.Data, f.Biases) 86 | } 87 | 88 | // AddBias adds a scaler to the biases. 89 | // It returns f for convenience. 90 | func (f *FC) AddBias(val anyvec.Numeric) *FC { 91 | f.Biases.Vector.AddScalar(val) 92 | return f 93 | } 94 | 95 | // Parameters returns a slice containing the weights 96 | // and the biases, in that order. 97 | func (f *FC) Parameters() []*anydiff.Var { 98 | return []*anydiff.Var{f.Weights, f.Biases} 99 | } 100 | 101 | // SerializerType returns the unique ID used to serialize 102 | // an FC with the serializer package. 103 | func (f *FC) SerializerType() string { 104 | return "github.com/unixpickle/anynet.FC" 105 | } 106 | 107 | // Serialize serializes the FC. 108 | func (f *FC) Serialize() ([]byte, error) { 109 | weights := &anyvecsave.S{Vector: f.Weights.Vector} 110 | biases := &anyvecsave.S{Vector: f.Biases.Vector} 111 | return serializer.SerializeAny(weights, biases) 112 | } 113 | -------------------------------------------------------------------------------- /mixer.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/essentials" 6 | "github.com/unixpickle/serializer" 7 | ) 8 | 9 | func init() { 10 | var a AddMixer 11 | serializer.RegisterTypedDeserializer(a.SerializerType(), DeserializeAddMixer) 12 | var c ConcatMixer 13 | serializer.RegisterTypedDeserializer(c.SerializerType(), DeserializeConcatMixer) 14 | } 15 | 16 | // A Mixer combines batches of inputs from two different 17 | // sources into a single vector. 18 | type Mixer interface { 19 | Mix(in1, in2 anydiff.Res, batch int) anydiff.Res 20 | } 21 | 22 | // An AddMixer combines two inputs by applying layers to 23 | // each of them, adding the results together, and then 24 | // applying an output layer to the sum. 25 | type AddMixer struct { 26 | In1 Layer 27 | In2 Layer 28 | Out Layer 29 | } 30 | 31 | // DeserializeAddMixer deserializes an AddMixer. 32 | func DeserializeAddMixer(d []byte) (*AddMixer, error) { 33 | var res AddMixer 34 | if err := serializer.DeserializeAny(d, &res.In1, &res.In2, &res.Out); err != nil { 35 | return nil, essentials.AddCtx("deserialize AddMixer", err) 36 | } 37 | return &res, nil 38 | } 39 | 40 | // Mix applies a.In1 to in1 and a.In2 to in2, then adds 41 | // the results, then applies a.Out. 42 | func (a *AddMixer) Mix(in1, in2 anydiff.Res, batch int) anydiff.Res { 43 | return a.Out.Apply(anydiff.Add( 44 | a.In1.Apply(in1, batch), 45 | a.In2.Apply(in2, batch), 46 | ), batch) 47 | } 48 | 49 | // Parameters gets the parameters of all the layers that 50 | // implement Parameterizer. 51 | func (a *AddMixer) Parameters() []*anydiff.Var { 52 | return AllParameters(a.In1, a.In2, a.Out) 53 | } 54 | 55 | // SerializerType returns the unique ID used to serialize 56 | // an AddMixer with the serializer package. 57 | func (a *AddMixer) SerializerType() string { 58 | return "github.com/unixpickle/anynet.AddMixer" 59 | } 60 | 61 | // Serialize attempts to serialize the AddMixer. 62 | func (a *AddMixer) Serialize() ([]byte, error) { 63 | return serializer.SerializeAny(a.In1, a.In2, a.Out) 64 | } 65 | 66 | // A ConcatMixer mixes inputs by concatenating inputs. 67 | type ConcatMixer struct{} 68 | 69 | // DeserializeConcatMixer deserializes a ConcatMixer. 70 | func DeserializeConcatMixer(d []byte) (ConcatMixer, error) { 71 | return ConcatMixer{}, nil 72 | } 73 | 74 | // Mix produces a vector of concatenated vectors, like 75 | // [in1[0], in2[0], in1[1], in2[1], ...], where in1[n] 76 | // represents the n-th vector in the batch represented 77 | // by in1. 78 | func (c ConcatMixer) Mix(in1, in2 anydiff.Res, batch int) anydiff.Res { 79 | return anydiff.Pool(in1, func(in1 anydiff.Res) anydiff.Res { 80 | return anydiff.Pool(in2, func(in2 anydiff.Res) anydiff.Res { 81 | var res []anydiff.Res 82 | v1Len := in1.Output().Len() / batch 83 | v2Len := in2.Output().Len() / batch 84 | for i := 0; i < batch; i++ { 85 | res = append(res, anydiff.Slice(in1, i*v1Len, (i+1)*v1Len), 86 | anydiff.Slice(in2, i*v2Len, (i+1)*v2Len)) 87 | } 88 | return anydiff.Concat(res...) 89 | }) 90 | }) 91 | } 92 | 93 | // SerializerType returns the unique ID used to serialize 94 | // a ConcatMixer with the serializer package. 95 | func (c ConcatMixer) SerializerType() string { 96 | return "github.com/unixpickle/anynet.ConcatMixer" 97 | } 98 | 99 | // Serialize serializes the instance. 100 | func (c ConcatMixer) Serialize() ([]byte, error) { 101 | return []byte{}, nil 102 | } 103 | -------------------------------------------------------------------------------- /param_hider.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "github.com/unixpickle/anydiff" 5 | "github.com/unixpickle/anydiff/anyfwd" 6 | "github.com/unixpickle/essentials" 7 | "github.com/unixpickle/serializer" 8 | ) 9 | 10 | func init() { 11 | var p ParamHider 12 | serializer.RegisterTypedDeserializer(p.SerializerType(), DeserializeParamHider) 13 | } 14 | 15 | // A ParamHider wraps a Layer and does not implement 16 | // Parameterizer, thus effectively freezing the parameters 17 | // of the layer. 18 | type ParamHider struct { 19 | Layer Layer 20 | } 21 | 22 | // DeserializeParamHider deserializes a ParamHider. 23 | func DeserializeParamHider(d []byte) (*ParamHider, error) { 24 | var p ParamHider 25 | if err := serializer.DeserializeAny(d, &p.Layer); err != nil { 26 | return nil, essentials.AddCtx("deserialize ParamHider", err) 27 | } 28 | return &p, nil 29 | } 30 | 31 | // Apply applies the layer. 32 | func (p *ParamHider) Apply(in anydiff.Res, n int) anydiff.Res { 33 | return p.Layer.Apply(in, n) 34 | } 35 | 36 | // SerializerType returns the unique ID used to serialize 37 | // a ParamHider with the serializer package. 38 | func (p *ParamHider) SerializerType() string { 39 | return "github.com/unixpickle/anynet.ParamHider" 40 | } 41 | 42 | // Serialize serializes the ParamHider. 43 | func (p *ParamHider) Serialize() ([]byte, error) { 44 | return serializer.SerializeAny(p.Layer) 45 | } 46 | 47 | // MakeFwd converts the contained layer to use forward 48 | // automatic differentiation. 49 | // 50 | // This works so long as anyfwd.MakeFwd works on the 51 | // contained layer. 52 | func (p *ParamHider) MakeFwd(c *anyfwd.Creator) { 53 | anyfwd.MakeFwd(c, p.Layer) 54 | } 55 | -------------------------------------------------------------------------------- /serializer_test.go: -------------------------------------------------------------------------------- 1 | package anynet 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anyvec/anyvec32" 9 | "github.com/unixpickle/serializer" 10 | ) 11 | 12 | func TestActivationSerialize(t *testing.T) { 13 | a1 := Tanh 14 | a2 := LogSoftmax 15 | a3 := Sigmoid 16 | a4 := ReLU 17 | a5 := Sin 18 | a6 := Exp 19 | data, err := serializer.SerializeAny(a1, a2, a3, a4, a5, a6) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | var newA1, newA2, newA3, newA4, newA5, newA6 Activation 24 | err = serializer.DeserializeAny(data, &newA1, &newA2, &newA3, &newA4, &newA5, &newA6) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | if newA1 != a1 { 29 | t.Error("Tanh failed") 30 | } 31 | if newA2 != a2 { 32 | t.Error("LogSoftmax failed") 33 | } 34 | if newA3 != a3 { 35 | t.Error("Sigmoid failed") 36 | } 37 | if newA4 != a4 { 38 | t.Error("ReLU failed") 39 | } 40 | if newA5 != a5 { 41 | t.Error("Sin failed") 42 | } 43 | if newA6 != a6 { 44 | t.Error("Exp failed") 45 | } 46 | } 47 | 48 | func TestFCSerialize(t *testing.T) { 49 | fc := NewFC(anyvec32.DefaultCreator{}, 7, 5) 50 | data, err := serializer.SerializeAny(fc) 51 | if err != nil { 52 | t.Fatal(err) 53 | } 54 | var newFC *FC 55 | if err := serializer.DeserializeAny(data, &newFC); err != nil { 56 | t.Fatal(err) 57 | } 58 | if !reflect.DeepEqual(fc, newFC) { 59 | t.Fatal("incorrect result") 60 | } 61 | } 62 | 63 | func TestAffineSerialize(t *testing.T) { 64 | affine := &Affine{ 65 | Scalers: anydiff.NewVar(anyvec32.MakeVectorData([]float32{1, 2, -1})), 66 | Biases: anydiff.NewVar(anyvec32.MakeVectorData([]float32{-3, 1})), 67 | } 68 | data, err := serializer.SerializeAny(affine) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | var newAffine *Affine 73 | if err := serializer.DeserializeAny(data, &newAffine); err != nil { 74 | t.Fatal(err) 75 | } 76 | if !reflect.DeepEqual(affine, newAffine) { 77 | t.Fatal("incorrect result") 78 | } 79 | } 80 | 81 | func TestConstAffineSerialize(t *testing.T) { 82 | affine := &ConstAffine{ 83 | Scale: 2, 84 | Bias: 3.14, 85 | } 86 | data, err := serializer.SerializeAny(affine) 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | var newAffine *ConstAffine 91 | if err := serializer.DeserializeAny(data, &newAffine); err != nil { 92 | t.Fatal(err) 93 | } 94 | if !reflect.DeepEqual(affine, newAffine) { 95 | t.Fatal("incorrect result") 96 | } 97 | } 98 | 99 | func TestDropoutSerialize(t *testing.T) { 100 | do := &Dropout{Enabled: true, KeepProb: 0.335} 101 | data, err := serializer.SerializeAny(do) 102 | if err != nil { 103 | t.Fatal(err) 104 | } 105 | var do1 *Dropout 106 | if err := serializer.DeserializeAny(data, &do1); err != nil { 107 | t.Fatal(err) 108 | } 109 | if !reflect.DeepEqual(do, do1) { 110 | t.Fatal("incorrect result") 111 | } 112 | } 113 | 114 | func TestNetSerialize(t *testing.T) { 115 | net := Net{Tanh, LogSoftmax} 116 | data, err := serializer.SerializeAny(net) 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | var net1 Net 121 | if err := serializer.DeserializeAny(data, &net1); err != nil { 122 | t.Fatal(err) 123 | } 124 | if !reflect.DeepEqual(net, net1) { 125 | t.Fatal("networks not equal") 126 | } 127 | } 128 | 129 | func TestAddMixerSerializer(t *testing.T) { 130 | c := anyvec32.DefaultCreator{} 131 | a := &AddMixer{ 132 | In1: NewFC(c, 5, 3), 133 | In2: NewFC(c, 2, 3), 134 | Out: NewFC(c, 3, 1), 135 | } 136 | data, err := serializer.SerializeAny(a) 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | var a1 *AddMixer 141 | if err := serializer.DeserializeAny(data, &a1); err != nil { 142 | t.Fatal(err) 143 | } 144 | if !reflect.DeepEqual(a1, a) { 145 | t.Error("incorrect result") 146 | } 147 | } 148 | 149 | func TestConcatMixerSerializer(t *testing.T) { 150 | mixer := ConcatMixer{} 151 | data, err := serializer.SerializeAny(mixer) 152 | if err != nil { 153 | t.Fatal(err) 154 | } 155 | var mixer1 ConcatMixer 156 | if err := serializer.DeserializeAny(data, &mixer1); err != nil { 157 | t.Fatal(err) 158 | } 159 | if !reflect.DeepEqual(mixer1, mixer) { 160 | t.Error("incorrect result") 161 | } 162 | } 163 | 164 | func TestParamHiderSerialize(t *testing.T) { 165 | net := &ParamHider{Layer: Tanh} 166 | data, err := serializer.SerializeAny(net) 167 | if err != nil { 168 | t.Fatal(err) 169 | } 170 | var net1 *ParamHider 171 | if err := serializer.DeserializeAny(data, &net1); err != nil { 172 | t.Fatal(err) 173 | } 174 | if !reflect.DeepEqual(net, net1) { 175 | t.Fatal("incorrect result") 176 | } 177 | } 178 | 179 | func TestDebugSerialize(t *testing.T) { 180 | net := &Debug{ID: "hey", PrintRaw: true, PrintVariance: true} 181 | data, err := serializer.SerializeAny(net) 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | var net1 *Debug 186 | if err := serializer.DeserializeAny(data, &net1); err != nil { 187 | t.Fatal(err) 188 | } 189 | if !reflect.DeepEqual(net, net1) { 190 | t.Fatal("incorrect result") 191 | } 192 | } 193 | --------------------------------------------------------------------------------