├── experiments ├── char_rnn │ ├── graphs │ │ ├── max_weight.png │ │ └── mean_weight.png │ └── README.md └── sentiment │ ├── graphs │ ├── hybrid_overfit.png │ └── validation_graph.png │ ├── cuda.go │ ├── README.md │ └── main.go ├── test_gen └── output_test.m ├── README.md ├── rwa_test.go └── rwa.go /experiments/char_rnn/graphs/max_weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/rwa/HEAD/experiments/char_rnn/graphs/max_weight.png -------------------------------------------------------------------------------- /experiments/char_rnn/graphs/mean_weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/rwa/HEAD/experiments/char_rnn/graphs/mean_weight.png -------------------------------------------------------------------------------- /experiments/sentiment/graphs/hybrid_overfit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/rwa/HEAD/experiments/sentiment/graphs/hybrid_overfit.png -------------------------------------------------------------------------------- /experiments/sentiment/graphs/validation_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/rwa/HEAD/experiments/sentiment/graphs/validation_graph.png -------------------------------------------------------------------------------- /experiments/sentiment/cuda.go: -------------------------------------------------------------------------------- 1 | // +build cuda 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/unixpickle/anyvec/anyvec32" 7 | "github.com/unixpickle/cudavec" 8 | ) 9 | 10 | func init() { 11 | handle, err := cudavec.NewHandleDefault() 12 | if err != nil { 13 | panic(err) 14 | } 15 | anyvec32.Use(&cudavec.Creator32{Handle: handle}) 16 | } 17 | -------------------------------------------------------------------------------- /test_gen/output_test.m: -------------------------------------------------------------------------------- 1 | initState = [0.074941; -1.132446]; 2 | 3 | W_u = [0.63997 -0.21826 0.88730; 0.25009 0.11063 0.72248]; 4 | b_u = [0.051791; 0.479197]; 5 | 6 | W_g_in = [0.0099470 0.8399450 -0.2081483; 0.9820264 0.3257544 0.1064337]; 7 | b_g_in = [-0.515952; 0.055721]; 8 | W_g_h = [0.97504 0.35937; -0.37616 0.69398]; 9 | b_g_h = [0.60679; 0.44104]; 10 | 11 | W_a_in = [0.90352 0.25258 0.76472; 0.73948 0.98564 0.20552]; 12 | b_a_in = [-0.50516; 0.73835]; 13 | W_a_h = [0.26337 0.28690; -0.29784 -0.79788]; 14 | b_a_h = [-0.017014; 0.803208]; 15 | 16 | ins = [0.29989 0.37573 0.45905; 0.36990 0.29873 0.14858; 0.50296 0.22233 0.78369]; 17 | 18 | num = [0; 0]; 19 | denom = [0; 0]; 20 | h = tanh(initState); 21 | for i=1:3 22 | input = ins(:,i); 23 | weights = exp(W_a_in*input + b_a_in + W_a_h*h + b_a_h); 24 | uOut = W_u*input + b_u; 25 | gOut = tanh(W_g_in*input + b_g_in + W_g_h*h + b_g_h); 26 | num += weights .* uOut .* gOut; 27 | denom += weights; 28 | h = tanh(num ./ denom) 29 | end 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Weighted Average 2 | 3 | This is a re-implementation of the architecture described in [Machine Learning on Sequential Data Using a Recurrent Weighted Average](https://arxiv.org/abs/1703.01253). 4 | 5 | # Hypotheses 6 | 7 | As the sequence gets longer and longer, the running average could become more and more "saturated" (i.e. new time-steps matter less and less). This might cause the network to have more and more trouble forming short-term memories as the sequence goes on. As a result, the network might do poorly at precise tasks like text character prediction. 8 | 9 | If the above concern is actually an issue, perhaps the long-term benefits of RWAs could still be leveraged by stacking an RWA on top of an LSTM. 10 | 11 | # Results 12 | 13 | Here are the experiments I have run: 14 | 15 | * [char-rnn](experiments/char_rnn) - RWAs can learn to model language character-by-character, although LSTMs are faster and better. 16 | * [sentiment](experiments/sentiment) - A hybrid LSTM-RWA model learns to predict the sentiment of tweets faster than a plain LSTM. 17 | -------------------------------------------------------------------------------- /experiments/sentiment/README.md: -------------------------------------------------------------------------------- 1 | # Experiment: Sentiment Analysis 2 | 3 | This tests the RWA on a pretty routine task: Twitter sentiment analysis. The model is fed a Tweet byte-by-byte. The network's last output is used to classify the tweet as "positive" or "negative". 4 | 5 | # Data 6 | 7 | This uses the data from [Sentiment140](http://help.sentiment140.com/for-students/). 8 | 9 | # Models 10 | 11 | In this experiment, I compare three different character-level models. Each model reads a tweet as a sequence of bytes, then produces a single classification at the last time-step. 12 | 13 | Here are the three models: 14 | 15 | * Pure LSTM: a two-layer LSTM model, with 384 cells per layer. 16 | * Pure RWA: a two-layer RWA model, with 384 cells per layer. 17 | * Hybrid: a two-layer model where the first layer is an LSTM and the second is an RWA. 18 | 19 | The hybrid model was inspired by the idea that the LSTM could deal with short-term dependencies (e.g. the spellings of words) while the RWA could deal with long-term dependencies. Essentially, the LSTM processes the raw character inputs and feeds higher-level data for the RWA to process on a longer timescale. 20 | 21 | # Results 22 | 23 | I trained the models for 1.5 epochs on this task. I found that the hybrid model learns significantly faster than the LSTM, especially at the beginning of training. Here is a graph comparing all three models, with number of samples along the x-axis and validation cross-entropy on the y-axis: 24 | 25 |  26 | 27 | The pure LSTM achieved an average validation cross-entropy of 0.39 during training. At the end of training, the validation accuracy was 86.1%. Note that the original Sentiment140 paper only managed to achieve 83.0% accuracy, and that required a ton of feature engineering. 28 | 29 | The hybrid LSTM-RWA achieved an average validation cross-entropy 0.37 during training, but only 83.8% validation accuracy. This can likely be explained by sample noise near the end of training. If the learning rate were annealed during training, we would likely find that the LSTM and hybrid models had similar validation accuracies at the end. It is clear from the average cross-entropy (which, as a moving average, is more reliable than accuracy) that the hybrid LSTM was doing slightly better than the pure LSTM. 30 | 31 | The pure RWA achieved an accuracy of 79.4% at the end of training, not to mention its much worse cross-entropy. 32 | 33 | As an extra experiment, I trained the hybrid model for a few more epochs with a smaller learning rate. It overfit: 34 | 35 |  36 | -------------------------------------------------------------------------------- /experiments/char_rnn/README.md: -------------------------------------------------------------------------------- 1 | # Experiment: char-rnn 2 | 3 | I created a [char-rnn branch](https://github.com/unixpickle/char-rnn/tree/rwa) that uses RWA. The results exceeded my expectations. 4 | 5 | In this experiment, a two-layer RWA with 512 hidden units per layer is trained to predict the next character in a string. In particular, strings are random 100-byte sequences taken from [Android app descriptions](https://github.com/unixpickle/appdescs). These strings may be taken from *anywhere* within a description: mid-sentence, mid HTML tag, etc. 6 | 7 | Quantitatively, the network achieves a cross-entropy loss of around 1.40 nats per byte (with virtually no overfitting). I trained the network for 8.9 epochs (about 80K batches of 32 samples each). This took ~10 hours on a Titan X GPU. An equivalent LSTM (with 2 layers and 512 cells each) achieves the same cross-entropy in about 6K batches (7.5% of the number of training steps). After 16K batches, the same LSTM achieves a cross-entropy closer to 1.28 nats, at which point it seems to plateau. 8 | 9 | Qualitatively, we can look at some strings generated by the trained RWA model: 10 | 11 | ``` 12 | id that you. The tourist now! If you is not eurislicy of kids7 rin cluide free! 13 |
Complete whil 14 | ``` 15 | 16 | ``` 17 | ifmills for your own phone Backbook Do Ore can egen Process.
< 18 | p> 19 | ``` 20 | 21 | ``` 22 | S Provide you will be sticker applications that you have to your account!> Animated Dis 48 | ``` 49 | 50 | ``` 51 | sight to get the Samsung Galaxy AS Kids cash chat news undeessage with up tok tasks!
The most 52 | ``` 53 | -------------------------------------------------------------------------------- /rwa_test.go: -------------------------------------------------------------------------------- 1 | package rwa 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/unixpickle/anydiff" 8 | "github.com/unixpickle/anydiff/anydifftest" 9 | "github.com/unixpickle/anydiff/anyseq" 10 | "github.com/unixpickle/anynet/anyrnn" 11 | "github.com/unixpickle/anyvec" 12 | "github.com/unixpickle/anyvec/anyvec32" 13 | ) 14 | 15 | func TestRWAOutput(t *testing.T) { 16 | block := NewRWA(anyvec32.CurrentCreator(), 3, 2) 17 | paramVals := [][]float32{ 18 | {0.074941, -1.132446}, 19 | {0.63997, -0.21826, 0.88730, 0.25009, 0.11063, 0.72248}, 20 | {0.051791, 0.479197}, 21 | {0.0099470, 0.8399450, -0.2081483, 0.9820264, 0.3257544, 0.1064337}, 22 | {-0.515952, 0.055721}, 23 | {0.97504, 0.35937, -0.37616, 0.69398}, 24 | {0.60679, 0.44104}, 25 | {0.90352, 0.25258, 0.76472, 0.73948, 0.98564, 0.20552}, 26 | {-0.50516, 0.73835}, 27 | {0.26337, 0.28690, -0.29784, -0.79788}, 28 | {-0.017014, 0.803208}, 29 | } 30 | for i, x := range block.Parameters() { 31 | x.Vector.SetData(paramVals[i]) 32 | } 33 | 34 | seq := anyseq.ConstSeqList(anyvec32.CurrentCreator(), [][]anyvec.Vector{ 35 | { 36 | anyvec32.MakeVectorData([]float32{0.29989, 0.36990, 0.50296}), 37 | anyvec32.MakeVectorData([]float32{0.37573, 0.29873, 0.22233}), 38 | anyvec32.MakeVectorData([]float32{0.45905, 0.14858, 0.78369}), 39 | }, 40 | }) 41 | out := anyrnn.Map(seq, block).Output() 42 | expected := [][]float32{ 43 | {0.049205, 0.329647}, 44 | {0.12153, 0.39991}, 45 | {0.20841, 0.49782}, 46 | } 47 | for i, aVec := range out { 48 | a := aVec.Packed.Data().([]float32) 49 | x := expected[i] 50 | for j := range a { 51 | if math.IsNaN(float64(a[j])) { 52 | t.Errorf("step %d component %d: got NaN", i, j) 53 | } else if math.Abs(float64(x[j]-a[j])) > 1e-3 { 54 | t.Errorf("step %d component %d: expected %v but got %v", 55 | i, j, x[j], a[j]) 56 | } 57 | } 58 | } 59 | } 60 | 61 | func TestRWAGradients(t *testing.T) { 62 | inSeq, inVars := randomTestSequence(3) 63 | block := NewRWA(anyvec32.CurrentCreator(), 3, 2) 64 | if len(block.Parameters()) != 11 { 65 | t.Errorf("expected 11 parameters, but got %d", len(block.Parameters())) 66 | } 67 | checker := &anydifftest.SeqChecker{ 68 | F: func() anyseq.Seq { 69 | return anyrnn.Map(inSeq, block) 70 | }, 71 | V: append(inVars, block.Parameters()...), 72 | } 73 | checker.FullCheck(t) 74 | } 75 | 76 | // randomTestSequence is borrowed from 77 | // https://github.com/unixpickle/anynet/blob/6a8bd570b702861f3c1260a6916723beea6bf296/anyrnn/layer_test.go#L34 78 | func randomTestSequence(inSize int) (anyseq.Seq, []*anydiff.Var) { 79 | inVars := []*anydiff.Var{} 80 | inBatches := []*anyseq.ResBatch{} 81 | 82 | presents := [][]bool{{true, true, true}, {true, false, true}} 83 | numPres := []int{3, 2} 84 | chunkLengths := []int{2, 3} 85 | 86 | for chunkIdx, pres := range presents { 87 | for i := 0; i < chunkLengths[chunkIdx]; i++ { 88 | vec := anyvec32.MakeVector(inSize * numPres[chunkIdx]) 89 | anyvec.Rand(vec, anyvec.Normal, nil) 90 | v := anydiff.NewVar(vec) 91 | batch := &anyseq.ResBatch{ 92 | Packed: v, 93 | Present: pres, 94 | } 95 | inVars = append(inVars, v) 96 | inBatches = append(inBatches, batch) 97 | } 98 | } 99 | return anyseq.ResSeq(anyvec32.CurrentCreator(), inBatches), inVars 100 | } 101 | -------------------------------------------------------------------------------- /experiments/sentiment/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/csv" 5 | "flag" 6 | "log" 7 | "math/rand" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "github.com/unixpickle/anydiff" 13 | "github.com/unixpickle/anydiff/anyseq" 14 | "github.com/unixpickle/anynet" 15 | "github.com/unixpickle/anynet/anyrnn" 16 | "github.com/unixpickle/anynet/anys2s" 17 | "github.com/unixpickle/anynet/anysgd" 18 | "github.com/unixpickle/anyvec" 19 | "github.com/unixpickle/anyvec/anyvec32" 20 | "github.com/unixpickle/essentials" 21 | "github.com/unixpickle/rip" 22 | "github.com/unixpickle/rwa" 23 | "github.com/unixpickle/serializer" 24 | ) 25 | 26 | func main() { 27 | serializer.RegisterTypedDeserializer((&Model{}).SerializerType(), DeserializeModel) 28 | rand.Seed(time.Now().UnixNano()) 29 | 30 | var modelPath string 31 | var trainingPath string 32 | var testingPath string 33 | var batchSize int 34 | var stepSize float64 35 | var hidden int 36 | var lstm bool 37 | var hybrid bool 38 | flag.StringVar(&modelPath, "out", "out_net", "output model file") 39 | flag.StringVar(&trainingPath, "training", "", "training data file") 40 | flag.StringVar(&testingPath, "testing", "", "testing data file") 41 | flag.IntVar(&batchSize, "batch", 16, "SGD batch size") 42 | flag.Float64Var(&stepSize, "step", 0.001, "SGD step size") 43 | flag.IntVar(&hidden, "hidden", 384, "number of hidden units") 44 | flag.BoolVar(&lstm, "lstm", false, "use LSTM instead of RWA") 45 | flag.BoolVar(&hybrid, "hybrid", false, "use hybrid LSTM-RWA model") 46 | flag.Parse() 47 | 48 | if trainingPath == "" || testingPath == "" { 49 | essentials.Die("Required flags: -testing and -training. See -help.") 50 | } 51 | 52 | c := anyvec32.CurrentCreator() 53 | 54 | var model *Model 55 | if err := serializer.LoadAny(modelPath, &model); err != nil { 56 | log.Println("Creating new model...") 57 | blockMaker := func(in, out int, inScale float32) anyrnn.Block { 58 | if lstm || (hybrid && inScale > 1) { 59 | return anyrnn.NewLSTM(c, in, out).ScaleInWeights(inScale) 60 | } else { 61 | return rwa.NewRWA(c, in, out).ScaleInWeights(inScale) 62 | } 63 | } 64 | model = &Model{ 65 | Block: anyrnn.Stack{ 66 | blockMaker(0x100, hidden, 16), 67 | blockMaker(hidden, hidden, 1), 68 | }, 69 | Out: anynet.Net{ 70 | anynet.NewFC(c, hidden, 1), 71 | }, 72 | } 73 | } else { 74 | log.Println("Loaded existing model.") 75 | } 76 | 77 | log.Println("Loading samples...") 78 | training, err := ReadSampleList(trainingPath) 79 | if err != nil { 80 | essentials.Die("Load training data:", err) 81 | } 82 | validation, err := ReadSampleList(testingPath) 83 | if err != nil { 84 | essentials.Die("Load testing data:", err) 85 | } 86 | 87 | log.Println("Training (ctrl+c to end)...") 88 | trainer := &anys2s.Trainer{ 89 | Func: model.Apply, 90 | Cost: anynet.SigmoidCE{}, 91 | Params: model.Parameters(), 92 | Average: true, 93 | } 94 | var iter int 95 | sgd := &anysgd.SGD{ 96 | Fetcher: trainer, 97 | Gradienter: trainer, 98 | Transformer: &anysgd.Adam{}, 99 | Samples: training, 100 | Rater: anysgd.ConstRater(stepSize), 101 | BatchSize: batchSize, 102 | StatusFunc: func(b anysgd.Batch) { 103 | if iter%4 == 0 { 104 | anysgd.Shuffle(validation) 105 | bs := essentials.MinInt(batchSize, validation.Len()) 106 | batch, _ := trainer.Fetch(validation.Slice(0, bs)) 107 | cost := anyvec.Sum(trainer.TotalCost(batch.(*anys2s.Batch)).Output()) 108 | log.Printf("iter %d: cost=%v validation=%v", iter, trainer.LastCost, cost) 109 | } else { 110 | log.Printf("iter %d: cost=%v", iter, trainer.LastCost) 111 | } 112 | iter++ 113 | }, 114 | } 115 | sgd.Run(rip.NewRIP().Chan()) 116 | 117 | log.Println("Saving model...") 118 | if err := serializer.SaveAny(modelPath, model); err != nil { 119 | essentials.Die("Save model:", err) 120 | } 121 | 122 | log.Println("Computing validation accuracy...") 123 | log.Printf("Validation accuracy: %f", model.Accuracy(validation)) 124 | } 125 | 126 | type Model struct { 127 | Block anyrnn.Block 128 | Out anynet.Layer 129 | } 130 | 131 | func DeserializeModel(d []byte) (*Model, error) { 132 | var res Model 133 | if err := serializer.DeserializeAny(d, &res.Block, &res.Out); err != nil { 134 | return nil, err 135 | } 136 | return &res, nil 137 | } 138 | 139 | func (m *Model) Apply(in anyseq.Seq) anyseq.Seq { 140 | n := in.Output()[0].NumPresent() 141 | latent := anyseq.Tail(anyrnn.Map(in, m.Block)) 142 | outs := m.Out.Apply(latent, n) 143 | return anyseq.ResSeq(in.Creator(), []*anyseq.ResBatch{ 144 | &anyseq.ResBatch{ 145 | Present: in.Output()[0].Present, 146 | Packed: outs, 147 | }, 148 | }) 149 | } 150 | 151 | func (m *Model) Accuracy(samples *SampleList) float64 { 152 | // Don't bother using batches; this is never called 153 | // during the training inner loop, anyway. 154 | var correct int 155 | total := samples.Len() 156 | for i := 0; i < samples.Len(); i++ { 157 | sample, _ := samples.GetSample(i) 158 | inSeq := anyseq.ConstSeqList(anyvec32.CurrentCreator(), 159 | [][]anyvec.Vector{sample.Input}) 160 | out := m.Out.Apply(anyseq.Tail(anyrnn.Map(inSeq, m.Block)), 1) 161 | actual := anyvec.Sum(out.Output()).(float32) > 0 162 | desired := samples.Sentiments[i] 163 | if actual == desired { 164 | correct++ 165 | } 166 | } 167 | return float64(correct) / float64(total) 168 | } 169 | 170 | func (m *Model) Parameters() []*anydiff.Var { 171 | var res []*anydiff.Var 172 | for _, x := range []interface{}{m.Block, m.Out} { 173 | if x, ok := x.(anynet.Parameterizer); ok { 174 | res = append(res, x.Parameters()...) 175 | } 176 | } 177 | return res 178 | } 179 | 180 | func (m *Model) SerializerType() string { 181 | return "github.com/unixpickle/rwa/experiments/sentiment.Model" 182 | } 183 | 184 | func (m *Model) Serialize() ([]byte, error) { 185 | return serializer.SerializeAny(m.Block, m.Out) 186 | } 187 | 188 | type SampleList struct { 189 | Tweets []string 190 | Sentiments []bool 191 | 192 | cacheLock *sync.RWMutex 193 | charCache map[byte]anyvec.Vector 194 | } 195 | 196 | func ReadSampleList(csvFile string) (*SampleList, error) { 197 | f, err := os.Open(csvFile) 198 | if err != nil { 199 | return nil, err 200 | } 201 | defer f.Close() 202 | r := csv.NewReader(f) 203 | rows, err := r.ReadAll() 204 | if err != nil { 205 | return nil, err 206 | } 207 | var res SampleList 208 | for _, row := range rows { 209 | if row[0] == "2" { 210 | continue 211 | } 212 | res.Sentiments = append(res.Sentiments, row[0] == "4") 213 | res.Tweets = append(res.Tweets, row[len(row)-1]) 214 | } 215 | res.cacheLock = &sync.RWMutex{} 216 | res.charCache = map[byte]anyvec.Vector{} 217 | return &res, nil 218 | } 219 | 220 | func (s *SampleList) Len() int { 221 | return len(s.Tweets) 222 | } 223 | 224 | func (s *SampleList) Swap(i, j int) { 225 | s.Tweets[i], s.Tweets[j] = s.Tweets[j], s.Tweets[i] 226 | s.Sentiments[i], s.Sentiments[j] = s.Sentiments[j], s.Sentiments[i] 227 | } 228 | 229 | func (s *SampleList) Slice(i, j int) anysgd.SampleList { 230 | return &SampleList{ 231 | Tweets: append([]string{}, s.Tweets[i:j]...), 232 | Sentiments: append([]bool{}, s.Sentiments[i:j]...), 233 | cacheLock: s.cacheLock, 234 | charCache: s.charCache, 235 | } 236 | } 237 | 238 | func (s *SampleList) Creator() anyvec.Creator { 239 | return anyvec32.CurrentCreator() 240 | } 241 | 242 | func (s *SampleList) GetSample(i int) (*anys2s.Sample, error) { 243 | data := []byte(s.Tweets[i]) 244 | var input []anyvec.Vector 245 | for _, x := range data { 246 | input = append(input, s.charVec(x)) 247 | } 248 | classVal := float32(0) 249 | if s.Sentiments[i] { 250 | classVal = 1 251 | } 252 | output := []anyvec.Vector{anyvec32.MakeVectorData([]float32{classVal})} 253 | return &anys2s.Sample{ 254 | Input: input, 255 | Output: output, 256 | }, nil 257 | } 258 | 259 | func (s *SampleList) charVec(char byte) anyvec.Vector { 260 | s.cacheLock.RLock() 261 | res, ok := s.charCache[char] 262 | s.cacheLock.RUnlock() 263 | if !ok { 264 | data := make([]float32, 0x100) 265 | data[int(char)] = 1 266 | res = anyvec32.MakeVectorData(data) 267 | s.cacheLock.Lock() 268 | s.charCache[char] = res 269 | s.cacheLock.Unlock() 270 | } 271 | return res 272 | } 273 | -------------------------------------------------------------------------------- /rwa.go: -------------------------------------------------------------------------------- 1 | // Package rwa implements the Recurrent Weighted Average 2 | // RNN defined in https://arxiv.org/pdf/1703.01253.pdf. 3 | package rwa 4 | 5 | import ( 6 | "math" 7 | 8 | "github.com/unixpickle/anydiff" 9 | "github.com/unixpickle/anynet" 10 | "github.com/unixpickle/anynet/anyrnn" 11 | "github.com/unixpickle/anyvec" 12 | "github.com/unixpickle/anyvec/anyvecsave" 13 | "github.com/unixpickle/serializer" 14 | ) 15 | 16 | func init() { 17 | serializer.RegisterTypedDeserializer((&RWA{}).SerializerType(), DeserializeRWA) 18 | } 19 | 20 | // RWA is a Recurrent Weighted Average RNN block. 21 | type RWA struct { 22 | // SquashFunc is used to squash the rolling average. 23 | SquashFunc anynet.Layer 24 | 25 | // Init is the unsquashed start state. 26 | Init *anydiff.Var 27 | 28 | // Encoder is u(x) from the paper. 29 | Encoder *anynet.FC 30 | 31 | // Masker is g(x,h) from the paper. 32 | Masker *anynet.AddMixer 33 | 34 | // Context is a(x,h) from the paper. 35 | Context *anynet.AddMixer 36 | } 37 | 38 | // NewRWA creates a randomized RWA with the given number 39 | // of inputs and hidden units. 40 | func NewRWA(c anyvec.Creator, inSize, stateSize int) *RWA { 41 | oneVec := c.MakeVector(stateSize) 42 | oneVec.AddScalar(c.MakeNumeric(1)) 43 | return &RWA{ 44 | SquashFunc: anynet.Tanh, 45 | Init: anydiff.NewVar(c.MakeVector(stateSize)), 46 | Encoder: anynet.NewFC(c, inSize, stateSize), 47 | Masker: &anynet.AddMixer{ 48 | In1: anynet.NewFC(c, inSize, stateSize), 49 | In2: anynet.NewFC(c, stateSize, stateSize), 50 | Out: anynet.Net{}, 51 | }, 52 | Context: &anynet.AddMixer{ 53 | In1: anynet.NewFC(c, inSize, stateSize), 54 | In2: anynet.NewFC(c, stateSize, stateSize), 55 | Out: anynet.Net{}, 56 | }, 57 | } 58 | } 59 | 60 | // DeserializeRWA deserializes an RWA. 61 | func DeserializeRWA(d []byte) (*RWA, error) { 62 | var res RWA 63 | var initVec *anyvecsave.S 64 | err := serializer.DeserializeAny(d, &res.SquashFunc, &initVec, &res.Encoder, 65 | &res.Masker, &res.Context) 66 | if err != nil { 67 | return nil, err 68 | } 69 | res.Init = anydiff.NewVar(initVec.Vector) 70 | return &res, nil 71 | } 72 | 73 | // ScaleInWeights scales all of the weights that modify 74 | // input vectors. 75 | // It returns r for convenience. 76 | func (r *RWA) ScaleInWeights(scaler anyvec.Numeric) *RWA { 77 | mats := []*anydiff.Var{r.Encoder.Weights, r.Masker.In1.(*anynet.FC).Weights, 78 | r.Context.In1.(*anynet.FC).Weights} 79 | for _, gate := range mats { 80 | gate.Vector.Scale(scaler) 81 | } 82 | return r 83 | } 84 | 85 | // Start generates an initial *State. 86 | func (r *RWA) Start(n int) anyrnn.State { 87 | c := r.Init.Vector.Creator() 88 | zeroDenom := c.MakeVector(r.Init.Vector.Len()) 89 | zeroDenom.AddScalar(c.MakeNumeric(math.Inf(-1))) 90 | return &State{ 91 | First: true, 92 | MaxWeight: anyrnn.NewVecState(zeroDenom, n), 93 | Hidden: anyrnn.NewVecState(r.Init.Vector, n), 94 | Num: anyrnn.NewVecState(c.MakeVector(r.Init.Vector.Len()), n), 95 | Denom: anyrnn.NewVecState(c.MakeVector(r.Init.Vector.Len()), n), 96 | } 97 | } 98 | 99 | // PropagateStart propagates through the start state. 100 | func (r *RWA) PropagateStart(s anyrnn.StateGrad, g anydiff.Grad) { 101 | s.(*State).Hidden.PropagateStart(r.Init, g) 102 | } 103 | 104 | // Step performs a timestep. 105 | func (r *RWA) Step(s anyrnn.State, in anyvec.Vector) anyrnn.Res { 106 | batch := s.Present().NumPresent() 107 | state := s.(*State) 108 | c := in.Creator() 109 | 110 | inPool := anydiff.NewVar(in) 111 | maxPool := anydiff.NewVar(state.MaxWeight.Vector) 112 | hiddenPool := anydiff.NewVar(state.Hidden.Vector) 113 | numPool := anydiff.NewVar(state.Num.Vector) 114 | denomPool := anydiff.NewVar(state.Denom.Vector) 115 | 116 | hidden := r.SquashFunc.Apply(hiddenPool, batch) 117 | 118 | intermed := anydiff.Fuse(hidden) 119 | outs := anydiff.PoolMulti(intermed, func(reses []anydiff.Res) anydiff.MultiRes { 120 | hidden := reses[0] 121 | weightLog := r.Context.Mix(inPool, hidden, batch) 122 | inEnc := r.Encoder.Apply(inPool, batch) 123 | inMask := anydiff.Tanh(r.Masker.Mix(inPool, hidden, batch)) 124 | z := anydiff.Mul(inEnc, inMask) 125 | 126 | intermed := anydiff.Fuse(weightLog, z) 127 | return anydiff.PoolMulti(intermed, func(reses []anydiff.Res) anydiff.MultiRes { 128 | weightLog := reses[0] 129 | z := reses[1] 130 | var newMax, maxAdjust anydiff.Res 131 | if state.First { 132 | newMax = weightLog 133 | zeros := c.MakeVector(maxPool.Vector.Len()) 134 | maxAdjust = anydiff.NewConst(zeros) 135 | } else { 136 | newMax = anydiff.ElemMax(weightLog, maxPool) 137 | maxAdjust = anydiff.Exp(anydiff.Sub(maxPool, newMax)) 138 | } 139 | weight := anydiff.Exp(anydiff.Sub(weightLog, newMax)) 140 | newNum := anydiff.Add( 141 | anydiff.Mul(numPool, maxAdjust), 142 | anydiff.Mul(z, weight), 143 | ) 144 | newDenom := anydiff.Add(anydiff.Mul(denomPool, maxAdjust), weight) 145 | intermed := anydiff.Fuse(newMax, newNum, newDenom) 146 | return anydiff.PoolMulti(intermed, func(reses []anydiff.Res) anydiff.MultiRes { 147 | newMax := reses[0] 148 | newNum := reses[1] 149 | newDenom := reses[2] 150 | invDenom := anydiff.Pow(newDenom, c.MakeNumeric(-1)) 151 | unsquashedHidden := anydiff.Mul(newNum, invDenom) 152 | intermed := anydiff.Fuse(unsquashedHidden) 153 | return anydiff.PoolMulti(intermed, func(reses []anydiff.Res) anydiff.MultiRes { 154 | squashedHidden := r.SquashFunc.Apply(reses[0], batch) 155 | return anydiff.Fuse(squashedHidden, newMax, unsquashedHidden, 156 | newNum, newDenom) 157 | }) 158 | }) 159 | }) 160 | }) 161 | 162 | return &blockRes{ 163 | V: anydiff.NewVarSet(r.Parameters()...), 164 | OutVec: outs.Outputs()[0], 165 | OutState: &State{ 166 | MaxWeight: &anyrnn.VecState{ 167 | PresentMap: state.Present(), 168 | Vector: outs.Outputs()[1], 169 | }, 170 | Hidden: &anyrnn.VecState{ 171 | PresentMap: state.Present(), 172 | Vector: outs.Outputs()[2], 173 | }, 174 | Num: &anyrnn.VecState{ 175 | PresentMap: state.Present(), 176 | Vector: outs.Outputs()[3], 177 | }, 178 | Denom: &anyrnn.VecState{ 179 | PresentMap: state.Present(), 180 | Vector: outs.Outputs()[4], 181 | }, 182 | }, 183 | InPool: inPool, 184 | MaxPool: maxPool, 185 | HiddenPool: hiddenPool, 186 | NumPool: numPool, 187 | DenomPool: denomPool, 188 | Res: outs, 189 | } 190 | } 191 | 192 | // Parameters returns the block's parameters. 193 | func (r *RWA) Parameters() []*anydiff.Var { 194 | res := []*anydiff.Var{r.Init} 195 | for _, x := range []interface{}{r.SquashFunc, r.Encoder, r.Masker, r.Context} { 196 | if p, ok := x.(anynet.Parameterizer); ok { 197 | res = append(res, p.Parameters()...) 198 | } 199 | } 200 | return res 201 | } 202 | 203 | // SerializerType returns the unique ID used to serialize 204 | // an RWA with the serializer package. 205 | func (r *RWA) SerializerType() string { 206 | return "github.com/unixpickle/rwa.RWA" 207 | } 208 | 209 | // Serialize serializes an RWA. 210 | func (r *RWA) Serialize() ([]byte, error) { 211 | return serializer.SerializeAny( 212 | r.SquashFunc, 213 | &anyvecsave.S{Vector: r.Init.Vector}, 214 | r.Encoder, 215 | r.Masker, 216 | r.Context, 217 | ) 218 | } 219 | 220 | // State stores the hidden state of an RWA block or the 221 | // gradient of such a state. 222 | // 223 | // The MaxWeight field stores the maximum (log domain) 224 | // weight from any past timestep. 225 | // It is used to ensure numerical stability. 226 | // 227 | // The Num and Denom fields store the current numerator, 228 | // divided by e^MaxWeight. 229 | // They both start at 0. 230 | // 231 | // The Hidden field stores the previous, unsquashed hidden 232 | // state. 233 | // 234 | // It is necessary for the Hidden field to be separate 235 | // from the Num and Denom fields so that the network can 236 | // be evaluated at the first timestep. 237 | type State struct { 238 | First bool 239 | 240 | MaxWeight *anyrnn.VecState 241 | Hidden *anyrnn.VecState 242 | Num *anyrnn.VecState 243 | Denom *anyrnn.VecState 244 | } 245 | 246 | func (s *State) Present() anyrnn.PresentMap { 247 | return s.Hidden.Present() 248 | } 249 | 250 | func (s *State) Reduce(p anyrnn.PresentMap) anyrnn.State { 251 | return &State{ 252 | First: s.First, 253 | MaxWeight: s.MaxWeight.Reduce(p).(*anyrnn.VecState), 254 | Hidden: s.Hidden.Reduce(p).(*anyrnn.VecState), 255 | Num: s.Num.Reduce(p).(*anyrnn.VecState), 256 | Denom: s.Denom.Reduce(p).(*anyrnn.VecState), 257 | } 258 | } 259 | 260 | func (s *State) Expand(p anyrnn.PresentMap) anyrnn.StateGrad { 261 | return &State{ 262 | First: s.First, 263 | MaxWeight: s.MaxWeight.Expand(p).(*anyrnn.VecState), 264 | Hidden: s.Hidden.Expand(p).(*anyrnn.VecState), 265 | Num: s.Num.Expand(p).(*anyrnn.VecState), 266 | Denom: s.Denom.Expand(p).(*anyrnn.VecState), 267 | } 268 | } 269 | 270 | type blockRes struct { 271 | OutState *State 272 | OutVec anyvec.Vector 273 | V anydiff.VarSet 274 | 275 | InPool *anydiff.Var 276 | MaxPool *anydiff.Var 277 | HiddenPool *anydiff.Var 278 | NumPool *anydiff.Var 279 | DenomPool *anydiff.Var 280 | 281 | Res anydiff.MultiRes 282 | } 283 | 284 | func (b *blockRes) State() anyrnn.State { 285 | return b.OutState 286 | } 287 | 288 | func (b *blockRes) Output() anyvec.Vector { 289 | return b.OutVec 290 | } 291 | 292 | func (b *blockRes) Vars() anydiff.VarSet { 293 | return b.V 294 | } 295 | 296 | func (b *blockRes) Propagate(u anyvec.Vector, s anyrnn.StateGrad, g anydiff.Grad) (anyvec.Vector, 297 | anyrnn.StateGrad) { 298 | c := u.Creator() 299 | down := make([]anyvec.Vector, 5) 300 | down[0] = u 301 | if s != nil { 302 | st := s.(*State) 303 | down[1] = st.MaxWeight.Vector 304 | down[2] = st.Hidden.Vector 305 | down[3] = st.Num.Vector 306 | down[4] = st.Denom.Vector 307 | } else { 308 | down[1] = c.MakeVector(b.OutState.MaxWeight.Vector.Len()) 309 | down[2] = c.MakeVector(b.OutState.Hidden.Vector.Len()) 310 | down[3] = c.MakeVector(b.OutState.Num.Vector.Len()) 311 | down[4] = c.MakeVector(b.OutState.Denom.Vector.Len()) 312 | } 313 | for _, x := range b.pools() { 314 | g[x] = c.MakeVector(x.Vector.Len()) 315 | } 316 | b.Res.Propagate(down, g) 317 | 318 | inDown := g[b.InPool] 319 | stateDown := &State{ 320 | MaxWeight: &anyrnn.VecState{ 321 | PresentMap: b.OutState.Present(), 322 | Vector: g[b.MaxPool], 323 | }, 324 | Hidden: &anyrnn.VecState{ 325 | PresentMap: b.OutState.Present(), 326 | Vector: g[b.HiddenPool], 327 | }, 328 | Num: &anyrnn.VecState{ 329 | PresentMap: b.OutState.Present(), 330 | Vector: g[b.NumPool], 331 | }, 332 | Denom: &anyrnn.VecState{ 333 | PresentMap: b.OutState.Present(), 334 | Vector: g[b.DenomPool], 335 | }, 336 | } 337 | 338 | for _, x := range b.pools() { 339 | delete(g, x) 340 | } 341 | 342 | return inDown, stateDown 343 | } 344 | 345 | func (b *blockRes) pools() []*anydiff.Var { 346 | return []*anydiff.Var{b.InPool, b.MaxPool, b.HiddenPool, b.NumPool, b.DenomPool} 347 | } 348 | --------------------------------------------------------------------------------