├── README.md ├── bencode.go ├── crawler └── main.go ├── handler.go ├── krpc.go ├── protocol.go ├── routing.go ├── test ├── test_bencode.go ├── test_krpc.go ├── test_routing.go └── test_token.go └── token.go /README.md: -------------------------------------------------------------------------------- 1 | # dht 2 | 3 | ## 介绍 4 | 5 | DHT是去中心化P2P下载的重要技术,它避免了BT下载依赖中心tracker节点来获取拥有资源的节点列表。 6 | 7 | DHT通过P2P的方式传播资源的拥有者信息,而不在依靠tracker,而这个传播的算法就是DHT。 8 | 9 | DHT并不是下载协议,最终资源下载仍旧是BT协议(种子),DHT是在帮助我们在P2P网络种找到下载地址。 10 | 11 | 具体参考官方论文:[DHT Protocol](http://www.bittorrent.org/beps/bep_0005.html) 12 | 13 | ## 计划 14 | 15 | 分步骤实现一个DHT协议的种子爬虫,因为涉及的知识点比较多,一次性实现也不是很有数,所以暂定一个计划: 16 | 17 | * 实现bencode协议的序列化/反序列化(bencode.go) 18 | * 创建UDP SOCKET,尝试向大型的DHT节点发送4种协议的请求,并接受应答进行观察(krpc.go) 19 | * 实现路由表Routing table,利用UDP请求/应答得到的其他Node,维护自己的亲近朋友列表(routing.go) 20 | * 接受外部应答,更新Routing table中活跃状态,或增加节点(不做了,因为爬虫没必要) 21 | * 接受外部调用,返回Routing table信息,更新活跃状态,增加节点(不做了,因为爬虫没必要) 22 | * 将收到的announce peers中的infohash与peer下载地址,先打印成日志保存(已实现) 23 | 24 | -------------------------------------------------------------------------------- /bencode.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "bytes" 7 | "sort" 8 | "unicode" 9 | "unicode/utf8" 10 | "strconv" 11 | ) 12 | 13 | /** 14 | 维基百科:https://zh.wikipedia.org/wiki/Bencode 15 | 16 | 字符串(允许二进制): 长度:字符串 17 | 整形:i数字e 18 | 列表:l嵌套内容e, 嵌套内容为其他bencode编码的对象 19 | 字典:d嵌套内容e, 嵌套内容为成对出现的beancode编码的string key和对象 20 | */ 21 | 22 | func encodeString(data string) (encData []byte, err error) { 23 | encData = []byte(fmt.Sprintf("%d:%s", len(data), data)) 24 | return 25 | } 26 | 27 | func encodeInt(data int) (encData []byte, err error) { 28 | encData = []byte(fmt.Sprintf("i%de", data)) 29 | return 30 | } 31 | 32 | func encodeList(data []interface{}) (encData []byte, err error){ 33 | var ( 34 | encList = [][]byte{[]byte("l")} 35 | encElem []byte 36 | ) 37 | 38 | for _, elem := range data { 39 | if encElem, err = Encode(elem); err != nil { 40 | return 41 | } 42 | encList = append(encList, encElem) 43 | } 44 | encList = append(encList, []byte("e")) 45 | encData = bytes.Join(encList, []byte("")) 46 | return 47 | } 48 | 49 | func encodeDict(data map[string]interface{}) (encData []byte, err error) { 50 | var ( 51 | encMap = map[string][]byte{} 52 | encKey []byte 53 | encValue []byte 54 | ) 55 | for key, value := range data { 56 | if encValue, err = Encode(value); err != nil { 57 | return 58 | } 59 | encMap[key] = encValue 60 | } 61 | 62 | sortedKeys := make([]string, 0, len(encMap)) 63 | for key, _ := range data { 64 | sortedKeys = append(sortedKeys, key) 65 | } 66 | sort.Strings(sortedKeys) 67 | 68 | encList := [][]byte{[]byte("d")} 69 | for _, key := range sortedKeys { 70 | if encKey, err = Encode(key); err != nil { 71 | return 72 | } 73 | encList = append(encList, encKey, encMap[key]) 74 | } 75 | encList = append(encList, []byte("e")) 76 | encData = bytes.Join(encList, []byte("")) 77 | return 78 | } 79 | 80 | /** 81 | 编码函数 82 | */ 83 | func Encode(data interface{}) ([]byte, error) { 84 | switch data.(type) { 85 | case string: 86 | return encodeString(data.(string)) 87 | case int: 88 | return encodeInt(data.(int)) 89 | case []interface{}: 90 | return encodeList(data.([]interface{})) 91 | case map[string]interface{}: 92 | return encodeDict(data.(map[string]interface{})) 93 | default: 94 | return nil, errors.New("invalid type") 95 | } 96 | } 97 | 98 | func decodeDict(data []byte) (decData interface{}, size int, err error) { 99 | var ( 100 | curIndex int 101 | elemMap map[string]interface{} = map[string]interface{}{} 102 | key interface{} 103 | strKey string 104 | value interface{} 105 | keySize int 106 | valueSize int 107 | isString bool 108 | ) 109 | if len(data) < 2 || data[0] != 'd' { 110 | goto ERROR 111 | } 112 | 113 | curIndex = 1 114 | for curIndex < len(data) { 115 | // 判断下一个字节是否为字典结束符 116 | if data[curIndex] == 'e' { 117 | break 118 | } 119 | // 解析string key 120 | if key, keySize, err = decode(data[curIndex:]); err != nil { 121 | goto ERROR 122 | } 123 | if strKey, isString = key.(string); !isString { 124 | goto ERROR 125 | } 126 | curIndex += keySize 127 | // 解析value 128 | if value, valueSize, err = decode(data[curIndex:]); err != nil { 129 | goto ERROR 130 | } 131 | elemMap[strKey] = value 132 | curIndex += valueSize 133 | } 134 | if curIndex == len(data) { // 未找到e结束符 135 | goto ERROR 136 | } 137 | return elemMap, curIndex + 1, nil 138 | 139 | ERROR: 140 | return nil, 0, errors.New("invalid dict") 141 | } 142 | 143 | func decodeList(data []byte) (decData interface{}, size int, err error) { 144 | var ( 145 | curIndex int 146 | elemList []interface{} 147 | elem interface{} 148 | elemSize int 149 | ) 150 | if len(data) < 2 || data[0] != 'l' { 151 | goto ERROR 152 | } 153 | 154 | curIndex = 1 155 | for curIndex < len(data) { 156 | // 判断下一个字节是否为列表结束符 157 | if data[curIndex] == 'e' { 158 | break 159 | } 160 | if elem, elemSize, err = decode(data[curIndex:]); err != nil { 161 | goto ERROR 162 | } 163 | elemList = append(elemList, elem) 164 | curIndex += elemSize 165 | } 166 | if curIndex == len(data) { // 未找到e结束符 167 | goto ERROR 168 | } 169 | return elemList, curIndex + 1, nil 170 | 171 | ERROR: 172 | return nil, 0, errors.New("invalid list") 173 | } 174 | 175 | func decodeInt(data []byte) (decData interface{}, size int, err error) { 176 | var ( 177 | value int 178 | endIndex int 179 | ) 180 | if len(data) < 3 || data[0] != 'i' { 181 | goto ERROR 182 | } 183 | 184 | // 找出utf-8字符串序列中的字母e(必须使用rune,因为utf-8的字符由多字节组成,可能包含e) 185 | if endIndex = bytes.IndexRune(data, 'e'); endIndex == -1 { 186 | goto ERROR 187 | } 188 | 189 | // 解析中间部分为整形 190 | if value, err = strconv.Atoi(string(data[1:endIndex])); err != nil { 191 | goto ERROR 192 | } 193 | return value, endIndex + 1, nil 194 | ERROR: 195 | return nil, 0, errors.New("invalid int") 196 | } 197 | 198 | func decodeString(data []byte) (decData interface{}, size int, err error) { 199 | var ( 200 | value string 201 | valueLen int 202 | endIndex int 203 | ) 204 | if len(data) < 2 { 205 | goto ERROR 206 | } 207 | 208 | // 找出utf-8字符串序列中的字母: 209 | if endIndex = bytes.IndexRune(data, ':'); endIndex == -1 { 210 | goto ERROR 211 | } 212 | 213 | // :左侧解析为字符串长度 214 | if valueLen, err = strconv.Atoi(string(data[:endIndex])); err != nil { 215 | goto ERROR 216 | } 217 | 218 | // :右侧必须有valueLen个字节 219 | if endIndex + valueLen + 1 > len(data) { 220 | goto ERROR 221 | } 222 | 223 | value = string(data[endIndex + 1 : endIndex + 1 + valueLen]) 224 | size = endIndex + 1 + len(value) 225 | return value, size, nil 226 | ERROR: 227 | return nil, 0, errors.New("invalid string") 228 | } 229 | 230 | func decode(data []byte) (decData interface{}, size int, err error) { 231 | if len(data) != 0 { 232 | dataType, _ := utf8.DecodeRune(data) 233 | if dataType == 'd' { 234 | return decodeDict(data) 235 | } else if dataType == 'l' { 236 | return decodeList(data) 237 | } else if dataType == 'i' { 238 | return decodeInt(data) 239 | } else if unicode.IsDigit(dataType) { 240 | return decodeString(data) 241 | } 242 | } 243 | return nil, 0, errors.New("invalid data") 244 | } 245 | 246 | /** 247 | 解码函数 248 | */ 249 | func Decode(data []byte) (decData interface{}, err error) { 250 | var size int 251 | decData, size, err = decode(data) 252 | if size != len(data) { 253 | return nil, errors.New("invalid data") 254 | } 255 | return decData, err 256 | } 257 | 258 | -------------------------------------------------------------------------------- /crawler/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/owenliang/dht" 5 | "os" 6 | "fmt" 7 | "context" 8 | "time" 9 | ) 10 | 11 | func main() { 12 | var ( 13 | krpc *dht.KRPC 14 | err error 15 | nodes = make(chan *dht.CompactNode, 10000) 16 | bootstrap = "router.bittorrent.com:6881" 17 | ) 18 | if krpc, err = dht.CreateKPRC(); err != nil { 19 | fmt.Println(err) 20 | os.Exit(1) 21 | } 22 | 23 | // 实际上, 做一个爬虫并不需要维护路由表, 而只需要尽快加入到更多节点的路由表中 24 | 25 | // 不停的find_node, 让更多人认识我 26 | nodes <- &dht.CompactNode{bootstrap, ""} 27 | for i := 0; i < 3000; i++ { 28 | go func() { 29 | var ( 30 | err error 31 | node *dht.CompactNode 32 | findNodeReq *dht.FindNodeRequest 33 | findNodeResp *dht.FindNodeResponse 34 | compactNode *dht.CompactNode 35 | ) 36 | for { 37 | node = nil 38 | select { 39 | case node = <-nodes: 40 | default: 41 | } 42 | if node == nil { 43 | node = &dht.CompactNode{bootstrap, ""} 44 | } 45 | 46 | findNodeReq = dht.NewFindNodeRequest() 47 | findNodeReq.Target = dht.GenNodeId() 48 | 49 | if findNodeResp, err = krpc.FindNode(context.Background(), findNodeReq, node.Address); err == nil { 50 | for _, compactNode = range findNodeResp.Nodes { 51 | // 没必要维护路由表 52 | // dht.GetRoutingTable().InsertNode(compactNode) 53 | select { 54 | case nodes <- compactNode: 55 | default: 56 | } 57 | } 58 | } 59 | } 60 | }() 61 | } 62 | for { 63 | time.Sleep(1 * time.Second) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "net" 5 | "errors" 6 | "fmt" 7 | "encoding/hex" 8 | ) 9 | 10 | func ActiveNode(addDict map[string]interface{}, packetFrom *net.UDPAddr) { 11 | var ( 12 | iField interface{} 13 | id string 14 | exist bool 15 | typeOk bool 16 | ) 17 | if iField, exist = addDict["id"]; !exist { 18 | return 19 | } 20 | if id, typeOk = iField.(string); !typeOk { 21 | return 22 | } 23 | GetRoutingTable().InsertNode(NewCompactNode(id, packetFrom)) 24 | } 25 | 26 | func HandlePing(transactionId string, addDict map[string]interface{}, packetFrom *net.UDPAddr) ([]byte, error) { 27 | ActiveNode(addDict, packetFrom) 28 | 29 | resp := &PingResponse{} 30 | resp.TransactionId = transactionId 31 | return resp.Serialize() 32 | } 33 | 34 | func HandleFindNode(transactionId string, addDict map[string]interface{}, packetFrom *net.UDPAddr) ([]byte, error) { 35 | var ( 36 | iField interface{} 37 | target string 38 | exist bool 39 | typeOk bool 40 | targetNode *CompactNode 41 | ) 42 | 43 | ActiveNode(addDict, packetFrom) 44 | 45 | resp := &FindNodeResponse{} 46 | resp.TransactionId = transactionId 47 | 48 | if iField, exist = addDict["target"]; !exist { 49 | return nil, errors.New("missing target field") 50 | } 51 | 52 | if target, typeOk = iField.(string); !typeOk { 53 | return nil, errors.New("target type invalid") 54 | } 55 | 56 | if targetNode = GetRoutingTable().FindNode(target); targetNode != nil { 57 | resp.Nodes = make([]*CompactNode, 1) 58 | resp.Nodes[0] = targetNode 59 | } else { 60 | resp.Nodes = GetRoutingTable().ClosestNodes(target) 61 | } 62 | return resp.Serialize() 63 | } 64 | 65 | func HandleGetPeer(transactionId string, addDict map[string]interface{}, packetFrom *net.UDPAddr) ([]byte, error) { 66 | var ( 67 | iField interface{} 68 | infoHash string 69 | exist bool 70 | typeOk bool 71 | ) 72 | 73 | ActiveNode(addDict, packetFrom) 74 | 75 | resp := &GetPeersResponse{} 76 | resp.TransactionId = transactionId 77 | 78 | if iField, exist = addDict["info_hash"]; !exist { 79 | return nil, errors.New("missing info_hash field") 80 | } 81 | 82 | if infoHash, typeOk = iField.(string); !typeOk { 83 | return nil, errors.New("info_hash type invalid") 84 | } 85 | 86 | // 暂时没保存peer,只能找到nodes 87 | resp.Nodes = GetRoutingTable().ClosestNodes(infoHash) 88 | 89 | return resp.Serialize() 90 | } 91 | 92 | func HandleAnnouncePeer(transactionId string, addDict map[string]interface{}, packetFrom *net.UDPAddr) ([]byte, error) { 93 | var ( 94 | iField interface{} 95 | infoHash string 96 | token string 97 | impliedPort int = 0 98 | port int 99 | exist bool 100 | typeOk bool 101 | ) 102 | 103 | ActiveNode(addDict, packetFrom) 104 | 105 | resp := &AnnouncePeerResponse{} 106 | resp.TransactionId = transactionId 107 | 108 | if iField, exist = addDict["info_hash"]; !exist { 109 | return nil, errors.New("missing info_hash field") 110 | } 111 | if infoHash, typeOk = iField.(string); !typeOk { 112 | return nil, errors.New("info_hash type invalid") 113 | } 114 | 115 | if iField, exist = addDict["token"]; !exist { 116 | return nil, errors.New("missing token field") 117 | } 118 | if token, typeOk = iField.(string); !typeOk { 119 | return nil, errors.New("token type invalid") 120 | } 121 | 122 | if iField, exist = addDict["implied_port"]; exist { 123 | if impliedPort, typeOk = iField.(int); !typeOk { 124 | return nil, errors.New("implied_port type invalid") 125 | } 126 | } 127 | 128 | // 解析port 129 | if impliedPort == 1 { 130 | if iField, exist = addDict["port"]; !exist { 131 | return nil, errors.New("missing port field") 132 | } 133 | if port, typeOk = iField.(int); !typeOk { 134 | return nil, errors.New("port type invalid") 135 | } 136 | } else { 137 | port = int(packetFrom.Port) 138 | } 139 | 140 | // 校验token 141 | if !GetTokenManager().ValidateToken(token) { 142 | return nil, errors.New("token invalid") 143 | } 144 | 145 | // 保存peerinfo, 后续用于抓取种子 146 | HandlePeerInfo(infoHash, packetFrom.IP, port) 147 | 148 | return resp.Serialize() 149 | } 150 | 151 | func HandlePeerInfo(infoHash string, ip net.IP, port int) { 152 | fmt.Println( "magnet:?xt=urn:btih:" + hex.EncodeToString([]byte(infoHash)), ip, ":", port) 153 | } 154 | -------------------------------------------------------------------------------- /krpc.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "context" 7 | "time" 8 | "errors" 9 | "runtime" 10 | ) 11 | 12 | type KRPCContext struct { 13 | transactionId string // 请求ID 14 | request interface{} // 请求protocol对象 15 | encoded []byte // 序列化请求 16 | requestTo *net.UDPAddr // 目标地址 17 | 18 | errCode int // 错误码 19 | errMsg string // 错误信息 20 | resDict map[string]interface{} // r字典 21 | responseFrom *net.UDPAddr // 发送应答的地址 22 | 23 | finishNotify chan byte // 收到应答后唤醒 24 | } 25 | 26 | type KRPCResponse struct { 27 | encoded []byte // 序列化应答 28 | responseTo *net.UDPAddr // 回复地址 29 | } 30 | 31 | type KRPCPacket struct { 32 | encoded []byte // 序列化的包 33 | packetFrom *net.UDPAddr // 来源地址 34 | } 35 | 36 | type KRPC struct { 37 | conn *net.UDPConn 38 | 39 | mutex sync.Mutex 40 | reqContext map[string]*KRPCContext // 等待应答的请求 41 | 42 | reqQueue chan *KRPCContext // 发送请求队列 43 | resQueue chan *KRPCResponse // 发送应答队列 44 | 45 | procQueue chan *KRPCPacket // 处理外来包队列 46 | procPending chan byte // 请求处理堆积控制 47 | } 48 | 49 | func (krpc *KRPC)HandleResponse(transactionId string, benDict map[string]interface{}, packetFrom *net.UDPAddr) { 50 | var ( 51 | ctx *KRPCContext 52 | resDict map[string]interface{} 53 | iField interface{} 54 | exist bool 55 | typeOk bool 56 | ) 57 | 58 | if iField, exist = benDict["r"]; !exist { 59 | return 60 | } 61 | if resDict, typeOk = iField.(map[string]interface{}); !typeOk { 62 | return 63 | } 64 | 65 | // 寻找请求上下文 66 | { 67 | krpc.mutex.Lock() 68 | if ctx, exist = krpc.reqContext[transactionId]; exist { 69 | delete(krpc.reqContext, transactionId) 70 | } 71 | krpc.mutex.Unlock() 72 | } 73 | // 唤醒调用者进一步处理 74 | if ctx != nil { 75 | ctx.resDict = resDict 76 | ctx.responseFrom = packetFrom 77 | ctx.finishNotify <- 1 78 | } 79 | } 80 | 81 | func (krpc *KRPC)HandleError(transactionId string, benDict map[string]interface{}, packetFrom *net.UDPAddr) { 82 | var ( 83 | ctx *KRPCContext 84 | exist bool 85 | iField interface{} 86 | iList []interface{} 87 | typeOk bool 88 | 89 | errCode int 90 | errMsg string 91 | ) 92 | 93 | if iField, exist = benDict["e"]; !exist { 94 | return 95 | } 96 | if iList, typeOk = iField.([]interface{}); !typeOk { 97 | return 98 | } 99 | if len(iList) < 2 { 100 | return 101 | } 102 | if errCode, typeOk = iList[0].(int); !typeOk { 103 | return 104 | } 105 | if errMsg, typeOk = iList[1].(string); !typeOk { 106 | return 107 | } 108 | 109 | // 寻找请求上下文 110 | { 111 | krpc.mutex.Lock() 112 | if ctx, exist = krpc.reqContext[transactionId]; exist { 113 | delete(krpc.reqContext, transactionId) 114 | } 115 | krpc.mutex.Unlock() 116 | } 117 | // 唤醒调用者进一步处理 118 | if ctx != nil { 119 | ctx.errCode = errCode 120 | ctx.errMsg = errMsg 121 | ctx.resDict = nil 122 | ctx.responseFrom = packetFrom 123 | ctx.finishNotify <- 1 124 | } 125 | } 126 | 127 | func (krpc *KRPC)HandleRequest(transactionId string, benDict map[string]interface{}, packetFrom *net.UDPAddr) { 128 | var ( 129 | iField interface{} 130 | method string 131 | exist bool 132 | typeOk bool 133 | addDict map[string]interface{} 134 | respBytes []byte 135 | err error 136 | ) 137 | 138 | if iField, exist = benDict["q"]; !exist { 139 | return 140 | } 141 | if method, typeOk = iField.(string); !typeOk { 142 | return 143 | } 144 | 145 | if iField, exist = benDict["a"]; !exist { 146 | return 147 | } 148 | if addDict, typeOk = iField.(map[string]interface{}); !typeOk { 149 | return 150 | } 151 | 152 | select { 153 | case krpc.procPending <- 1: // 增加1个处理中的请求 154 | default: 155 | return 156 | } 157 | // 并发协程处理 158 | go func() { 159 | if method == "ping" { 160 | respBytes, err = HandlePing(transactionId, addDict, packetFrom) 161 | } else if method == "find_node" { 162 | respBytes, err = HandleFindNode(transactionId, addDict, packetFrom) 163 | } else if method == "get_peers" { 164 | respBytes, err = HandleGetPeer(transactionId, addDict, packetFrom) 165 | } else if method == "announce_peer" { 166 | respBytes, err = HandleAnnouncePeer(transactionId, addDict, packetFrom) 167 | } else { 168 | goto END 169 | } 170 | if err == nil { 171 | krpc.resQueue <- &KRPCResponse{encoded: respBytes, responseTo: packetFrom} 172 | } 173 | END: 174 | <- krpc.procPending // 处理完释放计数 175 | }() 176 | } 177 | 178 | func (krpc *KRPC)HandlePacket(data []byte, packetFrom *net.UDPAddr) { 179 | var ( 180 | err error 181 | 182 | bencode interface{} 183 | benDict map[string]interface{} 184 | 185 | transactionId string 186 | msgType string 187 | 188 | iField interface{} 189 | exist bool 190 | typeOk bool 191 | ) 192 | 193 | if bencode, err = Decode(data); err != nil { 194 | goto INVALID 195 | } 196 | 197 | // 提取: t(请求ID),y(请求,应答,错误) 198 | if benDict, typeOk = bencode.(map[string]interface{}); !typeOk { 199 | goto INVALID 200 | } 201 | 202 | if iField, exist = benDict["t"]; !exist { 203 | goto INVALID 204 | } 205 | if transactionId, typeOk = iField.(string); !typeOk { 206 | goto INVALID 207 | } 208 | 209 | if iField, exist = benDict["y"]; !exist { 210 | goto INVALID 211 | } 212 | if msgType, typeOk = iField.(string); !typeOk { 213 | goto INVALID 214 | } 215 | 216 | // 应答 217 | if msgType == "r" { 218 | krpc.HandleResponse(transactionId, benDict, packetFrom) 219 | } else if msgType == "e" { // 错误 220 | krpc.HandleError(transactionId, benDict, packetFrom) 221 | } else if msgType == "q" { // 请求 222 | krpc.HandleRequest(transactionId, benDict, packetFrom) 223 | } else { // 未知 224 | goto INVALID 225 | } 226 | return 227 | 228 | INVALID: 229 | // fmt.Println("INVALID", string(data)) 230 | } 231 | 232 | func (krpc *KRPC)ProcLoop() { 233 | var ( 234 | packet *KRPCPacket 235 | ) 236 | for { 237 | packet = <- krpc.procQueue 238 | krpc.HandlePacket(packet.encoded, packet.packetFrom) 239 | } 240 | } 241 | 242 | func (krpc *KRPC)ReadLoop() { 243 | var ( 244 | err error 245 | 246 | packetFrom *net.UDPAddr 247 | buffer []byte = make([]byte, 10000) 248 | bufSize int 249 | ) 250 | for { 251 | if bufSize, packetFrom, err = krpc.conn.ReadFromUDP(buffer); err != nil || bufSize == 0 { 252 | continue 253 | } 254 | 255 | data := make([]byte, bufSize) 256 | copy(data, buffer[:bufSize]) 257 | 258 | packet := &KRPCPacket{encoded: data, packetFrom: packetFrom} 259 | 260 | krpc.procQueue <- packet 261 | } 262 | } 263 | 264 | func (krpc *KRPC) SendLoop() { 265 | var ( 266 | ctx *KRPCContext 267 | resp *KRPCResponse 268 | ) 269 | for { 270 | select { 271 | case ctx = <-krpc.reqQueue: 272 | krpc.conn.WriteToUDP(ctx.encoded, ctx.requestTo) 273 | case resp = <- krpc.resQueue: 274 | krpc.conn.WriteToUDP(resp.encoded, resp.responseTo) 275 | } 276 | } 277 | } 278 | 279 | func CreateKPRC() (krpc *KRPC, err error){ 280 | krpc = &KRPC{} 281 | addr := net.UDPAddr{net.IPv4(0, 0, 0,0), 6881, ""} 282 | if krpc.conn, err = net.ListenUDP("udp4", &addr); err != nil { 283 | return nil, err 284 | } 285 | krpc.reqContext = make(map[string]*KRPCContext) 286 | krpc.reqQueue = make(chan *KRPCContext, 100000) 287 | krpc.resQueue = make(chan *KRPCResponse, 100000) 288 | krpc.procQueue = make(chan *KRPCPacket, 100000) 289 | krpc.procPending = make(chan byte, 100000) 290 | go krpc.SendLoop() 291 | go krpc.ReadLoop() 292 | for i := 0; i < runtime.NumCPU(); i++ { 293 | go krpc.ProcLoop() 294 | } 295 | return krpc, nil 296 | } 297 | 298 | func (krpc *KRPC) BurstRequest(userCtx context.Context, transactionId string, request interface{}, encoded []byte, address string) (ctxt *KRPCContext, err error) { 299 | var ( 300 | requestTo *net.UDPAddr 301 | isTimeout bool = false 302 | ) 303 | // 域名解析 304 | if requestTo, err = net.ResolveUDPAddr("udp4", address); err != nil { 305 | return 306 | } 307 | // 生成调用上下文 308 | ctx := &KRPCContext{ 309 | transactionId: transactionId, 310 | request: request, 311 | encoded: encoded, 312 | requestTo: requestTo, 313 | finishNotify: make(chan byte, 1), 314 | } 315 | // 注册调用 316 | { 317 | krpc.mutex.Lock() 318 | krpc.reqContext[transactionId] = ctx 319 | krpc.mutex.Unlock() 320 | } 321 | // 启动RPC超时 322 | timeoutCtx, cancelFunc := context.WithTimeout(userCtx, time.Duration(1) * time.Second) 323 | defer cancelFunc() 324 | select { 325 | case krpc.reqQueue <- ctx: // 排队请求 326 | case <- timeoutCtx.Done(): // 等待超时 327 | isTimeout = true 328 | } 329 | // 排队成功,等待应答 330 | if !isTimeout { 331 | select { 332 | case <- ctx.finishNotify: 333 | case <- timeoutCtx.Done(): 334 | isTimeout = true 335 | } 336 | } 337 | if isTimeout { 338 | { // 超时取消注册的上下文 339 | krpc.mutex.Lock() 340 | if _, exist := krpc.reqContext[transactionId]; exist { 341 | delete(krpc.reqContext, transactionId) 342 | } 343 | krpc.mutex.Unlock() 344 | } 345 | return nil, errors.New("request timeout") 346 | } 347 | return ctx, nil 348 | } 349 | 350 | func (krpc *KRPC) Ping(userCtx context.Context, request *PingRequest, address string) (response *PingResponse, err error) { 351 | var ( 352 | ctx *KRPCContext 353 | bytes []byte 354 | ) 355 | 356 | // 序列化 357 | protobuf := map[string]interface{}{} 358 | protobuf["t"] = request.TransactionId 359 | protobuf["y"] = request.Type 360 | protobuf["q"] = request.Method 361 | protobuf["a"] = map[string]interface{}{ 362 | "id": MyNodeId(), 363 | } 364 | if bytes, err = Encode(protobuf); err != nil { 365 | return 366 | } 367 | 368 | if ctx, err = krpc.BurstRequest(userCtx, request.TransactionId, request, bytes, address); err != nil { 369 | return 370 | } 371 | if ctx.errCode != 0 { 372 | return nil, errors.New(ctx.errMsg) 373 | } 374 | response, err = UnserializePingResponse(ctx.transactionId, ctx.resDict) 375 | return 376 | } 377 | 378 | func (krpc *KRPC) FindNode(userCtx context.Context, request *FindNodeRequest, address string) (response *FindNodeResponse, err error) { 379 | var ( 380 | ctx *KRPCContext 381 | bytes []byte 382 | ) 383 | 384 | // 序列化 385 | protobuf := map[string]interface{}{} 386 | protobuf["t"] = request.TransactionId 387 | protobuf["y"] = request.Type 388 | protobuf["q"] = request.Method 389 | protobuf["a"] = map[string]interface{}{ 390 | "id": MyNodeId(), 391 | "target": request.Target, 392 | } 393 | if bytes, err = Encode(protobuf); err != nil { 394 | return 395 | } 396 | 397 | if ctx, err = krpc.BurstRequest(userCtx, request.TransactionId, request, bytes, address); err != nil { 398 | return 399 | } 400 | if ctx.errCode != 0 { 401 | return nil, errors.New(ctx.errMsg) 402 | } 403 | response, err = UnserializeFindNodeResponse(ctx.transactionId, ctx.resDict) 404 | return 405 | } 406 | 407 | func (krpc *KRPC) GetPeers(userCtx context.Context, request *GetPeersRequest, address string) (response *GetPeersResponse, err error) { 408 | var ( 409 | ctx *KRPCContext 410 | bytes []byte 411 | ) 412 | 413 | // 序列化 414 | protobuf := map[string]interface{}{} 415 | protobuf["t"] = request.TransactionId 416 | protobuf["y"] = request.Type 417 | protobuf["q"] = request.Method 418 | protobuf["a"] = map[string]interface{}{ 419 | "id": MyNodeId(), 420 | "info_hash": request.InfoHash, 421 | } 422 | if bytes, err = Encode(protobuf); err != nil { 423 | return 424 | } 425 | if ctx, err = krpc.BurstRequest(userCtx, request.TransactionId, request, bytes, address); err != nil { 426 | return 427 | } 428 | response, err = UnserializeGetPeersResponse(ctx.transactionId, ctx.resDict) 429 | return 430 | } 431 | 432 | func (krpc *KRPC) AnnouncePeer(userCtx context.Context, request *AnnouncePeerRequest, address string) (response *AnnouncePeerResponse, err error) { 433 | var ( 434 | ctx *KRPCContext 435 | bytes []byte 436 | addition map[string]interface{} 437 | ) 438 | 439 | // 序列化 440 | protobuf := map[string]interface{}{} 441 | protobuf["t"] = request.TransactionId 442 | protobuf["y"] = request.Type 443 | protobuf["q"] = request.Method 444 | addition = map[string]interface{}{ 445 | "id": MyNodeId(), 446 | "implied_port": request.ImpliedPort, 447 | "info_hash": request.InfoHash, 448 | } 449 | if request.ImpliedPort != 0 { 450 | addition["port"] = request.Port 451 | } 452 | if len(request.Token) != 0 { 453 | addition["token"] = request.Token 454 | } 455 | protobuf["a"] = addition 456 | if bytes, err = Encode(protobuf); err != nil { 457 | return 458 | } 459 | if ctx, err = krpc.BurstRequest(userCtx, request.TransactionId, request, bytes, address); err != nil { 460 | return 461 | } 462 | response, err = UnserializeAnnouncePeerResponse(ctx.transactionId, ctx.resDict) 463 | return 464 | } -------------------------------------------------------------------------------- /protocol.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/sha1" 6 | "sync/atomic" 7 | "strconv" 8 | "sync" 9 | "time" 10 | "errors" 11 | "encoding/binary" 12 | "fmt" 13 | "encoding/hex" 14 | "net" 15 | ) 16 | 17 | type CompactNode struct { 18 | Address string 19 | Id string 20 | } 21 | 22 | type ProtocolBase struct { 23 | TransactionId string // t: 请求唯一ID标识 24 | Type string // y: 消息类型(q,r,e) 25 | Id string // id(request是请求方id, response是应答方id) 26 | } 27 | 28 | type BaseRequest struct { 29 | ProtocolBase 30 | Method string // 请求方法名 31 | } 32 | 33 | type BaseResponse struct { 34 | ProtocolBase 35 | } 36 | 37 | // PING 38 | type PingRequest struct { 39 | BaseRequest 40 | } 41 | 42 | type PingResponse struct { 43 | BaseResponse 44 | } 45 | 46 | // FIND NODE 47 | type FindNodeRequest struct { 48 | BaseRequest 49 | Target string 50 | } 51 | 52 | type FindNodeResponse struct { 53 | BaseResponse 54 | Nodes []*CompactNode 55 | } 56 | 57 | // GET PEERS 58 | type GetPeersRequest struct { 59 | BaseRequest 60 | InfoHash string 61 | } 62 | 63 | type GetPeersResponse struct { 64 | BaseResponse 65 | Token string 66 | Nodes []*CompactNode 67 | Values []string // ip:port address.. 68 | } 69 | 70 | // ANNOUNCE PEER 71 | type AnnouncePeerRequest struct { 72 | BaseRequest 73 | ImpliedPort int 74 | InfoHash string 75 | Port int 76 | Token string 77 | } 78 | 79 | type AnnouncePeerResponse struct { 80 | BaseResponse 81 | } 82 | 83 | func NewCompactNode(id string, addr *net.UDPAddr) (compactNode *CompactNode) { 84 | compactNode = &CompactNode{} 85 | compactNode.Id = id 86 | compactNode.Address = fmt.Sprintf("%d.%d.%d.%d:%d", addr.IP[0], addr.IP[1], addr.IP[2], addr.IP[3], addr.Port) 87 | return compactNode 88 | } 89 | 90 | func NewPingRequest() (request *PingRequest) { 91 | request = &PingRequest{} 92 | request.TransactionId = GenTransactionId() 93 | request.Type = "q" 94 | request.Method = "ping" 95 | request.Id = MyNodeId() 96 | return request 97 | } 98 | 99 | func NewFindNodeRequest() (request *FindNodeRequest) { 100 | request = &FindNodeRequest{} 101 | request.TransactionId = GenTransactionId() 102 | request.Type = "q" 103 | request.Method = "find_node" 104 | request.Id = MyNodeId() 105 | return request 106 | } 107 | 108 | func NewGetPeersRequest() (request *GetPeersRequest) { 109 | request = &GetPeersRequest{} 110 | request.TransactionId = GenTransactionId() 111 | request.Type = "q" 112 | request.Method = "get_peers" 113 | request.Id = MyNodeId() 114 | return request 115 | } 116 | 117 | func NewAnnouncePeerRequest() (request *AnnouncePeerRequest) { 118 | request = &AnnouncePeerRequest{} 119 | request.TransactionId = GenTransactionId() 120 | request.Type = "q" 121 | request.Method = "announce_peer" 122 | request.ImpliedPort = 0 123 | request.Id = MyNodeId() 124 | return request 125 | } 126 | 127 | func UnserializeCompactNode(nodeInfo string) (*CompactNode, error) { 128 | if len(nodeInfo) != 26 { 129 | return nil, errors.New("compact node invalid") 130 | } 131 | compactNode := &CompactNode{} 132 | compactNode.Id = nodeInfo[0:20] 133 | port := binary.BigEndian.Uint16([]byte(nodeInfo[24:26])) 134 | compactNode.Address = fmt.Sprintf("%d.%d.%d.%d:%d", nodeInfo[20], nodeInfo[21], nodeInfo[22], nodeInfo[23], port) 135 | return compactNode, nil 136 | } 137 | 138 | func UnserializePeerInfo(peerInfo string) (string, error) { 139 | if len(peerInfo) != 6 { 140 | return "", errors.New("compact peer invalid") 141 | } 142 | port := binary.BigEndian.Uint16([]byte(peerInfo[4:6])) 143 | return fmt.Sprintf("%d.%d.%d.%d:%d", peerInfo[0], peerInfo[1], peerInfo[2], peerInfo[3], port), nil 144 | } 145 | 146 | func UnserializePingResponse(transactionId string, resDict map[string]interface{}) (response *PingResponse, err error) { 147 | var ( 148 | iField interface{} 149 | exist bool 150 | typeOk bool 151 | ) 152 | 153 | response = &PingResponse{} 154 | response.TransactionId = transactionId 155 | response.Type = "r" 156 | 157 | if iField, exist = resDict["id"]; !exist { 158 | goto ERROR 159 | } 160 | if response.Id, typeOk = iField.(string); !typeOk { 161 | goto ERROR 162 | } 163 | return response, nil 164 | ERROR: 165 | return nil, errors.New("invalid ping response") 166 | } 167 | 168 | func UnserializeFindNodeResponse(transactionId string, resDict map[string]interface{}) (response *FindNodeResponse, err error) { 169 | var ( 170 | iField interface{} 171 | exist bool 172 | typeOk bool 173 | nodes string 174 | compactNode *CompactNode 175 | nodesSplit string 176 | ) 177 | 178 | response = &FindNodeResponse{} 179 | response.TransactionId = transactionId 180 | response.Type = "r" 181 | response.Nodes = make([]*CompactNode, 0) 182 | 183 | if iField, exist = resDict["id"]; !exist { 184 | goto ERROR 185 | } 186 | if response.Id, typeOk = iField.(string); !typeOk { 187 | goto ERROR 188 | } 189 | 190 | if iField, exist = resDict["nodes"]; exist { 191 | if nodes, typeOk = iField.(string); !typeOk { 192 | goto ERROR 193 | } 194 | if len(nodes) % 26 != 0 { 195 | goto ERROR 196 | } 197 | for i := 0; i <= len(nodes); i++ { 198 | if i % 26 == 0 && i != 0 { 199 | nodesSplit = nodes[i - 26:i] 200 | // closest nodes解析compactNode 201 | if compactNode, err = UnserializeCompactNode(nodesSplit); err != nil { 202 | goto ERROR 203 | } 204 | response.Nodes = append(response.Nodes, compactNode) 205 | } 206 | } 207 | } 208 | return response, nil 209 | ERROR: 210 | return nil, errors.New("invalid find_node response") 211 | } 212 | 213 | func UnserializeGetPeersResponse(transactionId string, resDict map[string]interface{}) (response *GetPeersResponse, err error) { 214 | var ( 215 | iField interface{} 216 | exist bool 217 | typeOk bool 218 | nodes string 219 | nodesSplit string 220 | compactNode *CompactNode 221 | peers []interface{} 222 | peerInfo string 223 | address string 224 | ) 225 | 226 | response = &GetPeersResponse{} 227 | response.TransactionId = transactionId 228 | response.Type = "r" 229 | response.Nodes = make([]*CompactNode, 0) 230 | response.Values = make([]string, 0) 231 | 232 | if iField, exist = resDict["id"]; !exist { 233 | goto ERROR 234 | } 235 | if response.Id, typeOk = iField.(string); !typeOk { 236 | goto ERROR 237 | } 238 | 239 | if iField, exist = resDict["values"]; exist { 240 | if peers, typeOk = iField.([]interface{}); !typeOk { 241 | goto ERROR 242 | } 243 | for i := 0; i < len(peers); i++ { 244 | if peerInfo, typeOk = peers[i].(string); !typeOk { 245 | goto ERROR 246 | } 247 | address, err = UnserializePeerInfo(peerInfo) 248 | if err != nil { 249 | goto ERROR 250 | } 251 | response.Values = append(response.Values, address) 252 | } 253 | } 254 | 255 | if iField, exist = resDict["nodes"]; exist { 256 | if nodes, typeOk = iField.(string); !typeOk { 257 | goto ERROR 258 | } 259 | if len(nodes) % 26 != 0 { 260 | goto ERROR 261 | } 262 | for i := 0; i <= len(nodes); i++ { 263 | if i % 26 == 0 && i != 0 { 264 | nodesSplit = nodes[i - 26:i] 265 | // target解析compactNode 266 | if compactNode, err = UnserializeCompactNode(nodesSplit); err != nil { 267 | goto ERROR 268 | } 269 | response.Nodes = append(response.Nodes, compactNode) 270 | } 271 | } 272 | } 273 | return response, nil 274 | ERROR: 275 | return nil, errors.New("invalid find_node response") 276 | } 277 | 278 | func UnserializeAnnouncePeerResponse(transactionId string, benDict map[string]interface{}) (response *AnnouncePeerResponse, err error) { 279 | err = errors.New("not implement") 280 | return 281 | } 282 | 283 | func (node *CompactNode) Serialize() (bytes []byte, err error) { 284 | var ( 285 | addr *net.UDPAddr 286 | ) 287 | if addr, err = net.ResolveUDPAddr("udp", node.Address); err != nil { 288 | return 289 | } 290 | bytes = make([]byte, 26) 291 | copy(bytes, node.Id) 292 | copy(bytes[20:24], addr.IP) 293 | binary.BigEndian.PutUint16(bytes[24:26], uint16(addr.Port)) 294 | return bytes, nil 295 | } 296 | 297 | func (response *PingResponse) Serialize() ([]byte, error){ 298 | resp := map[string]interface{}{} 299 | 300 | resp["t"] = response.TransactionId 301 | resp["y"] = "r" 302 | 303 | r := map[string]interface{}{} 304 | r["id"] = MyNodeId() 305 | 306 | resp["r"] = r 307 | return Encode(resp) 308 | } 309 | 310 | func (response *FindNodeResponse) Serialize() (bytes []byte, err error){ 311 | var ( 312 | compactNode *CompactNode 313 | compactNodeBytes []byte 314 | resp = map[string]interface{}{} 315 | r = map[string]interface{}{} 316 | nodesBytes []byte = nil 317 | ) 318 | 319 | resp["t"] = response.TransactionId 320 | resp["y"] = "r" 321 | 322 | r["id"] = MyNodeId() 323 | for _, compactNode = range response.Nodes { 324 | if compactNodeBytes, err = compactNode.Serialize(); err == nil { 325 | nodesBytes = append( nodesBytes, compactNodeBytes...) 326 | } 327 | } 328 | r["nodes"] = string(nodesBytes) 329 | 330 | resp["r"] = r 331 | return Encode(resp) 332 | } 333 | 334 | func (response *GetPeersResponse) Serialize() (bytes []byte, err error) { 335 | var ( 336 | addr *net.UDPAddr 337 | compactNode *CompactNode 338 | compactNodeBytes []byte 339 | resp = map[string]interface{}{} 340 | r = map[string]interface{}{} 341 | nodesBytes []byte = nil 342 | peerInfos = make([]string, 0) 343 | peerInfo string 344 | ) 345 | resp["t"] = response.TransactionId 346 | resp["y"] = "r" 347 | 348 | var compactPeerInfo [6]byte 349 | for _, peerInfo = range response.Values { 350 | if addr, err = net.ResolveUDPAddr("udp", peerInfo); err != nil { 351 | return 352 | } 353 | copy(compactPeerInfo[0:4], addr.IP) 354 | binary.BigEndian.PutUint16(bytes[4:6], uint16(addr.Port)) 355 | peerInfos = append(peerInfos, string(compactPeerInfo[:])) 356 | } 357 | if len(peerInfos) > 0 { 358 | r["values"] = peerInfos 359 | } else { 360 | for _, compactNode = range response.Nodes { 361 | if compactNodeBytes, err = compactNode.Serialize(); err == nil { 362 | nodesBytes = append( nodesBytes, compactNodeBytes...) 363 | } 364 | } 365 | r["nodes"] = string(nodesBytes) 366 | } 367 | 368 | r["id"] = MyNodeId() 369 | r["token"] = GetTokenManager().GetToken() 370 | 371 | resp["r"] = r 372 | return Encode(resp) 373 | } 374 | 375 | func (response *AnnouncePeerResponse) Serialize() ([]byte, error) { 376 | var ( 377 | resp = map[string]interface{}{} 378 | r = map[string]interface{}{} 379 | ) 380 | resp["t"] = response.TransactionId 381 | resp["y"] = "r" 382 | r["id"] = MyNodeId() 383 | 384 | resp["r"] = r 385 | return Encode(resp) 386 | } 387 | 388 | func (response *PingResponse) String() string { 389 | ret := "\n---PingResponse---\n" 390 | ret += "T=" + hex.EncodeToString([]byte(response.TransactionId)) + "\n" 391 | ret += "Id=" + hex.EncodeToString([]byte(response.Id)) + "\n" 392 | ret += "---------------------\n" 393 | return ret 394 | } 395 | 396 | func (compactNode *CompactNode) String() string { 397 | ret := "ID=" + hex.EncodeToString([]byte(compactNode.Id)) + " Addr=" + compactNode.Address + "\n" 398 | return ret 399 | } 400 | 401 | func (response *FindNodeResponse) String() string { 402 | ret := "\n---FindNodeResponse---\n" 403 | ret += "T=" + hex.EncodeToString([]byte(response.TransactionId)) + "\n" 404 | ret += "Id=" + hex.EncodeToString([]byte(response.Id)) + "\n" 405 | if len(response.Nodes) != 0 { 406 | ret += "Nodes=\n" 407 | for _, node := range response.Nodes { 408 | ret += "->" + node.String() 409 | } 410 | } 411 | ret += "---------------------\n" 412 | return ret 413 | } 414 | 415 | 416 | func (response *GetPeersResponse) String() string { 417 | ret := "\n---GetPeersResponse---\n" 418 | ret += "T=" + hex.EncodeToString([]byte(response.TransactionId)) + "\n" 419 | ret += "Id=" + hex.EncodeToString([]byte(response.Id)) + "\n" 420 | if len(response.Values) != 0 { 421 | ret += "Values=\n" 422 | for _, address := range response.Values { 423 | ret += "->" + address 424 | } 425 | } 426 | if len(response.Nodes) != 0 { 427 | ret += "Nodes=\n" 428 | for _, node := range response.Nodes { 429 | ret += "->" + node.String() 430 | } 431 | } 432 | ret += "---------------------\n" 433 | return ret 434 | } 435 | 436 | // 生成随机DHT NODE ID 437 | func GenNodeId() string { 438 | randBytes := make([]byte, 160) // 随机160字节, 然后sha1计算20字节二进制ID 439 | for { 440 | if _, err := rand.Read(randBytes); err == nil { 441 | sha1Bytes := sha1.Sum(randBytes) 442 | return string(sha1Bytes[:]) 443 | } 444 | } 445 | } 446 | 447 | // 我的DHT NODE ID 448 | var myNodeId string 449 | var initMyNodeId sync.Once 450 | 451 | func MyNodeId() string { 452 | initMyNodeId.Do(func() { 453 | myNodeId = GenNodeId() 454 | }) 455 | return myNodeId 456 | } 457 | 458 | // 生成请求唯一ID 459 | var myTransactionId int64 = 0 460 | var initTransactionId sync.Once 461 | 462 | func GenTransactionId() string { 463 | initTransactionId.Do(func() { 464 | myTransactionId = time.Now().UnixNano() 465 | }) 466 | return strconv.Itoa(int(atomic.AddInt64(&myTransactionId, 1))) 467 | } 468 | -------------------------------------------------------------------------------- /routing.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "time" 5 | "math/big" 6 | "sync" 7 | "sort" 8 | ) 9 | 10 | const ( 11 | KNODES = 8 // 每个桶保存8个节点 12 | MAX_FAIL_TIMES = 3 // 3次连续fail则标记bad 13 | 14 | // 节点状态 15 | NODE_STATUS_GOOD = 1 16 | NODE_STATUS_BAD = 2 17 | ) 18 | 19 | type Node struct { 20 | info *CompactNode // 节点地址 21 | lastActive int64 // 上次活跃时间 22 | failTimes int // 连续访问失败的次数, 超过3次就标记为bad 23 | status int // 状态: good, bad, questionable 24 | } 25 | 26 | type Bucket struct { 27 | nodes map[string]*Node 28 | min, max *big.Int 29 | lastActive int64 30 | } 31 | 32 | type RoutingTable struct { 33 | buckets []*Bucket 34 | mutex sync.Mutex 35 | } 36 | 37 | type ClosestNodes struct { 38 | target string 39 | nodes []*CompactNode 40 | } 41 | 42 | func (closest ClosestNodes) Len() int { 43 | return len(closest.nodes) 44 | } 45 | 46 | func (closest ClosestNodes) Swap(i, j int) { 47 | closest.nodes[i], closest.nodes[j] = closest.nodes[j], closest.nodes[i] 48 | } 49 | 50 | func (closest ClosestNodes) Less(i, j int) bool { 51 | leftId := nodeId2Int(closest.nodes[i].Id) 52 | rightId := nodeId2Int(closest.nodes[j].Id) 53 | targetId := nodeId2Int(closest.target) 54 | 55 | // 计算异或距离, 比较大小 56 | cmp := new(big.Int).Xor(leftId, targetId).Cmp( new(big.Int).Xor(rightId, targetId) ) 57 | return cmp < 0 58 | } 59 | 60 | func newBucket(min, max *big.Int) (bucket *Bucket) { 61 | bucket = &Bucket{} 62 | 63 | bucket.min = min 64 | bucket.max = max 65 | bucket.nodes = make(map[string]*Node) 66 | bucket.lastActive = time.Now().Unix() 67 | return 68 | } 69 | 70 | func nodeId2Int(nodeId string) *big.Int { 71 | return new(big.Int).SetBytes([]byte(nodeId)) 72 | } 73 | 74 | func (bucket *Bucket) size() int { 75 | return len(bucket.nodes) 76 | } 77 | 78 | func (bucket *Bucket) inRange(nodeId string) bool { 79 | intId := nodeId2Int(nodeId) 80 | return intId.Cmp(bucket.min) >= 0 && intId.Cmp(bucket.max) < 0 81 | } 82 | 83 | func (bucket *Bucket) insertNode(nodeInfo *CompactNode) bool { 84 | var ( 85 | node *Node 86 | exist bool 87 | ) 88 | if node, exist = bucket.nodes[nodeInfo.Id]; exist { 89 | goto REPLACE 90 | } 91 | for nodeId, node := range bucket.nodes { 92 | if node.status == NODE_STATUS_BAD { 93 | delete(bucket.nodes, nodeId) // 虽然是bad, 但删1个就好了 94 | break 95 | } 96 | } 97 | if len(bucket.nodes) == KNODES { 98 | return false 99 | } 100 | REPLACE: 101 | node = &Node{} 102 | node.info = nodeInfo 103 | node.status = NODE_STATUS_GOOD 104 | node.lastActive = time.Now().Unix() 105 | node.failTimes = 0 106 | bucket.nodes[nodeInfo.Id] = node 107 | bucket.lastActive = time.Now().Unix() 108 | return true 109 | } 110 | 111 | func rootBucket() (root *Bucket) { 112 | minId := big.NewInt(0) 113 | maxId := new(big.Int).Exp(big.NewInt(2), big.NewInt(160), nil) 114 | root = newBucket(minId, maxId) 115 | root.insertNode(&CompactNode{"", MyNodeId()}) 116 | return 117 | } 118 | 119 | // 路由表单例 120 | var routingTable *RoutingTable 121 | var initRoutingTableOnce sync.Once 122 | 123 | func GetRoutingTable() (*RoutingTable) { 124 | initRoutingTableOnce.Do(func () { 125 | routingTable = &RoutingTable{} 126 | routingTable.buckets = append(routingTable.buckets, rootBucket()) 127 | }) 128 | return routingTable 129 | } 130 | 131 | func (rt *RoutingTable) splitBucket(idx int) { 132 | toSplit := rt.buckets[idx] 133 | 134 | sumRange := new(big.Int).Add(toSplit.min, toSplit.max) 135 | mid := new(big.Int).Div(sumRange, big.NewInt(2)) 136 | 137 | rightBucket := newBucket(mid, toSplit.max) 138 | toSplit.max = mid 139 | 140 | // 原桶分裂成2个新桶 141 | for nodeId, node := range toSplit.nodes { 142 | if !toSplit.inRange(nodeId) { 143 | delete(toSplit.nodes, nodeId) 144 | rightBucket.nodes[nodeId] = node 145 | } 146 | } 147 | rightBucket.lastActive = toSplit.lastActive 148 | 149 | // 插入分裂后的桶 150 | rt.buckets = append(rt.buckets, nil) // 扩容 151 | insertIdx := idx + 1 152 | copy(rt.buckets[insertIdx + 1:], rt.buckets[insertIdx:]) 153 | rt.buckets[insertIdx] = rightBucket 154 | } 155 | 156 | func (rt *RoutingTable) findBucket(nodeId string) int { 157 | for i := 0; i < len(rt.buckets); i++ { 158 | if !rt.buckets[i].inRange(nodeId) { 159 | continue 160 | } 161 | return i 162 | } 163 | return -1 164 | } 165 | 166 | func (rt *RoutingTable) insertNode(nodeInfo *CompactNode) bool { 167 | if nodeInfo.Id == MyNodeId() { 168 | return true 169 | } 170 | 171 | idx := rt.findBucket(nodeInfo.Id) 172 | if idx < 0 { 173 | return false // never reach 174 | } 175 | 176 | if rt.buckets[idx].insertNode(nodeInfo) { // bucket没满插入成功 177 | return true 178 | } 179 | if !rt.buckets[idx].inRange(MyNodeId()) { // bucket不包含自身,无法分裂 180 | return false 181 | } 182 | rt.splitBucket(idx) 183 | return rt.insertNode(nodeInfo) 184 | } 185 | 186 | func (rt *RoutingTable) InsertNode(nodeInfo *CompactNode) bool { 187 | rt.mutex.Lock() 188 | defer rt.mutex.Unlock() 189 | return rt.insertNode(nodeInfo) 190 | } 191 | 192 | func (rt *RoutingTable) Size() int { 193 | rt.mutex.Lock() 194 | defer rt.mutex.Unlock() 195 | return len(rt.buckets) 196 | } 197 | 198 | func (rt *RoutingTable) Fail(nodeId string) { 199 | rt.mutex.Lock() 200 | defer rt.mutex.Unlock() 201 | 202 | if nodeId == MyNodeId() { 203 | return 204 | } 205 | 206 | idx := rt.findBucket(nodeId) 207 | if idx < 0 { 208 | return 209 | } 210 | 211 | if node, exist := rt.buckets[idx].nodes[nodeId]; exist { 212 | if node.status != NODE_STATUS_BAD { 213 | node.failTimes++ 214 | if node.failTimes >= MAX_FAIL_TIMES { 215 | node.status = NODE_STATUS_BAD 216 | } 217 | } 218 | } 219 | } 220 | 221 | func (rt *RoutingTable) FindNode(nodeId string) *CompactNode { 222 | rt.mutex.Lock() 223 | defer rt.mutex.Unlock() 224 | 225 | // 永远不返回自己 226 | if nodeId == MyNodeId() { 227 | return nil 228 | } 229 | 230 | idx := rt.findBucket(nodeId) 231 | if idx < 0 { 232 | return nil 233 | } 234 | 235 | nodes := rt.buckets[idx].nodes 236 | if node, exist := nodes[nodeId]; exist { 237 | return node.info 238 | } 239 | return nil 240 | } 241 | 242 | func (rt *RoutingTable) ClosestNodes(nodeId string) (nodes []*CompactNode) { 243 | nodes = make([]*CompactNode, 0) 244 | 245 | rt.mutex.Lock() 246 | defer rt.mutex.Unlock() 247 | 248 | idx := rt.findBucket(nodeId) 249 | if idx < 0 { 250 | return 251 | } 252 | 253 | for _, node := range rt.buckets[idx].nodes { 254 | if node.info.Id != nodeId { 255 | nodes = append(nodes, node.info) 256 | } 257 | } 258 | 259 | // 不足8个, 找周边的bucket 260 | if len(nodes) < KNODES { 261 | leftIdx := idx - 1 262 | rightIdx := idx + 1 263 | for len(nodes) < KNODES && (leftIdx >= 0 || rightIdx < len(rt.buckets)){ // 从左边和右边的邻居桶补一些进来 264 | if leftIdx >= 0 { 265 | for _, node := range rt.buckets[leftIdx].nodes { 266 | nodes = append(nodes, node.info) 267 | } 268 | } 269 | if rightIdx < len(rt.buckets) { 270 | for _, node := range rt.buckets[rightIdx].nodes { 271 | nodes = append(nodes, node.info) 272 | } 273 | } 274 | leftIdx-- 275 | rightIdx++ 276 | } 277 | } 278 | 279 | // 按距离排序 280 | closestNodes := ClosestNodes{} 281 | closestNodes.target = nodeId 282 | closestNodes.nodes = nodes 283 | sort.Sort(closestNodes) 284 | // 取最近的8个 285 | if len(nodes) > KNODES { 286 | nodes = nodes[:KNODES] 287 | } 288 | return 289 | } -------------------------------------------------------------------------------- /test/test_bencode.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/owenliang/dht" 6 | ) 7 | 8 | func main() { 9 | strData := "你好吗" 10 | if encData, err := dht.Encode(strData); err == nil { 11 | fmt.Println("encData=", string(encData)) 12 | } 13 | 14 | intData := 1024 15 | if encData, err := dht.Encode(intData); err == nil { 16 | fmt.Println("encData=", string(encData)) 17 | } 18 | 19 | listData := []interface{}{"你好吗", 1024,} 20 | if encData, err := dht.Encode(listData); err == nil { 21 | fmt.Println("encData=", string(encData)) 22 | } 23 | 24 | dictData := map[string]interface{}{"t":"aa", "y":"q", "q":"ping", "a": map[string]interface{}{"id":"abcdefghij0123456789"}} 25 | if encData, err := dht.Encode(dictData); err == nil { 26 | fmt.Println("encData=", string(encData)) 27 | } 28 | 29 | encIntData := []byte("i-12345e") 30 | if decData, err := dht.Decode(encIntData); err == nil { 31 | fmt.Println("decData=", decData) 32 | } 33 | 34 | encStrData := []byte("2:ab") 35 | if decData, err := dht.Decode(encStrData); err == nil { 36 | fmt.Println("decData=", decData) 37 | } 38 | 39 | encListData := []byte("l2:abl3:mmm1:ai5123eee") 40 | if decData, err := dht.Decode(encListData); err == nil { 41 | fmt.Println("decData=", decData) 42 | } 43 | 44 | encDictData := []byte("d2:abd2:cdl2:fgi5ed9:小电影i0eeeee") 45 | if decData, err := dht.Decode(encDictData); err == nil { 46 | fmt.Println("decData=", decData) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/test_krpc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/owenliang/dht" 5 | 6 | "os" 7 | "fmt" 8 | "context" 9 | ) 10 | 11 | func main() { 12 | var ( 13 | krpc *dht.KRPC 14 | 15 | pingRequest *dht.PingRequest 16 | pingResponse *dht.PingResponse 17 | 18 | findNodeRequest *dht.FindNodeRequest 19 | findNodeResponse *dht.FindNodeResponse 20 | 21 | getPeersRequest *dht.GetPeersRequest 22 | getPeersResponse *dht.GetPeersResponse 23 | 24 | announcePeerRequest *dht.AnnouncePeerRequest 25 | announcePeerResponse *dht.AnnouncePeerResponse 26 | 27 | address = "router.bittorrent.com:6881" 28 | //address = "dht.transmissionbt.com:6881" 29 | err error 30 | ) 31 | 32 | // krpc 33 | if krpc, err = dht.CreateKPRC(); err != nil { 34 | fmt.Println(err) 35 | os.Exit(1) 36 | } 37 | 38 | for { 39 | // ping 40 | pingRequest = dht.NewPingRequest() 41 | pingResponse, err = krpc.Ping(context.Background(), pingRequest, address) 42 | fmt.Println("Ping", pingResponse, err) 43 | 44 | // find node 45 | findNodeRequest = dht.NewFindNodeRequest() 46 | findNodeRequest.Target = dht.GenNodeId() 47 | findNodeResponse, err = krpc.FindNode(context.Background(), findNodeRequest, address) 48 | fmt.Println("FindNode", findNodeResponse, err) 49 | 50 | // get peers 51 | getPeersRequest = dht.NewGetPeersRequest() 52 | getPeersRequest.InfoHash = dht.GenNodeId() // 随机仿造一个20字节的info_hash 53 | getPeersResponse, err = krpc.GetPeers(context.Background(), getPeersRequest, address) 54 | fmt.Println("GetPeers", getPeersResponse, err) 55 | 56 | // announce peer 57 | announcePeerRequest = dht.NewAnnouncePeerRequest() 58 | announcePeerRequest.InfoHash = dht.GenNodeId() // 随机仿造一个20字节的info_hash 59 | announcePeerRequest.Token = dht.GenNodeId() // 随机伪造一个token 60 | announcePeerResponse, err = krpc.AnnouncePeer(context.Background(), announcePeerRequest, address) 61 | fmt.Println("AnnouncePeer", announcePeerResponse, err) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /test/test_routing.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/owenliang/dht" 5 | "fmt" 6 | ) 7 | 8 | func main() { 9 | rt := dht.GetRoutingTable() 10 | size := 0 11 | for { 12 | rt.InsertNode(&dht.CompactNode{"", dht.GenNodeId()}) 13 | if size != rt.Size() { 14 | size = rt.Size() 15 | fmt.Println(size) 16 | fmt.Println(rt.ClosestNodes(dht.MyNodeId())) 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /test/test_token.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/owenliang/dht" 5 | ) 6 | 7 | func main() { 8 | mgr := dht.CreateTokenManager() 9 | token := mgr.GetToken() 10 | for { 11 | if !mgr.ValidateToken(token) { 12 | break 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /token.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "sync" 5 | "math/rand" 6 | "time" 7 | ) 8 | 9 | type TokenManager struct { 10 | mutex sync.Mutex 11 | tokens [2]string 12 | } 13 | 14 | func genToken() string { 15 | randBytes := make([]byte, 160) 16 | for { 17 | if _, err := rand.Read(randBytes); err == nil { 18 | return string(randBytes) 19 | } 20 | } 21 | } 22 | 23 | func (mgr *TokenManager)refreshToken() { 24 | for { 25 | time.Sleep(time.Duration(5) * time.Minute) 26 | 27 | mgr.mutex.Lock() 28 | mgr.tokens[0] = mgr.tokens[1] 29 | mgr.tokens[1] = genToken() 30 | mgr.mutex.Unlock() 31 | } 32 | } 33 | 34 | var myTokenMgr *TokenManager 35 | var initTokenMgrOnce sync.Once 36 | 37 | // 5分钟刷新一次token, 生成的token10分钟内有效 38 | func GetTokenManager() *TokenManager { 39 | initTokenMgrOnce.Do(func() { 40 | myTokenMgr = &TokenManager{} 41 | myTokenMgr.tokens[0] = genToken() 42 | myTokenMgr.tokens[1] = genToken() 43 | go myTokenMgr.refreshToken() 44 | }) 45 | return myTokenMgr 46 | } 47 | 48 | func (mgr *TokenManager) ValidateToken(token string) bool { 49 | mgr.mutex.Lock() 50 | defer mgr.mutex.Unlock() 51 | 52 | for _, myToken := range mgr.tokens { 53 | if token == myToken { 54 | return true 55 | } 56 | } 57 | return false 58 | } 59 | 60 | func (mgr *TokenManager) GetToken() string { 61 | mgr.mutex.Lock() 62 | defer mgr.mutex.Unlock() 63 | 64 | return mgr.tokens[1] 65 | } --------------------------------------------------------------------------------