├── .github └── workflows │ └── main.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── bamboozle_unit_test.go ├── banman ├── codec.go ├── codec_test.go ├── reason.go ├── store.go ├── store_test.go ├── util.go └── util_test.go ├── batch_spend_reporter.go ├── blockmanager.go ├── blockmanager_test.go ├── blockntfns ├── log.go ├── manager.go ├── manager_test.go └── notification.go ├── cache ├── cache.go ├── go.mod ├── go.sum └── lru │ ├── list.go │ ├── list_test.go │ ├── lru.go │ ├── lru_test.go │ └── sync_map.go ├── cache_test.go ├── cacheable_block.go ├── cacheable_filter.go ├── chainsync ├── filtercontrol.go └── filtercontrol_test.go ├── chanutils ├── batch_writer.go ├── batch_writer_test.go ├── log.go └── queue.go ├── errors.go ├── filterdb ├── db.go ├── db_test.go └── log.go ├── go.mod ├── go.sum ├── headerfs ├── file.go ├── index.go ├── index_test.go ├── store.go ├── store_test.go ├── truncate.go └── truncate_windows.go ├── headerlist ├── bounded_header_list.go ├── bounded_header_list_test.go └── header_list.go ├── headerlogger.go ├── log.go ├── mock_store.go ├── neutrino.go ├── notifications.go ├── pushtx ├── broadcaster.go ├── broadcaster_test.go ├── error.go ├── error_test.go └── log.go ├── query.go ├── query ├── interface.go ├── log.go ├── peer_rank.go ├── peer_rank_test.go ├── worker.go ├── worker_test.go ├── workmanager.go ├── workmanager_test.go ├── workqueue.go └── workqueue_test.go ├── query_test.go ├── rescan.go ├── rescan_test.go ├── sync_test.go ├── testdata └── blocks1-256.bz2 ├── tools ├── Dockerfile ├── go.mod ├── go.sum └── tools.go ├── utxoscanner.go ├── utxoscanner_test.go ├── verification.go └── verification_test.go /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - "master" 7 | pull_request: 8 | branches: 9 | - "*" 10 | 11 | defaults: 12 | run: 13 | shell: bash 14 | 15 | env: 16 | GOCACHE: /home/runner/work/go/pkg/build 17 | GOPATH: /home/runner/work/go 18 | GO_VERSION: 1.23.6 19 | 20 | jobs: 21 | ######################## 22 | # compilation check 23 | ######################## 24 | rpc-check: 25 | name: RPC and mobile compilation check 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: git checkout 29 | uses: actions/checkout@v2 30 | 31 | - name: setup go ${{ env.GO_VERSION }} 32 | uses: actions/setup-go@v5 33 | with: 34 | go-version: '${{ env.GO_VERSION }}' 35 | 36 | - name: run check 37 | run: make build 38 | 39 | ######################## 40 | # lint code 41 | ######################## 42 | lint: 43 | name: lint code 44 | runs-on: ubuntu-latest 45 | steps: 46 | - name: git checkout 47 | uses: actions/checkout@v4 48 | with: 49 | # The same as "git fetch --unshallow" but also works when running the 50 | # action locally with "act". 51 | fetch-depth: 0 52 | 53 | - name: setup go ${{ env.GO_VERSION }} 54 | uses: actions/setup-go@v5 55 | with: 56 | go-version: '${{ env.GO_VERSION }}' 57 | 58 | - name: lint 59 | run: make lint 60 | 61 | ######################## 62 | # run unit tests 63 | ######################## 64 | unit-test: 65 | name: run unit tests 66 | runs-on: ubuntu-latest 67 | strategy: 68 | # Allow other tests in the matrix to continue if one fails. 69 | fail-fast: false 70 | matrix: 71 | unit_type: 72 | - unit-cover 73 | - unit-race 74 | steps: 75 | - name: git checkout 76 | uses: actions/checkout@v4 77 | 78 | - name: setup go ${{ env.GO_VERSION }} 79 | uses: actions/setup-go@v5 80 | with: 81 | go-version: '${{ env.GO_VERSION }}' 82 | 83 | - name: run ${{ matrix.unit_type }} 84 | run: make ${{ matrix.unit_type }} 85 | 86 | - name: Send coverage 87 | uses: shogo82148/actions-goveralls@v1 88 | if: matrix.unit_type == 'unit-cover' 89 | with: 90 | path-to-profile: coverage.txt 91 | parallel: true 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Test binary, build with `go test -c` 8 | *.test 9 | 10 | # Output of the go coverage tool, specifically when used with LiteIDE 11 | *.out 12 | 13 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 14 | .glide/ 15 | 16 | # Glide vendor subdirectory 17 | vendor/ 18 | 19 | #GoLand config files 20 | .idea 21 | *.DS_Store 22 | 23 | # vim swap files 24 | *.swp 25 | 26 | # delve breakpoints 27 | breakpoints.txt 28 | 29 | # coverage output 30 | coverage.txt 31 | 32 | # go workspace 33 | go.work 34 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | # timeout for analysis 3 | timeout: 10m 4 | 5 | linters-settings: 6 | govet: 7 | # Don't report about shadowed variables 8 | shadow: false 9 | gofmt: 10 | # simplify code: gofmt with `-s` option, true by default 11 | simplify: true 12 | whitespace: 13 | multi-func: true 14 | multi-if: true 15 | gosec: 16 | excludes: 17 | - G115 # Integer overflow conversion. 18 | lll: 19 | # Max line length, lines longer will be reported. 20 | line-length: 80 21 | # Tab width in spaces. 22 | tab-width: 8 23 | 24 | linters: 25 | enable-all: true 26 | disable: 27 | # Global variables are used in many places throughout the code base. 28 | - gochecknoglobals 29 | 30 | # We want to allow short variable names. 31 | - varnamelen 32 | 33 | # We want to allow TODOs. 34 | - godox 35 | 36 | # We have long functions, especially in tests. Moving or renaming those would 37 | # trigger funlen problems that we may not want to solve at that time. 38 | - funlen 39 | 40 | # Disable for now as we haven't yet tuned the sensitivity to our codebase 41 | # yet. Enabling by default for example, would also force new contributors to 42 | # potentially extensively refactor code, when they want to smaller change to 43 | # land. 44 | - gocyclo 45 | - gocognit 46 | - cyclop 47 | 48 | # Instances of table driven tests that don't pre-allocate shouldn't trigger 49 | # the linter. 50 | - prealloc 51 | 52 | # Init functions are used by loggers throughout the codebase. 53 | - gochecknoinits 54 | 55 | # Causes stack overflow, see https://github.com/polyfloyd/go-errorlint/issues/19. 56 | - errorlint 57 | 58 | # New linters that need a code adjustment first. 59 | - wrapcheck 60 | - nolintlint 61 | - paralleltest 62 | - tparallel 63 | - testpackage 64 | - gofumpt 65 | - gomoddirectives 66 | - ireturn 67 | - maintidx 68 | - nlreturn 69 | - dogsled 70 | - gci 71 | - containedctx 72 | - contextcheck 73 | - errname 74 | - err113 75 | - mnd 76 | - noctx 77 | - nestif 78 | - wsl 79 | - exhaustive 80 | - forcetypeassert 81 | - nilerr 82 | - nilnil 83 | - stylecheck 84 | - thelper 85 | - exhaustruct 86 | - intrange 87 | - inamedparam 88 | - depguard 89 | - recvcheck 90 | - perfsprint 91 | - revive 92 | 93 | issues: 94 | # Only show newly introduced problems. 95 | new-from-rev: c932ae495eeedc20f58a521d0b3e08889348b06c 96 | 97 | exclude-rules: 98 | # Exclude gosec from running for tests so that tests with weak randomness 99 | # (math/rand) will pass the linter. 100 | - path: _test\.go 101 | linters: 102 | - gosec 103 | - errcheck 104 | - dupl 105 | - staticcheck 106 | 107 | # Instances of table driven tests that don't pre-allocate shouldn't 108 | # trigger the linter. 109 | - prealloc 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2022 Lightning Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PKG := github.com/lightninglabs/neutrino 2 | TOOLS_DIR := tools 3 | 4 | BTCD_PKG := github.com/btcsuite/btcd 5 | LINT_PKG := github.com/golangci/golangci-lint/cmd/golangci-lint 6 | GOACC_PKG := github.com/ory/go-acc 7 | GOIMPORTS_PKG := github.com/rinchsan/gosimports/cmd/gosimports 8 | 9 | GO_BIN := ${GOPATH}/bin 10 | LINT_BIN := $(GO_BIN)/golangci-lint 11 | GOACC_BIN := $(GO_BIN)/go-acc 12 | 13 | GOBUILD := go build -v 14 | GOINSTALL := go install -v 15 | GOTEST := go test 16 | 17 | GOLIST := go list -deps $(PKG)/... | grep '$(PKG)' 18 | GOLIST_COVER := $$(go list -deps $(PKG)/... | grep '$(PKG)') 19 | GOFILES_NOVENDOR = $(shell find . -type f -name '*.go' -not -path "./vendor/*") 20 | 21 | RM := rm -f 22 | CP := cp 23 | MAKE := make 24 | XARGS := xargs -L 1 25 | DOCKER_TOOLS = docker run -v $$(pwd):/build neutrino-tools 26 | 27 | # Linting uses a lot of memory, so keep it under control by limiting the number 28 | # of workers if requested. 29 | ifneq ($(workers),) 30 | LINT_WORKERS = --concurrency=$(workers) 31 | endif 32 | 33 | GREEN := "\\033[0;32m" 34 | NC := "\\033[0m" 35 | define print 36 | echo $(GREEN)$1$(NC) 37 | endef 38 | 39 | default: build 40 | 41 | all: build check 42 | 43 | # ============ 44 | # DEPENDENCIES 45 | # ============ 46 | 47 | btcd: 48 | @$(call print, "Installing btcd.") 49 | cd $(TOOLS_DIR); go install -trimpath -tags=tools $(BTCD_PKG) 50 | 51 | $(GOACC_BIN): 52 | @$(call print, "Fetching go-acc") 53 | cd $(TOOLS_DIR); go install -trimpath -tags=tools $(GOACC_PKG) 54 | 55 | goimports: 56 | @$(call print, "Installing goimports.") 57 | cd $(TOOLS_DIR); go install -trimpath -tags=tools $(GOIMPORTS_PKG) 58 | 59 | # ============ 60 | # INSTALLATION 61 | # ============ 62 | 63 | build: 64 | @$(call print, "Compiling neutrino.") 65 | $(GOBUILD) $(PKG)/... 66 | 67 | # ======= 68 | # TESTING 69 | # ======= 70 | 71 | check: unit 72 | 73 | unit: btcd 74 | @$(call print, "Running unit tests.") 75 | $(GOLIST) | $(XARGS) env $(GOTEST) 76 | 77 | unit-cover: btcd $(GOACC_BIN) 78 | @$(call print, "Running unit coverage tests.") 79 | $(GOACC_BIN) $(GOLIST_COVER) 80 | 81 | unit-race: btcd 82 | @$(call print, "Running unit race tests.") 83 | env CGO_ENABLED=1 GORACE="history_size=7 halt_on_errors=1" $(GOLIST) | $(XARGS) env $(GOTEST) -race 84 | 85 | # ========= 86 | # UTILITIES 87 | # ========= 88 | 89 | docker-tools: 90 | @$(call print, "Building tools docker image.") 91 | docker build -q -t neutrino-tools $(TOOLS_DIR) 92 | 93 | fmt: goimports 94 | @$(call print, "Fixing imports.") 95 | gosimports -w $(GOFILES_NOVENDOR) 96 | @$(call print, "Formatting source.") 97 | gofmt -l -w -s $(GOFILES_NOVENDOR) 98 | 99 | lint: docker-tools 100 | @$(call print, "Linting source.") 101 | $(DOCKER_TOOLS) golangci-lint run -v $(LINT_WORKERS) 102 | 103 | clean: 104 | @$(call print, "Cleaning source.$(NC)") 105 | $(RM) coverage.txt 106 | 107 | .PHONY: all \ 108 | btcd \ 109 | default \ 110 | build \ 111 | check \ 112 | unit \ 113 | unit-cover \ 114 | unit-race \ 115 | fmt \ 116 | lint \ 117 | clean 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neutrino: Privacy-Preserving Bitcoin Light Client 2 | 3 | [![Build Status](https://github.com/lightninglabs/neutrino/actions/workflows/main.yml/badge.svg)](https://github.com/lightninglabs/neutrino/actions/workflows/main.yml) 4 | [![Godoc](https://godoc.org/github.com/lightninglabs/neutrino?status.svg)](https://godoc.org/github.com/lightninglabs/neutrino) 5 | [![Coverage Status](https://coveralls.io/repos/github/lightninglabs/neutrino/badge.svg?branch=master)](https://coveralls.io/github/lightninglabs/neutrino?branch=master) 6 | 7 | Neutrino is a Bitcoin light client written in Go and designed with mobile 8 | Lightning Network clients in mind. It uses a 9 | [new proposal](https://lists.linuxfoundation.org/pipermail/bitcoin-dev/2017-June/014474.html) 10 | for compact block filters to minimize bandwidth and storage use on the client 11 | side, while attempting to preserve privacy and minimize processor load on full 12 | nodes serving light clients. 13 | 14 | ## Mechanism of operation 15 | The light client synchronizes only block headers and a chain of compact block 16 | filter headers specifying the correct filters for each block. Filters are loaded 17 | lazily and stored in the database upon request; blocks are loaded lazily and not 18 | saved. 19 | 20 | ## Usage 21 | The client is instantiated as an object using `NewChainService` and then 22 | started. Upon start, the client sets up its database and other relevant files 23 | and connects to the p2p network. At this point, it becomes possible to query the 24 | client. 25 | 26 | ### Queries 27 | There are various types of queries supported by the client. There are many ways 28 | to access the database, for example, to get block headers by height and hash; in 29 | addition, it's possible to get a full block from the network using 30 | `GetBlockFromNetwork` by hash. However, the most useful methods are specifically 31 | tailored to scan the blockchain for data relevant to a wallet or a smart 32 | contract platform such as a [Lightning Network node like 33 | `lnd`](https://github.com/lightningnetwork/lnd). These are described below. 34 | 35 | #### Rescan 36 | `Rescan` allows a wallet to scan a chain for specific TXIDs, outputs, and 37 | addresses. A start and end block may be specified along with other options. If 38 | no end block is specified, the rescan continues until stopped. If no start block 39 | is specified, the rescan begins with the latest known block. While a rescan 40 | runs, it notifies the client of each connected and disconnected block; the 41 | notifications follow the 42 | [btcjson](https://github.com/btcsuite/btcd/blob/master/btcjson/chainsvrwsntfns.go) 43 | format with the option to use any of the relevant notifications. It's important 44 | to note that "recvtx" and "redeemingtx" notifications are only sent when a 45 | transaction is confirmed, not when it enters the mempool; the client does not 46 | currently support accepting 0-confirmation transactions. 47 | 48 | #### GetUtxo 49 | `GetUtxo` allows a wallet or smart contract platform to check that a UTXO exists 50 | on the blockchain and has not been spent. It is **highly recommended** to 51 | specify a start block; otherwise, in the event that the UTXO doesn't exist on 52 | the blockchain, the client will download all the filters back to block 1 53 | searching for it. The client scans from the tip of the chain backwards, stopping 54 | when it finds the UTXO having been either spent or created; if it finds neither, 55 | it keeps scanning backwards until it hits the specified start block or, if a 56 | start block isn't specified, the first block in the blockchain. It returns a 57 | `SpendReport` containing either a `TxOut` including the `PkScript` required to 58 | spend the output, or containing information about the spending transaction, 59 | spending input, and block height in which the spending transaction was seen. 60 | 61 | ### Stopping the client 62 | Calling `Stop` on the `ChainService` client allows the user to stop the client; 63 | the method doesn't return until the `ChainService` is cleanly shut down. 64 | -------------------------------------------------------------------------------- /banman/codec.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | import ( 4 | "io" 5 | "net" 6 | ) 7 | 8 | // ipType represents the different types of IP addresses supported by the 9 | // BanStore interface. 10 | type ipType = byte 11 | 12 | const ( 13 | // ipv4 represents an IP address of type IPv4. 14 | ipv4 ipType = 0 15 | 16 | // ipv6 represents an IP address of type IPv6. 17 | ipv6 ipType = 1 18 | ) 19 | 20 | // encodeIPNet serializes the IP network into the given reader. 21 | func encodeIPNet(w io.Writer, ipNet *net.IPNet) error { 22 | // Determine the appropriate IP type for the IP address contained in the 23 | // network. 24 | var ( 25 | ip []byte 26 | ipType ipType 27 | ) 28 | switch { 29 | case ipNet.IP.To4() != nil: 30 | ip = ipNet.IP.To4() 31 | ipType = ipv4 32 | case ipNet.IP.To16() != nil: 33 | ip = ipNet.IP.To16() 34 | ipType = ipv6 35 | default: 36 | return ErrUnsupportedIP 37 | } 38 | 39 | // Write the IP type first in order to properly identify it when 40 | // deserializing it, followed by the IP itself and its mask. 41 | if _, err := w.Write([]byte{ipType}); err != nil { 42 | return err 43 | } 44 | if _, err := w.Write(ip); err != nil { 45 | return err 46 | } 47 | if _, err := w.Write([]byte(ipNet.Mask)); err != nil { 48 | return err 49 | } 50 | 51 | return nil 52 | } 53 | 54 | // decodeIPNet deserialized an IP network from the given reader. 55 | func decodeIPNet(r io.Reader) (*net.IPNet, error) { 56 | // Read the IP address type and determine whether it is supported. 57 | var ipType [1]byte 58 | if _, err := r.Read(ipType[:]); err != nil { 59 | return nil, err 60 | } 61 | 62 | var ipLen int 63 | switch ipType[0] { 64 | case ipv4: 65 | ipLen = net.IPv4len 66 | case ipv6: 67 | ipLen = net.IPv6len 68 | default: 69 | return nil, ErrUnsupportedIP 70 | } 71 | 72 | // Once we have the type and its corresponding length, attempt to read 73 | // it and its mask. 74 | ip := make([]byte, ipLen) 75 | if _, err := r.Read(ip); err != nil { 76 | return nil, err 77 | } 78 | mask := make([]byte, ipLen) 79 | if _, err := r.Read(mask); err != nil { 80 | return nil, err 81 | } 82 | return &net.IPNet{IP: ip, Mask: mask}, nil 83 | } 84 | -------------------------------------------------------------------------------- /banman/codec_test.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | import ( 4 | "bytes" 5 | "net" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | // TestIPNetSerialization ensures that we can serialize different supported IP 11 | // networks and deserialize them into their expected result. 12 | func TestIPNetSerialization(t *testing.T) { 13 | t.Parallel() 14 | 15 | testCases := []struct { 16 | name string 17 | ipNet *net.IPNet 18 | err error 19 | }{ 20 | { 21 | name: "ipv4 without mask", 22 | ipNet: &net.IPNet{ 23 | IP: net.ParseIP("172.217.6.46"), 24 | Mask: net.IPv4Mask(0x00, 0x00, 0x00, 0x00), 25 | }, 26 | }, 27 | { 28 | name: "ipv4 with default mask", 29 | ipNet: &net.IPNet{ 30 | IP: net.ParseIP("172.217.6.46"), 31 | Mask: defaultIPv4Mask, 32 | }, 33 | }, 34 | { 35 | name: "ipv4 with non-default mask", 36 | ipNet: &net.IPNet{ 37 | IP: net.ParseIP("172.217.6.46"), 38 | Mask: net.IPv4Mask(0xff, 0xff, 0x00, 0x00), 39 | }, 40 | }, 41 | { 42 | name: "ipv6 without mask", 43 | ipNet: &net.IPNet{ 44 | IP: net.ParseIP("2001:db8:a0b:12f0::1"), 45 | Mask: net.IPMask(make([]byte, net.IPv6len)), 46 | }, 47 | }, 48 | { 49 | name: "ipv6 with default mask", 50 | ipNet: &net.IPNet{ 51 | IP: net.ParseIP("2001:db8:a0b:12f0::1"), 52 | Mask: defaultIPv6Mask, 53 | }, 54 | }, 55 | { 56 | name: "ipv6 with non-default mask", 57 | ipNet: &net.IPNet{ 58 | IP: net.ParseIP("2001:db8:a0b:12f0::1"), 59 | Mask: net.IPMask([]byte{ 60 | 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 61 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 62 | 0x00, 0x00, 0x00, 0x00, 63 | }), 64 | }, 65 | }, 66 | } 67 | 68 | for _, testCase := range testCases { 69 | success := t.Run(testCase.name, func(t *testing.T) { 70 | // Serialize the IP network and deserialize it back. 71 | // We'll do this to ensure we are properly serializing 72 | // and deserializing them. 73 | var b bytes.Buffer 74 | err := encodeIPNet(&b, testCase.ipNet) 75 | if testCase.err != nil && err != testCase.err { 76 | t.Fatalf("encoding IP network %v expected "+ 77 | "error \"%v\", got \"%v\"", 78 | testCase.ipNet, testCase.err, err) 79 | } 80 | ipNet, err := decodeIPNet(&b) 81 | if testCase.err != nil && err != testCase.err { 82 | t.Fatalf("decoding IP network %v expected "+ 83 | "error \"%v\", got \"%v\"", 84 | testCase.ipNet, testCase.err, err) 85 | } 86 | 87 | // If the test did not expect a result, i.e., an invalid 88 | // IP network, then we can exit now. 89 | if testCase.err != nil { 90 | return 91 | } 92 | 93 | // Otherwise, ensure the result is what we expect. 94 | if !ipNet.IP.Equal(testCase.ipNet.IP) { 95 | t.Fatalf("expected IP %v, got %v", 96 | testCase.ipNet.IP, ipNet.IP) 97 | } 98 | if !reflect.DeepEqual(ipNet.Mask, testCase.ipNet.Mask) { 99 | t.Fatalf("expected mask %#v, got %#v", 100 | testCase.ipNet.Mask, ipNet.Mask) 101 | } 102 | }) 103 | if !success { 104 | return 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /banman/reason.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | // Reason includes the different possible reasons which caused us to ban a peer. 4 | type Reason uint8 5 | 6 | // We prevent using `iota` to ensure the order does not have the value since 7 | // these are serialized within the database. 8 | const ( 9 | // ExceededBanThreshold signals that a peer exceeded its ban threshold. 10 | ExceededBanThreshold Reason = 1 11 | 12 | // NoCompactFilters signals that a peer was unable to serve us compact 13 | // filters. 14 | NoCompactFilters Reason = 2 15 | 16 | // InvalidFilterHeader signals that a peer served us an invalid filter 17 | // header. 18 | InvalidFilterHeader Reason = 3 19 | 20 | // InvalidFilterHeaderCheckpoint signals that a peer served us an 21 | // invalid filter header checkpoint. 22 | InvalidFilterHeaderCheckpoint Reason = 4 23 | 24 | // InvalidBlock signals that a peer served us a bad block. 25 | InvalidBlock Reason = 5 26 | ) 27 | 28 | // String returns a human-readable description for the reason a peer was banned. 29 | func (r Reason) String() string { 30 | switch r { 31 | case ExceededBanThreshold: 32 | return "peer exceeded ban threshold" 33 | 34 | case NoCompactFilters: 35 | return "peer was unable to serve compact filters" 36 | 37 | case InvalidFilterHeader: 38 | return "peer served invalid filter header" 39 | 40 | case InvalidFilterHeaderCheckpoint: 41 | return "peer served invalid filter header checkpoint" 42 | 43 | case InvalidBlock: 44 | return "peer served an invalid block" 45 | 46 | default: 47 | return "unknown reason" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /banman/store.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "net" 9 | "time" 10 | 11 | "github.com/btcsuite/btcwallet/walletdb" 12 | ) 13 | 14 | var ( 15 | // byteOrder is the preferred byte order in which we should write things 16 | // to disk. 17 | byteOrder = binary.BigEndian 18 | 19 | // banStoreBucket is the top level bucket of the Store that will contain 20 | // all relevant sub-buckets. 21 | banStoreBucket = []byte("ban-store") 22 | 23 | // banBucket is the main index in which we keep track of IP networks and 24 | // their absolute expiration time. 25 | // 26 | // The key is the IP network host and the value is the absolute 27 | // expiration time. 28 | banBucket = []byte("ban-index") 29 | 30 | // reasonBucket is an index in which we keep track of why an IP network 31 | // was banned. 32 | // 33 | // The key is the IP network and the value is the Reason. 34 | reasonBucket = []byte("reason-index") 35 | 36 | // ErrCorruptedStore is an error returned when we attempt to locate any 37 | // of the ban-related buckets in the database but are unable to. 38 | ErrCorruptedStore = errors.New("corrupted ban store") 39 | 40 | // ErrUnsupportedIP is an error returned when we attempt to parse an 41 | // unsupported IP address type. 42 | ErrUnsupportedIP = errors.New("unsupported IP type") 43 | ) 44 | 45 | // Status gathers all of the details regarding an IP network's ban status. 46 | type Status struct { 47 | // Banned determines whether the IP network is currently banned. 48 | Banned bool 49 | 50 | // Reason is the reason for which the IP network was banned. 51 | Reason Reason 52 | 53 | // Expiration is the absolute time in which the ban will expire. 54 | Expiration time.Time 55 | } 56 | 57 | // Store is the store responsible for maintaining records of banned IP networks. 58 | // It uses IP networks, rather than single IP addresses, in order to coalesce 59 | // multiple IP addresses that are likely to be correlated. 60 | type Store interface { 61 | // BanIPNet creates a ban record for the IP network within the store for 62 | // the given duration. A reason can also be provided to note why the IP 63 | // network is being banned. The record will exist until a call to Status 64 | // is made after the ban expiration. 65 | BanIPNet(*net.IPNet, Reason, time.Duration) error 66 | 67 | // Status returns the ban status for a given IP network. 68 | Status(*net.IPNet) (Status, error) 69 | 70 | // UnbanIPNet removes the ban imposed on the specified peer. 71 | UnbanIPNet(ipNet *net.IPNet) error 72 | } 73 | 74 | // NewStore returns a Store backed by a database. 75 | func NewStore(db walletdb.DB) (Store, error) { 76 | return newBanStore(db) 77 | } 78 | 79 | // banStore is a concrete implementation of the Store interface backed by a 80 | // database. 81 | type banStore struct { 82 | db walletdb.DB 83 | } 84 | 85 | // A compile-time constraint to ensure banStore satisfies the Store interface. 86 | var _ Store = (*banStore)(nil) 87 | 88 | // newBanStore creates a concrete implementation of the Store interface backed 89 | // by a database. 90 | func newBanStore(db walletdb.DB) (*banStore, error) { 91 | s := &banStore{db: db} 92 | 93 | // We'll ensure the expected buckets are created upon initialization. 94 | err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error { 95 | banStore, err := tx.CreateTopLevelBucket(banStoreBucket) 96 | if err != nil { 97 | return err 98 | } 99 | _, err = banStore.CreateBucketIfNotExists(banBucket) 100 | if err != nil { 101 | return err 102 | } 103 | _, err = banStore.CreateBucketIfNotExists(reasonBucket) 104 | return err 105 | }) 106 | if err != nil && err != walletdb.ErrBucketExists { 107 | return nil, err 108 | } 109 | 110 | return s, nil 111 | } 112 | 113 | // BanIPNet creates a ban record for the IP network within the store for the 114 | // given duration. A reason can also be provided to note why the IP network is 115 | // being banned. The record will exist until a call to Status is made after the 116 | // ban expiration. 117 | func (s *banStore) BanIPNet(ipNet *net.IPNet, reason Reason, duration time.Duration) error { 118 | return walletdb.Update(s.db, func(tx walletdb.ReadWriteTx) error { 119 | banStore := tx.ReadWriteBucket(banStoreBucket) 120 | if banStore == nil { 121 | return ErrCorruptedStore 122 | } 123 | banIndex := banStore.NestedReadWriteBucket(banBucket) 124 | if banIndex == nil { 125 | return ErrCorruptedStore 126 | } 127 | reasonIndex := banStore.NestedReadWriteBucket(reasonBucket) 128 | if reasonIndex == nil { 129 | return ErrCorruptedStore 130 | } 131 | 132 | var ipNetBuf bytes.Buffer 133 | if err := encodeIPNet(&ipNetBuf, ipNet); err != nil { 134 | return fmt.Errorf("unable to encode %v: %v", ipNet, err) 135 | } 136 | k := ipNetBuf.Bytes() 137 | 138 | return addBannedIPNet(banIndex, reasonIndex, k, reason, duration) 139 | }) 140 | } 141 | 142 | // UnbanIPNet removes a ban record for the IP network within the store. 143 | func (s *banStore) UnbanIPNet(ipNet *net.IPNet) error { 144 | err := walletdb.Update(s.db, func(tx walletdb.ReadWriteTx) error { 145 | banStore := tx.ReadWriteBucket(banStoreBucket) 146 | if banStore == nil { 147 | return ErrCorruptedStore 148 | } 149 | 150 | banIndex := banStore.NestedReadWriteBucket(banBucket) 151 | if banIndex == nil { 152 | return ErrCorruptedStore 153 | } 154 | 155 | reasonIndex := banStore.NestedReadWriteBucket(reasonBucket) 156 | if reasonIndex == nil { 157 | return ErrCorruptedStore 158 | } 159 | 160 | var ipNetBuf bytes.Buffer 161 | if err := encodeIPNet(&ipNetBuf, ipNet); err != nil { 162 | return fmt.Errorf("unable to encode %v: %v", ipNet, 163 | err) 164 | } 165 | 166 | k := ipNetBuf.Bytes() 167 | 168 | return removeBannedIPNet(banIndex, reasonIndex, k) 169 | }) 170 | 171 | return err 172 | } 173 | 174 | // addBannedIPNet adds an entry to the ban store for the given IP network. 175 | func addBannedIPNet(banIndex, reasonIndex walletdb.ReadWriteBucket, 176 | ipNetKey []byte, reason Reason, duration time.Duration) error { 177 | 178 | var v [8]byte 179 | banExpiration := time.Now().Add(duration) 180 | byteOrder.PutUint64(v[:], uint64(banExpiration.Unix())) 181 | 182 | if err := banIndex.Put(ipNetKey, v[:]); err != nil { 183 | return err 184 | } 185 | return reasonIndex.Put(ipNetKey, []byte{byte(reason)}) 186 | } 187 | 188 | // Status returns the ban status for a given IP network. 189 | func (s *banStore) Status(ipNet *net.IPNet) (Status, error) { 190 | var banStatus Status 191 | err := walletdb.Update(s.db, func(tx walletdb.ReadWriteTx) error { 192 | banStore := tx.ReadWriteBucket(banStoreBucket) 193 | if banStore == nil { 194 | return ErrCorruptedStore 195 | } 196 | banIndex := banStore.NestedReadWriteBucket(banBucket) 197 | if banIndex == nil { 198 | return ErrCorruptedStore 199 | } 200 | reasonIndex := banStore.NestedReadWriteBucket(reasonBucket) 201 | if reasonIndex == nil { 202 | return ErrCorruptedStore 203 | } 204 | 205 | var ipNetBuf bytes.Buffer 206 | if err := encodeIPNet(&ipNetBuf, ipNet); err != nil { 207 | return fmt.Errorf("unable to encode %v: %v", ipNet, err) 208 | } 209 | k := ipNetBuf.Bytes() 210 | 211 | status := fetchStatus(banIndex, reasonIndex, k) 212 | 213 | // If the IP network's ban duration has expired, we can remove 214 | // its entry from the store. 215 | if !time.Now().Before(status.Expiration) { 216 | return removeBannedIPNet(banIndex, reasonIndex, k) 217 | } 218 | 219 | banStatus = status 220 | return nil 221 | }) 222 | if err != nil { 223 | return Status{}, err 224 | } 225 | 226 | return banStatus, nil 227 | } 228 | 229 | // fetchStatus retrieves the ban status of the given IP network. 230 | func fetchStatus(banIndex, reasonIndex walletdb.ReadWriteBucket, 231 | ipNetKey []byte) Status { 232 | 233 | v := banIndex.Get(ipNetKey) 234 | if v == nil { 235 | return Status{} 236 | } 237 | reason := Reason(reasonIndex.Get(ipNetKey)[0]) 238 | banExpiration := time.Unix(int64(byteOrder.Uint64(v)), 0) 239 | 240 | return Status{ 241 | Banned: true, 242 | Reason: reason, 243 | Expiration: banExpiration, 244 | } 245 | } 246 | 247 | // removeBannedIPNet removes all references to a banned IP network within the 248 | // ban store. 249 | func removeBannedIPNet(banIndex, reasonIndex walletdb.ReadWriteBucket, 250 | ipNetKey []byte) error { 251 | 252 | if err := banIndex.Delete(ipNetKey); err != nil { 253 | return err 254 | } 255 | return reasonIndex.Delete(ipNetKey) 256 | } 257 | -------------------------------------------------------------------------------- /banman/store_test.go: -------------------------------------------------------------------------------- 1 | package banman_test 2 | 3 | import ( 4 | "io/ioutil" 5 | "net" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | "time" 10 | 11 | "github.com/btcsuite/btcwallet/walletdb" 12 | _ "github.com/btcsuite/btcwallet/walletdb/bdb" 13 | "github.com/lightninglabs/neutrino/banman" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | // createTestBanStore creates a test Store backed by a boltdb instance. 18 | func createTestBanStore(t *testing.T) (banman.Store, func()) { 19 | t.Helper() 20 | 21 | dbDir, err := ioutil.TempDir("", "") 22 | if err != nil { 23 | t.Fatalf("unable to create db dir: %v", err) 24 | } 25 | dbPath := filepath.Join(dbDir, "test.db") 26 | 27 | db, err := walletdb.Create("bdb", dbPath, true, time.Second*10) 28 | if err != nil { 29 | os.RemoveAll(dbDir) 30 | t.Fatalf("unable to create db: %v", err) 31 | } 32 | 33 | cleanUp := func() { 34 | db.Close() 35 | os.RemoveAll(dbDir) 36 | } 37 | 38 | banStore, err := banman.NewStore(db) 39 | if err != nil { 40 | cleanUp() 41 | t.Fatalf("unable to create ban store: %v", err) 42 | } 43 | 44 | return banStore, cleanUp 45 | } 46 | 47 | // TestBanStore ensures that the BanStore's state correctly reflects the 48 | // BanStatus of IP networks. 49 | func TestBanStore(t *testing.T) { 50 | t.Parallel() 51 | 52 | // We'll start by creating our test BanStore backed by a boltdb 53 | // instance. 54 | banStore, cleanUp := createTestBanStore(t) 55 | defer cleanUp() 56 | 57 | // checkBanStore is a helper closure to ensure to the IP network's ban 58 | // status is correctly reflected within the BanStore. 59 | checkBanStore := func(ipNet *net.IPNet, banned bool, 60 | reason banman.Reason, duration time.Duration) { 61 | 62 | t.Helper() 63 | 64 | banStatus, err := banStore.Status(ipNet) 65 | if err != nil { 66 | t.Fatalf("unable to determine %v's ban status: %v", 67 | ipNet, err) 68 | } 69 | if banned && !banStatus.Banned { 70 | t.Fatalf("expected %v to be banned", ipNet) 71 | } 72 | if !banned && banStatus.Banned { 73 | t.Fatalf("expected %v to not be banned", ipNet) 74 | } 75 | 76 | if banned { 77 | return 78 | } 79 | 80 | if banStatus.Reason != reason { 81 | t.Fatalf("expected ban reason \"%v\", got \"%v\"", 82 | reason, banStatus.Reason) 83 | } 84 | banDuration := time.Until(banStatus.Expiration) 85 | if banDuration > duration { 86 | t.Fatalf("expected ban duration to be within %v, got %v", 87 | duration, banDuration) 88 | } 89 | } 90 | 91 | // We'll create two IP networks, the first banned for an hour and the 92 | // second for a second. 93 | addr1 := "127.0.0.1:8333" 94 | ipNet1, err := banman.ParseIPNet(addr1, nil) 95 | if err != nil { 96 | t.Fatalf("unable to parse IP network from %v: %v", addr1, err) 97 | } 98 | err = banStore.BanIPNet(ipNet1, banman.NoCompactFilters, time.Hour) 99 | if err != nil { 100 | t.Fatalf("unable to ban IP network: %v", err) 101 | } 102 | addr2 := "192.168.1.1:8333" 103 | ipNet2, err := banman.ParseIPNet(addr2, nil) 104 | if err != nil { 105 | t.Fatalf("unable to parse IP network from %v: %v", addr2, err) 106 | } 107 | err = banStore.BanIPNet(ipNet2, banman.ExceededBanThreshold, time.Second) 108 | if err != nil { 109 | t.Fatalf("unable to ban IP network: %v", err) 110 | } 111 | 112 | // Both IP networks should be found within the BanStore with their 113 | // expected reason since their ban has yet to expire. 114 | checkBanStore(ipNet1, true, banman.NoCompactFilters, time.Hour) 115 | checkBanStore(ipNet2, true, banman.ExceededBanThreshold, time.Second) 116 | 117 | // Wait long enough for the second IP network's ban to expire. 118 | <-time.After(time.Second) 119 | 120 | // We should only find the first IP network within the BanStore. 121 | checkBanStore(ipNet1, true, banman.NoCompactFilters, time.Hour) 122 | checkBanStore(ipNet2, false, 0, 0) 123 | 124 | // We'll query for second IP network again as it should now be unknown 125 | // to the BanStore. We should expect not to find anything regarding it. 126 | checkBanStore(ipNet2, false, 0, 0) 127 | 128 | // Test UnbanIPNet. 129 | require.NoError(t, banStore.UnbanIPNet(ipNet1)) 130 | 131 | // We would now check that ipNet1 is indeed unbanned. 132 | checkBanStore(ipNet1, false, 0, 0) 133 | } 134 | -------------------------------------------------------------------------------- /banman/util.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | var ( 8 | // defaultIPv4Mask is the default IPv4 mask used when parsing IP 9 | // networks from an address. This ensures that the IP network only 10 | // contains *one* IP address -- the one specified. 11 | defaultIPv4Mask = net.CIDRMask(32, 32) 12 | 13 | // defaultIPv6Mask is the default IPv6 mask used when parsing IP 14 | // networks from an address. This ensures that the IP network only 15 | // contains *one* IP address -- the one specified. 16 | defaultIPv6Mask = net.CIDRMask(128, 128) 17 | ) 18 | 19 | // ParseIPNet parses the IP network that contains the given address. An optional 20 | // mask can be provided, to expand the scope of the IP network, otherwise the 21 | // IP's default is used. 22 | // 23 | // NOTE: This assumes that the address has already been resolved. 24 | func ParseIPNet(addr string, mask net.IPMask) (*net.IPNet, error) { 25 | // If the address includes a port, we'll remove it. 26 | host, _, err := net.SplitHostPort(addr) 27 | if err != nil { 28 | // Address doesn't include a port. 29 | host = addr 30 | } 31 | 32 | // Parse the IP from the host to ensure it is supported. 33 | ip := net.ParseIP(host) 34 | switch { 35 | case ip.To4() != nil: 36 | if mask == nil { 37 | mask = defaultIPv4Mask 38 | } 39 | case ip.To16() != nil: 40 | if mask == nil { 41 | mask = defaultIPv6Mask 42 | } 43 | default: 44 | return nil, ErrUnsupportedIP 45 | } 46 | 47 | return &net.IPNet{IP: ip.Mask(mask), Mask: mask}, nil 48 | } 49 | -------------------------------------------------------------------------------- /banman/util_test.go: -------------------------------------------------------------------------------- 1 | package banman 2 | 3 | import ( 4 | "net" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | // TestParseIPNet ensures that we can parse different combinations of 10 | // IPs/addresses and masks. 11 | func TestParseIPNet(t *testing.T) { 12 | t.Parallel() 13 | 14 | testCases := []struct { 15 | name string 16 | addr string 17 | mask net.IPMask 18 | result *net.IPNet 19 | }{ 20 | { 21 | name: "ipv4 with default mask", 22 | addr: "192.168.1.1", 23 | mask: nil, 24 | result: &net.IPNet{ 25 | IP: net.ParseIP("192.168.1.1"), 26 | Mask: defaultIPv4Mask, 27 | }, 28 | }, 29 | { 30 | name: "ipv4 with port and non-default mask", 31 | addr: "192.168.1.1:80", 32 | mask: net.CIDRMask(16, 32), 33 | result: &net.IPNet{ 34 | IP: net.ParseIP("192.168.0.0"), 35 | Mask: net.CIDRMask(16, 32), 36 | }, 37 | }, 38 | { 39 | name: "ipv6 with port and default mask", 40 | addr: "[2001:db8:a0b:12f0::1]:80", 41 | mask: nil, 42 | result: &net.IPNet{ 43 | IP: net.ParseIP("2001:db8:a0b:12f0::1"), 44 | Mask: defaultIPv6Mask, 45 | }, 46 | }, 47 | { 48 | name: "ipv6 with non-default mask", 49 | addr: "2001:db8:a0b:12f0::1", 50 | mask: net.CIDRMask(32, 128), 51 | result: &net.IPNet{ 52 | IP: net.ParseIP("2001:db8::"), 53 | Mask: net.CIDRMask(32, 128), 54 | }, 55 | }, 56 | } 57 | 58 | for _, testCase := range testCases { 59 | success := t.Run(testCase.name, func(t *testing.T) { 60 | // Parse the IP network from each test's address and 61 | // mask. 62 | ipNet, err := ParseIPNet(testCase.addr, testCase.mask) 63 | if testCase.result != nil && err != nil { 64 | t.Fatalf("unable to parse IP network for "+ 65 | "addr=%v and mask=%v: %v", 66 | testCase.addr, testCase.mask, err) 67 | } 68 | 69 | // If the test did not expect a result, i.e., an invalid 70 | // IP network, then we can exit now. 71 | if testCase.result == nil { 72 | return 73 | } 74 | 75 | // Otherwise, ensure the result is what we expect. 76 | if !ipNet.IP.Equal(testCase.result.IP) { 77 | t.Fatalf("expected IP %v, got %v", 78 | testCase.result.IP, ipNet.IP) 79 | } 80 | if !reflect.DeepEqual(ipNet.Mask, testCase.result.Mask) { 81 | t.Fatalf("expected mask %#v, got %#v", 82 | testCase.result.Mask, ipNet.Mask) 83 | } 84 | }) 85 | if !success { 86 | return 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /batch_spend_reporter.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "github.com/btcsuite/btcd/chaincfg/chainhash" 5 | "github.com/btcsuite/btcd/wire" 6 | ) 7 | 8 | // batchSpendReporter orchestrates the delivery of spend reports to 9 | // GetUtxoRequests processed by the UtxoScanner. The reporter expects a sequence 10 | // of blocks consisting of those containing a UTXO to watch, or any whose 11 | // filter generates a match using current filterEntries. This instance supports 12 | // multiple requests for the same outpoint. 13 | type batchSpendReporter struct { 14 | // requests maps an outpoint to list of GetUtxoRequests waiting for that 15 | // UTXO's spend report. 16 | requests map[wire.OutPoint][]*GetUtxoRequest 17 | 18 | // initialTxns contains a map from an outpoint to the "unspent" version 19 | // of it's spend report. This value is populated by fetching the output 20 | // from the block in the request's start height. This spend report will 21 | // be returned in the case that the output remained unspent for the 22 | // duration of the scan. 23 | initialTxns map[wire.OutPoint]*SpendReport 24 | 25 | // outpoints caches the filter entry for each outpoint, conserving 26 | // allocations when reconstructing the current filterEntries. 27 | outpoints map[wire.OutPoint][]byte 28 | 29 | // filterEntries holds the current set of watched outpoint, and is 30 | // applied to cfilters to gauge whether we should download the block. 31 | // 32 | // NOTE: This watchlist is updated during each call to ProcessBlock. 33 | filterEntries [][]byte 34 | } 35 | 36 | // newBatchSpendReporter instantiates a fresh batchSpendReporter. 37 | func newBatchSpendReporter() *batchSpendReporter { 38 | return &batchSpendReporter{ 39 | requests: make(map[wire.OutPoint][]*GetUtxoRequest), 40 | initialTxns: make(map[wire.OutPoint]*SpendReport), 41 | outpoints: make(map[wire.OutPoint][]byte), 42 | } 43 | } 44 | 45 | // NotifyProgress notifies all requests with the last processed height. 46 | func (b *batchSpendReporter) NotifyProgress(blockHeight uint32) { 47 | for _, requests := range b.requests { 48 | for _, r := range requests { 49 | if r.onProgress != nil { 50 | r.onProgress(blockHeight) 51 | } 52 | } 53 | } 54 | } 55 | 56 | // NotifyUnspentAndUnfound iterates through any requests for which no spends 57 | // were detected. If we were able to find the initial output, this will be 58 | // delivered signaling that no spend was detected. If the original output could 59 | // not be found, a nil spend report is returned. 60 | func (b *batchSpendReporter) NotifyUnspentAndUnfound() { 61 | log.Debugf("Finished batch, %d unspent outpoints", len(b.requests)) 62 | 63 | for outpoint, requests := range b.requests { 64 | op := outpoint 65 | 66 | // A nil SpendReport indicates the output was not found. 67 | tx, ok := b.initialTxns[outpoint] 68 | if !ok { 69 | log.Warnf("Unknown initial txn for getuxo request %v", 70 | outpoint) 71 | } 72 | 73 | b.notifyRequests(&op, requests, tx, nil) 74 | } 75 | } 76 | 77 | // FailRemaining will return an error to all remaining requests in the event we 78 | // experience a critical rescan error. The error is threaded through to allow 79 | // the syntax: 80 | // 81 | // return reporter.FailRemaining(err) 82 | func (b *batchSpendReporter) FailRemaining(err error) error { 83 | for outpoint, requests := range b.requests { 84 | op := outpoint 85 | b.notifyRequests(&op, requests, nil, err) 86 | } 87 | return err 88 | } 89 | 90 | // notifyRequests delivers the same final response to the given requests, and 91 | // cleans up any remaining state for the outpoint. 92 | // 93 | // NOTE: AT MOST ONE of `report` or `err` may be non-nil. 94 | func (b *batchSpendReporter) notifyRequests( 95 | outpoint *wire.OutPoint, 96 | requests []*GetUtxoRequest, 97 | report *SpendReport, 98 | err error) { 99 | 100 | delete(b.requests, *outpoint) 101 | delete(b.initialTxns, *outpoint) 102 | delete(b.outpoints, *outpoint) 103 | 104 | for _, request := range requests { 105 | request.deliver(report, err) 106 | } 107 | } 108 | 109 | // ProcessBlock accepts a block, block height, and any new requests whose start 110 | // height matches the provided height. If a non-zero number of new requests are 111 | // presented, the block will first be checked for the initial outputs from which 112 | // spends may occur. Afterwards, any spends detected in the block are 113 | // immediately dispatched, and the watchlist updated in preparation of filtering 114 | // the next block. 115 | func (b *batchSpendReporter) ProcessBlock(blk *wire.MsgBlock, 116 | newReqs []*GetUtxoRequest, height uint32) { 117 | 118 | // If any requests want the UTXOs at this height, scan the block to find 119 | // the original outputs that might be spent from. 120 | if len(newReqs) > 0 { 121 | b.addNewRequests(newReqs) 122 | b.findInitialTransactions(blk, newReqs, height) 123 | } 124 | 125 | // Next, filter the block for any spends using the current set of 126 | // watched outpoints. This will include any new requests added above. 127 | spends := b.notifySpends(blk, height) 128 | 129 | // Finally, rebuild filter entries from cached entries remaining in 130 | // outpoints map. This will provide an updated watchlist used to scan 131 | // the subsequent filters. 132 | rebuildWatchlist := len(newReqs) > 0 || len(spends) > 0 133 | if rebuildWatchlist { 134 | b.filterEntries = b.filterEntries[:0] 135 | for _, entry := range b.outpoints { 136 | b.filterEntries = append(b.filterEntries, entry) 137 | } 138 | } 139 | } 140 | 141 | // addNewRequests adds a set of new GetUtxoRequests to the spend reporter's 142 | // state. This method immediately adds the request's outpoints to the reporter's 143 | // watchlist. 144 | func (b *batchSpendReporter) addNewRequests(reqs []*GetUtxoRequest) { 145 | for _, req := range reqs { 146 | outpoint := req.Input.OutPoint 147 | 148 | log.Debugf("Adding outpoint=%s height=%d to watchlist", 149 | outpoint, req.BirthHeight) 150 | 151 | b.requests[outpoint] = append(b.requests[outpoint], req) 152 | 153 | // Build the filter entry only if it is the first time seeing 154 | // the outpoint. 155 | if _, ok := b.outpoints[outpoint]; !ok { 156 | entry := req.Input.PkScript 157 | b.outpoints[outpoint] = entry 158 | b.filterEntries = append(b.filterEntries, entry) 159 | } 160 | } 161 | } 162 | 163 | // findInitialTransactions searches the given block for the creation of the 164 | // UTXOs that are supposed to be birthed in this block. If any are found, a 165 | // spend report containing the initial outpoint will be saved in case the 166 | // outpoint is not spent later on. Requests corresponding to outpoints that are 167 | // not found in the block will return a nil spend report to indicate that the 168 | // UTXO was not found. 169 | func (b *batchSpendReporter) findInitialTransactions(block *wire.MsgBlock, 170 | newReqs []*GetUtxoRequest, height uint32) map[wire.OutPoint]*SpendReport { 171 | 172 | // First, construct a reverse index from txid to all a list of requests 173 | // whose outputs share the same txid. 174 | txidReverseIndex := make(map[chainhash.Hash][]*GetUtxoRequest) 175 | for _, req := range newReqs { 176 | txidReverseIndex[req.Input.OutPoint.Hash] = append( 177 | txidReverseIndex[req.Input.OutPoint.Hash], req, 178 | ) 179 | } 180 | 181 | // Iterate over the transactions in this block, hashing each and 182 | // querying our reverse index to see if any requests depend on the txn. 183 | initialTxns := make(map[wire.OutPoint]*SpendReport) 184 | for i, tx := range block.Transactions { 185 | // If our reverse index has been cleared, we are done. 186 | if len(txidReverseIndex) == 0 { 187 | break 188 | } 189 | 190 | hash := tx.TxHash() 191 | txidReqs, ok := txidReverseIndex[hash] 192 | if !ok { 193 | continue 194 | } 195 | delete(txidReverseIndex, hash) 196 | 197 | // For all requests that are watching this txid, use the output 198 | // index of each to grab the initial output. 199 | txOuts := tx.TxOut 200 | for _, req := range txidReqs { 201 | op := req.Input.OutPoint 202 | 203 | // Ensure that the outpoint's index references an actual 204 | // output on the transaction. If not, we will be unable 205 | // to find the initial output. 206 | if op.Index >= uint32(len(txOuts)) { 207 | log.Errorf("Failed to find outpoint %s -- "+ 208 | "invalid output index", op) 209 | initialTxns[op] = nil 210 | continue 211 | } 212 | 213 | h := block.BlockHash() 214 | 215 | initialTxns[op] = &SpendReport{ 216 | Output: txOuts[op.Index], 217 | BlockHash: &h, 218 | BlockHeight: height, 219 | BlockIndex: uint32(i), 220 | } 221 | } 222 | } 223 | 224 | // Finally, we must reconcile any requests for which the txid did not 225 | // exist in this block. A nil spend report is saved for every initial 226 | // txn that could not be found, otherwise the result is copied from scan 227 | // above. The copied values can include valid initial txns, as well as 228 | // nil spend report if the output index was invalid. 229 | for _, req := range newReqs { 230 | tx, ok := initialTxns[req.Input.OutPoint] 231 | switch { 232 | case !ok: 233 | log.Debugf("Outpoint %v not found in block %d ", 234 | req.Input.OutPoint, height) 235 | initialTxns[req.Input.OutPoint] = nil 236 | case tx != nil: 237 | log.Tracef("Block %d creates output %s", 238 | height, req.Input.OutPoint) 239 | default: 240 | } 241 | 242 | b.initialTxns[req.Input.OutPoint] = tx 243 | } 244 | 245 | return initialTxns 246 | } 247 | 248 | // notifySpends finds any transactions in the block that spend from our watched 249 | // outpoints. If a spend is detected, it is immediately delivered and cleaned up 250 | // from the reporter's internal state. 251 | func (b *batchSpendReporter) notifySpends(block *wire.MsgBlock, 252 | height uint32) map[wire.OutPoint]*SpendReport { 253 | 254 | spends := make(map[wire.OutPoint]*SpendReport) 255 | for _, tx := range block.Transactions { 256 | // Check each input to see if this transaction spends one of our 257 | // watched outpoints. 258 | for i, ti := range tx.TxIn { 259 | outpoint := ti.PreviousOutPoint 260 | 261 | // Find the requests this spend relates to. 262 | requests, ok := b.requests[outpoint] 263 | if !ok { 264 | continue 265 | } 266 | 267 | log.Debugf("UTXO %v spent by txn %v", outpoint, 268 | tx.TxHash()) 269 | 270 | spend := &SpendReport{ 271 | SpendingTx: tx, 272 | SpendingInputIndex: uint32(i), 273 | SpendingTxHeight: height, 274 | } 275 | 276 | spends[outpoint] = spend 277 | 278 | // With the requests located, we remove this outpoint 279 | // from both the requests, outpoints, and initial txns 280 | // map. This will ensures we don't continue watching 281 | // this outpoint. 282 | b.notifyRequests(&outpoint, requests, spend, nil) 283 | } 284 | } 285 | 286 | return spends 287 | } 288 | -------------------------------------------------------------------------------- /blockntfns/log.go: -------------------------------------------------------------------------------- 1 | package blockntfns 2 | 3 | import "github.com/btcsuite/btclog" 4 | 5 | // log is a logger that is initialized with no output filters. This 6 | // means the package will not perform any logging by default until the caller 7 | // requests it. 8 | var log btclog.Logger 9 | 10 | // The default amount of logging is none. 11 | func init() { 12 | DisableLog() 13 | } 14 | 15 | // DisableLog disables all library log output. Logging output is disabled 16 | // by default until either UseLogger or SetLogWriter are called. 17 | func DisableLog() { 18 | UseLogger(btclog.Disabled) 19 | } 20 | 21 | // UseLogger uses a specified Logger to output package logging info. 22 | // This should be used in preference to SetLogWriter if the caller is also 23 | // using btclog. 24 | func UseLogger(logger btclog.Logger) { 25 | log = logger 26 | } 27 | -------------------------------------------------------------------------------- /blockntfns/manager_test.go: -------------------------------------------------------------------------------- 1 | package blockntfns_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/btcsuite/btcd/wire" 9 | "github.com/lightninglabs/neutrino/blockntfns" 10 | ) 11 | 12 | var emptyHeader wire.BlockHeader 13 | 14 | type mockNtfnSource struct { 15 | blockChan chan blockntfns.BlockNtfn 16 | blocksSinceHeight func(uint32) ([]blockntfns.BlockNtfn, uint32, error) 17 | } 18 | 19 | func newMockBlockSource() *mockNtfnSource { 20 | return &mockNtfnSource{ 21 | blockChan: make(chan blockntfns.BlockNtfn), 22 | } 23 | } 24 | 25 | func (s *mockNtfnSource) Notifications() <-chan blockntfns.BlockNtfn { 26 | return s.blockChan 27 | } 28 | 29 | func (s *mockNtfnSource) NotificationsSinceHeight( 30 | height uint32) ([]blockntfns.BlockNtfn, uint32, error) { 31 | 32 | if s.blocksSinceHeight != nil { 33 | return s.blocksSinceHeight(height) 34 | } 35 | 36 | return nil, 0, nil 37 | } 38 | 39 | // TestManagerNewSubscription ensures that a client properly receives new 40 | // block notifications once it successfully registers for a subscription. 41 | func TestManagerNewSubscription(t *testing.T) { 42 | t.Parallel() 43 | 44 | // We'll start by creating a subscription manager backed by our mocked 45 | // block source. 46 | blockSource := newMockBlockSource() 47 | subMgr := blockntfns.NewSubscriptionManager(blockSource) 48 | subMgr.Start() 49 | defer subMgr.Stop() 50 | 51 | // We'll create some notifications that will be delivered to a 52 | // registered client. 53 | tipHeight := uint32(0) 54 | newNotifications := make([]blockntfns.BlockNtfn, 0, 20) 55 | for i := uint32(0); i < 20; i++ { 56 | newNotifications = append(newNotifications, blockntfns.NewBlockConnected( 57 | emptyHeader, i+1, 58 | )) 59 | } 60 | staleNotifications := make([]blockntfns.BlockNtfn, 0, 10) 61 | for i := len(newNotifications) - 1; i >= 10; i-- { 62 | staleNotifications = append(staleNotifications, blockntfns.NewBlockDisconnected( 63 | emptyHeader, newNotifications[i].Height(), emptyHeader, 64 | )) 65 | } 66 | 67 | // We'll register a client and proceed to deliver the notifications to 68 | // the SubscriptionManager. This should act as if the client is being 69 | // told of notifications following the tip of the chain. 70 | sub, err := subMgr.NewSubscription(0) 71 | if err != nil { 72 | t.Fatalf("unable to register new subscription: %v", err) 73 | } 74 | 75 | go func() { 76 | for _, block := range newNotifications { 77 | blockSource.blockChan <- block 78 | } 79 | for _, block := range staleNotifications { 80 | blockSource.blockChan <- block 81 | } 82 | }() 83 | 84 | // Then, we'll attempt to process these notifications in order from the 85 | // client's point of view. 20 successive block connected notifications 86 | // should be received, followed by 10 block disconnected notifications. 87 | for i := 0; i < len(newNotifications)+len(staleNotifications); i++ { 88 | select { 89 | case ntfn := <-sub.Notifications: 90 | switch ntfn := ntfn.(type) { 91 | case *blockntfns.Connected: 92 | if ntfn.Height() != tipHeight+1 { 93 | t.Fatalf("expected new block with "+ 94 | "height %d, got %d", 95 | tipHeight+1, ntfn.Height()) 96 | } 97 | tipHeight++ 98 | 99 | case *blockntfns.Disconnected: 100 | if ntfn.Height() != tipHeight { 101 | t.Fatalf("expected stale block with "+ 102 | "height %d, got %d", tipHeight, 103 | ntfn.Height()) 104 | } 105 | tipHeight-- 106 | } 107 | 108 | case <-time.After(time.Second): 109 | t.Fatal("expected to receive block notification") 110 | } 111 | } 112 | 113 | // Finally, the client's height should match as expected after 114 | // processing all the notifications in order. 115 | if tipHeight != 10 { 116 | t.Fatalf("expected chain tip with height %d, got %d", 10, 117 | tipHeight) 118 | } 119 | } 120 | 121 | // TestManagerCancelSubscription ensures that when a client desires to cancel 122 | // their subscription, that they are no longer delivered any new notifications 123 | // after the fact. 124 | func TestManagerCancelSubscription(t *testing.T) { 125 | t.Parallel() 126 | 127 | // We'll start by creating a subscription manager backed by our mocked 128 | // block source. 129 | blockSource := newMockBlockSource() 130 | subMgr := blockntfns.NewSubscriptionManager(blockSource) 131 | subMgr.Start() 132 | defer subMgr.Stop() 133 | 134 | // We'll create two client subscriptions to ensure subscription 135 | // cancellation works as intended. We'll be canceling the second 136 | // subscription only. 137 | sub1, err := subMgr.NewSubscription(0) 138 | if err != nil { 139 | t.Fatalf("unable to register new subscription: %v", err) 140 | } 141 | sub2, err := subMgr.NewSubscription(0) 142 | if err != nil { 143 | t.Fatalf("unable to register new subscription: %v", err) 144 | } 145 | 146 | // We'll send a single block connected notification to both clients. 147 | go func() { 148 | blockSource.blockChan <- blockntfns.NewBlockConnected( 149 | emptyHeader, 1, 150 | ) 151 | }() 152 | 153 | // Both of them should receive it. 154 | subs := []*blockntfns.Subscription{sub1, sub2} 155 | for _, sub := range subs { 156 | select { 157 | case _, ok := <-sub.Notifications: 158 | if !ok { 159 | t.Fatal("expected to continue receiving " + 160 | "notifications") 161 | } 162 | case <-time.After(time.Second): 163 | t.Fatalf("expected block connected notification") 164 | } 165 | } 166 | 167 | // Now, we'll attempt to deliver another block connected notification, 168 | // but this time we'll cancel the second subscription. 169 | sub2.Cancel() 170 | 171 | go func() { 172 | blockSource.blockChan <- blockntfns.NewBlockConnected( 173 | emptyHeader, 2, 174 | ) 175 | }() 176 | 177 | // The first subscription should still see the new notification come 178 | // through. 179 | select { 180 | case _, ok := <-sub1.Notifications: 181 | if !ok { 182 | t.Fatalf("expected to continue receiving notifications") 183 | } 184 | case <-time.After(time.Second): 185 | t.Fatalf("expected block connected notification") 186 | } 187 | 188 | // However, the second subscription shouldn't. 189 | select { 190 | case _, ok := <-sub2.Notifications: 191 | if ok { 192 | t.Fatalf("expected closed NotificationsConnected channel") 193 | } 194 | case <-time.After(time.Second): 195 | t.Fatalf("expected closed NotificationsConnected channel") 196 | } 197 | } 198 | 199 | // TestManagerHistoricalBacklog ensures that when a client registers for a 200 | // subscription with a best known height lower than the current tip of the 201 | // chain, that a historical backlog of notifications is delivered from that 202 | // point forwards. 203 | func TestManagerHistoricalBacklog(t *testing.T) { 204 | t.Parallel() 205 | 206 | // We'll start by creating a subscription manager backed by our mocked 207 | // block source. 208 | blockSource := newMockBlockSource() 209 | subMgr := blockntfns.NewSubscriptionManager(blockSource) 210 | subMgr.Start() 211 | defer subMgr.Stop() 212 | 213 | // We'll make NotificationsSinceHeight return an error to ensure that a 214 | // client registration fails if it returns an error. 215 | blockSource.blocksSinceHeight = func(uint32) ([]blockntfns.BlockNtfn, 216 | uint32, error) { 217 | 218 | return nil, 0, errors.New("") 219 | } 220 | _, err := subMgr.NewSubscription(0) 221 | if err == nil { 222 | t.Fatal("expected registration to fail due to not delivering " + 223 | "backlog") 224 | } 225 | 226 | // We'll go with the assumption that the tip of the chain is at height 227 | // 20, while the client's best known height is 10. 228 | // NotificationsSinceHeight should then return notifications for blocks 229 | // 11-20. 230 | const chainTip uint32 = 20 231 | subCurrentHeight := chainTip / 2 232 | numBacklog := chainTip - subCurrentHeight 233 | blockSource.blocksSinceHeight = func(uint32) ([]blockntfns.BlockNtfn, 234 | uint32, error) { 235 | 236 | blocks := make([]blockntfns.BlockNtfn, 0, numBacklog) 237 | for i := subCurrentHeight + 1; i <= chainTip; i++ { 238 | blocks = append(blocks, blockntfns.NewBlockConnected( 239 | emptyHeader, i, 240 | )) 241 | } 242 | 243 | return blocks, chainTip, nil 244 | } 245 | 246 | // Register a new client with the expected current height. 247 | sub, err := subMgr.NewSubscription(subCurrentHeight) 248 | if err != nil { 249 | t.Fatalf("unable to register new subscription: %v", err) 250 | } 251 | 252 | // Then, we'll attempt to retrieve all of the expected notifications 253 | // from the historical backlog delivered. 254 | for i := uint32(0); i < numBacklog; i++ { 255 | select { 256 | case ntfn := <-sub.Notifications: 257 | if ntfn, ok := ntfn.(*blockntfns.Connected); !ok { 258 | t.Fatalf("expected *blockntfns.Connected "+ 259 | "notification, got %T", ntfn) 260 | } 261 | if ntfn.Height() != subCurrentHeight+1 { 262 | t.Fatalf("expected new block with height %d, "+ 263 | "got %d", subCurrentHeight+1, 264 | ntfn.Height()) 265 | } 266 | subCurrentHeight++ 267 | 268 | case <-time.After(time.Second): 269 | t.Fatal("expected to receive historical block " + 270 | "notification") 271 | } 272 | } 273 | 274 | // Finally, the client should now be caught up with the chain. 275 | if subCurrentHeight != chainTip { 276 | t.Fatalf("expected client to be caught up to height %d, got %d", 277 | chainTip, subCurrentHeight) 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /blockntfns/notification.go: -------------------------------------------------------------------------------- 1 | package blockntfns 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/btcsuite/btcd/wire" 7 | ) 8 | 9 | // BlockNtfn is an interface that coalesces all the different types of block 10 | // notifications. 11 | type BlockNtfn interface { 12 | // Header returns the header of the block for which this notification is 13 | // for. 14 | Header() wire.BlockHeader 15 | 16 | // Height returns the height of the block for which this notification is 17 | // for. 18 | Height() uint32 19 | 20 | // ChainTip returns the header of the new tip of the chain after 21 | // processing the block being connected/disconnected. 22 | ChainTip() wire.BlockHeader 23 | } 24 | 25 | // Connected is a block notification that gets dispatched to clients when the 26 | // filter header of a new block has been found that extends the current chain. 27 | type Connected struct { 28 | header wire.BlockHeader 29 | height uint32 30 | } 31 | 32 | // A compile-time check to ensure Connected satisfies the BlockNtfn interface. 33 | var _ BlockNtfn = (*Connected)(nil) 34 | 35 | // NewBlockConnected creates a new Connected notification for the given block. 36 | func NewBlockConnected(header wire.BlockHeader, height uint32) *Connected { 37 | return &Connected{header: header, height: height} 38 | } 39 | 40 | // Header returns the header of the block extending the chain. 41 | func (n *Connected) Header() wire.BlockHeader { 42 | return n.header 43 | } 44 | 45 | // Height returns the height of the block extending the chain. 46 | func (n *Connected) Height() uint32 { 47 | return n.height 48 | } 49 | 50 | // ChainTip returns the header of the new tip of the chain after processing the 51 | // block being connected. 52 | func (n *Connected) ChainTip() wire.BlockHeader { 53 | return n.header 54 | } 55 | 56 | // String returns the string representation of a Connected notification. 57 | func (n *Connected) String() string { 58 | return fmt.Sprintf("block connected (height=%d, hash=%v)", n.height, 59 | n.header.BlockHash()) 60 | } 61 | 62 | // Disconnected if a notification that gets dispatched to clients when a reorg 63 | // has been detected at the tip of the chain. 64 | type Disconnected struct { 65 | headerDisconnected wire.BlockHeader 66 | heightDisconnected uint32 67 | chainTip wire.BlockHeader 68 | } 69 | 70 | // A compile-time check to ensure Disconnected satisfies the BlockNtfn 71 | // interface. 72 | var _ BlockNtfn = (*Disconnected)(nil) 73 | 74 | // NewBlockDisconnected creates a Disconnected notification for the given block. 75 | func NewBlockDisconnected(headerDisconnected wire.BlockHeader, 76 | heightDisconnected uint32, chainTip wire.BlockHeader) *Disconnected { 77 | 78 | return &Disconnected{ 79 | headerDisconnected: headerDisconnected, 80 | heightDisconnected: heightDisconnected, 81 | chainTip: chainTip, 82 | } 83 | } 84 | 85 | // Header returns the header of the block being disconnected. 86 | func (n *Disconnected) Header() wire.BlockHeader { 87 | return n.headerDisconnected 88 | } 89 | 90 | // Height returns the height of the block being disconnected. 91 | func (n *Disconnected) Height() uint32 { 92 | return n.heightDisconnected 93 | } 94 | 95 | // ChainTip returns the header of the new tip of the chain after processing the 96 | // block being disconnected. 97 | func (n *Disconnected) ChainTip() wire.BlockHeader { 98 | return n.chainTip 99 | } 100 | 101 | // String returns the string representation of a Disconnected notification. 102 | func (n *Disconnected) String() string { 103 | return fmt.Sprintf("block disconnected (height=%d, hash=%v)", 104 | n.heightDisconnected, n.headerDisconnected.BlockHash()) 105 | } 106 | -------------------------------------------------------------------------------- /cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import "fmt" 4 | 5 | var ( 6 | // ErrElementNotFound is returned when element isn't found in the cache. 7 | ErrElementNotFound = fmt.Errorf("unable to find element") 8 | ) 9 | 10 | // Value represents a value stored in the Cache. 11 | type Value interface { 12 | // Size determines how big this entry would be in the cache. For 13 | // example, for a filter, it could be the size of the filter in bytes. 14 | Size() (uint64, error) 15 | } 16 | 17 | // Cache represents a generic cache. 18 | type Cache[K comparable, V Value] interface { 19 | // Put stores the given (key,value) pair, replacing existing value if 20 | // key already exists. The return value indicates whether items had to 21 | // be evicted to make room for the new element. 22 | Put(key K, value V) (bool, error) 23 | 24 | // Get returns the value for a given key. 25 | Get(key K) (V, error) 26 | 27 | // Len returns number of elements in the cache. 28 | Len() int 29 | } 30 | -------------------------------------------------------------------------------- /cache/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lightninglabs/neutrino/cache 2 | 3 | go 1.18 4 | 5 | require github.com/stretchr/testify v1.8.1 6 | 7 | require ( 8 | github.com/davecgh/go-spew v1.1.1 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | gopkg.in/yaml.v3 v3.0.1 // indirect 11 | ) 12 | 13 | replace github.com/lightninglabs/neutrino/cache => ./cache 14 | -------------------------------------------------------------------------------- /cache/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 8 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 11 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 12 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 16 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 17 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | -------------------------------------------------------------------------------- /cache/lru/list.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // Copyright (c) 2017-2022 Lightning Labs 5 | 6 | // Package list implements a doubly linked list. 7 | // 8 | // To iterate over a list (where l is a *List): 9 | // 10 | // for e := l.Front(); e != nil; e = e.Next() { 11 | // // do something with e.Value 12 | // } 13 | package lru 14 | 15 | // Element is an element of a linked list. 16 | type Element[V any] struct { 17 | // Next and previous pointers in the doubly-linked list of elements. 18 | // To simplify the implementation, internally a list l is implemented 19 | // as a ring, such that &l.root is both the next element of the last 20 | // list element (l.Back()) and the previous element of the first list 21 | // element (l.Front()). 22 | next, prev *Element[V] 23 | 24 | // The list to which this element belongs. 25 | list *List[V] 26 | 27 | // The value stored with this element. 28 | Value V 29 | } 30 | 31 | // Next returns the next list element or nil. 32 | func (e *Element[V]) Next() *Element[V] { 33 | if p := e.next; e.list != nil && p != &e.list.root { 34 | return p 35 | } 36 | return nil 37 | } 38 | 39 | // Prev returns the previous list element or nil. 40 | func (e *Element[V]) Prev() *Element[V] { 41 | if p := e.prev; e.list != nil && p != &e.list.root { 42 | return p 43 | } 44 | return nil 45 | } 46 | 47 | // List represents a doubly linked list. 48 | // The zero value for List is an empty list ready to use. 49 | type List[V any] struct { 50 | root Element[V] // sentinel list element, only &root, root.prev, and root.next are used 51 | len int // current list length excluding (this) sentinel element 52 | } 53 | 54 | // Init initializes or clears list l. 55 | func (l *List[V]) Init() *List[V] { 56 | l.root.next = &l.root 57 | l.root.prev = &l.root 58 | l.len = 0 59 | return l 60 | } 61 | 62 | // NewList returns an initialized list. 63 | func NewList[V any]() *List[V] { return new(List[V]).Init() } 64 | 65 | // Len returns the number of elements of list l. 66 | // The complexity is O(1). 67 | func (l *List[V]) Len() int { return l.len } 68 | 69 | // Front returns the first element of list l or nil if the list is empty. 70 | func (l *List[V]) Front() *Element[V] { 71 | if l.len == 0 { 72 | return nil 73 | } 74 | return l.root.next 75 | } 76 | 77 | // Back returns the last element of list l or nil if the list is empty. 78 | func (l *List[V]) Back() *Element[V] { 79 | if l.len == 0 { 80 | return nil 81 | } 82 | return l.root.prev 83 | } 84 | 85 | // lazyInit lazily initializes a zero List value. 86 | func (l *List[V]) lazyInit() { 87 | if l.root.next == nil { 88 | l.Init() 89 | } 90 | } 91 | 92 | // insert inserts e after at, increments l.len, and returns e. 93 | func (l *List[V]) insert(e, at *Element[V]) *Element[V] { 94 | e.prev = at 95 | e.next = at.next 96 | e.prev.next = e 97 | e.next.prev = e 98 | e.list = l 99 | l.len++ 100 | return e 101 | } 102 | 103 | // insertValue is a convenience wrapper for insert(&Element{Value: v}, at). 104 | func (l *List[V]) insertValue(v V, at *Element[V]) *Element[V] { 105 | return l.insert(&Element[V]{Value: v}, at) 106 | } 107 | 108 | // remove removes e from its list, decrements l.len 109 | func (l *List[V]) remove(e *Element[V]) { 110 | e.prev.next = e.next 111 | e.next.prev = e.prev 112 | e.next = nil // avoid memory leaks 113 | e.prev = nil // avoid memory leaks 114 | e.list = nil 115 | l.len-- 116 | } 117 | 118 | // move moves e to next to at. 119 | func (l *List[V]) move(e, at *Element[V]) { 120 | if e == at { 121 | return 122 | } 123 | e.prev.next = e.next 124 | e.next.prev = e.prev 125 | 126 | e.prev = at 127 | e.next = at.next 128 | e.prev.next = e 129 | e.next.prev = e 130 | } 131 | 132 | // Remove removes e from l if e is an element of list l. 133 | // It returns the element value e.Value. 134 | // The element must not be nil. 135 | func (l *List[V]) Remove(e *Element[V]) any { 136 | if e.list == l { 137 | // if e.list == l, l must have been initialized when e was inserted 138 | // in l or l == nil (e is a zero Element) and l.remove will crash 139 | l.remove(e) 140 | } 141 | return e.Value 142 | } 143 | 144 | // PushFront inserts a new element e with value v at the front of list l and returns e. 145 | func (l *List[V]) PushFront(v V) *Element[V] { 146 | l.lazyInit() 147 | return l.insertValue(v, &l.root) 148 | } 149 | 150 | // PushBack inserts a new element e with value v at the back of list l and returns e. 151 | func (l *List[V]) PushBack(v V) *Element[V] { 152 | l.lazyInit() 153 | return l.insertValue(v, l.root.prev) 154 | } 155 | 156 | // InsertBefore inserts a new element e with value v immediately before mark and returns e. 157 | // If mark is not an element of l, the list is not modified. 158 | // The mark must not be nil. 159 | func (l *List[V]) InsertBefore(v V, mark *Element[V]) *Element[V] { 160 | if mark.list != l { 161 | return nil 162 | } 163 | // see comment in List.Remove about initialization of l 164 | return l.insertValue(v, mark.prev) 165 | } 166 | 167 | // InsertAfter inserts a new element e with value v immediately after mark and returns e. 168 | // If mark is not an element of l, the list is not modified. 169 | // The mark must not be nil. 170 | func (l *List[V]) InsertAfter(v V, mark *Element[V]) *Element[V] { 171 | if mark.list != l { 172 | return nil 173 | } 174 | // see comment in List.Remove about initialization of l 175 | return l.insertValue(v, mark) 176 | } 177 | 178 | // MoveToFront moves element e to the front of list l. 179 | // If e is not an element of l, the list is not modified. 180 | // The element must not be nil. 181 | func (l *List[V]) MoveToFront(e *Element[V]) { 182 | if e.list != l || l.root.next == e { 183 | return 184 | } 185 | // see comment in List.Remove about initialization of l 186 | l.move(e, &l.root) 187 | } 188 | 189 | // MoveToBack moves element e to the back of list l. 190 | // If e is not an element of l, the list is not modified. 191 | // The element must not be nil. 192 | func (l *List[V]) MoveToBack(e *Element[V]) { 193 | if e.list != l || l.root.prev == e { 194 | return 195 | } 196 | // see comment in List.Remove about initialization of l 197 | l.move(e, l.root.prev) 198 | } 199 | 200 | // MoveBefore moves element e to its new position before mark. 201 | // If e or mark is not an element of l, or e == mark, the list is not modified. 202 | // The element and mark must not be nil. 203 | func (l *List[V]) MoveBefore(e, mark *Element[V]) { 204 | if e.list != l || e == mark || mark.list != l { 205 | return 206 | } 207 | l.move(e, mark.prev) 208 | } 209 | 210 | // MoveAfter moves element e to its new position after mark. 211 | // If e or mark is not an element of l, or e == mark, the list is not modified. 212 | // The element and mark must not be nil. 213 | func (l *List[V]) MoveAfter(e, mark *Element[V]) { 214 | if e.list != l || e == mark || mark.list != l { 215 | return 216 | } 217 | l.move(e, mark) 218 | } 219 | 220 | // PushBackList inserts a copy of another list at the back of list l. 221 | // The lists l and other may be the same. They must not be nil. 222 | func (l *List[V]) PushBackList(other *List[V]) { 223 | l.lazyInit() 224 | for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { 225 | l.insertValue(e.Value, l.root.prev) 226 | } 227 | } 228 | 229 | // PushFrontList inserts a copy of another list at the front of list l. 230 | // The lists l and other may be the same. They must not be nil. 231 | func (l *List[V]) PushFrontList(other *List[V]) { 232 | l.lazyInit() 233 | for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { 234 | l.insertValue(e.Value, &l.root) 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /cache/lru/list_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // Copyright (c) 2017-2022 Lightning Labs 5 | 6 | package lru 7 | 8 | import "testing" 9 | 10 | func checkListLen[V any](t *testing.T, l *List[V], len int) bool { 11 | if n := l.Len(); n != len { 12 | t.Errorf("l.Len() = %d, want %d", n, len) 13 | return false 14 | } 15 | return true 16 | } 17 | 18 | func checkListPointers[V any](t *testing.T, l *List[V], es []*Element[V]) { 19 | root := &l.root 20 | 21 | if !checkListLen(t, l, len(es)) { 22 | return 23 | } 24 | 25 | // zero length lists must be the zero value or properly initialized (sentinel circle) 26 | if len(es) == 0 { 27 | if l.root.next != nil && l.root.next != root || l.root.prev != nil && l.root.prev != root { 28 | t.Errorf("l.root.next = %p, l.root.prev = %p; both should both be nil or %p", l.root.next, l.root.prev, root) 29 | } 30 | return 31 | } 32 | // len(es) > 0 33 | 34 | // check internal and external prev/next connections 35 | for i, e := range es { 36 | prev := root 37 | Prev := (*Element[V])(nil) 38 | if i > 0 { 39 | prev = es[i-1] 40 | Prev = prev 41 | } 42 | if p := e.prev; p != prev { 43 | t.Errorf("elt[%d](%p).prev = %p, want %p", i, e, p, prev) 44 | } 45 | if p := e.Prev(); p != Prev { 46 | t.Errorf("elt[%d](%p).Prev() = %p, want %p", i, e, p, Prev) 47 | } 48 | 49 | next := root 50 | Next := (*Element[V])(nil) 51 | if i < len(es)-1 { 52 | next = es[i+1] 53 | Next = next 54 | } 55 | if n := e.next; n != next { 56 | t.Errorf("elt[%d](%p).next = %p, want %p", i, e, n, next) 57 | } 58 | if n := e.Next(); n != Next { 59 | t.Errorf("elt[%d](%p).Next() = %p, want %p", i, e, n, Next) 60 | } 61 | } 62 | } 63 | 64 | func TestList(t *testing.T) { 65 | l := NewList[int]() 66 | checkListPointers(t, l, []*Element[int]{}) 67 | 68 | // Single element list 69 | e := l.PushFront(5) 70 | checkListPointers(t, l, []*Element[int]{e}) 71 | l.MoveToFront(e) 72 | checkListPointers(t, l, []*Element[int]{e}) 73 | l.MoveToBack(e) 74 | checkListPointers(t, l, []*Element[int]{e}) 75 | l.Remove(e) 76 | checkListPointers(t, l, []*Element[int]{}) 77 | 78 | // Bigger list 79 | e2 := l.PushFront(2) 80 | e1 := l.PushFront(1) 81 | e3 := l.PushBack(3) 82 | e4 := l.PushBack(0) 83 | checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) 84 | 85 | l.Remove(e2) 86 | checkListPointers(t, l, []*Element[int]{e1, e3, e4}) 87 | 88 | l.MoveToFront(e3) // move from middle 89 | checkListPointers(t, l, []*Element[int]{e3, e1, e4}) 90 | 91 | l.MoveToFront(e1) 92 | l.MoveToBack(e3) // move from middle 93 | checkListPointers(t, l, []*Element[int]{e1, e4, e3}) 94 | 95 | l.MoveToFront(e3) // move from back 96 | checkListPointers(t, l, []*Element[int]{e3, e1, e4}) 97 | l.MoveToFront(e3) // should be no-op 98 | checkListPointers(t, l, []*Element[int]{e3, e1, e4}) 99 | 100 | l.MoveToBack(e3) // move from front 101 | checkListPointers(t, l, []*Element[int]{e1, e4, e3}) 102 | l.MoveToBack(e3) // should be no-op 103 | checkListPointers(t, l, []*Element[int]{e1, e4, e3}) 104 | 105 | e2 = l.InsertBefore(2, e1) // insert before front 106 | checkListPointers(t, l, []*Element[int]{e2, e1, e4, e3}) 107 | l.Remove(e2) 108 | e2 = l.InsertBefore(2, e4) // insert before middle 109 | checkListPointers(t, l, []*Element[int]{e1, e2, e4, e3}) 110 | l.Remove(e2) 111 | e2 = l.InsertBefore(2, e3) // insert before back 112 | checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) 113 | l.Remove(e2) 114 | 115 | e2 = l.InsertAfter(2, e1) // insert after front 116 | checkListPointers(t, l, []*Element[int]{e1, e2, e4, e3}) 117 | l.Remove(e2) 118 | e2 = l.InsertAfter(2, e4) // insert after middle 119 | checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) 120 | l.Remove(e2) 121 | e2 = l.InsertAfter(2, e3) // insert after back 122 | checkListPointers(t, l, []*Element[int]{e1, e4, e3, e2}) 123 | l.Remove(e2) 124 | 125 | // Check standard iteration. 126 | sum := 0 127 | for e := l.Front(); e != nil; e = e.Next() { 128 | sum += e.Value 129 | } 130 | if sum != 4 { 131 | t.Errorf("sum over l = %d, want 4", sum) 132 | } 133 | 134 | // Clear all elements by iterating 135 | var next *Element[int] 136 | for e := l.Front(); e != nil; e = next { 137 | next = e.Next() 138 | l.Remove(e) 139 | } 140 | checkListPointers(t, l, []*Element[int]{}) 141 | } 142 | 143 | func checkList[V comparable](t *testing.T, l *List[V], es []V) { 144 | if !checkListLen(t, l, len(es)) { 145 | return 146 | } 147 | 148 | i := 0 149 | for e := l.Front(); e != nil; e = e.Next() { 150 | le := e.Value 151 | if le != es[i] { 152 | t.Errorf("elt[%d].Value = %v, want %v", i, le, es[i]) 153 | } 154 | i++ 155 | } 156 | } 157 | 158 | func TestExtending(t *testing.T) { 159 | l1 := NewList[int]() 160 | l2 := NewList[int]() 161 | 162 | l1.PushBack(1) 163 | l1.PushBack(2) 164 | l1.PushBack(3) 165 | 166 | l2.PushBack(4) 167 | l2.PushBack(5) 168 | 169 | l3 := NewList[int]() 170 | l3.PushBackList(l1) 171 | checkList(t, l3, []int{1, 2, 3}) 172 | l3.PushBackList(l2) 173 | checkList(t, l3, []int{1, 2, 3, 4, 5}) 174 | 175 | l3 = NewList[int]() 176 | l3.PushFrontList(l2) 177 | checkList(t, l3, []int{4, 5}) 178 | l3.PushFrontList(l1) 179 | checkList(t, l3, []int{1, 2, 3, 4, 5}) 180 | 181 | checkList(t, l1, []int{1, 2, 3}) 182 | checkList(t, l2, []int{4, 5}) 183 | 184 | l3 = NewList[int]() 185 | l3.PushBackList(l1) 186 | checkList(t, l3, []int{1, 2, 3}) 187 | l3.PushBackList(l3) 188 | checkList(t, l3, []int{1, 2, 3, 1, 2, 3}) 189 | 190 | l3 = NewList[int]() 191 | l3.PushFrontList(l1) 192 | checkList(t, l3, []int{1, 2, 3}) 193 | l3.PushFrontList(l3) 194 | checkList(t, l3, []int{1, 2, 3, 1, 2, 3}) 195 | 196 | l3 = NewList[int]() 197 | l1.PushBackList(l3) 198 | checkList(t, l1, []int{1, 2, 3}) 199 | l1.PushFrontList(l3) 200 | checkList(t, l1, []int{1, 2, 3}) 201 | } 202 | 203 | func TestRemove(t *testing.T) { 204 | l := NewList[int]() 205 | e1 := l.PushBack(1) 206 | e2 := l.PushBack(2) 207 | checkListPointers(t, l, []*Element[int]{e1, e2}) 208 | e := l.Front() 209 | l.Remove(e) 210 | checkListPointers(t, l, []*Element[int]{e2}) 211 | l.Remove(e) 212 | checkListPointers(t, l, []*Element[int]{e2}) 213 | } 214 | 215 | func TestIssue4103(t *testing.T) { 216 | l1 := NewList[int]() 217 | l1.PushBack(1) 218 | l1.PushBack(2) 219 | 220 | l2 := NewList[int]() 221 | l2.PushBack(3) 222 | l2.PushBack(4) 223 | 224 | e := l1.Front() 225 | l2.Remove(e) // l2 should not change because e is not an element of l2 226 | if n := l2.Len(); n != 2 { 227 | t.Errorf("l2.Len() = %d, want 2", n) 228 | } 229 | 230 | l1.InsertBefore(8, e) 231 | if n := l1.Len(); n != 3 { 232 | t.Errorf("l1.Len() = %d, want 3", n) 233 | } 234 | } 235 | 236 | func TestIssue6349(t *testing.T) { 237 | l := NewList[int]() 238 | l.PushBack(1) 239 | l.PushBack(2) 240 | 241 | e := l.Front() 242 | l.Remove(e) 243 | if e.Value != 1 { 244 | t.Errorf("e.value = %d, want 1", e.Value) 245 | } 246 | if e.Next() != nil { 247 | t.Errorf("e.Next() != nil") 248 | } 249 | if e.Prev() != nil { 250 | t.Errorf("e.Prev() != nil") 251 | } 252 | } 253 | 254 | func TestMove(t *testing.T) { 255 | l := NewList[int]() 256 | e1 := l.PushBack(1) 257 | e2 := l.PushBack(2) 258 | e3 := l.PushBack(3) 259 | e4 := l.PushBack(4) 260 | 261 | l.MoveAfter(e3, e3) 262 | checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) 263 | l.MoveBefore(e2, e2) 264 | checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) 265 | 266 | l.MoveAfter(e3, e2) 267 | checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) 268 | l.MoveBefore(e2, e3) 269 | checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) 270 | 271 | l.MoveBefore(e2, e4) 272 | checkListPointers(t, l, []*Element[int]{e1, e3, e2, e4}) 273 | e2, e3 = e3, e2 274 | 275 | l.MoveBefore(e4, e1) 276 | checkListPointers(t, l, []*Element[int]{e4, e1, e2, e3}) 277 | e1, e2, e3, e4 = e4, e1, e2, e3 278 | 279 | l.MoveAfter(e4, e1) 280 | checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) 281 | e2, e3, e4 = e4, e2, e3 282 | 283 | l.MoveAfter(e2, e3) 284 | checkListPointers(t, l, []*Element[int]{e1, e3, e2, e4}) 285 | } 286 | 287 | // Test PushFront, PushBack, PushFrontList, PushBackList with uninitialized List 288 | func TestZeroList(t *testing.T) { 289 | var l1 = new(List[int]) 290 | l1.PushFront(1) 291 | checkList(t, l1, []int{1}) 292 | 293 | var l2 = new(List[int]) 294 | l2.PushBack(1) 295 | checkList(t, l2, []int{1}) 296 | 297 | var l3 = new(List[int]) 298 | l3.PushFrontList(l1) 299 | checkList(t, l3, []int{1}) 300 | 301 | var l4 = new(List[int]) 302 | l4.PushBackList(l2) 303 | checkList(t, l4, []int{1}) 304 | } 305 | 306 | // Test that a list l is not modified when calling InsertBefore with a mark that is not an element of l. 307 | func TestInsertBeforeUnknownMark(t *testing.T) { 308 | var l List[int] 309 | l.PushBack(1) 310 | l.PushBack(2) 311 | l.PushBack(3) 312 | l.InsertBefore(1, new(Element[int])) 313 | checkList(t, &l, []int{1, 2, 3}) 314 | } 315 | 316 | // Test that a list l is not modified when calling InsertAfter with a mark that is not an element of l. 317 | func TestInsertAfterUnknownMark(t *testing.T) { 318 | var l List[int] 319 | l.PushBack(1) 320 | l.PushBack(2) 321 | l.PushBack(3) 322 | l.InsertAfter(1, new(Element[int])) 323 | checkList(t, &l, []int{1, 2, 3}) 324 | } 325 | 326 | // Test that a list l is not modified when calling MoveAfter or MoveBefore with a mark that is not an element of l. 327 | func TestMoveUnknownMark(t *testing.T) { 328 | var l1 List[int] 329 | e1 := l1.PushBack(1) 330 | 331 | var l2 List[int] 332 | e2 := l2.PushBack(2) 333 | 334 | l1.MoveAfter(e1, e2) 335 | checkList(t, &l1, []int{1}) 336 | checkList(t, &l2, []int{2}) 337 | 338 | l1.MoveBefore(e1, e2) 339 | checkList(t, &l1, []int{1}) 340 | checkList(t, &l2, []int{2}) 341 | } 342 | -------------------------------------------------------------------------------- /cache/lru/lru.go: -------------------------------------------------------------------------------- 1 | package lru 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "github.com/lightninglabs/neutrino/cache" 8 | ) 9 | 10 | // entry represents a (key,value) pair entry in the Cache. The Cache's list 11 | // stores entries which let us get the cache key when an entry is evicted. 12 | type entry[K comparable, V cache.Value] struct { 13 | key K 14 | value V 15 | } 16 | 17 | // Cache provides a generic thread-safe lru cache that can be used for 18 | // storing filters, blocks, etc. 19 | type Cache[K comparable, V cache.Value] struct { 20 | // capacity represents how much this cache can hold. It could be number 21 | // of elements or a number of bytes, decided by the cache.Value's Size. 22 | capacity uint64 23 | 24 | // size represents the size of all the elements currently in the cache. 25 | size uint64 26 | 27 | // ll is a doubly linked list which keeps track of recency of used 28 | // elements by moving them to the front. 29 | ll *List[entry[K, V]] 30 | 31 | // cache is a generic cache which allows us to find an elements position 32 | // in the ll list from a given key. 33 | cache syncMap[K, *Element[entry[K, V]]] 34 | 35 | // mtx is used to make sure the Cache is thread-safe. 36 | mtx sync.RWMutex 37 | } 38 | 39 | // NewCache return a cache with specified capacity, the cache's size can't 40 | // exceed that given capacity. 41 | func NewCache[K comparable, V cache.Value](capacity uint64) *Cache[K, V] { 42 | return &Cache[K, V]{ 43 | capacity: capacity, 44 | ll: NewList[entry[K, V]](), 45 | cache: syncMap[K, *Element[entry[K, V]]]{}, 46 | } 47 | } 48 | 49 | // evict will evict as many elements as necessary to make enough space for a new 50 | // element with size needed to be inserted. 51 | func (c *Cache[K, V]) evict(needed uint64) (bool, error) { 52 | if needed > c.capacity { 53 | return false, fmt.Errorf("can't evict %v elements in size, "+ 54 | "since capacity is %v", needed, c.capacity) 55 | } 56 | 57 | evicted := false 58 | for c.capacity-c.size < needed { 59 | // We still need to evict some more elements. 60 | if c.ll.Len() == 0 { 61 | // We should never reach here. 62 | return false, fmt.Errorf("all elements got evicted, "+ 63 | "yet still need to evict %v, likelihood of "+ 64 | "error during size calculation", 65 | needed-(c.capacity-c.size)) 66 | } 67 | 68 | // Find the least recently used item. 69 | if elr := c.ll.Back(); elr != nil { 70 | // Determine lru item's size. 71 | ce := elr.Value 72 | es, err := ce.value.Size() 73 | if err != nil { 74 | return false, fmt.Errorf("couldn't determine "+ 75 | "size of existing cache value %v", err) 76 | } 77 | 78 | // Account for that element's removal in evicted and 79 | // cache size. 80 | c.size -= es 81 | 82 | // Remove the element from the cache. 83 | c.ll.Remove(elr) 84 | c.cache.Delete(ce.key) 85 | evicted = true 86 | } 87 | } 88 | 89 | return evicted, nil 90 | } 91 | 92 | // Put inserts a given (key,value) pair into the cache, if the key already 93 | // exists, it will replace value and update it to be most recent item in cache. 94 | // The return value indicates whether items had to be evicted to make room for 95 | // the new element. 96 | func (c *Cache[K, V]) Put(key K, value V) (bool, error) { 97 | vs, err := value.Size() 98 | if err != nil { 99 | return false, fmt.Errorf("couldn't determine size of cache "+ 100 | "value: %v", err) 101 | } 102 | 103 | if vs > c.capacity { 104 | return false, fmt.Errorf("can't insert entry of size %v into "+ 105 | "cache with capacity %v", vs, c.capacity) 106 | } 107 | 108 | // Load the element. 109 | el, ok := c.cache.Load(key) 110 | 111 | // Update the internal list inside a lock. 112 | c.mtx.Lock() 113 | 114 | // If the element already exists, remove it and decrease cache's size. 115 | if ok { 116 | es, err := el.Value.value.Size() 117 | if err != nil { 118 | c.mtx.Unlock() 119 | 120 | return false, fmt.Errorf("couldn't determine size of "+ 121 | "existing cache value %v", err) 122 | } 123 | 124 | c.ll.Remove(el) 125 | c.size -= es 126 | } 127 | 128 | // Then we need to make sure we have enough space for the element, evict 129 | // elements if we need more space. 130 | evicted, err := c.evict(vs) 131 | if err != nil { 132 | return false, err 133 | } 134 | 135 | // We have made enough space in the cache, so just insert it. 136 | el = c.ll.PushFront(entry[K, V]{key, value}) 137 | c.size += vs 138 | 139 | // Release the lock. 140 | c.mtx.Unlock() 141 | 142 | // Update the cache. 143 | c.cache.Store(key, el) 144 | 145 | return evicted, nil 146 | } 147 | 148 | // Get will return value for a given key, making the element the most recently 149 | // accessed item in the process. Will return nil if the key isn't found. 150 | func (c *Cache[K, V]) Get(key K) (V, error) { 151 | var defaultVal V 152 | 153 | el, ok := c.cache.Load(key) 154 | if !ok { 155 | // Element not found in the cache. 156 | return defaultVal, cache.ErrElementNotFound 157 | } 158 | 159 | c.mtx.Lock() 160 | defer c.mtx.Unlock() 161 | 162 | // When the cache needs to evict a element to make space for another 163 | // one, it starts eviction from the back, so by moving this element to 164 | // the front, it's eviction is delayed because it's recently accessed. 165 | c.ll.MoveToFront(el) 166 | return el.Value.value, nil 167 | } 168 | 169 | // Len returns number of elements in the cache. 170 | func (c *Cache[K, V]) Len() int { 171 | c.mtx.RLock() 172 | defer c.mtx.RUnlock() 173 | 174 | return c.ll.Len() 175 | } 176 | 177 | // Delete removes an item from the cache. 178 | func (c *Cache[K, V]) Delete(key K) { 179 | c.LoadAndDelete(key) 180 | } 181 | 182 | // LoadAndDelete queries an item and deletes it from the cache using the 183 | // specified key. 184 | func (c *Cache[K, V]) LoadAndDelete(key K) (V, bool) { 185 | var defaultVal V 186 | 187 | // Noop if the element doesn't exist. 188 | el, ok := c.cache.LoadAndDelete(key) 189 | if !ok { 190 | return defaultVal, false 191 | } 192 | 193 | c.mtx.Lock() 194 | defer c.mtx.Unlock() 195 | 196 | // Get its size. 197 | vs, err := el.Value.value.Size() 198 | if err != nil { 199 | return defaultVal, false 200 | } 201 | 202 | // Remove the element from the list and update the cache's size. 203 | c.ll.Remove(el) 204 | c.size -= vs 205 | 206 | return el.Value.value, true 207 | } 208 | 209 | // Range iterates the cache without any ordering. 210 | func (c *Cache[K, V]) Range(visitor func(K, V) bool) { 211 | // valueVisitor is a closure to help unwrap the value from the cache. 212 | valueVisitor := func(key K, value *Element[entry[K, V]]) bool { 213 | return visitor(key, value.Value.value) 214 | } 215 | 216 | c.cache.Range(valueVisitor) 217 | } 218 | 219 | // RangeFILO iterates the items with FILO order, behaving like a stack. 220 | func (c *Cache[K, V]) RangeFILO(visitor func(K, V) bool) { 221 | for e := c.ll.Front(); e != nil; e = e.Next() { 222 | next := visitor(e.Value.key, e.Value.value) 223 | 224 | // Stops the iteration if the visitor returns false to mimick 225 | // the same behavior of `Range`. 226 | if !next { 227 | return 228 | } 229 | } 230 | } 231 | 232 | // RangeFIFO iterates the items with FIFO order, behaving like a queue. 233 | func (c *Cache[K, V]) RangeFIFO(visitor func(K, V) bool) { 234 | for e := c.ll.Back(); e != nil; e = e.Prev() { 235 | next := visitor(e.Value.key, e.Value.value) 236 | 237 | // Stops the iteration if the visitor returns false to mimick 238 | // the same behavior of `Range`. 239 | if !next { 240 | return 241 | } 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /cache/lru/sync_map.go: -------------------------------------------------------------------------------- 1 | package lru 2 | 3 | import "sync" 4 | 5 | // syncMap wraps a sync.Map with type parameters such that it's easier to 6 | // access the items stored in the map since no type assertion is needed. It 7 | // also requires explicit type definition when declaring and initiating the 8 | // variables, which helps us understanding what's stored in a given map. 9 | // 10 | // NOTE: this is unexported to avoid confusion with `lnd`'s `SyncMap`. 11 | type syncMap[K comparable, V any] struct { 12 | sync.Map 13 | } 14 | 15 | // Store puts an item in the map. 16 | func (m *syncMap[K, V]) Store(key K, value V) { 17 | m.Map.Store(key, value) 18 | } 19 | 20 | // Load queries an item from the map using the specified key. If the item 21 | // cannot be found, an empty value and false will be returned. If the stored 22 | // item fails the type assertion, a nil value and false will be returned. 23 | func (m *syncMap[K, V]) Load(key K) (V, bool) { 24 | result, ok := m.Map.Load(key) 25 | if !ok { 26 | return *new(V), false // nolint: gocritic 27 | } 28 | 29 | item, ok := result.(V) 30 | return item, ok 31 | } 32 | 33 | // Delete removes an item from the map specified by the key. 34 | func (m *syncMap[K, V]) Delete(key K) { 35 | m.Map.Delete(key) 36 | } 37 | 38 | // LoadAndDelete queries an item and deletes it from the map using the 39 | // specified key. 40 | func (m *syncMap[K, V]) LoadAndDelete(key K) (V, bool) { 41 | result, loaded := m.Map.LoadAndDelete(key) 42 | if !loaded { 43 | return *new(V), loaded // nolint: gocritic 44 | } 45 | 46 | item, ok := result.(V) 47 | return item, ok 48 | } 49 | 50 | // Range iterates the map. 51 | func (m *syncMap[K, V]) Range(visitor func(K, V) bool) { 52 | m.Map.Range(func(k any, v any) bool { 53 | return visitor(k.(K), v.(V)) 54 | }) 55 | } 56 | 57 | // Len returns the number of items in the map. 58 | func (m *syncMap[K, V]) Len() int { 59 | var count int 60 | m.Range(func(K, V) bool { 61 | count++ 62 | return true 63 | }) 64 | 65 | return count 66 | } 67 | -------------------------------------------------------------------------------- /cache_test.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/btcsuite/btcd/btcutil" 8 | "github.com/btcsuite/btcd/btcutil/gcs" 9 | "github.com/btcsuite/btcd/chaincfg/chainhash" 10 | "github.com/btcsuite/btcd/wire" 11 | "github.com/lightninglabs/neutrino/cache" 12 | "github.com/lightninglabs/neutrino/cache/lru" 13 | "github.com/lightninglabs/neutrino/filterdb" 14 | ) 15 | 16 | // TestBlockFilterCaches tests that we can put and retrieve elements from all 17 | // implementations of the filter and block caches. 18 | func TestBlockFilterCaches(t *testing.T) { 19 | t.Parallel() 20 | 21 | const filterType = filterdb.RegularFilter 22 | 23 | // Create a cache large enough to not evict any item. We do this so we 24 | // don't have to worry about the eviction strategy of the tested 25 | // caches. 26 | const numElements = 10 27 | const cacheSize = 100000 28 | 29 | // Initialize all types of caches we want to test, for both filters and 30 | // blocks. Currently the LRU cache is the only implementation. 31 | filterCaches := []cache.Cache[FilterCacheKey, *CacheableFilter]{ 32 | lru.NewCache[FilterCacheKey, *CacheableFilter](cacheSize), 33 | } 34 | blockCaches := []cache.Cache[wire.InvVect, *CacheableBlock]{ 35 | lru.NewCache[wire.InvVect, *CacheableBlock](cacheSize), 36 | } 37 | 38 | // Generate a list of hashes, filters and blocks that we will use as 39 | // cache keys an values. 40 | var ( 41 | blockHashes []chainhash.Hash 42 | filters []*gcs.Filter 43 | blocks []*btcutil.Block 44 | ) 45 | for i := 0; i < numElements; i++ { 46 | var blockHash chainhash.Hash 47 | if _, err := rand.Read(blockHash[:]); err != nil { 48 | t.Fatalf("unable to read rand: %v", err) 49 | } 50 | 51 | blockHashes = append(blockHashes, blockHash) 52 | 53 | filter, err := gcs.FromBytes( 54 | uint32(i), uint8(i), uint64(i), []byte{byte(i)}, 55 | ) 56 | if err != nil { 57 | t.Fatalf("unable to create filter: %v", err) 58 | } 59 | filters = append(filters, filter) 60 | 61 | // Put the generated filter in the filter caches. 62 | cacheKey := FilterCacheKey{blockHash, filterType} 63 | for _, c := range filterCaches { 64 | _, _ = c.Put(cacheKey, &CacheableFilter{Filter: filter}) 65 | } 66 | 67 | msgBlock := &wire.MsgBlock{} 68 | block := btcutil.NewBlock(msgBlock) 69 | blocks = append(blocks, block) 70 | 71 | // Add the block to the block caches, using the block INV 72 | // vector as key. 73 | blockKey := wire.NewInvVect( 74 | wire.InvTypeWitnessBlock, &blockHash, 75 | ) 76 | for _, c := range blockCaches { 77 | _, _ = c.Put(*blockKey, &CacheableBlock{block}) 78 | } 79 | } 80 | 81 | // Now go through the list of block hashes, and make sure we can 82 | // retrieve all elements from the caches. 83 | for i, blockHash := range blockHashes { 84 | // Check filter caches. 85 | cacheKey := FilterCacheKey{blockHash, filterType} 86 | for _, c := range filterCaches { 87 | e, err := c.Get(cacheKey) 88 | if err != nil { 89 | t.Fatalf("Unable to get filter: %v", err) 90 | } 91 | 92 | // Ensure we got the correct filter. 93 | filter := e.Filter 94 | if filter != filters[i] { 95 | t.Fatalf("Filters not equal: %v vs %v ", 96 | filter, filters[i]) 97 | } 98 | } 99 | 100 | // Check block caches. 101 | blockKey := wire.NewInvVect( 102 | wire.InvTypeWitnessBlock, &blockHash, 103 | ) 104 | for _, c := range blockCaches { 105 | b, err := c.Get(*blockKey) 106 | if err != nil { 107 | t.Fatalf("Unable to get block: %v", err) 108 | } 109 | 110 | // Ensure it is the same block. 111 | block := b.Block 112 | if block != blocks[i] { 113 | t.Fatalf("Not equal: %v vs %v ", 114 | block, blocks[i]) 115 | } 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /cacheable_block.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import "github.com/btcsuite/btcd/btcutil" 4 | 5 | // CacheableBlock is a wrapper around the btcutil.Block type which provides a 6 | // Size method used by the cache to target certain memory usage. 7 | type CacheableBlock struct { 8 | *btcutil.Block 9 | } 10 | 11 | // Size returns size of this block in bytes. 12 | func (c *CacheableBlock) Size() (uint64, error) { 13 | return uint64(c.Block.MsgBlock().SerializeSize()), nil 14 | } 15 | -------------------------------------------------------------------------------- /cacheable_filter.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "github.com/btcsuite/btcd/btcutil/gcs" 5 | "github.com/btcsuite/btcd/chaincfg/chainhash" 6 | "github.com/lightninglabs/neutrino/filterdb" 7 | ) 8 | 9 | // FilterCacheKey represents the key used to access filters in the FilterCache. 10 | type FilterCacheKey struct { 11 | BlockHash chainhash.Hash 12 | FilterType filterdb.FilterType 13 | } 14 | 15 | // CacheableFilter is a wrapper around Filter type which provides a Size method 16 | // used by the cache to target certain memory usage. 17 | type CacheableFilter struct { 18 | *gcs.Filter 19 | } 20 | 21 | // Size returns size of this filter in bytes. 22 | func (c *CacheableFilter) Size() (uint64, error) { 23 | f, err := c.Filter.NBytes() 24 | if err != nil { 25 | return 0, err 26 | } 27 | return uint64(len(f)), nil 28 | } 29 | -------------------------------------------------------------------------------- /chainsync/filtercontrol.go: -------------------------------------------------------------------------------- 1 | package chainsync 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/btcsuite/btcd/chaincfg" 7 | "github.com/btcsuite/btcd/chaincfg/chainhash" 8 | "github.com/btcsuite/btcd/wire" 9 | ) 10 | 11 | // ErrCheckpointMismatch is returned if given filter headers don't pass our 12 | // control check. 13 | var ErrCheckpointMismatch = fmt.Errorf("checkpoint doesn't match") 14 | 15 | // filterHeaderCheckpoints holds a mapping from heights to filter headers for 16 | // various heights. We use them to check whether peers are serving us the 17 | // expected filter headers. 18 | var filterHeaderCheckpoints = map[wire.BitcoinNet]map[uint32]*chainhash.Hash{ 19 | // Mainnet filter header checkpoints. 20 | chaincfg.MainNetParams.Net: { 21 | 100000: hashFromStr("f28cbc1ab369eb01b7b5fe8bf59763abb73a31471fe404a26a06be4153aa7fa5"), 22 | 200000: hashFromStr("e5031471732f4fbfe7a25f6a03acc1413300d5c56ae8e06b95046b8e4c0f32b3"), 23 | 300000: hashFromStr("1bd50220fcdde929ca3143c91d2dd9a9bfedb38c452ba98dbb51e719bff8aa5b"), 24 | 400000: hashFromStr("5d973ab1f1c569c70deec1c1a8fb2e317a260f1656edb3b262c65f78ef192e3a"), 25 | 500000: hashFromStr("5d16ca293c9bdc0a9bc279b63f99fb661be38b095a59a44200a807caaa631a3c"), 26 | 600000: hashFromStr("bde0854d0b2f4386a860462547140e0c6817f5b4b2ab515ef70e204e377598f8"), 27 | 660000: hashFromStr("08312375fabc082b17fa8ee88443feb350c19a34bb7483f94f7478fa4ad33032"), 28 | }, 29 | 30 | // Testnet filter header checkpoints. 31 | chaincfg.TestNet3Params.Net: { 32 | 100000: hashFromStr("97c0633f14625627fcd133250ad8cc525937e776b5f3fd272b06d02c58b65a1c"), 33 | 200000: hashFromStr("51aa817e5abe3acdcf103616b1a5736caf84bc3773a7286e9081108ecc38cc87"), 34 | 400000: hashFromStr("4aab9b3d4312cd85cfcd48a08b36c4402bfdc1e8395dcf4236c3029dfa837c48"), 35 | 600000: hashFromStr("713d9c9198e2dba0739e85aab6875cb951c36297b95a2d51131aa6919753b55d"), 36 | 800000: hashFromStr("0dafdff27269a70293c120b14b1f5e9a72a5e8688098cfc6140b9d64f8325b99"), 37 | 1000000: hashFromStr("c2043fa2f6eb5f8f8d2c5584f743187f36302ed86b62c302e31155f378da9c5f"), 38 | 1400000: hashFromStr("f9ae1750483d4c8ce82512616b1ded932886af46decb8d3e575907930542d9b3"), 39 | 1500000: hashFromStr("dc0cfa13daf09df9b8dbe7532f75ebdb4255860b295016b2ca4b789394bc5090"), 40 | 1800000: hashFromStr("67083b2d5dfc9ca1415bffa14e43a5bbe595e2e8b7ffbcc7a4ea78fa069a9c8d"), 41 | 1900000: hashFromStr("96a31467f9edcaa3297770bc6cdf66926d5d17dfad70cb0cac285bfe9075c494"), 42 | }, 43 | 44 | // Testnet4 filter header checkpoints. 45 | chaincfg.TestNet4Params.Net: { 46 | 10000: hashFromStr("5bf92ba99cc9e4971e705ab3c4a2a78ef0ea40986ab20f2c06ebfe7751e3fbb8"), 47 | 50000: hashFromStr("66592214f388e315256e5c6362750b858d580dec269500ecca7a4fbb8042b8e3"), 48 | }, 49 | } 50 | 51 | // ControlCFHeader controls the given filter header against our list of 52 | // checkpoints. It returns ErrCheckpointMismatch if we have a checkpoint at the 53 | // given height, and it doesn't match. 54 | func ControlCFHeader(params chaincfg.Params, fType wire.FilterType, 55 | height uint32, filterHeader *chainhash.Hash) error { 56 | 57 | if fType != wire.GCSFilterRegular { 58 | return fmt.Errorf("unsupported filter type %v", fType) 59 | } 60 | 61 | control, ok := filterHeaderCheckpoints[params.Net] 62 | if !ok { 63 | return nil 64 | } 65 | 66 | hash, ok := control[height] 67 | if !ok { 68 | return nil 69 | } 70 | 71 | if *filterHeader != *hash { 72 | return ErrCheckpointMismatch 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // hashFromStr makes a chainhash.Hash from a valid hex string. If the string is 79 | // invalid, a nil pointer will be returned. 80 | func hashFromStr(hexStr string) *chainhash.Hash { 81 | hash, _ := chainhash.NewHashFromStr(hexStr) 82 | return hash 83 | } 84 | -------------------------------------------------------------------------------- /chainsync/filtercontrol_test.go: -------------------------------------------------------------------------------- 1 | package chainsync 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/btcsuite/btcd/chaincfg" 7 | "github.com/btcsuite/btcd/chaincfg/chainhash" 8 | "github.com/btcsuite/btcd/wire" 9 | ) 10 | 11 | func TestControlCFHeader(t *testing.T) { 12 | t.Parallel() 13 | 14 | // We'll modify our backing list of checkpoints for this test. 15 | height := uint32(999) 16 | header := hashFromStr( 17 | "4a242283a406a7c089f671bb8df7671e5d5e9ba577cea1047d30a7f4919df193", 18 | ) 19 | filterHeaderCheckpoints = map[wire.BitcoinNet]map[uint32]*chainhash.Hash{ 20 | chaincfg.MainNetParams.Net: { 21 | height: header, 22 | }, 23 | } 24 | 25 | // Expect the control at height to succeed. 26 | err := ControlCFHeader( 27 | chaincfg.MainNetParams, wire.GCSFilterRegular, height, header, 28 | ) 29 | if err != nil { 30 | t.Fatalf("error checking height: %v", err) 31 | } 32 | 33 | // Pass an invalid header, this should return an error. 34 | header = hashFromStr( 35 | "000000000006a7c089f671bb8df7671e5d5e9ba577cea1047d30a7f4919df193", 36 | ) 37 | err = ControlCFHeader( 38 | chaincfg.MainNetParams, wire.GCSFilterRegular, height, header, 39 | ) 40 | if err != ErrCheckpointMismatch { 41 | t.Fatalf("expected ErrCheckpointMismatch, got %v", err) 42 | } 43 | 44 | // Finally, control an unknown height. This should also pass since we 45 | // don't have the checkpoint stored. 46 | err = ControlCFHeader( 47 | chaincfg.MainNetParams, wire.GCSFilterRegular, 99, header, 48 | ) 49 | if err != nil { 50 | t.Fatalf("error checking height: %v", err) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /chanutils/batch_writer.go: -------------------------------------------------------------------------------- 1 | package chanutils 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // BatchWriterConfig holds the configuration options for BatchWriter. 9 | type BatchWriterConfig[T any] struct { 10 | // QueueBufferSize sets the buffer size of the output channel of the 11 | // concurrent queue used by the BatchWriter. 12 | QueueBufferSize int 13 | 14 | // MaxBatch is the maximum number of filters to be persisted to the DB 15 | // in one go. 16 | MaxBatch int 17 | 18 | // DBWritesTickerDuration is the time after receiving a filter that the 19 | // writer will wait for more filters before writing the current batch 20 | // to the DB. 21 | DBWritesTickerDuration time.Duration 22 | 23 | // PutItems will be used by the BatchWriter to persist filters in 24 | // batches. 25 | PutItems func(...T) error 26 | } 27 | 28 | // BatchWriter manages writing Filters to the DB and tries to batch the writes 29 | // as much as possible. 30 | type BatchWriter[T any] struct { 31 | started sync.Once 32 | stopped sync.Once 33 | 34 | cfg *BatchWriterConfig[T] 35 | 36 | queue *ConcurrentQueue[T] 37 | 38 | quit chan struct{} 39 | wg sync.WaitGroup 40 | } 41 | 42 | // NewBatchWriter constructs a new BatchWriter using the given 43 | // BatchWriterConfig. 44 | func NewBatchWriter[T any](cfg *BatchWriterConfig[T]) *BatchWriter[T] { 45 | return &BatchWriter[T]{ 46 | cfg: cfg, 47 | queue: NewConcurrentQueue[T](cfg.QueueBufferSize), 48 | quit: make(chan struct{}), 49 | } 50 | } 51 | 52 | // Start starts the BatchWriter. 53 | func (b *BatchWriter[T]) Start() { 54 | b.started.Do(func() { 55 | b.queue.Start() 56 | 57 | b.wg.Add(1) 58 | go b.manageNewItems() 59 | }) 60 | } 61 | 62 | // Stop stops the BatchWriter. 63 | func (b *BatchWriter[T]) Stop() { 64 | b.stopped.Do(func() { 65 | close(b.quit) 66 | b.wg.Wait() 67 | 68 | b.queue.Stop() 69 | }) 70 | } 71 | 72 | // AddItem adds a given item to the BatchWriter queue. 73 | func (b *BatchWriter[T]) AddItem(item T) { 74 | b.queue.ChanIn() <- item 75 | } 76 | 77 | // manageNewItems manages collecting filters and persisting them to the DB. 78 | // There are two conditions for writing a batch of filters to the DB: the first 79 | // is if a certain threshold (MaxBatch) of filters has been collected and the 80 | // other is if at least one filter has been collected and a timeout has been 81 | // reached. 82 | // 83 | // NOTE: this must be run in a goroutine. 84 | func (b *BatchWriter[T]) manageNewItems() { 85 | defer b.wg.Done() 86 | 87 | batch := make([]T, 0, b.cfg.MaxBatch) 88 | 89 | // writeBatch writes the current contents of the batch slice to the 90 | // filters DB. 91 | writeBatch := func() { 92 | if len(batch) == 0 { 93 | return 94 | } 95 | 96 | err := b.cfg.PutItems(batch...) 97 | if err != nil { 98 | log.Errorf("Could not write filters to filterDB: %v", 99 | err) 100 | } 101 | 102 | // Empty the batch slice. 103 | batch = make([]T, 0, b.cfg.MaxBatch) 104 | } 105 | 106 | ticker := time.NewTicker(b.cfg.DBWritesTickerDuration) 107 | defer ticker.Stop() 108 | 109 | // Stop the ticker since we don't want it to tick unless there is at 110 | // least one item in the queue. 111 | ticker.Stop() 112 | 113 | for { 114 | select { 115 | case filter, ok := <-b.queue.ChanOut(): 116 | if !ok { 117 | return 118 | } 119 | 120 | batch = append(batch, filter) 121 | 122 | switch len(batch) { 123 | // If the batch slice is full, we stop the ticker and 124 | // write the batch contents to disk. 125 | case b.cfg.MaxBatch: 126 | ticker.Stop() 127 | writeBatch() 128 | 129 | // If an item is added to the batch, we reset the timer. 130 | // This ensures that if the batch threshold is not met 131 | // then items are still persisted in a timely manner. 132 | default: 133 | ticker.Reset(b.cfg.DBWritesTickerDuration) 134 | } 135 | 136 | case <-ticker.C: 137 | // If the ticker ticks, then we stop it and write the 138 | // current batch contents to the db. If any more items 139 | // are added, the ticker will be reset. 140 | ticker.Stop() 141 | writeBatch() 142 | 143 | case <-b.quit: 144 | writeBatch() 145 | 146 | return 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /chanutils/batch_writer_test.go: -------------------------------------------------------------------------------- 1 | package chanutils 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | const waitTime = time.Second * 5 14 | 15 | // TestBatchWriter tests that the BatchWriter behaves as expected. 16 | func TestBatchWriter(t *testing.T) { 17 | t.Parallel() 18 | rand.Seed(time.Now().UnixNano()) 19 | 20 | // waitForItems is a helper function that will wait for a given set of 21 | // items to appear in the db. 22 | waitForItems := func(db *mockItemsDB, items ...*item) { 23 | err := waitFor(func() bool { 24 | return db.hasItems(items...) 25 | }, waitTime) 26 | require.NoError(t, err) 27 | } 28 | 29 | t.Run("filters persisted after ticker", func(t *testing.T) { 30 | t.Parallel() 31 | 32 | // Create a mock filters DB. 33 | db := newMockItemsDB() 34 | 35 | // Construct a new BatchWriter backed by the mock db. 36 | b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ 37 | QueueBufferSize: 10, 38 | MaxBatch: 20, 39 | DBWritesTickerDuration: time.Millisecond * 500, 40 | PutItems: db.PutItems, 41 | }) 42 | b.Start() 43 | t.Cleanup(b.Stop) 44 | 45 | fs := genFilterSet(5) 46 | for _, f := range fs { 47 | b.AddItem(f) 48 | } 49 | waitForItems(db, fs...) 50 | }) 51 | 52 | t.Run("write once threshold is reached", func(t *testing.T) { 53 | t.Parallel() 54 | 55 | // Create a mock filters DB. 56 | db := newMockItemsDB() 57 | 58 | // Construct a new BatchWriter backed by the mock db. 59 | // Make the DB writes ticker duration extra long so that we 60 | // can explicitly test that the batch gets persisted if the 61 | // MaxBatch threshold is reached. 62 | b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ 63 | QueueBufferSize: 10, 64 | MaxBatch: 20, 65 | DBWritesTickerDuration: time.Hour, 66 | PutItems: db.PutItems, 67 | }) 68 | b.Start() 69 | t.Cleanup(b.Stop) 70 | 71 | // Generate 30 filters and add each one to the batch writer. 72 | fs := genFilterSet(30) 73 | for _, f := range fs { 74 | b.AddItem(f) 75 | } 76 | 77 | // Since the MaxBatch threshold has been reached, we expect the 78 | // first 20 filters to be persisted. 79 | waitForItems(db, fs[:20]...) 80 | 81 | // Since the last 10 filters don't reach the threshold and since 82 | // the ticker has definitely not ticked yet, we don't expect the 83 | // last 10 filters to be in the db yet. 84 | require.False(t, db.hasItems(fs[21:]...)) 85 | }) 86 | 87 | t.Run("stress test", func(t *testing.T) { 88 | t.Parallel() 89 | 90 | // Create a mock filters DB. 91 | db := newMockItemsDB() 92 | 93 | // Construct a new BatchWriter backed by the mock db. 94 | // Make the DB writes ticker duration extra long so that we 95 | // can explicitly test that the batch gets persisted if the 96 | // MaxBatch threshold is reached. 97 | b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ 98 | QueueBufferSize: 5, 99 | MaxBatch: 5, 100 | DBWritesTickerDuration: time.Millisecond * 2, 101 | PutItems: db.PutItems, 102 | }) 103 | b.Start() 104 | t.Cleanup(b.Stop) 105 | 106 | // Generate lots of filters and add each to the batch writer. 107 | // Sleep for a bit between each filter to ensure that we 108 | // sometimes hit the timeout write and sometimes the threshold 109 | // write. 110 | fs := genFilterSet(1000) 111 | for _, f := range fs { 112 | b.AddItem(f) 113 | 114 | n := rand.Intn(3) 115 | time.Sleep(time.Duration(n) * time.Millisecond) 116 | } 117 | 118 | // Since the MaxBatch threshold has been reached, we expect the 119 | // first 20 filters to be persisted. 120 | waitForItems(db, fs...) 121 | }) 122 | } 123 | 124 | type item struct { 125 | i int 126 | } 127 | 128 | // mockItemsDB is a mock DB that holds a set of items. 129 | type mockItemsDB struct { 130 | items map[int]bool 131 | mu sync.Mutex 132 | } 133 | 134 | // newMockItemsDB constructs a new mockItemsDB. 135 | func newMockItemsDB() *mockItemsDB { 136 | return &mockItemsDB{ 137 | items: make(map[int]bool), 138 | } 139 | } 140 | 141 | // hasItems returns true if the db contains all the given items. 142 | func (m *mockItemsDB) hasItems(items ...*item) bool { 143 | m.mu.Lock() 144 | defer m.mu.Unlock() 145 | 146 | for _, i := range items { 147 | _, ok := m.items[i.i] 148 | if !ok { 149 | return false 150 | } 151 | } 152 | 153 | return true 154 | } 155 | 156 | // PutItems adds a set of items to the db. 157 | func (m *mockItemsDB) PutItems(items ...*item) error { 158 | m.mu.Lock() 159 | defer m.mu.Unlock() 160 | 161 | for _, i := range items { 162 | m.items[i.i] = true 163 | } 164 | 165 | return nil 166 | } 167 | 168 | // genItemSet generates a set of numFilters items. 169 | func genFilterSet(numFilters int) []*item { 170 | res := make([]*item, numFilters) 171 | for i := 0; i < numFilters; i++ { 172 | res[i] = &item{i: i} 173 | } 174 | 175 | return res 176 | } 177 | 178 | // pollInterval is a constant specifying a 200 ms interval. 179 | const pollInterval = 200 * time.Millisecond 180 | 181 | // waitFor is a helper test function that will wait for a timeout period of 182 | // time until the passed predicate returns true. This function is helpful as 183 | // timing doesn't always line up well when running integration tests with 184 | // several running lnd nodes. This function gives callers a way to assert that 185 | // some property is upheld within a particular time frame. 186 | func waitFor(pred func() bool, timeout time.Duration) error { 187 | exitTimer := time.After(timeout) 188 | result := make(chan bool, 1) 189 | 190 | for { 191 | <-time.After(pollInterval) 192 | 193 | go func() { 194 | result <- pred() 195 | }() 196 | 197 | // Each time we call the pred(), we expect a result to be 198 | // returned otherwise it will timeout. 199 | select { 200 | case <-exitTimer: 201 | return fmt.Errorf("predicate not satisfied after " + 202 | "time out") 203 | 204 | case succeed := <-result: 205 | if succeed { 206 | return nil 207 | } 208 | } 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /chanutils/log.go: -------------------------------------------------------------------------------- 1 | package chanutils 2 | 3 | import "github.com/btcsuite/btclog" 4 | 5 | // log is a logger that is initialized with no output filters. This 6 | // means the package will not perform any logging by default until the caller 7 | // requests it. 8 | var log btclog.Logger 9 | 10 | // The default amount of logging is none. 11 | func init() { 12 | DisableLog() 13 | } 14 | 15 | // DisableLog disables all library log output. Logging output is disabled 16 | // by default until either UseLogger or SetLogWriter are called. 17 | func DisableLog() { 18 | UseLogger(btclog.Disabled) 19 | } 20 | 21 | // UseLogger uses a specified Logger to output package logging info. 22 | // This should be used in preference to SetLogWriter if the caller is also 23 | // using btclog. 24 | func UseLogger(logger btclog.Logger) { 25 | log = logger 26 | } 27 | -------------------------------------------------------------------------------- /chanutils/queue.go: -------------------------------------------------------------------------------- 1 | package chanutils 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/lightninglabs/neutrino/cache/lru" 7 | ) 8 | 9 | const ( 10 | // DefaultQueueSize is the default size to use for concurrent queues. 11 | DefaultQueueSize = 10 12 | ) 13 | 14 | // ConcurrentQueue is a typed concurrent-safe FIFO queue with unbounded 15 | // capacity. Clients interact with the queue by pushing items into the in 16 | // channel and popping items from the out channel. There is a goroutine that 17 | // manages moving items from the in channel to the out channel in the correct 18 | // order that must be started by calling Start(). 19 | type ConcurrentQueue[T any] struct { 20 | started sync.Once 21 | stopped sync.Once 22 | 23 | chanIn chan T 24 | chanOut chan T 25 | overflow *lru.List[T] 26 | 27 | wg sync.WaitGroup 28 | quit chan struct{} 29 | } 30 | 31 | // NewConcurrentQueue constructs a ConcurrentQueue. The bufferSize parameter is 32 | // the capacity of the output channel. When the size of the queue is below this 33 | // threshold, pushes do not incur the overhead of the less efficient overflow 34 | // structure. 35 | func NewConcurrentQueue[T any](bufferSize int) *ConcurrentQueue[T] { 36 | return &ConcurrentQueue[T]{ 37 | chanIn: make(chan T), 38 | chanOut: make(chan T, bufferSize), 39 | overflow: lru.NewList[T](), 40 | quit: make(chan struct{}), 41 | } 42 | } 43 | 44 | // ChanIn returns a channel that can be used to push new items into the queue. 45 | func (cq *ConcurrentQueue[T]) ChanIn() chan<- T { 46 | return cq.chanIn 47 | } 48 | 49 | // ChanOut returns a channel that can be used to pop items from the queue. 50 | func (cq *ConcurrentQueue[T]) ChanOut() <-chan T { 51 | return cq.chanOut 52 | } 53 | 54 | // Start begins a goroutine that manages moving items from the in channel to the 55 | // out channel. The queue tries to move items directly to the out channel 56 | // minimize overhead, but if the out channel is full it pushes items to an 57 | // overflow queue. This must be called before using the queue. 58 | func (cq *ConcurrentQueue[T]) Start() { 59 | cq.started.Do(cq.start) 60 | } 61 | 62 | func (cq *ConcurrentQueue[T]) start() { 63 | cq.wg.Add(1) 64 | go func() { 65 | defer cq.wg.Done() 66 | 67 | readLoop: 68 | for { 69 | nextElement := cq.overflow.Front() 70 | if nextElement == nil { 71 | // Overflow queue is empty so incoming items can 72 | // be pushed directly to the output channel. If 73 | // output channel is full though, push to 74 | // overflow. 75 | select { 76 | case item, ok := <-cq.chanIn: 77 | if !ok { 78 | log.Warnf("ConcurrentQueue " + 79 | "has exited due to " + 80 | "the input channel " + 81 | "being closed") 82 | 83 | break readLoop 84 | } 85 | select { 86 | case cq.chanOut <- item: 87 | // Optimistically push directly 88 | // to chanOut. 89 | default: 90 | cq.overflow.PushBack(item) 91 | } 92 | case <-cq.quit: 93 | return 94 | } 95 | } else { 96 | // Overflow queue is not empty, so any new items 97 | // get pushed to the back to preserve order. 98 | select { 99 | case item, ok := <-cq.chanIn: 100 | if !ok { 101 | log.Warnf("ConcurrentQueue " + 102 | "has exited due to " + 103 | "the input channel " + 104 | "being closed") 105 | 106 | break readLoop 107 | } 108 | cq.overflow.PushBack(item) 109 | case cq.chanOut <- nextElement.Value: 110 | cq.overflow.Remove(nextElement) 111 | case <-cq.quit: 112 | return 113 | } 114 | } 115 | } 116 | 117 | // Incoming channel has been closed. Empty overflow queue into 118 | // the outgoing channel. 119 | nextElement := cq.overflow.Front() 120 | for nextElement != nil { 121 | select { 122 | case cq.chanOut <- nextElement.Value: 123 | cq.overflow.Remove(nextElement) 124 | case <-cq.quit: 125 | return 126 | } 127 | nextElement = cq.overflow.Front() 128 | } 129 | 130 | // Close outgoing channel. 131 | close(cq.chanOut) 132 | }() 133 | } 134 | 135 | // Stop ends the goroutine that moves items from the in channel to the out 136 | // channel. This does not clear the queue state, so the queue can be restarted 137 | // without dropping items. 138 | func (cq *ConcurrentQueue[T]) Stop() { 139 | cq.stopped.Do(func() { 140 | close(cq.quit) 141 | cq.wg.Wait() 142 | }) 143 | } 144 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrGetUtxoCancelled signals that a GetUtxo request was cancelled. 7 | ErrGetUtxoCancelled = errors.New("get utxo request cancelled") 8 | 9 | // ErrShuttingDown signals that neutrino received a shutdown request. 10 | ErrShuttingDown = errors.New("neutrino shutting down") 11 | ) 12 | -------------------------------------------------------------------------------- /filterdb/db.go: -------------------------------------------------------------------------------- 1 | package filterdb 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/btcsuite/btcd/btcutil/gcs" 7 | "github.com/btcsuite/btcd/btcutil/gcs/builder" 8 | "github.com/btcsuite/btcd/chaincfg" 9 | "github.com/btcsuite/btcd/chaincfg/chainhash" 10 | "github.com/btcsuite/btcwallet/walletdb" 11 | ) 12 | 13 | var ( 14 | // filterBucket is the name of the root bucket for this package. Within 15 | // this bucket, sub-buckets are stored which themselves store the 16 | // actual filters. 17 | filterBucket = []byte("filter-store") 18 | 19 | // regBucket is the bucket that stores the regular filters. 20 | regBucket = []byte("regular") 21 | 22 | // ErrFilterNotFound is returned when a filter for a target block hash 23 | // is unable to be located. 24 | ErrFilterNotFound = fmt.Errorf("unable to find filter") 25 | ) 26 | 27 | // FilterType is an enum-like type that represents the various filter types 28 | // currently defined. 29 | type FilterType uint8 30 | 31 | const ( 32 | // RegularFilter is the filter type of regular filters which contain 33 | // outputs and pkScript data pushes. 34 | RegularFilter FilterType = iota 35 | ) 36 | 37 | // FilterData holds all the info about a filter required to store it. 38 | type FilterData struct { 39 | // Filter is the actual filter to be stored. 40 | Filter *gcs.Filter 41 | 42 | // BlockHash is the block header hash of the block associated with the 43 | // Filter. 44 | BlockHash *chainhash.Hash 45 | 46 | // Type is the filter type. 47 | Type FilterType 48 | } 49 | 50 | // FilterDatabase is an interface which represents an object that is capable of 51 | // storing and retrieving filters according to their corresponding block hash 52 | // and also their filter type. 53 | // 54 | // TODO(roasbeef): similar interface for headerfs? 55 | type FilterDatabase interface { 56 | // PutFilters stores a set of filters to persistent storage. 57 | PutFilters(...*FilterData) error 58 | 59 | // FetchFilter attempts to fetch a filter with the given hash and type 60 | // from persistent storage. In the case that a filter matching the 61 | // target block hash cannot be found, then ErrFilterNotFound is to be 62 | // returned. 63 | FetchFilter(*chainhash.Hash, FilterType) (*gcs.Filter, error) 64 | 65 | // PurgeFilters purge all filters with a given type from persistent 66 | // storage. 67 | PurgeFilters(FilterType) error 68 | } 69 | 70 | // FilterStore is an implementation of the FilterDatabase interface which is 71 | // backed by boltdb. 72 | type FilterStore struct { 73 | db walletdb.DB 74 | } 75 | 76 | // A compile-time check to ensure the FilterStore adheres to the FilterDatabase 77 | // interface. 78 | var _ FilterDatabase = (*FilterStore)(nil) 79 | 80 | // New creates a new instance of the FilterStore given an already open 81 | // database, and the target chain parameters. 82 | func New(db walletdb.DB, params chaincfg.Params) (*FilterStore, error) { 83 | err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error { 84 | // As part of our initial setup, we'll try to create the top 85 | // level filter bucket. If this already exists, then we can 86 | // exit early. 87 | filters, err := tx.CreateTopLevelBucket(filterBucket) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | // If the main bucket doesn't already exist, then we'll need to 93 | // create the sub-buckets, and also initialize them with the 94 | // genesis filters. 95 | genesisBlock := params.GenesisBlock 96 | genesisHash := params.GenesisHash 97 | 98 | // First we'll create the bucket for the regular filters. 99 | regFilters, err := filters.CreateBucketIfNotExists(regBucket) 100 | if err != nil { 101 | return err 102 | } 103 | 104 | // With the bucket created, we'll now construct the initial 105 | // basic genesis filter and store it within the database. 106 | basicFilter, err := builder.BuildBasicFilter(genesisBlock, nil) 107 | if err != nil { 108 | return err 109 | } 110 | 111 | return putFilter(regFilters, genesisHash, basicFilter) 112 | }) 113 | if err != nil && err != walletdb.ErrBucketExists { 114 | return nil, err 115 | } 116 | 117 | return &FilterStore{ 118 | db: db, 119 | }, nil 120 | } 121 | 122 | // PurgeFilters purge all filters with a given type from persistent storage. 123 | // 124 | // NOTE: This method is a part of the FilterDatabase interface. 125 | func (f *FilterStore) PurgeFilters(fType FilterType) error { 126 | return walletdb.Update(f.db, func(tx walletdb.ReadWriteTx) error { 127 | filters := tx.ReadWriteBucket(filterBucket) 128 | 129 | switch fType { 130 | case RegularFilter: 131 | err := filters.DeleteNestedBucket(regBucket) 132 | if err != nil { 133 | return err 134 | } 135 | 136 | _, err = filters.CreateBucket(regBucket) 137 | if err != nil { 138 | return err 139 | } 140 | default: 141 | return fmt.Errorf("unknown filter type: %v", fType) 142 | } 143 | 144 | return nil 145 | }) 146 | } 147 | 148 | // putFilter stores a filter in the database according to the corresponding 149 | // block hash. The passed bucket is expected to be the proper bucket for the 150 | // passed filter type. 151 | func putFilter(bucket walletdb.ReadWriteBucket, hash *chainhash.Hash, 152 | filter *gcs.Filter) error { 153 | 154 | if filter == nil { 155 | return bucket.Put(hash[:], nil) 156 | } 157 | 158 | bytes, err := filter.NBytes() 159 | if err != nil { 160 | return err 161 | } 162 | 163 | return bucket.Put(hash[:], bytes) 164 | } 165 | 166 | // PutFilters stores a set of filters to persistent storage. 167 | // 168 | // NOTE: This method is a part of the FilterDatabase interface. 169 | func (f *FilterStore) PutFilters(filterList ...*FilterData) error { 170 | var updateErr error 171 | err := walletdb.Batch(f.db, func(tx walletdb.ReadWriteTx) error { 172 | filters := tx.ReadWriteBucket(filterBucket) 173 | regularFilterBkt := filters.NestedReadWriteBucket(regBucket) 174 | 175 | for _, filterData := range filterList { 176 | var targetBucket walletdb.ReadWriteBucket 177 | switch filterData.Type { 178 | case RegularFilter: 179 | targetBucket = regularFilterBkt 180 | default: 181 | updateErr = fmt.Errorf("unknown filter "+ 182 | "type: %v", filterData.Type) 183 | 184 | return nil 185 | } 186 | 187 | err := putFilter( 188 | targetBucket, filterData.BlockHash, 189 | filterData.Filter, 190 | ) 191 | if err != nil { 192 | return err 193 | } 194 | 195 | log.Tracef("Wrote filter for block %s, type %d", 196 | &filterData.BlockHash, filterData.Type) 197 | } 198 | 199 | return nil 200 | }) 201 | if err != nil { 202 | return err 203 | } 204 | 205 | return updateErr 206 | } 207 | 208 | // FetchFilter attempts to fetch a filter with the given hash and type from 209 | // persistent storage. 210 | // 211 | // NOTE: This method is a part of the FilterDatabase interface. 212 | func (f *FilterStore) FetchFilter(blockHash *chainhash.Hash, 213 | filterType FilterType) (*gcs.Filter, error) { 214 | 215 | var filter *gcs.Filter 216 | 217 | err := walletdb.View(f.db, func(tx walletdb.ReadTx) error { 218 | filters := tx.ReadBucket(filterBucket) 219 | 220 | var targetBucket walletdb.ReadBucket 221 | switch filterType { 222 | case RegularFilter: 223 | targetBucket = filters.NestedReadBucket(regBucket) 224 | default: 225 | return fmt.Errorf("unknown filter type") 226 | } 227 | 228 | filterBytes := targetBucket.Get(blockHash[:]) 229 | if filterBytes == nil { 230 | return ErrFilterNotFound 231 | } 232 | if len(filterBytes) == 0 { 233 | return nil 234 | } 235 | 236 | dbFilter, err := gcs.FromNBytes( 237 | builder.DefaultP, builder.DefaultM, filterBytes, 238 | ) 239 | if err != nil { 240 | return err 241 | } 242 | 243 | filter = dbFilter 244 | return nil 245 | }) 246 | if err != nil { 247 | return nil, err 248 | } 249 | 250 | return filter, nil 251 | } 252 | -------------------------------------------------------------------------------- /filterdb/db_test.go: -------------------------------------------------------------------------------- 1 | package filterdb 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | "time" 7 | 8 | "github.com/btcsuite/btcd/btcutil/gcs" 9 | "github.com/btcsuite/btcd/btcutil/gcs/builder" 10 | "github.com/btcsuite/btcd/chaincfg" 11 | "github.com/btcsuite/btcd/chaincfg/chainhash" 12 | "github.com/btcsuite/btcwallet/walletdb" 13 | _ "github.com/btcsuite/btcwallet/walletdb/bdb" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func createTestDatabase(t *testing.T) FilterDatabase { 18 | tempDir := t.TempDir() 19 | 20 | db, err := walletdb.Create( 21 | "bdb", tempDir+"/test.db", true, time.Second*10, 22 | ) 23 | require.NoError(t, err) 24 | t.Cleanup(func() { 25 | require.NoError(t, db.Close()) 26 | }) 27 | 28 | filterDB, err := New(db, chaincfg.SimNetParams) 29 | require.NoError(t, err) 30 | 31 | return filterDB 32 | } 33 | 34 | // TestGenesisFilterCreation tests the fetching of the genesis block filter. 35 | func TestGenesisFilterCreation(t *testing.T) { 36 | var ( 37 | database = createTestDatabase(t) 38 | genesisHash = chaincfg.SimNetParams.GenesisHash 39 | ) 40 | 41 | // With the database initialized, we should be able to fetch the 42 | // regular filter for the genesis block. 43 | regGenesisFilter, err := database.FetchFilter( 44 | genesisHash, RegularFilter, 45 | ) 46 | require.NoError(t, err) 47 | 48 | // The regular filter should be non-nil as the gensis block's output 49 | // and the coinbase txid should be indexed. 50 | require.NotNil(t, regGenesisFilter) 51 | } 52 | 53 | func genRandFilter(t *testing.T, numElements uint32) *gcs.Filter { 54 | elements := make([][]byte, numElements) 55 | for i := uint32(0); i < numElements; i++ { 56 | var elem [20]byte 57 | _, err := rand.Read(elem[:]) 58 | require.NoError(t, err) 59 | 60 | elements[i] = elem[:] 61 | } 62 | 63 | var key [16]byte 64 | _, err := rand.Read(key[:]) 65 | require.NoError(t, err) 66 | 67 | filter, err := gcs.BuildGCSFilter( 68 | builder.DefaultP, builder.DefaultM, key, elements, 69 | ) 70 | require.NoError(t, err) 71 | 72 | return filter 73 | } 74 | 75 | // TestFilterStorage test writing to and reading from the filter DB. 76 | func TestFilterStorage(t *testing.T) { 77 | database := createTestDatabase(t) 78 | 79 | // We'll generate a random block hash to create our test filters 80 | // against. 81 | var randHash chainhash.Hash 82 | _, err := rand.Read(randHash[:]) 83 | require.NoError(t, err) 84 | 85 | // First, we'll create and store a random filter for the regular filter 86 | // type for the block hash generate above. 87 | regFilter := genRandFilter(t, 100) 88 | 89 | filter := &FilterData{ 90 | Filter: regFilter, 91 | BlockHash: &randHash, 92 | Type: RegularFilter, 93 | } 94 | 95 | err = database.PutFilters(filter) 96 | require.NoError(t, err) 97 | 98 | // With the filter stored, we should be able to retrieve the filter 99 | // without any issue, and it should match the stored filter exactly. 100 | regFilterDB, err := database.FetchFilter(&randHash, RegularFilter) 101 | require.NoError(t, err) 102 | require.Equal(t, regFilter, regFilterDB) 103 | } 104 | -------------------------------------------------------------------------------- /filterdb/log.go: -------------------------------------------------------------------------------- 1 | package filterdb 2 | 3 | import "github.com/btcsuite/btclog" 4 | 5 | // log is a logger that is initialized with no output filters. This 6 | // means the package will not perform any logging by default until the caller 7 | // requests it. 8 | var log btclog.Logger 9 | 10 | // The default amount of logging is none. 11 | func init() { 12 | DisableLog() 13 | } 14 | 15 | // DisableLog disables all library log output. Logging output is disabled 16 | // by default until either UseLogger or SetLogWriter are called. 17 | func DisableLog() { 18 | UseLogger(btclog.Disabled) 19 | } 20 | 21 | // UseLogger uses a specified Logger to output package logging info. 22 | // This should be used in preference to SetLogWriter if the caller is also 23 | // using btclog. 24 | func UseLogger(logger btclog.Logger) { 25 | log = logger 26 | } 27 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lightninglabs/neutrino 2 | 3 | require ( 4 | github.com/btcsuite/btcd v0.24.3-0.20250318170759-4f4ea81776d6 5 | github.com/btcsuite/btcd/btcec/v2 v2.3.4 6 | github.com/btcsuite/btcd/btcutil v1.1.5 7 | github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 8 | github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f 9 | github.com/btcsuite/btcwallet/wallet/txauthor v1.2.3 10 | github.com/btcsuite/btcwallet/walletdb v1.3.5 11 | github.com/btcsuite/btcwallet/wtxmgr v1.5.0 12 | github.com/davecgh/go-spew v1.1.1 13 | github.com/lightninglabs/neutrino/cache v1.1.2 14 | github.com/lightningnetwork/lnd/queue v1.0.1 15 | github.com/stretchr/testify v1.9.0 16 | ) 17 | 18 | require ( 19 | github.com/aead/siphash v1.0.1 // indirect 20 | github.com/btcsuite/btcwallet/wallet/txrules v1.2.0 // indirect 21 | github.com/btcsuite/btcwallet/wallet/txsizes v1.1.0 // indirect 22 | github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd // indirect 23 | github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 // indirect 24 | github.com/decred/dcrd/crypto/blake256 v1.0.1 // indirect 25 | github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect 26 | github.com/decred/dcrd/lru v1.1.2 // indirect 27 | github.com/kkdai/bstream v1.0.0 // indirect 28 | github.com/kr/pretty v0.3.0 // indirect 29 | github.com/lightningnetwork/lnd/clock v1.0.1 // indirect 30 | github.com/lightningnetwork/lnd/ticker v1.0.0 // indirect 31 | github.com/pmezard/go-difflib v1.0.0 // indirect 32 | github.com/rogpeppe/go-internal v1.12.0 // indirect 33 | github.com/stretchr/objx v0.5.2 // indirect 34 | go.etcd.io/bbolt v1.3.7 // indirect 35 | golang.org/x/crypto v0.22.0 // indirect 36 | golang.org/x/sys v0.19.0 // indirect 37 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 38 | gopkg.in/yaml.v3 v3.0.1 // indirect 39 | ) 40 | 41 | go 1.23.6 42 | -------------------------------------------------------------------------------- /headerfs/file.go: -------------------------------------------------------------------------------- 1 | package headerfs 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/btcsuite/btcd/chaincfg/chainhash" 9 | "github.com/btcsuite/btcd/wire" 10 | ) 11 | 12 | // ErrHeaderNotFound is returned when a target header on disk (flat file) can't 13 | // be found. 14 | type ErrHeaderNotFound struct { 15 | error 16 | } 17 | 18 | // appendRaw appends a new raw header to the end of the flat file. 19 | func (h *headerStore) appendRaw(header []byte) error { 20 | if _, err := h.file.Write(header); err != nil { 21 | return err 22 | } 23 | 24 | return nil 25 | } 26 | 27 | // readRaw reads a raw header from disk from a particular seek distance. The 28 | // amount of bytes read past the seek distance is determined by the specified 29 | // header type. 30 | func (h *headerStore) readRaw(seekDist uint64) ([]byte, error) { 31 | var headerSize uint32 32 | 33 | // Based on the defined header type, we'll determine the number of 34 | // bytes that we need to read past the sync point. 35 | switch h.indexType { 36 | case Block: 37 | headerSize = 80 38 | 39 | case RegularFilter: 40 | headerSize = 32 41 | 42 | default: 43 | return nil, fmt.Errorf("unknown index type: %v", h.indexType) 44 | } 45 | 46 | // TODO(roasbeef): add buffer pool 47 | 48 | // With the number of bytes to read determined, we'll create a slice 49 | // for that number of bytes, and read directly from the file into the 50 | // buffer. 51 | rawHeader := make([]byte, headerSize) 52 | if _, err := h.file.ReadAt(rawHeader, int64(seekDist)); err != nil { 53 | return nil, &ErrHeaderNotFound{err} 54 | } 55 | 56 | return rawHeader, nil 57 | } 58 | 59 | // readHeaderRange will attempt to fetch a series of block headers within the 60 | // target height range. This method batches a set of reads into a single system 61 | // call thereby increasing performance when reading a set of contiguous 62 | // headers. 63 | // 64 | // NOTE: The end height is _inclusive_ so we'll fetch all headers from the 65 | // startHeight up to the end height, including the final header. 66 | func (h *blockHeaderStore) readHeaderRange(startHeight uint32, 67 | endHeight uint32) ([]wire.BlockHeader, error) { 68 | 69 | // Based on the defined header type, we'll determine the number of 70 | // bytes that we need to read from the file. 71 | headerReader, err := readHeadersFromFile( 72 | h.file, BlockHeaderSize, startHeight, endHeight, 73 | ) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | // We'll now incrementally parse out the set of individual headers from 79 | // our set of serialized contiguous raw headers. 80 | numHeaders := endHeight - startHeight + 1 81 | headers := make([]wire.BlockHeader, 0, numHeaders) 82 | for headerReader.Len() != 0 { 83 | var nextHeader wire.BlockHeader 84 | if err := nextHeader.Deserialize(headerReader); err != nil { 85 | return nil, err 86 | } 87 | 88 | headers = append(headers, nextHeader) 89 | } 90 | 91 | return headers, nil 92 | } 93 | 94 | // readHeader reads a full block header from the flat-file. The header read is 95 | // determined by the height value. 96 | func (h *blockHeaderStore) readHeader(height uint32) (wire.BlockHeader, error) { 97 | var header wire.BlockHeader 98 | 99 | // Each header is 80 bytes, so using this information, we'll seek a 100 | // distance to cover that height based on the size of block headers. 101 | seekDistance := uint64(height) * 80 102 | 103 | // With the distance calculated, we'll raw a raw header start from that 104 | // offset. 105 | rawHeader, err := h.readRaw(seekDistance) 106 | if err != nil { 107 | return header, err 108 | } 109 | headerReader := bytes.NewReader(rawHeader) 110 | 111 | // Finally, decode the raw bytes into a proper bitcoin header. 112 | if err := header.Deserialize(headerReader); err != nil { 113 | return header, err 114 | } 115 | 116 | return header, nil 117 | } 118 | 119 | // readHeader reads a single filter header at the specified height from the 120 | // flat files on disk. 121 | func (f *FilterHeaderStore) readHeader(height uint32) (*chainhash.Hash, error) { 122 | seekDistance := uint64(height) * 32 123 | 124 | rawHeader, err := f.readRaw(seekDistance) 125 | if err != nil { 126 | return nil, err 127 | } 128 | 129 | return chainhash.NewHash(rawHeader) 130 | } 131 | 132 | // readHeaderRange will attempt to fetch a series of filter headers within the 133 | // target height range. This method batches a set of reads into a single system 134 | // call thereby increasing performance when reading a set of contiguous 135 | // headers. 136 | // 137 | // NOTE: The end height is _inclusive_ so we'll fetch all headers from the 138 | // startHeight up to the end height, including the final header. 139 | func (f *FilterHeaderStore) readHeaderRange(startHeight uint32, 140 | endHeight uint32) ([]chainhash.Hash, error) { 141 | 142 | // Based on the defined header type, we'll determine the number of 143 | // bytes that we need to read from the file. 144 | headerReader, err := readHeadersFromFile( 145 | f.file, RegularFilterHeaderSize, startHeight, endHeight, 146 | ) 147 | if err != nil { 148 | return nil, err 149 | } 150 | 151 | // We'll now incrementally parse out the set of individual headers from 152 | // our set of serialized contiguous raw headers. 153 | numHeaders := endHeight - startHeight + 1 154 | headers := make([]chainhash.Hash, 0, numHeaders) 155 | for headerReader.Len() != 0 { 156 | var nextHeader chainhash.Hash 157 | if _, err := headerReader.Read(nextHeader[:]); err != nil { 158 | return nil, err 159 | } 160 | 161 | headers = append(headers, nextHeader) 162 | } 163 | 164 | return headers, nil 165 | } 166 | 167 | // readHeadersFromFile reads a chunk of headers, each of size headerSize, from 168 | // the given file, from startHeight to endHeight. 169 | func readHeadersFromFile(f *os.File, headerSize, startHeight, 170 | endHeight uint32) (*bytes.Reader, error) { 171 | 172 | // Each header is headerSize bytes, so using this information, we'll 173 | // seek a distance to cover that height based on the size the headers. 174 | seekDistance := uint64(startHeight) * uint64(headerSize) 175 | 176 | // Based on the number of headers in the range, we'll allocate a single 177 | // slice that's able to hold the entire range of headers. 178 | numHeaders := endHeight - startHeight + 1 179 | rawHeaderBytes := make([]byte, headerSize*numHeaders) 180 | 181 | // Now that we have our slice allocated, we'll read out the entire 182 | // range of headers with a single system call. 183 | _, err := f.ReadAt(rawHeaderBytes, int64(seekDistance)) 184 | if err != nil { 185 | return nil, err 186 | } 187 | 188 | return bytes.NewReader(rawHeaderBytes), nil 189 | } 190 | -------------------------------------------------------------------------------- /headerfs/truncate.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package headerfs 5 | 6 | import "fmt" 7 | 8 | // singleTruncate truncates a single header from the end of the header file. 9 | // This can be used in the case of a re-org to remove the last header from the 10 | // end of the main chain. 11 | // 12 | // TODO(roasbeef): define this and the two methods above on a headerFile 13 | // struct? 14 | func (h *headerStore) singleTruncate() error { 15 | // In order to truncate the file, we'll need to grab the absolute size 16 | // of the file as it stands currently. 17 | fileInfo, err := h.file.Stat() 18 | if err != nil { 19 | return err 20 | } 21 | fileSize := fileInfo.Size() 22 | 23 | // Next, we'll determine the number of bytes we need to truncate from 24 | // the end of the file. 25 | var truncateLength int64 26 | switch h.indexType { 27 | case Block: 28 | truncateLength = 80 29 | case RegularFilter: 30 | truncateLength = 32 31 | default: 32 | return fmt.Errorf("unknown index type: %v", h.indexType) 33 | } 34 | 35 | // Finally, we'll use both of these values to calculate the new size of 36 | // the file and truncate it accordingly. 37 | newSize := fileSize - truncateLength 38 | return h.file.Truncate(newSize) 39 | } 40 | -------------------------------------------------------------------------------- /headerfs/truncate_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package headerfs 5 | 6 | import ( 7 | "fmt" 8 | "os" 9 | ) 10 | 11 | // singleTruncate truncates a single header from the end of the header file. 12 | // This can be used in the case of a re-org to remove the last header from the 13 | // end of the main chain. 14 | // 15 | // TODO(roasbeef): define this and the two methods above on a headerFile 16 | // struct? 17 | func (h *headerStore) singleTruncate() error { 18 | // In order to truncate the file, we'll need to grab the absolute size 19 | // of the file as it stands currently. 20 | fileInfo, err := h.file.Stat() 21 | if err != nil { 22 | return err 23 | } 24 | fileSize := fileInfo.Size() 25 | 26 | // Next, we'll determine the number of bytes we need to truncate from 27 | // the end of the file. 28 | var truncateLength int64 29 | switch h.indexType { 30 | case Block: 31 | truncateLength = 80 32 | case RegularFilter: 33 | truncateLength = 32 34 | default: 35 | return fmt.Errorf("unknown index type: %v", h.indexType) 36 | } 37 | 38 | // Finally, we'll use both of these values to calculate the new size of 39 | // the file. 40 | newSize := fileSize - truncateLength 41 | 42 | // On Windows, a file can't be truncated while open, even if using a 43 | // file handle to truncate it. This means we have to close, truncate, 44 | // and reopen it. 45 | fileName := h.file.Name() 46 | if err = h.file.Close(); err != nil { 47 | return err 48 | } 49 | 50 | if err = os.Truncate(fileName, newSize); err != nil { 51 | return err 52 | } 53 | 54 | fileFlags := os.O_RDWR | os.O_APPEND | os.O_CREATE 55 | h.file, err = os.OpenFile(fileName, fileFlags, 0644) 56 | return err 57 | } 58 | -------------------------------------------------------------------------------- /headerlist/bounded_header_list.go: -------------------------------------------------------------------------------- 1 | package headerlist 2 | 3 | // BoundedMemoryChain is an implementation of the headerlist.Chain interface 4 | // which has a bounded size. The chain will be stored purely in memory. This is 5 | // useful for enforcing that only the past N headers are stored in memory, or 6 | // even as the primary header store. If an element inserted to the end of the 7 | // chain exceeds the size limit, then the head of the chain will be moved 8 | // forward removing a single entry from the head of the chain. 9 | type BoundedMemoryChain struct { 10 | // headPtr points to the "front" of the chain. If the tailPtr is less 11 | // than this value, then we've wrapped around once. This value can 12 | // never exceed the maxSize value. 13 | headPtr int32 14 | 15 | // tailPtr points to the "tail" of the chain. This indexes into the 16 | // main chain slice which stores each node. This value can never exceed 17 | // the maxSize value. 18 | tailPtr int32 19 | 20 | // len is the length of the chain. This will be incremented for each 21 | // item inserted. This value can never exceed the maxSize value. 22 | len int32 23 | 24 | // maxSize is the max number of elements that should be kept int the 25 | // BoundedMemoryChain. Once we exceed this size, we'll start to wrap 26 | // the chain around. 27 | maxSize int32 28 | 29 | // chain is the primary store of the chain. 30 | chain []Node 31 | } 32 | 33 | // NewBoundedMemoryChain returns a new instance of the BoundedMemoryChain with 34 | // a target max number of nodes. 35 | func NewBoundedMemoryChain(maxNodes uint32) *BoundedMemoryChain { 36 | return &BoundedMemoryChain{ 37 | headPtr: -1, 38 | tailPtr: -1, 39 | maxSize: int32(maxNodes), 40 | chain: make([]Node, maxNodes), 41 | } 42 | } 43 | 44 | // A compile time constant to ensure that BoundedMemoryChain meets the Chain 45 | // interface. 46 | var _ Chain = (*BoundedMemoryChain)(nil) 47 | 48 | // ResetHeaderState resets the state of all nodes. After this method, it will 49 | // be as if the chain was just newly created. 50 | // 51 | // NOTE: Part of the Chain interface. 52 | func (b *BoundedMemoryChain) ResetHeaderState(n Node) { 53 | b.headPtr = -1 54 | b.tailPtr = -1 55 | b.len = 0 56 | 57 | b.PushBack(n) 58 | } 59 | 60 | // Back returns the end of the chain. If the chain is empty, then this return a 61 | // pointer to a nil node. 62 | // 63 | // NOTE: Part of the Chain interface. 64 | func (b *BoundedMemoryChain) Back() *Node { 65 | if b.tailPtr == -1 && b.headPtr == -1 { 66 | return nil 67 | } 68 | 69 | return &b.chain[b.tailPtr] 70 | } 71 | 72 | // Front returns the head of the chain. If the chain is empty, then this 73 | // returns a pointer to a nil node. 74 | // 75 | // NOTE: Part of the Chain interface. 76 | func (b *BoundedMemoryChain) Front() *Node { 77 | if b.tailPtr == -1 && b.headPtr == -1 { 78 | return nil 79 | } 80 | 81 | return &b.chain[b.headPtr] 82 | } 83 | 84 | // PushBack will push a new entry to the end of the chain. The entry added to 85 | // the chain is also returned in place. As the chain is bounded, if the length 86 | // of the chain is exceeded, then the front of the chain will be walked forward 87 | // one element. 88 | // 89 | // NOTE: Part of the Chain interface. 90 | func (b *BoundedMemoryChain) PushBack(n Node) *Node { 91 | // Before we do any insertion, we'll fetch the prior element to be able 92 | // to easily set the prev pointer of the new entry. 93 | var prevElem *Node 94 | if b.tailPtr != -1 { 95 | prevElem = &b.chain[b.tailPtr] 96 | } 97 | 98 | // As we're adding to the chain, we'll increment the tail pointer and 99 | // clamp it down to the max size. 100 | b.tailPtr++ 101 | b.tailPtr %= b.maxSize 102 | 103 | // If we've wrapped around, or this is the first insertion, then we'll 104 | // increment the head pointer as well so it tracks the "start" of the 105 | // queue properly. 106 | if b.tailPtr <= b.headPtr || b.headPtr == -1 { 107 | b.headPtr++ 108 | b.headPtr %= b.maxSize 109 | 110 | // As this is the new head of the chain, we'll set its prev 111 | // pointer to nil. 112 | b.chain[b.headPtr].prev = nil 113 | } 114 | 115 | // Now that we've updated the header and tail pointer, we can add the 116 | // new element to our backing slice, and also update its index within 117 | // the current chain. 118 | chainIndex := b.tailPtr 119 | b.chain[chainIndex] = n 120 | 121 | // If this isn't the very fist element we're inserting, then we'll set 122 | // its prev pointer to the prior node. 123 | 124 | // Now that we've inserted this new element, we'll set the prev pointer 125 | // to the prior element. If this is the first element, then we'll just 126 | // set the nil value again. 127 | b.chain[chainIndex].prev = prevElem 128 | 129 | // Finally, we'll increment the length of the chain, and clamp down the 130 | // size if needed to the max possible length. 131 | b.len++ 132 | if b.len > b.maxSize { 133 | b.len = b.maxSize 134 | } 135 | 136 | return &b.chain[chainIndex] 137 | } 138 | -------------------------------------------------------------------------------- /headerlist/bounded_header_list_test.go: -------------------------------------------------------------------------------- 1 | package headerlist 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/davecgh/go-spew/spew" 7 | ) 8 | 9 | // TestBoundedMemoryChainEmptyList tests the expected functionality of an empty 10 | // list w.r.t which methods return a nil pointer and which do not. 11 | func TestBoundedMemoryChainEmptyList(t *testing.T) { 12 | t.Parallel() 13 | 14 | memChain := NewBoundedMemoryChain(5) 15 | 16 | // An empty list should have a nil Back() pointer. 17 | if memChain.Back() != nil { 18 | t.Fatalf("back of chain should be nil but isn't") 19 | } 20 | 21 | // An empty list should have a nil Front() pointer. 22 | if memChain.Front() != nil { 23 | t.Fatalf("front of chain should be nil but isn't") 24 | } 25 | 26 | // The length of the chain at this point should be zero. 27 | if memChain.len != 0 { 28 | t.Fatalf("length of chain should be zero, is instead: %v", 29 | memChain.len) 30 | } 31 | 32 | // After we push back a single element to the empty list, the Front() 33 | // and Back() pointers should be identical. 34 | memChain.PushBack(Node{ 35 | Height: 1, 36 | }) 37 | 38 | if memChain.Front() != memChain.Back() { 39 | t.Fatalf("back and front of chain of length 1 should be " + 40 | "identical") 41 | } 42 | } 43 | 44 | // TestBoundedMemoryChainResetHeaderState tests that if we insert a number of 45 | // elements, then reset the chain to nothing, it is identical to a newly 46 | // created chain with only that element. 47 | func TestBoundedMemoryChainResetHeaderState(t *testing.T) { 48 | t.Parallel() 49 | 50 | memChain := NewBoundedMemoryChain(5) 51 | 52 | // We'll start out by inserting 3 elements into the chain. 53 | const numElements = 3 54 | for i := 0; i < numElements; i++ { 55 | memChain.PushBack(Node{ 56 | Height: int32(i), 57 | }) 58 | } 59 | 60 | // With the set of elements inserted, we'll now pick a new element to 61 | // serve as the very head of the chain, with all other items removed. 62 | newNode := Node{ 63 | Height: 4, 64 | } 65 | memChain.ResetHeaderState(newNode) 66 | 67 | // At this point, the front and back of the chain should be identical. 68 | if memChain.Front() != memChain.Back() { 69 | t.Fatalf("back and front of chain of length 1 should be " + 70 | "identical") 71 | } 72 | 73 | // Additionally, both the front and back of the chain should be 74 | // identical to the node above. 75 | if *memChain.Front() != newNode { 76 | t.Fatalf("wrong node, expected %v, got %v", newNode, 77 | memChain.Front()) 78 | } 79 | if *memChain.Back() != newNode { 80 | t.Fatalf("wrong node, expected %v, got %v", newNode, 81 | memChain.Back()) 82 | } 83 | } 84 | 85 | // TestBoundedMemoryChainSizeLimit tests that if we add elements until the size 86 | // of the list if exceeded, then the list is properly bounded. 87 | func TestBoundedMemoryChainSizeLimit(t *testing.T) { 88 | t.Parallel() 89 | 90 | memChain := NewBoundedMemoryChain(5) 91 | 92 | // We'll start out by inserting 20 elements into the memChain. As this 93 | // is greater than the total number of elements, we should end up with 94 | // the chain bounded at the end of the set of insertions. 95 | const numElements = 20 96 | var totalElems []Node 97 | for i := 0; i < numElements; i++ { 98 | node := Node{ 99 | Height: int32(i), 100 | } 101 | memChain.PushBack(node) 102 | 103 | totalElems = append(totalElems, node) 104 | } 105 | 106 | // At this point, the length of the chain should still be 5, the total 107 | // number of elements. 108 | if memChain.len != 5 { 109 | t.Fatalf("wrong length, expected %v, got %v", 5, memChain.len) 110 | } 111 | 112 | // If we attempt to get the prev element front of the chain, we should 113 | // get a nil value. 114 | if memChain.Front().Prev() != nil { 115 | t.Fatalf("expected prev of tail to be nil, is instead: %v", 116 | spew.Sdump(memChain.Front().Prev())) 117 | } 118 | 119 | // The prev element to the back of the chain, should be the element 120 | // directly following it. 121 | expectedPrev := totalElems[len(totalElems)-2] 122 | if memChain.Back().Prev().Height != expectedPrev.Height { 123 | t.Fatalf("wrong node, expected %v, got %v", expectedPrev, 124 | memChain.Back().Prev()) 125 | } 126 | 127 | // We'll now confirm that the remaining elements within the chain are 128 | // the as we expect, and that they have the proper prev element. 129 | for i, node := range memChain.chain { 130 | if node.Height != totalElems[15+i].Height { 131 | t.Fatalf("wrong node: expected %v, got %v", 132 | spew.Sdump(node), 133 | spew.Sdump(totalElems[15+i])) 134 | } 135 | 136 | if i == 0 { 137 | if node.Prev() != nil { 138 | t.Fatalf("prev of first elem should be nil") 139 | } 140 | } else { 141 | expectedPrevElem := memChain.chain[i-1] 142 | if node.Prev().Height != expectedPrevElem.Height { 143 | t.Fatalf("wrong node: expected %v, got %v", 144 | spew.Sdump(expectedPrevElem), 145 | spew.Sdump(node.Prev())) 146 | } 147 | } 148 | } 149 | } 150 | 151 | // TestBoundedMemoryChainPrevIteration tests that once we insert elements, we 152 | // can properly traverse the entire chain backwards, starting from the final 153 | // element. 154 | func TestBoundedMemoryChainPrevIteration(t *testing.T) { 155 | t.Parallel() 156 | 157 | memChain := NewBoundedMemoryChain(5) 158 | 159 | // We'll start out by inserting 3 elements into the chain. 160 | const numElements = 3 161 | for i := 0; i < numElements; i++ { 162 | memChain.PushBack(Node{ 163 | Height: int32(i), 164 | }) 165 | } 166 | 167 | // We'll now add an additional element to the chain. 168 | iterNode := memChain.PushBack(Node{ 169 | Height: 99, 170 | }) 171 | 172 | // We'll now walk backwards with the iterNode until we run into the nil 173 | // pointer. 174 | for iterNode != nil { 175 | nextNode := iterNode 176 | iterNode = iterNode.Prev() 177 | 178 | if iterNode != nil && nextNode.Prev().Height != iterNode.Height { 179 | t.Fatalf("expected %v, got %v", 180 | spew.Sdump(nextNode.Prev()), 181 | spew.Sdump(iterNode)) 182 | } 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /headerlist/header_list.go: -------------------------------------------------------------------------------- 1 | package headerlist 2 | 3 | import "github.com/btcsuite/btcd/wire" 4 | 5 | // Chain is an interface that stores a list of Nodes. Each node represents a 6 | // header in the main chain and also includes a height along with it. This is 7 | // meant to serve as a replacement to list.List which provides similar 8 | // functionality, but allows implementations to use custom storage backends and 9 | // semantics. 10 | type Chain interface { 11 | // ResetHeaderState resets the state of all nodes. After this method, it will 12 | // be as if the chain was just newly created. 13 | ResetHeaderState(Node) 14 | 15 | // Back returns the end of the chain. If the chain is empty, then this 16 | // return a pointer to a nil node. 17 | Back() *Node 18 | 19 | // Front returns the head of the chain. If the chain is empty, then 20 | // this returns a pointer to a nil node. 21 | Front() *Node 22 | 23 | // PushBack will push a new entry to the end of the chain. The entry 24 | // added to the chain is also returned in place. 25 | PushBack(Node) *Node 26 | } 27 | 28 | // Node is a node within the Chain. Each node stores a header as well as a 29 | // height. Nodes can also be used to traverse the chain backwards via their 30 | // Prev() method. 31 | type Node struct { 32 | // Height is the height of this node within the main chain. 33 | Height int32 34 | 35 | // Header is the header that this node represents. 36 | Header wire.BlockHeader 37 | 38 | prev *Node 39 | } 40 | 41 | // Prev attempts to access the prior node within the header chain relative to 42 | // this node. If this is the start of the chain, then this method will return 43 | // nil. 44 | func (n *Node) Prev() *Node { 45 | return n.prev 46 | } 47 | -------------------------------------------------------------------------------- /headerlogger.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/btcsuite/btclog" 8 | ) 9 | 10 | // headerProgressLogger provides periodic logging for other services in order 11 | // to show users progress of certain "actions" involving some or all current 12 | // blocks. Ex: syncing to best chain, indexing all blocks, etc. 13 | type headerProgressLogger struct { 14 | receivedLogBlocks int64 15 | lastBlockLogTime time.Time 16 | 17 | entityType string 18 | 19 | subsystemLogger btclog.Logger 20 | progressAction string 21 | sync.Mutex 22 | } 23 | 24 | // newBlockProgressLogger returns a new block progress logger. 25 | // The progress message is templated as follows: 26 | // 27 | // {progressAction} {numProcessed} {blocks|block} in the last {timePeriod} 28 | // ({numTxs}, height {lastBlockHeight}, {lastBlockTimeStamp}) 29 | func newBlockProgressLogger(progressMessage string, 30 | entityType string, logger btclog.Logger) *headerProgressLogger { 31 | 32 | return &headerProgressLogger{ 33 | entityType: entityType, 34 | lastBlockLogTime: time.Now(), 35 | progressAction: progressMessage, 36 | subsystemLogger: logger, 37 | } 38 | } 39 | 40 | // LogBlockHeight logs a new block height as an information message to show 41 | // progress to the user. In order to prevent spam, it limits logging to one 42 | // message every 10 seconds with duration and totals included. 43 | func (b *headerProgressLogger) LogBlockHeight(timestamp time.Time, height int32) { 44 | b.Lock() 45 | defer b.Unlock() 46 | 47 | b.receivedLogBlocks++ 48 | 49 | // TODO(roasbeef): have diff logger for fetching blocks to can eye ball 50 | // false positive 51 | 52 | now := time.Now() 53 | duration := now.Sub(b.lastBlockLogTime) 54 | if duration < time.Second*10 { 55 | return 56 | } 57 | 58 | // Truncate the duration to 10s of milliseconds. 59 | durationMillis := int64(duration / time.Millisecond) 60 | tDuration := 10 * time.Millisecond * time.Duration(durationMillis/10) 61 | 62 | // Log information about new block height. 63 | entityStr := b.entityType 64 | if b.receivedLogBlocks > 1 { 65 | entityStr += "s" 66 | } 67 | b.subsystemLogger.Infof("%s %d %s in the last %s (height %d, %s)", 68 | b.progressAction, b.receivedLogBlocks, entityStr, tDuration, 69 | height, timestamp) 70 | 71 | b.receivedLogBlocks = 0 72 | b.lastBlockLogTime = now 73 | } 74 | 75 | func (b *headerProgressLogger) SetLastLogTime(time time.Time) { 76 | b.lastBlockLogTime = time 77 | } 78 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "github.com/btcsuite/btcd/addrmgr" 5 | "github.com/btcsuite/btcd/blockchain" 6 | "github.com/btcsuite/btcd/connmgr" 7 | "github.com/btcsuite/btcd/peer" 8 | "github.com/btcsuite/btcd/txscript" 9 | "github.com/btcsuite/btclog" 10 | "github.com/lightninglabs/neutrino/blockntfns" 11 | "github.com/lightninglabs/neutrino/chanutils" 12 | "github.com/lightninglabs/neutrino/filterdb" 13 | "github.com/lightninglabs/neutrino/pushtx" 14 | "github.com/lightninglabs/neutrino/query" 15 | ) 16 | 17 | // log is a logger that is initialized with no output filters. This 18 | // means the package will not perform any logging by default until the caller 19 | // requests it. 20 | var log btclog.Logger 21 | 22 | // The default amount of logging is none. 23 | func init() { 24 | DisableLog() 25 | } 26 | 27 | // DisableLog disables all library log output. Logging output is disabled 28 | // by default until either UseLogger or SetLogWriter are called. 29 | func DisableLog() { 30 | log = btclog.Disabled 31 | } 32 | 33 | // UseLogger uses a specified Logger to output package logging info. 34 | // This should be used in preference to SetLogWriter if the caller is also 35 | // using btclog. 36 | func UseLogger(logger btclog.Logger) { 37 | log = logger 38 | blockchain.UseLogger(logger) 39 | txscript.UseLogger(logger) 40 | peer.UseLogger(logger) 41 | addrmgr.UseLogger(logger) 42 | blockntfns.UseLogger(logger) 43 | pushtx.UseLogger(logger) 44 | connmgr.UseLogger(logger) 45 | query.UseLogger(logger) 46 | filterdb.UseLogger(logger) 47 | chanutils.UseLogger(logger) 48 | } 49 | -------------------------------------------------------------------------------- /mock_store.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/btcsuite/btcd/blockchain" 7 | "github.com/btcsuite/btcd/chaincfg/chainhash" 8 | "github.com/btcsuite/btcd/wire" 9 | "github.com/lightninglabs/neutrino/headerfs" 10 | ) 11 | 12 | // mockBlockHeaderStore is an implementation of the BlockHeaderStore backed by 13 | // a simple map. 14 | type mockBlockHeaderStore struct { 15 | headers map[chainhash.Hash]wire.BlockHeader 16 | heights map[uint32]wire.BlockHeader 17 | } 18 | 19 | // A compile-time check to ensure the mockBlockHeaderStore adheres to the 20 | // BlockHeaderStore interface. 21 | var _ headerfs.BlockHeaderStore = (*mockBlockHeaderStore)(nil) 22 | 23 | // NewMockBlockHeaderStore returns a version of the BlockHeaderStore that's 24 | // backed by an in-memory map. This instance is meant to be used by callers 25 | // outside the package to unit test components that require a BlockHeaderStore 26 | // interface. 27 | func newMockBlockHeaderStore() *mockBlockHeaderStore { 28 | return &mockBlockHeaderStore{ 29 | headers: make(map[chainhash.Hash]wire.BlockHeader), 30 | heights: make(map[uint32]wire.BlockHeader), 31 | } 32 | } 33 | 34 | func (m *mockBlockHeaderStore) ChainTip() (*wire.BlockHeader, 35 | uint32, error) { 36 | 37 | return nil, 0, nil 38 | } 39 | func (m *mockBlockHeaderStore) LatestBlockLocator() ( 40 | blockchain.BlockLocator, error) { 41 | 42 | return nil, nil 43 | } 44 | 45 | func (m *mockBlockHeaderStore) FetchHeaderByHeight(height uint32) ( 46 | *wire.BlockHeader, error) { 47 | 48 | if header, ok := m.heights[height]; ok { 49 | return &header, nil 50 | } 51 | 52 | return nil, headerfs.ErrHeightNotFound 53 | } 54 | 55 | func (m *mockBlockHeaderStore) FetchHeaderAncestors(uint32, 56 | *chainhash.Hash) ([]wire.BlockHeader, uint32, error) { 57 | 58 | return nil, 0, nil 59 | } 60 | func (m *mockBlockHeaderStore) HeightFromHash(*chainhash.Hash) (uint32, error) { 61 | return 0, nil 62 | } 63 | func (m *mockBlockHeaderStore) RollbackLastBlock() (*headerfs.BlockStamp, 64 | error) { 65 | 66 | return nil, nil 67 | } 68 | 69 | func (m *mockBlockHeaderStore) FetchHeader(h *chainhash.Hash) ( 70 | *wire.BlockHeader, uint32, error) { 71 | 72 | if header, ok := m.headers[*h]; ok { 73 | return &header, 0, nil 74 | } 75 | return nil, 0, fmt.Errorf("not found") 76 | } 77 | 78 | func (m *mockBlockHeaderStore) WriteHeaders(headers ...headerfs.BlockHeader) error { 79 | for _, h := range headers { 80 | m.headers[h.BlockHash()] = *h.BlockHeader 81 | } 82 | 83 | return nil 84 | } 85 | -------------------------------------------------------------------------------- /pushtx/broadcaster.go: -------------------------------------------------------------------------------- 1 | package pushtx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/btcsuite/btcd/chaincfg/chainhash" 10 | "github.com/btcsuite/btcd/wire" 11 | "github.com/btcsuite/btcwallet/wtxmgr" 12 | "github.com/lightninglabs/neutrino/blockntfns" 13 | ) 14 | 15 | var ( 16 | // ErrBroadcasterStopped is an error returned when we attempt to process 17 | // a request to broadcast a transaction but the Broadcaster has already 18 | // been stopped. 19 | ErrBroadcasterStopped = errors.New("broadcaster has been stopped") 20 | ) 21 | 22 | const ( 23 | // DefaultBroadcastTimeout is the default timeout used when broadcasting 24 | // transactions to network peers. 25 | DefaultBroadcastTimeout = 5 * time.Second 26 | 27 | // DefaultRebroadcastInterval is the default period that we'll wait 28 | // between blocks to attempt another rebroadcast. 29 | DefaultRebroadcastInterval = time.Minute 30 | ) 31 | 32 | // broadcastReq is an internal message the Broadcaster will use to process 33 | // transaction broadcast requests. 34 | type broadcastReq struct { 35 | tx *wire.MsgTx 36 | errChan chan error 37 | } 38 | 39 | // Config contains all of the external dependencies required for the Broadcaster 40 | // to properly carry out its duties. 41 | type Config struct { 42 | // Broadcast broadcasts a transaction to the network. We expect certain 43 | // BroadcastError's to be returned to handle special cases, namely 44 | // errors with the codes Mempool and Confirmed. 45 | Broadcast func(*wire.MsgTx) error 46 | 47 | // SubscribeBlocks returns a block subscription that delivers block 48 | // notifications in order. This will be used to rebroadcast all 49 | // transactions once a new block arrives. 50 | SubscribeBlocks func() (*blockntfns.Subscription, error) 51 | 52 | // RebroadcastInterval is the interval that we'll continually try to 53 | // re-broadcast transactions in-between new block arrival. 54 | RebroadcastInterval time.Duration 55 | 56 | // MapCustomBroadcastError allows the Rebroadcaster to map broadcast 57 | // errors from other backends to the neutrino internal BroadcastError. 58 | // This allows the Rebroadcaster to behave consistently over different 59 | // backends. 60 | MapCustomBroadcastError func(error) error 61 | } 62 | 63 | // Broadcaster is a subsystem responsible for reliably broadcasting transactions 64 | // to the network. Each transaction will be rebroadcast upon every new block 65 | // being connected/disconnected to/from the chain. 66 | type Broadcaster struct { 67 | start sync.Once 68 | stop sync.Once 69 | 70 | cfg Config 71 | 72 | // broadcastReqs is a channel through which new transaction broadcast 73 | // requests from external callers will be streamed through. 74 | broadcastReqs chan *broadcastReq 75 | 76 | // confChan is a channel used to notify the broadcast handler about 77 | // confirmed transactions. 78 | confChan chan chainhash.Hash 79 | 80 | quit chan struct{} 81 | wg sync.WaitGroup 82 | } 83 | 84 | // NewBroadcaster creates a new Broadcaster backed by the given config. 85 | func NewBroadcaster(cfg *Config) *Broadcaster { 86 | b := &Broadcaster{ 87 | cfg: *cfg, 88 | broadcastReqs: make(chan *broadcastReq), 89 | confChan: make(chan chainhash.Hash), 90 | quit: make(chan struct{}), 91 | } 92 | 93 | return b 94 | } 95 | 96 | // Start starts all of the necessary steps for the Broadcaster to begin properly 97 | // carrying out its duties. 98 | func (b *Broadcaster) Start() error { 99 | var returnErr error 100 | b.start.Do(func() { 101 | sub, err := b.cfg.SubscribeBlocks() 102 | if err != nil { 103 | returnErr = fmt.Errorf("unable to subscribe for block "+ 104 | "notifications: %v", err) 105 | return 106 | } 107 | 108 | b.wg.Add(1) 109 | go b.broadcastHandler(sub) 110 | }) 111 | return returnErr 112 | } 113 | 114 | // Stop halts the Broadcaster from rebroadcasting pending transactions. 115 | func (b *Broadcaster) Stop() { 116 | b.stop.Do(func() { 117 | close(b.quit) 118 | b.wg.Wait() 119 | }) 120 | } 121 | 122 | // broadcastHandler is the main event handler of the Broadcaster responsible for 123 | // handling new broadcast requests, rebroadcasting transactions upon every new 124 | // block, etc. 125 | // 126 | // NOTE: This must be run as a goroutine. 127 | func (b *Broadcaster) broadcastHandler(sub *blockntfns.Subscription) { 128 | defer b.wg.Done() 129 | defer sub.Cancel() 130 | 131 | log.Infof("Broadcaster now active") 132 | 133 | // transactions is the set of transactions we have broadcast so far, 134 | // and are still not confirmed. 135 | transactions := make(map[chainhash.Hash]*wire.MsgTx) 136 | 137 | // The rebroadcast semaphore is used to ensure we have only one 138 | // rebroadcast running at a time. 139 | rebroadcastSem := make(chan struct{}, 1) 140 | rebroadcastSem <- struct{}{} 141 | 142 | // triggerRebroadcast is a helper method that checks whether the 143 | // rebroadcast semaphore is available, and if it is spawns a goroutine 144 | // to rebroadcast all pending transactions. 145 | triggerRebroadcast := func() { 146 | select { 147 | // If the rebroadcast semaphore is available, start a 148 | // new goroutine to exectue a rebroadcast. 149 | case <-rebroadcastSem: 150 | default: 151 | log.Tracef("Existing rebroadcast still in " + 152 | "progress") 153 | return 154 | } 155 | 156 | // Make a copy of the current set of transactions to hand to 157 | // the goroutine. 158 | txs := make(map[chainhash.Hash]*wire.MsgTx) 159 | for k, v := range transactions { 160 | txs[k] = v.Copy() 161 | } 162 | 163 | b.wg.Add(1) 164 | go func() { 165 | defer b.wg.Done() 166 | 167 | b.rebroadcast(txs, b.confChan) 168 | rebroadcastSem <- struct{}{} 169 | }() 170 | } 171 | 172 | reBroadcastTicker := time.NewTicker(b.cfg.RebroadcastInterval) 173 | defer reBroadcastTicker.Stop() 174 | 175 | for { 176 | select { 177 | // A new broadcast request was submitted by an external caller. 178 | case req := <-b.broadcastReqs: 179 | err := b.cfg.Broadcast(req.tx) 180 | if err != nil { 181 | // We apply the custom err mapping function if 182 | // it was supplied which allows to map other 183 | // backend errors to the neutrino BroadcastError. 184 | if b.cfg.MapCustomBroadcastError != nil { 185 | err = b.cfg.MapCustomBroadcastError(err) 186 | } 187 | if !IsBroadcastError(err, Mempool) { 188 | log.Errorf("Broadcast attempt "+ 189 | "failed: %v", err) 190 | req.errChan <- err 191 | continue 192 | } 193 | } 194 | 195 | transactions[req.tx.TxHash()] = req.tx 196 | req.errChan <- nil 197 | 198 | // A tx was confirmed, and we can remove it from our set of 199 | // transactions. 200 | case txHash := <-b.confChan: 201 | delete(transactions, txHash) 202 | 203 | // A new block notification has arrived, so we'll rebroadcast 204 | // all of our pending transactions. 205 | case _, ok := <-sub.Notifications: 206 | if !ok { 207 | log.Warn("Unable to rebroadcast transactions: " + 208 | "block subscription was canceled") 209 | continue 210 | } 211 | triggerRebroadcast() 212 | 213 | // Between blocks, we'll also try to attempt additional 214 | // re-broadcasts to ensure a timely confirmation. 215 | case <-reBroadcastTicker.C: 216 | triggerRebroadcast() 217 | 218 | case <-b.quit: 219 | return 220 | } 221 | } 222 | } 223 | 224 | // rebroadcast rebroadcasts all of the currently pending transactions. Care has 225 | // been taken to ensure that the transactions are sorted in their dependency 226 | // order to prevent peers from deeming our transactions as invalid due to 227 | // broadcasting them before their pending dependencies. 228 | func (b *Broadcaster) rebroadcast(txs map[chainhash.Hash]*wire.MsgTx, 229 | confChan chan<- chainhash.Hash) { 230 | 231 | // Return immediately if there are no transactions to re-broadcast. 232 | if len(txs) == 0 { 233 | return 234 | } 235 | 236 | log.Debugf("Re-broadcasting %d transactions", len(txs)) 237 | 238 | sortedTxs := wtxmgr.DependencySort(txs) 239 | for _, tx := range sortedTxs { 240 | // Before attempting to broadcast this transaction, we check 241 | // whether we are shutting down. 242 | select { 243 | case <-b.quit: 244 | return 245 | default: 246 | } 247 | 248 | err := b.cfg.Broadcast(tx) 249 | // We apply the custom err mapping function if it was supplied 250 | // which allows to map other backend errors to the neutrino 251 | // BroadcastError. 252 | if err != nil && b.cfg.MapCustomBroadcastError != nil { 253 | err = b.cfg.MapCustomBroadcastError(err) 254 | } 255 | switch { 256 | // If the transaction has already confirmed on-chain, we can 257 | // stop broadcasting it further. 258 | // 259 | // TODO(wilmer); This should ideally be implemented by checking 260 | // the chain ourselves rather than trusting our peers. 261 | case IsBroadcastError(err, Confirmed): 262 | log.Debugf("Re-broadcast of txid=%v, now confirmed!", 263 | tx.TxHash()) 264 | 265 | select { 266 | case confChan <- tx.TxHash(): 267 | case <-b.quit: 268 | return 269 | } 270 | continue 271 | 272 | // If the transaction already exists within our peers' mempool, 273 | // we'll continue to rebroadcast it to ensure it actually 274 | // propagates throughout the network. 275 | // 276 | // TODO(wilmer): Rate limit peers that have already accepted our 277 | // transaction into their mempool to prevent resending to them 278 | // every time. 279 | case IsBroadcastError(err, Mempool): 280 | log.Debugf("Re-broadcast of txid=%v, still "+ 281 | "pending...", tx.TxHash()) 282 | 283 | continue 284 | 285 | case err != nil: 286 | log.Errorf("Unable to rebroadcast transaction %v: %v", 287 | tx.TxHash(), err) 288 | continue 289 | } 290 | } 291 | } 292 | 293 | // Broadcast submits a request to the Broadcaster to reliably broadcast the 294 | // given transaction. An error won't be returned if the transaction already 295 | // exists within the mempool. Any transaction broadcast through this method will 296 | // be rebroadcast upon every change of the tip of the chain. 297 | func (b *Broadcaster) Broadcast(tx *wire.MsgTx) error { 298 | errChan := make(chan error, 1) 299 | 300 | select { 301 | case b.broadcastReqs <- &broadcastReq{ 302 | tx: tx, 303 | errChan: errChan, 304 | }: 305 | case <-b.quit: 306 | return ErrBroadcasterStopped 307 | } 308 | 309 | select { 310 | case err := <-errChan: 311 | return err 312 | case <-b.quit: 313 | return ErrBroadcasterStopped 314 | } 315 | } 316 | 317 | // MarkAsConfirmed is used to tell the broadcaster that a transaction has been 318 | // confirmed and that it is no longer necessary to rebroadcast this transaction. 319 | func (b *Broadcaster) MarkAsConfirmed(txHash chainhash.Hash) { 320 | b.confChan <- txHash 321 | } 322 | -------------------------------------------------------------------------------- /pushtx/broadcaster_test.go: -------------------------------------------------------------------------------- 1 | package pushtx 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/btcsuite/btcd/btcjson" 11 | "github.com/btcsuite/btcd/wire" 12 | "github.com/lightninglabs/neutrino/blockntfns" 13 | ) 14 | 15 | // createTx is a helper method to create random transactions that spend 16 | // particular inputs. 17 | func createTx(t *testing.T, numOutputs int, inputs ...wire.OutPoint) *wire.MsgTx { 18 | t.Helper() 19 | 20 | tx := wire.NewMsgTx(1) 21 | if len(inputs) == 0 { 22 | tx.AddTxIn(&wire.TxIn{}) 23 | } else { 24 | for _, input := range inputs { 25 | tx.AddTxIn(&wire.TxIn{PreviousOutPoint: input}) 26 | } 27 | } 28 | for i := 0; i < numOutputs; i++ { 29 | var pkScript [32]byte 30 | if _, err := rand.Read(pkScript[:]); err != nil { 31 | t.Fatal(err) 32 | } 33 | 34 | tx.AddTxOut(&wire.TxOut{ 35 | Value: rand.Int63(), 36 | PkScript: pkScript[:], 37 | }) 38 | } 39 | 40 | return tx 41 | } 42 | 43 | // TestBroadcaster ensures that we can broadcast transactions while it is 44 | // active. 45 | func TestBroadcaster(t *testing.T) { 46 | t.Parallel() 47 | 48 | cfg := &Config{ 49 | Broadcast: func(*wire.MsgTx) error { 50 | return nil 51 | }, 52 | SubscribeBlocks: func() (*blockntfns.Subscription, error) { 53 | return &blockntfns.Subscription{ 54 | Notifications: make(chan blockntfns.BlockNtfn), 55 | Cancel: func() {}, 56 | }, nil 57 | }, 58 | RebroadcastInterval: DefaultRebroadcastInterval, 59 | } 60 | 61 | broadcaster := NewBroadcaster(cfg) 62 | 63 | if err := broadcaster.Start(); err != nil { 64 | t.Fatalf("unable to start broadcaster: %v", err) 65 | } 66 | 67 | tx := &wire.MsgTx{} 68 | if err := broadcaster.Broadcast(tx); err != nil { 69 | t.Fatalf("unable to broadcast transaction: %v", err) 70 | } 71 | 72 | broadcaster.Stop() 73 | 74 | if err := broadcaster.Broadcast(tx); err != ErrBroadcasterStopped { 75 | t.Fatalf("expected ErrBroadcasterStopped, got %v", err) 76 | } 77 | } 78 | 79 | // TestRebroadcast ensures that we properly rebroadcast transactions upon every 80 | // new block. Transactions that have confirmed should no longer be broadcast. 81 | func TestRebroadcast(t *testing.T) { 82 | t.Parallel() 83 | 84 | const numTxs = 5 85 | 86 | // We'll start by setting up the broadcaster with channels to mock the 87 | // behavior of its external dependencies. 88 | broadcastChan := make(chan *wire.MsgTx, numTxs) 89 | ntfnChan := make(chan blockntfns.BlockNtfn) 90 | 91 | cfg := &Config{ 92 | Broadcast: func(tx *wire.MsgTx) error { 93 | broadcastChan <- tx 94 | return nil 95 | }, 96 | SubscribeBlocks: func() (*blockntfns.Subscription, error) { 97 | return &blockntfns.Subscription{ 98 | Notifications: ntfnChan, 99 | Cancel: func() {}, 100 | }, nil 101 | }, 102 | RebroadcastInterval: DefaultRebroadcastInterval, 103 | } 104 | 105 | broadcaster := NewBroadcaster(cfg) 106 | 107 | if err := broadcaster.Start(); err != nil { 108 | t.Fatalf("unable to start broadcaster: %v", err) 109 | } 110 | defer broadcaster.Stop() 111 | 112 | // We'll then create some test transactions such that they all depend on 113 | // the previous one, creating a dependency chain. We'll do this to 114 | // ensure transactions are rebroadcast in the order of their 115 | // dependencies. 116 | txs := make([]*wire.MsgTx, 0, numTxs) 117 | for i := 0; i < numTxs; i++ { 118 | var tx *wire.MsgTx 119 | if i == 0 { 120 | tx = createTx(t, 1) 121 | } else { 122 | prevOut := wire.OutPoint{ 123 | Hash: txs[i-1].TxHash(), 124 | Index: 0, 125 | } 126 | tx = createTx(t, 1, prevOut) 127 | } 128 | txs = append(txs, tx) 129 | } 130 | 131 | // assertBroadcastOrder is a helper closure to ensure that the 132 | // transactions rebroadcast match the expected order. 133 | assertBroadcastOrder := func(expectedOrder []*wire.MsgTx) { 134 | t.Helper() 135 | 136 | for i := 0; i < len(expectedOrder); i++ { 137 | tx := <-broadcastChan 138 | if tx.TxHash() != expectedOrder[i].TxHash() { 139 | t.Fatalf("expected transaction %v, got %v", 140 | expectedOrder[i].TxHash(), tx.TxHash()) 141 | } 142 | } 143 | } 144 | 145 | // Broadcast the transactions. We'll be broadcasting them in order so 146 | // assertBroadcastOrder is more of a sanity check to ensure that all of 147 | // the transactions were actually broadcast. 148 | for _, tx := range txs { 149 | if err := broadcaster.Broadcast(tx); err != nil { 150 | t.Fatalf("unable to broadcast transaction %v: %v", 151 | tx.TxHash(), err) 152 | } 153 | } 154 | 155 | assertBroadcastOrder(txs) 156 | 157 | // Now, we'll modify the Broadcast method to mark the first transaction 158 | // as confirmed, and the second as it being accepted into the mempool. 159 | broadcaster.cfg.Broadcast = func(tx *wire.MsgTx) error { 160 | broadcastChan <- tx 161 | if tx.TxHash() == txs[0].TxHash() { 162 | return &BroadcastError{Code: Confirmed} 163 | } 164 | if tx.TxHash() == txs[1].TxHash() { 165 | return &BroadcastError{Code: Mempool} 166 | } 167 | return nil 168 | } 169 | 170 | // Trigger a new block notification to rebroadcast the transactions. 171 | ntfnChan <- blockntfns.NewBlockConnected(wire.BlockHeader{}, 100) 172 | 173 | // They should all be broadcast in their expected dependency order. 174 | assertBroadcastOrder(txs) 175 | 176 | // Trigger another block notification simulating a reorg in the chain. 177 | // The transactions should be rebroadcast again to ensure they properly 178 | // propagate throughout the network. 179 | ntfnChan <- blockntfns.NewBlockDisconnected( 180 | wire.BlockHeader{}, 100, wire.BlockHeader{}, 181 | ) 182 | 183 | // This time however, only the last four transactions will be 184 | // rebroadcasted since the first one confirmed in the previous 185 | // rebroadcast attempt. 186 | assertBroadcastOrder(txs[1:]) 187 | 188 | // We now manually mark one of the transactions as confirmed. 189 | broadcaster.MarkAsConfirmed(txs[1].TxHash()) 190 | 191 | // Trigger a new block notification to rebroadcast the transactions. 192 | ntfnChan <- blockntfns.NewBlockConnected(wire.BlockHeader{}, 101) 193 | 194 | // We assert that only the last three transactions are rebroadcasted. 195 | assertBroadcastOrder(txs[2:]) 196 | 197 | // Manually mark the third transaction as confirmed. 198 | broadcaster.MarkAsConfirmed(txs[2].TxHash()) 199 | 200 | // Now we inject a custom error mapping function for backend errors 201 | // other than neutrino. 202 | broadcaster.cfg.MapCustomBroadcastError = func(err error) error { 203 | // match is a helper method to easily string match on the error 204 | // message. 205 | match := func(err error, s string) bool { 206 | return strings.Contains(strings.ToLower(err.Error()), s) 207 | } 208 | 209 | switch { 210 | case match(err, "mempool min fee not met"): 211 | return &BroadcastError{ 212 | Code: Mempool, 213 | Reason: err.Error(), 214 | } 215 | 216 | case match(err, "transaction already exists"): 217 | return &BroadcastError{ 218 | Code: Confirmed, 219 | Reason: err.Error(), 220 | } 221 | 222 | default: 223 | return fmt.Errorf("unmatched backend error: %v", err) 224 | } 225 | } 226 | 227 | // Now, we'll modify the Broadcast method to mark the fourth transaction 228 | // as confirmed but with a bitcoind backend notification to test that 229 | // the mapping between different backend errors and the neutrino 230 | // BroadcastError works as expected. We also mark the last transaction 231 | // with the bitcoind backend error for not having enough fees to be 232 | // included in the mempool. We expected that it gets rebroadcasted too. 233 | broadcaster.cfg.Broadcast = func(tx *wire.MsgTx) error { 234 | broadcastChan <- tx 235 | if tx.TxHash() == txs[3].TxHash() { 236 | return &btcjson.RPCError{ 237 | Code: btcjson.ErrRPCVerifyAlreadyInChain, 238 | Message: "transaction already exists", 239 | } 240 | } 241 | if tx.TxHash() == txs[4].TxHash() { 242 | return &btcjson.RPCError{ 243 | Code: btcjson.ErrRPCTxRejected, 244 | Message: "mempool min fee not met", 245 | } 246 | } 247 | 248 | return nil 249 | } 250 | 251 | // Trigger a new block notification. 252 | ntfnChan <- blockntfns.NewBlockConnected(wire.BlockHeader{}, 102) 253 | 254 | // We assert that only the last two transactions are rebroadcasted. 255 | assertBroadcastOrder(txs[3:]) 256 | 257 | // Trigger another block notification simulating a reorg in the chain. 258 | // The transactions should be rebroadcasted again to ensure they 259 | // properly propagate throughout the network. 260 | ntfnChan <- blockntfns.NewBlockDisconnected( 261 | wire.BlockHeader{}, 102, wire.BlockHeader{}, 262 | ) 263 | 264 | // We assert that only the last transaction is rebroadcasted. 265 | assertBroadcastOrder(txs[4:]) 266 | 267 | // Manually mark the last transaction as confirmed. 268 | broadcaster.MarkAsConfirmed(txs[4].TxHash()) 269 | 270 | // Trigger a new block notification. 271 | ntfnChan <- blockntfns.NewBlockConnected(wire.BlockHeader{}, 103) 272 | 273 | // Assert that no transactions were rebroadcasted. 274 | select { 275 | case tx := <-broadcastChan: 276 | t.Fatalf("unexpected rebroadcast of tx %s", tx.TxHash()) 277 | case <-time.Tick(100 * time.Millisecond): 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /pushtx/error.go: -------------------------------------------------------------------------------- 1 | package pushtx 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/btcsuite/btcd/wire" 8 | ) 9 | 10 | // BroadcastErrorCode uniquely identifies the broadcast error. 11 | type BroadcastErrorCode uint8 12 | 13 | const ( 14 | // Unknown is the code used when a transaction has been rejected by some 15 | // unknown reason by a peer. 16 | Unknown BroadcastErrorCode = iota 17 | 18 | // Invalid is the code used when a transaction has been deemed invalid 19 | // by a peer. 20 | Invalid 21 | 22 | // InsufficientFee is the code used when a transaction has been deemed 23 | // as having an insufficient fee by a peer. 24 | InsufficientFee 25 | 26 | // Mempool is the code used when a transaction already exists in a 27 | // peer's mempool. 28 | Mempool 29 | 30 | // Confirmed is the code used when a transaction has been deemed as 31 | // confirmed in the chain by a peer. 32 | Confirmed 33 | ) 34 | 35 | func (c BroadcastErrorCode) String() string { 36 | switch c { 37 | case Invalid: 38 | return "Invalid" 39 | case InsufficientFee: 40 | return "InsufficientFee" 41 | case Mempool: 42 | return "Mempool" 43 | case Confirmed: 44 | return "Confirmed" 45 | default: 46 | return "Unknown" 47 | } 48 | } 49 | 50 | // BroadcastError is an error type that encompasses the different possible 51 | // broadcast errors returned by the network. 52 | type BroadcastError struct { 53 | // Code is the uniquely identifying code of the broadcast error. 54 | Code BroadcastErrorCode 55 | 56 | // Reason is the string detailing the reason as to why the transaction 57 | // was rejected. 58 | Reason string 59 | } 60 | 61 | // A compile-time constraint to ensure BroadcastError satisfies the error 62 | // interface. 63 | var _ error = (*BroadcastError)(nil) 64 | 65 | // Error returns the reason of the broadcast error. 66 | func (e *BroadcastError) Error() string { 67 | return e.Reason 68 | } 69 | 70 | // IsBroadcastError is a helper function that can be used to determine whether 71 | // an error is a BroadcastError that matches any of the specified codes. 72 | func IsBroadcastError(err error, codes ...BroadcastErrorCode) bool { 73 | broadcastErr, ok := err.(*BroadcastError) 74 | if !ok { 75 | return false 76 | } 77 | 78 | for _, code := range codes { 79 | if broadcastErr.Code == code { 80 | return true 81 | } 82 | } 83 | 84 | return false 85 | } 86 | 87 | // ParseBroadcastError maps a peer's reject message for a transaction to a 88 | // BroadcastError. 89 | func ParseBroadcastError(msg *wire.MsgReject, peerAddr string) *BroadcastError { 90 | // We'll determine the appropriate broadcast error code by looking at 91 | // the reject's message code and reason. The only reject codes returned 92 | // from peers (bitcoind and btcd) when attempting to accept a 93 | // transaction into their mempool are: 94 | // RejectInvalid, RejectNonstandard, RejectInsufficientFee, 95 | // RejectDuplicate 96 | var code BroadcastErrorCode 97 | switch { 98 | // The cases below apply for reject messages sent from any kind of peer. 99 | case msg.Code == wire.RejectInvalid || msg.Code == wire.RejectNonstandard: 100 | code = Invalid 101 | 102 | case msg.Code == wire.RejectInsufficientFee: 103 | code = InsufficientFee 104 | 105 | // The cases below apply for reject messages sent from bitcoind peers. 106 | // 107 | // If the transaction double spends an unconfirmed transaction in the 108 | // peer's mempool, then we'll deem it as invalid. 109 | case msg.Code == wire.RejectDuplicate && 110 | strings.Contains(msg.Reason, "txn-mempool-conflict"): 111 | code = Invalid 112 | 113 | // If the transaction was rejected due to it already existing in the 114 | // peer's mempool, then return an error signaling so. 115 | case msg.Code == wire.RejectDuplicate && 116 | strings.Contains(msg.Reason, "txn-already-in-mempool"): 117 | code = Mempool 118 | 119 | // If the transaction was rejected due to it already existing in the 120 | // chain according to our peer, then we'll return an error signaling so. 121 | case msg.Code == wire.RejectDuplicate && 122 | strings.Contains(msg.Reason, "txn-already-known"): 123 | code = Confirmed 124 | 125 | // The cases below apply for reject messages sent from btcd peers. 126 | // 127 | // If the transaction double spends an unconfirmed transaction in the 128 | // peer's mempool, then we'll deem it as invalid. 129 | case msg.Code == wire.RejectDuplicate && 130 | strings.Contains(msg.Reason, "already spent"): 131 | code = Invalid 132 | 133 | // If the transaction was rejected due to it already existing in the 134 | // peer's mempool, then return an error signaling so. 135 | case msg.Code == wire.RejectDuplicate && 136 | strings.Contains(msg.Reason, "already have transaction"): 137 | code = Mempool 138 | 139 | // If the transaction was rejected due to it already existing in the 140 | // chain according to our peer, then we'll return an error signaling so. 141 | case msg.Code == wire.RejectDuplicate && 142 | strings.Contains(msg.Reason, "transaction already exists"): 143 | code = Confirmed 144 | 145 | // Any other reject messages will use the unknown code. 146 | default: 147 | code = Unknown 148 | } 149 | 150 | reason := fmt.Sprintf("rejected by %v: %v", peerAddr, msg.Reason) 151 | return &BroadcastError{Code: code, Reason: reason} 152 | } 153 | -------------------------------------------------------------------------------- /pushtx/error_test.go: -------------------------------------------------------------------------------- 1 | package pushtx_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/btcsuite/btcd/wire" 7 | "github.com/lightninglabs/neutrino/pushtx" 8 | ) 9 | 10 | // TestParseBroadcastErrorCode ensures that we properly construct a 11 | // BroadcastError with the appropriate error code from a wire.MsgReject. 12 | func TestParseBroadcastErrorCode(t *testing.T) { 13 | t.Parallel() 14 | 15 | testCases := []struct { 16 | name string 17 | msg *wire.MsgReject 18 | code pushtx.BroadcastErrorCode 19 | }{ 20 | { 21 | name: "dust transaction", 22 | msg: &wire.MsgReject{ 23 | Code: wire.RejectDust, 24 | }, 25 | }, 26 | { 27 | name: "invalid transaction", 28 | msg: &wire.MsgReject{ 29 | Code: wire.RejectInvalid, 30 | Reason: "spends inexistent output", 31 | }, 32 | code: pushtx.Invalid, 33 | }, 34 | { 35 | name: "nonstandard transaction", 36 | msg: &wire.MsgReject{ 37 | Code: wire.RejectNonstandard, 38 | Reason: "", 39 | }, 40 | code: pushtx.Invalid, 41 | }, 42 | { 43 | name: "insufficient fee transaction", 44 | msg: &wire.MsgReject{ 45 | Code: wire.RejectInsufficientFee, 46 | Reason: "", 47 | }, 48 | code: pushtx.InsufficientFee, 49 | }, 50 | { 51 | name: "bitcoind mempool double spend", 52 | msg: &wire.MsgReject{ 53 | Code: wire.RejectDuplicate, 54 | Reason: "txn-mempool-conflict", 55 | }, 56 | code: pushtx.Invalid, 57 | }, 58 | { 59 | name: "bitcoind transaction in mempool", 60 | msg: &wire.MsgReject{ 61 | Code: wire.RejectDuplicate, 62 | Reason: "txn-already-in-mempool", 63 | }, 64 | code: pushtx.Mempool, 65 | }, 66 | { 67 | name: "bitcoind transaction in chain", 68 | msg: &wire.MsgReject{ 69 | Code: wire.RejectDuplicate, 70 | Reason: "txn-already-known", 71 | }, 72 | code: pushtx.Confirmed, 73 | }, 74 | { 75 | name: "btcd mempool double spend", 76 | msg: &wire.MsgReject{ 77 | Code: wire.RejectDuplicate, 78 | Reason: "already spent", 79 | }, 80 | code: pushtx.Invalid, 81 | }, 82 | { 83 | name: "btcd transaction in mempool", 84 | msg: &wire.MsgReject{ 85 | Code: wire.RejectDuplicate, 86 | Reason: "already have transaction", 87 | }, 88 | code: pushtx.Mempool, 89 | }, 90 | { 91 | name: "btcd transaction in chain", 92 | msg: &wire.MsgReject{ 93 | Code: wire.RejectDuplicate, 94 | Reason: "transaction already exists", 95 | }, 96 | code: pushtx.Confirmed, 97 | }, 98 | } 99 | 100 | for _, testCase := range testCases { 101 | test := testCase 102 | t.Run(test.name, func(t *testing.T) { 103 | t.Parallel() 104 | 105 | broadcastErr := pushtx.ParseBroadcastError( 106 | test.msg, "127.0.0.1:8333", 107 | ) 108 | if broadcastErr.Code != test.code { 109 | t.Fatalf("expected BroadcastErrorCode %v, got "+ 110 | "%v", test.code, broadcastErr.Code) 111 | } 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /pushtx/log.go: -------------------------------------------------------------------------------- 1 | package pushtx 2 | 3 | import "github.com/btcsuite/btclog" 4 | 5 | // log is a logger that is initialized with no output filters. This 6 | // means the package will not perform any logging by default until the caller 7 | // requests it. 8 | var log btclog.Logger 9 | 10 | // The default amount of logging is none. 11 | func init() { 12 | DisableLog() 13 | } 14 | 15 | // DisableLog disables all library log output. Logging output is disabled 16 | // by default until either UseLogger or SetLogWriter are called. 17 | func DisableLog() { 18 | UseLogger(btclog.Disabled) 19 | } 20 | 21 | // UseLogger uses a specified Logger to output package logging info. 22 | // This should be used in preference to SetLogWriter if the caller is also 23 | // using btclog. 24 | func UseLogger(logger btclog.Logger) { 25 | log = logger 26 | } 27 | -------------------------------------------------------------------------------- /query/interface.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/btcsuite/btcd/wire" 7 | ) 8 | 9 | const ( 10 | // defaultQueryTimeout specifies the default total time a query is 11 | // allowed to be retried before it will fail. 12 | defaultQueryTimeout = time.Second * 30 13 | 14 | // defaultQueryEncoding specifies the default encoding (witness or not) 15 | // for `getdata` and other similar messages. 16 | defaultQueryEncoding = wire.WitnessEncoding 17 | 18 | // defaultNumRetries is the default number of times that a query job 19 | // will be retried. 20 | defaultNumRetries = 2 21 | ) 22 | 23 | // queries are a set of options that can be modified per-query, unlike global 24 | // options. 25 | type queryOptions struct { 26 | // timeout specifies the total time a query is allowed to 27 | // be retried before it will fail. 28 | timeout time.Duration 29 | 30 | // encoding lets the query know which encoding to use when queueing 31 | // messages to a peer. 32 | encoding wire.MessageEncoding 33 | 34 | // cancelChan is an optional channel that can be closed to indicate 35 | // that the query should be canceled. 36 | cancelChan chan struct{} 37 | 38 | // numRetries is the number of times that a query should be retried 39 | // before failing. 40 | numRetries uint8 41 | 42 | // noRetryMax is set if no cap should be applied to the number of times 43 | // that a query can be retried. If this is set then numRetries has no 44 | // effect. 45 | noRetryMax bool 46 | } 47 | 48 | // QueryOption is a functional option argument to any of the network query 49 | // methods, such as GetBlock and GetCFilter (when that resorts to a network 50 | // query). These are always processed in order, with later options overriding 51 | // earlier ones. 52 | type QueryOption func(*queryOptions) // nolint 53 | 54 | // defaultQueryOptions returns a queryOptions set to package-level defaults. 55 | func defaultQueryOptions() *queryOptions { 56 | return &queryOptions{ 57 | timeout: defaultQueryTimeout, 58 | encoding: defaultQueryEncoding, 59 | numRetries: defaultNumRetries, 60 | } 61 | } 62 | 63 | // applyQueryOptions updates a queryOptions set with functional options. 64 | func (qo *queryOptions) applyQueryOptions(options ...QueryOption) { 65 | for _, option := range options { 66 | option(qo) 67 | } 68 | } 69 | 70 | // NumRetries is a query option that specifies the number of times a query 71 | // should be retried. 72 | func NumRetries(num uint8) QueryOption { 73 | return func(qo *queryOptions) { 74 | qo.numRetries = num 75 | } 76 | } 77 | 78 | // NoRetryMax is a query option that can be used to disable the cap on the 79 | // number of retries. If this is set then NumRetries has no effect. 80 | func NoRetryMax() QueryOption { 81 | return func(qo *queryOptions) { 82 | qo.noRetryMax = true 83 | } 84 | } 85 | 86 | // Timeout is a query option that specifies the total time a query is allowed 87 | // to be tried before it is failed. 88 | func Timeout(timeout time.Duration) QueryOption { 89 | return func(qo *queryOptions) { 90 | qo.timeout = timeout 91 | } 92 | } 93 | 94 | // Encoding is a query option that allows the caller to set a message encoding 95 | // for the query messages. 96 | func Encoding(encoding wire.MessageEncoding) QueryOption { 97 | return func(qo *queryOptions) { 98 | qo.encoding = encoding 99 | } 100 | } 101 | 102 | // Cancel takes a channel that can be closed to indicate that the query should 103 | // be canceled. 104 | func Cancel(cancel chan struct{}) QueryOption { 105 | return func(qo *queryOptions) { 106 | qo.cancelChan = cancel 107 | } 108 | } 109 | 110 | // Progress encloses the result of handling a response for a given Request, 111 | // determining whether the response did progress the query. 112 | type Progress struct { 113 | // Finished is true if the query was finished as a result of the 114 | // received response. 115 | Finished bool 116 | 117 | // Progressed is true if the query made progress towards fully 118 | // answering the request as a result of the received response. This is 119 | // used for the requests types where more than one response is 120 | // expected. 121 | Progressed bool 122 | } 123 | 124 | // Request is the main struct that defines a bitcoin network query to be sent to 125 | // connected peers. 126 | type Request struct { 127 | // Req is the message request to send. 128 | Req wire.Message 129 | 130 | // HandleResp is a response handler that will be called for every 131 | // message received from the peer that the request was made to. It 132 | // should validate the response against the request made, and return a 133 | // Progress indicating whether the request was answered by this 134 | // particular response. 135 | // 136 | // NOTE: Since the worker's job queue will be stalled while this method 137 | // is running, it should not be doing any expensive operations. It 138 | // should validate the response and immediately return the progress. 139 | // The response should be handed off to another goroutine for 140 | // processing. 141 | HandleResp func(req, resp wire.Message, peer string) Progress 142 | } 143 | 144 | // WorkManager defines an API for a manager that dispatches queries to bitcoin 145 | // peers that must be started and stopped in order to perform these queries. 146 | type WorkManager interface { 147 | Dispatcher 148 | 149 | // Start sets up any resources that the WorkManager requires. It must 150 | // be called before any of the Dispatcher calls can be made. 151 | Start() error 152 | 153 | // Stop cleans up the resources held by the WorkManager. 154 | Stop() error 155 | } 156 | 157 | // Dispatcher is an interface defining the API for dispatching queries to 158 | // bitcoin peers. 159 | type Dispatcher interface { 160 | // Query distributes the slice of requests to the set of connected 161 | // peers. It returns an error channel where the final result of the 162 | // batch of queries will be sent. Responses for the individual queries 163 | // should be handled by the response handler of each Request. 164 | Query(reqs []*Request, options ...QueryOption) chan error 165 | } 166 | 167 | // Peer is the interface that defines the methods needed by the query package 168 | // to be able to make requests and receive responses from a network peer. 169 | type Peer interface { 170 | // QueueMessageWithEncoding adds the passed bitcoin message to the peer 171 | // send queue. 172 | QueueMessageWithEncoding(msg wire.Message, doneChan chan<- struct{}, 173 | encoding wire.MessageEncoding) 174 | 175 | // SubscribeRecvMsg adds a OnRead subscription to the peer. All bitcoin 176 | // messages received from this peer will be sent on the returned 177 | // channel. A closure is also returned, that should be called to cancel 178 | // the subscription. 179 | SubscribeRecvMsg() (<-chan wire.Message, func()) 180 | 181 | // Addr returns the address of this peer. 182 | Addr() string 183 | 184 | // OnDisconnect returns a channel that will be closed when this peer is 185 | // disconnected. 186 | OnDisconnect() <-chan struct{} 187 | } 188 | -------------------------------------------------------------------------------- /query/log.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import "github.com/btcsuite/btclog" 4 | 5 | // log is a logger that is initialized with no output filters. This 6 | // means the package will not perform any logging by default until the caller 7 | // requests it. 8 | var log btclog.Logger 9 | 10 | // The default amount of logging is none. 11 | func init() { 12 | DisableLog() 13 | } 14 | 15 | // DisableLog disables all library log output. Logging output is disabled 16 | // by default until either UseLogger or SetLogWriter are called. 17 | func DisableLog() { 18 | UseLogger(btclog.Disabled) 19 | } 20 | 21 | // UseLogger uses a specified Logger to output package logging info. 22 | // This should be used in preference to SetLogWriter if the caller is also 23 | // using btclog. 24 | func UseLogger(logger btclog.Logger) { 25 | log = logger 26 | } 27 | -------------------------------------------------------------------------------- /query/peer_rank.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "sort" 5 | ) 6 | 7 | const ( 8 | // bestScore is the best score a peer can get after multiple rewards. 9 | bestScore = 0 10 | 11 | // defaultScore is the score given to a peer when it hasn't been 12 | // rewarded or punished. 13 | defaultScore = 4 14 | 15 | // worstScore is the worst score a peer can get after multiple 16 | // punishments. 17 | worstScore = 8 18 | ) 19 | 20 | // peerRanking is a struct that keeps history of peer's previous query success 21 | // rate, and uses that to prioritise which peers to give the next queries to. 22 | type peerRanking struct { 23 | // rank keeps track of the current set of peers and their score. A 24 | // lower score is better. 25 | rank map[string]uint64 26 | } 27 | 28 | // A compile time check to ensure peerRanking satisfies the PeerRanking 29 | // interface. 30 | var _ PeerRanking = (*peerRanking)(nil) 31 | 32 | // NewPeerRanking returns a new, empty ranking. 33 | func NewPeerRanking() PeerRanking { 34 | return &peerRanking{ 35 | rank: make(map[string]uint64), 36 | } 37 | } 38 | 39 | // Order sorts the given slice of peers based on their current score. If a 40 | // peer has no current score given, the default will be used. 41 | func (p *peerRanking) Order(peers []string) { 42 | sort.Slice(peers, func(i, j int) bool { 43 | score1, ok := p.rank[peers[i]] 44 | if !ok { 45 | score1 = defaultScore 46 | } 47 | 48 | score2, ok := p.rank[peers[j]] 49 | if !ok { 50 | score2 = defaultScore 51 | } 52 | return score1 < score2 53 | }) 54 | } 55 | 56 | // AddPeer adds a new peer to the ranking, starting out with the default score. 57 | func (p *peerRanking) AddPeer(peer string) { 58 | if _, ok := p.rank[peer]; ok { 59 | return 60 | } 61 | p.rank[peer] = defaultScore 62 | } 63 | 64 | // Punish increases the score of the given peer. 65 | func (p *peerRanking) Punish(peer string) { 66 | score, ok := p.rank[peer] 67 | if !ok { 68 | return 69 | } 70 | 71 | // Cannot punish more. 72 | if score == worstScore { 73 | return 74 | } 75 | 76 | p.rank[peer] = score + 1 77 | } 78 | 79 | // Reward decreases the score of the given peer. 80 | // TODO(halseth): use actual response time when ranking peers. 81 | func (p *peerRanking) Reward(peer string) { 82 | score, ok := p.rank[peer] 83 | if !ok { 84 | return 85 | } 86 | 87 | // Cannot reward more. 88 | if score == bestScore { 89 | return 90 | } 91 | 92 | p.rank[peer] = score - 1 93 | } 94 | 95 | // ResetRanking sets the score of the passed peer to the defaultScore. 96 | func (p *peerRanking) ResetRanking(peer string) { 97 | _, ok := p.rank[peer] 98 | if !ok { 99 | return 100 | } 101 | 102 | p.rank[peer] = defaultScore 103 | } 104 | -------------------------------------------------------------------------------- /query/peer_rank_test.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | // TestPeerRank checks that the peerRanking correctly orders peers according to 9 | // how they are rewarded and punished. 10 | func TestPeerRank(t *testing.T) { 11 | const numPeers = 8 12 | 13 | ranking := NewPeerRanking() 14 | var peers []string 15 | for i := 0; i < numPeers; i++ { 16 | p := fmt.Sprintf("peer%d", i) 17 | peers = append(peers, p) 18 | ranking.AddPeer(p) 19 | } 20 | 21 | // We'll try to order half of the peers. 22 | peers = peers[:numPeers/2] 23 | ranking.Order(peers) 24 | 25 | // Since no peer was rewarded or punished, their order 26 | // should be unchanged. 27 | for i := 0; i < numPeers/2; i++ { 28 | p := fmt.Sprintf("peer%d", i) 29 | if peers[i] != p { 30 | t.Fatalf("expected %v, got %v", p, peers[i]) 31 | } 32 | } 33 | 34 | // Punish the first ones more, which should flip the order. 35 | for i := 0; i < numPeers/2; i++ { 36 | for j := 0; j <= i; j++ { 37 | ranking.Punish(peers[j]) 38 | } 39 | } 40 | 41 | ranking.Order(peers) 42 | for i := 0; i < numPeers/2; i++ { 43 | p := fmt.Sprintf("peer%d", numPeers/2-i-1) 44 | if peers[i] != p { 45 | t.Fatalf("expected %v, got %v", p, peers[i]) 46 | } 47 | } 48 | 49 | // This is the lowest scored peer after punishment. 50 | const lowestScoredPeer = "peer0" 51 | 52 | // Reward the lowest scored one a bunch, which should move it 53 | // to the front. 54 | for i := 0; i < 10; i++ { 55 | ranking.Reward(lowestScoredPeer) 56 | } 57 | 58 | ranking.Order(peers) 59 | if peers[0] != lowestScoredPeer { 60 | t.Fatalf("peer0 was not first") 61 | } 62 | 63 | // Punish the peer a bunch to make it the lowest scored one. 64 | for i := 0; i < 10; i++ { 65 | ranking.Punish(lowestScoredPeer) 66 | } 67 | 68 | ranking.Order(peers) 69 | if peers[len(peers)-1] != lowestScoredPeer { 70 | t.Fatalf("peer0 should be last") 71 | } 72 | 73 | // Reset its ranking. It should have the default score now 74 | // and should not be the lowest ranked peer. 75 | ranking.ResetRanking(lowestScoredPeer) 76 | ranking.Order(peers) 77 | if peers[len(peers)-1] == lowestScoredPeer { 78 | t.Fatalf("peer0 should not be last.") 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /query/worker.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/btcsuite/btcd/wire" 8 | ) 9 | 10 | var ( 11 | // ErrQueryTimeout is an error returned if the worker doesn't respond 12 | // with a valid response to the request within the timeout. 13 | ErrQueryTimeout = errors.New("did not get response before timeout") 14 | 15 | // ErrPeerDisconnected is returned if the worker's peer disconnect 16 | // before the query has been answered. 17 | ErrPeerDisconnected = errors.New("peer disconnected") 18 | 19 | // ErrJobCanceled is returned if the job is canceled before the query 20 | // has been answered. 21 | ErrJobCanceled = errors.New("job canceled") 22 | ) 23 | 24 | // queryJob is the internal struct that wraps the Query to work on, in 25 | // addition to some information about the query. 26 | type queryJob struct { 27 | tries uint8 28 | index uint64 29 | timeout time.Duration 30 | encoding wire.MessageEncoding 31 | cancelChan <-chan struct{} 32 | internalCancelChan <-chan struct{} 33 | *Request 34 | } 35 | 36 | // queryJob should satisfy the Task interface in order to be sorted by the 37 | // workQueue. 38 | var _ Task = (*queryJob)(nil) 39 | 40 | // Index returns the queryJob's index within the work queue. 41 | // 42 | // NOTE: Part of the Task interface. 43 | func (q *queryJob) Index() uint64 { 44 | return q.index 45 | } 46 | 47 | // jobResult is the final result of the worker's handling of the queryJob. 48 | type jobResult struct { 49 | job *queryJob 50 | peer Peer 51 | err error 52 | } 53 | 54 | // worker is responsible for polling work from its work queue, and handing it 55 | // to the associated peer. It validates incoming responses with the current 56 | // query's response handler, and polls more work for the peer when it has 57 | // successfully received a response to the request. 58 | type worker struct { 59 | peer Peer 60 | 61 | // nextJob is a channel of queries to be distributed, where the worker 62 | // will poll new work from. 63 | nextJob chan *queryJob 64 | } 65 | 66 | // A compile-time check to ensure worker satisfies the Worker interface. 67 | var _ Worker = (*worker)(nil) 68 | 69 | // NewWorker creates a new worker associated with the given peer. 70 | func NewWorker(peer Peer) Worker { 71 | return &worker{ 72 | peer: peer, 73 | nextJob: make(chan *queryJob), 74 | } 75 | } 76 | 77 | // Run starts the worker. The worker will supply its peer with queries, and 78 | // handle responses from it. Results for any query handled by this worker will 79 | // be delivered on the results channel. quit can be closed to immediately make 80 | // the worker exit. 81 | // 82 | // The method is blocking, and should be started in a goroutine. It will run 83 | // until the peer disconnects or the worker is told to quit. 84 | // 85 | // NOTE: Part of the Worker interface. 86 | func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { 87 | peer := w.peer 88 | 89 | // Subscribe to messages from the peer. 90 | msgChan, cancel := peer.SubscribeRecvMsg() 91 | defer cancel() 92 | 93 | for { 94 | log.Tracef("Worker %v waiting for more work", peer.Addr()) 95 | 96 | var job *queryJob 97 | select { 98 | // Poll a new job from the nextJob channel. 99 | case job = <-w.nextJob: 100 | log.Tracef("Worker %v picked up job with index %v", 101 | peer.Addr(), job.Index()) 102 | 103 | // Ignore any message received while not working on anything. 104 | case msg := <-msgChan: 105 | log.Tracef("Worker %v ignoring received msg %T "+ 106 | "since no job active", peer.Addr(), msg) 107 | continue 108 | 109 | // If the peer disconnected, we can exit immediately, as we 110 | // weren't working on a query. 111 | case <-peer.OnDisconnect(): 112 | log.Debugf("Peer %v for worker disconnected", 113 | peer.Addr()) 114 | return 115 | 116 | case <-quit: 117 | return 118 | } 119 | 120 | select { 121 | // There is no point in queueing the request if the job already 122 | // is canceled, so we check this quickly. 123 | case <-job.cancelChan: 124 | log.Tracef("Worker %v found job with index %v "+ 125 | "already canceled", peer.Addr(), job.Index()) 126 | 127 | // We break to the below loop, where we'll check the 128 | // cancel channel again and the ErrJobCanceled 129 | // result will be sent back. 130 | break 131 | 132 | case <-job.internalCancelChan: 133 | log.Tracef("Worker %v found job with index %v "+ 134 | "already internally canceled (batch timed out)", 135 | peer.Addr(), job.Index()) 136 | 137 | // We break to the below loop, where we'll check the 138 | // internal cancel channel again and the ErrJobCanceled 139 | // result will be sent back. 140 | break 141 | 142 | // We received a non-canceled query job, send it to the peer. 143 | default: 144 | log.Tracef("Worker %v queuing job %T with index %v", 145 | peer.Addr(), job.Req, job.Index()) 146 | 147 | peer.QueueMessageWithEncoding(job.Req, nil, job.encoding) 148 | } 149 | 150 | // Wait for the correct response to be received from the peer, 151 | // or an error happening. 152 | var ( 153 | jobErr error 154 | timeout = time.NewTimer(job.timeout) 155 | ) 156 | 157 | Loop: 158 | for { 159 | select { 160 | // A message was received from the peer, use the 161 | // response handler to check whether it was answering 162 | // our request. 163 | case resp := <-msgChan: 164 | progress := job.HandleResp( 165 | job.Req, resp, peer.Addr(), 166 | ) 167 | 168 | log.Tracef("Worker %v handled msg %T while "+ 169 | "waiting for response to %T (job=%v). "+ 170 | "Finished=%v, progressed=%v", 171 | peer.Addr(), resp, job.Req, job.Index(), 172 | progress.Finished, progress.Progressed) 173 | 174 | // If the response did not answer our query, we 175 | // check whether it did progress it. 176 | if !progress.Finished { 177 | // If it did make progress we reset the 178 | // timeout. This ensures that the 179 | // queries with multiple responses 180 | // expected won't timeout before all 181 | // responses have been handled. 182 | // TODO(halseth): separate progress 183 | // timeout value. 184 | if progress.Progressed { 185 | timeout.Stop() 186 | timeout = time.NewTimer( 187 | job.timeout, 188 | ) 189 | } 190 | continue Loop 191 | } 192 | 193 | // We did get a valid response, and can break 194 | // the loop. 195 | break Loop 196 | 197 | // If the timeout is reached before a valid response 198 | // has been received, we exit with an error. 199 | case <-timeout.C: 200 | // The query did experience a timeout and will 201 | // be given to someone else. 202 | jobErr = ErrQueryTimeout 203 | log.Tracef("Worker %v timeout for request %T "+ 204 | "with job index %v", peer.Addr(), 205 | job.Req, job.Index()) 206 | 207 | break Loop 208 | 209 | // If the peer disconnects before giving us a valid 210 | // answer, we'll also exit with an error. 211 | case <-peer.OnDisconnect(): 212 | log.Debugf("Peer %v for worker disconnected, "+ 213 | "cancelling job %v", peer.Addr(), 214 | job.Index()) 215 | 216 | jobErr = ErrPeerDisconnected 217 | break Loop 218 | 219 | // If the job was canceled, we report this back to the 220 | // work manager. 221 | case <-job.cancelChan: 222 | log.Tracef("Worker %v job %v canceled", 223 | peer.Addr(), job.Index()) 224 | 225 | jobErr = ErrJobCanceled 226 | break Loop 227 | 228 | case <-job.internalCancelChan: 229 | log.Tracef("Worker %v job %v internally "+ 230 | "canceled", peer.Addr(), job.Index()) 231 | 232 | jobErr = ErrJobCanceled 233 | break Loop 234 | 235 | case <-quit: 236 | return 237 | } 238 | } 239 | 240 | // Stop to allow garbage collection. 241 | timeout.Stop() 242 | 243 | // We have a result ready for the query, hand it off before 244 | // getting a new job. 245 | select { 246 | case results <- &jobResult{ 247 | job: job, 248 | peer: peer, 249 | err: jobErr, 250 | }: 251 | case <-quit: 252 | return 253 | } 254 | 255 | // If the peer disconnected, we can exit immediately. 256 | if jobErr == ErrPeerDisconnected { 257 | return 258 | } 259 | } 260 | } 261 | 262 | // NewJob returns a channel where work that is to be handled by the worker can 263 | // be sent. If the worker reads a queryJob from this channel, it is guaranteed 264 | // that a response will eventually be deliverd on the results channel (except 265 | // when the quit channel has been closed). 266 | // 267 | // NOTE: Part of the Worker interface. 268 | func (w *worker) NewJob() chan<- *queryJob { 269 | return w.nextJob 270 | } 271 | -------------------------------------------------------------------------------- /query/workqueue.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | // Task is an interface that has a method for returning their index in the 4 | // work queue. 5 | type Task interface { 6 | // Index returns this Task's index in the work queue. 7 | Index() uint64 8 | } 9 | 10 | // workQueue is struct implementing the heap interface, and is used to keep a 11 | // list of remaining queryTasks in order. 12 | type workQueue struct { 13 | tasks []Task 14 | } 15 | 16 | // Len returns the number of nodes in the priority queue. 17 | // 18 | // NOTE: This is part of the heap.Interface implementation. 19 | func (w *workQueue) Len() int { return len(w.tasks) } 20 | 21 | // Less returns whether the item in the priority queue with index i should sort 22 | // before the item with index j. 23 | // 24 | // NOTE: This is part of the heap.Interface implementation. 25 | func (w *workQueue) Less(i, j int) bool { 26 | return w.tasks[i].Index() < w.tasks[j].Index() 27 | } 28 | 29 | // Swap swaps the nodes at the passed indices in the priority queue. 30 | // 31 | // NOTE: This is part of the heap.Interface implementation. 32 | func (w *workQueue) Swap(i, j int) { 33 | w.tasks[i], w.tasks[j] = w.tasks[j], w.tasks[i] 34 | } 35 | 36 | // Push add x as element Len(). 37 | // 38 | // NOTE: This is part of the heap.Interface implementation. 39 | func (w *workQueue) Push(x interface{}) { 40 | w.tasks = append(w.tasks, x.(Task)) 41 | } 42 | 43 | // Pop removes and returns element Len()-1. 44 | // 45 | // NOTE: This is part of the heap.Interface implementation. 46 | func (w *workQueue) Pop() interface{} { 47 | n := len(w.tasks) 48 | x := w.tasks[n-1] 49 | w.tasks[n-1] = nil 50 | w.tasks = w.tasks[0 : n-1] 51 | return x 52 | } 53 | 54 | // Peek returns the first item in the queue. 55 | func (w *workQueue) Peek() interface{} { 56 | return w.tasks[0] 57 | } 58 | -------------------------------------------------------------------------------- /query/workqueue_test.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "container/heap" 5 | "testing" 6 | ) 7 | 8 | type task struct { 9 | index uint64 10 | } 11 | 12 | var _ Task = (*task)(nil) 13 | 14 | func (t *task) Index() uint64 { 15 | return t.index 16 | } 17 | 18 | // TestWorkQueue makes sure workQueue implements the desired behaviour. 19 | func TestWorkQueue(t *testing.T) { 20 | t.Parallel() 21 | 22 | const numTasks = 20 23 | 24 | // Create a workQueue. 25 | q := &workQueue{} 26 | heap.Init(q) 27 | 28 | // Create a simple list of tasks and add them all to the queue. 29 | var tasks []*task 30 | for i := uint64(0); i < numTasks; i++ { 31 | tasks = append(tasks, &task{ 32 | index: i, 33 | }) 34 | } 35 | 36 | for _, t := range tasks { 37 | heap.Push(q, t) 38 | } 39 | 40 | // Check that it reports the expected number of elements. 41 | l := q.Len() 42 | if l != numTasks { 43 | t.Fatalf("expected %d length, was %d", numTasks, l) 44 | } 45 | 46 | // Pop half, and make sure they arrive in the right order. 47 | for i := uint64(0); i < numTasks/2; i++ { 48 | peek := q.Peek().(*task) 49 | pop := heap.Pop(q) 50 | 51 | // We expect the peeked and popped element to be the same. 52 | if peek != pop { 53 | t.Fatalf("peek and pop mismatch") 54 | } 55 | 56 | if peek.index != i { 57 | t.Fatalf("wrong index: %d", peek.index) 58 | } 59 | } 60 | 61 | // Insert 3 elements with index 0. 62 | for j := 0; j < 3; j++ { 63 | heap.Push(q, tasks[0]) 64 | } 65 | 66 | for i := uint64(numTasks/2 - 3); i < numTasks; i++ { 67 | peek := q.Peek().(*task) 68 | pop := heap.Pop(q) 69 | 70 | // We expect the peeked and popped element to be the same. 71 | if peek != pop { 72 | t.Fatalf("peek and pop mismatch") 73 | } 74 | 75 | // First three element should have index 0, rest should have 76 | // index i. 77 | exp := i 78 | if i < numTasks/2 { 79 | exp = 0 80 | } 81 | 82 | if peek.index != exp { 83 | t.Fatalf("wrong index: %d", peek.index) 84 | } 85 | } 86 | 87 | // Finally, the queue should be empty. 88 | l = q.Len() 89 | if l != 0 { 90 | t.Fatalf("expected %d length, was %d", 0, l) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /testdata/blocks1-256.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightninglabs/neutrino/ef73743b717195b30481ad16d7fd948867da11e2/testdata/blocks1-256.bz2 -------------------------------------------------------------------------------- /tools/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.23.6-bookworm 2 | 3 | RUN apt-get update && apt-get install -y git 4 | ENV GOCACHE=/tmp/build/.cache 5 | ENV GOMODCACHE=/tmp/build/.modcache 6 | 7 | COPY . /tmp/tools 8 | 9 | RUN cd /tmp \ 10 | && mkdir -p /tmp/build/.cache \ 11 | && mkdir -p /tmp/build/.modcache \ 12 | && cd /tmp/tools \ 13 | && go install -trimpath -tags=tools github.com/golangci/golangci-lint/cmd/golangci-lint \ 14 | && chmod -R 777 /tmp/build/ \ 15 | && git config --global --add safe.directory /build 16 | 17 | 18 | WORKDIR /build 19 | -------------------------------------------------------------------------------- /tools/tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | package neutrino 5 | 6 | // The other imports represent our build tools. Instead of defining a commit we 7 | // want to use for those golang based tools, we use the go mod versioning system 8 | // to unify the way we manage dependencies. So we define our build tool 9 | // dependencies here and pin the version in go.mod. 10 | import ( 11 | _ "github.com/btcsuite/btcd" 12 | _ "github.com/golangci/golangci-lint/cmd/golangci-lint" 13 | _ "github.com/ory/go-acc" 14 | _ "github.com/rinchsan/gosimports/cmd/gosimports" 15 | ) 16 | -------------------------------------------------------------------------------- /verification.go: -------------------------------------------------------------------------------- 1 | package neutrino 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/btcsuite/btcd/btcutil" 7 | "github.com/btcsuite/btcd/btcutil/gcs" 8 | "github.com/btcsuite/btcd/btcutil/gcs/builder" 9 | "github.com/btcsuite/btcd/txscript" 10 | ) 11 | 12 | // VerifyBasicBlockFilter asserts that a given block filter was constructed 13 | // correctly and according to the rules of BIP-0158 to contain both the output's 14 | // pk scripts as well as the pk scripts the inputs are spending. 15 | func VerifyBasicBlockFilter(filter *gcs.Filter, block *btcutil.Block) (int, 16 | error) { 17 | 18 | var ( 19 | opReturnMatches int 20 | key = builder.DeriveKey(block.Hash()) 21 | ) 22 | for idx, tx := range block.Transactions() { 23 | // Skip coinbase transaction. 24 | if idx == 0 { 25 | continue 26 | } 27 | 28 | // Check outputs first. 29 | for outIdx, txOut := range tx.MsgTx().TxOut { 30 | switch { 31 | // If the script itself is blank, then we'll skip this 32 | // as it doesn't contain any useful information. 33 | case len(txOut.PkScript) == 0: 34 | continue 35 | 36 | // We'll also skip any OP_RETURN scripts as well since 37 | // we don't index these in order to avoid a circular 38 | // dependency. 39 | case txOut.PkScript[0] == txscript.OP_RETURN: 40 | // Previous versions of the filters did include 41 | // OP_RETURNs. To be able disconnect bad peers 42 | // still serving these old filters we attempt to 43 | // check if there's an unexpected match. Since 44 | // there might be false positives, an OP_RETURN 45 | // can still match filters not including them. 46 | // Therefore, we count the number of such 47 | // unexpected matches for each peer, such that 48 | // we can ban peers matching more than the rest. 49 | match, err := filter.Match(key, txOut.PkScript) 50 | if err != nil { 51 | // Mark peer bad if we cannot match on 52 | // its filter. 53 | return 0, fmt.Errorf("error "+ 54 | "validating block %v outpoint "+ 55 | "%v:%d script %x: %v", 56 | block.Hash(), tx.Hash(), outIdx, 57 | txOut.PkScript, err) 58 | } 59 | 60 | // If it matches on the OP_RETURN output, we 61 | // increase the op return counter. 62 | if match { 63 | opReturnMatches++ 64 | } 65 | 66 | continue 67 | } 68 | 69 | // This is a "normal" script where we definitely expect 70 | // a match. 71 | match, err := filter.Match(key, txOut.PkScript) 72 | if err != nil { 73 | return 0, fmt.Errorf("error validating block "+ 74 | "%v outpoint %v:%d script %x: %v", 75 | block.Hash(), tx.Hash(), outIdx, 76 | txOut.PkScript, err) 77 | } 78 | 79 | if !match { 80 | return 0, fmt.Errorf("filter for block %v is "+ 81 | "invalid, outpoint %v:%d script %x "+ 82 | "wasn't matched by filter", 83 | block.Hash(), tx.Hash(), outIdx, 84 | txOut.PkScript) 85 | } 86 | } 87 | 88 | // Now we can go through all inputs and check that the filter 89 | // also included any pk scripts of the outputs being _spent_. 90 | // We can do this for witness items since the witness always 91 | // contains the full script as the last element on the stack. 92 | for inIdx, in := range tx.MsgTx().TxIn { 93 | // There are too many edge cases to cover for non- 94 | // witness scripts. And in LN land we're interested in 95 | // witness spends only anyway. Therefore let's skip any 96 | // input that has no witness. 97 | // 98 | // TODO(guggero): Add all those edge cases to 99 | // ComputePkScript? 100 | if len(in.Witness) == 0 { 101 | continue 102 | } 103 | 104 | // The only input type that has both set is a nested 105 | // P2PKH (P2SH-P2WKH). We can verify that one because 106 | // the script hash has to be HASH160(OP_PUSH32 ). 107 | script, err := txscript.ComputePkScript( 108 | in.SignatureScript, in.Witness, 109 | ) 110 | 111 | // Just skip any inputs that we can't derive the pk 112 | // script from. 113 | if err == txscript.ErrUnsupportedScriptType { 114 | log.Tracef("Skipping filter validation for "+ 115 | "input %d of tx %v in block %v "+ 116 | "because script type is not supported "+ 117 | "for validating against filter", inIdx, 118 | tx.Hash(), block.Hash()) 119 | 120 | continue 121 | } 122 | 123 | // Something else went wrong. We can't really say the 124 | // filter is faulty though so we also just skip over 125 | // this input. 126 | if err != nil { 127 | log.Debug("Skipping filter validation for "+ 128 | "input %d of tx %v in block %v "+ 129 | "because computing the script failed: "+ 130 | "%v", inIdx, block.Hash(), err) 131 | 132 | continue 133 | } 134 | 135 | match, err := filter.Match(key, script.Script()) 136 | if err != nil { 137 | return 0, fmt.Errorf("error validating block "+ 138 | "%v input %d of tx %v script %x: %v", 139 | block.Hash(), inIdx, tx.Hash(), 140 | script.Script(), err) 141 | } 142 | 143 | if !match { 144 | log.Debugf("filter for block %v might be "+ 145 | "invalid, input %d of tx %v spends "+ 146 | "pk script %x which wasn't matched by "+ 147 | "filter. The input likely spends a "+ 148 | "taproot output which is not yet"+ 149 | "supported", block.Hash(), inIdx, 150 | tx.Hash(), script.Script()) 151 | } 152 | } 153 | } 154 | 155 | return opReturnMatches, nil 156 | } 157 | --------------------------------------------------------------------------------