├── doc.go ├── .gitignore ├── .travis.yml ├── utils ├── utils_test.go └── utils.go ├── LICENSE ├── temporalMemoryConnections_test.go ├── encoders ├── dateEncoder_test.go ├── scalerEncoder_test.go ├── dateEncoder.go └── scalerEncoder.go ├── temporalPooler_test.go ├── spatialPoolerCompute_test.go ├── sparseBinaryMatrix_test.go ├── denseBinaryMatrix_test.go ├── temporalMemoryConnections.go ├── README.md ├── segmentUpdate.go ├── trivialPredictor.go ├── denseBinaryMatrix.go ├── sparseBinaryMatrix.go ├── temporalPoolerPrint.go ├── spatialPoolerBoost_test.go ├── temporalPoolerStats.go ├── segment.go ├── temporalMemory_test.go ├── temporalMemory.go └── spatialPooler.go /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | htm contains ports of Numenta's spatial and temporal poolers as they are 3 | currently implemented in the Nupic framework. 4 | */ 5 | package htm 6 | 7 | import ( 8 | _ "github.com/nupic-community/htm/utils" 9 | ) 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | install: 4 | - go get github.com/stretchr/testify/assert 5 | - go get github.com/cznic/mathutil 6 | - go get github.com/gonum/floats 7 | - go get github.com/skelterjohn/go.matrix 8 | - go get github.com/zacg/floats 9 | - go get github.com/zacg/go.matrix 10 | - go get github.com/zacg/ints 11 | - go get github.com/zacg/testify/assert -------------------------------------------------------------------------------- /utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestFillSliceWithIdxInt(t *testing.T) { 9 | vals := make([]int, 3) 10 | FillSliceWithIdxInt(vals) 11 | expected := []int{0, 1, 2} 12 | assert.Equal(t, expected, vals) 13 | } 14 | 15 | func TestCartProductInt(t *testing.T) { 16 | vals := [][]int{ 17 | {1, 2, 3, 4}, 18 | {5, 6, 7, 8}, 19 | {10, 11, 12, 13}, 20 | } 21 | 22 | result := CartProductInt(vals) 23 | 24 | assert.Equal(t, 64, len(result)) 25 | assert.Equal(t, []int{1, 5, 10}, result[0]) 26 | assert.Equal(t, []int{2, 5, 12}, result[18]) 27 | assert.Equal(t, []int{3, 8, 13}, result[47]) 28 | 29 | vals = [][]int{ 30 | {1, 2}, 31 | {2, 3}, 32 | {0, 1}, 33 | } 34 | 35 | result = CartProductInt(vals) 36 | 37 | assert.Equal(t, 8, len(result)) 38 | 39 | } 40 | 41 | func TestProdInt(t *testing.T) { 42 | 43 | vals := []int{32, 32} 44 | expected := 1024 45 | 46 | actual := ProdInt(vals) 47 | 48 | assert.Equal(t, expected, actual) 49 | 50 | } 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Zac Gross 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /temporalMemoryConnections_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "github.com/zacg/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestNumColumns(t *testing.T) { 9 | c := NewTemporalMemoryConnections(0, 32, []int{64, 64}) 10 | assert.Equal(t, 64*64, c.NumberOfColumns()) 11 | } 12 | 13 | func TestNumCells(t *testing.T) { 14 | c := NewTemporalMemoryConnections(0, 32, []int{64, 64}) 15 | assert.Equal(t, 64*64*32, c.NumberOfcells()) 16 | } 17 | 18 | func TestUpdateSynapsePermanence(t *testing.T) { 19 | c := NewTemporalMemoryConnections(1000, 32, []int{64, 64}) 20 | c.CreateSegment(0) 21 | c.CreateSynapse(0, 483, 0.1284) 22 | c.UpdateSynapsePermanence(0, 0.2496) 23 | assert.Equal(t, 0.2496, c.DataForSynapse(0).Permanence) 24 | } 25 | 26 | func TestCellsForColumn1D(t *testing.T) { 27 | c := NewTemporalMemoryConnections(1000, 5, []int{2048}) 28 | expectedCells := []int{5, 6, 7, 8, 9} 29 | assert.Equal(t, expectedCells, c.CellsForColumn(1)) 30 | } 31 | 32 | func TestCellsForColumn2D(t *testing.T) { 33 | c := NewTemporalMemoryConnections(1000, 4, []int{64, 64}) 34 | expectedCells := []int{256, 257, 258, 259} 35 | assert.Equal(t, expectedCells, c.CellsForColumn(64)) 36 | } 37 | -------------------------------------------------------------------------------- /encoders/dateEncoder_test.go: -------------------------------------------------------------------------------- 1 | package encoders 2 | 3 | import ( 4 | "github.com/nupic-community/htm/utils" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestSimpleDateEncoding(t *testing.T) { 11 | 12 | p := NewDateEncoderParams() 13 | p.SeasonWidth = 3 14 | p.DayOfWeekWidth = 1 15 | p.WeekendWidth = 3 16 | p.TimeOfDayWidth = 5 17 | de := NewDateEncoder(p) 18 | 19 | // season is aaabbbcccddd (1 bit/month) TODO should be <<3? 20 | // should be 000000000111 (centered on month 11 - Nov) 21 | seasonExpected := utils.Make1DBool([]int{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1}) 22 | // week is SMTWTFS 23 | // differs from python implementation 24 | dayOfWeekExpected := utils.Make1DBool([]int{0, 0, 0, 0, 1, 0, 0}) 25 | // not a weekend, so it should be "False" 26 | weekendExpected := utils.Make1DBool([]int{1, 1, 1, 0, 0, 0}) 27 | // time of day has radius of 4 hours and w of 5 so each bit = 240/5 min = 48min 28 | // 14:55 is minute 14*60 + 55 = 895; 895/48 = bit 18.6 29 | // should be 30 bits total (30 * 48 minutes = 24 hours) 30 | timeOfDayExpected := utils.Make1DBool([]int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0}) 31 | 32 | d := time.Date(2010, 11, 4, 14, 55, 0, 0, time.UTC) 33 | encoded := de.Encode(d) 34 | t.Log(utils.Bool2Int(encoded)) 35 | 36 | expected := append(seasonExpected, dayOfWeekExpected...) 37 | expected = append(expected, weekendExpected...) 38 | expected = append(expected, timeOfDayExpected...) 39 | 40 | assert.Equal(t, utils.Bool2Int(expected), utils.Bool2Int(encoded)) 41 | 42 | } 43 | -------------------------------------------------------------------------------- /temporalPooler_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "fmt" 5 | //"github.com/cznic/mathutil" 6 | //"github.com/zacg/go.matrix" 7 | //"math" 8 | "math/rand" 9 | //"sort" 10 | //"github.com/gonum/floats" 11 | //"github.com/zacg/ints" 12 | "github.com/zacg/testify/assert" 13 | "testing" 14 | ) 15 | 16 | func boolRange(start int, end int, length int) []bool { 17 | result := make([]bool, length) 18 | for i := start; i <= end; i++ { 19 | result[i] = true 20 | } 21 | 22 | return result 23 | } 24 | 25 | func TestLearnPredict(t *testing.T) { 26 | tps := NewTemporalPoolerParams() 27 | tps.Verbosity = 0 28 | tps.NumberOfCols = 50 29 | tps.CellsPerColumn = 2 30 | tps.ActivationThreshold = 8 31 | tps.MinThreshold = 10 32 | tps.InitialPerm = 0.5 33 | tps.ConnectedPerm = 0.5 34 | tps.NewSynapseCount = 10 35 | tps.PermanenceDec = 0.0 36 | tps.PermanenceInc = 0.1 37 | tps.GlobalDecay = 0 38 | tps.BurnIn = 1 39 | tps.PamLength = 10 40 | //tps.DoPooling = true 41 | 42 | tps.CollectStats = true 43 | tp := NewTemporalPooler(*tps) 44 | 45 | inputs := make([][]bool, 5) 46 | 47 | // inputs[0] = GenerateRandSequence(80, 50) 48 | // inputs[1] = GenerateRandSequence(80, 50) 49 | // inputs[2] = GenerateRandSequence(80, 50) 50 | inputs[0] = boolRange(0, 9, 50) 51 | inputs[1] = boolRange(10, 19, 50) 52 | inputs[2] = boolRange(20, 29, 50) 53 | inputs[3] = boolRange(30, 39, 50) 54 | inputs[4] = boolRange(40, 49, 50) 55 | 56 | //Learn 5 sequences above 57 | for i := 0; i < 10; i++ { 58 | for p := 0; p < 5; p++ { 59 | tp.Compute(inputs[p], true, false) 60 | } 61 | 62 | tp.Reset() 63 | } 64 | 65 | //Predict sequences 66 | for i := 0; i < 4; i++ { 67 | tp.Compute(inputs[i], false, true) 68 | p := tp.DynamicState.InfPredictedState.Entries() 69 | fmt.Println(p) 70 | assert.Equal(t, 10, len(p)) 71 | for _, val := range p { 72 | next := i + 1 73 | if next > 4 { 74 | next = 4 75 | } 76 | assert.True(t, inputs[next][val.Row]) 77 | } 78 | } 79 | 80 | } 81 | 82 | func GenerateRandSequence(size int, width int) []bool { 83 | input := make([]bool, size) 84 | for i := 0; i < width; i++ { 85 | ind := rand.Intn(size) 86 | input[ind] = true 87 | } 88 | 89 | return input 90 | } 91 | -------------------------------------------------------------------------------- /spatialPoolerCompute_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | //"fmt" 5 | //"github.com/skelterjohn/go.matrix" 6 | //"github.com/stretchr/testify/assert" 7 | "github.com/nupic-community/htm/utils" 8 | "github.com/zacg/testify/assert" 9 | //"math/big" 10 | //"github.com/stretchr/testify/mock" 11 | //"math" 12 | "math/rand" 13 | //"strconv" 14 | "testing" 15 | ) 16 | 17 | func basicComputeLoop(t *testing.T, spParams SpParams) { 18 | /* 19 | Feed in some vectors and retrieve outputs. Ensure the right number of 20 | columns win, that we always get binary outputs, and that nothing crashes. 21 | */ 22 | 23 | sp := NewSpatialPooler(spParams) 24 | 25 | // Create a set of input vectors as well as various numpy vectors we will 26 | // need to retrieve data from the SP 27 | numRecords := 100 28 | 29 | inputMatrix := make([][]bool, numRecords) 30 | for i := range inputMatrix { 31 | inputMatrix[i] = make([]bool, sp.numInputs) 32 | for j := range inputMatrix[i] { 33 | inputMatrix[i][j] = rand.Float64() > 0.8 34 | } 35 | } 36 | 37 | // With learning off and no prior training we should get no winners 38 | y := make([]bool, sp.numColumns) 39 | for _, input := range inputMatrix { 40 | utils.FillSliceBool(y, false) 41 | sp.Compute(input, false, y, sp.InhibitColumns) 42 | assert.Equal(t, 0, utils.CountTrue(y)) 43 | } 44 | 45 | // With learning on we should get the requested number of winners 46 | for _, input := range inputMatrix { 47 | utils.FillSliceBool(y, false) 48 | sp.Compute(input, true, y, sp.InhibitColumns) 49 | assert.Equal(t, sp.NumActiveColumnsPerInhArea, utils.CountTrue(y)) 50 | 51 | } 52 | 53 | // With learning off and some prior training we should get the requested 54 | // number of winners 55 | for _, input := range inputMatrix { 56 | utils.FillSliceBool(y, false) 57 | sp.Compute(input, false, y, sp.InhibitColumns) 58 | assert.Equal(t, sp.NumActiveColumnsPerInhArea, utils.CountTrue(y)) 59 | } 60 | 61 | } 62 | 63 | func TestBasicCompute1(t *testing.T) { 64 | 65 | spParams := NewSpParams() 66 | spParams.InputDimensions = []int{30} 67 | spParams.ColumnDimensions = []int{50} 68 | spParams.GlobalInhibition = true 69 | 70 | basicComputeLoop(t, spParams) 71 | } 72 | 73 | func TestBasicCompute2(t *testing.T) { 74 | 75 | spParams := NewSpParams() 76 | spParams.InputDimensions = []int{100} 77 | spParams.ColumnDimensions = []int{100} 78 | spParams.GlobalInhibition = true 79 | spParams.SynPermActiveInc = 0 80 | spParams.SynPermInactiveDec = 0 81 | 82 | basicComputeLoop(t, spParams) 83 | 84 | } 85 | -------------------------------------------------------------------------------- /encoders/scalerEncoder_test.go: -------------------------------------------------------------------------------- 1 | package encoders 2 | 3 | import ( 4 | "github.com/nupic-community/htm/utils" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestSimpleEncoding(t *testing.T) { 10 | 11 | p := NewScalerEncoderParams(3, 1, 8) 12 | p.N = 14 13 | p.Periodic = true 14 | //p.Verbosity = 5 15 | 16 | e := NewScalerEncoder(p) 17 | 18 | encoded := e.Encode(1, false) 19 | t.Log(encoded) 20 | expected := utils.Make1DBool([]int{1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) 21 | assert.True(t, len(encoded) == 14) 22 | assert.Equal(t, expected, encoded) 23 | 24 | encoded = e.Encode(2, false) 25 | expected = utils.Make1DBool([]int{0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) 26 | assert.True(t, len(encoded) == 14) 27 | assert.Equal(t, expected, encoded) 28 | 29 | encoded = e.Encode(3, false) 30 | expected = utils.Make1DBool([]int{0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}) 31 | assert.True(t, len(encoded) == 14) 32 | assert.Equal(t, expected, encoded) 33 | 34 | } 35 | 36 | func TestWideEncoding(t *testing.T) { 37 | 38 | p := NewScalerEncoderParams(5, 0, 24) 39 | p.Periodic = true 40 | //p.Verbosity = 5 41 | p.Radius = 4 42 | e := NewScalerEncoder(p) 43 | 44 | encoded := e.Encode(14.916666666666666, false) 45 | t.Log(encoded) 46 | expected := utils.Make1DBool([]int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0}) 47 | 48 | assert.True(t, len(encoded) == 30) 49 | assert.Equal(t, utils.Bool2Int(expected), utils.Bool2Int(encoded)) 50 | 51 | } 52 | 53 | func TestNarrowEncoding(t *testing.T) { 54 | 55 | p := NewScalerEncoderParams(3, 0, 1) 56 | p.Periodic = false 57 | //p.Verbosity = 5 58 | p.Radius = 1 59 | e := NewScalerEncoder(p) 60 | 61 | encoded := make([]bool, 6) 62 | e.EncodeToSlice(0, false, encoded) 63 | t.Log(encoded) 64 | expected := utils.Make1DBool([]int{1, 1, 1, 0, 0, 0}) 65 | 66 | assert.True(t, len(encoded) == 6) 67 | assert.Equal(t, utils.Bool2Int(expected), utils.Bool2Int(encoded)) 68 | 69 | } 70 | 71 | func TestSimpleDecoding(t *testing.T) { 72 | 73 | p := NewScalerEncoderParams(3, 1, 8) 74 | p.Radius = 1.5 75 | p.Periodic = true 76 | //p.Verbosity = 5 77 | 78 | e := NewScalerEncoder(p) 79 | 80 | // Test with a "hole" 81 | encoded := utils.Make1DBool([]int{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}) 82 | expected := []utils.TupleFloat{utils.TupleFloat{7.5, 7.5}} 83 | actual := e.Decode(encoded) 84 | assert.Equal(t, expected, actual) 85 | 86 | // Test with something wider than w, and with a hole, and wrapped 87 | encoded = utils.Make1DBool([]int{1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}) 88 | expected = []utils.TupleFloat{utils.TupleFloat{7.5, 8}, utils.TupleFloat{1, 1}} 89 | actual = e.Decode(encoded) 90 | assert.Equal(t, expected, actual) 91 | 92 | // Test with something wider than w, no hole 93 | encoded = utils.Make1DBool([]int{1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0}) 94 | expected = []utils.TupleFloat{utils.TupleFloat{1.5, 2.5}} 95 | actual = e.Decode(encoded) 96 | assert.Equal(t, expected, actual) 97 | 98 | // 1 99 | encoded = utils.Make1DBool([]int{1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) 100 | expected = []utils.TupleFloat{utils.TupleFloat{1, 1}} 101 | actual = e.Decode(encoded) 102 | assert.Equal(t, expected, actual) 103 | 104 | // 2 105 | encoded = utils.Make1DBool([]int{0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) 106 | expected = []utils.TupleFloat{utils.TupleFloat{2, 2}} 107 | actual = e.Decode(encoded) 108 | assert.Equal(t, expected, actual) 109 | 110 | } 111 | -------------------------------------------------------------------------------- /sparseBinaryMatrix_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | //Tests getting/setting values 9 | func TestGetSet(t *testing.T) { 10 | 11 | sm := NewSparseBinaryMatrix(10, 10) 12 | sm.Set(2, 4, true) 13 | sm.Set(6, 5, true) 14 | sm.Set(7, 5, false) 15 | 16 | if !sm.Get(2, 4) { 17 | t.Errorf("Was false expected true @ [2,4]") 18 | } 19 | 20 | if !sm.Get(6, 5) { 21 | t.Errorf("Was false expected true @ [6,5]") 22 | } 23 | 24 | if sm.Get(7, 5) { 25 | t.Errorf("Was true expected false @ [7,5]") 26 | } 27 | 28 | } 29 | 30 | func TestRowReplace(t *testing.T) { 31 | sm := NewSparseBinaryMatrix(10, 10) 32 | sm.Set(2, 4, true) 33 | sm.Set(6, 5, true) 34 | sm.Set(7, 5, true) 35 | sm.Set(8, 8, true) 36 | 37 | if !sm.Get(8, 8) { 38 | t.Errorf("Was false expected true @ [8,8]") 39 | } 40 | 41 | newRow := make([]bool, 10) 42 | newRow[6] = true 43 | sm.ReplaceRow(8, newRow) 44 | 45 | if !sm.Get(8, 6) { 46 | t.Errorf("Was false expected true @ [8,6]") 47 | } 48 | 49 | if sm.Get(8, 8) { 50 | t.Errorf("Was true expected false @ [8,8]") 51 | } 52 | 53 | } 54 | 55 | func TestReplaceRowByIndices(t *testing.T) { 56 | sm := NewSparseBinaryMatrix(10, 10) 57 | 58 | indices := make([]int, 3) 59 | indices[0] = 3 60 | indices[1] = 9 61 | indices[2] = 6 62 | sm.ReplaceRowByIndices(4, indices) 63 | 64 | if !sm.Get(4, 3) { 65 | t.Errorf("Was false expected true @ [4,3]") 66 | } 67 | 68 | if !sm.Get(4, 9) { 69 | t.Errorf("Was false expected true @ [4,9]") 70 | } 71 | 72 | if !sm.Get(4, 6) { 73 | t.Errorf("Was false expected true @ [4,6]") 74 | } 75 | 76 | if sm.Get(4, 5) { 77 | t.Errorf("Was true expected false @ [4,5]") 78 | } 79 | 80 | if sm.Get(4, 0) { 81 | t.Errorf("Was true expected false @ [4,0]") 82 | } 83 | 84 | indices = make([]int, 3) 85 | indices[0] = 4 86 | 87 | sm.ReplaceRowByIndices(4, indices) 88 | if sm.Get(4, 3) { 89 | t.Errorf("Was true expected false @ [4,3]") 90 | } 91 | 92 | if sm.Get(4, 9) { 93 | t.Errorf("Was true expected false @ [4,9]") 94 | } 95 | 96 | if !sm.Get(4, 4) { 97 | t.Errorf("Was false expected true @ [4,4]") 98 | } 99 | 100 | } 101 | 102 | func TestGetRowIndices(t *testing.T) { 103 | sm := NewSparseBinaryMatrix(10, 10) 104 | 105 | indices := make([]int, 3) 106 | indices[0] = 3 107 | indices[1] = 6 108 | indices[2] = 9 109 | sm.ReplaceRowByIndices(4, indices) 110 | 111 | indResult := sm.GetRowIndices(4) 112 | 113 | if len(indResult) != len(indices) { 114 | t.Errorf("Len was %v expected %v", len(indResult), len(indices)) 115 | } 116 | 117 | for i := 0; i < 3; i++ { 118 | if indResult[i] != indices[i] { 119 | t.Errorf("Was %v expected %v", indResult, indices) 120 | } 121 | } 122 | 123 | } 124 | 125 | func TestGetRowAndSum(t *testing.T) { 126 | sm := NewSparseBinaryMatrix(4, 5) 127 | 128 | sm.SetRowFromDense(0, []bool{true, false, true, true, false}) 129 | sm.SetRowFromDense(1, []bool{false, false, false, true, false}) 130 | sm.SetRowFromDense(2, []bool{false, false, false, false, false}) 131 | sm.SetRowFromDense(3, []bool{true, true, true, true, true}) 132 | 133 | t.Log(sm.ToString()) 134 | t.Log(sm.Entries()) 135 | i := []bool{true, false, true, true, false} 136 | 137 | result := sm.RowAndSum(i) 138 | 139 | assert.Equal(t, 3, result[0]) 140 | assert.Equal(t, 1, result[1]) 141 | assert.Equal(t, 0, result[2]) 142 | assert.Equal(t, 3, result[3]) 143 | 144 | } 145 | 146 | func TestSetRowFromDense(t *testing.T) { 147 | 148 | } 149 | 150 | func TestNewFromDense(t *testing.T) { 151 | sbm := NewSparseBinaryMatrixFromDense([][]bool{ 152 | {true, true, true}, 153 | {false, false, false}, 154 | {false, true, false}, 155 | {true, false, true}, 156 | }) 157 | 158 | assert.Equal(t, 4, sbm.Height) 159 | assert.Equal(t, 3, sbm.Width) 160 | assert.Equal(t, true, sbm.Get(3, 2)) 161 | assert.Equal(t, []bool{false, true, false}, sbm.GetDenseRow(2)) 162 | 163 | } 164 | -------------------------------------------------------------------------------- /denseBinaryMatrix_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "github.com/zacg/go.matrix" 6 | "math/rand" 7 | "testing" 8 | ) 9 | 10 | //Tests getting/setting values 11 | func TestDenseGetSet(t *testing.T) { 12 | 13 | sm := NewDenseBinaryMatrix(10, 10) 14 | sm.Set(2, 4, true) 15 | sm.Set(6, 5, true) 16 | sm.Set(7, 5, false) 17 | 18 | if !sm.Get(2, 4) { 19 | t.Errorf("Was false expected true @ [2,4]") 20 | } 21 | 22 | if !sm.Get(6, 5) { 23 | t.Errorf("Was false expected true @ [6,5]") 24 | } 25 | 26 | if sm.Get(7, 5) { 27 | t.Errorf("Was true expected false @ [7,5]") 28 | } 29 | 30 | } 31 | 32 | func TestDenseRowReplace(t *testing.T) { 33 | sm := NewDenseBinaryMatrix(10, 10) 34 | sm.Set(2, 4, true) 35 | sm.Set(6, 5, true) 36 | sm.Set(7, 5, true) 37 | sm.Set(8, 8, true) 38 | 39 | if !sm.Get(8, 8) { 40 | t.Errorf("Was false expected true @ [8,8]") 41 | } 42 | 43 | newRow := make([]bool, 10) 44 | newRow[6] = true 45 | sm.ReplaceRow(8, newRow) 46 | 47 | if !sm.Get(8, 6) { 48 | t.Errorf("Was false expected true @ [8,6]") 49 | } 50 | 51 | if sm.Get(8, 8) { 52 | t.Errorf("Was true expected false @ [8,8]") 53 | } 54 | 55 | } 56 | 57 | func TestDenseReplaceRowByIndices(t *testing.T) { 58 | sm := NewDenseBinaryMatrix(10, 10) 59 | 60 | indices := make([]int, 3) 61 | indices[0] = 3 62 | indices[1] = 9 63 | indices[2] = 6 64 | sm.ReplaceRowByIndices(4, indices) 65 | 66 | if !sm.Get(4, 3) { 67 | t.Errorf("Was false expected true @ [4,3]") 68 | } 69 | 70 | if !sm.Get(4, 9) { 71 | t.Errorf("Was false expected true @ [4,9]") 72 | } 73 | 74 | if !sm.Get(4, 6) { 75 | t.Errorf("Was false expected true @ [4,6]") 76 | } 77 | 78 | if sm.Get(4, 5) { 79 | t.Errorf("Was true expected false @ [4,5]") 80 | } 81 | 82 | if sm.Get(4, 0) { 83 | t.Errorf("Was true expected false @ [4,0]") 84 | } 85 | 86 | indices = make([]int, 3) 87 | indices[0] = 4 88 | 89 | sm.ReplaceRowByIndices(4, indices) 90 | if sm.Get(4, 3) { 91 | t.Errorf("Was true expected false @ [4,3]") 92 | } 93 | 94 | if sm.Get(4, 9) { 95 | t.Errorf("Was true expected false @ [4,9]") 96 | } 97 | 98 | if !sm.Get(4, 4) { 99 | t.Errorf("Was false expected true @ [4,4]") 100 | } 101 | 102 | } 103 | 104 | func TestDenseGetRowIndices(t *testing.T) { 105 | sm := NewDenseBinaryMatrix(10, 10) 106 | 107 | indices := make([]int, 3) 108 | indices[0] = 3 109 | indices[1] = 6 110 | indices[2] = 9 111 | sm.ReplaceRowByIndices(4, indices) 112 | 113 | indResult := sm.GetRowIndices(4) 114 | 115 | if len(indResult) != len(indices) { 116 | t.Errorf("Len was %v expected %v", len(indResult), len(indices)) 117 | } 118 | 119 | t.Log("len", len(indResult)) 120 | t.Log("indResult", indResult) 121 | 122 | for i := 0; i < 3; i++ { 123 | if indResult[i] != indices[i] { 124 | t.Errorf("Was %v expected %v", indResult, indices) 125 | } 126 | } 127 | 128 | } 129 | 130 | func TestDenseGetRowAndSum(t *testing.T) { 131 | sm := NewDenseBinaryMatrix(4, 5) 132 | 133 | sm.SetRowFromDense(0, []bool{true, false, true, true, false}) 134 | sm.SetRowFromDense(1, []bool{false, false, false, true, false}) 135 | sm.SetRowFromDense(2, []bool{false, false, false, false, false}) 136 | sm.SetRowFromDense(3, []bool{true, true, true, true, true}) 137 | 138 | t.Log(sm.ToString()) 139 | t.Log(sm.Entries()) 140 | i := []bool{true, false, true, true, false} 141 | 142 | result := sm.RowAndSum(i) 143 | 144 | assert.Equal(t, 3, result[0]) 145 | assert.Equal(t, 1, result[1]) 146 | assert.Equal(t, 0, result[2]) 147 | assert.Equal(t, 3, result[3]) 148 | 149 | } 150 | 151 | func TestDenseNewFromDense(t *testing.T) { 152 | sbm := NewDenseBinaryMatrixFromDense([][]bool{ 153 | {true, true, true}, 154 | {false, false, false}, 155 | {false, true, false}, 156 | {true, false, true}, 157 | }) 158 | 159 | assert.Equal(t, 4, sbm.Height) 160 | assert.Equal(t, 3, sbm.Width) 161 | assert.Equal(t, true, sbm.Get(3, 2)) 162 | assert.Equal(t, []bool{false, true, false}, sbm.GetDenseRow(2)) 163 | 164 | } 165 | 166 | func BenchmarkDenseSet(t *testing.B) { 167 | elms := make(map[int]float64, 1258291) 168 | m := matrix.MakeSparseMatrix(elms, 1024, 4096) 169 | 170 | for i := 0; i < 3500000; i++ { 171 | row := rand.Intn(1023) 172 | col := rand.Intn(4095) 173 | value := float64(rand.Intn(1000)) 174 | m.Set(row, col, value) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /temporalMemoryConnections.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | // "fmt" 5 | // "github.com/cznic/mathutil" 6 | // "github.com/zacg/floats" 7 | // "github.com/zacg/go.matrix" 8 | //"github.com/nupic-community/htmutils" 9 | "github.com/zacg/ints" 10 | // //"math" 11 | // "math/rand" 12 | // //"sort" 13 | ) 14 | 15 | type TmSynapse struct { 16 | Segment int 17 | SourceCell int 18 | Permanence float64 19 | } 20 | 21 | /* 22 | Structure holds data representing the connectivity of a layer of cells, 23 | that the TM operates on. 24 | */ 25 | type TemporalMemoryConnections struct { 26 | ColumnDimensions []int 27 | CellsPerColumn int 28 | 29 | segments []int 30 | synapses []*TmSynapse 31 | 32 | synapsesForSegment [][]int 33 | synapsesForSourceCell [][]int 34 | 35 | segmentsForCell [][]int 36 | 37 | segmentIndex int 38 | synIndex int 39 | 40 | maxSynapseCount int 41 | } 42 | 43 | //Create a new temporal memory 44 | func NewTemporalMemoryConnections(maxSynCount int, cellsPerColumn int, colDimensions []int) *TemporalMemoryConnections { 45 | if len(colDimensions) < 1 { 46 | panic("Column dimensions must be greater than 0") 47 | } 48 | 49 | if cellsPerColumn < 1 { 50 | panic("Number of cells per column must be greater than 0") 51 | } 52 | 53 | c := new(TemporalMemoryConnections) 54 | c.maxSynapseCount = maxSynCount 55 | c.CellsPerColumn = cellsPerColumn 56 | c.ColumnDimensions = colDimensions 57 | 58 | c.synapses = make([]*TmSynapse, 0, c.maxSynapseCount) 59 | //TODO: calc better size 60 | c.segments = make([]int, 0, 50000) 61 | c.segmentsForCell = make([][]int, cap(c.segments)) 62 | c.synapsesForSegment = make([][]int, cap(c.segments)) 63 | c.synapsesForSourceCell = make([][]int, cap(c.segments)) 64 | 65 | return c 66 | } 67 | 68 | // func (tmc *TemporalMemoryConnections) nextSegmentIndex() int { 69 | // idx := tmc.segmentIndex 70 | // tmc.segmentIndex++ 71 | // return idx 72 | // } 73 | 74 | // func (tmc *TemporalMemoryConnections) nextSynapseIndex() int { 75 | // idx := tmc.synIndex 76 | // tmc.synIndex++ 77 | // return idx 78 | // } 79 | 80 | func (tmc *TemporalMemoryConnections) CreateSynapse(segment int, sourceCell int, permanence float64) *TmSynapse { 81 | syn := len(tmc.synapses) 82 | data := new(TmSynapse) 83 | data.Segment = segment 84 | data.SourceCell = sourceCell 85 | data.Permanence = permanence 86 | tmc.synapses = append(tmc.synapses, data) 87 | 88 | //Update indexes 89 | tmc.synapsesForSegment[segment] = append(tmc.synapsesForSegment[segment], syn) 90 | tmc.synapsesForSourceCell[sourceCell] = append(tmc.synapsesForSourceCell[sourceCell], syn) 91 | 92 | return data 93 | } 94 | 95 | //Creates a new segment on specified cell, returns segment index 96 | func (tmc *TemporalMemoryConnections) CreateSegment(cell int) int { 97 | idx := len(tmc.segments) 98 | // Add data 99 | tmc.segments = append(tmc.segments, cell) 100 | tmc.segmentsForCell[cell] = append(tmc.segmentsForCell[cell], idx) 101 | return idx 102 | } 103 | 104 | //Updates the permanence for a synapse. 105 | func (tmc *TemporalMemoryConnections) UpdateSynapsePermanence(synapse int, permanence float64) { 106 | tmc.validatePermanence(permanence) 107 | tmc.synapses[synapse].Permanence = permanence 108 | } 109 | 110 | //Returns the index of the column that a cell belongs to. 111 | func (tmc *TemporalMemoryConnections) ColumnForCell(cell int) int { 112 | return int(cell / tmc.CellsPerColumn) 113 | } 114 | 115 | //Returns the indices of cells that belong to a column. 116 | func (tmc *TemporalMemoryConnections) CellsForColumn(column int) []int { 117 | start := tmc.CellsPerColumn * column 118 | result := make([]int, tmc.CellsPerColumn) 119 | for idx, _ := range result { 120 | result[idx] = start + idx 121 | } 122 | return result 123 | } 124 | 125 | //Returns the cell that a segment belongs to. 126 | func (tmc *TemporalMemoryConnections) CellForSegment(segment int) int { 127 | return tmc.segments[segment] 128 | } 129 | 130 | //Returns the segments that belong to a cell. 131 | func (tmc *TemporalMemoryConnections) SegmentsForCell(cell int) []int { 132 | return tmc.segmentsForCell[cell] 133 | } 134 | 135 | //Returns synapse data for specified index 136 | func (tmc *TemporalMemoryConnections) DataForSynapse(synapse int) *TmSynapse { 137 | return tmc.synapses[synapse] 138 | } 139 | 140 | //Returns the synapses on a segment. 141 | func (tmc *TemporalMemoryConnections) SynapsesForSegment(segment int) []int { 142 | return tmc.synapsesForSegment[segment] 143 | } 144 | 145 | //Returns the synapses for the source cell that they synapse on. 146 | func (tmc *TemporalMemoryConnections) SynapsesForSourceCell(sourceCell int) []int { 147 | return tmc.synapsesForSourceCell[sourceCell] 148 | } 149 | 150 | // Helpers 151 | 152 | //Returns the number of columns in this layer. 153 | func (tmc *TemporalMemoryConnections) NumberOfColumns() int { 154 | return ints.Prod(tmc.ColumnDimensions) 155 | } 156 | 157 | //Returns the number of cells in this layer. 158 | func (tmc *TemporalMemoryConnections) NumberOfcells() int { 159 | return tmc.NumberOfColumns() * tmc.CellsPerColumn 160 | } 161 | 162 | //Validation 163 | 164 | func (tmc *TemporalMemoryConnections) validatePermanence(permanence float64) { 165 | if permanence < 0 || permanence > 1 { 166 | panic("invalid permanence value") 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | htm 2 | === 3 | 4 | Hierarchical Temporal Memory Implementation in Golang 5 | 6 | [![GoDoc](https://godoc.org/github.com/nupic-community/htm?status.png)](https://godoc.org/github.com/nupic-community/htm) 7 | [![Build Status](https://travis-ci.org/nupic-community/htm.svg?branch=master)](https://travis-ci.org/nupic-community/htm) 8 | 9 | This is a direct port of the spatial & temporal poolers, temporal memory, and encoders as they currently exist in Numenta's Nupic Project. This project was done as a learning exercise, no effort has been made to optimize this implementation and it was not designed for production use. 10 | 11 | The Nupic project basically demonstrates the CLA, a single stage of the cortical hierarchy. Eventually this same code can be extended to form a full HTM hierarchy. https://github.com/numenta/nupic 12 | 13 | ##Changes From Numentas Implementation 14 | * Temporal pooler ephemeral state is stored in strongly typed struct rather than a hashmap. t-1 vars have "last" appended to their names. 15 | * Temporal pooler params stored in "params" sub struct 16 | * Binary data structures are used rather than ints 17 | * No C++ dependency everything is written in Go 18 | 19 | ##Current State of Project 20 | * Temporal and Spatial poolers pass basic tests 21 | * Temporal memory passes basic unit tests 22 | * Basic scaler encoder implemented 23 | 24 | ##Todo 25 | ~~* Finish temporal unit tests~~ 26 | * Implement a better sparse binary matrix structure with versions optimized for col or row heavy access. 27 | * Implement better binary datastructure 28 | * Refactor to be more idiomatic Go. It is basically a line for line port of the python implementation, it could be refactored to make better use of Go's type system. 29 | * Implement some of the common encoders 30 | 31 | ##Examples 32 | 33 | ###Temporal Pooler 34 | ```go 35 | package main 36 | 37 | import ( 38 | "fmt" 39 | "github.com/zacg/htm" 40 | "github.com/nupic-community/htmutils" 41 | ) 42 | 43 | func main() { 44 | tps := htm.NewTemporalPoolerParams() 45 | tps.Verbosity = 0 46 | tps.NumberOfCols = 50 47 | tps.CellsPerColumn = 2 48 | tps.ActivationThreshold = 8 49 | tps.MinThreshold = 10 50 | tps.InitialPerm = 0.5 51 | tps.ConnectedPerm = 0.5 52 | tps.NewSynapseCount = 10 53 | tps.PermanenceDec = 0.0 54 | tps.PermanenceInc = 0.1 55 | tps.GlobalDecay = 0 56 | tps.BurnIn = 1 57 | tps.PamLength = 10 58 | tps.CollectStats = true 59 | tp := htm.NewTemporalPooler(*tps) 60 | 61 | //Mock encoding of ABCDE 62 | inputs := make([][]bool, 5) 63 | inputs[0] = boolRange(0, 9, 50) //bits 0-9 are "on" 64 | inputs[1] = boolRange(10, 19, 50) //bits 10-19 are "on" 65 | inputs[2] = boolRange(20, 29, 50) //bits 20-29 are "on" 66 | inputs[3] = boolRange(30, 39, 50) //bits 30-39 are "on" 67 | inputs[4] = boolRange(40, 49, 50) //bits 40-49 are "on" 68 | 69 | //Learn 5 sequences above 70 | for i := 0; i < 10; i++ { 71 | for p := 0; p < 5; p++ { 72 | tp.Compute(inputs[p], true, false) 73 | } 74 | tp.Reset() 75 | } 76 | 77 | //Predict sequences 78 | for i := 0; i < 4; i++ { 79 | tp.Compute(inputs[i], false, true) 80 | p := tp.DynamicState.InfPredictedState 81 | 82 | fmt.Printf("Predicted: %v From input: %v \n", p.NonZeroRows(), utils.OnIndices(inputs[i])) 83 | 84 | } 85 | 86 | } 87 | 88 | //helper method for creating boolean sequences 89 | func boolRange(start int, end int, length int) []bool { 90 | result := make([]bool, length) 91 | for i := start; i <= end; i++ { 92 | result[i] = true 93 | } 94 | return result 95 | } 96 | 97 | 98 | ``` 99 | 100 | ###Spatial Pooler 101 | ```go 102 | package main 103 | 104 | import ( 105 | "fmt" 106 | "github.com/davecheney/profile" 107 | "github.com/zacg/htm" 108 | "github.com/nupic-community/htmutils" 109 | "math/rand" 110 | ) 111 | 112 | func main() { 113 | 114 | ssp := htm.NewSpParams() 115 | ssp.ColumnDimensions = []int{64, 64} 116 | ssp.InputDimensions = []int{32, 32} 117 | ssp.PotentialRadius = ssp.NumInputs() 118 | ssp.NumActiveColumnsPerInhArea = int(0.02 * float64(ssp.NumColumns())) 119 | ssp.GlobalInhibition = true 120 | ssp.SynPermActiveInc = 0.01 121 | ssp.SpVerbosity = 10 122 | sp := htm.NewSpatialPooler(ssp) 123 | 124 | 125 | activeArray := make([]bool, sp.NumColumns()) 126 | inputVector := make([]bool, sp.NumInputs()) 127 | 128 | for idx, _ := range inputVector { 129 | inputVector[idx] = rand.Intn(5) >= 2 130 | } 131 | 132 | sp.Compute(inputVector, true, activeArray, sp.InhibitColumns) 133 | 134 | fmt.Println("Active Indices:", utils.OnIndices(activeArray)) 135 | 136 | } 137 | 138 | ``` 139 | 140 | ###Temporal Memory 141 | ```go 142 | 143 | tmp := NewTemporalMemoryParams() 144 | tmp.MaxNewSynapseCount = 1000 145 | 146 | tm := NewTemporalMemory(tmp) 147 | 148 | ``` 149 | 150 | ###Encoding 151 | ```go 152 | 153 | //Create new scaler encoder 154 | p := NewScalerEncoderParams(3, 1, 8) 155 | p.Radius = 1.5 156 | p.Periodic = true 157 | p.Verbosity = 5 158 | e := NewScalerEncoder(p) 159 | 160 | //Encode "1" 161 | encoded := e.Encode(1, false) 162 | 163 | //Print results 164 | fmt.Printfn("1 Encoded as: %v", utils.Bool2Int(encoded)) 165 | 166 | ``` 167 | 168 | ```go 169 | 170 | //Create new date encoder 171 | p := NewDateEncoderParams() 172 | p.SeasonWidth = 3 173 | p.DayOfWeekWidth = 1 174 | p.WeekendWidth = 3 175 | p.TimeOfDayWidth = 5 176 | p.Verbosity = 5 177 | de := NewDateEncoder(p) 178 | 179 | d := time.Date(2010, 11, 4, 14, 55, 0, 0, time.UTC) 180 | encoded := de.Encode(d) 181 | 182 | //Print results 183 | fmt.Printfn("%v Encoded as: %v", d, utils.Bool2Int(encoded)) 184 | 185 | ``` -------------------------------------------------------------------------------- /segmentUpdate.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "fmt" 5 | "github.com/nupic-community/htm/utils" 6 | 7 | //"github.com/cznic/mathutil" 8 | //"github.com/skelterjohn/go.matrix" 9 | //"math" 10 | //"math/rand" 11 | //"sort" 12 | ) 13 | 14 | type SynapseUpdateState struct { 15 | New bool 16 | Index int 17 | CellIndex int //only set when new 18 | } 19 | 20 | type SegmentUpdate struct { 21 | columnIdx int 22 | cellIdx int 23 | segment *Segment 24 | activeSynapses []SynapseUpdateState 25 | sequenceSegment bool 26 | phase1Flag bool 27 | weaklyPredicting bool 28 | lrnIterationIdx int 29 | } 30 | 31 | type UpdateState struct { 32 | //creationdate refers to iteration idx 33 | CreationDate int 34 | Update *SegmentUpdate 35 | } 36 | 37 | /* 38 | Store a dated potential segment update. The "date" (iteration index) is used 39 | later to determine whether the update is too old and should be forgotten. 40 | This is controlled by parameter segUpdateValidDuration. 41 | */ 42 | func (tp *TemporalPooler) addToSegmentUpdates(c, i int, segUpdate *SegmentUpdate) { 43 | if segUpdate == nil || len(segUpdate.activeSynapses) == 0 { 44 | return 45 | } 46 | 47 | // key = (column index, cell index in column) 48 | key := utils.TupleInt{} 49 | key.A = c 50 | key.B = i 51 | 52 | newUpdate := UpdateState{tp.lrnIterationIdx, segUpdate} 53 | if tp.segmentUpdates == nil { 54 | tp.segmentUpdates = make(map[utils.TupleInt][]UpdateState, 1000) 55 | } 56 | if _, ok := tp.segmentUpdates[key]; ok { 57 | tp.segmentUpdates[key] = append(tp.segmentUpdates[key], newUpdate) 58 | } else { 59 | tp.segmentUpdates[key] = []UpdateState{newUpdate} 60 | } 61 | 62 | } 63 | 64 | /* 65 | This function applies segment update information to a segment in a 66 | cell. 67 | 68 | Synapses on the active list get their permanence counts incremented by 69 | permanenceInc. All other synapses get their permanence counts decremented 70 | by permanenceDec. 71 | 72 | We also increment the positiveActivations count of the segment. 73 | 74 | param segUpdate SegmentUpdate instance 75 | returns True if some synapses were decremented to 0 and the segment is a 76 | candidate for trimming 77 | */ 78 | func (segUpdate *SegmentUpdate) adaptSegments(tp *TemporalPooler) bool { 79 | // This will be set to True if detect that any syapses were decremented to 0 80 | trimSegment := false 81 | 82 | // segUpdate.segment is None when creating a new segment 83 | //c, i, segment := segUpdate.columnIdx, segUpdate.cellIdx, segUpdate.segment 84 | c := segUpdate.columnIdx 85 | i := segUpdate.cellIdx 86 | segment := segUpdate.segment 87 | 88 | // update.activeSynapses can be empty. 89 | // If not, it can contain either or both integers and tuples. 90 | // The integers are indices of synapses to update. 91 | // The tuples represent new synapses to create (src col, src cell in col). 92 | // We pre-process to separate these various element types. 93 | // synToCreate is not empty only if positiveReinforcement is True. 94 | // NOTE: the synapse indices start at *1* to skip the segment flags. 95 | activeSynapses := segUpdate.activeSynapses 96 | 97 | var synToUpdate []int 98 | for _, val := range activeSynapses { 99 | if !val.New { 100 | synToUpdate = append(synToUpdate, val.Index) 101 | } 102 | } 103 | 104 | //fmt.Printf("Entering adapt seg %v %v \n", segment, len(activeSynapses)) 105 | 106 | if segment != nil { 107 | 108 | if tp.params.Verbosity >= 4 { 109 | fmt.Printf("Reinforcing segment #%v for cell[%v,%v] \n", segment.segId, c, i) 110 | } 111 | 112 | //modify existing segment 113 | // Mark it as recently useful 114 | segment.lastActiveIteration = tp.lrnIterationIdx 115 | 116 | // Update frequency and positiveActivations 117 | segment.positiveActivations++ 118 | segment.dutyCycle(true, false) 119 | 120 | // First, decrement synapses that are not active 121 | lastSynIndex := len(segment.syns) - 1 122 | 123 | var inactiveSynIndices []int 124 | for i := 0; i < lastSynIndex+1; i++ { 125 | if !utils.ContainsInt(i, synToUpdate) { 126 | inactiveSynIndices = append(inactiveSynIndices, i) 127 | } 128 | } 129 | 130 | trimSegment = segment.updateSynapses(inactiveSynIndices, -tp.params.PermanenceDec) 131 | 132 | // Now, increment active synapses 133 | var activeSynIndices []int 134 | for _, val := range activeSynapses { 135 | if val.Index <= lastSynIndex { 136 | activeSynIndices = append(activeSynIndices, val.Index) 137 | } 138 | } 139 | 140 | segment.updateSynapses(activeSynIndices, tp.params.PermanenceInc) 141 | 142 | // Finally, create new synapses if needed 143 | var synsToAdd []SynapseUpdateState 144 | for _, val := range activeSynapses { 145 | if val.New { 146 | synsToAdd = append(synsToAdd, val) 147 | } 148 | } 149 | 150 | // If we have fixed resources, get rid of some old syns if necessary 151 | if tp.params.MaxSynapsesPerSegment > 0 && len(synsToAdd)+len(segment.syns) > tp.params.MaxSynapsesPerSegment { 152 | numToFree := (len(segment.syns) + len(synsToAdd)) - tp.params.MaxSynapsesPerSegment 153 | segment.freeNSynapses(numToFree, inactiveSynIndices) 154 | } 155 | 156 | for _, val := range synsToAdd { 157 | segment.AddSynapse(val.Index, val.CellIndex, tp.params.InitialPerm) 158 | } 159 | 160 | } else { 161 | //create new segment 162 | newSegment := NewSegment(tp, segUpdate.sequenceSegment) 163 | 164 | for _, val := range activeSynapses { 165 | newSegment.AddSynapse(val.Index, val.CellIndex, tp.params.InitialPerm) 166 | } 167 | 168 | if tp.params.Verbosity >= 3 { 169 | fmt.Printf("New segment #%v for cell[%v,%v] \n", tp.segId, c, i) 170 | fmt.Print(newSegment.ToString()) 171 | } 172 | 173 | tp.cells[c][i] = append(tp.cells[c][i], *newSegment) 174 | } 175 | 176 | return trimSegment 177 | } 178 | -------------------------------------------------------------------------------- /trivialPredictor.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cznic/mathutil" 6 | "github.com/nupic-community/htm/utils" 7 | //"github.com/skelterjohn/go.matrix" 8 | //"math" 9 | //"math/rand" 10 | //"sort" 11 | //"github.com/gonum/floats" 12 | "github.com/zacg/ints" 13 | ) 14 | 15 | /* 16 | (n = half the number of average input columns on) 17 | "random" - predict n random columns 18 | "zeroth" - predict the n most common columns learned from the input 19 | "last" - predict the last input 20 | "all" - predict all columns 21 | "lots" - predict the 2n most common columns learned from the input 22 | 23 | Both "random" and "all" should give a prediction score of zero" 24 | */ 25 | 26 | type PredictorMethod int 27 | 28 | const ( 29 | Random PredictorMethod = 1 30 | Zeroth PredictorMethod = 2 31 | Last PredictorMethod = 3 32 | All PredictorMethod = 4 33 | Lots PredictorMethod = 5 34 | ) 35 | 36 | type TrivialPredictorState struct { 37 | ActiveState []bool 38 | ActiveStateLast []bool 39 | PredictedState []bool 40 | PredictedStateLast []bool 41 | Confidence []float64 42 | ConfidenceLast []float64 43 | } 44 | 45 | type TrivialPredictor struct { 46 | NumOfCols int 47 | Methods []PredictorMethod 48 | Verbosity int 49 | InternalStats map[PredictorMethod]*TpStats 50 | State map[PredictorMethod]TrivialPredictorState 51 | ColumnCount []int 52 | AverageDensity float64 53 | } 54 | 55 | func MakeTrivialPredictor(numberOfCols int, methods []PredictorMethod) *TrivialPredictor { 56 | tp := new(TrivialPredictor) 57 | 58 | for _, method := range methods { 59 | tps := TrivialPredictorState{} 60 | tps.ActiveState = make([]bool, numberOfCols) 61 | tps.ActiveStateLast = make([]bool, numberOfCols) 62 | tps.Confidence = make([]float64, numberOfCols) 63 | tps.ConfidenceLast = make([]float64, numberOfCols) 64 | tps.PredictedState = make([]bool, numberOfCols) 65 | tps.PredictedStateLast = make([]bool, numberOfCols) 66 | tp.State[method] = tps 67 | 68 | tp.InternalStats[method] = new(TpStats) 69 | } 70 | 71 | // Number of times each column has been active during learning 72 | tp.ColumnCount = make([]int, numberOfCols) 73 | 74 | // Running average of input density 75 | tp.AverageDensity = 0.05 76 | 77 | return tp 78 | } 79 | 80 | /* 81 | 82 | */ 83 | 84 | func (tp *TrivialPredictor) infer(activeColumns []int) { 85 | 86 | numColsToPredict := int(0.5 + tp.AverageDensity*float64(tp.NumOfCols)) 87 | 88 | //for method in self.methods: 89 | for _, method := range tp.Methods { 90 | // Copy t-1 into t 91 | copy(tp.State[method].ActiveStateLast, tp.State[method].ActiveState) 92 | copy(tp.State[method].PredictedStateLast, tp.State[method].PredictedState) 93 | copy(tp.State[method].ConfidenceLast, tp.State[method].Confidence) 94 | 95 | utils.FillSliceBool(tp.State[method].ActiveState, false) 96 | utils.FillSliceBool(tp.State[method].PredictedState, false) 97 | utils.FillSliceFloat64(tp.State[method].Confidence, 0.0) 98 | 99 | for _, val := range activeColumns { 100 | tp.State[method].ActiveState[val] = true 101 | } 102 | 103 | var predictedCols []int 104 | 105 | switch method { 106 | case Random: 107 | // Randomly predict N columns 108 | //predictedCols = RandomInts(numColsToPredict, tp.NumOfCols) 109 | break 110 | case Zeroth: 111 | // Always predict the top N most frequent columns 112 | var inds []int 113 | ints.Argsort(tp.ColumnCount, inds) 114 | predictedCols = inds[len(inds)-numColsToPredict:] 115 | break 116 | case Last: 117 | // Always predict the last input 118 | for idx, val := range tp.State[method].ActiveState { 119 | if val { 120 | predictedCols = append(predictedCols, idx) 121 | } 122 | } 123 | break 124 | case All: 125 | // Always predict all columns 126 | for i := 0; i < tp.NumOfCols; i++ { 127 | predictedCols = append(predictedCols, i) 128 | } 129 | break 130 | case Lots: 131 | // Always predict 2 * the top N most frequent columns 132 | numColsToPredict := mathutil.Min(2*numColsToPredict, tp.NumOfCols) 133 | var inds []int 134 | ints.Argsort(tp.ColumnCount, inds) 135 | predictedCols = inds[len(inds)-numColsToPredict:] 136 | 137 | break 138 | default: 139 | panic("prediction method not implemented") 140 | } 141 | 142 | for _, val := range predictedCols { 143 | tp.State[method].PredictedState[val] = true 144 | tp.State[method].Confidence[val] = 1.0 145 | } 146 | 147 | if tp.Verbosity > 1 { 148 | fmt.Println("Random prediction:", method) 149 | fmt.Println(" numColsToPredict:", numColsToPredict) 150 | fmt.Println(predictedCols) 151 | } 152 | 153 | } 154 | 155 | } 156 | 157 | /* 158 | Do one iteration of the temporal pooler learning. 159 | Returns TP output 160 | */ 161 | 162 | func (tp *TrivialPredictor) learn(activeColumns []int) { 163 | // Running average of bottom up density 164 | density := float64(len(activeColumns)) / float64(tp.NumOfCols) 165 | 166 | tp.AverageDensity = 0.95*tp.AverageDensity + 0.05*density 167 | 168 | // Running count of how often each column has been active 169 | for _, val := range activeColumns { 170 | tp.ColumnCount[val]++ 171 | } 172 | 173 | // Do "inference" 174 | tp.infer(activeColumns) 175 | } 176 | 177 | /* 178 | Reset the state of all cells. 179 | This is normally used between sequences while training. All internal states 180 | are reset to 0. 181 | */ 182 | 183 | func (tp *TrivialPredictor) reset() { 184 | 185 | for _, method := range tp.Methods { 186 | 187 | utils.FillSliceBool(tp.State[method].ActiveState, false) 188 | utils.FillSliceBool(tp.State[method].ActiveStateLast, false) 189 | utils.FillSliceBool(tp.State[method].PredictedState, false) 190 | utils.FillSliceBool(tp.State[method].PredictedStateLast, false) 191 | utils.FillSliceFloat64(tp.State[method].Confidence, 0.0) 192 | utils.FillSliceFloat64(tp.State[method].ConfidenceLast, 0.0) 193 | 194 | stats := tp.InternalStats[method] 195 | stats.NInfersSinceReset = 0 196 | stats.CurPredictionScore = 0.0 197 | stats.CurPredictionScore2 = 0.0 198 | stats.FalseNegativeScoreTotal = 0.0 199 | stats.FalsePositiveScoreTotal = 0.0 200 | stats.CurExtra = 0.0 201 | stats.CurMissing = 0.0 202 | tp.InternalStats[method] = stats 203 | } 204 | 205 | } 206 | 207 | /* 208 | Reset the learning and inference stats. This will usually be called by 209 | user code at the start of each inference run (for a particular data set). 210 | */ 211 | 212 | func (tp *TrivialPredictor) resetStats() { 213 | 214 | tp.reset() 215 | 216 | //Additionally, reset all of the "total" values 217 | for _, method := range tp.Methods { 218 | 219 | stats := tp.InternalStats[method] 220 | stats.NInfersSinceReset = 0 221 | stats.NPredictions = 0 222 | stats.PredictionScoreTotal = 0 223 | stats.PredictionScoreTotal2 = 0 224 | stats.FalseNegativeScoreTotal = 0 225 | stats.FalsePositiveScoreTotal = 0 226 | stats.PctExtraTotal = 0.0 227 | stats.PctMissingTotal = 0.0 228 | stats.TotalMissing = 0.0 229 | stats.TotalExtra = 0.0 230 | tp.InternalStats[method] = stats 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /denseBinaryMatrix.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "bytes" 5 | "github.com/nupic-community/htm/utils" 6 | //"math" 7 | ) 8 | 9 | //Sparse binary matrix stores indexes of non-zero entries in matrix 10 | //to conserve space 11 | type DenseBinaryMatrix struct { 12 | Width int 13 | Height int 14 | entries []bool 15 | } 16 | 17 | //Create new sparse binary matrix of specified size 18 | func NewDenseBinaryMatrix(height, width int) *DenseBinaryMatrix { 19 | m := &DenseBinaryMatrix{} 20 | m.Height = height 21 | m.Width = width 22 | m.entries = make([]bool, width*height) 23 | return m 24 | } 25 | 26 | //Create sparse binary matrix from specified dense matrix 27 | func NewDenseBinaryMatrixFromDense(values [][]bool) *DenseBinaryMatrix { 28 | if len(values) < 1 { 29 | panic("No values specified.") 30 | } 31 | 32 | m := NewDenseBinaryMatrix(len(values), len(values[0])) 33 | for r := 0; r < m.Height; r++ { 34 | m.SetRowFromDense(r, values[r]) 35 | } 36 | return m 37 | } 38 | 39 | //Create sparse binary matrix from specified dense matrix 40 | func NewDenseBinaryMatrixFromDense1D(values []bool, rows, cols int) *DenseBinaryMatrix { 41 | if len(values) < 1 { 42 | panic("No values specified.") 43 | } 44 | if len(values) != rows*cols { 45 | panic("Invalid size") 46 | } 47 | 48 | m := NewDenseBinaryMatrix(rows, cols) 49 | 50 | for r := 0; r < m.Height; r++ { 51 | m.SetRowFromDense(r, values[r*cols:(r*cols)+cols]) 52 | } 53 | 54 | return m 55 | } 56 | 57 | // Creates a sparse binary matrix from specified integer array 58 | // (any values greater than 0 are true) 59 | func NewDenseBinaryMatrixFromInts(values [][]int) *DenseBinaryMatrix { 60 | if len(values) < 1 { 61 | panic("No values specified.") 62 | } 63 | 64 | m := NewDenseBinaryMatrix(len(values), len(values[0])) 65 | 66 | for r := 0; r < m.Height; r++ { 67 | for c := 0; c < m.Width; c++ { 68 | if values[r][c] > 0 { 69 | m.Set(r, c, true) 70 | } 71 | } 72 | } 73 | 74 | return m 75 | } 76 | 77 | //Converts index to col/row 78 | func (sm *DenseBinaryMatrix) toIndex(index int) (row int, col int) { 79 | row = index / sm.Width 80 | col = index % sm.Width 81 | return 82 | } 83 | 84 | //Returns all true/on indices 85 | func (sm *DenseBinaryMatrix) Entries() []SparseEntry { 86 | result := make([]SparseEntry, 0, int(float64(len(sm.entries))*0.3)) 87 | for idx, val := range sm.entries { 88 | if val { 89 | i, j := sm.toIndex(idx) 90 | result = append(result, SparseEntry{i, j}) 91 | } 92 | } 93 | return result 94 | } 95 | 96 | //Returns flattend dense represenation 97 | func (sm *DenseBinaryMatrix) Flatten() []bool { 98 | result := make([]bool, sm.Height*sm.Width) 99 | for _, val := range sm.Entries() { 100 | result[(val.Row*sm.Width)+val.Col] = true 101 | } 102 | return result 103 | } 104 | 105 | //Get value at col,row position 106 | func (sm *DenseBinaryMatrix) Get(row int, col int) bool { 107 | row = row % sm.Height 108 | if row < 0 { 109 | row = sm.Height - row 110 | } 111 | col = col % sm.Width 112 | if col < 0 { 113 | col = sm.Width - col 114 | } 115 | 116 | return sm.entries[row*sm.Width+col] 117 | } 118 | 119 | //Set value at row,col position 120 | func (sm *DenseBinaryMatrix) Set(row int, col int, value bool) { 121 | row = row % sm.Height 122 | if row < 0 { 123 | row = sm.Height - row 124 | } 125 | col = col % sm.Width 126 | if col < 0 { 127 | col = sm.Width - col 128 | } 129 | sm.entries[row*sm.Width+col] = value 130 | } 131 | 132 | //Replaces specified row with values, assumes values is ordered 133 | //correctly 134 | func (sm *DenseBinaryMatrix) ReplaceRow(row int, values []bool) { 135 | sm.validateRowCol(row, len(values)) 136 | 137 | for i := 0; i < sm.Width; i++ { 138 | sm.Set(row, i, values[i]) 139 | } 140 | } 141 | 142 | //Replaces row with true values at specified indices 143 | func (sm *DenseBinaryMatrix) ReplaceRowByIndices(row int, indices []int) { 144 | sm.validateRow(row) 145 | start := row * sm.Width 146 | for i := 0; i < sm.Width; i++ { 147 | sm.entries[start+i] = utils.ContainsInt(i, indices) 148 | } 149 | } 150 | 151 | //Returns dense row 152 | func (sm *DenseBinaryMatrix) GetDenseRow(row int) []bool { 153 | sm.validateRow(row) 154 | result := make([]bool, sm.Width) 155 | 156 | start := row * sm.Width 157 | for i := 0; i < sm.Width; i++ { 158 | result[i] = sm.entries[start+i] 159 | } 160 | 161 | return result 162 | } 163 | 164 | //Returns a rows "on" indices 165 | func (sm *DenseBinaryMatrix) GetRowIndices(row int) []int { 166 | result := make([]int, 0, sm.Width) 167 | start := row * sm.Width 168 | for i := 0; i < sm.Width; i++ { 169 | if sm.entries[start+i] { 170 | result = append(result, i) 171 | } 172 | } 173 | return result 174 | } 175 | 176 | //Sets a sparse row from dense representation 177 | func (sm *DenseBinaryMatrix) SetRowFromDense(row int, denseRow []bool) { 178 | //TODO: speed this up 179 | sm.validateRowCol(row, len(denseRow)) 180 | for i := 0; i < sm.Width; i++ { 181 | sm.Set(row, i, denseRow[i]) 182 | } 183 | } 184 | 185 | //In a normal matrix this would be multiplication in binary terms 186 | //we just and then sum the true entries 187 | func (sm *DenseBinaryMatrix) RowAndSum(row []bool) []int { 188 | sm.validateCol(len(row)) 189 | result := make([]int, sm.Height) 190 | 191 | for idx, val := range sm.entries { 192 | if val { 193 | r, c := sm.toIndex(idx) 194 | if row[c] { 195 | result[r]++ 196 | } 197 | } 198 | } 199 | 200 | return result 201 | } 202 | 203 | //Returns row indexes with at least 1 true column 204 | func (sm *DenseBinaryMatrix) NonZeroRows() []int { 205 | counts := make(map[int]int, sm.Height) 206 | 207 | for idx, val := range sm.entries { 208 | if val { 209 | r, _ := sm.toIndex(idx) 210 | counts[r]++ 211 | 212 | } 213 | } 214 | 215 | result := make([]int, 0, sm.Height) 216 | for k, v := range counts { 217 | if v > 0 && !utils.ContainsInt(k, result) { 218 | result = append(result, k) 219 | } 220 | } 221 | return result 222 | } 223 | 224 | //Returns # of rows with at least 1 true value 225 | func (sm *DenseBinaryMatrix) TotalTrueRows() int { 226 | return len(sm.NonZeroRows()) 227 | } 228 | 229 | //Returns total true entries 230 | func (sm *DenseBinaryMatrix) TotalNonZeroCount() int { 231 | return len(sm.Entries()) 232 | } 233 | 234 | // Ors 2 matrices 235 | func (sm *DenseBinaryMatrix) Or(sm2 *DenseBinaryMatrix) *DenseBinaryMatrix { 236 | result := NewDenseBinaryMatrix(sm.Height, sm.Width) 237 | 238 | for _, val := range sm.Entries() { 239 | result.Set(val.Row, val.Col, true) 240 | } 241 | 242 | for _, val := range sm2.Entries() { 243 | result.Set(val.Row, val.Col, true) 244 | } 245 | 246 | return result 247 | } 248 | 249 | //Clears all entries 250 | func (sm *DenseBinaryMatrix) Clear() { 251 | utils.FillSliceBool(sm.entries, false) 252 | } 253 | 254 | //Fills specified row with specified value 255 | func (sm *DenseBinaryMatrix) FillRow(row int, val bool) { 256 | for j := 0; j < sm.Width; j++ { 257 | sm.Set(row, j, val) 258 | } 259 | } 260 | 261 | //Copys a matrix 262 | func (sm *DenseBinaryMatrix) Copy() *DenseBinaryMatrix { 263 | if sm == nil { 264 | return nil 265 | } 266 | 267 | result := new(DenseBinaryMatrix) 268 | result.Width = sm.Width 269 | result.Height = sm.Height 270 | result.entries = make([]bool, len(sm.entries)) 271 | for idx, val := range sm.entries { 272 | result.entries[idx] = val 273 | } 274 | 275 | return result 276 | } 277 | 278 | func (sm *DenseBinaryMatrix) ToString() string { 279 | var buffer bytes.Buffer 280 | 281 | for r := 0; r < sm.Height; r++ { 282 | for c := 0; c < sm.Width; c++ { 283 | if sm.Get(r, c) { 284 | buffer.WriteByte('1') 285 | } else { 286 | buffer.WriteByte('0') 287 | } 288 | } 289 | buffer.WriteByte('\n') 290 | } 291 | 292 | return buffer.String() 293 | } 294 | 295 | func (sm *DenseBinaryMatrix) validateCol(col int) { 296 | if col > sm.Width { 297 | panic("Specified row is wider than matrix.") 298 | } 299 | } 300 | 301 | func (sm *DenseBinaryMatrix) validateRow(row int) { 302 | if row > sm.Height { 303 | panic("Specified row is out of bounds.") 304 | } 305 | } 306 | 307 | func (sm *DenseBinaryMatrix) validateRowCol(row int, col int) { 308 | sm.validateCol(col) 309 | sm.validateRow(row) 310 | } 311 | -------------------------------------------------------------------------------- /encoders/dateEncoder.go: -------------------------------------------------------------------------------- 1 | package encoders 2 | 3 | import ( 4 | "fmt" 5 | //"github.com/cznic/mathutil" 6 | //"github.com/zacg/floats" 7 | //"github.com/nupic-community/htm" 8 | "github.com/nupic-community/htm/utils" 9 | //"github.com/zacg/ints" 10 | //"math" 11 | "time" 12 | ) 13 | 14 | /* 15 | Params for the date encoder 16 | */ 17 | type DateEncoderParams struct { 18 | HolidayWidth int 19 | HolidayRadius float64 20 | SeasonWidth int 21 | SeasonRadius float64 22 | DayOfWeekWidth int 23 | DayOfWeekRadius float64 24 | WeekendWidth int 25 | WeekendRadius float64 26 | TimeOfDayWidth int 27 | TimeOfDayRadius float64 28 | //CustomDays int 29 | Name string 30 | //list of holidays stored as {mm,dd} 31 | Holidays []utils.TupleInt 32 | } 33 | 34 | func NewDateEncoderParams() *DateEncoderParams { 35 | p := new(DateEncoderParams) 36 | 37 | //set defaults 38 | p.SeasonWidth = 3 39 | p.DayOfWeekWidth = 1 40 | p.WeekendWidth = 3 41 | p.TimeOfDayWidth = 5 42 | 43 | p.SeasonRadius = 91.5 //days 44 | p.DayOfWeekRadius = 1 45 | p.TimeOfDayRadius = 4 46 | p.WeekendRadius = 1 47 | p.HolidayRadius = 1 48 | 49 | p.Holidays = []utils.TupleInt{{12, 25}} 50 | 51 | return p 52 | } 53 | 54 | /* 55 | Date encoder encodes a datetime to a SDR. Params allow for tuning 56 | for specific date attributes 57 | */ 58 | type DateEncoder struct { 59 | DateEncoderParams 60 | seasonEncoder *ScalerEncoder 61 | holidayEncoder *ScalerEncoder 62 | dayOfWeekEncoder *ScalerEncoder 63 | weekendEncoder *ScalerEncoder 64 | timeOfDayEncoder *ScalerEncoder 65 | 66 | width int 67 | seasonOffset int 68 | weekendOffset int 69 | dayOfWeekOffset int 70 | holidayOffset int 71 | timeOfDayOffset int 72 | } 73 | 74 | /* 75 | Intializes a new date encoder 76 | */ 77 | func NewDateEncoder(params *DateEncoderParams) *DateEncoder { 78 | de := new(DateEncoder) 79 | 80 | de.DateEncoderParams = *params 81 | 82 | de.width = 0 83 | 84 | if params.SeasonWidth != 0 { 85 | // Ignore leapyear differences -- assume 366 days in a year 86 | // Radius = 91.5 days = length of season 87 | // Value is number of days since beginning of year (0 - 355) 88 | 89 | sep := NewScalerEncoderParams(params.SeasonWidth, 0, 366) 90 | sep.Name = "Season" 91 | sep.Periodic = true 92 | sep.Radius = de.SeasonRadius 93 | de.seasonEncoder = NewScalerEncoder(sep) 94 | de.seasonOffset = de.width 95 | de.width += de.seasonEncoder.N 96 | } 97 | 98 | if params.DayOfWeekWidth != 0 { 99 | // Value is day of week (floating point) 100 | // Radius is 1 day 101 | 102 | sep := NewScalerEncoderParams(params.DayOfWeekWidth, 0, 7) 103 | sep.Name = "day of week" 104 | sep.Radius = de.DayOfWeekRadius 105 | sep.Periodic = true 106 | de.dayOfWeekEncoder = NewScalerEncoder(sep) 107 | de.dayOfWeekOffset = de.width 108 | de.width += de.dayOfWeekEncoder.N 109 | } 110 | 111 | if params.WeekendWidth != 0 { 112 | // Binary value. Not sure if this makes sense. Also is somewhat redundant 113 | // with dayOfWeek 114 | //Append radius if it was not provided 115 | 116 | sep := NewScalerEncoderParams(params.WeekendWidth, 0, 1) 117 | sep.Name = "weekend" 118 | sep.Radius = params.WeekendRadius 119 | de.weekendEncoder = NewScalerEncoder(sep) 120 | de.weekendOffset = de.width 121 | de.width += de.weekendEncoder.N 122 | } 123 | 124 | if params.HolidayWidth > 0 { 125 | // A "continuous" binary value. = 1 on the holiday itself and smooth ramp 126 | // 0->1 on the day before the holiday and 1->0 on the day after the holiday. 127 | 128 | sep := NewScalerEncoderParams(params.HolidayWidth, 0, 1) 129 | sep.Name = "holiday" 130 | sep.Radius = params.HolidayRadius 131 | de.holidayEncoder = NewScalerEncoder(sep) 132 | de.holidayOffset = de.width 133 | de.width += de.holidayEncoder.N 134 | } 135 | 136 | if params.TimeOfDayWidth > 0 { 137 | // Value is time of day in hours 138 | // Radius = 4 hours, e.g. morning, afternoon, evening, early night, 139 | // late night, etc. 140 | 141 | sep := NewScalerEncoderParams(params.TimeOfDayWidth, 0, 24) 142 | sep.Name = "time of day" 143 | sep.Radius = params.TimeOfDayRadius 144 | sep.Periodic = true 145 | de.timeOfDayEncoder = NewScalerEncoder(sep) 146 | de.timeOfDayOffset = de.width 147 | de.width += de.timeOfDayEncoder.N 148 | 149 | } 150 | 151 | return de 152 | } 153 | 154 | /* 155 | get season scaler from time 156 | */ 157 | func (de *DateEncoder) getSeasonScaler(date time.Time) float64 { 158 | if de.seasonEncoder == nil { 159 | return 0.0 160 | } 161 | 162 | //make year 0 based 163 | dayOfYear := float64(date.YearDay() - 1) 164 | return dayOfYear 165 | 166 | } 167 | 168 | /* 169 | get day of week scaler from time 170 | */ 171 | func (de *DateEncoder) getDayOfWeekScaler(date time.Time) float64 { 172 | if de.dayOfWeekEncoder == nil { 173 | return 0.0 174 | } 175 | return float64(date.Weekday()) 176 | } 177 | 178 | /* 179 | get weekend scaler from time 180 | */ 181 | func (de *DateEncoder) getWeekendScaler(date time.Time) float64 { 182 | if de.weekendEncoder == nil { 183 | return 0.0 184 | } 185 | dayOfWeek := date.Weekday() 186 | timeOfDay := date.Hour() + date.Minute()/60.0 187 | 188 | // saturday, sunday or friday evening 189 | weekend := 0.0 190 | if dayOfWeek == time.Saturday || 191 | dayOfWeek == time.Sunday || 192 | (dayOfWeek == time.Friday && timeOfDay > 18) { 193 | weekend = 1.0 194 | } 195 | return weekend 196 | } 197 | 198 | /* 199 | get holiday scaler from time 200 | */ 201 | func (de *DateEncoder) getHolidayScaler(date time.Time) float64 { 202 | if de.holidayEncoder == nil { 203 | return 0.0 204 | } 205 | // A "continuous" binary value. = 1 on the holiday itself and smooth ramp 206 | // 0->1 on the day before the holiday and 1->0 on the day after the holiday. 207 | // Currently the only holiday we know about is December 25 208 | // holidays is a list of holidays that occur on a fixed date every year 209 | val := 0.0 210 | 211 | for _, h := range de.Holidays { 212 | // hdate is midnight on the holiday 213 | hDate := time.Date(date.Year(), time.Month(h.A), h.B, 0, 0, 0, 0, time.UTC) 214 | if date.After(hDate) { 215 | diff := date.Sub(hDate) 216 | if (diff/time.Hour)/24 == 0 { 217 | val = 1 218 | break 219 | } else if (diff/time.Hour)/24 == 1 { 220 | // ramp smoothly from 1 -> 0 on the next day 221 | val = 1.0 - (float64(diff/time.Second) / (86400)) 222 | break 223 | } 224 | } else { 225 | diff := hDate.Sub(date) 226 | if (diff/time.Hour)/24 == 1 { 227 | // ramp smoothly from 0 -> 1 on the previous day 228 | val = 1.0 - (float64(diff/time.Second) / 86400) 229 | } 230 | 231 | } 232 | } 233 | 234 | return val 235 | 236 | } 237 | 238 | /* 239 | 240 | */ 241 | func (de *DateEncoder) getTimeOfDayScaler(date time.Time) float64 { 242 | if de.timeOfDayEncoder == nil { 243 | return 0.0 244 | } 245 | return float64(date.Hour()) + (float64(date.Minute()) / 60.0) 246 | 247 | } 248 | 249 | /* 250 | Encodes input to specifed slice 251 | */ 252 | func (de *DateEncoder) EncodeToSlice(date time.Time, output []bool) { 253 | 254 | learn := false 255 | 256 | // Get a scaler value for each subfield and encode it with the 257 | // appropriate encoder 258 | if de.seasonEncoder != nil { 259 | val := de.getSeasonScaler(date) 260 | de.seasonEncoder.EncodeToSlice(val, learn, output[de.seasonOffset:]) 261 | } 262 | 263 | if de.holidayEncoder != nil { 264 | val := de.getHolidayScaler(date) 265 | de.holidayEncoder.EncodeToSlice(val, learn, output[de.holidayOffset:]) 266 | } 267 | 268 | if de.dayOfWeekEncoder != nil { 269 | val := de.getDayOfWeekScaler(date) 270 | de.dayOfWeekEncoder.EncodeToSlice(val, learn, output[de.dayOfWeekOffset:]) 271 | } 272 | 273 | if de.weekendEncoder != nil { 274 | val := de.getWeekendScaler(date) 275 | de.weekendEncoder.EncodeToSlice(val, learn, output[de.weekendOffset:]) 276 | } 277 | 278 | if de.timeOfDayEncoder != nil { 279 | val := de.getTimeOfDayScaler(date) 280 | de.timeOfDayEncoder.EncodeToSlice(val, learn, output[de.timeOfDayOffset:]) 281 | } 282 | 283 | } 284 | 285 | /* 286 | Returns encoded date/time 287 | */ 288 | func (de *DateEncoder) Encode(date time.Time) []bool { 289 | output := make([]bool, de.width) 290 | de.EncodeToSlice(date, output) 291 | return output 292 | } 293 | 294 | /* 295 | Encoder description 296 | */ 297 | func (de *DateEncoder) Description(date time.Time) string { 298 | return fmt.Sprintf("season %v ", de.seasonOffset) + 299 | fmt.Sprintf(" day of week: %v", de.dayOfWeekOffset) + 300 | fmt.Sprintf(" weekend: %v", de.weekendOffset) + 301 | fmt.Sprintf(" holiday %v", de.holidayOffset) + 302 | fmt.Sprintf(" time of day: %v ", de.timeOfDayOffset) 303 | } 304 | -------------------------------------------------------------------------------- /sparseBinaryMatrix.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | //"math" 5 | "bytes" 6 | "github.com/nupic-community/htm/utils" 7 | ) 8 | 9 | //entries are positions of non-zero values 10 | type SparseEntry struct { 11 | Row int 12 | Col int 13 | } 14 | 15 | //Sparse binary matrix stores indexes of non-zero entries in matrix 16 | //to conserve space 17 | type SparseBinaryMatrix struct { 18 | Width int 19 | Height int 20 | entries []SparseEntry 21 | } 22 | 23 | //Create new sparse binary matrix of specified size 24 | func NewSparseBinaryMatrix(height, width int) *SparseBinaryMatrix { 25 | m := &SparseBinaryMatrix{} 26 | m.Height = height 27 | m.Width = width 28 | //Intialize with 70% sparsity 29 | //m.entries = make([]SparseEntry, int(math.Ceil(width*height*0.3))) 30 | return m 31 | } 32 | 33 | //Create sparse binary matrix from specified dense matrix 34 | func NewSparseBinaryMatrixFromDense(values [][]bool) *SparseBinaryMatrix { 35 | if len(values) < 1 { 36 | panic("No values specified.") 37 | } 38 | m := &SparseBinaryMatrix{} 39 | m.Height = len(values) 40 | m.Width = len(values[0]) 41 | 42 | for r := 0; r < m.Height; r++ { 43 | m.SetRowFromDense(r, values[r]) 44 | } 45 | 46 | return m 47 | } 48 | 49 | //Create sparse binary matrix from specified dense matrix 50 | func NewSparseBinaryMatrixFromDense1D(values []bool, rows, cols int) *SparseBinaryMatrix { 51 | if len(values) < 1 { 52 | panic("No values specified.") 53 | } 54 | if len(values) != rows*cols { 55 | panic("Invalid size") 56 | } 57 | 58 | m := new(SparseBinaryMatrix) 59 | m.Height = rows 60 | m.Width = cols 61 | 62 | for r := 0; r < m.Height; r++ { 63 | m.SetRowFromDense(r, values[r*cols:(r*cols)+cols]) 64 | } 65 | 66 | return m 67 | } 68 | 69 | // Creates a sparse binary matrix from specified integer array 70 | // (any values greater than 0 are true) 71 | func NewSparseBinaryMatrixFromInts(values [][]int) *SparseBinaryMatrix { 72 | if len(values) < 1 { 73 | panic("No values specified.") 74 | } 75 | 76 | m := &SparseBinaryMatrix{} 77 | m.Height = len(values) 78 | m.Width = len(values[0]) 79 | 80 | for r := 0; r < m.Height; r++ { 81 | for c := 0; c < m.Width; c++ { 82 | if values[r][c] > 0 { 83 | m.Set(r, c, true) 84 | } 85 | } 86 | } 87 | 88 | return m 89 | } 90 | 91 | // func NewRandSparseBinaryMatrix() *SparseBinaryMatrix { 92 | // } 93 | 94 | // func (sm *SparseBinaryMatrix) Resize(width int, height int) { 95 | // } 96 | 97 | //Returns all true/on indices 98 | func (sm *SparseBinaryMatrix) Entries() []SparseEntry { 99 | return sm.entries 100 | } 101 | 102 | //Returns flattend dense represenation 103 | func (sm *SparseBinaryMatrix) Flatten() []bool { 104 | result := make([]bool, sm.Height*sm.Width) 105 | for _, val := range sm.entries { 106 | result[(val.Row*sm.Width)+val.Col] = true 107 | } 108 | return result 109 | } 110 | 111 | //Get value at col,row position 112 | func (sm *SparseBinaryMatrix) Get(row int, col int) bool { 113 | for _, val := range sm.entries { 114 | if val.Row == row && val.Col == col { 115 | return true 116 | } 117 | } 118 | return false 119 | } 120 | 121 | func (sm *SparseBinaryMatrix) delete(row int, col int) { 122 | for idx, val := range sm.entries { 123 | if val.Row == row && val.Col == col { 124 | sm.entries = append(sm.entries[:idx], sm.entries[idx+1:]...) 125 | break 126 | } 127 | } 128 | } 129 | 130 | //Set value at row,col position 131 | func (sm *SparseBinaryMatrix) Set(row int, col int, value bool) { 132 | if !value { 133 | sm.delete(row, col) 134 | return 135 | } 136 | 137 | if sm.Get(row, col) { 138 | return 139 | } 140 | 141 | newEntry := SparseEntry{} 142 | newEntry.Col = col 143 | newEntry.Row = row 144 | sm.entries = append(sm.entries, newEntry) 145 | 146 | } 147 | 148 | //Replaces specified row with values, assumes values is ordered 149 | //correctly 150 | func (sm *SparseBinaryMatrix) ReplaceRow(row int, values []bool) { 151 | sm.validateRowCol(row, len(values)) 152 | 153 | for i := 0; i < sm.Width; i++ { 154 | sm.Set(row, i, values[i]) 155 | } 156 | } 157 | 158 | //Replaces row with true values at specified indices 159 | func (sm *SparseBinaryMatrix) ReplaceRowByIndices(row int, indices []int) { 160 | sm.validateRow(row) 161 | 162 | for i := 0; i < sm.Width; i++ { 163 | val := false 164 | for x := 0; x < len(indices); x++ { 165 | if i == indices[x] { 166 | val = true 167 | break 168 | } 169 | } 170 | sm.Set(row, i, val) 171 | } 172 | } 173 | 174 | //Returns dense row 175 | func (sm *SparseBinaryMatrix) GetDenseRow(row int) []bool { 176 | sm.validateRow(row) 177 | result := make([]bool, sm.Width) 178 | 179 | for i := 0; i < len(sm.entries); i++ { 180 | if sm.entries[i].Row == row { 181 | result[sm.entries[i].Col] = true 182 | } 183 | } 184 | 185 | return result 186 | } 187 | 188 | //Returns a rows "on" indices 189 | func (sm *SparseBinaryMatrix) GetRowIndices(row int) []int { 190 | result := []int{} 191 | for i := 0; i < len(sm.entries); i++ { 192 | if sm.entries[i].Row == row { 193 | result = append(result, sm.entries[i].Col) 194 | } 195 | } 196 | return result 197 | } 198 | 199 | //Sets a sparse row from dense representation 200 | func (sm *SparseBinaryMatrix) SetRowFromDense(row int, denseRow []bool) { 201 | sm.validateRowCol(row, len(denseRow)) 202 | for i := 0; i < sm.Width; i++ { 203 | sm.Set(row, i, denseRow[i]) 204 | } 205 | } 206 | 207 | //In a normal matrix this would be multiplication in binary terms 208 | //we just and then sum the true entries 209 | func (sm *SparseBinaryMatrix) RowAndSum(row []bool) []int { 210 | sm.validateCol(len(row)) 211 | result := make([]int, sm.Height) 212 | 213 | for _, val := range sm.entries { 214 | if row[val.Col] { 215 | result[val.Row]++ 216 | } 217 | } 218 | 219 | return result 220 | } 221 | 222 | //Returns row indexes with at least 1 true column 223 | func (sm *SparseBinaryMatrix) NonZeroRows() []int { 224 | var result []int 225 | 226 | for _, val := range sm.entries { 227 | if !utils.ContainsInt(val.Row, result) { 228 | result = append(result, val.Row) 229 | } 230 | } 231 | 232 | return result 233 | } 234 | 235 | //Returns # of rows with at least 1 true value 236 | func (sm *SparseBinaryMatrix) TotalTrueRows() int { 237 | var hitRows []int 238 | for _, val := range sm.entries { 239 | if !utils.ContainsInt(val.Row, hitRows) { 240 | hitRows = append(hitRows, val.Row) 241 | } 242 | } 243 | return len(hitRows) 244 | } 245 | 246 | //Returns # of cols with at least 1 true value 247 | func (sm *SparseBinaryMatrix) TotalTrueCols() int { 248 | var hitCols []int 249 | for _, val := range sm.entries { 250 | if !utils.ContainsInt(val.Col, hitCols) { 251 | hitCols = append(hitCols, val.Col) 252 | } 253 | } 254 | return len(hitCols) 255 | } 256 | 257 | //Returns total true entries 258 | func (sm *SparseBinaryMatrix) TotalNonZeroCount() int { 259 | return len(sm.entries) 260 | } 261 | 262 | // Ors 2 matrices 263 | func (sm *SparseBinaryMatrix) Or(sm2 *SparseBinaryMatrix) *SparseBinaryMatrix { 264 | result := NewSparseBinaryMatrix(sm.Height, sm.Width) 265 | 266 | for _, val := range sm.entries { 267 | result.Set(val.Row, val.Col, true) 268 | } 269 | 270 | for _, val := range sm2.entries { 271 | result.Set(val.Row, val.Col, true) 272 | } 273 | 274 | return result 275 | } 276 | 277 | //Clears all entries 278 | func (sm *SparseBinaryMatrix) Clear() { 279 | sm.entries = nil 280 | } 281 | 282 | //Fills specified row with specified value 283 | func (sm *SparseBinaryMatrix) FillRow(row int, val bool) { 284 | for j := 0; j < sm.Width; j++ { 285 | sm.Set(row, j, val) 286 | } 287 | } 288 | 289 | //Copys a matrix 290 | func (sm *SparseBinaryMatrix) Copy() *SparseBinaryMatrix { 291 | if sm == nil { 292 | return nil 293 | } 294 | 295 | result := new(SparseBinaryMatrix) 296 | result.Width = sm.Width 297 | result.Height = sm.Height 298 | result.entries = make([]SparseEntry, len(sm.entries)) 299 | for idx, val := range sm.entries { 300 | result.entries[idx] = val 301 | } 302 | 303 | return result 304 | } 305 | 306 | func (sm *SparseBinaryMatrix) ToString() string { 307 | var buffer bytes.Buffer 308 | 309 | for r := 0; r < sm.Height; r++ { 310 | for c := 0; c < sm.Width; c++ { 311 | if sm.Get(r, c) { 312 | buffer.WriteByte('1') 313 | } else { 314 | buffer.WriteByte('0') 315 | } 316 | } 317 | buffer.WriteByte('\n') 318 | } 319 | 320 | return buffer.String() 321 | } 322 | 323 | func (sm *SparseBinaryMatrix) validateCol(col int) { 324 | if col > sm.Width { 325 | panic("Specified row is wider than matrix.") 326 | } 327 | } 328 | 329 | func (sm *SparseBinaryMatrix) validateRow(row int) { 330 | if row > sm.Height { 331 | panic("Specified row is out of bounds.") 332 | } 333 | } 334 | 335 | func (sm *SparseBinaryMatrix) validateRowCol(row int, col int) { 336 | sm.validateCol(col) 337 | sm.validateRow(row) 338 | } 339 | -------------------------------------------------------------------------------- /temporalPoolerPrint.go: -------------------------------------------------------------------------------- 1 | // 2 | // Code related to temporal pooler printing 3 | // 4 | 5 | package htm 6 | 7 | import ( 8 | "fmt" 9 | //"github.com/cznic/mathutil" 10 | //"github.com/zacg/go.matrix" 11 | //"math" 12 | //"math/rand" 13 | //"sort" 14 | //"github.com/gonum/floats" 15 | //"github.com/zacg/ints" 16 | "github.com/nupic-community/htm/utils" 17 | ) 18 | 19 | type SegmentStats struct { 20 | NumSegments int 21 | NumSynapses int 22 | NumActiveSynapses int 23 | DistSegSizes float64 24 | DistNumSegsPerCell float64 25 | DistPermValues float64 26 | DistAges float64 27 | } 28 | 29 | /* 30 | Returns information about the distribution of segments, synapses and 31 | permanence values in the current TP. If requested, also returns information 32 | regarding the number of currently active segments and synapses. 33 | 34 | */ 35 | func (tp *TemporalPooler) calcSegmentStats(collectActiveData bool) SegmentStats { 36 | result := SegmentStats{} 37 | 38 | var distNSegsPerCell map[int]int 39 | var distSegSizes map[int]int 40 | var distPermValues map[int]int 41 | 42 | numAgeBuckets := 20 43 | ageBucketSize := int((tp.lrnIterationIdx + 20) / 20) 44 | 45 | distAges := make(map[int]int, numAgeBuckets) 46 | distAgesLabels := make([]string, numAgeBuckets) 47 | for i := 0; i < numAgeBuckets; i++ { 48 | distAgesLabels[i] = fmt.Sprintf("%v-%v", i*ageBucketSize, (i+1)*ageBucketSize-1) 49 | } 50 | 51 | distNSegsPerCell = make(map[int]int, 1000) 52 | distSegSizes = make(map[int]int, 1000) 53 | distPermValues = make(map[int]int, 1000) 54 | 55 | for _, col := range tp.cells { 56 | for _, cell := range col { 57 | 58 | nSegmentsThisCell := len(cell) 59 | result.NumSegments += nSegmentsThisCell 60 | 61 | if _, ok := distNSegsPerCell[nSegmentsThisCell]; ok { 62 | distNSegsPerCell[nSegmentsThisCell]++ 63 | } else { 64 | distNSegsPerCell[nSegmentsThisCell] = 1 65 | } 66 | 67 | for _, seg := range cell { 68 | nSynapsesThisSeg := len(seg.syns) 69 | result.NumSynapses += nSynapsesThisSeg 70 | 71 | if _, ok := distSegSizes[nSynapsesThisSeg]; ok { 72 | distSegSizes[nSynapsesThisSeg]++ 73 | } else { 74 | distSegSizes[nSynapsesThisSeg] = 1 75 | } 76 | 77 | // Accumulate permanence value histogram 78 | for _, syn := range seg.syns { 79 | p := int(syn.Permanence * 10) 80 | if _, ok := distPermValues[p]; ok { 81 | distPermValues[p]++ 82 | } else { 83 | distPermValues[p] = 1 84 | } 85 | } 86 | 87 | // Accumulate segment age histogram 88 | age := tp.lrnIterationIdx - seg.lastActiveIteration 89 | ageBucket := int(age / ageBucketSize) 90 | distAges[ageBucket]++ 91 | 92 | // Get active synapse statistics if requested 93 | if collectActiveData { 94 | if tp.isSegmentActive(seg, tp.DynamicState.InfActiveState) { 95 | result.NumSegments++ 96 | } 97 | for _, syn := range seg.syns { 98 | if tp.DynamicState.InfActiveState.Get(syn.SrcCellIdx, syn.SrcCellCol) { 99 | result.NumActiveSynapses++ 100 | } 101 | } 102 | } 103 | 104 | } 105 | } 106 | } 107 | 108 | return result 109 | } 110 | 111 | /* 112 | Print the list of [column, cellIdx] indices for each of the active 113 | cells in state. 114 | */ 115 | func (tp *TemporalPooler) printActiveIndices(state *SparseBinaryMatrix, andValues bool) { 116 | if state.TotalNonZeroCount() == 0 { 117 | fmt.Println("None") 118 | return 119 | } 120 | 121 | fmt.Println(state.Entries()) 122 | 123 | } 124 | 125 | /* 126 | Prints a cels information 127 | */ 128 | func (tp *TemporalPooler) printCell(c int, i int, onlyActiveSegments bool) { 129 | 130 | cell := tp.cells[c][i] 131 | 132 | if len(cell) > 0 { 133 | fmt.Printf("Column: %v Cell: %v - %v segment(s)", c, i, len(cell)) 134 | for idx, seg := range cell { 135 | isActive := tp.isSegmentActive(seg, tp.DynamicState.InfActiveState) 136 | if !onlyActiveSegments || isActive { 137 | str := " " 138 | if isActive { 139 | str = "*" 140 | } 141 | fmt.Printf("%vSeg: %v", str, idx) 142 | fmt.Println(seg.ToString()) 143 | } 144 | } 145 | } 146 | 147 | } 148 | 149 | /* 150 | Print all cell information 151 | */ 152 | func (tp *TemporalPooler) printCells(predictedOnly bool) { 153 | 154 | if predictedOnly { 155 | fmt.Println("--- PREDICTED CELLS ---") 156 | } else { 157 | fmt.Println("--- ALL CELLS ---") 158 | } 159 | 160 | fmt.Println("Activation threshold:", tp.params.ActivationThreshold) 161 | fmt.Println("min threshold:", tp.params.MinThreshold) 162 | fmt.Println("connected perm:", tp.params.ConnectedPerm) 163 | 164 | for c, col := range tp.cells { 165 | for i := range col { 166 | if !predictedOnly || tp.DynamicState.InfPredictedState.Get(c, i) { 167 | tp.printCell(c, i, predictedOnly) 168 | } 169 | } 170 | } 171 | 172 | } 173 | 174 | /* 175 | Called at the end of inference to print out various diagnostic 176 | information based on the current verbosity level. 177 | */ 178 | func (tp *TemporalPooler) printComputeEnd(output []bool, learn bool) { 179 | 180 | if tp.params.Verbosity < 3 { 181 | if tp.params.Verbosity >= 1 { 182 | fmt.Println("TP: learn:", learn) 183 | fmt.Printf("TP: active outputs(%v):\n", utils.CountTrue(output)) 184 | fmt.Print(NewSparseBinaryMatrixFromDense1D(output, 185 | tp.params.NumberOfCols, tp.params.CellsPerColumn).ToString()) 186 | } 187 | return 188 | } 189 | 190 | fmt.Println("----- computeEnd summary: ") 191 | fmt.Println("learn:", learn) 192 | bursting := 0 193 | counts := make([]int, tp.DynamicState.InfActiveState.Height) 194 | for _, val := range tp.DynamicState.InfActiveState.Entries() { 195 | counts[val.Row]++ 196 | if counts[val.Row] == tp.DynamicState.InfActiveState.Width { 197 | bursting++ 198 | } 199 | } 200 | fmt.Println("numBurstingCols:", bursting) 201 | fmt.Println("curPredScore2:", tp.internalStats.CurPredictionScore2) 202 | fmt.Println("curFalsePosScore", tp.internalStats.CurFalsePositiveScore) 203 | fmt.Println("1-curFalseNegScore", 1-tp.internalStats.CurFalseNegativeScore) 204 | fmt.Println("avgLearnedSeqLength", tp.avgLearnedSeqLength) 205 | 206 | stats := tp.calcSegmentStats(true) 207 | fmt.Println("numSegments", stats.NumSegments) 208 | 209 | fmt.Printf("----- InfActiveState (%v on) ------\n", tp.DynamicState.InfActiveState.TotalNonZeroCount()) 210 | tp.printActiveIndices(tp.DynamicState.InfActiveState, false) 211 | 212 | if tp.params.Verbosity >= 6 { 213 | //tp.printState(tp.InfActiveState['t']) 214 | //fmt.Println(tp.DynamicState.InfActiveState.ToString()) 215 | } 216 | 217 | fmt.Printf("----- InfPredictedState (%v on)-----\n", tp.DynamicState.InfPredictedState.TotalNonZeroCount()) 218 | tp.printActiveIndices(tp.DynamicState.InfPredictedState, false) 219 | if tp.params.Verbosity >= 6 { 220 | //fmt.Println(tp.DynamicState.InfPredictedState.ToString()) 221 | } 222 | 223 | fmt.Printf("----- LrnActiveState (%v on) ------\n", tp.DynamicState.LrnActiveState.TotalNonZeroCount()) 224 | tp.printActiveIndices(tp.DynamicState.LrnActiveState, false) 225 | if tp.params.Verbosity >= 6 { 226 | //fmt.Println(tp.DynamicState.LrnActiveState.ToString()) 227 | } 228 | 229 | fmt.Printf("----- LrnPredictedState (%v on)-----\n", tp.DynamicState.LrnPredictedState.TotalNonZeroCount()) 230 | tp.printActiveIndices(tp.DynamicState.LrnPredictedState, false) 231 | if tp.params.Verbosity >= 6 { 232 | //fmt.Println(tp.DynamicState.LrnPredictedState.ToString()) 233 | } 234 | 235 | fmt.Println("----- CellConfidence -----") 236 | //tp.printActiveIndices(tp.DynamicState.CellConfidence, true) 237 | 238 | if tp.params.Verbosity >= 6 { 239 | //TODO: this 240 | //tp.printConfidence(tp.DynamicState.CellConfidence) 241 | for r := 0; r < tp.DynamicState.CellConfidence.Rows(); r++ { 242 | for c := 0; c < tp.DynamicState.CellConfidenceLast.Cols(); c++ { 243 | if tp.DynamicState.CellConfidence.Get(r, c) != 0 { 244 | fmt.Printf("[%v,%v,%v]", r, c, tp.DynamicState.CellConfidence.Get(r, c)) 245 | } 246 | } 247 | } 248 | 249 | } 250 | 251 | fmt.Println("----- ColConfidence -----") 252 | //tp.printActiveIndices(tp.DynamicState.ColConfidence, true) 253 | fmt.Println("----- CellConfidence[t-1] for currently active cells -----") 254 | //cc := matrix.ZerosSparse(tp.DynamicState.CellConfidence.Rows(), tp.DynamicState.CellConfidence.Cols()) 255 | for _, val := range tp.DynamicState.InfActiveState.Entries() { 256 | //cc.Set(val.Row, val.Col, tp.DynamicState.CellConfidence.Get(val.Row, val.Col)) 257 | fmt.Printf("[%v,%v,%v]", val.Row, val.Col, tp.DynamicState.CellConfidence.Get(val.Row, val.Col)) 258 | 259 | } 260 | //fmt.Println(cc.String()) 261 | 262 | if tp.params.Verbosity == 4 { 263 | fmt.Println("Cells, predicted segments only:") 264 | tp.printCells(true) 265 | } else if tp.params.Verbosity >= 5 { 266 | fmt.Println("Cells, all segments:") 267 | tp.printCells(true) 268 | } 269 | 270 | } 271 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | //"fmt" 5 | "fmt" 6 | "math" 7 | "math/big" 8 | "math/rand" 9 | "time" 10 | ) 11 | 12 | type TupleInt struct { 13 | A int 14 | B int 15 | } 16 | 17 | type TupleFloat struct { 18 | A float64 19 | B float64 20 | } 21 | 22 | //Euclidean modulous 23 | func Mod(a, b int) int { 24 | ab := big.NewInt(int64(a)) 25 | bb := big.NewInt(int64(b)) 26 | return int(ab.Mod(ab, bb).Int64()) 27 | } 28 | 29 | //Dot product 30 | func DotInt(a, b []int) int { 31 | if len(a) != len(b) { 32 | panic("Params have differing lengths") 33 | } 34 | result := 0 35 | for i := range a { 36 | result += a[i] * b[i] 37 | } 38 | return result 39 | } 40 | 41 | //Populates integer slice with index values 42 | func FillSliceWithIdxInt(values []int) { 43 | for i := range values { 44 | values[i] = i 45 | } 46 | } 47 | 48 | //Populates float64 slice with specified value 49 | func FillSliceInt(values []int, value int) { 50 | for i := range values { 51 | values[i] = value 52 | } 53 | } 54 | 55 | //Populates float64 slice with specified value 56 | func FillSliceFloat64(values []float64, value float64) { 57 | for i := range values { 58 | values[i] = value 59 | } 60 | } 61 | 62 | //Populates bool slice with specified value 63 | func FillSliceBool(values []bool, value bool) { 64 | for i := range values { 65 | values[i] = value 66 | } 67 | } 68 | 69 | //Populates bool slice with specified value 70 | func FillSliceRangeBool(values []bool, value bool, start, length int) { 71 | for i := 0; i < length; i++ { 72 | values[start+i] = value 73 | } 74 | } 75 | 76 | //Returns the subset of values specified by indices 77 | func SubsetSliceInt(values, indices []int) []int { 78 | result := make([]int, len(indices)) 79 | for i, val := range indices { 80 | result[i] = values[val] 81 | } 82 | return result 83 | } 84 | 85 | //Returns the subset of values specified by indices 86 | func SubsetSliceFloat64(values []float64, indices []int) []float64 { 87 | result := make([]float64, len(indices)) 88 | for i, val := range indices { 89 | result[i] = values[val] 90 | } 91 | return result 92 | } 93 | 94 | //returns a copy of specified indices 95 | func SubsetSliceBool(values []bool, indices []int) []bool { 96 | result := make([]bool, len(indices)) 97 | for i, val := range indices { 98 | result[i] = values[val] 99 | } 100 | return result 101 | } 102 | 103 | //sets the specified indexes of a bool slice to specified value 104 | func SetIdxBool(values []bool, indexes []int, value bool) { 105 | for _, val := range indexes { 106 | values[val] = value 107 | } 108 | } 109 | 110 | //Compares 2 bool slices for equality 111 | func BoolEq(a, b []bool) bool { 112 | if len(a) != len(b) { 113 | return false 114 | } 115 | 116 | for idx, val := range a { 117 | if val != b[idx] { 118 | return false 119 | } 120 | } 121 | 122 | return true 123 | } 124 | 125 | //Creates an integer slice with indices containing 126 | // the specified initial value 127 | func MakeSliceInt(size, initialValue int) []int { 128 | result := make([]int, size) 129 | if initialValue != 0 { 130 | for i, _ := range result { 131 | result[i] = initialValue 132 | } 133 | } 134 | return result 135 | } 136 | 137 | func MakeSliceFloat64(size int, initialValue float64) []float64 { 138 | result := make([]float64, size) 139 | if initialValue != 0 { 140 | for i, _ := range result { 141 | result[i] = initialValue 142 | } 143 | } 144 | return result 145 | } 146 | 147 | //Returns cartesian product of specified 148 | //2d arrayb 149 | func CartProductInt(values [][]int) [][]int { 150 | pos := make([]int, len(values)) 151 | var result [][]int 152 | 153 | for pos[0] < len(values[0]) { 154 | temp := make([]int, len(values)) 155 | for j := 0; j < len(values); j++ { 156 | temp[j] = values[j][pos[j]] 157 | } 158 | result = append(result, temp) 159 | pos[len(values)-1]++ 160 | for k := len(values) - 1; k >= 1; k-- { 161 | if pos[k] >= len(values[k]) { 162 | pos[k] = 0 163 | pos[k-1]++ 164 | } else { 165 | break 166 | } 167 | } 168 | } 169 | return result 170 | } 171 | 172 | //Searches int slice for specified integer 173 | func ContainsInt(q int, vals []int) bool { 174 | for _, val := range vals { 175 | if val == q { 176 | return true 177 | } 178 | } 179 | return false 180 | } 181 | 182 | func ContainsFloat64(q float64, vals []float64) bool { 183 | for _, val := range vals { 184 | if val == q { 185 | return true 186 | } 187 | } 188 | return false 189 | } 190 | 191 | // type CompareInt func(int) bool 192 | 193 | // func CountInt(q CompareInt, vals []int) int { 194 | // count := 0 195 | // for i := range vals { 196 | // if q(i) { 197 | // count++ 198 | // } 199 | // } 200 | // return count 201 | // } 202 | 203 | func RandFloatRange(min, max float64) float64 { 204 | return rand.Float64()*(max-min) + min 205 | } 206 | 207 | //returns max index wise comparison 208 | func MaxInt(a, b []int) []int { 209 | result := make([]int, len(a)) 210 | for i := 0; i < len(a); i++ { 211 | if a[i] > b[i] { 212 | result[i] = a[i] 213 | } else { 214 | result[i] = b[i] 215 | } 216 | } 217 | 218 | return result 219 | } 220 | 221 | //Returns max value from specified int slice 222 | func MaxSliceInt(values []int) int { 223 | max := 0 224 | for i := 0; i < len(values); i++ { 225 | if values[i] > max { 226 | max = values[i] 227 | } 228 | } 229 | return max 230 | } 231 | 232 | //Returns max value from specified float slice 233 | func MaxSliceFloat64(values []float64) float64 { 234 | max := 0.0 235 | for i := 0; i < len(values); i++ { 236 | if values[i] > max { 237 | max = values[i] 238 | } 239 | } 240 | return max 241 | } 242 | 243 | //Returns product of set of integers 244 | func ProdInt(vals []int) int { 245 | sum := 1 246 | for x := 0; x < len(vals); x++ { 247 | sum *= vals[x] 248 | } 249 | 250 | if sum == 1 { 251 | return 0 252 | } else { 253 | return sum 254 | } 255 | } 256 | 257 | //Returns cumulative product 258 | func CumProdInt(vals []int) []int { 259 | if len(vals) < 2 { 260 | return vals 261 | } 262 | result := make([]int, len(vals)) 263 | result[0] = vals[0] 264 | for x := 1; x < len(vals); x++ { 265 | result[x] = vals[x] * result[x-1] 266 | } 267 | 268 | return result 269 | } 270 | 271 | //Returns cumulative product starting from end 272 | func RevCumProdInt(vals []int) []int { 273 | if len(vals) < 2 { 274 | return vals 275 | } 276 | result := make([]int, len(vals)) 277 | result[len(vals)-1] = vals[len(vals)-1] 278 | for x := len(vals) - 2; x >= 0; x-- { 279 | result[x] = vals[x] * result[x+1] 280 | } 281 | 282 | return result 283 | } 284 | 285 | func RoundPrec(x float64, prec int) float64 { 286 | if math.IsNaN(x) || math.IsInf(x, 0) { 287 | return x 288 | } 289 | 290 | sign := 1.0 291 | if x < 0 { 292 | sign = -1 293 | x *= -1 294 | } 295 | 296 | var rounder float64 297 | pow := math.Pow(10, float64(prec)) 298 | intermed := x * pow 299 | _, frac := math.Modf(intermed) 300 | 301 | if frac >= 0.5 { 302 | rounder = math.Ceil(intermed) 303 | } else { 304 | rounder = math.Floor(intermed) 305 | } 306 | 307 | return rounder / pow * sign 308 | } 309 | 310 | //Helper for unit tests where int literals are easier 311 | // to read 312 | func Make2DBool(values [][]int) [][]bool { 313 | result := make([][]bool, len(values)) 314 | 315 | for i, val := range values { 316 | result[i] = make([]bool, len(val)) 317 | for j, col := range val { 318 | result[i][j] = col == 1 319 | } 320 | } 321 | 322 | return result 323 | } 324 | 325 | func Make1DBool(values []int) []bool { 326 | result := make([]bool, len(values)) 327 | for i, val := range values { 328 | result[i] = val == 1 329 | } 330 | return result 331 | } 332 | 333 | //Returns number of on bits 334 | func CountInt(values []int, value int) int { 335 | count := 0 336 | for _, val := range values { 337 | if val == value { 338 | count++ 339 | } 340 | } 341 | return count 342 | } 343 | 344 | //Returns number of on bits 345 | func CountFloat64(values []float64, value float64) int { 346 | count := 0 347 | for _, val := range values { 348 | if val == value { 349 | count++ 350 | } 351 | } 352 | return count 353 | } 354 | 355 | //Returns number of on bits 356 | func CountTrue(values []bool) int { 357 | count := 0 358 | for _, val := range values { 359 | if val { 360 | count++ 361 | } 362 | } 363 | return count 364 | } 365 | 366 | //Returns number of on bits 367 | func AnyTrue(values []bool) bool { 368 | for _, val := range values { 369 | if val { 370 | return true 371 | } 372 | } 373 | return false 374 | } 375 | 376 | //Or's 2 bool slices 377 | func OrBool(a, b []bool) []bool { 378 | result := make([]bool, len(a)) 379 | for i, val := range a { 380 | result[i] = val || b[i] 381 | } 382 | return result 383 | } 384 | 385 | //Returns random slice of floats of specified length 386 | func RandomSample(length int) []float64 { 387 | result := make([]float64, length) 388 | 389 | for i, _ := range result { 390 | result[i] = rand.Float64() 391 | } 392 | 393 | return result 394 | } 395 | 396 | func Bool2Int(s []bool) []int { 397 | result := make([]int, len(s)) 398 | for idx, val := range s { 399 | if val { 400 | result[idx] = 1 401 | } else { 402 | result[idx] = 0 403 | } 404 | 405 | } 406 | return result 407 | } 408 | 409 | func timeTrack(start time.Time, name string) { 410 | elapsed := time.Since(start) 411 | fmt.Printf("%s took %s \n", name, elapsed) 412 | } 413 | 414 | func SumSliceFloat64(values []float64) float64 { 415 | result := 0.0 416 | for _, val := range values { 417 | result += val 418 | } 419 | return result 420 | } 421 | 422 | //Returns "on" indices 423 | func OnIndices(s []bool) []int { 424 | var result []int 425 | for idx, val := range s { 426 | if val { 427 | result = append(result, idx) 428 | } 429 | } 430 | return result 431 | } 432 | 433 | // Returns complement of s and t 434 | func Complement(s []int, t []int) []int { 435 | result := make([]int, 0, len(s)) 436 | for _, val := range s { 437 | found := false 438 | for _, v2 := range t { 439 | if v2 == val { 440 | found = true 441 | break 442 | } 443 | } 444 | if !found { 445 | result = append(result, val) 446 | } 447 | } 448 | return result 449 | } 450 | 451 | func Add(s []int, t []int) []int { 452 | result := make([]int, 0, len(s)+len(t)) 453 | result = append(result, s...) 454 | 455 | for _, val := range t { 456 | if !ContainsInt(val, s) { 457 | result = append(result, val) 458 | } 459 | } 460 | return result 461 | } 462 | -------------------------------------------------------------------------------- /spatialPoolerBoost_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | //"fmt" 5 | //"github.com/skelterjohn/go.matrix" 6 | //"github.com/stretchr/testify/assert" 7 | "github.com/nupic-community/htm/utils" 8 | "github.com/zacg/testify/assert" 9 | //"math/big" 10 | //"github.com/stretchr/testify/mock" 11 | //"math" 12 | //"math/rand" 13 | //"strconv" 14 | "testing" 15 | ) 16 | 17 | /* 18 | Test boosting. 19 | The test is constructed as follows: we construct a set of 5 known inputs. Two 20 | of the input patterns have 50% overlap while all other combinations have 0% 21 | overlap. Each input pattern has 20 bits on to ensure reasonable overlap with 22 | almost all columns. 23 | 24 | SP parameters: the minActiveDutyCycle is set to 1 in 10. This allows us to 25 | test boosting with a small number of iterations. The SP is set to have 600 26 | columns with 10% output sparsity. This ensures that the 5 inputs cannot use up 27 | all the columns. Yet we still can have a reasonable number of winning columns 28 | at each step in order to test overlap properties. maxBoost is set to 10 so 29 | that some boosted columns are guaranteed to win eventually but not necessarily 30 | quickly. potentialPct is set to 0.9 to ensure all columns have at least some 31 | overlap with at least one input bit. Thus, when sufficiently boosted, every 32 | column should become a winner at some point. We set permanence increment 33 | and decrement to 0 so that winning columns don't change unless they have 34 | been boosted. 35 | 36 | Phase 1: As learning progresses through the first 5 iterations, the first 5 37 | patterns should get distinct output SDRs. The two overlapping input patterns 38 | should have reasonably overlapping output SDRs. The other pattern 39 | combinations should have very little overlap. The boost factor for all 40 | columns should be at 1. At this point least half of the columns should have 41 | never become active and these columns should have duty cycle of 0. Any 42 | columns which have won, should have duty cycles >= 0.2. 43 | 44 | Phase 2: Over the next 45 iterations, boosting should stay at 1 for all 45 | columns since minActiveDutyCycle is only calculated after 50 iterations. The 46 | winning columns should be very similar (identical?) to the columns that won 47 | before. About half of the columns should never become active. At the end of 48 | the this phase, most of these columns should have activity level around 0.2. 49 | It's ok for some columns to have higher activity levels. 50 | 51 | Phase 3: At this point about half or fewer columns have never won. These 52 | should get boosted to maxBoost and start to win. As each one wins, their 53 | boost gets lowered to 1. After 2 batches, the number of columns that 54 | have never won should be 0. Because of the artificially induced thrashing 55 | behavior in this test, all the inputs should now have pretty distinct 56 | patterns. During this process, as soon as a new column wins, the boost value 57 | for that column should be set back to 1. 58 | 59 | Phase 4: Run for 5 iterations without learning on. Boost values and winners 60 | should not change. 61 | */ 62 | 63 | type boostTest struct { 64 | sp *SpatialPooler 65 | x [][]bool 66 | winningIteration []int 67 | lastSDR [][]bool 68 | } 69 | 70 | //Returns overlap of the 2 specified sdrs 71 | func computeOverlap(x, y []bool) int { 72 | result := 0 73 | for idx, val := range x { 74 | if val && y[idx] { 75 | result++ 76 | } 77 | } 78 | return result 79 | } 80 | 81 | func verifySDRProps(t *testing.T, bt *boostTest) { 82 | /* 83 | Verify that all SDRs have the properties desired for this test. 84 | The bounds for checking overlap are set fairly loosely here since there is 85 | some variance due to randomness and the artificial parameters used in this 86 | test. 87 | */ 88 | 89 | // Verify that all SDR's are unique 90 | for i := 0; i <= 4; i++ { 91 | for j := 1; j <= 4; j++ { 92 | eq := 0 93 | for k := 0; k < len(bt.lastSDR[i]); k++ { 94 | if bt.lastSDR[i][k] == bt.lastSDR[j][k] { 95 | eq++ 96 | } 97 | } 98 | if eq == len(bt.lastSDR[i]) { 99 | //equal 100 | assert.Fail(t, "All SDR's are not unique") 101 | } 102 | } 103 | } 104 | 105 | //Verify that the first two SDR's have some overlap. 106 | expected := computeOverlap(bt.lastSDR[0], bt.lastSDR[1]) > 9 107 | assert.True(t, expected, "First two SDR's don't overlap much") 108 | 109 | // Verify the last three SDR's have low overlap with everyone else. 110 | for i := 2; i <= 4; i++ { 111 | for j := 0; j <= 4; j++ { 112 | if i != j { 113 | overlap := computeOverlap(bt.lastSDR[i], bt.lastSDR[j]) 114 | expected := overlap < 18 115 | assert.True(t, expected, "One of the last three SDRs has high overlap") 116 | } 117 | } 118 | } 119 | 120 | } 121 | 122 | func phase1(t *testing.T, bt *boostTest) { 123 | y := make([]bool, bt.sp.numColumns) 124 | // Do one training batch through the input patterns 125 | for idx, input := range bt.x { 126 | utils.FillSliceBool(y, false) 127 | bt.sp.Compute(input, true, y, bt.sp.InhibitColumns) 128 | for j, winner := range y { 129 | if winner { 130 | bt.winningIteration[j] = bt.sp.IterationLearnNum 131 | } 132 | } 133 | bt.lastSDR[idx] = y 134 | } 135 | 136 | //The boost factor for all columns should be at 1. 137 | assert.Equal(t, bt.sp.numColumns, utils.CountFloat64(bt.sp.boostFactors, 1), "Boost factors are not all 1") 138 | 139 | //At least half of the columns should have never been active. 140 | winners := utils.CountInt(bt.winningIteration, 0) 141 | assert.True(t, winners >= bt.sp.numColumns/2, "More than half of the columns have been active") 142 | 143 | //All the never-active columns should have duty cycle of 0 144 | //All the at-least-once-active columns should have duty cycle >= 0.2 145 | activeSum := 0.0 146 | for idx, val := range bt.sp.activeDutyCycles { 147 | if bt.winningIteration[idx] == 0 { 148 | //assert.Equal(t, expected, actual, ...) 149 | activeSum += val 150 | } 151 | } 152 | assert.Equal(t, 0, activeSum, "Inactive columns have positive duty cycle.") 153 | 154 | winningMin := 100000.0 155 | for idx, val := range bt.sp.activeDutyCycles { 156 | if bt.winningIteration[idx] > 0 { 157 | if val < winningMin { 158 | winningMin = val 159 | } 160 | } 161 | } 162 | assert.True(t, winningMin >= 0.2, "Active columns have duty cycle that is too low.") 163 | 164 | verifySDRProps(t, bt) 165 | } 166 | 167 | func phase2(t *testing.T, bt *boostTest) { 168 | 169 | y := make([]bool, bt.sp.numColumns) 170 | 171 | // Do 9 training batch through the input patterns 172 | for i := 0; i < 9; i++ { 173 | for idx, input := range bt.x { 174 | utils.FillSliceBool(y, false) 175 | bt.sp.Compute(input, true, y, bt.sp.InhibitColumns) 176 | for j, winner := range y { 177 | if winner { 178 | bt.winningIteration[j] = bt.sp.IterationLearnNum 179 | } 180 | } 181 | bt.lastSDR[idx] = y 182 | } 183 | } 184 | 185 | // The boost factor for all columns should be at 1. 186 | assert.Equal(t, bt.sp.numColumns, utils.CountFloat64(bt.sp.boostFactors, 1), "Boost factors are not all 1") 187 | 188 | // Roughly half of the columns should have never been active. 189 | winners := utils.CountInt(bt.winningIteration, 0) 190 | assert.True(t, winners >= int(0.4*float64(bt.sp.numColumns)), "More than 60% of the columns have been active") 191 | 192 | // All the never-active columns should have duty cycle of 0 193 | activeSum := 0.0 194 | for idx, val := range bt.sp.activeDutyCycles { 195 | if bt.winningIteration[idx] == 0 { 196 | activeSum += val 197 | } 198 | } 199 | assert.Equal(t, 0, activeSum, "Inactive columns have positive duty cycle.") 200 | 201 | dutyAvg := 0.0 202 | dutyCount := 0 203 | for _, val := range bt.sp.activeDutyCycles { 204 | if val > 0 { 205 | dutyAvg += val 206 | dutyCount++ 207 | } 208 | } 209 | 210 | // The average at-least-once-active columns should have duty cycle >= 0.15 211 | // and <= 0.25 212 | dutyAvg = dutyAvg / float64(dutyCount) 213 | assert.True(t, dutyAvg >= 0.15, "Average on-columns duty cycle is too low.") 214 | assert.True(t, dutyAvg <= 0.30, "Average on-columns duty cycle is too high.") 215 | 216 | verifySDRProps(t, bt) 217 | } 218 | 219 | func phase3(t *testing.T, bt *boostTest) { 220 | //Do two more training batches through the input patterns 221 | y := make([]bool, bt.sp.numColumns) 222 | 223 | for i := 0; i < 2; i++ { 224 | for idx, input := range bt.x { 225 | utils.FillSliceBool(y, false) 226 | bt.sp.Compute(input, true, y, bt.sp.InhibitColumns) 227 | for j, winner := range y { 228 | if winner { 229 | bt.winningIteration[j] = bt.sp.IterationLearnNum 230 | } 231 | } 232 | bt.lastSDR[idx] = y 233 | } 234 | } 235 | 236 | // The boost factor for all columns that just won should be at 1. 237 | for idx, val := range y { 238 | if val { 239 | if bt.sp.boostFactors[idx] != 1 { 240 | assert.Fail(t, "Boost factors of winning columns not 1") 241 | } 242 | } 243 | } 244 | 245 | // By now, every column should have been sufficiently boosted to win at least 246 | // once. The number of columns that have never won should now be 0 247 | for _, val := range bt.winningIteration { 248 | if val == 0 { 249 | assert.Fail(t, "Expected all columns to have won atleast once.") 250 | } 251 | } 252 | 253 | // Because of the artificially induced thrashing, even the first two patterns 254 | // should have low overlap. Verify that the first two SDR's now have little 255 | // overlap 256 | overlap := computeOverlap(bt.lastSDR[0], bt.lastSDR[1]) 257 | assert.True(t, overlap < 7, "First two SDR's overlap significantly when they should not") 258 | } 259 | 260 | func phase4(t *testing.T, bt *boostTest) { 261 | //The boost factor for all columns that just won should be at 1. 262 | boostAtBeg := make([]float64, len(bt.sp.boostFactors)) 263 | copy(bt.sp.boostFactors, boostAtBeg) 264 | 265 | // Do one more iteration through the input patterns with learning OFF 266 | y := make([]bool, bt.sp.numColumns) 267 | for _, input := range bt.x { 268 | utils.FillSliceBool(y, false) 269 | bt.sp.Compute(input, false, y, bt.sp.InhibitColumns) 270 | 271 | // The boost factor for all columns that just won should be at 1. 272 | assert.Equal(t, utils.SumSliceFloat64(boostAtBeg), utils.SumSliceFloat64(bt.sp.boostFactors), "Boost factors changed when learning is off") 273 | } 274 | 275 | } 276 | 277 | func BoostTest(t *testing.T) { 278 | bt := boostTest{} 279 | spParams := NewSpParams() 280 | spParams.InputDimensions = []int{90} 281 | spParams.ColumnDimensions = []int{600} 282 | spParams.PotentialRadius = 90 283 | spParams.PotentialPct = 0.9 284 | spParams.GlobalInhibition = true 285 | spParams.NumActiveColumnsPerInhArea = 60 286 | spParams.MinPctActiveDutyCycle = 0.1 287 | spParams.SynPermActiveInc = 0 288 | spParams.SynPermInactiveDec = 0 289 | spParams.DutyCyclePeriod = 10 290 | bt.sp = NewSpatialPooler(spParams) 291 | 292 | // Create a set of input vectors, x 293 | // B,C,D don't overlap at all with other patterns 294 | bt.x = make([][]bool, 5) 295 | for i := range bt.x { 296 | bt.x[i] = make([]bool, bt.sp.numInputs) 297 | } 298 | 299 | utils.FillSliceRangeBool(bt.x[0], true, 0, 20) 300 | utils.FillSliceRangeBool(bt.x[1], true, 10, 30) 301 | utils.FillSliceRangeBool(bt.x[2], true, 30, 50) 302 | utils.FillSliceRangeBool(bt.x[3], true, 50, 70) 303 | utils.FillSliceRangeBool(bt.x[4], true, 70, 90) 304 | // For each column, this will contain the last iteration number where that 305 | // column was a winner 306 | bt.winningIteration = make([]int, bt.sp.numColumns) 307 | 308 | // For each input vector i, lastSDR[i] contains the most recent SDR output 309 | // by the SP. 310 | bt.lastSDR = make([][]bool, 5) 311 | 312 | phase1(t, &bt) 313 | phase2(t, &bt) 314 | phase3(t, &bt) 315 | phase4(t, &bt) 316 | } 317 | -------------------------------------------------------------------------------- /temporalPoolerStats.go: -------------------------------------------------------------------------------- 1 | // 2 | // Code related to temporal pooler stats 3 | // 4 | 5 | package htm 6 | 7 | import ( 8 | "fmt" 9 | "github.com/cznic/mathutil" 10 | "github.com/nupic-community/htm/utils" 11 | "github.com/zacg/floats" 12 | "github.com/zacg/go.matrix" 13 | //"math" 14 | //"math/rand" 15 | //"sort" 16 | ) 17 | 18 | type TpStats struct { 19 | NInfersSinceReset int 20 | NPredictions int 21 | PredictionScoreTotal float64 22 | PredictionScoreTotal2 float64 23 | FalseNegativeScoreTotal float64 24 | FalsePositiveScoreTotal float64 25 | PctExtraTotal float64 26 | PctMissingTotal float64 27 | TotalMissing float64 28 | TotalExtra float64 29 | 30 | CurPredictionScore float64 31 | CurPredictionScore2 float64 32 | CurFalseNegativeScore float64 33 | CurFalsePositiveScore float64 34 | CurMissing float64 35 | CurExtra float64 36 | ConfHistogram matrix.DenseMatrix 37 | } 38 | 39 | func (s *TpStats) ToString() string { 40 | result := "Stats: \n" 41 | 42 | result += fmt.Sprintf("nInferSinceReset %v \n", s.NInfersSinceReset) 43 | result += fmt.Sprintf("nPredictions %v \n", s.NPredictions) 44 | result += fmt.Sprintf("PredictionScoreTotal %v \n", s.PredictionScoreTotal) 45 | result += fmt.Sprintf("PredictionScoreTotal2 %v \n", s.PredictionScoreTotal2) 46 | result += fmt.Sprintf("FalseNegativeScoreTotal %v \n", s.FalseNegativeScoreTotal) 47 | result += fmt.Sprintf("FalsePositiveScoreTotal %v \n", s.FalsePositiveScoreTotal) 48 | result += fmt.Sprintf("PctExtraTotal %v \n", s.PctExtraTotal) 49 | result += fmt.Sprintf("PctMissingTotal %v \n", s.PctMissingTotal) 50 | result += fmt.Sprintf("TotalMissing %v \n", s.TotalMissing) 51 | result += fmt.Sprintf("TotalExtra %v \n", s.TotalExtra) 52 | result += fmt.Sprintf("CurPredictionScore %v \n", s.CurPredictionScore) 53 | result += fmt.Sprintf("CurPredictionScore2 %v \n", s.CurPredictionScore2) 54 | result += fmt.Sprintf("CurFalseNegativeScore %v \n", s.CurFalseNegativeScore) 55 | result += fmt.Sprintf("CurFalsePositiveScore %v \n", s.CurFalsePositiveScore) 56 | result += fmt.Sprintf("CurMissing %v \n", s.CurMissing) 57 | result += fmt.Sprintf("CurExtra %v \n", s.CurExtra) 58 | result += fmt.Sprintf("ConfHistogram %v \n", s.ConfHistogram.String()) 59 | 60 | return result 61 | } 62 | 63 | type confidence struct { 64 | PredictionScore float64 65 | PositivePredictionScore float64 66 | NegativePredictionScore float64 67 | } 68 | 69 | /* 70 | This function produces goodness-of-match scores for a set of input patterns, 71 | by checking for their presence in the current and predicted output of the 72 | TP. Returns a global count of the number of extra and missing bits, the 73 | confidence scores for each input pattern, and (if requested) the 74 | bits in each input pattern that were not present in the TP's prediction. 75 | 76 | param patternNZs a list of input patterns that we want to check for. Each 77 | element is a list of the non-zeros in that pattern. 78 | param output The output of the TP. If not specified, then use the 79 | TP's current output. This can be specified if you are 80 | trying to check the prediction metric for an output from 81 | the past. 82 | param colConfidence The column confidences. If not specified, then use the 83 | TP's current colConfidence. This can be specified if you 84 | are trying to check the prediction metrics for an output 85 | from the past. 86 | param details if True, also include details of missing bits per pattern. 87 | 88 | returns list containing: 89 | 90 | [ 91 | totalExtras, 92 | totalMissing, 93 | [conf_1, conf_2, ...], 94 | [missing1, missing2, ...] 95 | ] 96 | 97 | retval totalExtras a global count of the number of 'extras', i.e. bits that 98 | are on in the current output but not in the or of all the 99 | passed in patterns 100 | retval totalMissing a global count of all the missing bits, i.e. the bits 101 | that are on in the or of the patterns, but not in the 102 | current output 103 | retval conf_i the confidence score for the i'th pattern inpatternsToCheck 104 | This consists of 3 items as a tuple: 105 | (predictionScore, posPredictionScore, negPredictionScore) 106 | retval missing_i the bits in the i'th pattern that were missing 107 | in the output. This list is only returned if details is 108 | True. 109 | */ 110 | func (tp *TemporalPooler) checkPrediction2(patternNZs [][]int, output *SparseBinaryMatrix, 111 | colConfidence []float64, details bool) (int, int, []confidence, []int) { 112 | 113 | // Get the non-zeros in each pattern 114 | numPatterns := len(patternNZs) 115 | 116 | // Compute the union of all the expected patterns 117 | var orAll []int 118 | for _, row := range patternNZs { 119 | for _, col := range row { 120 | if !utils.ContainsInt(col, orAll) { 121 | orAll = append(orAll, col) 122 | } 123 | } 124 | } 125 | 126 | var outputIdxs []int 127 | 128 | // Get the list of active columns in the output 129 | if output == nil { 130 | if tp.CurrentOutput == nil { 131 | panic("Expected tp output") 132 | } 133 | outputIdxs = tp.CurrentOutput.NonZeroRows() 134 | } else { 135 | outputIdxs = output.NonZeroRows() 136 | } 137 | 138 | // Compute the total extra and missing in the output 139 | totalExtras := 0 140 | totalMissing := 0 141 | 142 | for _, val := range outputIdxs { 143 | if !utils.ContainsInt(val, orAll) { 144 | totalExtras++ 145 | } 146 | } 147 | 148 | for _, val := range orAll { 149 | if !utils.ContainsInt(val, outputIdxs) { 150 | totalMissing++ 151 | } 152 | } 153 | 154 | // Get the percent confidence level per column by summing the confidence 155 | // levels of the cells in the column. During training, each segment's 156 | // confidence number is computed as a running average of how often it 157 | // correctly predicted bottom-up activity on that column. A cell's 158 | // confidence number is taken from the first active segment found in the 159 | // cell. Note that confidence will only be non-zero for predicted columns. 160 | 161 | if colConfidence == nil { 162 | if tp.params.Verbosity >= 5 { 163 | fmt.Println("Col confidence nil, copying from tp state...") 164 | } 165 | colConfidence = make([]float64, len(tp.DynamicState.ColConfidence)) 166 | copy(colConfidence, tp.DynamicState.ColConfidence) 167 | } 168 | 169 | // Assign confidences to each pattern 170 | var confidences []confidence 171 | 172 | for i := 0; i < numPatterns; i++ { 173 | // Sum of the column confidences for this pattern 174 | //positivePredictionSum = colConfidence[patternNZs[i]].sum() 175 | positivePredictionSum := floats.Sum(floats.SubSet(colConfidence, patternNZs[i])) 176 | 177 | // How many columns in this pattern 178 | positiveColumnCount := len(patternNZs[i]) 179 | 180 | // Sum of all the column confidences 181 | totalPredictionSum := floats.Sum(colConfidence) 182 | // Total number of columns 183 | totalColumnCount := len(colConfidence) 184 | 185 | negativePredictionSum := totalPredictionSum - positivePredictionSum 186 | negativeColumnCount := totalColumnCount - positiveColumnCount 187 | 188 | positivePredictionScore := 0.0 189 | // Compute the average confidence score per column for this pattern 190 | if positiveColumnCount != 0 { 191 | positivePredictionScore = positivePredictionSum 192 | } 193 | 194 | // Compute the average confidence score per column for the other patterns 195 | negativePredictionScore := 0.0 196 | if negativeColumnCount != 0 { 197 | negativePredictionScore = negativePredictionSum 198 | } 199 | 200 | // Scale the positive and negative prediction scores so that they sum to 201 | // 1.0 202 | currentSum := negativePredictionScore + positivePredictionScore 203 | if currentSum > 0 { 204 | positivePredictionScore *= 1.0 / currentSum 205 | negativePredictionScore *= 1.0 / currentSum 206 | } 207 | 208 | predictionScore := positivePredictionScore - negativePredictionScore 209 | newConf := confidence{predictionScore, positivePredictionScore, negativePredictionScore} 210 | confidences = append(confidences, newConf) 211 | 212 | } 213 | 214 | // Include detail? (bits in each pattern that were missing from the output) 215 | if details { 216 | var missingPatternBits []int 217 | for _, pattern := range patternNZs { 218 | for _, val := range pattern { 219 | if !utils.ContainsInt(val, outputIdxs) && 220 | !utils.ContainsInt(val, missingPatternBits) { 221 | missingPatternBits = append(missingPatternBits, val) 222 | } 223 | } 224 | 225 | } 226 | return totalExtras, totalMissing, confidences, missingPatternBits 227 | } else { 228 | return totalExtras, totalMissing, confidences, nil 229 | } 230 | 231 | } 232 | 233 | /* 234 | Called at the end of learning and inference, this routine will update 235 | a number of stats in our _internalStats dictionary, including our computed 236 | prediction score. 237 | 238 | param stats internal stats dictionary 239 | param bottomUpNZ list of the active bottom-up inputs 240 | param predictedState The columns we predicted on the last time step (should 241 | match the current bottomUpNZ in the best case) 242 | param colConfidence Column confidences we determined on the last time step 243 | */ 244 | 245 | func (tp *TemporalPooler) updateStatsInferEnd(stats *TpStats, bottomUpNZ []int, 246 | predictedState *SparseBinaryMatrix, colConfidence []float64) { 247 | // Return if not collecting stats 248 | if !tp.params.CollectStats { 249 | return 250 | } 251 | 252 | stats.NInfersSinceReset++ 253 | 254 | // Compute the prediction score, how well the prediction from the last 255 | // time step predicted the current bottom-up input 256 | numExtra2, numMissing2, confidences2, _ := tp.checkPrediction2([][]int{bottomUpNZ}, predictedState, colConfidence, false) 257 | predictionScore := confidences2[0].PredictionScore 258 | positivePredictionScore := confidences2[0].PositivePredictionScore 259 | negativePredictionScore := confidences2[0].NegativePredictionScore 260 | 261 | // Store the stats that don't depend on burn-in 262 | stats.CurPredictionScore2 = predictionScore 263 | stats.CurFalseNegativeScore = negativePredictionScore 264 | stats.CurFalsePositiveScore = positivePredictionScore 265 | 266 | stats.CurMissing = float64(numMissing2) 267 | stats.CurExtra = float64(numExtra2) 268 | 269 | // If we are passed the burn-in period, update the accumulated stats 270 | // Here's what various burn-in values mean: 271 | // 0: try to predict the first element of each sequence and all subsequent 272 | // 1: try to predict the second element of each sequence and all subsequent 273 | // etc. 274 | if stats.NInfersSinceReset <= tp.params.BurnIn { 275 | return 276 | } 277 | 278 | // Burn-in related stats 279 | stats.NPredictions++ 280 | numExpected := mathutil.Max(1, len(bottomUpNZ)) 281 | 282 | stats.TotalMissing += float64(numMissing2) 283 | stats.TotalExtra += float64(numExtra2) 284 | stats.PctExtraTotal += 100.0 * float64(numExtra2) / float64(numExpected) 285 | stats.PctMissingTotal += 100.0 * float64(numMissing2) / float64(numExpected) 286 | stats.PredictionScoreTotal2 += predictionScore 287 | stats.FalseNegativeScoreTotal += 1.0 - positivePredictionScore 288 | stats.FalsePositiveScoreTotal += negativePredictionScore 289 | 290 | if tp.collectSequenceStats { 291 | // Collect cell confidences for every cell that correctly predicted current 292 | // bottom up input. Normalize confidence across each column 293 | cc := tp.DynamicState.CellConfidence.Copy() 294 | 295 | for r := 0; r < cc.Rows(); r++ { 296 | for c := 0; c < cc.Cols(); c++ { 297 | if !tp.DynamicState.InfActiveState.Get(r, c) { 298 | cc.Set(r, c, 0) 299 | } 300 | } 301 | } 302 | sconf := make([]int, cc.Rows()) 303 | for r := 0; r < cc.Rows(); r++ { 304 | count := 0 305 | for c := 0; c < cc.Cols(); c++ { 306 | if cc.Get(r, c) > 0 { 307 | count++ 308 | } 309 | } 310 | sconf[r] = count 311 | } 312 | 313 | for r := 0; r < cc.Rows(); r++ { 314 | for c := 0; c < cc.Cols(); c++ { 315 | temp := cc.Get(r, c) 316 | cc.Set(r, c, temp/float64(sconf[r])) 317 | } 318 | } 319 | 320 | // Update cell confidence histogram: add column-normalized confidence 321 | // scores to the histogram 322 | stats.ConfHistogram.Add(cc) 323 | } 324 | 325 | } 326 | -------------------------------------------------------------------------------- /segment.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cznic/mathutil" 6 | //"github.com/skelterjohn/go.matrix" 7 | "math" 8 | //"math/rand" 9 | //"sort" 10 | "github.com/gonum/floats" 11 | "github.com/nupic-community/htm/utils" 12 | ) 13 | 14 | var SegmentDutyCycleTiers = []int{0, 100, 320, 1000, 15 | 3200, 10000, 32000, 100000, 320000} 16 | 17 | var SegmentDutyCycleAlphas = []float64{0, 0.0032, 0.0010, 0.00032, 18 | 0.00010, 0.000032, 0.00001, 0.0000032, 19 | 0.0000010} 20 | 21 | type Synapse struct { 22 | SrcCellCol int 23 | SrcCellIdx int 24 | Permanence float64 25 | } 26 | 27 | // The Segment struct is a container for all of the segment variables and 28 | //the synapses it owns. 29 | type Segment struct { 30 | tp *TemporalPooler 31 | segId int 32 | isSequenceSeg bool 33 | lastActiveIteration int 34 | positiveActivations int 35 | totalActivations int 36 | lastPosDutyCycle float64 37 | lastPosDutyCycleIteration int 38 | syns []Synapse 39 | } 40 | 41 | //Determines segment equality 42 | func (s *Segment) Equals(seg *Segment) bool { 43 | synsEqual := true 44 | 45 | if len(s.syns) != len(seg.syns) { 46 | return false 47 | } 48 | 49 | for idx, val := range s.syns { 50 | if seg.syns[idx].Permanence != val.Permanence || 51 | seg.syns[idx].SrcCellCol != val.SrcCellCol || 52 | seg.syns[idx].SrcCellIdx != val.SrcCellIdx { 53 | return false 54 | } 55 | } 56 | 57 | return synsEqual && 58 | s.tp == seg.tp && 59 | s.segId == seg.segId && 60 | s.isSequenceSeg == seg.isSequenceSeg && 61 | s.lastActiveIteration == seg.lastActiveIteration && 62 | s.positiveActivations == seg.positiveActivations && 63 | s.totalActivations == seg.totalActivations && 64 | s.lastPosDutyCycle == seg.lastPosDutyCycle && 65 | s.lastPosDutyCycleIteration == seg.lastPosDutyCycleIteration 66 | 67 | } 68 | 69 | //Creates a new segment 70 | func NewSegment(tp *TemporalPooler, isSequenceSeg bool) *Segment { 71 | seg := Segment{} 72 | seg.tp = tp 73 | seg.segId = tp.GetSegId() 74 | seg.isSequenceSeg = isSequenceSeg 75 | seg.lastActiveIteration = tp.lrnIterationIdx 76 | seg.positiveActivations = 1 77 | seg.totalActivations = 1 78 | 79 | seg.lastPosDutyCycle = 1.0 / float64(tp.lrnIterationIdx) 80 | seg.lastPosDutyCycleIteration = tp.lrnIterationIdx 81 | 82 | //TODO: initialize synapse collection 83 | 84 | return &seg 85 | } 86 | 87 | /* 88 | Compute/update and return the positive activations duty cycle of 89 | this segment. This is a measure of how often this segment is 90 | providing good predictions. 91 | 92 | param active True if segment just provided a good prediction 93 | param readOnly If True, compute the updated duty cycle, but don't change 94 | the cached value. This is used by debugging print statements. 95 | 96 | returns The duty cycle, a measure of how often this segment is 97 | providing good predictions. 98 | 99 | **NOTE:** This method relies on different schemes to compute the duty cycle 100 | based on how much history we have. In order to support this tiered 101 | approach **IT MUST BE CALLED ON EVERY SEGMENT AT EACH DUTY CYCLE TIER** 102 | (ref dutyCycleTiers). 103 | 104 | When we don't have a lot of history yet (first tier), we simply return 105 | number of positive activations / total number of iterations 106 | 107 | After a certain number of iterations have accumulated, it converts into 108 | a moving average calculation, which is updated only when requested 109 | since it can be a bit expensive to compute on every iteration (it uses 110 | the pow() function). 111 | 112 | The duty cycle is computed as follows: 113 | 114 | dc[t] = (1-alpha) * dc[t-1] + alpha * value[t] 115 | 116 | If the value[t] has been 0 for a number of steps in a row, you can apply 117 | all of the updates at once using: 118 | 119 | dc[t] = (1-alpha)^(t-lastT) * dc[lastT] 120 | 121 | We use the alphas and tiers as defined in ref dutyCycleAlphas and 122 | ref dutyCycleTiers. 123 | */ 124 | func (s *Segment) dutyCycle(active, readOnly bool) float64 { 125 | 126 | // For tier #0, compute it from total number of positive activations seen 127 | if s.tp.lrnIterationIdx <= SegmentDutyCycleTiers[1] { 128 | dutyCycle := float64(s.positiveActivations) / float64(s.tp.lrnIterationIdx) 129 | if !readOnly { 130 | s.lastPosDutyCycleIteration = s.tp.lrnIterationIdx 131 | s.lastPosDutyCycle = dutyCycle 132 | } 133 | return dutyCycle 134 | } 135 | 136 | // How old is our update? 137 | age := s.tp.lrnIterationIdx - s.lastPosDutyCycleIteration 138 | 139 | //If it's already up to date, we can returned our cached value. 140 | if age == 0 && !active { 141 | return s.lastPosDutyCycle 142 | } 143 | 144 | alpha := 0.0 145 | //Figure out which alpha we're using 146 | for i := len(SegmentDutyCycleTiers) - 1; i > 0; i-- { 147 | if s.tp.lrnIterationIdx > SegmentDutyCycleTiers[i] { 148 | alpha = SegmentDutyCycleAlphas[i] 149 | break 150 | } 151 | } 152 | 153 | // Update duty cycle 154 | dutyCycle := math.Pow(1.0-alpha, float64(age)) * s.lastPosDutyCycle 155 | 156 | if active { 157 | dutyCycle += alpha 158 | } 159 | 160 | // Update cached values if not read-only 161 | if !readOnly { 162 | s.lastPosDutyCycleIteration = s.tp.lrnIterationIdx 163 | s.lastPosDutyCycle = dutyCycle 164 | } 165 | 166 | return dutyCycle 167 | } 168 | 169 | /* 170 | Free up some synapses in this segment. We always free up inactive 171 | synapses (lowest permanence freed up first) before we start to free up 172 | active ones. 173 | 174 | param numToFree number of synapses to free up 175 | param inactiveSynapseIndices list of the inactive synapse indices. 176 | */ 177 | func (s *Segment) freeNSynapses(numToFree int, inactiveSynapseIndices []int) { 178 | //Make sure numToFree isn't larger than the total number of syns we have 179 | if numToFree > len(s.syns) { 180 | panic("Number to free cannot be larger than existing synapses.") 181 | } 182 | 183 | if s.tp.params.Verbosity >= 5 { 184 | fmt.Println("freeNSynapses with numToFree=", numToFree) 185 | fmt.Println("inactiveSynapseIndices= ", inactiveSynapseIndices) 186 | } 187 | 188 | var candidates []int 189 | // Remove the lowest perm inactive synapses first 190 | if len(inactiveSynapseIndices) > 0 { 191 | perms := make([]float64, len(inactiveSynapseIndices)) 192 | for idx, _ := range perms { 193 | perms[idx] = s.syns[idx].Permanence 194 | } 195 | var indexes []int 196 | floats.Argsort(perms, indexes) 197 | //sort perms 198 | cSize := mathutil.Min(numToFree, len(perms)) 199 | candidates = make([]int, cSize) 200 | //indexes[0:cSize] 201 | for i := 0; i < cSize; i++ { 202 | candidates[i] = inactiveSynapseIndices[indexes[i]] 203 | } 204 | } 205 | 206 | // Do we need more? if so, remove the lowest perm active synapses too 207 | var activeSynIndices []int 208 | if len(candidates) < numToFree { 209 | for i := 0; i < len(s.syns); i++ { 210 | if !utils.ContainsInt(i, inactiveSynapseIndices) { 211 | activeSynIndices = append(activeSynIndices, i) 212 | } 213 | } 214 | 215 | perms := make([]float64, len(activeSynIndices)) 216 | for i := range perms { 217 | perms[i] = s.syns[i].Permanence 218 | } 219 | var indexes []int 220 | floats.Argsort(perms, indexes) 221 | 222 | moreToFree := numToFree - len(candidates) 223 | //moreCandidates := make([]int, moreToFree) 224 | for i := 0; i < moreToFree; i++ { 225 | candidates = append(candidates, activeSynIndices[indexes[i]]) 226 | } 227 | } 228 | 229 | if s.tp.params.Verbosity >= 4 { 230 | fmt.Printf("Deleting %v synapses from segment to make room for new ones: %v \n", 231 | len(candidates), candidates) 232 | fmt.Println("Before:", s.ToString()) 233 | } 234 | 235 | // Delete candidate syns by copying undeleted to new slice 236 | var newSyns []Synapse 237 | for idx, val := range s.syns { 238 | if !utils.ContainsInt(idx, candidates) { 239 | newSyns = append(newSyns, val) 240 | } 241 | } 242 | s.syns = newSyns 243 | 244 | if s.tp.params.Verbosity >= 4 { 245 | fmt.Println("After:", s.ToString()) 246 | } 247 | 248 | } 249 | 250 | /* 251 | Update a set of synapses in the segment. 252 | 253 | param synapses List of synapse indices to update 254 | param delta How much to add to each permanence 255 | 256 | returns True if synapse reached 0 257 | */ 258 | func (s *Segment) updateSynapses(synapses []int, delta float64) bool { 259 | hitZero := false 260 | 261 | if delta > 0 { 262 | for idx, _ := range synapses { 263 | s.syns[idx].Permanence += delta 264 | // Cap synapse permanence at permanenceMax 265 | if s.syns[idx].Permanence > s.tp.params.PermanenceMax { 266 | s.syns[idx].Permanence = s.tp.params.PermanenceMax 267 | } 268 | } 269 | } else { 270 | for idx, _ := range synapses { 271 | s.syns[idx].Permanence += delta 272 | // Cap min synapse permanence to 0 in case there is no global decay 273 | if s.syns[idx].Permanence <= 0 { 274 | s.syns[idx].Permanence = 0 275 | hitZero = true 276 | } 277 | } 278 | } 279 | 280 | return hitZero 281 | } 282 | 283 | /* 284 | Adds a new synapse 285 | */ 286 | func (s *Segment) AddSynapse(srcCellCol, srcCellIdx int, perm float64) { 287 | s.syns = append(s.syns, Synapse{srcCellCol, srcCellIdx, perm}) 288 | } 289 | 290 | /* 291 | Return a segmentUpdate data structure containing a list of proposed 292 | changes to segment s. Let activeSynapses be the list of active synapses 293 | where the originating cells have their activeState output = true at time step 294 | t. (This list is empty if s is None since the segment doesn't exist.) 295 | newSynapses is an optional argument that defaults to false. If newSynapses 296 | is true, then newSynapseCount - len(activeSynapses) synapses are added to 297 | activeSynapses. These synapses are randomly chosen from the set of cells 298 | that have learnState = true at timeStep. 299 | */ 300 | func (tp *TemporalPooler) getSegmentActiveSynapses(c int, i int, s *Segment, 301 | activeState *SparseBinaryMatrix, newSynapses bool) *SegmentUpdate { 302 | var activeSynapses []SynapseUpdateState 303 | 304 | if tp.params.Verbosity >= 5 { 305 | fmt.Printf("Entering getSegActiveSyns syns:%v segnil:%v newsyns:%v \n", 0, s == nil, newSynapses) 306 | } 307 | 308 | if s != nil { 309 | for idx, val := range s.syns { 310 | if activeState.Get(val.SrcCellCol, val.SrcCellIdx) { 311 | temp := SynapseUpdateState{} 312 | temp.Index = idx 313 | activeSynapses = append(activeSynapses, temp) 314 | } 315 | } 316 | } 317 | 318 | if newSynapses { 319 | nSynapsesToAdd := tp.params.NewSynapseCount - len(activeSynapses) 320 | newSyns := tp.chooseCellsToLearnFrom(s, nSynapsesToAdd, activeState) 321 | //fmt.Printf("newSyncount: %v \n", len(newSyns)) 322 | for _, val := range newSyns { 323 | temp := SynapseUpdateState{} 324 | temp.Index = val.Row 325 | temp.CellIndex = val.Col 326 | temp.New = true 327 | activeSynapses = append(activeSynapses, temp) 328 | } 329 | } 330 | 331 | // It's still possible that activeSynapses is empty, and this will 332 | // be handled in addToSegmentUpdates 333 | result := new(SegmentUpdate) 334 | result.activeSynapses = activeSynapses 335 | result.columnIdx = c 336 | result.cellIdx = i 337 | result.segment = s 338 | return result 339 | 340 | } 341 | 342 | /* 343 | Print segment information for verbose messaging and debugging. 344 | This uses the following format: 345 | 346 | ID:54413 True 0.64801 (24/36) 101 [9,1]0.75 [10,1]0.75 [11,1]0.75 347 | 348 | where: 349 | 54413 - is the unique segment id 350 | True - is sequence segment 351 | 0.64801 - moving average duty cycle 352 | (24/36) - (numPositiveActivations / numTotalActivations) 353 | 101 - age, number of iterations since last activated 354 | [9,1]0.75 - synapse from column 9, cell #1, strength 0.75 355 | [10,1]0.75 - synapse from column 10, cell #1, strength 0.75 356 | [11,1]0.75 - synapse from column 11, cell #1, strength 0.75 357 | */ 358 | func (s *Segment) ToString() string { 359 | //ID 360 | result := fmt.Sprintf("ID:%v %v ", s.segId, s.isSequenceSeg) 361 | 362 | //Duty Cycle 363 | result += fmt.Sprintf("%v", s.dutyCycle(false, true)) 364 | 365 | //numPositive/totalActivations 366 | result += fmt.Sprintf(" (%v/%v) ", s.positiveActivations, s.totalActivations) 367 | 368 | //age 369 | result += fmt.Sprintf("%v", s.tp.lrnIterationIdx-s.lastActiveIteration) 370 | 371 | // Print each synapses on this segment as: srcCellCol/srcCellIdx/perm 372 | // if the permanence is above connected, put [] around the synapse coords 373 | for _, syn := range s.syns { 374 | result += fmt.Sprintf(" [%v,%v]%v", syn.SrcCellCol, syn.SrcCellIdx, syn.Permanence) 375 | } 376 | 377 | result += "\n" 378 | 379 | return result 380 | } 381 | -------------------------------------------------------------------------------- /temporalMemory_test.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | //"fmt" 5 | "github.com/stretchr/testify/assert" 6 | "sort" 7 | "testing" 8 | ) 9 | 10 | func TestPickCellsToLearnOnAvoidDuplicates(t *testing.T) { 11 | tmp := NewTemporalMemoryParams() 12 | tmp.MaxNewSynapseCount = 1000 13 | tm := NewTemporalMemory(tmp) 14 | 15 | connections := tm.Connections 16 | connections.CreateSegment(0) 17 | connections.CreateSynapse(0, 23, 0.6) 18 | 19 | winnerCells := []int{233, 144} 20 | 21 | // Ensure that no additional (duplicate) cells were picked 22 | assert.Equal(t, winnerCells, tm.pickCellsToLearnOn(2, 0, winnerCells, connections)) 23 | 24 | } 25 | 26 | func TestPickCellsToLearnOn(t *testing.T) { 27 | tmp := NewTemporalMemoryParams() 28 | tm := NewTemporalMemory(tmp) 29 | connections := tm.Connections 30 | connections.CreateSegment(0) 31 | 32 | winnerCells := []int{4, 47, 58, 93} 33 | 34 | result := tm.pickCellsToLearnOn(100, 0, winnerCells, connections) 35 | sort.Ints(result) 36 | assert.Equal(t, []int{4, 47, 58, 93}, result) 37 | assert.Equal(t, []int{}, tm.pickCellsToLearnOn(0, 0, winnerCells, connections)) 38 | assert.Equal(t, []int{4, 58}, tm.pickCellsToLearnOn(2, 0, winnerCells, connections)) 39 | } 40 | 41 | func TestAdaptSegmentToMin(t *testing.T) { 42 | tmp := NewTemporalMemoryParams() 43 | tm := NewTemporalMemory(tmp) 44 | connections := tm.Connections 45 | connections.CreateSegment(0) 46 | connections.CreateSynapse(0, 23, 0.1) 47 | 48 | tm.adaptSegment(0, []int{}, connections) 49 | assert.Equal(t, 0.0, connections.DataForSynapse(0).Permanence) 50 | 51 | // // Now permanence should be at min 52 | tm.adaptSegment(0, []int{}, connections) 53 | assert.Equal(t, 0.0, connections.DataForSynapse(0).Permanence) 54 | 55 | } 56 | 57 | func TestAdaptSegmentToMax(t *testing.T) { 58 | tmp := NewTemporalMemoryParams() 59 | tm := NewTemporalMemory(tmp) 60 | connections := tm.Connections 61 | connections.CreateSegment(0) 62 | connections.CreateSynapse(0, 23, 0.9) 63 | 64 | tm.adaptSegment(0, []int{0}, connections) 65 | assert.Equal(t, 1.0, connections.DataForSynapse(0).Permanence) 66 | 67 | // Now permanence should be at max 68 | tm.adaptSegment(0, []int{0}, connections) 69 | assert.Equal(t, 1.0, connections.DataForSynapse(0).Permanence) 70 | 71 | } 72 | 73 | func TestLeastUsedCell(t *testing.T) { 74 | tmp := NewTemporalMemoryParams() 75 | tmp.ColumnDimensions = []int{2} 76 | tmp.CellsPerColumn = 2 77 | 78 | tm := NewTemporalMemory(tmp) 79 | 80 | connections := tm.Connections 81 | connections.CreateSegment(0) 82 | connections.CreateSynapse(0, 3, 0.3) 83 | 84 | for i := 0; i < 100; i++ { 85 | assert.Equal(t, 1, tm.getLeastUsedCell(0, connections)) 86 | } 87 | 88 | } 89 | 90 | func TestAdaptSegment(t *testing.T) { 91 | tmp := NewTemporalMemoryParams() 92 | tm := NewTemporalMemory(tmp) 93 | connections := tm.Connections 94 | 95 | connections.CreateSegment(0) 96 | connections.CreateSynapse(0, 23, 0.6) 97 | connections.CreateSynapse(0, 37, 0.4) 98 | connections.CreateSynapse(0, 477, 0.9) 99 | tm.adaptSegment(0, []int{0, 1}, connections) 100 | 101 | assert.Equal(t, 0.7, connections.DataForSynapse(0).Permanence) 102 | assert.Equal(t, 0.5, connections.DataForSynapse(1).Permanence) 103 | assert.Equal(t, 0.8, connections.DataForSynapse(2).Permanence) 104 | 105 | } 106 | 107 | func TestGetConnectedActiveSynapsesForSegment(t *testing.T) { 108 | tmp := NewTemporalMemoryParams() 109 | tm := NewTemporalMemory(tmp) 110 | connections := tm.Connections 111 | 112 | connections.CreateSegment(0) 113 | connections.CreateSynapse(0, 23, 0.6) 114 | connections.CreateSynapse(0, 37, 0.4) 115 | connections.CreateSynapse(0, 477, 0.9) 116 | connections.CreateSegment(1) 117 | connections.CreateSynapse(1, 733, 0.7) 118 | connections.CreateSegment(8) 119 | connections.CreateSynapse(2, 486, 0.9) 120 | 121 | activeSynapsesForSegment := map[int][]int{ 122 | 0: {0, 1}, 123 | 1: {3}, 124 | } 125 | 126 | assert.Equal(t, []int{0}, tm.getConnectedActiveSynapsesForSegment(0, 127 | activeSynapsesForSegment, 128 | 0.5, 129 | connections)) 130 | 131 | assert.Equal(t, []int{3}, tm.getConnectedActiveSynapsesForSegment(1, 132 | activeSynapsesForSegment, 133 | 0.5, 134 | connections)) 135 | 136 | } 137 | 138 | func TestComputeActiveSynapsesNoActivity(t *testing.T) { 139 | tmp := NewTemporalMemoryParams() 140 | tm := NewTemporalMemory(tmp) 141 | connections := tm.Connections 142 | 143 | connections.CreateSegment(0) 144 | connections.CreateSynapse(0, 23, 0.6) 145 | connections.CreateSynapse(0, 37, 0.4) 146 | connections.CreateSynapse(0, 477, 0.9) 147 | connections.CreateSegment(1) 148 | connections.CreateSynapse(1, 733, 0.7) 149 | connections.CreateSegment(8) 150 | connections.CreateSynapse(2, 486, 0.9) 151 | activeCells := []int{} 152 | assert.Equal(t, map[int][]int{}, tm.computeActiveSynapses(activeCells, connections)) 153 | 154 | } 155 | 156 | func TestGetBestMatchingSegment(t *testing.T) { 157 | 158 | tmp := NewTemporalMemoryParams() 159 | tmp.MinThreshold = 1 160 | tm := NewTemporalMemory(tmp) 161 | connections := tm.Connections 162 | 163 | connections.CreateSegment(0) 164 | connections.CreateSynapse(0, 23, 0.6) 165 | connections.CreateSynapse(0, 37, 0.4) 166 | connections.CreateSynapse(0, 477, 0.9) 167 | connections.CreateSegment(0) 168 | connections.CreateSynapse(1, 49, 0.9) 169 | connections.CreateSynapse(1, 3, 0.8) 170 | connections.CreateSegment(1) 171 | connections.CreateSynapse(2, 733, 0.7) 172 | connections.CreateSegment(8) 173 | connections.CreateSynapse(3, 486, 0.9) 174 | 175 | activeSynapsesForSegment := map[int][]int{ 176 | 0: []int{0, 1}, 177 | 1: []int{3}, 178 | 2: []int{5}, 179 | } 180 | 181 | bestCell, connectedSyns := tm.getBestMatchingSegment(0, activeSynapsesForSegment, connections) 182 | assert.Equal(t, 0, bestCell) 183 | assert.Equal(t, []int{0, 1}, connectedSyns) 184 | 185 | bestCell, connectedSyns = tm.getBestMatchingSegment(1, activeSynapsesForSegment, connections) 186 | assert.Equal(t, 2, bestCell) 187 | assert.Equal(t, []int{5}, connectedSyns) 188 | 189 | bestCell, connectedSyns = tm.getBestMatchingSegment(8, activeSynapsesForSegment, connections) 190 | assert.Equal(t, -1, bestCell) 191 | assert.Equal(t, []int(nil), connectedSyns) 192 | 193 | bestCell, connectedSyns = tm.getBestMatchingSegment(100, activeSynapsesForSegment, connections) 194 | assert.Equal(t, -1, bestCell) 195 | assert.Equal(t, []int(nil), connectedSyns) 196 | 197 | } 198 | 199 | func TestGetBestMatchingCellFewestSegments(t *testing.T) { 200 | tmp := NewTemporalMemoryParams() 201 | tmp.ColumnDimensions = []int{2} 202 | tmp.CellsPerColumn = 2 203 | tmp.MinThreshold = 1 204 | tm := NewTemporalMemory(tmp) 205 | connections := tm.Connections 206 | 207 | connections.CreateSegment(0) 208 | connections.CreateSynapse(0, 3, 0.3) 209 | activeSynapsesForSegment := map[int][]int{} 210 | 211 | for i := 0; i < 100; i++ { 212 | // Never pick cell 0, always pick cell 1 213 | cell, _ := tm.getBestMatchingCell(0, activeSynapsesForSegment, connections) 214 | assert.Equal(t, 1, cell) 215 | } 216 | 217 | } 218 | 219 | func TestGetBestMatchingCell(t *testing.T) { 220 | tmp := NewTemporalMemoryParams() 221 | tmp.MinThreshold = 1 222 | tm := NewTemporalMemory(tmp) 223 | connections := tm.Connections 224 | 225 | connections.CreateSegment(0) 226 | connections.CreateSynapse(0, 23, 0.6) 227 | connections.CreateSynapse(0, 37, 0.4) 228 | connections.CreateSynapse(0, 477, 0.9) 229 | connections.CreateSegment(0) 230 | connections.CreateSynapse(1, 49, 0.9) 231 | connections.CreateSynapse(1, 3, 0.8) 232 | connections.CreateSegment(1) 233 | connections.CreateSynapse(2, 733, 0.7) 234 | connections.CreateSegment(108) 235 | connections.CreateSynapse(3, 486, 0.9) 236 | 237 | activeSynapsesForSegment := map[int][]int{ 238 | 0: []int{0, 1}, 239 | 1: []int{3}, 240 | 2: []int{5}, 241 | } 242 | 243 | bestCell, bestSeg := tm.getBestMatchingCell(0, activeSynapsesForSegment, connections) 244 | assert.Equal(t, 0, bestCell) 245 | assert.Equal(t, 0, bestSeg) 246 | 247 | //randomly picked 248 | bestCell, bestSeg = tm.getBestMatchingCell(3, activeSynapsesForSegment, connections) 249 | assert.Equal(t, 99, bestCell) //random 250 | assert.Equal(t, -1, bestSeg) 251 | 252 | //randomly picked 253 | bestCell, bestSeg = tm.getBestMatchingCell(999, activeSynapsesForSegment, connections) 254 | assert.Equal(t, 31979, bestCell) //random 255 | assert.Equal(t, -1, bestSeg) 256 | 257 | } 258 | 259 | func TestComputeActiveSynapses(t *testing.T) { 260 | tmp := NewTemporalMemoryParams() 261 | //tmp.MinThreshold = 1 262 | tm := NewTemporalMemory(tmp) 263 | connections := tm.Connections 264 | 265 | connections.CreateSegment(0) 266 | connections.CreateSynapse(0, 23, 0.6) 267 | connections.CreateSynapse(0, 37, 0.4) 268 | connections.CreateSynapse(0, 477, 0.9) 269 | connections.CreateSegment(1) 270 | connections.CreateSynapse(1, 733, 0.7) 271 | connections.CreateSegment(8) 272 | connections.CreateSynapse(2, 486, 0.9) 273 | activeCells := []int{23, 37, 733, 4973} 274 | 275 | expected := map[int][]int{ 276 | 0: []int{0, 1}, 277 | 1: []int{3}, 278 | } 279 | assert.Equal(t, expected, tm.computeActiveSynapses(activeCells, connections)) 280 | 281 | } 282 | 283 | func TestComputePredictiveCells(t *testing.T) { 284 | 285 | tmp := NewTemporalMemoryParams() 286 | tm := NewTemporalMemory(tmp) 287 | connections := tm.Connections 288 | 289 | connections.CreateSegment(0) 290 | connections.CreateSynapse(0, 23, 0.6) 291 | connections.CreateSynapse(0, 37, 0.5) 292 | connections.CreateSynapse(0, 477, 0.9) 293 | connections.CreateSegment(1) 294 | connections.CreateSynapse(1, 733, 0.7) 295 | connections.CreateSynapse(1, 733, 0.4) 296 | connections.CreateSegment(1) 297 | connections.CreateSynapse(2, 974, 0.9) 298 | connections.CreateSegment(8) 299 | connections.CreateSynapse(3, 486, 0.9) 300 | connections.CreateSegment(100) 301 | 302 | activeSynapsesForSegment := map[int][]int{ 303 | 0: []int{0, 1}, 304 | 1: []int{3, 4}, 305 | 2: []int{5}, 306 | } 307 | 308 | activeSegments, predictiveCells := tm.computePredictiveCells(activeSynapsesForSegment, connections) 309 | //TODO: numentas returns [0] 310 | assert.Equal(t, []int(nil), activeSegments) 311 | assert.Equal(t, []int(nil), predictiveCells) 312 | 313 | } 314 | 315 | func TestLearnOnSegments(t *testing.T) { 316 | tmp := NewTemporalMemoryParams() 317 | tm := NewTemporalMemory(tmp) 318 | connections := tm.Connections 319 | connections.CreateSegment(0) 320 | connections.CreateSynapse(0, 23, 0.6) 321 | connections.CreateSynapse(0, 37, 0.4) 322 | connections.CreateSynapse(0, 477, 0.9) 323 | connections.CreateSegment(1) 324 | connections.CreateSynapse(1, 733, 0.7) 325 | connections.CreateSegment(8) 326 | connections.CreateSynapse(2, 486, 0.9) 327 | connections.CreateSegment(100) 328 | 329 | prevActiveSegments := []int{0, 2} 330 | learningSegments := []int{1, 3} 331 | 332 | prevActiveSynapsesForSegment := map[int][]int{ 333 | 0: []int{0, 1}, 334 | 1: []int{3}, 335 | } 336 | 337 | winnerCells := []int{0} 338 | 339 | prevWinnerCells := []int{10, 11, 12, 13, 14} 340 | 341 | tm.learnOnSegments(prevActiveSegments, 342 | learningSegments, 343 | prevActiveSynapsesForSegment, 344 | winnerCells, 345 | prevWinnerCells, 346 | connections) 347 | 348 | //Check segment 0 349 | assert.Equal(t, 0.7, connections.DataForSynapse(0).Permanence) 350 | assert.Equal(t, 0.5, connections.DataForSynapse(1).Permanence) 351 | assert.Equal(t, 0.8, connections.DataForSynapse(2).Permanence) 352 | 353 | //Check segment 1 354 | assert.InEpsilon(t, 0.8, connections.DataForSynapse(3).Permanence, 0.1) 355 | assert.Equal(t, 2, len(connections.synapsesForSegment[1])) 356 | 357 | //Check segment 2 358 | assert.Equal(t, 0.9, connections.DataForSynapse(4).Permanence) 359 | assert.Equal(t, 1, len(connections.synapsesForSegment[2])) 360 | 361 | // Check segment 3 362 | assert.Equal(t, 1, len(connections.synapsesForSegment[3])) 363 | 364 | } 365 | 366 | func TestBurstColumnsEmpty(t *testing.T) { 367 | tmp := NewTemporalMemoryParams() 368 | tm := NewTemporalMemory(tmp) 369 | connections := tm.Connections 370 | activeColumns := []int{} 371 | predictedColumns := []int{} 372 | prevActiveSynapsesForSegment := make(map[int][]int) 373 | 374 | activeCells, winnerCells, learningSegments := tm.burstColumns(activeColumns, 375 | predictedColumns, 376 | prevActiveSynapsesForSegment, 377 | connections) 378 | 379 | assert.Equal(t, []int(nil), activeCells) 380 | assert.Equal(t, []int(nil), winnerCells) 381 | assert.Equal(t, []int(nil), learningSegments) 382 | 383 | } 384 | 385 | func TestBurstColumns(t *testing.T) { 386 | tmp := NewTemporalMemoryParams() 387 | tmp.CellsPerColumn = 4 388 | tmp.MinThreshold = 1 389 | tm := NewTemporalMemory(tmp) 390 | connections := tm.Connections 391 | 392 | connections.CreateSegment(0) 393 | connections.CreateSynapse(0, 23, 0.6) 394 | connections.CreateSynapse(0, 37, 0.4) 395 | connections.CreateSynapse(0, 477, 0.9) 396 | connections.CreateSegment(0) 397 | connections.CreateSynapse(1, 49, 0.9) 398 | connections.CreateSynapse(1, 3, 0.8) 399 | connections.CreateSegment(1) 400 | connections.CreateSynapse(2, 733, 0.7) 401 | connections.CreateSegment(108) 402 | connections.CreateSynapse(3, 486, 0.9) 403 | 404 | activeColumns := []int{0, 1, 26} 405 | predictedColumns := []int{26} 406 | 407 | prevActiveSynapsesForSegment := map[int][]int{ 408 | 0: []int{0, 1}, 409 | 1: []int{3}, 410 | 2: []int{5}, 411 | } 412 | 413 | activeCells, winnerCells, learningSegments := tm.burstColumns(activeColumns, 414 | predictedColumns, 415 | prevActiveSynapsesForSegment, 416 | connections) 417 | 418 | assert.Equal(t, []int{0, 1, 2, 3, 4, 5, 6, 7}, activeCells) 419 | assert.Equal(t, []int{0, 5}, winnerCells) //5 is randomly chosen cell 420 | assert.Equal(t, []int{0, 4}, learningSegments) //4 is new segment created 421 | //Check that new segment was added to winner cell (4) in column 1 422 | assert.Equal(t, []int{4}, connections.segmentsForCell[5]) 423 | 424 | } 425 | 426 | func TestActivateCorrectlyPredictiveCellsEmpty(t *testing.T) { 427 | tmp := NewTemporalMemoryParams() 428 | tmp.CellsPerColumn = 4 429 | tmp.MinThreshold = 1 430 | tm := NewTemporalMemory(tmp) 431 | connections := tm.Connections 432 | 433 | prevPredictiveCells := []int{} 434 | activeColumns := []int{} 435 | 436 | activeCells, winnerCells, predictedColumns := tm.activateCorrectlyPredictiveCells(prevPredictiveCells, 437 | activeColumns, 438 | connections) 439 | 440 | assert.Equal(t, []int(nil), activeCells) 441 | assert.Equal(t, []int(nil), winnerCells) 442 | assert.Equal(t, []int(nil), predictedColumns) 443 | 444 | // No previous predictive cells 445 | prevPredictiveCells = []int{} 446 | activeColumns = []int{32, 47, 823} 447 | 448 | activeCells, winnerCells, predictedColumns = tm.activateCorrectlyPredictiveCells(prevPredictiveCells, 449 | activeColumns, 450 | connections) 451 | 452 | assert.Equal(t, []int(nil), activeCells) 453 | assert.Equal(t, []int(nil), winnerCells) 454 | assert.Equal(t, []int(nil), predictedColumns) 455 | 456 | // No active columns 457 | prevPredictiveCells = []int{0, 237, 1026, 26337, 26339, 55536} 458 | activeColumns = []int{} 459 | 460 | activeCells, winnerCells, predictedColumns = tm.activateCorrectlyPredictiveCells(prevPredictiveCells, 461 | activeColumns, 462 | connections) 463 | 464 | assert.Equal(t, []int(nil), activeCells) 465 | assert.Equal(t, []int(nil), winnerCells) 466 | assert.Equal(t, []int(nil), predictedColumns) 467 | 468 | } 469 | 470 | func TestActivateCorrectlyPredictiveCells(t *testing.T) { 471 | tmp := NewTemporalMemoryParams() 472 | tm := NewTemporalMemory(tmp) 473 | connections := tm.Connections 474 | prevPredictiveCells := []int{0, 237, 1026, 26337, 26339, 55536} 475 | activeColumns := []int{32, 47, 823} 476 | 477 | activeCells, winnerCells, predictedColumns := tm.activateCorrectlyPredictiveCells(prevPredictiveCells, 478 | activeColumns, 479 | connections) 480 | 481 | assert.Equal(t, []int{1026, 26337, 26339}, activeCells) 482 | assert.Equal(t, []int{1026, 26337, 26339}, winnerCells) 483 | assert.Equal(t, []int{32, 823}, predictedColumns) 484 | 485 | } 486 | -------------------------------------------------------------------------------- /temporalMemory.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | //"fmt" 5 | "github.com/cznic/mathutil" 6 | // "github.com/zacg/floats" 7 | // "github.com/zacg/go.matrix" 8 | "github.com/nupic-community/htm/utils" 9 | //"github.com/zacg/ints" 10 | "math" 11 | "math/rand" 12 | // //"sort" 13 | ) 14 | 15 | /* 16 | Params for intializing temporal memory 17 | */ 18 | type TemporalMemoryParams struct { 19 | //Column dimensions 20 | ColumnDimensions []int 21 | CellsPerColumn int 22 | //If the number of active connected synapses on a segment is at least 23 | //this threshold, the segment is said to be active. 24 | ActivationThreshold int 25 | //Radius around cell from which it can sample to form distal dendrite 26 | //connections. 27 | LearningRadius int 28 | InitialPermanence float64 29 | //If the permanence value for a synapse is greater than this value, it is said 30 | //to be connected. 31 | ConnectedPermanence float64 32 | //If the number of synapses active on a segment is at least this threshold, 33 | //it is selected as the best matching cell in a bursing column. 34 | MinThreshold int 35 | //The maximum number of synapses added to a segment during learning. 36 | MaxNewSynapseCount int 37 | PermanenceIncrement float64 38 | PermanenceDecrement float64 39 | //rand seed 40 | Seed int 41 | } 42 | 43 | //Create default temporal memory params 44 | func NewTemporalMemoryParams() *TemporalMemoryParams { 45 | p := new(TemporalMemoryParams) 46 | 47 | p.ColumnDimensions = []int{2048} 48 | p.CellsPerColumn = 32 49 | p.ActivationThreshold = 13 50 | p.LearningRadius = 2048 51 | p.InitialPermanence = 0.21 52 | p.ConnectedPermanence = 0.50 53 | p.MinThreshold = 10 54 | p.MaxNewSynapseCount = 20 55 | p.PermanenceIncrement = 0.10 56 | p.PermanenceDecrement = 0.10 57 | p.Seed = 42 58 | 59 | return p 60 | } 61 | 62 | /* 63 | Temporal memory 64 | */ 65 | type TemporalMemory struct { 66 | params *TemporalMemoryParams 67 | ActiveCells []int 68 | PredictiveCells []int 69 | ActiveSegments []int 70 | ActiveSynapsesForSegment map[int][]int 71 | WinnerCells []int 72 | Connections *TemporalMemoryConnections 73 | } 74 | 75 | //Create new temporal memory 76 | func NewTemporalMemory(params *TemporalMemoryParams) *TemporalMemory { 77 | tm := new(TemporalMemory) 78 | tm.params = params 79 | tm.Connections = NewTemporalMemoryConnections(params.MaxNewSynapseCount, 80 | params.CellsPerColumn, params.ColumnDimensions) 81 | //TODO: refactor into encapsulated RNG 82 | rand.Seed(int64(params.Seed)) 83 | return tm 84 | } 85 | 86 | //Feeds input record through TM, performing inference and learning. 87 | //Updates member variables with new state. 88 | func (tm *TemporalMemory) Compute(activeColumns []int, learn bool) { 89 | 90 | activeCells, winnerCells, activeSynapsesForSegment, activeSegments, predictiveCells := tm.computeFn(activeColumns, 91 | tm.PredictiveCells, 92 | tm.ActiveSegments, 93 | tm.ActiveSynapsesForSegment, 94 | tm.WinnerCells, 95 | tm.Connections, 96 | learn) 97 | 98 | tm.ActiveCells = activeCells 99 | tm.WinnerCells = winnerCells 100 | tm.ActiveSynapsesForSegment = activeSynapsesForSegment 101 | tm.ActiveSegments = activeSegments 102 | tm.PredictiveCells = predictiveCells 103 | 104 | } 105 | 106 | // helper for compute(). 107 | //Returns new state 108 | func (tm *TemporalMemory) computeFn(activeColumns []int, 109 | prevPredictiveCells []int, 110 | prevActiveSegments []int, 111 | prevActiveSynapsesForSegment map[int][]int, 112 | prevWinnerCells []int, 113 | connections *TemporalMemoryConnections, 114 | learn bool) (activeCells []int, 115 | winnerCells []int, 116 | activeSynapsesForSegment map[int][]int, 117 | activeSegments []int, 118 | predictiveCells []int) { 119 | 120 | var predictedColumns []int 121 | 122 | activeCells, winnerCells, predictedColumns = tm.activateCorrectlyPredictiveCells( 123 | prevPredictiveCells, 124 | activeColumns, 125 | connections) 126 | 127 | _activeCells, _winnerCells, learningSegments := tm.burstColumns(activeColumns, 128 | predictedColumns, 129 | prevActiveSynapsesForSegment, 130 | connections) 131 | 132 | utils.Add(activeCells, _activeCells) 133 | utils.Add(winnerCells, _winnerCells) 134 | 135 | if learn { 136 | tm.learnOnSegments(prevActiveSegments, 137 | learningSegments, 138 | prevActiveSynapsesForSegment, 139 | winnerCells, 140 | prevWinnerCells, 141 | connections) 142 | } 143 | 144 | activeSynapsesForSegment = tm.computeActiveSynapses(activeCells, connections) 145 | 146 | activeSegments, predictiveCells = tm.computePredictiveCells(activeSynapsesForSegment, 147 | connections) 148 | 149 | return activeCells, 150 | winnerCells, 151 | activeSynapsesForSegment, 152 | activeSegments, 153 | predictiveCells 154 | 155 | } 156 | 157 | //Indicates the start of a new sequence. Resets sequence state of the TM. 158 | func (tm *TemporalMemory) Reset() { 159 | tm.ActiveCells = tm.ActiveCells[:0] 160 | tm.PredictiveCells = tm.PredictiveCells[:0] 161 | tm.ActiveSegments = tm.ActiveSegments[:0] 162 | tm.WinnerCells = tm.WinnerCells[:0] 163 | } 164 | 165 | /* 166 | Phase 1: Activate the correctly predictive cells. 167 | Pseudocode: 168 | - for each prev predictive cell 169 | - if in active column 170 | - mark it as active 171 | - mark it as winner cell 172 | - mark column as predicted 173 | */ 174 | func (tm *TemporalMemory) activateCorrectlyPredictiveCells(prevPredictiveCells []int, 175 | activeColumns []int, 176 | connections *TemporalMemoryConnections) (activeCells []int, 177 | winnerCells []int, 178 | predictedColumns []int) { 179 | 180 | for _, cell := range prevPredictiveCells { 181 | column := connections.ColumnForCell(cell) 182 | if utils.ContainsInt(column, activeColumns) { 183 | activeCells = append(activeCells, cell) 184 | winnerCells = append(winnerCells, cell) 185 | //TODO: change this to a set data structure 186 | if !utils.ContainsInt(column, predictedColumns) { 187 | predictedColumns = append(predictedColumns, column) 188 | } 189 | } 190 | } 191 | 192 | return activeCells, winnerCells, predictedColumns 193 | } 194 | 195 | /* 196 | Phase 2: Burst unpredicted columns. 197 | Pseudocode: 198 | - for each unpredicted active column 199 | - mark all cells as active 200 | - mark the best matching cell as winner cell 201 | - (learning) 202 | - if it has no matching segment 203 | - (optimization) if there are prev winner cells 204 | - add a segment to it 205 | - mark the segment as learning 206 | */ 207 | func (tm *TemporalMemory) burstColumns(activeColumns []int, 208 | predictedColumns []int, 209 | prevActiveSynapsesForSegment map[int][]int, 210 | connections *TemporalMemoryConnections) (activeCells []int, 211 | winnerCells []int, 212 | learningSegments []int) { 213 | 214 | unpredictedColumns := utils.Complement(activeColumns, predictedColumns) 215 | 216 | for _, column := range unpredictedColumns { 217 | cells := connections.CellsForColumn(column) 218 | activeCells = utils.Add(activeCells, cells) 219 | 220 | bestCell, bestSegment := tm.getBestMatchingCell(column, 221 | prevActiveSynapsesForSegment, 222 | connections) 223 | 224 | winnerCells = append(winnerCells, bestCell) 225 | 226 | if bestSegment == -1 { 227 | //TODO: (optimization) Only do this if there are prev winner cells 228 | bestSegment = connections.CreateSegment(bestCell) 229 | } 230 | //TODO: change to set data structure 231 | if !utils.ContainsInt(bestSegment, learningSegments) { 232 | learningSegments = append(learningSegments, bestSegment) 233 | } 234 | } 235 | 236 | return activeCells, winnerCells, learningSegments 237 | } 238 | 239 | /* 240 | Phase 3: Perform learning by adapting segments. 241 | Pseudocode: 242 | - (learning) for each prev active or learning segment 243 | - if learning segment or from winner cell 244 | - strengthen active synapses 245 | - weaken inactive synapses 246 | - if learning segment 247 | - add some synapses to the segment 248 | - subsample from prev winner cells 249 | */ 250 | func (tm *TemporalMemory) learnOnSegments(prevActiveSegments []int, 251 | learningSegments []int, 252 | prevActiveSynapsesForSegment map[int][]int, 253 | winnerCells []int, 254 | prevWinnerCells []int, 255 | connections *TemporalMemoryConnections) { 256 | 257 | tm.lrnOnSegments(prevActiveSegments, false, prevActiveSynapsesForSegment, winnerCells, prevWinnerCells, connections) 258 | tm.lrnOnSegments(learningSegments, true, prevActiveSynapsesForSegment, winnerCells, prevWinnerCells, connections) 259 | 260 | } 261 | 262 | //helper 263 | func (tm *TemporalMemory) lrnOnSegments(segments []int, 264 | isLearningSegments bool, 265 | prevActiveSynapsesForSegment map[int][]int, 266 | winnerCells []int, 267 | prevWinnerCells []int, 268 | connections *TemporalMemoryConnections) { 269 | 270 | for _, segment := range segments { 271 | isFromWinnerCell := utils.ContainsInt(connections.CellForSegment(segment), winnerCells) 272 | activeSynapses := tm.getConnectedActiveSynapsesForSegment(segment, 273 | prevActiveSynapsesForSegment, 274 | 0, 275 | connections) 276 | 277 | if isLearningSegments || isFromWinnerCell { 278 | tm.adaptSegment(segment, activeSynapses, connections) 279 | } 280 | 281 | if isLearningSegments { 282 | n := tm.params.MaxNewSynapseCount - len(activeSynapses) 283 | for _, sourceCell := range tm.pickCellsToLearnOn(n, 284 | segment, 285 | winnerCells, 286 | connections) { 287 | connections.CreateSynapse(segment, sourceCell, tm.params.InitialPermanence) 288 | } 289 | } 290 | 291 | } 292 | 293 | } 294 | 295 | /* 296 | Phase 4: Compute predictive cells due to lateral input 297 | on distal dendrites. 298 | 299 | Pseudocode: 300 | 301 | - for each distal dendrite segment with activity >= activationThreshold 302 | - mark the segment as active 303 | - mark the cell as predictive 304 | */ 305 | func (tm *TemporalMemory) computePredictiveCells(activeSynapsesForSegment map[int][]int, 306 | connections *TemporalMemoryConnections) (activeSegments []int, predictiveCells []int) { 307 | 308 | for segment, _ := range activeSynapsesForSegment { 309 | synapses := tm.getConnectedActiveSynapsesForSegment(segment, 310 | activeSynapsesForSegment, 311 | tm.params.ConnectedPermanence, 312 | connections) 313 | if len(synapses) >= tm.params.ActivationThreshold { 314 | activeSegments = append(activeSegments, segment) 315 | predictiveCells = append(predictiveCells, connections.CellForSegment(segment)) 316 | } 317 | } 318 | 319 | return activeSegments, predictiveCells 320 | } 321 | 322 | // Forward propagates activity from active cells to the synapses that touch 323 | // them, to determine which synapses are active. 324 | func (tm *TemporalMemory) computeActiveSynapses(activeCells []int, 325 | connections *TemporalMemoryConnections) map[int][]int { 326 | 327 | activeSynapsesForSegment := make(map[int][]int) 328 | 329 | for _, cell := range activeCells { 330 | for _, synapse := range connections.SynapsesForSourceCell(cell) { 331 | segment := connections.DataForSynapse(synapse).Segment 332 | activeSynapsesForSegment[segment] = append(activeSynapsesForSegment[segment], synapse) 333 | } 334 | } 335 | 336 | return activeSynapsesForSegment 337 | } 338 | 339 | // Gets the cell with the best matching segment 340 | //(see `TM.getBestMatchingSegment`) that has the largest number of active 341 | //synapses of all best matching segments. 342 | //If none were found, pick the least used cell (see `TM.getLeastUsedCell`). 343 | func (tm *TemporalMemory) getBestMatchingCell(column int, activeSynapsesForSegment map[int][]int, 344 | connections *TemporalMemoryConnections) (bestCell int, bestSegment int) { 345 | bestCell = -1 346 | bestSegment = -1 347 | 348 | maxSynapses := 0 349 | cells := connections.CellsForColumn(column) 350 | 351 | for _, cell := range cells { 352 | segment, connectedActiveSynapses := tm.getBestMatchingSegment(cell, 353 | activeSynapsesForSegment, 354 | connections) 355 | 356 | if segment > -1 && len(connectedActiveSynapses) > maxSynapses { 357 | maxSynapses = len(connectedActiveSynapses) 358 | bestCell = cell 359 | bestSegment = segment 360 | } 361 | } 362 | 363 | if bestCell == -1 { 364 | bestCell = tm.getLeastUsedCell(column, connections) 365 | } 366 | 367 | return bestCell, bestSegment 368 | } 369 | 370 | // Gets the segment on a cell with the largest number of activate synapses, 371 | // including all synapses with non-zero permanences. 372 | func (tm *TemporalMemory) getBestMatchingSegment(cell int, activeSynapsesForSegment map[int][]int, 373 | connections *TemporalMemoryConnections) (bestSegment int, connectedActiveSynapses []int) { 374 | maxSynapses := tm.params.MinThreshold 375 | bestSegment = -1 376 | 377 | for _, segment := range connections.SegmentsForCell(cell) { 378 | synapses := tm.getConnectedActiveSynapsesForSegment(segment, 379 | activeSynapsesForSegment, 380 | 0, 381 | connections) 382 | 383 | if len(synapses) >= maxSynapses { 384 | maxSynapses = len(synapses) 385 | bestSegment = segment 386 | connectedActiveSynapses = synapses 387 | } 388 | 389 | } 390 | 391 | return bestSegment, connectedActiveSynapses 392 | } 393 | 394 | // Gets the cell with the smallest number of segments. 395 | // Break ties randomly. 396 | func (tm *TemporalMemory) getLeastUsedCell(column int, connections *TemporalMemoryConnections) int { 397 | cells := connections.CellsForColumn(column) 398 | leastUsedCells := make([]int, 0, len(cells)) 399 | minNumSegments := math.MaxInt64 400 | 401 | for _, cell := range cells { 402 | numSegments := len(connections.SegmentsForCell(cell)) 403 | 404 | if numSegments < minNumSegments { 405 | minNumSegments = numSegments 406 | leastUsedCells = leastUsedCells[:0] 407 | } 408 | 409 | if numSegments == minNumSegments { 410 | leastUsedCells = append(leastUsedCells, cell) 411 | } 412 | } 413 | 414 | //pick random cell 415 | return leastUsedCells[rand.Intn(len(leastUsedCells))] 416 | } 417 | 418 | //Returns the synapses on a segment that are active due to lateral input 419 | //from active cells. 420 | func (tm *TemporalMemory) getConnectedActiveSynapsesForSegment(segment int, 421 | activeSynapsesForSegment map[int][]int, permanenceThreshold float64, connections *TemporalMemoryConnections) []int { 422 | 423 | if _, ok := activeSynapsesForSegment[segment]; !ok { 424 | return []int{} 425 | } 426 | 427 | connectedSynapses := make([]int, 0, len(activeSynapsesForSegment)) 428 | 429 | //TODO: (optimization) Can skip this logic if permanenceThreshold = 0 430 | for _, synIdx := range activeSynapsesForSegment[segment] { 431 | perm := connections.DataForSynapse(synIdx).Permanence 432 | if perm >= permanenceThreshold { 433 | connectedSynapses = append(connectedSynapses, synIdx) 434 | } 435 | } 436 | 437 | return connectedSynapses 438 | } 439 | 440 | // Updates synapses on segment. 441 | // Strengthens active synapses; weakens inactive synapses. 442 | func (tm *TemporalMemory) adaptSegment(segment int, activeSynapses []int, 443 | connections *TemporalMemoryConnections) { 444 | 445 | for _, synIdx := range connections.SynapsesForSegment(segment) { 446 | syn := connections.DataForSynapse(synIdx) 447 | perm := syn.Permanence 448 | 449 | if utils.ContainsInt(synIdx, activeSynapses) { 450 | perm += tm.params.PermanenceIncrement 451 | } else { 452 | perm -= tm.params.PermanenceDecrement 453 | } 454 | //enforce min/max bounds 455 | perm = math.Max(0.0, math.Min(1.0, perm)) 456 | connections.UpdateSynapsePermanence(synIdx, perm) 457 | } 458 | 459 | } 460 | 461 | //Pick cells to form distal connections to. 462 | func (tm *TemporalMemory) pickCellsToLearnOn(n int, segment int, 463 | winnerCells []int, connections *TemporalMemoryConnections) []int { 464 | 465 | candidates := make([]int, len(winnerCells)) 466 | copy(candidates, winnerCells) 467 | 468 | for _, val := range connections.SynapsesForSegment(segment) { 469 | syn := connections.DataForSynapse(val) 470 | for idx, val := range candidates { 471 | if val == syn.SourceCell { 472 | candidates = append(candidates[:idx], candidates[idx+1:]...) 473 | break 474 | } 475 | } 476 | } 477 | 478 | //Shuffle candidates 479 | for i := range candidates { 480 | j := rand.Intn(i + 1) 481 | candidates[i], candidates[j] = candidates[j], candidates[i] 482 | } 483 | 484 | n = mathutil.Min(n, len(candidates)) 485 | return candidates[:n] 486 | } 487 | -------------------------------------------------------------------------------- /encoders/scalerEncoder.go: -------------------------------------------------------------------------------- 1 | package encoders 2 | 3 | import ( 4 | "fmt" 5 | //"github.com/cznic/mathutil" 6 | //"github.com/zacg/floats" 7 | "github.com/nupic-community/htm" 8 | "github.com/nupic-community/htm/utils" 9 | "github.com/zacg/ints" 10 | "math" 11 | ) 12 | 13 | /* 14 | n -- The number of bits in the output. Must be greater than or equal to w 15 | 16 | radius -- Two inputs separated by more than the radius have non-overlapping 17 | representations. Two inputs separated by less than the radius will 18 | in general overlap in at least some of their bits. You can think 19 | of this as the radius of the input. 20 | 21 | resolution -- Two inputs separated by greater than, or equal to the resolution are guaranteed 22 | to have different representations. 23 | */ 24 | type ScalerOutputType int 25 | 26 | const ( 27 | N ScalerOutputType = 1 28 | Radius ScalerOutputType = 2 29 | Resolution ScalerOutputType = 3 30 | ) 31 | 32 | type ScalerEncoderParams struct { 33 | Width int 34 | MinVal float64 35 | MaxVal float64 36 | Periodic bool 37 | OutputType ScalerOutputType 38 | Range float64 39 | Resolution float64 40 | Name string 41 | Radius float64 42 | ClipInput bool 43 | Verbosity int 44 | N int 45 | } 46 | 47 | func NewScalerEncoderParams(width int, minVal float64, maxVal float64) *ScalerEncoderParams { 48 | p := new(ScalerEncoderParams) 49 | 50 | p.Width = width 51 | p.MinVal = minVal 52 | p.MaxVal = maxVal 53 | p.N = 0 54 | p.Radius = 0 55 | p.Resolution = 0 56 | p.Name = "" 57 | p.Verbosity = 0 58 | p.ClipInput = false 59 | 60 | return p 61 | } 62 | 63 | /* 64 | A scalar encoder encodes a numeric (floating point) value into an array 65 | of bits. The output is 0's except for a contiguous block of 1's. The 66 | location of this contiguous block varies continuously with the input value. 67 | 68 | The encoding is linear. If you want a nonlinear encoding, just transform 69 | the scalar (e.g. by applying a logarithm function) before encoding. 70 | It is not recommended to bin the data as a pre-processing step, e.g. 71 | "1" = $0 - $.20, "2" = $.21-$0.80, "3" = $.81-$1.20, etc. as this 72 | removes a lot of information and prevents nearby values from overlapping 73 | in the output. Instead, use a continuous transformation that scales 74 | the data (a piecewise transformation is fine). 75 | */ 76 | type ScalerEncoder struct { 77 | ScalerEncoderParams 78 | 79 | padding int 80 | halfWidth int 81 | rangeInternal float64 82 | topDownMappingM *htm.SparseBinaryMatrix 83 | topDownValues []float64 84 | bucketValues []float64 85 | //nInternal represents the output area excluding the possible padding on each 86 | nInternal int 87 | } 88 | 89 | func NewScalerEncoder(p *ScalerEncoderParams) *ScalerEncoder { 90 | se := new(ScalerEncoder) 91 | se.ScalerEncoderParams = *p 92 | 93 | if se.Width%2 == 0 { 94 | panic("Width must be an odd number.") 95 | } 96 | 97 | se.halfWidth = (se.Width - 1) / 2 98 | 99 | /* For non-periodic inputs, padding is the number of bits "outside" the range, 100 | on each side. I.e. the representation of minval is centered on some bit, and 101 | there are "padding" bits to the left of that centered bit; similarly with 102 | bits to the right of the center bit of maxval*/ 103 | if !se.Periodic { 104 | se.padding = se.halfWidth 105 | } 106 | 107 | if se.MinVal >= se.MaxVal { 108 | panic("MinVal must be less than MaxVal") 109 | } 110 | 111 | se.rangeInternal = se.MaxVal - se.MinVal 112 | 113 | // There are three different ways of thinking about the representation. Handle 114 | // each case here. 115 | se.initEncoder(se.Width, se.MinVal, se.MaxVal, se.N, 116 | se.Radius, se.Resolution) 117 | 118 | // nInternal represents the output area excluding the possible padding on each 119 | // side 120 | se.nInternal = se.N - 2*se.padding 121 | 122 | // Our name 123 | if len(se.Name) == 0 { 124 | se.Name = fmt.Sprintf("[%v:%v]", se.MinVal, se.MaxVal) 125 | } 126 | 127 | if se.Width < 21 { 128 | fmt.Println("Number of bits in the SDR must be greater than 21") 129 | } 130 | 131 | return se 132 | } 133 | 134 | /* 135 | helper used to inititalize the encoder 136 | */ 137 | func (se *ScalerEncoder) initEncoder(width int, minval float64, maxval float64, n int, 138 | radius float64, resolution float64) { 139 | //handle 3 diff ways of representation 140 | 141 | if n != 0 { 142 | //crutches ;( 143 | if radius != 0 { 144 | panic("radius is not 0") 145 | } 146 | if resolution != 0 { 147 | panic("resolution is not 0") 148 | } 149 | if n <= width { 150 | panic("n less than width") 151 | } 152 | 153 | se.N = n 154 | 155 | //if (minval is not None and maxval is not None){ 156 | 157 | if !se.Periodic { 158 | se.Resolution = se.rangeInternal / float64(se.N-se.Width) 159 | } else { 160 | se.Resolution = se.rangeInternal / float64(se.N) 161 | } 162 | 163 | se.Radius = float64(se.Width) * se.Resolution 164 | 165 | if se.Periodic { 166 | se.Range = se.rangeInternal 167 | } else { 168 | se.Range = se.rangeInternal + se.Resolution 169 | } 170 | 171 | } else { //n == 0 172 | if radius != 0 { 173 | if resolution != 0 { 174 | panic("resolution not 0") 175 | } 176 | se.Radius = radius 177 | se.Resolution = se.Radius / float64(width) 178 | } else if resolution != 0 { 179 | se.Resolution = resolution 180 | se.Radius = se.Resolution * float64(se.Width) 181 | } else { 182 | panic("One of n, radius, resolution must be set") 183 | } 184 | 185 | if se.Periodic { 186 | se.Range = se.rangeInternal 187 | } else { 188 | se.Range = se.rangeInternal + se.Resolution 189 | } 190 | 191 | nfloat := float64(se.Width)*(se.Range/se.Radius) + 2*float64(se.padding) 192 | se.N = int(math.Ceil(nfloat)) 193 | 194 | } 195 | 196 | } 197 | 198 | /* 199 | recalculate encoder parameters and name 200 | */ 201 | func (se *ScalerEncoder) recalcParams() { 202 | se.rangeInternal = se.MaxVal - se.MinVal 203 | 204 | if !se.Periodic { 205 | se.Resolution = se.rangeInternal/float64(se.N) - float64(se.Width) 206 | } else { 207 | se.Resolution = se.rangeInternal / float64(se.N) 208 | } 209 | 210 | se.Radius = float64(se.Width) * se.Resolution 211 | 212 | if se.Periodic { 213 | se.Range = se.rangeInternal 214 | } else { 215 | se.Range = se.rangeInternal + se.Resolution 216 | } 217 | 218 | se.Name = fmt.Sprintf("[%v:%v]", se.MinVal, se.MaxVal) 219 | 220 | } 221 | 222 | /* Return the bit offset of the first bit to be set in the encoder output. 223 | For periodic encoders, this can be a negative number when the encoded output 224 | wraps around. */ 225 | func (se *ScalerEncoder) getFirstOnBit(input float64) int { 226 | 227 | //if input == SENTINEL_VALUE_FOR_MISSING_DATA: 228 | // return [None] 229 | //else: 230 | 231 | if input < se.MinVal { 232 | //Don't clip periodic inputs. Out-of-range input is always an error 233 | if se.ClipInput && !se.Periodic { 234 | 235 | if se.Verbosity > 0 { 236 | fmt.Printf("Clipped input %v=%v to minval %v", se.Name, input, se.MinVal) 237 | } 238 | input = se.MinVal 239 | } else { 240 | panic(fmt.Sprintf("Input %v less than range %v - %v", input, se.MinVal, se.MaxVal)) 241 | } 242 | 243 | if se.Periodic { 244 | 245 | // Don't clip periodic inputs. Out-of-range input is always an error 246 | if input >= se.MaxVal { 247 | panic(fmt.Sprintf("input %v greater than periodic range %v - %v", input, se.MinVal, se.MaxVal)) 248 | } 249 | 250 | } else { 251 | 252 | if input > se.MaxVal { 253 | if se.ClipInput { 254 | if se.Verbosity > 0 { 255 | fmt.Printf("Clipped input %v=%v to maxval %v", se.Name, input, se.MaxVal) 256 | } 257 | input = se.MaxVal 258 | } else { 259 | panic(fmt.Sprintf("input %v greater than range (%v - %v)", input, se.MinVal, se.MaxVal)) 260 | } 261 | } 262 | } 263 | } 264 | 265 | centerbin := 0 266 | 267 | if se.Periodic { 268 | centerbin = int((input-se.MinVal)*float64(se.nInternal)/se.Range) + se.padding 269 | } else { 270 | centerbin = int(((input-se.MinVal)+se.Resolution/2)/se.Resolution) + se.padding 271 | } 272 | 273 | // We use the first bit to be set in the encoded output as the bucket index 274 | minbin := centerbin - se.halfWidth 275 | return minbin 276 | } 277 | 278 | /* 279 | Returns bucket index for given input 280 | */ 281 | func (se *ScalerEncoder) getBucketIndices(input float64) []int { 282 | 283 | minbin := se.getFirstOnBit(input) 284 | var bucketIdx int 285 | 286 | // For periodic encoders, the bucket index is the index of the center bit 287 | if se.Periodic { 288 | bucketIdx = minbin + se.halfWidth 289 | if bucketIdx < 0 { 290 | bucketIdx += se.N 291 | } 292 | } else { 293 | // for non-periodic encoders, the bucket index is the index of the left bit 294 | bucketIdx = minbin 295 | } 296 | 297 | return []int{bucketIdx} 298 | } 299 | 300 | /* 301 | Returns encoded input 302 | */ 303 | func (se *ScalerEncoder) Encode(input float64, learn bool) (output []bool) { 304 | output = make([]bool, se.N) 305 | se.EncodeToSlice(input, learn, output) 306 | return output 307 | } 308 | 309 | /* 310 | Encodes input to specified slice. Slice should be valid length 311 | */ 312 | func (se *ScalerEncoder) EncodeToSlice(input float64, learn bool, output []bool) { 313 | 314 | // Get the bucket index to use 315 | bucketIdx := se.getFirstOnBit(input) 316 | 317 | //if len(bucketIdx) { 318 | //This shouldn't get hit 319 | // panic("Missing input value") 320 | //TODO output[0:self.n] = 0 TODO: should all 1s, or random SDR be returned instead? 321 | //} else { 322 | // The bucket index is the index of the first bit to set in the output 323 | output = output[:se.N] 324 | 325 | minbin := bucketIdx 326 | maxbin := minbin + 2*se.halfWidth 327 | 328 | if se.Periodic { 329 | 330 | // Handle the edges by computing wrap-around 331 | if maxbin >= se.N { 332 | bottombins := maxbin - se.N + 1 333 | utils.FillSliceRangeBool(output, true, 0, bottombins) 334 | maxbin = se.N - 1 335 | } 336 | if minbin < 0 { 337 | topbins := -minbin 338 | utils.FillSliceRangeBool(output, true, se.N-topbins, (se.N - (se.N - topbins))) 339 | minbin = 0 340 | } 341 | 342 | } 343 | 344 | if minbin < 0 { 345 | panic("invalid minbin") 346 | } 347 | if maxbin >= se.N { 348 | panic("invalid maxbin") 349 | } 350 | 351 | // set the output (except for periodic wraparound) 352 | utils.FillSliceRangeBool(output, true, minbin, (maxbin+1)-minbin) 353 | 354 | if se.Verbosity >= 2 { 355 | fmt.Println("input:", input) 356 | fmt.Printf("half width:%v \n", se.Width) 357 | fmt.Printf("range: %v - %v \n", se.MinVal, se.MaxVal) 358 | fmt.Printf("n: %v width: %v resolution: %v \n", se.N, se.Width, se.Resolution) 359 | fmt.Printf("radius: %v periodic: %v \n", se.Radius, se.Periodic) 360 | fmt.Printf("output: %v \n", output) 361 | } 362 | 363 | //} 364 | 365 | } 366 | 367 | /* 368 | Return the interal topDownMappingM matrix used for handling the 369 | bucketInfo() and topDownCompute() methods. This is a matrix, one row per 370 | category (bucket) where each row contains the encoded output for that 371 | category. 372 | */ 373 | func (se *ScalerEncoder) getTopDownMapping() *htm.SparseBinaryMatrix { 374 | 375 | //if already calculated return 376 | if se.topDownMappingM != nil { 377 | return se.topDownMappingM 378 | } 379 | 380 | // The input scalar value corresponding to each possible output encoding 381 | if se.Periodic { 382 | se.topDownValues = make([]float64, 0, int(se.MaxVal-se.MinVal)) 383 | start := se.MinVal + se.Resolution/2.0 384 | idx := 0 385 | for i := start; i <= se.MaxVal; i += se.Resolution { 386 | se.topDownValues[idx] = i 387 | idx++ 388 | } 389 | } else { 390 | //Number of values is (max-min)/resolution 391 | se.topDownValues = make([]float64, int(math.Ceil((se.MaxVal-se.MinVal)/se.Resolution))) 392 | end := se.MaxVal + se.Resolution/2.0 393 | idx := 0 394 | for i := se.MinVal; i <= end; i += se.Resolution { 395 | se.topDownValues[idx] = i 396 | idx++ 397 | } 398 | } 399 | 400 | // Each row represents an encoded output pattern 401 | numCategories := len(se.topDownValues) 402 | 403 | se.topDownMappingM = htm.NewSparseBinaryMatrix(numCategories, se.N) 404 | 405 | for i := 0; i < numCategories; i++ { 406 | value := se.topDownValues[i] 407 | value = math.Max(value, se.MinVal) 408 | value = math.Min(value, se.MaxVal) 409 | 410 | outputSpace := se.Encode(value, false) 411 | se.topDownMappingM.SetRowFromDense(i, outputSpace) 412 | } 413 | 414 | return se.topDownMappingM 415 | 416 | } 417 | 418 | /* 419 | Returns input description for bucket. Numenta implementations iface returns 420 | set of tuples to support diff encoder types. 421 | */ 422 | func (se *ScalerEncoder) getBucketInfo(buckets []int) (value float64, encoding []bool) { 423 | 424 | //ensure topdownmapping matrix is calculated 425 | se.getTopDownMapping() 426 | 427 | // The "category" is simply the bucket index 428 | category := buckets[0] 429 | encoding = se.topDownMappingM.GetDenseRow(category) 430 | 431 | if se.Periodic { 432 | value = (se.MinVal + (se.Resolution / 2.0) + (float64(category) * se.Resolution)) 433 | } else { 434 | value = se.MinVal + (float64(category) * se.Resolution) 435 | } 436 | 437 | return value, encoding 438 | 439 | } 440 | 441 | /* 442 | Returns the value for each bucket defined by the encoder 443 | */ 444 | func (se *ScalerEncoder) getBucketValues() []float64 { 445 | 446 | if se.bucketValues == nil { 447 | topDownMappingM := se.getTopDownMapping() 448 | numBuckets := topDownMappingM.Height 449 | se.bucketValues = make([]float64, numBuckets) 450 | for i := 0; i < numBuckets; i++ { 451 | val, _ := se.getBucketInfo([]int{i}) 452 | se.bucketValues[i] = val 453 | } 454 | } 455 | 456 | return se.bucketValues 457 | } 458 | 459 | /* 460 | top down compute 461 | */ 462 | func (se *ScalerEncoder) topDownCompute(encoded []bool) float64 { 463 | 464 | topDownMappingM := se.getTopDownMapping() 465 | 466 | //find "closest" match 467 | comps := topDownMappingM.RowAndSum(encoded) 468 | _, category := ints.Max(comps) 469 | 470 | val, _ := se.getBucketInfo([]int{category}) 471 | return val 472 | 473 | } 474 | 475 | /* 476 | generates a text description of specified slice of ranges 477 | */ 478 | func (se *ScalerEncoder) generateRangeDescription(ranges []utils.TupleFloat) string { 479 | 480 | desc := "" 481 | numRanges := len(ranges) 482 | for idx, val := range ranges { 483 | if val.A == val.B { 484 | desc += fmt.Sprintf("%v-%v", val.A, val.B) 485 | } else { 486 | desc += fmt.Sprintf("%v", val.A) 487 | } 488 | if idx < numRanges-1 { 489 | desc += "," 490 | } 491 | } 492 | return desc 493 | 494 | } 495 | 496 | /* 497 | Decode an encoded sequence. Returns range of values 498 | */ 499 | func (se *ScalerEncoder) Decode(encoded []bool) []utils.TupleFloat { 500 | 501 | if !utils.AnyTrue(encoded) { 502 | return []utils.TupleFloat{} 503 | } 504 | 505 | tmpOutput := encoded[:se.N] 506 | 507 | // First, assume the input pool is not sampled 100%, and fill in the 508 | // "holes" in the encoded representation (which are likely to be present 509 | // if this is a coincidence that was learned by the SP). 510 | 511 | // Search for portions of the output that have "holes" 512 | maxZerosInARow := se.halfWidth 513 | 514 | for i := 0; i < maxZerosInARow; i++ { 515 | searchSeq := make([]bool, i+3) 516 | subLen := len(searchSeq) 517 | searchSeq[0] = true 518 | searchSeq[subLen-1] = true 519 | 520 | if se.Periodic { 521 | for j := 0; j < se.N; j++ { 522 | outputIndices := make([]int, subLen) 523 | 524 | for idx := range outputIndices { 525 | outputIndices[idx] = (j + idx) % se.N 526 | } 527 | 528 | if utils.BoolEq(searchSeq, utils.SubsetSliceBool(tmpOutput, outputIndices)) { 529 | utils.SetIdxBool(tmpOutput, outputIndices, true) 530 | } 531 | } 532 | 533 | } else { 534 | 535 | for j := 0; j < se.N-subLen+1; j++ { 536 | if utils.BoolEq(searchSeq, tmpOutput[j:j+subLen]) { 537 | utils.FillSliceRangeBool(tmpOutput, true, j, subLen) 538 | } 539 | } 540 | 541 | } 542 | 543 | } 544 | 545 | if se.Verbosity >= 2 { 546 | fmt.Println("raw output:", utils.Bool2Int(encoded[:se.N])) 547 | fmt.Println("filtered output:", utils.Bool2Int(tmpOutput)) 548 | } 549 | 550 | // ------------------------------------------------------------------------ 551 | // Find each run of 1's in sequence 552 | 553 | nz := utils.OnIndices(tmpOutput) 554 | //key = start index, value = run length 555 | runs := make([]utils.TupleInt, 0, len(nz)) 556 | 557 | runStart := -1 558 | runLen := 0 559 | 560 | for idx, val := range tmpOutput { 561 | if val { 562 | //increment or new idx 563 | if runStart == -1 { 564 | runStart = idx 565 | runLen = 0 566 | } 567 | runLen++ 568 | } else { 569 | if runStart != -1 { 570 | runs = append(runs, utils.TupleInt{runStart, runLen}) 571 | runStart = -1 572 | } 573 | 574 | } 575 | } 576 | 577 | if runStart != -1 { 578 | runs = append(runs, utils.TupleInt{runStart, runLen}) 579 | runStart = -1 580 | } 581 | 582 | // If we have a periodic encoder, merge the first and last run if they 583 | // both go all the way to the edges 584 | if se.Periodic && len(runs) > 1 { 585 | if runs[0].A == 0 && runs[len(runs)-1].A+runs[len(runs)-1].B == se.N { 586 | runs[len(runs)-1].B += runs[0].B 587 | runs = runs[1:] 588 | } 589 | } 590 | 591 | // ------------------------------------------------------------------------ 592 | // Now, for each group of 1's, determine the "left" and "right" edges, where 593 | // the "left" edge is inset by halfwidth and the "right" edge is inset by 594 | // halfwidth. 595 | // For a group of width w or less, the "left" and "right" edge are both at 596 | // the center position of the group. 597 | 598 | ranges := make([]utils.TupleFloat, 0, len(runs)+2) 599 | 600 | for _, val := range runs { 601 | var left, right int 602 | start := val.A 603 | length := val.B 604 | 605 | if length <= se.Width { 606 | right = start + length/2 607 | left = right 608 | } else { 609 | left = start + se.halfWidth 610 | right = start + length - 1 - se.halfWidth 611 | } 612 | 613 | var inMin, inMax float64 614 | 615 | // Convert to input space. 616 | if !se.Periodic { 617 | inMin = float64(left-se.padding)*se.Resolution + se.MinVal 618 | inMax = float64(right-se.padding)*se.Resolution + se.MinVal 619 | } else { 620 | inMin = float64(left-se.padding)*se.Range/float64(se.nInternal) + se.MinVal 621 | inMax = float64(right-se.padding)*se.Range/float64(se.nInternal) + se.MinVal 622 | } 623 | 624 | // Handle wrap-around if periodic 625 | if se.Periodic { 626 | if inMin >= se.MaxVal { 627 | inMin -= se.Range 628 | inMax -= se.Range 629 | } 630 | } 631 | 632 | // Clip low end 633 | if inMin < se.MinVal { 634 | inMin = se.MinVal 635 | } 636 | if inMax < se.MinVal { 637 | inMax = se.MinVal 638 | } 639 | 640 | // If we have a periodic encoder, and the max is past the edge, break into 641 | // 2 separate ranges 642 | 643 | if se.Periodic && inMax >= se.MaxVal { 644 | ranges = append(ranges, utils.TupleFloat{inMin, se.MaxVal}) 645 | ranges = append(ranges, utils.TupleFloat{se.MinVal, inMax - se.Range}) 646 | } else { 647 | //clip high end 648 | if inMax > se.MaxVal { 649 | inMax = se.MaxVal 650 | } 651 | if inMin > se.MaxVal { 652 | inMin = se.MaxVal 653 | } 654 | ranges = append(ranges, utils.TupleFloat{inMin, inMax}) 655 | } 656 | } 657 | 658 | //desc := se.generateRangeDescription(ranges) 659 | 660 | return ranges 661 | } 662 | -------------------------------------------------------------------------------- /spatialPooler.go: -------------------------------------------------------------------------------- 1 | package htm 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cznic/mathutil" 6 | "github.com/nupic-community/htm/utils" 7 | "github.com/skelterjohn/go.matrix" 8 | "math" 9 | "math/rand" 10 | "sort" 11 | ) 12 | 13 | type SpatialPooler struct { 14 | numColumns int 15 | numInputs int 16 | ColumnDimensions []int 17 | InputDimensions []int 18 | PotentialRadius int 19 | PotentialPct float64 20 | GlobalInhibition bool 21 | NumActiveColumnsPerInhArea int 22 | LocalAreaDensity float64 23 | StimulusThreshold int 24 | SynPermInactiveDec float64 25 | SynPermActiveInc float64 26 | SynPermBelowStimulusInc float64 27 | SynPermConnected float64 28 | MinPctOverlapDutyCycles float64 29 | MinPctActiveDutyCycles float64 30 | DutyCyclePeriod int 31 | MaxBoost float64 32 | SpVerbosity int 33 | 34 | // Extra parameter settings 35 | SynPermMin float64 36 | SynPermMax float64 37 | SynPermTrimThreshold float64 38 | UpdatePeriod int 39 | InitConnectedPct float64 40 | 41 | // Internal state 42 | Version float64 43 | IterationNum int 44 | IterationLearnNum int 45 | 46 | //random seed 47 | Seed int 48 | 49 | potentialPools *DenseBinaryMatrix 50 | permanences *matrix.SparseMatrix 51 | tieBreaker []float64 52 | 53 | connectedSynapses *DenseBinaryMatrix 54 | //redundant 55 | connectedCounts []int 56 | 57 | overlapDutyCycles []float64 58 | activeDutyCycles []float64 59 | minOverlapDutyCycles []float64 60 | minActiveDutyCycles []float64 61 | boostFactors []float64 62 | 63 | inhibitionRadius int 64 | 65 | spVerbosity int 66 | } 67 | 68 | type SpParams struct { 69 | InputDimensions []int 70 | ColumnDimensions []int 71 | PotentialRadius int 72 | PotentialPct float64 73 | GlobalInhibition bool 74 | LocalAreaDensity float64 75 | NumActiveColumnsPerInhArea int 76 | StimulusThreshold int 77 | SynPermInactiveDec float64 78 | SynPermActiveInc float64 79 | SynPermConnected float64 80 | MinPctOverlapDutyCycle float64 81 | MinPctActiveDutyCycle float64 82 | DutyCyclePeriod int 83 | MaxBoost float64 84 | Seed int 85 | SpVerbosity int 86 | } 87 | 88 | //Initializes default spatial pooler params 89 | func NewSpParams() SpParams { 90 | sp := SpParams{} 91 | 92 | sp.InputDimensions = []int{32, 32} 93 | sp.ColumnDimensions = []int{64, 64} 94 | sp.PotentialRadius = 16 95 | sp.PotentialPct = 0.5 96 | sp.GlobalInhibition = false 97 | sp.LocalAreaDensity = -1.0 98 | sp.NumActiveColumnsPerInhArea = 10.0 99 | sp.StimulusThreshold = 0 100 | sp.SynPermInactiveDec = 0.01 101 | sp.SynPermActiveInc = 0.1 102 | sp.SynPermConnected = 0.10 103 | sp.MinPctOverlapDutyCycle = 0.001 104 | sp.MinPctActiveDutyCycle = 0.001 105 | sp.DutyCyclePeriod = 1000 106 | sp.MaxBoost = 10.0 107 | sp.Seed = -1 108 | sp.SpVerbosity = 0 109 | 110 | return sp 111 | } 112 | 113 | //Creates a new spatial pooler 114 | func NewSpatialPooler(spParams SpParams) *SpatialPooler { 115 | sp := SpatialPooler{} 116 | //Validate inputs 117 | sp.numColumns = utils.ProdInt(spParams.ColumnDimensions) 118 | sp.numInputs = utils.ProdInt(spParams.InputDimensions) 119 | 120 | if sp.numColumns < 1 { 121 | panic("Must have at least 1 column") 122 | } 123 | if sp.numInputs < 1 { 124 | panic("must have at least 1 input") 125 | } 126 | if spParams.NumActiveColumnsPerInhArea < 1 && (spParams.LocalAreaDensity < 1) && (spParams.LocalAreaDensity >= 0.5) { 127 | panic("Num active colums invalid") 128 | } 129 | 130 | sp.InputDimensions = spParams.InputDimensions 131 | sp.ColumnDimensions = spParams.ColumnDimensions 132 | sp.PotentialRadius = int(mathutil.Min(spParams.PotentialRadius, sp.numInputs)) 133 | sp.PotentialPct = spParams.PotentialPct 134 | sp.GlobalInhibition = spParams.GlobalInhibition 135 | sp.LocalAreaDensity = spParams.LocalAreaDensity 136 | sp.NumActiveColumnsPerInhArea = spParams.NumActiveColumnsPerInhArea 137 | sp.StimulusThreshold = spParams.StimulusThreshold 138 | sp.SynPermInactiveDec = spParams.SynPermInactiveDec 139 | sp.SynPermActiveInc = spParams.SynPermActiveInc 140 | sp.SynPermBelowStimulusInc = spParams.SynPermConnected / 10.0 141 | sp.SynPermConnected = spParams.SynPermConnected 142 | sp.MinPctOverlapDutyCycles = spParams.MinPctOverlapDutyCycle 143 | sp.MinPctActiveDutyCycles = spParams.MinPctActiveDutyCycle 144 | sp.DutyCyclePeriod = spParams.DutyCyclePeriod 145 | sp.MaxBoost = spParams.MaxBoost 146 | sp.Seed = spParams.Seed 147 | sp.SpVerbosity = spParams.SpVerbosity 148 | 149 | // Extra parameter settings 150 | sp.SynPermMin = 0 151 | sp.SynPermMax = 1 152 | sp.SynPermTrimThreshold = sp.SynPermActiveInc / 2.0 153 | if sp.SynPermTrimThreshold >= sp.SynPermConnected { 154 | panic("Syn perm threshold >= syn connected.") 155 | } 156 | sp.UpdatePeriod = 50 157 | sp.InitConnectedPct = 0.5 158 | 159 | /* 160 | # Internal state 161 | version = 1.0 162 | iterationNum = 0 163 | iterationLearnNum = 0 164 | */ 165 | 166 | /* 167 | Store the set of all inputs that are within each column's potential pool. 168 | 'potentialPools' is a matrix, whose rows represent cortical columns, and 169 | whose columns represent the input bits. if potentialPools[i][j] == 1, 170 | then input bit 'j' is in column 'i's potential pool. A column can only be 171 | connected to inputs in its potential pool. The indices refer to a 172 | falttenned version of both the inputs and columns. Namely, irrespective 173 | of the topology of the inputs and columns, they are treated as being a 174 | one dimensional array. Since a column is typically connected to only a 175 | subset of the inputs, many of the entries in the matrix are 0. Therefore 176 | the the potentialPool matrix is stored using the SparseBinaryMatrix 177 | class, to reduce memory footprint and compuation time of algorithms that 178 | require iterating over the data strcuture. 179 | */ 180 | sp.potentialPools = NewDenseBinaryMatrix(sp.numColumns, sp.numInputs) 181 | 182 | /* 183 | Initialize the permanences for each column. Similar to the 184 | 'potentialPools', the permances are stored in a matrix whose rows 185 | represent the cortial columns, and whose columns represent the input 186 | bits. if permanences[i][j] = 0.2, then the synapse connecting 187 | cortical column 'i' to input bit 'j' has a permanence of 0.2. Here we 188 | also use the SparseMatrix class to reduce the memory footprint and 189 | computation time of algorithms that require iterating over the data 190 | structure. This permanence matrix is only allowed to have non-zero 191 | elements where the potential pool is non-zero. 192 | */ 193 | //Assumes 70% sparsity 194 | elms := make(map[int]float64, int(float64(sp.numColumns*sp.numInputs)*0.3)) 195 | sp.permanences = matrix.MakeSparseMatrix(elms, sp.numColumns, sp.numInputs) 196 | 197 | /* 198 | Initialize a tiny random tie breaker. This is used to determine winning 199 | columns where the overlaps are identical. 200 | */ 201 | 202 | sp.tieBreaker = make([]float64, sp.numColumns) 203 | for i := 0; i < len(sp.tieBreaker); i++ { 204 | sp.tieBreaker[i] = 0.01 * rand.Float64() 205 | } 206 | 207 | /* 208 | 'connectedSynapses' is a similar matrix to 'permanences' 209 | (rows represent cortial columns, columns represent input bits) whose 210 | entries represent whether the cortial column is connected to the input 211 | bit, i.e. its permanence value is greater than 'synPermConnected'. While 212 | this information is readily available from the 'permanence' matrix, 213 | it is stored separately for efficiency purposes. 214 | */ 215 | sp.connectedSynapses = NewDenseBinaryMatrix(sp.numColumns, sp.numInputs) 216 | 217 | /* 218 | Stores the number of connected synapses for each column. This is simply 219 | a sum of each row of 'ConnectedSynapses'. again, while this 220 | information is readily available from 'ConnectedSynapses', it is 221 | stored separately for efficiency purposes. 222 | */ 223 | sp.connectedCounts = make([]int, sp.numColumns) 224 | 225 | /* 226 | Initialize the set of permanence values for each columns. Ensure that 227 | each column is connected to enough input bits to allow it to be 228 | activated 229 | */ 230 | 231 | for i := 0; i < sp.numColumns; i++ { 232 | potential := sp.mapPotential(i, true) 233 | sp.potentialPools.ReplaceRow(i, potential) 234 | perm := sp.initPermanence(potential, sp.InitConnectedPct) 235 | sp.updatePermanencesForColumn(perm, i, true) 236 | } 237 | 238 | sp.overlapDutyCycles = make([]float64, sp.numColumns) 239 | sp.activeDutyCycles = make([]float64, sp.numColumns) 240 | sp.minOverlapDutyCycles = make([]float64, sp.numColumns) 241 | sp.minActiveDutyCycles = make([]float64, sp.numColumns) 242 | sp.boostFactors = make([]float64, sp.numColumns) 243 | for i := 0; i < len(sp.boostFactors); i++ { 244 | sp.boostFactors[i] = 1.0 245 | } 246 | 247 | /* 248 | The inhibition radius determines the size of a column's local 249 | neighborhood. of a column. A cortical column must overcome the overlap 250 | score of columns in his neighborhood in order to become actives. This 251 | radius is updated every learning round. It grows and shrinks with the 252 | average number of connected synapses per column. 253 | */ 254 | sp.inhibitionRadius = 0 255 | sp.updateInhibitionRadius(sp.avgConnectedSpanForColumnND, sp.avgColumnsPerInput) 256 | 257 | if sp.spVerbosity > 0 { 258 | sp.printParameters() 259 | } 260 | 261 | return &sp 262 | } 263 | 264 | //Returns number of inputs 265 | func (sp *SpatialPooler) NumInputs() int { 266 | return sp.numInputs 267 | } 268 | 269 | //Returns number of columns 270 | func (sp *SpatialPooler) NumColumns() int { 271 | return sp.numColumns 272 | } 273 | 274 | //Returns number of inputs 275 | func (ssp *SpParams) NumInputs() int { 276 | return utils.ProdInt(ssp.InputDimensions) 277 | } 278 | 279 | //Returns number of columns 280 | func (ssp *SpParams) NumColumns() int { 281 | return utils.ProdInt(ssp.ColumnDimensions) 282 | } 283 | 284 | /* 285 | Maps a column to its input bits. This method encapsultes the topology of 286 | the region. It takes the index of the column as an argument and determines 287 | what are the indices of the input vector that are located within the 288 | column's potential pool. The return value is a list containing the indices 289 | of the input bits. The current implementation of the base class only 290 | supports a 1 dimensional topology of columsn with a 1 dimensional topology 291 | of inputs. To extend this class to support 2-D topology you will need to 292 | override this method. Examples of the expected output of this method: 293 | * If the potentialRadius is greater than or equal to the entire input 294 | space, (global visibility), then this method returns an array filled with 295 | all the indices 296 | * If the topology is one dimensional, and the potentialRadius is 5, this 297 | method will return an array containing 5 consecutive values centered on 298 | the index of the column (wrapping around if necessary). 299 | * If the topology is two dimensional (not implemented), and the 300 | potentialRadius is 5, the method should return an array containing 25 301 | '1's, where the exact indices are to be determined by the mapping from 302 | 1-D index to 2-D position. 303 | 304 | Parameters: 305 | ---------------------------- 306 | index: The index identifying a column in the permanence, potential 307 | and connectivity matrices. 308 | wrapAround: A boolean value indicating that boundaries should be 309 | region boundaries ignored. 310 | */ 311 | func (sp *SpatialPooler) mapPotential(index int, wrapAround bool) []bool { 312 | // Distribute column over inputs uniformly 313 | ratio := float64(index) / float64(mathutil.Max((sp.numColumns-1), 1)) 314 | index = int(float64(sp.numInputs-1) * ratio) 315 | 316 | var indices []int 317 | indLen := 2*sp.PotentialRadius + 1 318 | 319 | for i := 0; i < indLen; i++ { 320 | temp := (i + index - sp.PotentialRadius) 321 | if wrapAround { 322 | temp = temp % sp.numInputs 323 | if temp < 0 { 324 | temp = sp.numInputs + temp 325 | } 326 | } else { 327 | if !(temp >= 0 && temp < sp.numInputs) { 328 | continue 329 | } 330 | } 331 | //no dupes 332 | if !utils.ContainsInt(temp, indices) { 333 | indices = append(indices, temp) 334 | } 335 | } 336 | 337 | // Select a subset of the receptive field to serve as the 338 | // the potential pool 339 | 340 | //shuffle indices 341 | for i := range indices { 342 | j := rand.Intn(i + 1) 343 | indices[i], indices[j] = indices[j], indices[i] 344 | } 345 | 346 | sampleLen := int(utils.RoundPrec(float64(len(indices))*sp.PotentialPct, 0)) 347 | sample := indices[:sampleLen] 348 | //project indices onto input mask 349 | mask := make([]bool, sp.numInputs) 350 | for i, _ := range mask { 351 | mask[i] = utils.ContainsInt(i, sample) 352 | } 353 | 354 | return mask 355 | } 356 | 357 | /* 358 | Returns a randomly generated permanence value for a synapses that is 359 | initialized in a connected state. The basic idea here is to initialize 360 | permanence values very close to synPermConnected so that a small number of 361 | learning steps could make it disconnected or connected. 362 | 363 | Note: experimentation was done a long time ago on the best way to initialize 364 | permanence values, but the history for this particular scheme has been lost. 365 | */ 366 | 367 | func (sp *SpatialPooler) initPermConnected() float64 { 368 | 369 | p := sp.SynPermConnected + rand.Float64()*sp.SynPermActiveInc/4.0 370 | 371 | // Ensure we don't have too much unnecessary precision. A full 64 bits of 372 | // precision causes numerical stability issues across platforms and across 373 | // implementations 374 | 375 | return float64(int(p*100000)) / 100000.0 376 | } 377 | 378 | /* 379 | Returns a randomly generated permanence value for a synapses that is to be 380 | initialized in a non-connected state. 381 | */ 382 | 383 | func (sp *SpatialPooler) initPermNonConnected() float64 { 384 | p := sp.SynPermConnected * rand.Float64() 385 | 386 | // Ensure we don't have too much unnecessary precision. A full 64 bits of 387 | // precision causes numerical stability issues across platforms and across 388 | // implementations 389 | return float64(int(p*100000)) / 100000.0 390 | } 391 | 392 | /* 393 | Initializes the permanences of a column. The method 394 | returns a 1-D array the size of the input, where each entry in the 395 | array represents the initial permanence value between the input bit 396 | at the particular index in the array, and the column represented by 397 | the 'index' parameter. 398 | 399 | Parameters: 400 | ---------------------------- 401 | potential: A numpy array specifying the potential pool of the column. 402 | Permanence values will only be generated for input bits 403 | corresponding to indices for which the mask value is 1. 404 | connectedPct: A value between 0 and 1 specifying the percent of the input 405 | bits that will start off in a connected state. 406 | */ 407 | 408 | func (sp *SpatialPooler) initPermanence(potential []bool, connectedPct float64) []float64 { 409 | // Determine which inputs bits will start out as connected 410 | // to the inputs. Initially a subset of the input bits in a 411 | // column's potential pool will be connected. This number is 412 | // given by the parameter "connectedPct" 413 | 414 | perm := make([]float64, sp.numInputs) 415 | //var perm []int 416 | 417 | for i := 0; i < sp.numInputs; i++ { 418 | if !potential[i] { 419 | continue 420 | } 421 | var temp float64 422 | if utils.RandFloatRange(0.0, 1.0) < connectedPct { 423 | temp = sp.initPermConnected() 424 | } else { 425 | temp = sp.initPermNonConnected() 426 | } 427 | //Exclude low values to save memory 428 | if temp < sp.SynPermTrimThreshold { 429 | temp = 0.0 430 | } 431 | 432 | perm[i] = temp 433 | } 434 | 435 | return perm 436 | } 437 | 438 | /* 439 | This method ensures that each column has enough connections to input bits 440 | to allow it to become active. Since a column must have at least 441 | 'stimulusThreshold' overlaps in order to be considered during the 442 | inhibition phase, columns without such minimal number of connections, even 443 | if all the input bits they are connected to turn on, have no chance of 444 | obtaining the minimum threshold. For such columns, the permanence values 445 | are increased until the minimum number of connections are formed. 446 | 447 | Parameters: 448 | ---------------------------- 449 | perm: An array of permanence values for a column. The array is 450 | "dense", i.e. it contains an entry for each input bit, even 451 | if the permanence value is 0. 452 | mask: the indices of the columns whose permanences need to be 453 | raised. 454 | */ 455 | 456 | func (sp *SpatialPooler) raisePermanenceToThreshold(perm []float64, mask []int) { 457 | 458 | for i := 0; i < len(perm); i++ { 459 | if perm[i] < sp.SynPermMin { 460 | perm[i] = sp.SynPermMin 461 | } else if perm[i] > sp.SynPermMax { 462 | perm[i] = sp.SynPermMax 463 | } 464 | } 465 | 466 | for { 467 | numConnected := 0 468 | for i := 0; i < len(perm); i++ { 469 | if perm[i] > sp.SynPermConnected { 470 | numConnected++ 471 | } 472 | } 473 | if numConnected >= sp.StimulusThreshold { 474 | return 475 | } 476 | for i := 0; i < len(mask); i++ { 477 | perm[mask[i]] += sp.SynPermBelowStimulusInc 478 | } 479 | } 480 | 481 | } 482 | 483 | /* 484 | This method updates the permanence matrix with a column's new permanence 485 | values. The column is identified by its index, which reflects the row in 486 | the matrix, and the permanence is given in 'dense' form, i.e. a full 487 | arrray containing all the zeros as well as the non-zero values. It is in 488 | charge of implementing 'clipping' - ensuring that the permanence values are 489 | always between 0 and 1 - and 'trimming' - enforcing sparsity by zeroing out 490 | all permanence values below 'synPermTrimThreshold'. It also maintains 491 | the consistency between 'permanences' (the matrix storing the 492 | permanence values), 'connectedSynapses', (the matrix storing the bits 493 | each column is connected to), and 'connectedCounts' (an array storing 494 | the number of input bits each column is connected to). Every method wishing 495 | to modify the permanence matrix should do so through this method. 496 | 497 | Parameters: 498 | ---------------------------- 499 | perm: An array of permanence values for a column. The array is 500 | "dense", i.e. it contains an entry for each input bit, even 501 | if the permanence value is 0. 502 | index: The index identifying a column in the permanence, potential 503 | and connectivity matrices 504 | raisePerm: a boolean value indicating whether the permanence values 505 | should be raised until a minimum number are synapses are in 506 | a connected state. Should be set to 'false' when a direct 507 | assignment is required. 508 | */ 509 | 510 | func (sp *SpatialPooler) updatePermanencesForColumn(perm []float64, index int, raisePerm bool) { 511 | maskPotential := sp.potentialPools.GetRowIndices(index) 512 | if raisePerm { 513 | sp.raisePermanenceToThreshold(perm, maskPotential) 514 | } 515 | 516 | var newConnected []int 517 | for i := 0; i < len(perm); i++ { 518 | if perm[i] <= sp.SynPermTrimThreshold { 519 | perm[i] = 0 520 | continue 521 | } 522 | 523 | //TODO: can be simplified if syn min/max are always 1/0 524 | if perm[i] < sp.SynPermMin { 525 | perm[i] = sp.SynPermMin 526 | } 527 | if perm[i] > sp.SynPermMax { 528 | perm[i] = sp.SynPermMax 529 | } 530 | if perm[i] >= sp.SynPermConnected { 531 | newConnected = append(newConnected, i) 532 | } 533 | } 534 | 535 | //TODO: replace with sparse matrix that indexes by rows 536 | //sp.permanences.SetRowFromDense(index, perm) 537 | for i := 0; i < len(perm); i++ { 538 | sp.permanences.Set(index, i, perm[i]) 539 | } 540 | sp.connectedSynapses.ReplaceRowByIndices(index, newConnected) 541 | sp.connectedCounts[index] = len(newConnected) 542 | } 543 | 544 | /* 545 | This is the primary public method of the SpatialPooler class. This 546 | function takes a input vector and outputs the indices of the active columns. 547 | If 'learn' is set to True, this method also updates the permanences of the 548 | columns. 549 | 550 | Parameters: 551 | ---------------------------- 552 | inputVector: a numpy array of 0's and 1's thata comprises the input to 553 | the spatial pooler. The array will be treated as a one 554 | dimensional array, therefore the dimensions of the array 555 | do not have to much the exact dimensions specified in the 556 | class constructor. In fact, even a list would suffice. 557 | The number of input bits in the vector must, however, 558 | match the number of bits specified by the call to the 559 | constructor. Therefore there must be a '0' or '1' in the 560 | array for every input bit. 561 | learn: a boolean value indicating whether learning should be 562 | performed. Learning entails updating the permanence 563 | values of the synapses, and hence modifying the 'state' 564 | of the model. Setting learning to 'off' freezes the SP 565 | and has many uses. For example, you might want to feed in 566 | various inputs and examine the resulting SDR's. 567 | activeArray: an array whose size is equal to the number of columns. 568 | Before the function returns this array will be populated 569 | with 1's at the indices of the active columns, and 0's 570 | everywhere else. 571 | */ 572 | func (sp *SpatialPooler) Compute(inputVector []bool, learn bool, activeArray []bool, inhibitColumns inhibitColFunc) { 573 | if len(inputVector) != sp.numInputs { 574 | panic("input != numimputs") 575 | } 576 | 577 | sp.updateBookeepingVars(learn) 578 | overlaps := sp.calculateOverlap(inputVector) 579 | boostedOverlaps := make([]float64, len(overlaps)) 580 | // Apply boosting when learning is on 581 | if learn { 582 | for i, val := range sp.boostFactors { 583 | boostedOverlaps[i] = float64(overlaps[i]) * val 584 | } 585 | } 586 | 587 | // Apply inhibition to determine the winning columns 588 | activeColumns := inhibitColumns(boostedOverlaps, sp.inhibitColumnsGlobal, sp.inhibitColumnsLocal) 589 | overlapsf := make([]float64, len(overlaps)) 590 | for i, val := range overlaps { 591 | overlapsf[i] = float64(val) 592 | } 593 | 594 | if learn { 595 | sp.adaptSynapses(inputVector, activeColumns) 596 | sp.updateDutyCycles(overlapsf, activeColumns) 597 | sp.bumpUpWeakColumns() 598 | sp.updateBoostFactors() 599 | if sp.isUpdateRound() { 600 | sp.updateInhibitionRadius(sp.avgConnectedSpanForColumnND, sp.avgColumnsPerInput) 601 | sp.updateMinDutyCycles() 602 | } 603 | 604 | } else { 605 | activeColumns = sp.stripNeverLearned(activeColumns) 606 | } 607 | 608 | if len(activeColumns) > 0 { 609 | for i, _ := range activeArray { 610 | activeArray[i] = utils.ContainsInt(i, activeColumns) 611 | } 612 | } 613 | 614 | } 615 | 616 | /* 617 | Updates counter instance variables each round. 618 | 619 | Parameters: 620 | ---------------------------- 621 | learn: a boolean value indicating whether learning should be 622 | performed. Learning entails updating the permanence 623 | values of the synapses, and hence modifying the 'state' 624 | of the model. setting learning to 'off' might be useful 625 | for indicating separate training vs. testing sets. 626 | */ 627 | 628 | func (sp *SpatialPooler) updateBookeepingVars(learn bool) { 629 | sp.IterationNum += 1 630 | if learn { 631 | sp.IterationLearnNum += 1 632 | } 633 | 634 | } 635 | 636 | /* 637 | This function determines each column's overlap with the current input 638 | vector. The overlap of a column is the number of synapses for that column 639 | that are connected (permance value is greater than 'synPermConnected') 640 | to input bits which are turned on. Overlap values that are lower than 641 | the 'stimulusThreshold' are ignored. The implementation takes advantage of 642 | the SpraseBinaryMatrix class to perform this calculation efficiently. 643 | 644 | Parameters: 645 | ---------------------------- 646 | inputVector: a numpy array of 0's and 1's that comprises the input to 647 | the spatial pooler. 648 | */ 649 | func (sp *SpatialPooler) calculateOverlap(inputVector []bool) []int { 650 | overlaps := sp.connectedSynapses.RowAndSum(inputVector) 651 | for idx, _ := range overlaps { 652 | if overlaps[idx] < sp.StimulusThreshold { 653 | overlaps[idx] = 0 654 | } 655 | } 656 | return overlaps 657 | } 658 | 659 | func (sp *SpatialPooler) calculateOverlapPct(overlaps []int) []float64 { 660 | result := make([]float64, len(overlaps)) 661 | for idx, val := range overlaps { 662 | result[idx] = float64(val) / float64(sp.connectedCounts[idx]) 663 | } 664 | return result 665 | } 666 | 667 | /* 668 | Similar to _getNeighbors1D and _getNeighbors2D, this function Returns a 669 | list of indices corresponding to the neighbors of a given column. Since the 670 | permanence values are stored in such a way that information about toplogy 671 | is lost. This method allows for reconstructing the toplogy of the inputs, 672 | which are flattened to one array. Given a column's index, its neighbors are 673 | defined as those columns that are 'radius' indices away from it in each 674 | dimension. The method returns a list of the flat indices of these columns. 675 | Parameters: 676 | ---------------------------- 677 | columnIndex: The index identifying a column in the permanence, potential 678 | and connectivity matrices. 679 | dimensions: An array containg a dimensions for the column space. A 2x3 680 | grid will be represented by [2,3]. 681 | radius: Indicates how far away from a given column are other 682 | columns to be considered its neighbors. In the previous 2x3 683 | example, each column with coordinates: 684 | [2+/-radius, 3+/-radius] is considered a neighbor. 685 | wrapAround: A boolean value indicating whether to consider columns at 686 | the border of a dimensions to be adjacent to columns at the 687 | other end of the dimension. For example, if the columns are 688 | layed out in one deimnsion, columns 1 and 10 will be 689 | considered adjacent if wrapAround is set to true: 690 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 691 | */ 692 | 693 | func (sp *SpatialPooler) getNeighborsND(columnIndex int, dimensions []int, radius int, wrapAround bool) []int { 694 | if len(dimensions) < 1 { 695 | panic("Dimensions empty") 696 | } 697 | 698 | bounds := append(dimensions[1:], 1) 699 | bounds = utils.RevCumProdInt(bounds) 700 | 701 | columnCoords := make([]int, len(bounds)) 702 | for j := 0; j < len(bounds); j++ { 703 | columnCoords[j] = utils.Mod(columnIndex/bounds[j], dimensions[j]) 704 | } 705 | 706 | rangeND := make([][]int, len(dimensions)) 707 | for i := 0; i < len(dimensions); i++ { 708 | if wrapAround { 709 | cRange := make([]int, (radius*2)+1) 710 | for j := 0; j < (2*radius)+1; j++ { 711 | cRange[j] = utils.Mod((columnCoords[i]-radius)+j, dimensions[i]) 712 | } 713 | rangeND[i] = cRange 714 | } else { 715 | var cRange []int 716 | for j := 0; j < (radius*2)+1; j++ { 717 | temp := columnCoords[i] - radius + j 718 | if temp >= 0 && temp < dimensions[i] { 719 | cRange = append(cRange, temp) 720 | } 721 | } 722 | rangeND[i] = cRange 723 | } 724 | } 725 | 726 | cp := utils.CartProductInt(rangeND) 727 | var neighbors []int 728 | for i := 0; i < len(cp); i++ { 729 | val := utils.DotInt(bounds, cp[i]) 730 | if val != columnIndex && !utils.ContainsInt(val, neighbors) { 731 | neighbors = append(neighbors, val) 732 | } 733 | } 734 | 735 | return neighbors 736 | } 737 | 738 | /* 739 | Perform global inhibition. Performing global inhibition entails picking the 740 | top 'numActive' columns with the highest overlap score in the entire 741 | region. At most half of the columns in a local neighborhood are allowed to 742 | be active. 743 | 744 | Parameters: 745 | ---------------------------- 746 | overlaps: an array containing the overlap score for each column. 747 | The overlap score for a column is defined as the number 748 | of synapses in a "connected state" (connected synapses) 749 | that are connected to input bits which are turned on. 750 | density: The fraction of columns to survive inhibition. 751 | */ 752 | 753 | func (sp *SpatialPooler) inhibitColumnsGlobal(overlaps []float64, density float64) []int { 754 | //calculate num active per inhibition area 755 | numActive := int(density * float64(sp.numColumns)) 756 | ov := make([]utils.TupleInt, len(overlaps)) 757 | //TODO: if overlaps is assumed to be distinct this can be 758 | // simplified 759 | //a = value, b = original index 760 | for i := 0; i < len(ov); i++ { 761 | ov[i].A = int(overlaps[i]) 762 | ov[i].B = i 763 | } 764 | //insert sort overlaps 765 | for i := 1; i < len(ov); i++ { 766 | for j := i; j > 0 && ov[j].A < ov[j-1].A; j-- { 767 | tmp := ov[j] 768 | ov[j] = ov[j-1] 769 | ov[j-1] = tmp 770 | } 771 | } 772 | 773 | result := make([]int, numActive) 774 | for i := 0; i < numActive; i++ { 775 | result[i] = ov[len(ov)-1-i].B 776 | } 777 | 778 | sort.Sort(sort.IntSlice(result)) 779 | 780 | //return indexes of active columns 781 | return result 782 | } 783 | 784 | /* 785 | Performs local inhibition. Local inhibition is performed on a column by 786 | column basis. Each column observes the overlaps of its neighbors and is 787 | selected if its overlap score is within the top 'numActive' in its local 788 | neighborhood. At most half of the columns in a local neighborhood are 789 | allowed to be active. 790 | 791 | Parameters: 792 | ---------------------------- 793 | overlaps: an array containing the overlap score for each column. 794 | The overlap score for a column is defined as the number 795 | of synapses in a "connected state" (connected synapses) 796 | that are connected to input bits which are turned on. 797 | density: The fraction of columns to survive inhibition. This 798 | value is only an intended target. Since the surviving 799 | columns are picked in a local fashion, the exact fraction 800 | of survining columns is likely to vary. 801 | */ 802 | 803 | func (sp *SpatialPooler) inhibitColumnsLocal(overlaps []float64, density float64) []int { 804 | var activeColumns []int 805 | addToWinners := utils.MaxSliceFloat64(overlaps) / 1000.0 806 | 807 | for i := 0; i < sp.numColumns; i++ { 808 | mask := sp.getNeighborsND(i, sp.ColumnDimensions, sp.inhibitionRadius, false) 809 | 810 | ovSlice := make([]float64, len(mask)) 811 | for idx, val := range mask { 812 | ovSlice[idx] = overlaps[val] 813 | } 814 | 815 | numActive := int(0.5 + density*float64(len(mask)+1)) 816 | numBigger := 0 817 | for _, ov := range ovSlice { 818 | if ov > overlaps[i] { 819 | numBigger++ 820 | } 821 | } 822 | 823 | if numBigger < numActive { 824 | activeColumns = append(activeColumns, i) 825 | overlaps[i] += addToWinners 826 | } 827 | } 828 | 829 | return activeColumns 830 | } 831 | 832 | type inhibitColumnsFunc func([]float64, float64) []int 833 | type inhibitColFunc func(overlaps []float64, inhibitColumnsGlobal, inhibitColumnsLocal inhibitColumnsFunc) []int 834 | 835 | /* 836 | Performs inhibition. This method calculates the necessary values needed to 837 | actually perform inhibition and then delegates the task of picking the 838 | active columns to helper functions. 839 | 840 | Parameters: 841 | ---------------------------- 842 | overlaps: an array containing the overlap score for each column. 843 | The overlap score for a column is defined as the number 844 | of synapses in a "connected state" (connected synapses) 845 | that are connected to input bits which are turned on. 846 | 847 | */ 848 | 849 | func (sp *SpatialPooler) InhibitColumns(overlaps []float64, inhibitColumnsGlobal, inhibitColumnsLocal inhibitColumnsFunc) []int { 850 | /* 851 | determine how many columns should be selected in the inhibition phase. 852 | This can be specified by either setting the 'numActiveColumnsPerInhArea' 853 | parameter of the 'localAreaDensity' parameter when initializing the class 854 | */ 855 | density := 0.0 856 | if sp.LocalAreaDensity > 0 { 857 | density = sp.LocalAreaDensity 858 | } else { 859 | inhibitionArea := math.Pow(float64(2*sp.inhibitionRadius+1), float64(len(sp.ColumnDimensions))) 860 | inhibitionArea = math.Min(float64(sp.numColumns), inhibitionArea) 861 | density = float64(sp.NumActiveColumnsPerInhArea) / inhibitionArea 862 | density = math.Min(density, 0.5) 863 | } 864 | 865 | // Add our fixed little bit of random noise to the scores to help break ties. 866 | //overlaps += sp.tieBreaker 867 | for i := 0; i < len(overlaps); i++ { 868 | overlaps[i] += sp.tieBreaker[i] 869 | } 870 | 871 | if sp.GlobalInhibition || 872 | sp.inhibitionRadius > utils.MaxSliceInt(sp.ColumnDimensions) { 873 | return inhibitColumnsGlobal(overlaps, density) 874 | } else { 875 | return inhibitColumnsLocal(overlaps, density) 876 | } 877 | 878 | } 879 | 880 | /* 881 | The primary method in charge of learning. Adapts the permanence values of 882 | the synapses based on the input vector, and the chosen columns after 883 | inhibition round. Permanence values are increased for synapses connected to 884 | input bits that are turned on, and decreased for synapses connected to 885 | inputs bits that are turned off. 886 | 887 | Parameters: 888 | ---------------------------- 889 | inputVector: a numpy array of 0's and 1's thata comprises the input to 890 | the spatial pooler. There exists an entry in the array 891 | for every input bit. 892 | activeColumns: an array containing the indices of the columns that 893 | survived inhibition. 894 | */ 895 | func (sp *SpatialPooler) adaptSynapses(inputVector []bool, activeColumns []int) { 896 | var inputIndices []int 897 | for i, val := range inputVector { 898 | if val { 899 | inputIndices = append(inputIndices, i) 900 | } 901 | } 902 | 903 | permChanges := make([]float64, sp.numInputs) 904 | utils.FillSliceFloat64(permChanges, -1*sp.SynPermInactiveDec) 905 | for _, val := range inputIndices { 906 | permChanges[val] = sp.SynPermActiveInc 907 | } 908 | 909 | for _, ac := range activeColumns { 910 | perm := make([]float64, sp.numInputs) 911 | mask := sp.potentialPools.GetRowIndices(ac) 912 | for j := 0; j < sp.numInputs; j++ { 913 | if utils.ContainsInt(j, mask) { 914 | perm[j] = permChanges[j] + sp.permanences.Get(ac, j) 915 | } else { 916 | perm[j] = sp.permanences.Get(ac, j) 917 | } 918 | 919 | } 920 | sp.updatePermanencesForColumn(perm, ac, true) 921 | } 922 | 923 | } 924 | 925 | /* 926 | Updates the duty cycles for each column. The OVERLAP duty cycle is a moving 927 | average of the number of inputs which overlapped with the each column. The 928 | ACTIVITY duty cycles is a moving average of the frequency of activation for 929 | each column. 930 | 931 | Parameters: 932 | ---------------------------- 933 | overlaps: an array containing the overlap score for each column. 934 | The overlap score for a column is defined as the number 935 | of synapses in a "connected state" (connected synapses) 936 | that are connected to input bits which are turned on. 937 | activeColumns: An array containing the indices of the active columns, 938 | the sprase set of columns which survived inhibition 939 | */ 940 | func (sp *SpatialPooler) updateDutyCycles(overlaps []float64, activeColumns []int) { 941 | overlapArray := make([]int, sp.numColumns) 942 | activeArray := make([]int, sp.numColumns) 943 | 944 | for i, val := range overlaps { 945 | if val > 0 { 946 | overlapArray[i] = 1 947 | } 948 | } 949 | 950 | if len(activeColumns) > 0 { 951 | for _, val := range activeColumns { 952 | activeArray[val] = 1 953 | } 954 | } 955 | 956 | period := sp.DutyCyclePeriod 957 | if period > sp.IterationNum { 958 | period = sp.IterationNum 959 | } 960 | 961 | sp.overlapDutyCycles = updateDutyCyclesHelper( 962 | sp.overlapDutyCycles, 963 | overlapArray, 964 | period, 965 | ) 966 | 967 | sp.activeDutyCycles = updateDutyCyclesHelper( 968 | sp.activeDutyCycles, 969 | activeArray, 970 | period, 971 | ) 972 | } 973 | 974 | /* 975 | This method increases the permanence values of synapses of columns whose 976 | activity level has been too low. Such columns are identified by having an 977 | overlap duty cycle that drops too much below those of their peers. The 978 | permanence values for such columns are increased. 979 | */ 980 | func (sp *SpatialPooler) bumpUpWeakColumns() { 981 | var weakColumns []int 982 | for i, val := range sp.overlapDutyCycles { 983 | if val < sp.minOverlapDutyCycles[i] { 984 | weakColumns = append(weakColumns, i) 985 | } 986 | } 987 | 988 | for _, col := range weakColumns { 989 | perm := make([]float64, sp.numInputs) 990 | for j := 0; j < sp.numInputs; j++ { 991 | perm[j] = sp.permanences.Get(col, j) 992 | } 993 | 994 | maskPotential := sp.potentialPools.GetRowIndices(col) 995 | for _, mpot := range maskPotential { 996 | perm[mpot] += sp.SynPermBelowStimulusInc 997 | } 998 | sp.updatePermanencesForColumn(perm, col, false) 999 | } 1000 | 1001 | } 1002 | 1003 | /* 1004 | Update the boost factors for all columns. The boost factors are used to 1005 | increase the overlap of inactive columns to improve their chances of 1006 | becoming active. and hence encourage participation of more columns in the 1007 | learning process. This is a line defined as: y = mx + b boost = 1008 | (1-maxBoost)/minDuty * dutyCycle + maxFiringBoost. Intuitively this means 1009 | that columns that have been active enough have a boost factor of 1, meaning 1010 | their overlap is not boosted. Columns whose active duty cycle drops too much 1011 | below that of their neighbors are boosted depending on how infrequently they 1012 | have been active. The more infrequent, the more they are boosted. The exact 1013 | boost factor is linearly interpolated between the points (dutyCycle:0, 1014 | boost:maxFiringBoost) and (dutyCycle:minDuty, boost:1.0). 1015 | 1016 | boostFactor 1017 | ^ 1018 | maxBoost _ | 1019 | |\ 1020 | | \ 1021 | 1 _ | \ _ _ _ _ _ _ _ 1022 | | 1023 | +--------------------> activeDutyCycle 1024 | | 1025 | minActiveDutyCycle 1026 | */ 1027 | 1028 | func (sp *SpatialPooler) updateBoostFactors() { 1029 | for i, val := range sp.minActiveDutyCycles { 1030 | if val > 0 { 1031 | sp.boostFactors[i] = ((1.0 - sp.MaxBoost) / 1032 | sp.minActiveDutyCycles[i] * sp.activeDutyCycles[i]) + sp.MaxBoost 1033 | } 1034 | } 1035 | 1036 | for i, val := range sp.activeDutyCycles { 1037 | if val > sp.minActiveDutyCycles[i] { 1038 | sp.boostFactors[i] = 1.0 1039 | } 1040 | } 1041 | 1042 | } 1043 | 1044 | /* 1045 | returns true if the enough rounds have passed to warrant updates of 1046 | duty cycles 1047 | */ 1048 | func (sp *SpatialPooler) isUpdateRound() bool { 1049 | return (sp.IterationNum % sp.UpdatePeriod) == 0 1050 | } 1051 | 1052 | /* 1053 | The range of connectedSynapses per column, averaged for each dimension. 1054 | This vaule is used to calculate the inhibition radius. This variation of 1055 | the function supports arbitrary column dimensions. 1056 | 1057 | Parameters: 1058 | ---------------------------- 1059 | index: The index identifying a column in the permanence, potential 1060 | and connectivity matrices. 1061 | */ 1062 | 1063 | func (sp *SpatialPooler) avgConnectedSpanForColumnND(index int) float64 { 1064 | dimensions := sp.InputDimensions 1065 | 1066 | bounds := append(dimensions[1:], 1) 1067 | bounds = utils.RevCumProdInt(bounds) 1068 | 1069 | connected := sp.connectedSynapses.GetRowIndices(index) 1070 | if len(connected) == 0 { 1071 | return 0 1072 | } 1073 | 1074 | maxCoord := make([]int, len(dimensions)) 1075 | minCoord := make([]int, len(dimensions)) 1076 | inputMax := 0 1077 | for i := 0; i < len(dimensions); i++ { 1078 | if dimensions[i] > inputMax { 1079 | inputMax = dimensions[i] 1080 | } 1081 | } 1082 | for i := 0; i < len(maxCoord); i++ { 1083 | maxCoord[i] = -1.0 1084 | minCoord[i] = inputMax 1085 | } 1086 | //calc min/max of (i/bounds) % dimensions 1087 | for _, val := range connected { 1088 | for j := 0; j < len(dimensions); j++ { 1089 | coord := (val / bounds[j]) % dimensions[j] 1090 | if coord > maxCoord[j] { 1091 | maxCoord[j] = coord 1092 | } 1093 | if coord < minCoord[j] { 1094 | minCoord[j] = coord 1095 | } 1096 | } 1097 | } 1098 | 1099 | sum := 0 1100 | for i := 0; i < len(dimensions); i++ { 1101 | sum += maxCoord[i] - minCoord[i] + 1 1102 | } 1103 | 1104 | return float64(sum) / float64(len(dimensions)) 1105 | } 1106 | 1107 | /* 1108 | The average number of columns per input, taking into account the topology 1109 | of the inputs and columns. This value is used to calculate the inhibition 1110 | radius. This function supports an arbitrary number of dimensions. If the 1111 | number of column dimensions does not match the number of input dimensions, 1112 | we treat the missing, or phantom dimensions as 'ones'. 1113 | */ 1114 | 1115 | func (sp *SpatialPooler) avgColumnsPerInput() float64 { 1116 | 1117 | //TODO: extend to support different number of dimensions for inputs and 1118 | // columns 1119 | numDim := mathutil.Max(len(sp.ColumnDimensions), len(sp.InputDimensions)) 1120 | columnDims := sp.ColumnDimensions 1121 | inputDims := sp.InputDimensions 1122 | 1123 | //overlay column dimensions across 1's matrix 1124 | colDim := make([]int, numDim) 1125 | inputDim := make([]int, numDim) 1126 | 1127 | for i := 0; i < numDim; i++ { 1128 | if i < len(columnDims) { 1129 | colDim[i] = columnDims[i] 1130 | } else { 1131 | colDim[i] = 1 1132 | } 1133 | 1134 | if i < numDim { 1135 | inputDim[i] = inputDims[i] 1136 | } else { 1137 | inputDim[i] = 1 1138 | } 1139 | 1140 | } 1141 | 1142 | sum := 0.0 1143 | for i := 0; i < len(inputDim); i++ { 1144 | sum += float64(colDim[i]) / float64(inputDim[i]) 1145 | } 1146 | return sum / float64(numDim) 1147 | } 1148 | 1149 | type avgConnectedSpanForColumnNDFunc func(int) float64 1150 | type avgColumnsPerInputFunc func() float64 1151 | 1152 | /* 1153 | Update the inhibition radius. The inhibition radius is a meausre of the 1154 | square (or hypersquare) of columns that each a column is "conencted to" 1155 | on average. Since columns are are not connected to each other directly, we 1156 | determine this quantity by first figuring out how many *inputs* a column is 1157 | connected to, and then multiplying it by the total number of columns that 1158 | exist for each input. For multiple dimension the aforementioned 1159 | calculations are averaged over all dimensions of inputs and columns. This 1160 | value is meaningless if global inhibition is enabled. 1161 | */ 1162 | func (sp *SpatialPooler) updateInhibitionRadius(avgConnectedSpanForColumnND avgConnectedSpanForColumnNDFunc, 1163 | avgColumnsPerInput avgColumnsPerInputFunc) { 1164 | 1165 | if sp.GlobalInhibition { 1166 | cmax := utils.MaxSliceInt(sp.ColumnDimensions) 1167 | sp.inhibitionRadius = cmax 1168 | return 1169 | } 1170 | 1171 | avgConnectedSpan := 0.0 1172 | for i := 0; i < sp.numColumns; i++ { 1173 | avgConnectedSpan += avgConnectedSpanForColumnND(i) 1174 | } 1175 | avgConnectedSpan = avgConnectedSpan / float64(sp.numColumns) 1176 | 1177 | columnsPerInput := avgColumnsPerInput() 1178 | diameter := avgConnectedSpan * columnsPerInput 1179 | radius := (diameter - 1) / 2.0 1180 | radius = math.Max(1.0, radius) 1181 | 1182 | sp.inhibitionRadius = int(utils.RoundPrec(radius, 0)) 1183 | } 1184 | 1185 | /* 1186 | Updates the minimum duty cycles defining normal activity for a column. A 1187 | column with activity duty cycle below this minimum threshold is boosted. 1188 | */ 1189 | func (sp *SpatialPooler) updateMinDutyCycles() { 1190 | if sp.GlobalInhibition || sp.inhibitionRadius > sp.numInputs { 1191 | sp.updateMinDutyCyclesGlobal() 1192 | } else { 1193 | sp.updateMinDutyCyclesLocal(sp.getNeighborsND) 1194 | } 1195 | 1196 | } 1197 | 1198 | /* 1199 | Updates the minimum duty cycles in a global fashion. Sets the minimum duty 1200 | cycles for the overlap and activation of all columns to be a percent of the 1201 | maximum in the region, specified by minPctOverlapDutyCycle and 1202 | minPctActiveDutyCycle respectively. Functionaly it is equivalent to 1203 | updateMinDutyCyclesLocal, but this function exploits the globalilty of the 1204 | compuation to perform it in a straightforward, and more efficient manner. 1205 | */ 1206 | func (sp *SpatialPooler) updateMinDutyCyclesGlobal() { 1207 | minOverlap := sp.MinPctOverlapDutyCycles * utils.MaxSliceFloat64(sp.overlapDutyCycles) 1208 | utils.FillSliceFloat64(sp.minOverlapDutyCycles, minOverlap) 1209 | minActive := sp.MinPctActiveDutyCycles * utils.MaxSliceFloat64(sp.activeDutyCycles) 1210 | utils.FillSliceFloat64(sp.minActiveDutyCycles, minActive) 1211 | } 1212 | 1213 | type getNeighborsNDFunc func(int, []int, int, bool) []int 1214 | 1215 | /* 1216 | Updates the minimum duty cycles. The minimum duty cycles are determined 1217 | locally. Each column's minimum duty cycles are set to be a percent of the 1218 | maximum duty cycles in the column's neighborhood. Unlike 1219 | updateMinDutyCyclesGlobal, here the values can be quite different for 1220 | different columns. 1221 | */ 1222 | func (sp *SpatialPooler) updateMinDutyCyclesLocal(getNeighborsND getNeighborsNDFunc) { 1223 | 1224 | for i := 0; i < sp.numColumns; i++ { 1225 | maskNeighbors := getNeighborsND(i, sp.ColumnDimensions, sp.inhibitionRadius, false) 1226 | maskNeighbors = append(maskNeighbors, i) 1227 | 1228 | maxOverlap := utils.MaxSliceFloat64(utils.SubsetSliceFloat64(sp.overlapDutyCycles, maskNeighbors)) 1229 | sp.minOverlapDutyCycles[i] = maxOverlap * sp.MinPctOverlapDutyCycles 1230 | 1231 | maxActive := utils.MaxSliceFloat64(utils.SubsetSliceFloat64(sp.activeDutyCycles, maskNeighbors)) 1232 | sp.minActiveDutyCycles[i] = maxActive * sp.MinPctActiveDutyCycles 1233 | } 1234 | 1235 | } 1236 | 1237 | /* 1238 | Removes the set of columns who have never been active from the set of 1239 | active columns selected in the inhibition round. Such columns cannot 1240 | represent learned pattern and are therefore meaningless if only inference 1241 | is required. 1242 | 1243 | Parameters: 1244 | ---------------------------- 1245 | activeColumns: An array containing the indices of the active columns 1246 | */ 1247 | func (sp *SpatialPooler) stripNeverLearned(activeColumns []int) []int { 1248 | var result []int 1249 | for i := 0; i < len(activeColumns); i++ { 1250 | if sp.activeDutyCycles[activeColumns[i]] != 0 { 1251 | result = append(result, activeColumns[i]) 1252 | } 1253 | } 1254 | 1255 | return result 1256 | } 1257 | 1258 | func (sp *SpatialPooler) printParameters() { 1259 | fmt.Println("numInputs", sp.numInputs) 1260 | fmt.Println("numColumns", sp.numColumns) 1261 | 1262 | } 1263 | 1264 | //----- Helper functions ---- 1265 | 1266 | /* 1267 | Updates a duty cycle estimate with a new value. This is a helper 1268 | function that is used to update several duty cycle variables in 1269 | the Column class, such as: overlapDutyCucle, activeDutyCycle, 1270 | minPctDutyCycleBeforeInh, minPctDutyCycleAfterInh, etc. returns 1271 | the updated duty cycle. Duty cycles are updated according to the following 1272 | formula: 1273 | 1274 | (period - 1)*dutyCycle + newValue 1275 | dutyCycle := ---------------------------------- 1276 | period 1277 | 1278 | Parameters: 1279 | ---------------------------- 1280 | dutyCycles: An array containing one or more duty cycle values that need 1281 | to be updated 1282 | newInput: A new numerical value used to update the duty cycle 1283 | period: The period of the duty cycle 1284 | */ 1285 | func updateDutyCyclesHelper(dutyCycles []float64, newInput []int, period int) []float64 { 1286 | if period < 1.0 { 1287 | panic("period can't be less than 1") 1288 | } 1289 | pf := float64(period) 1290 | result := make([]float64, len(dutyCycles)) 1291 | for i, val := range dutyCycles { 1292 | result[i] = (val*(pf-1.0) + float64(newInput[i])) / pf 1293 | } 1294 | return result 1295 | } 1296 | --------------------------------------------------------------------------------