├── README.md ├── hamt.go ├── hamt_test.go └── popcount.go /README.md: -------------------------------------------------------------------------------- 1 | hamt 2 | ==== 3 | 4 | Golang Hash Array Map Trie -------------------------------------------------------------------------------- /hamt.go: -------------------------------------------------------------------------------- 1 | package hamt 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "hash/fnv" 8 | ) 9 | 10 | const ( 11 | fanoutLog2 = 6 12 | fanout uint = 1 << fanoutLog2 13 | fanMask uint = fanout -1 14 | maxDepth = 60/fanoutLog2 15 | keyNotFound = "Key not found" 16 | ) 17 | 18 | type Key []byte 19 | 20 | type node interface { 21 | assoc(shift int, hash uint64, key Key, value interface{}) (last node, leaf *valueNode) 22 | without(shift int, hash uint64, key Key) node 23 | find(shift int, hash uint64, key Key) (value interface{}, err error) 24 | pos() uint64 25 | } 26 | 27 | type PersistentMap struct { 28 | root *bitmapNode 29 | // collision map[uint]interface{} 30 | } 31 | 32 | type valueNode struct { 33 | key Key 34 | hash uint64 35 | value interface{} 36 | bitpos uint64 37 | } 38 | 39 | type bitmapNode struct { 40 | childBitmap uint64 41 | children []node 42 | bitpos uint64 43 | } 44 | 45 | func (n *valueNode) assoc(shift int, hash uint64, key Key, val interface{}) (last node, leaf *valueNode) { 46 | if n.hash == hash { 47 | n.value = val 48 | last = n 49 | leaf = n 50 | } else { 51 | nn := &bitmapNode{0, make([]node, 2, 2), n.pos()} 52 | last = nn 53 | nn.assoc(shift, n.hash, key, n.value) 54 | _, leaf = nn.assoc(shift, hash, key, val) 55 | } 56 | 57 | return last, leaf 58 | } 59 | 60 | func (n *valueNode) without(shift int, hash uint64, key Key) node { 61 | return n 62 | } 63 | 64 | func (n *valueNode) find(shift int, hash uint64, key Key) (value interface{}, err error) { 65 | if hash == n.hash { 66 | value = n.value 67 | } else { 68 | err = errors.New(keyNotFound) 69 | } 70 | return value, err 71 | } 72 | 73 | func (n *valueNode) pos() uint64 { 74 | return n.bitpos 75 | } 76 | 77 | func (n *bitmapNode) assoc(shift int, hash uint64, key Key, val interface{}) (last node, leaf *valueNode) { 78 | bitsToShift := uint(shift*fanoutLog2) 79 | pos := bitpos(hash, bitsToShift) 80 | 81 | if (pos & n.childBitmap) == 0 { //nothing in slot, not found 82 | //mark our slot taken and xpand our children 83 | n.childBitmap |= pos 84 | newChildren := make([]node, (len(n.children) + 1)) 85 | 86 | newChildIndex := n.index(pos) 87 | newChild := &valueNode{key, hash, val, pos} 88 | newChildren [newChildIndex] = newChild 89 | 90 | for _, c := range n.children { 91 | if c != nil { 92 | oldChildNewIndex := n.index(c.pos()) 93 | newChildren[oldChildNewIndex] = c 94 | } 95 | } 96 | 97 | n.children = newChildren 98 | last = n 99 | leaf = newChild 100 | } else { 101 | index := n.index(pos) 102 | nodeAtIndex := n.children[index] 103 | last, leaf = nodeAtIndex.assoc(shift +1, hash, key, val) 104 | 105 | if _, isValNode := nodeAtIndex.(*valueNode); isValNode { 106 | n.children[index] = last 107 | } 108 | } 109 | return last, leaf 110 | } 111 | 112 | func (n *bitmapNode) without(shift int, hash uint64, key Key) node { 113 | return n 114 | } 115 | 116 | func (n *bitmapNode) find(shift int, hash uint64, key Key) (value interface{}, err error) { 117 | bitsToShift := uint(shift*fanoutLog2) 118 | pos := bitpos(hash, bitsToShift) 119 | if cMap := n.childBitmap; (pos & cMap) == 0 { //nothing in slot, not found 120 | err = errors.New(keyNotFound) 121 | } else { 122 | index := n.index(pos) 123 | if int(index) >= len(n.children) { 124 | err = errors.New("Keys computed index is larger than children") 125 | } else { 126 | value, err = n.children[index].find(shift + 1, hash, key) 127 | } 128 | } 129 | 130 | return value, err 131 | } 132 | 133 | func (n *bitmapNode) pos() uint64 { 134 | return n.bitpos 135 | } 136 | 137 | //Shift key hash until leaf with matching key is found or key is not found 138 | func (t *PersistentMap) Get(key Key) (value interface{}, err error) { 139 | //Hash our key and look for it in the root 140 | hash := hash(key) 141 | value, err = t.root.find(0, hash, key) 142 | 143 | return value, err 144 | } 145 | 146 | func (t *PersistentMap) Insert(key Key, value interface{}) (n node) { 147 | hash := hash(key) 148 | _, n = t.root.assoc(0, hash, key, value) 149 | 150 | return n 151 | } 152 | 153 | func New() *PersistentMap { 154 | return &PersistentMap { 155 | root: &bitmapNode{}, 156 | } 157 | } 158 | 159 | func StringKey(k string) Key { 160 | return Key([]byte(k)) 161 | } 162 | 163 | func IntKey(i int) Key { 164 | buf := new(bytes.Buffer) 165 | binary.Write(buf, binary.LittleEndian, i) 166 | return Key(buf.Bytes()) 167 | } 168 | 169 | func hash(a []byte) uint64 { 170 | h := fnv.New64() 171 | h.Write(a) 172 | 173 | return h.Sum64() 174 | } 175 | 176 | func shift(hash uint64, shift uint) uint64 { 177 | if shift == 0 { 178 | return hash 179 | } 180 | return hash >> shift 181 | } 182 | 183 | func mask(hash uint64, bshift uint) uint { 184 | return uint(shift(hash, bshift) & uint64(fanMask)) 185 | } 186 | 187 | func bitpos(hash uint64, bshift uint) uint64 { 188 | return 1 << mask(hash, bshift) 189 | } 190 | 191 | func (n *bitmapNode) index(onebitset uint64) uint { 192 | return popcount_2(n.childBitmap & (onebitset - 1)) 193 | } 194 | -------------------------------------------------------------------------------- /hamt_test.go: -------------------------------------------------------------------------------- 1 | package hamt 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "testing" 8 | ) 9 | 10 | func TestEmptyTrie(t *testing.T) { 11 | trie := New() 12 | 13 | _,e := trie.Get([]byte("EMPTY")) 14 | 15 | if e==nil { 16 | t.Errorf("Not finding a key, which we wont find in an empty trie, must return an err") 17 | } 18 | } 19 | 20 | func TestInsertIntoTrie(t *testing.T) { 21 | key1 := []byte("store key1") 22 | key2 := []byte("store key2") 23 | 24 | v1 := "value 1" 25 | v2 := "value 2" 26 | 27 | 28 | root := New() 29 | 30 | root.Insert( key1, v1 ) 31 | root.Insert( key2, v2 ) 32 | 33 | vg1, e1 := root.Get( key1 ) 34 | vg2, e2 := root.Get( key2 ) 35 | 36 | if vg1 != v1 || vg2 != v2 { 37 | t.Errorf("Set values for keys[%v=%v,%v=%v] do not match returned [%v=%v,%v=%v]",key1,v1,key2,v2,key1,vg1,key2,vg2) 38 | t.Errorf("error return: %v, %v", e1, e2) 39 | } 40 | 41 | intKeys := make([][]byte, 256) 42 | for i := 0; i < 256; i++ { 43 | intKeys[i] = IntKey(i) 44 | } 45 | 46 | stringKeys := make([][]byte, 256) 47 | for i := 0; i < 256; i++ { 48 | stringKeys[i] = StringKey(fmt.Sprint("String key", i)) 49 | } 50 | 51 | insertAndAssureStorageForKeys(intKeys, t) 52 | insertAndAssureStorageForKeys(stringKeys, t) 53 | } 54 | 55 | func insertAndAssureStorageForKeys(keys [][]byte, t *testing.T) { 56 | tree := New() 57 | makeValueForKey := func (key interface{}) string { return fmt.Sprintf("Value for %v", key) } 58 | 59 | for _, k := range keys { 60 | v := makeValueForKey(k) 61 | tree.Insert(k, v) 62 | gv, ge := tree.Get(k) 63 | if ge != nil || gv != v { 64 | t.Errorf("We blewit! %v not equal to %v for %v", v, gv, k) 65 | } 66 | } 67 | 68 | for _, k := range keys { 69 | expectedValue := makeValueForKey(k) 70 | gv, ge := tree.Get(k) 71 | if ge != nil || gv != expectedValue { 72 | t.Errorf("After full insert of keys we got inequality for key %v. Expected %v but got %v", k, expectedValue, gv) 73 | } 74 | } 75 | } 76 | 77 | 78 | func BenchmarkGoMapIntInsert(b *testing.B) { 79 | m := make(map[int] int) 80 | for i := 0; i < b.N; i++ { 81 | m[i] = i 82 | } 83 | } 84 | 85 | func BenchmarkGoMapStringInsert(b *testing.B) { 86 | m := make(map[string] string) 87 | for i := 0; i < b.N; i++ { 88 | k := fmt.Sprintf("String Key %v", i) 89 | v := fmt.Sprintf("String Val %v", i) 90 | m[k] = v 91 | } 92 | } 93 | 94 | 95 | func BenchmarkHamtIntInsert(b *testing.B) { 96 | t := New() 97 | for i := 0; i < b.N; i++ { 98 | buf := new(bytes.Buffer) 99 | binary.Write(buf, binary.LittleEndian, i) 100 | t.Insert(buf.Bytes(), i) 101 | } 102 | } 103 | 104 | func BenchmarkHamtStringInsert(b *testing.B) { 105 | t := New() 106 | for i := 0; i < b.N; i++ { 107 | k := fmt.Sprintf("String Key %v", i) 108 | v := fmt.Sprintf("String Val %v", i) 109 | t.Insert([]byte(k), v) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /popcount.go: -------------------------------------------------------------------------------- 1 | package hamt 2 | 3 | // Hamming weight pulled from wikipedia 4 | // http://en.wikipedia.org/wiki/Hamming_weight 5 | // this includes the crazy uint64s and popcount 6 | const ( 7 | m1 = 0x5555555555555555 //binary: 0101... 8 | m2 = 0x3333333333333333 //binary: 00110011.. 9 | m4 = 0x0f0f0f0f0f0f0f0f //binary: 4 zeros, 4 ones ... 10 | m8 = 0x00ff00ff00ff00ff //binary: 8 zeros, 8 ones ... 11 | m16 = 0x0000ffff0000ffff //binary: 16 zeros, 16 ones ... 12 | m32 = 0x00000000ffffffff //binary: 32 zeros, 32 ones 13 | hff = 0xffffffffffffffff //binary: all ones 14 | h01 = 0x0101010101010101 //the sum of 256 to the power of 0,1,2,3... 15 | ) 16 | 17 | //This uses fewer arithmetic operations than any other known 18 | //implementation on machines with slow multiplication. 19 | //It uses 17 arithmetic operations. 20 | func popcount_2(x uint64) uint { 21 | x -= (x >> 1) & m1 //put count of each 2 bits into those 2 bits 22 | x = (x & m2) + ((x >> 2) & m2) //put count of each 4 bits into those 4 bits 23 | x = (x + (x >> 4)) & m4 //put count of each 8 bits into those 8 bits 24 | x += x >> 8 //put count of each 16 bits into their lowest 8 bits 25 | x += x >> 16 //put count of each 32 bits into their lowest 8 bits 26 | x += x >> 32 //put count of each 64 bits into their lowest 8 bits 27 | return uint(x & 0x7f) 28 | } 29 | 30 | //This uses fewer arithmetic operations than any other known 31 | //implementation on machines with fast multiplication. 32 | //It uses 12 arithmetic operations, one of which is a multiply. 33 | func popcount_3(x uint64) uint{ 34 | x -= (x >> 1) & m1 //put count of each 2 bits into those 2 bits 35 | x = (x & m2) + ((x >> 2) & m2) //put count of each 4 bits into those 4 bits 36 | x = (x + (x >> 4)) & m4 //put count of each 8 bits into those 8 bits 37 | return uint((x * h01)>>56) //returns left 8 bits of x + (x<<8) + (x<<16) + (x<<24) + ... 38 | } 39 | 40 | --------------------------------------------------------------------------------