├── .github ├── FUNDING.yml ├── actions │ └── setup-massdns │ │ └── action.yml └── workflows │ ├── build.yml │ └── release.yml ├── .gitignore ├── .vscode └── settings.json ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── assets ├── become-sponsor.jpg ├── puredns-logo.png ├── puredns-operation.gif └── puredns-terminal.png ├── go.mod ├── go.sum ├── internal ├── app │ ├── cmd │ │ ├── bruteforce.go │ │ ├── bruteforce_test.go │ │ ├── resolve.go │ │ ├── resolve_test.go │ │ ├── root.go │ │ ├── root_test.go │ │ ├── sponsors.go │ │ └── sponsors_test.go │ ├── ctx │ │ ├── ctx.go │ │ ├── ctx_test.go │ │ ├── options.go │ │ └── options_test.go │ ├── stdin.go │ ├── stdin_test.go │ └── version.go ├── pkg │ └── console │ │ ├── color.go │ │ ├── console.go │ │ └── console_test.go └── usecase │ ├── programbanner │ ├── programbanner.go │ └── programbanner_test.go │ ├── resolve │ ├── cachereader.go │ ├── cachereader_test.go │ ├── domainreader.go │ ├── domainreader_test.go │ ├── massresolver.go │ ├── massresolver_test.go │ ├── requirementchecker.go │ ├── requirementchecker_test.go │ ├── resolve.go │ ├── resolve_test.go │ ├── resolverloader.go │ ├── resolverloader_test.go │ ├── resultsaver.go │ ├── resultsaver_test.go │ ├── sanitizer.go │ ├── sanitizer_test.go │ ├── stubs_test.go │ ├── wildcardfilter.go │ ├── wildcardfilter_test.go │ ├── workfilecreator.go │ └── workfilecreator_test.go │ └── sponsors │ ├── service.go │ └── service_test.go ├── main.go ├── main_test.go └── pkg ├── fileoperation ├── appendlines.go ├── appendlines_test.go ├── appendword.go ├── appendword_test.go ├── cat.go ├── cat_test.go ├── copy.go ├── copy_test.go ├── countlines.go ├── countlines_test.go ├── doc.go ├── fileexists.go ├── fileexists_test.go ├── readlines.go ├── readlines_test.go ├── writelines.go └── writelines_test.go ├── filetest ├── doc.go ├── file.go ├── file_test.go ├── stdin.go ├── stubreader.go ├── stubreader_test.go ├── stubwriter.go └── stubwriter_test.go ├── massdns ├── callback.go ├── callback_test.go ├── doc.go ├── linereader.go ├── linereader_test.go ├── resolver.go ├── resolver_test.go ├── runner.go ├── runner_test.go ├── stdouthandler.go ├── stdouthandler_test.go └── type.go ├── procreader ├── doc.go ├── procreader.go └── procreader_test.go ├── progressbar ├── doc.go ├── movingrate.go ├── movingrate_test.go ├── options.go ├── progressbar.go ├── progressbar_test.go └── style.go ├── shellexecutor ├── doc.go ├── shellexecutor.go └── shellexecutor_test.go ├── threadpool ├── doc.go ├── threadpool.go ├── threadpool_test.go └── worker.go └── wildcarder ├── answercache.go ├── answercache_test.go ├── clientdns.go ├── clientdns_test.go ├── detectiontask.go ├── detectiontask_test.go ├── dnscache.go ├── dnscache_test.go ├── doc.go ├── gather.go ├── gather_test.go ├── hashing_hash.go ├── hashing_string.go ├── randomsub.go ├── wildcarder.go └── wildcarder_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: d3mondev 2 | -------------------------------------------------------------------------------- /.github/actions/setup-massdns/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Setup massdns' 2 | runs: 3 | using: "composite" 4 | steps: 5 | - run: | 6 | sudo git clone https://github.com/blechschmidt/massdns.git /usr/local/src/massdns 7 | cd /usr/local/src/massdns 8 | if [[ "${{ runner.os }}" == "Linux" ]]; then 9 | sudo make 10 | else 11 | sudo make nolinux 12 | fi 13 | sudo make install 14 | shell: bash 15 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | tags-ignore: 8 | - '**' 9 | pull_request: 10 | branches: 11 | - '**' 12 | 13 | jobs: 14 | lint_and_test: 15 | name: Lint and test 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v3 20 | 21 | - name: Setup Go 22 | uses: actions/setup-go@v4 23 | with: 24 | go-version: 1.20.x 25 | 26 | - name: Lint 27 | run: | 28 | go install golang.org/x/lint/golint@latest 29 | go install honnef.co/go/tools/cmd/staticcheck@latest 30 | make lint 31 | 32 | - name: Setup massdns 33 | uses: ./.github/actions/setup-massdns 34 | 35 | - name: Test 36 | run: make test 37 | 38 | - name: Code coverage 39 | run: make cover 40 | 41 | - name: Upload coverage to codecov.io 42 | uses: codecov/codecov-action@v3 43 | 44 | build: 45 | name: Build 46 | runs-on: ubuntu-latest 47 | strategy: 48 | matrix: 49 | go: ["1.18.x", "1.19.x", "1.20.x"] 50 | 51 | steps: 52 | - name: Checkout code 53 | uses: actions/checkout@v3 54 | 55 | - name: Setup massdns 56 | uses: ./.github/actions/setup-massdns 57 | 58 | - name: Setup Go 59 | uses: actions/setup-go@v4 60 | with: 61 | go-version: ${{ matrix.go }} 62 | 63 | - name: Load cached dependencies 64 | uses: actions/cache@v3 65 | with: 66 | path: ~/go/pkg/mod 67 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 68 | 69 | - name: Download dependencies 70 | run: go mod download 71 | 72 | - name: Build 73 | run: make 74 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | lint_and_test: 10 | name: Lint and test 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v3 15 | 16 | - name: Setup Go 17 | uses: actions/setup-go@v4 18 | with: 19 | go-version: 1.20.x 20 | 21 | - name: Cache dependencies 22 | uses: actions/cache@v3 23 | with: 24 | path: ~/go/pkg/mod 25 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 26 | restore-keys: | 27 | ${{ runner.os }}-go- 28 | 29 | - name: Download dependencies 30 | run: go mod download 31 | 32 | - name: Lint 33 | run: | 34 | go install golang.org/x/lint/golint@latest 35 | go install honnef.co/go/tools/cmd/staticcheck@latest 36 | make lint 37 | 38 | - name: Setup massdns 39 | uses: ./.github/actions/setup-massdns 40 | 41 | - name: Test 42 | run: make test 43 | 44 | build: 45 | name: Build 46 | runs-on: ${{ matrix.os }} 47 | strategy: 48 | matrix: 49 | include: 50 | - os: ubuntu-latest 51 | platform: amd64 52 | - os: ubuntu-latest 53 | platform: arm64 54 | - os: macos-latest 55 | platform: amd64 56 | - os: macos-latest 57 | platform: arm64 58 | 59 | steps: 60 | - name: Checkout code 61 | uses: actions/checkout@v3 62 | 63 | - name: Setup Go 64 | uses: actions/setup-go@v4 65 | with: 66 | go-version: 1.20.x 67 | 68 | - name: Cache dependencies 69 | uses: actions/cache@v3 70 | with: 71 | path: ~/go/pkg/mod 72 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 73 | restore-keys: | 74 | ${{ runner.os }}-go- 75 | 76 | - name: Download dependencies 77 | run: go mod download 78 | 79 | - name: Build 80 | env: 81 | CGO_ENABLED: 0 82 | GOARCH: ${{ matrix.platform }} 83 | run: | 84 | make 85 | tar czf puredns-${{ runner.os }}-${{ matrix.platform }}.tgz puredns 86 | 87 | - name: Upload binaries 88 | uses: actions/upload-artifact@v3 89 | with: 90 | name: binaries 91 | path: puredns-${{ runner.os }}-${{ matrix.platform }}.tgz 92 | 93 | release: 94 | name: Release 95 | needs: [build] 96 | runs-on: ubuntu-latest 97 | steps: 98 | - name: Download binaries 99 | uses: actions/download-artifact@v3 100 | 101 | - name: Create Release 102 | id: create_release 103 | uses: softprops/action-gh-release@v1 104 | with: 105 | files: binaries/puredns-*.tgz 106 | env: 107 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /puredns 2 | /__debug_bin 3 | /.vscode/launch.json 4 | *.swp 5 | *.prof 6 | *.mprof 7 | cover.out 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "gopls": { 3 | "buildFlags": [ 4 | "-tags=no_hashing" 5 | ], 6 | }, 7 | "cmake.configureOnOpen": false 8 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | ### Fixed 9 | - Wildcard filtering not working when wildcard tests put the domain over the 253 character limit. 10 | - Sanitizer accepts subdomains with underscores in them. 11 | 12 | ## [2.1.1] - 2023-04-11 13 | ### Fixed 14 | - Wrong version number in binary releases 15 | 16 | ## [2.1.0] - 2023-04-11 17 | ### Added 18 | - Added the ability to bruteforce multiple domains simultaneously using the `-d`, `--domains` option with the bruteforce command, rather than providing just one domain as an argument. Now, executing `puredns bruteforce wordlist.txt -d domains.txt` will bruteforce all domains listed in the domains.txt file. [#13](https://github.com/d3mondev/puredns/issues/13) 19 | - Added a new option to use trusted resolvers only: `--trusted-only`. This can help quickly validate small domain lists with less risk of errors due to bad public resolvers. When this option is set, `--skip-validation` is also implied. [#11](https://github.com/d3mondev/puredns/issues/11) 20 | - Introduced the ability to use the * wildcard character when bruteforcing subdomains, enabling users to specify the desired location for word substitution. For example, executing `puredns bruteforce wordlist.txt "www.*.example.com"` will replace * with words from the wordlist, rather than appending the word to the beginning of the domain. 21 | - Added a `--debug` global flag to keep intermediate files. Useful to debug massdns or resolver issues. 22 | 23 | ### Changed 24 | - Resolvers are now loaded from `~/.config/puredns/resolvers.txt` and `~/.config/puredns/resolvers-trusted.txt` by default. If there is a `resolvers.txt` file present in the current directory, it still takes precedence. [#35](https://github.com/d3mondev/puredns/issues/35) 25 | 26 | ### Fixed 27 | - Number of domains found was not displayed when the `--skip-validation` option was set. 28 | - Domain sanitization now strips any remaining `*.` prefix at the beginning of a domain instead of skipping the domain entirely. For example, puredns will try to resolve `*.example.com` as `example.com`. 29 | - Support running massdns as root. [#17](https://github.com/d3mondev/puredns/issues/17) [#27](https://github.com/d3mondev/puredns/issues/27) 30 | 31 | ## [2.0.1] - 2021-06-25 32 | ### Fixed 33 | - Wildcard subdomains with only CNAME records were not being filtered properly. [#14](https://github.com/d3mondev/puredns/issues/14) 34 | 35 | ## [2.0.0] - 2021-05-03 36 | ### Added 37 | - Stdin can be used in place of the domain list or wordlist files. See help for examples. 38 | - Quiet flag (`-q`, `--quiet`) to silence output. Only valid domains are output to stdout when quiet mode is on. [#4](https://github.com/d3mondev/puredns/issues/4) 39 | - Attempt to detect DNS load balancing during wildcard detection. Use flag `-n`, `--wildcard-tests` to specify the number of DNS queries to perform to detect all the possible IPs for a subdomain. 40 | - Add ability to specify a maximum batch size of domains to process at once during wildcard detection with `--wildcard-batch`. This is to help prevent memory issues that can happen on very large lists (70M+ wildcard subdomains). 41 | - Progress bar during wildcard detection. 42 | - Selected options are displayed at the start of the program. 43 | - Add sponsors command to view active [Github sponsors](https://github.com/sponsors/d3mondev). 44 | 45 | ### Changed 46 | - Complete rewrite in Go for more stability and to prepare new features. 47 | - Some command line flags have changed to be POSIX compliant, use `--help` on commands to see the changes. 48 | - Rewrite wildcard detection algorithm to be more robust. 49 | - Remove dependency on 'pv' and do progress bar and rate limiting internally instead. 50 | - Massdns output file is now written in `-o Snl` format. 51 | - A default list of public resolvers is no longer provided as a reference. Best results will be obtained by curating your own list, for example using [public-dns.info](https://public-dns.info/nameservers-all.txt) and [DNS Validator](https://github.com/vortexau/dnsvalidator). 52 | - Remove `--write-answers` command line option since the full wildcard answers are no longer kept in memory to optimize for large files. This might come back in a future release if requested. 53 | 54 | ### Fixed 55 | - Massdns and wildcard detection will retry on SERVFAIL errors. 56 | - Add missing entries in the massdns cache that resulted in a higher number of DNS queries being made during wildcard detection. 57 | - Fix many edge cases happening around wildcard detection. 58 | 59 | ## [1.0.3] - 2021-04-12 60 | ### Fixed 61 | - Remove Cloudflare DNS from the list of trusted resolvers. [Here's why](https://twitter.com/d3mondev/status/1381678504450924552?s=20). 62 | - Increase the default rate limit per trusted resolver to 50. 63 | - Adjust massdns command line parameter `-s` to limit the size of the initial burst of queries sent to the trusted resolvers. 64 | 65 | ## [1.0.2] - 2021-03-22 66 | ### Fixed 67 | - Fix a badly handled exception during wildcard detection that was halting the process. 68 | 69 | ## [1.0.1] - 2020-10-15 70 | ### Fixed 71 | - Fix a bug where valid subdomains were not saved to a file. [#1](https://github.com/d3mondev/puredns/issues/1) 72 | 73 | ## [1.0.0] - 2020-08-02 74 | ### Added 75 | - Initial implementation. 76 | 77 | [Unreleased]: https://github.com/d3mondev/puredns/compare/v2.1.1...HEAD 78 | [2.1.1]: https://github.com/d3mondev/puredns/compare/v2.1.0...v2.1.1 79 | [2.1.0]: https://github.com/d3mondev/puredns/compare/v2.0.1...v2.1.0 80 | [2.0.1]: https://github.com/d3mondev/puredns/compare/v2.0.0...v2.0.1 81 | [2.0.0]: https://github.com/d3mondev/puredns/compare/v1.0.3...v2.0.1 82 | [1.0.3]: https://github.com/d3mondev/puredns/compare/v1.0.2...v1.0.3 83 | [1.0.2]: https://github.com/d3mondev/puredns/compare/v1.0.1...v1.0.2 84 | [1.0.1]: https://github.com/d3mondev/puredns/compare/v1.0.0...v1.0.1 85 | [1.0.0]: https://github.com/d3mondev/puredns/releases/tag/v1.0.0 86 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PKG := github.com/d3mondev/puredns/v2 2 | PKG_LIST := $(shell go list ./... | grep -v /vendor/) 3 | BRANCH := $(shell git rev-parse --abbrev-ref HEAD | tr -d '\040\011\012\015\n') 4 | REVISION := $(shell git rev-parse --short HEAD) 5 | 6 | .SILENT: ; 7 | .PHONY: all 8 | 9 | all: build 10 | 11 | lint: ## Lint the files 12 | golint -set_exit_status $(PKG_LIST) 13 | staticcheck ./... 14 | 15 | test: ## Run unit tests 16 | go fmt $(PKG_LIST) 17 | go vet $(PKG_LIST) 18 | go test -race -timeout 30s -cover -count 1 $(PKG_LIST) 19 | 20 | msan: ## Run memory sanitizer 21 | go test -msan $(PKG_LIST) 22 | 23 | build: ## Build the binary file 24 | go build -trimpath -ldflags="-s -w" 25 | 26 | cover: ## Code coverage 27 | go test -coverprofile=cover.out $(PKG_LIST) 28 | 29 | clean: ## Remove previous build 30 | rm -f cover.out 31 | go clean 32 | 33 | help: ## Display this help screen 34 | grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 35 | -------------------------------------------------------------------------------- /assets/become-sponsor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3mondev/puredns/d82341557f0bb1efcfacf15533464be720ae4be8/assets/become-sponsor.jpg -------------------------------------------------------------------------------- /assets/puredns-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3mondev/puredns/d82341557f0bb1efcfacf15533464be720ae4be8/assets/puredns-logo.png -------------------------------------------------------------------------------- /assets/puredns-operation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3mondev/puredns/d82341557f0bb1efcfacf15533464be720ae4be8/assets/puredns-operation.gif -------------------------------------------------------------------------------- /assets/puredns-terminal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3mondev/puredns/d82341557f0bb1efcfacf15533464be720ae4be8/assets/puredns-terminal.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/d3mondev/puredns/v2 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/d3mondev/resolvermt v0.3.2 7 | github.com/kr/pretty v0.1.0 // indirect 8 | github.com/miekg/dns v1.1.53 // indirect 9 | github.com/spf13/cobra v1.7.0 10 | github.com/spf13/pflag v1.0.5 11 | github.com/stretchr/testify v1.7.0 12 | golang.org/x/tools v0.8.0 // indirect 13 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /internal/app/cmd/bruteforce.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app" 7 | "github.com/d3mondev/puredns/v2/internal/usecase/programbanner" 8 | "github.com/d3mondev/puredns/v2/internal/usecase/resolve" 9 | "github.com/spf13/cobra" 10 | ) 11 | 12 | func newCmdBruteforce() *cobra.Command { 13 | cmdBruteforce := &cobra.Command{ 14 | Use: "bruteforce domain [flags]\n puredns bruteforce -d domains.txt [flags]", 15 | Short: "Bruteforce subdomains using a wordlist", 16 | Long: `Bruteforce takes a file containing words to test as subdomains against the 17 | domain specified. It will invoke massdns using public resolvers for 18 | a quick first pass, then attempt to filter out any wildcard subdomains found. 19 | Finally, it will ensure the results are free of DNS poisoning by resolving 20 | the remaining domains using trusted resolvers. 21 | 22 | The argument can be omitted if the wordlist is read from stdin.`, 23 | RunE: runBruteforce, 24 | } 25 | 26 | cmdBruteforce.Flags().StringVarP(&resolveOptions.DomainFile, "domains", "d", resolveOptions.DomainFile, "text file containing domains to bruteforce") 27 | 28 | cmdBruteforce.Flags().AddFlagSet(resolveFlags) 29 | cmdBruteforce.Flags().SortFlags = false 30 | 31 | return cmdBruteforce 32 | } 33 | 34 | func runBruteforce(cmd *cobra.Command, args []string) error { 35 | parseBruteforceArgs(args) 36 | 37 | if err := resolveOptions.Validate(); err != nil { 38 | return err 39 | } 40 | 41 | bannerService := programbanner.NewService(context) 42 | resolveService := resolve.NewService(context, resolveOptions) 43 | 44 | err := resolveService.Initialize() 45 | if err != nil { 46 | return err 47 | } 48 | defer resolveService.Close(context.Options.Debug) 49 | 50 | bannerService.PrintWithResolveOptions(resolveOptions) 51 | 52 | return resolveService.Resolve() 53 | } 54 | 55 | func parseBruteforceArgs(args []string) error { 56 | if app.HasStdin() { 57 | context.Stdin = os.Stdin 58 | 59 | if len(args) >= 1 { 60 | if resolveOptions.DomainFile == "" { 61 | resolveOptions.Domain = args[0] 62 | } 63 | } 64 | } else { 65 | if len(args) == 1 { 66 | resolveOptions.Wordlist = args[0] 67 | } else if len(args) >= 2 { 68 | resolveOptions.Wordlist = args[0] 69 | resolveOptions.Domain = args[1] 70 | } 71 | } 72 | 73 | resolveOptions.Mode = 1 74 | 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /internal/app/cmd/bruteforce_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "io/fs" 5 | "os" 6 | "testing" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 9 | "github.com/d3mondev/puredns/v2/pkg/filetest" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestParseBruteforceArgs_TwoArgs(t *testing.T) { 15 | context = ctx.NewCtx() 16 | resolveOptions = ctx.DefaultResolveOptions() 17 | 18 | parseBruteforceArgs([]string{"wordlist.txt", "example.com"}) 19 | 20 | err := resolveOptions.Validate() 21 | assert.Nil(t, err) 22 | } 23 | 24 | func TestParseBruteforceArgs_NoArgs(t *testing.T) { 25 | context = ctx.NewCtx() 26 | resolveOptions = ctx.DefaultResolveOptions() 27 | 28 | parseBruteforceArgs([]string{}) 29 | 30 | err := resolveOptions.Validate() 31 | assert.ErrorIs(t, err, ctx.ErrNoDomain) 32 | } 33 | 34 | func TestParseBruteforceArgs_NoDomain(t *testing.T) { 35 | context = ctx.NewCtx() 36 | resolveOptions = ctx.DefaultResolveOptions() 37 | 38 | parseBruteforceArgs([]string{"wordlist.txt"}) 39 | 40 | err := resolveOptions.Validate() 41 | assert.ErrorIs(t, err, ctx.ErrNoDomain) 42 | } 43 | 44 | func TestParseBruteforceArgs_DomainFile(t *testing.T) { 45 | context = ctx.NewCtx() 46 | resolveOptions = ctx.DefaultResolveOptions() 47 | resolveOptions.DomainFile = "domains.txt" 48 | 49 | parseBruteforceArgs([]string{"wordlist.txt"}) 50 | 51 | err := resolveOptions.Validate() 52 | assert.Nil(t, err) 53 | } 54 | 55 | func TestParseBruteforceArgs_Stdin(t *testing.T) { 56 | context = ctx.NewCtx() 57 | resolveOptions = ctx.DefaultResolveOptions() 58 | 59 | r, w, err := os.Pipe() 60 | require.Nil(t, err) 61 | require.Nil(t, w.Close()) 62 | filetest.OverrideStdin(t, r) 63 | 64 | parseBruteforceArgs([]string{"domain.com"}) 65 | 66 | err = resolveOptions.Validate() 67 | assert.Nil(t, err) 68 | } 69 | 70 | func TestParseBruteforceArgs_StdinNoDomain(t *testing.T) { 71 | context = ctx.NewCtx() 72 | resolveOptions = ctx.DefaultResolveOptions() 73 | 74 | r, w, err := os.Pipe() 75 | require.Nil(t, err) 76 | require.Nil(t, w.Close()) 77 | filetest.OverrideStdin(t, r) 78 | 79 | parseBruteforceArgs([]string{}) 80 | 81 | err = resolveOptions.Validate() 82 | assert.ErrorIs(t, err, ctx.ErrNoDomain) 83 | } 84 | 85 | func TestParseBruteforceArgs_StdinDomainFile(t *testing.T) { 86 | context = ctx.NewCtx() 87 | resolveOptions = ctx.DefaultResolveOptions() 88 | resolveOptions.DomainFile = "domains.txt" 89 | 90 | r, w, err := os.Pipe() 91 | require.Nil(t, err) 92 | require.Nil(t, w.Close()) 93 | filetest.OverrideStdin(t, r) 94 | 95 | parseBruteforceArgs([]string{}) 96 | 97 | err = resolveOptions.Validate() 98 | assert.Nil(t, err) 99 | } 100 | 101 | func TestRunBruteforce_OK(t *testing.T) { 102 | resolvers := filetest.CreateFile(t, "8.8.8.8\n") 103 | wordlist := filetest.CreateFile(t, "") 104 | 105 | context = ctx.NewCtx() 106 | resolveOptions = ctx.DefaultResolveOptions() 107 | resolveOptions.ResolverFile = resolvers.Name() 108 | 109 | cmd := newCmdBruteforce() 110 | err := runBruteforce(cmd, []string{wordlist.Name(), "example.com"}) 111 | assert.Nil(t, err) 112 | } 113 | 114 | func TestRunBruteforce_ValidateError(t *testing.T) { 115 | context = ctx.NewCtx() 116 | resolveOptions = ctx.DefaultResolveOptions() 117 | 118 | cmd := newCmdBruteforce() 119 | err := runBruteforce(cmd, []string{}) 120 | assert.ErrorIs(t, err, ctx.ErrNoDomain) 121 | } 122 | 123 | func TestRunBruteforce_InitializeError(t *testing.T) { 124 | context = ctx.NewCtx() 125 | resolveOptions = ctx.DefaultResolveOptions() 126 | 127 | cmd := newCmdBruteforce() 128 | err := runBruteforce(cmd, []string{"wordlist.txt", "example.com"}) 129 | assert.ErrorIs(t, err, fs.ErrNotExist) 130 | } 131 | -------------------------------------------------------------------------------- /internal/app/cmd/resolve.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/app" 9 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 10 | "github.com/d3mondev/puredns/v2/internal/usecase/programbanner" 11 | "github.com/d3mondev/puredns/v2/internal/usecase/resolve" 12 | "github.com/spf13/cobra" 13 | "github.com/spf13/pflag" 14 | ) 15 | 16 | var ( 17 | resolveFlags *pflag.FlagSet 18 | resolveOptions *ctx.ResolveOptions 19 | ) 20 | 21 | func newCmdResolve() *cobra.Command { 22 | resolveOptions = ctx.DefaultResolveOptions() 23 | 24 | cmdResolve := &cobra.Command{ 25 | Use: "resolve [flags]", 26 | Short: "Resolve a list of domains", 27 | Long: `Resolve takes a file containing a list of domains and performs DNS queries 28 | to resolve each domain. It will invoke massdns using public resolvers for 29 | a quick first pass, then attempt to filter out any wildcard subdomains found. 30 | Finally, it will ensure the results are free of DNS poisoning by resolving 31 | the remaining domains using trusted resolvers. 32 | 33 | The argument can be omitted if the domains to resolve are read from stdin.`, 34 | Args: cobra.MinimumNArgs(0), 35 | RunE: runResolve, 36 | } 37 | 38 | resolveFlags = pflag.NewFlagSet("resolve", pflag.ExitOnError) 39 | resolveFlags.StringVarP(&resolveOptions.BinPath, "bin", "b", resolveOptions.BinPath, "path to massdns binary file") 40 | resolveFlags.IntVarP(&resolveOptions.RateLimit, "rate-limit", "l", resolveOptions.RateLimit, "limit total queries per second for public resolvers (0 = unlimited) (default unlimited)") 41 | resolveFlags.IntVar(&resolveOptions.RateLimitTrusted, "rate-limit-trusted", resolveOptions.RateLimitTrusted, "limit total queries per second for trusted resolvers (0 = unlimited)") 42 | resolveFlags.StringVarP(&resolveOptions.ResolverFile, "resolvers", "r", resolveOptions.ResolverFile, "text file containing public resolvers") 43 | resolveFlags.StringVar(&resolveOptions.ResolverTrustedFile, "resolvers-trusted", resolveOptions.ResolverTrustedFile, "text file containing trusted resolvers") 44 | resolveFlags.IntVarP(&resolveOptions.WildcardThreads, "threads", "t", resolveOptions.WildcardThreads, "number of threads to use while filtering wildcards") 45 | resolveFlags.BoolVar(&resolveOptions.TrustedOnly, "trusted-only", resolveOptions.TrustedOnly, "use only trusted resolvers (implies --skip-validation)") 46 | resolveFlags.IntVarP(&resolveOptions.WildcardTests, "wildcard-tests", "n", resolveOptions.WildcardTests, "number of tests to perform to detect DNS load balancing") 47 | resolveFlags.IntVar(&resolveOptions.WildcardBatchSize, "wildcard-batch", resolveOptions.WildcardBatchSize, "number of subdomains to test for wildcards in a single batch (0 = unlimited) (default unlimited)") 48 | resolveFlags.StringVarP(&resolveOptions.WriteDomainsFile, "write", "w", resolveOptions.WriteDomainsFile, "write found domains to a file") 49 | resolveFlags.StringVar(&resolveOptions.WriteMassdnsFile, "write-massdns", resolveOptions.WriteMassdnsFile, "write massdns database to a file (-o Snl format)") 50 | resolveFlags.StringVar(&resolveOptions.WriteWildcardsFile, "write-wildcards", resolveOptions.WriteWildcardsFile, "write wildcard subdomain roots to a file") 51 | resolveFlags.BoolVar(&resolveOptions.SkipSanitize, "skip-sanitize", resolveOptions.SkipSanitize, "do not sanitize the list of domains to test") 52 | resolveFlags.BoolVar(&resolveOptions.SkipWildcard, "skip-wildcard-filter", resolveOptions.SkipWildcard, "do not perform wildcard detection and filtering") 53 | resolveFlags.BoolVar(&resolveOptions.SkipValidation, "skip-validation", resolveOptions.SkipValidation, "do not validate results with trusted resolvers") 54 | 55 | must(cobra.MarkFlagFilename(resolveFlags, "bin")) 56 | must(cobra.MarkFlagFilename(resolveFlags, "resolvers")) 57 | must(cobra.MarkFlagFilename(resolveFlags, "resolvers-trusted")) 58 | must(cobra.MarkFlagFilename(resolveFlags, "write")) 59 | must(cobra.MarkFlagFilename(resolveFlags, "write-massdns")) 60 | must(cobra.MarkFlagFilename(resolveFlags, "write-wildcards")) 61 | 62 | cmdResolve.Flags().AddFlagSet(resolveFlags) 63 | cmdResolve.Flags().SortFlags = false 64 | 65 | return cmdResolve 66 | } 67 | 68 | func runResolve(cmd *cobra.Command, args []string) error { 69 | if len(args) == 0 { 70 | if !app.HasStdin() { 71 | fmt.Println(cmd.UsageString()) 72 | return errors.New("requires a list of domains to resolve") 73 | } 74 | context.Stdin = os.Stdin 75 | } else { 76 | resolveOptions.DomainFile = args[0] 77 | } 78 | 79 | resolveOptions.Mode = 0 80 | 81 | if err := resolveOptions.Validate(); err != nil { 82 | return err 83 | } 84 | 85 | bannerService := programbanner.NewService(context) 86 | resolveService := resolve.NewService(context, resolveOptions) 87 | 88 | err := resolveService.Initialize() 89 | if err != nil { 90 | return err 91 | } 92 | defer resolveService.Close(context.Options.Debug) 93 | 94 | bannerService.PrintWithResolveOptions(resolveOptions) 95 | 96 | return resolveService.Resolve() 97 | } 98 | -------------------------------------------------------------------------------- /internal/app/cmd/resolve_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRunResolve(t *testing.T) { 11 | t.Run("no argument", func(t *testing.T) { 12 | context = ctx.NewCtx() 13 | cmd := newCmdResolve() 14 | 15 | err := runResolve(cmd, []string{}) 16 | 17 | assert.NotNil(t, err) 18 | }) 19 | 20 | t.Run("file that does not exist", func(t *testing.T) { 21 | context = ctx.NewCtx() 22 | cmd := newCmdResolve() 23 | 24 | err := runResolve(cmd, []string{"thisfiledoesnotexist.txt"}) 25 | 26 | assert.NotNil(t, err) 27 | }) 28 | } 29 | -------------------------------------------------------------------------------- /internal/app/cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "io/ioutil" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 7 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | var ( 12 | context *ctx.Ctx 13 | ) 14 | 15 | func must(err error) { 16 | if err != nil { 17 | panic(err) 18 | } 19 | } 20 | 21 | func newCmdRoot() *cobra.Command { 22 | rootCmd := &cobra.Command{ 23 | Use: context.ProgramName, 24 | Short: context.ProgramTagline, 25 | Long: context.ProgramName + " " + context.ProgramVersion + ` 26 | 27 | A subdomain bruteforce tool that wraps around massdns to quickly resolve 28 | a massive number of DNS queries. Using its heuristic algorithm, it can filter out 29 | wildcard subdomains and validate that the results are free of DNS poisoning 30 | by using trusted resolvers.`, 31 | Example: ` puredns resolve domains.txt 32 | puredns bruteforce wordlist.txt domain.com --resolvers public.txt 33 | cat domains.txt | puredns resolve`, 34 | Version: context.ProgramVersion, 35 | } 36 | 37 | rootCmd.CompletionOptions.DisableDefaultCmd = true 38 | rootCmd.PersistentFlags().BoolVarP(&context.Options.Quiet, "quiet", "q", context.Options.Quiet, "quiet mode") 39 | rootCmd.PersistentFlags().BoolVar(&context.Options.Debug, "debug", context.Options.Debug, "keep intermediate files") 40 | rootCmd.Flags().SortFlags = false 41 | 42 | cmdResolve := newCmdResolve() 43 | cmdBruteforce := newCmdBruteforce() 44 | cmdSponsors := newCmdSponsors() 45 | rootCmd.AddCommand(cmdResolve) 46 | rootCmd.AddCommand(cmdBruteforce) 47 | rootCmd.AddCommand(cmdSponsors) 48 | 49 | rootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) 50 | rootCmd.SilenceErrors = true 51 | 52 | rootCmd.PersistentPreRun = preRun 53 | 54 | return rootCmd 55 | } 56 | 57 | func preRun(cmd *cobra.Command, args []string) { 58 | cmd.SilenceUsage = true 59 | 60 | if context.Options.Quiet { 61 | console.Output = ioutil.Discard 62 | } 63 | } 64 | 65 | // Execute executes the root command. 66 | func Execute(ctx *ctx.Ctx) error { 67 | context = ctx 68 | cmdRoot := newCmdRoot() 69 | 70 | return cmdRoot.Execute() 71 | } 72 | -------------------------------------------------------------------------------- /internal/app/cmd/root_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "testing" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 9 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewCmdRoot(t *testing.T) { 14 | cmd := newCmdRoot() 15 | assert.NotNil(t, cmd) 16 | } 17 | 18 | func TestPreRun_Quiet(t *testing.T) { 19 | context = ctx.NewCtx() 20 | context.Options.Quiet = true 21 | 22 | cmd := newCmdResolve() 23 | preRun(cmd, []string{}) 24 | 25 | assert.Equal(t, console.Output, io.Discard) 26 | } 27 | 28 | func TestMust_OK(t *testing.T) { 29 | must(nil) 30 | } 31 | 32 | func TestMust_Panics(t *testing.T) { 33 | assert.Panics(t, func() { must(errors.New("error")) }) 34 | } 35 | -------------------------------------------------------------------------------- /internal/app/cmd/sponsors.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/d3mondev/puredns/v2/internal/app" 5 | "github.com/d3mondev/puredns/v2/internal/usecase/sponsors" 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | func newCmdSponsors() *cobra.Command { 10 | cmdSponsors := &cobra.Command{ 11 | Use: "sponsors", 12 | Short: "Show the active sponsors <3", 13 | Long: `Show the very kind-hearted people who support my work as sponsors. 14 | 15 | This software is made by me, @d3mondev. I'm on a mission to make free and open-souce 16 | software for the bug bounty community and infosec professionals. 17 | 18 | As you know, free doesn't help pay the bills. If my work is earning you money, 19 | consider becoming a sponsor! It would mean A WHOLE LOT as it would allow me to continue 20 | working for free for the community: https://github.com/sponsors/d3mondev`, 21 | RunE: runSponsors, 22 | } 23 | 24 | return cmdSponsors 25 | } 26 | 27 | func runSponsors(cmd *cobra.Command, args []string) error { 28 | service := sponsors.NewService() 29 | 30 | return service.Show(app.AppSponsorsURL) 31 | } 32 | -------------------------------------------------------------------------------- /internal/app/cmd/sponsors_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewCmdSponsors(t *testing.T) { 10 | cmd := newCmdSponsors() 11 | assert.NotNil(t, cmd) 12 | } 13 | -------------------------------------------------------------------------------- /internal/app/ctx/ctx.go: -------------------------------------------------------------------------------- 1 | package ctx 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app" 7 | ) 8 | 9 | // Ctx is the program context. It contains the necessary parameters for a command to run. 10 | type Ctx struct { 11 | ProgramName string 12 | ProgramVersion string 13 | ProgramTagline string 14 | GitBranch string 15 | GitRevision string 16 | 17 | Options *GlobalOptions 18 | Stdin *os.File 19 | } 20 | 21 | // NewCtx creates a new context. 22 | func NewCtx() *Ctx { 23 | return &Ctx{ 24 | ProgramName: app.AppName, 25 | ProgramVersion: app.AppVersion, 26 | ProgramTagline: app.AppDesc, 27 | 28 | GitBranch: app.GitBranch, 29 | GitRevision: app.GitRevision, 30 | 31 | Options: DefaultGlobalOptions(), 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/app/ctx/ctx_test.go: -------------------------------------------------------------------------------- 1 | package ctx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewCtx(t *testing.T) { 10 | ctx := NewCtx() 11 | assert.NotNil(t, ctx) 12 | } 13 | -------------------------------------------------------------------------------- /internal/app/ctx/options.go: -------------------------------------------------------------------------------- 1 | package ctx 2 | 3 | import ( 4 | "errors" 5 | "os/user" 6 | "path/filepath" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/app" 9 | "github.com/d3mondev/puredns/v2/pkg/fileoperation" 10 | ) 11 | 12 | // ResolveMode is the resolve mode. 13 | type ResolveMode int 14 | 15 | const ( 16 | // Resolve resolves domains. 17 | Resolve ResolveMode = iota 18 | 19 | // Bruteforce bruteforces subdomains. 20 | Bruteforce 21 | ) 22 | 23 | var ( 24 | // ErrNoDomain no domain specified. 25 | ErrNoDomain error = errors.New("no domain specified") 26 | 27 | // ErrNoWordlist no wordlist specified. 28 | ErrNoWordlist error = errors.New("no wordlist specified") 29 | ) 30 | 31 | // GlobalOptions contains the program's global options. 32 | type GlobalOptions struct { 33 | TrustedResolvers []string 34 | 35 | Quiet bool 36 | Debug bool 37 | } 38 | 39 | // DefaultGlobalOptions creates a new GlobalOptions struct with default values. 40 | func DefaultGlobalOptions() *GlobalOptions { 41 | return &GlobalOptions{ 42 | TrustedResolvers: []string{ 43 | "8.8.8.8", 44 | "8.8.4.4", 45 | }, 46 | 47 | Quiet: false, 48 | Debug: false, 49 | } 50 | } 51 | 52 | // ResolveOptions contains a resolve command's options. 53 | type ResolveOptions struct { 54 | BinPath string 55 | 56 | ResolverFile string 57 | ResolverTrustedFile string 58 | TrustedOnly bool 59 | 60 | RateLimit int 61 | RateLimitTrusted int 62 | 63 | WildcardThreads int 64 | WildcardTests int 65 | WildcardBatchSize int 66 | 67 | SkipSanitize bool 68 | SkipWildcard bool 69 | SkipValidation bool 70 | 71 | WriteDomainsFile string 72 | WriteMassdnsFile string 73 | WriteWildcardsFile string 74 | 75 | Mode ResolveMode 76 | Domain string 77 | Wordlist string 78 | DomainFile string 79 | } 80 | 81 | // DefaultResolveOptions creates a new ResolveOptions struct with default values. 82 | func DefaultResolveOptions() *ResolveOptions { 83 | resolversPath := "resolvers.txt" 84 | trustedResolversPath := "" 85 | 86 | if !fileoperation.FileExists(resolversPath) { 87 | usr, err := user.Current() 88 | if err == nil { 89 | resolversPath = filepath.Join(usr.HomeDir, ".config", "puredns", "resolvers.txt") 90 | trustedResolversPath = filepath.Join(usr.HomeDir, ".config", "puredns", "resolvers-trusted.txt") 91 | 92 | if !fileoperation.FileExists(trustedResolversPath) { 93 | trustedResolversPath = "" 94 | } 95 | } 96 | } 97 | 98 | return &ResolveOptions{ 99 | BinPath: "massdns", 100 | 101 | ResolverFile: resolversPath, 102 | ResolverTrustedFile: trustedResolversPath, 103 | TrustedOnly: false, 104 | 105 | RateLimit: 0, 106 | RateLimitTrusted: 500, 107 | 108 | WildcardThreads: 100, 109 | WildcardTests: 3, 110 | WildcardBatchSize: 0, 111 | 112 | SkipSanitize: false, 113 | SkipWildcard: false, 114 | SkipValidation: false, 115 | 116 | WriteDomainsFile: "", 117 | WriteMassdnsFile: "", 118 | WriteWildcardsFile: "", 119 | 120 | Mode: Resolve, 121 | Domain: "", 122 | Wordlist: "", 123 | DomainFile: "", 124 | } 125 | } 126 | 127 | // Validate validates the options. 128 | func (o *ResolveOptions) Validate() error { 129 | // Enforce --skip-validation when --trusted-only is set 130 | if o.TrustedOnly { 131 | o.SkipValidation = true 132 | } 133 | 134 | // Validate that a wordlist and a domain are present in bruteforce mode 135 | if o.Mode == Bruteforce { 136 | if o.Domain == "" && o.DomainFile == "" { 137 | return ErrNoDomain 138 | } 139 | 140 | if o.Wordlist == "" && !app.HasStdin() { 141 | return ErrNoWordlist 142 | } 143 | } 144 | 145 | return nil 146 | } 147 | -------------------------------------------------------------------------------- /internal/app/ctx/options_test.go: -------------------------------------------------------------------------------- 1 | package ctx 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestDefaultGlobalOptions(t *testing.T) { 13 | opts := DefaultGlobalOptions() 14 | assert.NotNil(t, opts) 15 | } 16 | 17 | func TestResolveOptionsValidate_OK(t *testing.T) { 18 | have := DefaultResolveOptions() 19 | want := DefaultResolveOptions() 20 | 21 | err := have.Validate() 22 | 23 | assert.Nil(t, err) 24 | assert.Equal(t, want, have) 25 | } 26 | 27 | func TestResolveOptionsValidate_NoPublic(t *testing.T) { 28 | have := DefaultResolveOptions() 29 | have.TrustedOnly = true 30 | 31 | want := DefaultResolveOptions() 32 | want.TrustedOnly = true 33 | want.SkipValidation = true 34 | 35 | err := have.Validate() 36 | 37 | assert.Nil(t, err) 38 | assert.Equal(t, want, have) 39 | } 40 | 41 | func TestResolveOptionsValidate_BruteforceNoDomain(t *testing.T) { 42 | have := DefaultResolveOptions() 43 | have.Mode = Bruteforce 44 | have.Wordlist = "wordlist.txt" 45 | 46 | err := have.Validate() 47 | assert.ErrorIs(t, ErrNoDomain, err) 48 | } 49 | 50 | func TestResolveOptionsValidate_BruteforceDomain(t *testing.T) { 51 | have := DefaultResolveOptions() 52 | have.Mode = Bruteforce 53 | have.Wordlist = "wordlist.txt" 54 | have.Domain = "example.com" 55 | 56 | err := have.Validate() 57 | 58 | assert.Nil(t, err) 59 | } 60 | 61 | func TestResolveOptionsValidate_BruteforceDomainFile(t *testing.T) { 62 | have := DefaultResolveOptions() 63 | have.Mode = Bruteforce 64 | have.Wordlist = "wordlist.txt" 65 | have.DomainFile = "domains.txt" 66 | 67 | err := have.Validate() 68 | 69 | assert.Nil(t, err) 70 | } 71 | 72 | func TestResolveOptionsValidate_BruteforceNoWordlist(t *testing.T) { 73 | have := DefaultResolveOptions() 74 | have.Mode = Bruteforce 75 | have.Domain = "example.com" 76 | 77 | err := have.Validate() 78 | 79 | assert.ErrorIs(t, ErrNoWordlist, err) 80 | } 81 | 82 | func TestResolveOptionsValidate_BruteforceWordlistStdin(t *testing.T) { 83 | have := DefaultResolveOptions() 84 | have.Mode = Bruteforce 85 | have.Domain = "example.com" 86 | 87 | r, w, err := os.Pipe() 88 | require.Nil(t, err) 89 | require.Nil(t, w.Close()) 90 | filetest.OverrideStdin(t, r) 91 | 92 | err = have.Validate() 93 | 94 | assert.Nil(t, err) 95 | } 96 | -------------------------------------------------------------------------------- /internal/app/stdin.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import "os" 4 | 5 | // HasStdin returns true if there is a valid stdin present. 6 | func HasStdin() bool { 7 | stat, err := os.Stdin.Stat() 8 | if err != nil { 9 | return false 10 | } 11 | 12 | if stat.Mode()&os.ModeNamedPipe == 0 { 13 | return false 14 | } 15 | 16 | return true 17 | } 18 | -------------------------------------------------------------------------------- /internal/app/stdin_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestHasStdin_Default(t *testing.T) { 12 | got := HasStdin() 13 | assert.Equal(t, false, got) 14 | } 15 | 16 | func TestHasStdin_With(t *testing.T) { 17 | r, _, err := os.Pipe() 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | 22 | os.Stdin = r 23 | 24 | got := HasStdin() 25 | assert.Equal(t, true, got) 26 | } 27 | 28 | func TestHasStdin_File(t *testing.T) { 29 | file := filetest.CreateFile(t, "") 30 | filetest.OverrideStdin(t, file) 31 | 32 | got := HasStdin() 33 | assert.Equal(t, false, got) 34 | } 35 | 36 | func TestHasStdin_Nil(t *testing.T) { 37 | filetest.OverrideStdin(t, nil) 38 | 39 | got := HasStdin() 40 | assert.Equal(t, false, got) 41 | } 42 | -------------------------------------------------------------------------------- /internal/app/version.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | const ( 4 | // AppName is the name of the application. 5 | AppName string = "puredns" 6 | 7 | // AppDesc is a short description of the application. 8 | AppDesc string = "Very accurate massdns resolving and bruteforcing." 9 | 10 | // AppVersion is the program version. 11 | AppVersion string = "v2.1.2" 12 | 13 | // AppSponsorsURL is the text file containing sponsors information. 14 | AppSponsorsURL string = "https://gist.githubusercontent.com/d3mondev/0bfff529a4dad627bdb684ad1ef2506d/raw/sponsors.txt" 15 | ) 16 | 17 | // GitBranch is the current git branch. 18 | var GitBranch string 19 | 20 | // GitRevision is the current git commit. 21 | var GitRevision string 22 | -------------------------------------------------------------------------------- /internal/pkg/console/color.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | const ( 4 | // ColorBlack is the black foreground color. 5 | ColorBlack = "\033[0;30m" 6 | 7 | // ColorGray is the gray foreground color. 8 | ColorGray = "\033[1;30m" 9 | 10 | // ColorRed is the red foreground color. 11 | ColorRed = "\033[0;31m" 12 | 13 | // ColorBrightRed is the bright red foreground color. 14 | ColorBrightRed = "\033[1;31m" 15 | 16 | // ColorGreen is the green foreground color. 17 | ColorGreen = "\033[0;32m" 18 | 19 | // ColorBrightGreen is the bright green foreground color. 20 | ColorBrightGreen = "\033[1;32m" 21 | 22 | // ColorYellow is the yellow foreground color. 23 | ColorYellow = "\033[0;33m" 24 | 25 | // ColorBrightYellow is the yellow foreground color. 26 | ColorBrightYellow = "\033[1;33m" 27 | 28 | // ColorBlue is the blue foreground color. 29 | ColorBlue = "\033[0;34m" 30 | 31 | // ColorBrightBlue is the bright blue foreground color. 32 | ColorBrightBlue = "\033[1;34m" 33 | 34 | // ColorMagenta is the magenta foreground color. 35 | ColorMagenta = "\033[0;35m" 36 | 37 | // ColorBrightMagenta is the bright magenta foreground color. 38 | ColorBrightMagenta = "\033[1;35m" 39 | 40 | // ColorCyan is the cyan foreground color. 41 | ColorCyan = "\033[0;36m" 42 | 43 | // ColorBrightCyan is the bright cyan foreground color. 44 | ColorBrightCyan = "\033[1;36m" 45 | 46 | // ColorWhite is the white foreground color. 47 | ColorWhite = "\033[0;37m" 48 | 49 | // ColorBrightWhite is the bright white foreground color. 50 | ColorBrightWhite = "\033[1;37m" 51 | 52 | // ColorReset is the code to reset all attributes. 53 | ColorReset = "\033[0m" 54 | ) 55 | -------------------------------------------------------------------------------- /internal/pkg/console/console.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | ) 8 | 9 | const ( 10 | colorMeta = ColorBrightWhite 11 | colorMessage = ColorCyan 12 | colorMessageText = ColorWhite 13 | colorSuccess = ColorBrightGreen 14 | colorSuccessText = ColorWhite 15 | colorWarning = ColorBrightYellow 16 | colorWarningText = ColorWhite 17 | colorError = ColorRed 18 | colorErrorText = ColorRed 19 | ) 20 | 21 | var ( 22 | // Output is a writer where the messages are sent. 23 | Output io.Writer = os.Stderr 24 | 25 | // ExitHandler is a function called to exit the process. 26 | ExitHandler = os.Exit 27 | ) 28 | 29 | // Message displays an informative message. 30 | func Message(format string, a ...interface{}) { 31 | message := fmt.Sprintf("%s[%s*%s]%s %s%s\n", colorMeta, colorMessage, colorMeta, colorMessageText, format, ColorReset) 32 | fmt.Fprintf(Output, message, a...) 33 | } 34 | 35 | // Success displays a success message. 36 | func Success(format string, a ...interface{}) { 37 | message := fmt.Sprintf("%s[%s+%s]%s %s%s\n", colorMeta, colorSuccess, colorMeta, colorSuccessText, format, ColorReset) 38 | fmt.Fprintf(Output, message, a...) 39 | } 40 | 41 | // Warning displays a warning message. 42 | func Warning(format string, a ...interface{}) { 43 | message := fmt.Sprintf("%s[%s!%s]%s %s%s\n", colorMeta, colorWarning, colorMeta, colorWarningText, format, ColorReset) 44 | fmt.Fprintf(Output, message, a...) 45 | } 46 | 47 | // Error displays an error message. 48 | func Error(format string, a ...interface{}) { 49 | message := fmt.Sprintf("%s[%sX%s]%s %s%s\n", colorMeta, colorError, colorMeta, colorErrorText, format, ColorReset) 50 | fmt.Fprintf(Output, message, a...) 51 | } 52 | 53 | // Printf prints a message without any formatting. 54 | func Printf(format string, a ...interface{}) { 55 | fmt.Fprintf(Output, format, a...) 56 | } 57 | 58 | // Fatal prints a fatal error message and exists the process. 59 | func Fatal(format string, a ...interface{}) { 60 | fmt.Fprintf(Output, ColorRed+format+ColorReset, a...) 61 | ExitHandler(-1) 62 | } 63 | -------------------------------------------------------------------------------- /internal/pkg/console/console_test.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func redirectOutput() *bytes.Buffer { 12 | buffer := bytes.NewBuffer([]byte{}) 13 | Output = buffer 14 | 15 | return buffer 16 | } 17 | 18 | type spyExitHandler struct { 19 | seen int 20 | } 21 | 22 | func (s *spyExitHandler) Exit(int) { 23 | s.seen++ 24 | } 25 | 26 | func TestMessage(t *testing.T) { 27 | buffer := redirectOutput() 28 | Message("foo %s", "bar") 29 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 30 | } 31 | 32 | func TestSuccess(t *testing.T) { 33 | buffer := redirectOutput() 34 | Success("foo %s", "bar") 35 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 36 | } 37 | 38 | func TestWarning(t *testing.T) { 39 | buffer := redirectOutput() 40 | Warning("foo %s", "bar") 41 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 42 | } 43 | 44 | func TestError(t *testing.T) { 45 | buffer := redirectOutput() 46 | Error("foo %s", "bar") 47 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 48 | } 49 | 50 | func TestPrintf(t *testing.T) { 51 | buffer := redirectOutput() 52 | Printf("foo %s", "bar") 53 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 54 | } 55 | 56 | func TestFatal(t *testing.T) { 57 | buffer := redirectOutput() 58 | spyExitHandler := spyExitHandler{} 59 | ExitHandler = spyExitHandler.Exit 60 | 61 | Fatal("foo %s", "bar") 62 | assert.True(t, strings.Contains(buffer.String(), "foo bar")) 63 | assert.Equal(t, 1, spyExitHandler.seen) 64 | } 65 | -------------------------------------------------------------------------------- /internal/usecase/programbanner/programbanner.go: -------------------------------------------------------------------------------- 1 | package programbanner 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 8 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 9 | ) 10 | 11 | // Service prints the program banner and version number. 12 | type Service struct { 13 | ctx *ctx.Ctx 14 | } 15 | 16 | // NewService returns a new Service object. 17 | func NewService(ctx *ctx.Ctx) Service { 18 | return Service{ 19 | ctx: ctx, 20 | } 21 | } 22 | 23 | // Print prints the program logo along with its name, tagline and version information. 24 | func (s Service) Print() { 25 | version := s.ctx.ProgramVersion 26 | if s.ctx.GitBranch != "" { 27 | version = fmt.Sprintf("%s-%s", s.ctx.GitBranch, s.ctx.GitRevision) 28 | } 29 | 30 | padding := strings.Repeat(" ", 34-len(version)-len(s.ctx.ProgramName)) 31 | 32 | console.Printf(console.ColorBrightBlue) 33 | console.Printf(" _ \n") 34 | console.Printf(" | | \n") 35 | console.Printf(" _ __ _ _ _ __ ___ __| |_ __ ___ \n") 36 | console.Printf("| '_ \\| | | | '__/ _ \\/ _` | '_ \\/ __|\n") 37 | console.Printf("| |_) | |_| | | | __/ (_| | | | \\__ \\\n") 38 | console.Printf("| .__/ \\__,_|_| \\___|\\__,_|_| |_|___/\n") 39 | console.Printf("| | \n") 40 | console.Printf("|_|%s%s%s %s%s\n", padding, console.ColorBrightCyan, s.ctx.ProgramName, console.ColorBrightBlue, version) 41 | console.Printf("\n") 42 | console.Printf("%sFast and accurate DNS resolving and bruteforcing\n", console.ColorBrightWhite) 43 | console.Printf("\n") 44 | console.Printf("%sCrafted with %s<3%s by @d3mondev\n", console.ColorBrightWhite, console.ColorBrightRed, console.ColorBrightWhite) 45 | console.Printf("https://github.com/sponsors/d3mondev\n") 46 | console.Printf(console.ColorReset + "\n") 47 | } 48 | 49 | // PrintWithResolveOptions prints the program's logo, along with the options selected 50 | // for the resolve command. 51 | func (s Service) PrintWithResolveOptions(opts *ctx.ResolveOptions) { 52 | s.Print() 53 | console.Printf(console.ColorBrightWhite + "------------------------------------------------------------\n" + console.ColorReset) 54 | 55 | defaultOptions := ctx.DefaultResolveOptions() 56 | 57 | var file string 58 | if s.ctx.Stdin != nil { 59 | file = "stdin" 60 | } else { 61 | if opts.Mode == 1 { 62 | file = opts.Wordlist 63 | } else { 64 | file = opts.DomainFile 65 | } 66 | } 67 | 68 | colorOptionLabel := console.ColorBrightWhite 69 | colorOptionSkipLabel := console.ColorBrightYellow 70 | colorOptionValue := console.ColorWhite 71 | colorOptionTick := console.ColorBrightBlue 72 | colorOptionTickWrite := console.ColorBrightGreen 73 | 74 | tickSymbol := fmt.Sprintf("%s[%s+%s]", colorOptionLabel, colorOptionTick, colorOptionLabel) 75 | tickSymbolWrite := fmt.Sprintf("%s[%s+%s]", colorOptionLabel, colorOptionTickWrite, colorOptionLabel) 76 | 77 | if opts.Mode == 1 { 78 | console.Printf("%s Mode :%s bruteforce\n", tickSymbol, colorOptionValue) 79 | 80 | if opts.DomainFile != "" { 81 | console.Printf("%s Domains :%s %s\n", tickSymbol, colorOptionValue, opts.DomainFile) 82 | } else { 83 | console.Printf("%s Domain :%s %s\n", tickSymbol, colorOptionValue, opts.Domain) 84 | } 85 | 86 | console.Printf("%s Wordlist :%s %s\n", tickSymbol, colorOptionValue, file) 87 | } else { 88 | console.Printf("%s Mode :%s resolve\n", tickSymbol, colorOptionValue) 89 | console.Printf("%s File :%s %s\n", tickSymbol, colorOptionValue, file) 90 | } 91 | 92 | if opts.TrustedOnly { 93 | console.Printf("%s Trusted Only :%s true\n", tickSymbol, colorOptionValue) 94 | } 95 | 96 | if !opts.TrustedOnly { 97 | console.Printf("%s Resolvers :%s %s\n", tickSymbol, colorOptionValue, opts.ResolverFile) 98 | } 99 | 100 | if opts.ResolverTrustedFile != "" { 101 | console.Printf("%s Trusted Resolvers :%s %s\n", tickSymbol, colorOptionValue, opts.ResolverTrustedFile) 102 | } 103 | 104 | if !opts.TrustedOnly { 105 | rate := "unlimited" 106 | if opts.RateLimit != 0 { 107 | rate = fmt.Sprintf("%d qps", opts.RateLimit) 108 | } 109 | console.Printf("%s Rate Limit :%s %s\n", tickSymbol, colorOptionValue, rate) 110 | } 111 | 112 | console.Printf("%s Rate Limit (Trusted) :%s %d qps\n", tickSymbol, colorOptionValue, opts.RateLimitTrusted) 113 | console.Printf("%s Wildcard Threads :%s %d\n", tickSymbol, colorOptionValue, opts.WildcardThreads) 114 | console.Printf("%s Wildcard Tests :%s %d\n", tickSymbol, colorOptionValue, opts.WildcardTests) 115 | 116 | if opts.WildcardBatchSize != defaultOptions.WildcardBatchSize { 117 | console.Printf("%s Wildcard Batch Size :%s %d\n", tickSymbol, colorOptionValue, opts.WildcardBatchSize) 118 | } 119 | 120 | if opts.WriteDomainsFile != "" { 121 | console.Printf("%s Write Domains :%s %s\n", tickSymbolWrite, colorOptionValue, opts.WriteDomainsFile) 122 | } 123 | 124 | if opts.WriteMassdnsFile != "" { 125 | console.Printf("%s Write Massdns :%s %s\n", tickSymbolWrite, colorOptionValue, opts.WriteMassdnsFile) 126 | } 127 | 128 | if opts.WriteWildcardsFile != "" { 129 | console.Printf("%s Write Wildcards :%s %s\n", tickSymbolWrite, colorOptionValue, opts.WriteWildcardsFile) 130 | } 131 | 132 | if opts.SkipSanitize { 133 | console.Printf("%s[+] Skip Sanitize\n", colorOptionSkipLabel) 134 | } 135 | 136 | if opts.SkipWildcard { 137 | console.Printf("%s[+] Skip Wildcard Detection\n", colorOptionSkipLabel) 138 | } 139 | 140 | if !opts.TrustedOnly { 141 | if opts.SkipValidation { 142 | console.Printf("%s[+] Skip Validation\n", colorOptionSkipLabel) 143 | } 144 | } 145 | 146 | console.Printf(console.ColorBrightWhite + "------------------------------------------------------------\n" + console.ColorReset) 147 | console.Printf("\n") 148 | } 149 | -------------------------------------------------------------------------------- /internal/usecase/programbanner/programbanner_test.go: -------------------------------------------------------------------------------- 1 | package programbanner 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 10 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestPrint(t *testing.T) { 16 | buffer := new(bytes.Buffer) 17 | console.Output = buffer 18 | 19 | ctx := ctx.NewCtx() 20 | service := NewService(ctx) 21 | service.Print() 22 | 23 | assert.True(t, strings.Contains(buffer.String(), ctx.ProgramName)) 24 | assert.True(t, strings.Contains(buffer.String(), ctx.ProgramVersion)) 25 | } 26 | 27 | func TestPrintGit(t *testing.T) { 28 | buffer := new(bytes.Buffer) 29 | console.Output = buffer 30 | 31 | ctx := ctx.NewCtx() 32 | ctx.GitBranch = "master" 33 | ctx.GitRevision = "revision" 34 | service := NewService(ctx) 35 | service.Print() 36 | 37 | assert.True(t, strings.Contains(buffer.String(), ctx.ProgramName)) 38 | assert.True(t, strings.Contains(buffer.String(), ctx.GitBranch)) 39 | assert.True(t, strings.Contains(buffer.String(), ctx.GitRevision)) 40 | } 41 | 42 | func TestPrintWithResolveOptions(t *testing.T) { 43 | tests := []struct { 44 | name string 45 | haveCtx ctx.Ctx 46 | haveOpts ctx.ResolveOptions 47 | want string 48 | }{ 49 | {name: "stdin", haveCtx: ctx.Ctx{Stdin: os.Stdin}, haveOpts: ctx.ResolveOptions{}, want: "stdin"}, 50 | {name: "resolve mode", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{DomainFile: "domains.txt", Mode: 0}, want: "domains.txt"}, 51 | {name: "bruteforce mode", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{Domain: "example.com", Wordlist: "wordlist.txt", Mode: 1}, want: "wordlist.txt"}, 52 | {name: "bruteforce mode multiple domains", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{DomainFile: "domains.txt", Wordlist: "wordlist.txt", Mode: 1}, want: "domains.txt"}, 53 | {name: "trusted resolvers", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{ResolverTrustedFile: "trusted.txt"}, want: "trusted.txt"}, 54 | {name: "rate", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{RateLimit: 777}, want: "777"}, 55 | {name: "batch size", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{WildcardBatchSize: 5555}, want: "5555"}, 56 | {name: "write domains", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{WriteDomainsFile: "domains_out.txt"}, want: "domains_out.txt"}, 57 | {name: "write massdns", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{WriteMassdnsFile: "massdns_out.txt"}, want: "massdns_out.txt"}, 58 | {name: "write wildcards", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{WriteWildcardsFile: "wildcards_out.txt"}, want: "wildcards_out.txt"}, 59 | {name: "skip sanitize", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{SkipSanitize: true}, want: "Skip Sanitize"}, 60 | {name: "skip wildcard", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{SkipWildcard: true}, want: "Skip Wildcard"}, 61 | {name: "skip validation", haveCtx: ctx.Ctx{}, haveOpts: ctx.ResolveOptions{SkipValidation: true}, want: "Skip Validation"}, 62 | } 63 | 64 | for _, test := range tests { 65 | t.Run(test.name, func(t *testing.T) { 66 | buffer := new(bytes.Buffer) 67 | console.Output = buffer 68 | 69 | require.Nil(t, test.haveOpts.Validate()) 70 | service := NewService(&test.haveCtx) 71 | service.PrintWithResolveOptions(&test.haveOpts) 72 | 73 | assert.Truef(t, strings.Contains(buffer.String(), test.want), "%s not found in output", test.want) 74 | }) 75 | } 76 | } 77 | 78 | func TestPrintWithResolveOptions_NoPublic(t *testing.T) { 79 | haveCtx := ctx.Ctx{} 80 | haveOpts := ctx.ResolveOptions{} 81 | haveOpts.TrustedOnly = true 82 | 83 | buffer := new(bytes.Buffer) 84 | console.Output = buffer 85 | 86 | require.Nil(t, haveOpts.Validate()) 87 | service := NewService(&haveCtx) 88 | service.PrintWithResolveOptions(&haveOpts) 89 | 90 | assert.True(t, strings.Contains(buffer.String(), "] Trusted Only"), "should appear in output") 91 | assert.False(t, strings.Contains(buffer.String(), "] Resolvers"), "should not appear in output") 92 | assert.False(t, strings.Contains(buffer.String(), "] Rate-Limit"), "should not appear in output") 93 | assert.False(t, strings.Contains(buffer.String(), "] Skip Validation"), "should not appear in output") 94 | } 95 | -------------------------------------------------------------------------------- /internal/usecase/resolve/cachereader.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "strings" 7 | 8 | "github.com/d3mondev/puredns/v2/pkg/wildcarder" 9 | "github.com/d3mondev/resolvermt" 10 | ) 11 | 12 | // CacheReader reads a DNS cache from a file and can fill a wildcarder.DNSCache object, 13 | // save valid domains to a file, and count valid domains. The number of items processed can be 14 | // limited to a specific number, and subsequent calls to Read will resume without starting over. 15 | type CacheReader struct { 16 | reader io.ReadCloser 17 | scanner *bufio.Scanner 18 | } 19 | 20 | // NewCacheReader returns a new CacheReader. 21 | func NewCacheReader(r io.ReadCloser) *CacheReader { 22 | return &CacheReader{ 23 | reader: r, 24 | scanner: bufio.NewScanner(r), 25 | } 26 | } 27 | 28 | // Read reads a massdns cache from a file (created with -o Snl), can save the valid domains to a writer, 29 | // fill a wildcarder.DNSCache object, and return the number of valid domains in the cache. 30 | // Subsequent calls to Read will resume without starting over. 31 | func (c CacheReader) Read(w io.Writer, cache *wildcarder.DNSCache, maxCount int) (count int, err error) { 32 | type state int 33 | const ( 34 | stateNewAnswerSection state = iota 35 | stateSaveAnswer 36 | stateSkip 37 | ) 38 | 39 | var curDomain string 40 | var curState state 41 | var domainSaved bool 42 | var found int 43 | 44 | for c.scanner.Scan() { 45 | line := c.scanner.Text() 46 | 47 | // If we receive an empty line, it's the beginning of a new answer 48 | if line == "" { 49 | curState = stateNewAnswerSection 50 | 51 | // Break from the loop if we have reached the maximum number of elements to process 52 | if maxCount > 0 && found == maxCount { 53 | break 54 | } 55 | 56 | continue 57 | } 58 | 59 | switch curState { 60 | // We're at the beginning of a new answer section, look for the domain name 61 | case stateNewAnswerSection: 62 | // Records should be in the form "domain RRTYPE answer" 63 | parts := strings.Split(line, " ") 64 | if len(parts) != 3 { 65 | curState = stateSkip 66 | continue 67 | } 68 | 69 | domain := strings.TrimSuffix(parts[0], ".") 70 | if domain == "" { 71 | curState = stateSkip 72 | continue 73 | } 74 | 75 | curDomain = domain 76 | domainSaved = false 77 | curState = stateSaveAnswer 78 | 79 | fallthrough 80 | 81 | // Save the answer record found 82 | case stateSaveAnswer: 83 | parts := strings.Split(line, " ") 84 | if len(parts) != 3 { 85 | curState = stateSkip 86 | continue 87 | } 88 | 89 | domain := curDomain 90 | rrtypeStr := parts[1] 91 | answer := parts[2] 92 | 93 | var rrtype resolvermt.RRtype 94 | switch rrtypeStr { 95 | case "A": 96 | rrtype = resolvermt.TypeA 97 | case "AAAA": 98 | rrtype = resolvermt.TypeAAAA 99 | case "CNAME": 100 | answer = strings.TrimSuffix(answer, ".") 101 | rrtype = resolvermt.TypeCNAME 102 | default: 103 | continue 104 | } 105 | 106 | // Save valid domain just once 107 | if !domainSaved { 108 | found++ 109 | domainSaved = true 110 | 111 | if w != nil { 112 | w.Write([]byte(domain + "\n")) 113 | } 114 | } 115 | 116 | // Valid record found, add it to the cache 117 | if cache != nil { 118 | cacheAnswer := wildcarder.DNSAnswer{ 119 | Type: rrtype, 120 | Answer: answer, 121 | } 122 | cache.Add(domain, []wildcarder.DNSAnswer{cacheAnswer}) 123 | } 124 | 125 | // If we're just counting valid domains, we can skip the rest of the records 126 | if cache == nil && w == nil { 127 | curState = stateSkip 128 | } 129 | 130 | // Answer was invalid, skip until we receive a new answer section 131 | case stateSkip: 132 | continue 133 | } 134 | } 135 | 136 | return found, c.scanner.Err() 137 | } 138 | 139 | // Close closes the input reader. 140 | func (c CacheReader) Close() error { 141 | return c.reader.Close() 142 | } 143 | -------------------------------------------------------------------------------- /internal/usecase/resolve/cachereader_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/d3mondev/puredns/v2/pkg/filetest" 9 | "github.com/d3mondev/puredns/v2/pkg/wildcarder" 10 | "github.com/d3mondev/resolvermt" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestCacheReaderRead(t *testing.T) { 15 | type cacheEntry struct { 16 | question string 17 | answers []wildcarder.AnswerHash 18 | } 19 | 20 | tests := []struct { 21 | name string 22 | haveData string 23 | wantCache []cacheEntry 24 | wantDomain []string 25 | wantErr error 26 | }{ 27 | { 28 | name: "single record", 29 | haveData: `example.com. A 127.0.0.1`, 30 | wantCache: []cacheEntry{ 31 | { 32 | question: "example.com", 33 | answers: []wildcarder.AnswerHash{ 34 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeA, Answer: "127.0.0.1"}), 35 | }, 36 | }, 37 | }, 38 | wantDomain: []string{ 39 | "example.com", 40 | }, 41 | }, 42 | { 43 | name: "multiple record", 44 | haveData: `www.example.com. CNAME example.com. 45 | example.com. A 127.0.0.1 46 | example.com. AAAA ::1`, 47 | wantCache: []cacheEntry{ 48 | { 49 | question: "www.example.com", 50 | answers: []wildcarder.AnswerHash{ 51 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeCNAME, Answer: "example.com"}), 52 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeA, Answer: "127.0.0.1"}), 53 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeAAAA, Answer: "::1"}), 54 | }, 55 | }, 56 | }, 57 | wantDomain: []string{ 58 | "www.example.com", 59 | }, 60 | }, 61 | { 62 | name: "invalid record type", 63 | haveData: `example.com. NS ns.example.com.`, 64 | wantCache: []cacheEntry{}, 65 | wantDomain: []string{}, 66 | }, 67 | { 68 | name: "save domain after valid record is found", 69 | haveData: `example.com. NS ns.example.com. 70 | example.com. AAAA ::1`, 71 | wantCache: []cacheEntry{ 72 | { 73 | question: "example.com", 74 | answers: []wildcarder.AnswerHash{ 75 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeAAAA, Answer: "::1"}), 76 | }, 77 | }, 78 | }, 79 | wantDomain: []string{ 80 | "example.com", 81 | }, 82 | }, 83 | { 84 | name: "multiple answer sections", 85 | haveData: ` 86 | example.com. A 127.0.0.1 87 | 88 | www.test.com. CNAME test.com. 89 | test.com. A 127.0.0.1 90 | test.com. AAAA ::1 91 | `, 92 | wantCache: []cacheEntry{ 93 | { 94 | question: "example.com", 95 | answers: []wildcarder.AnswerHash{ 96 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeA, Answer: "127.0.0.1"}), 97 | }, 98 | }, 99 | { 100 | question: "www.test.com", 101 | answers: []wildcarder.AnswerHash{ 102 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeCNAME, Answer: "test.com"}), 103 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeA, Answer: "127.0.0.1"}), 104 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeAAAA, Answer: "::1"}), 105 | }, 106 | }, 107 | }, 108 | wantDomain: []string{ 109 | "example.com", 110 | "www.test.com", 111 | }, 112 | }, 113 | { 114 | name: "skip if domain name can't be parsed", 115 | haveData: `garbage 116 | example.com. A 127.0.0.1`, 117 | wantCache: []cacheEntry{}, 118 | wantDomain: []string{}, 119 | }, 120 | { 121 | name: "skip answer section containing bad data", 122 | haveData: `example.com. A 127.0.0.1 123 | garbage`, 124 | wantCache: []cacheEntry{ 125 | { 126 | question: "example.com", 127 | answers: []wildcarder.AnswerHash{ 128 | wildcarder.HashAnswer(wildcarder.DNSAnswer{Type: resolvermt.TypeA, Answer: "127.0.0.1"}), 129 | }, 130 | }, 131 | }, 132 | wantDomain: []string{"example.com"}, 133 | }, 134 | { 135 | name: "empty domain", 136 | haveData: `. A 127.0.0.1`, 137 | wantCache: []cacheEntry{}, 138 | wantDomain: []string{}, 139 | }, 140 | } 141 | 142 | for _, test := range tests { 143 | t.Run(test.name, func(t *testing.T) { 144 | domainFile := filetest.CreateFile(t, "") 145 | cache := wildcarder.NewDNSCache() 146 | 147 | loader := NewCacheReader(io.NopCloser(strings.NewReader(test.haveData))) 148 | count, err := loader.Read(domainFile, cache, 0) 149 | assert.ErrorIs(t, err, test.wantErr) 150 | assert.Equal(t, len(test.wantDomain), count) 151 | 152 | gotDomain := filetest.ReadFile(t, domainFile.Name()) 153 | 154 | assert.Equal(t, test.wantDomain, gotDomain) 155 | 156 | for _, cacheTest := range test.wantCache { 157 | got := cache.Find(cacheTest.question) 158 | assert.ElementsMatch(t, cacheTest.answers, got) 159 | } 160 | }) 161 | } 162 | } 163 | 164 | func TestCacheReaderRead_WithMax(t *testing.T) { 165 | domainFile := filetest.CreateFile(t, "") 166 | data := ` 167 | example.com. A 127.0.0.1 168 | 169 | example.net. AAAA ::1 170 | 171 | example.org. CNAME example.net. 172 | example.net. AAAA ::1` 173 | 174 | r := io.NopCloser(strings.NewReader(data)) 175 | loader := NewCacheReader(r) 176 | 177 | _, err := loader.Read(domainFile, nil, 2) 178 | gotDomain := filetest.ReadFile(t, domainFile.Name()) 179 | 180 | assert.Nil(t, err) 181 | assert.Equal(t, []string{"example.com", "example.net"}, gotDomain) 182 | 183 | _, err = loader.Read(domainFile, nil, 2) 184 | gotDomain = filetest.ReadFile(t, domainFile.Name()) 185 | 186 | assert.Nil(t, err) 187 | assert.Equal(t, []string{"example.com", "example.net", "example.org"}, gotDomain) 188 | } 189 | 190 | func TestCacheReaderRead_CountOnly(t *testing.T) { 191 | data := ` 192 | example.com. A 127.0.0.1 193 | 194 | example.net. AAAA ::1 195 | 196 | example.org. CNAME example.net. 197 | example.net. AAAA ::1 198 | ` 199 | 200 | r := io.NopCloser(strings.NewReader(data)) 201 | loader := NewCacheReader(r) 202 | 203 | count, _ := loader.Read(nil, nil, 0) 204 | assert.Equal(t, 3, count) 205 | } 206 | -------------------------------------------------------------------------------- /internal/usecase/resolve/domainreader.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "strings" 9 | 10 | "github.com/d3mondev/puredns/v2/pkg/procreader" 11 | ) 12 | 13 | // DomainReader implements an io.Reader interface that generates subdomains to resolve. 14 | // It reads data line by line from a source scanner. This data is either words that will be 15 | // prefixed to a domain to create subdomains, or a straight list of subdomains to resolve. 16 | // The DomainReader will also discard any generated domains that do not pass the specified 17 | // domain sanitizer filter if present. 18 | type DomainReader struct { 19 | source io.ReadCloser 20 | sourceScanner *bufio.Scanner 21 | subdomainReader *procreader.ProcReader 22 | 23 | domains []string 24 | sanitizer DomainSanitizer 25 | } 26 | 27 | var _ io.Reader = (*DomainReader)(nil) 28 | 29 | // DomainSanitizer is a function that sanitizes a domain, typically removing invalid characters. 30 | // If the domain cannot be sanitized or is invalid, an empty string is expected. 31 | type DomainSanitizer func(domain string) string 32 | 33 | // NewDomainReader creates a new DomainReader. If domains is not empty, the source 34 | // reader is expected to contain words that will be prefixed to the domains to create subdomains. 35 | func NewDomainReader(source io.ReadCloser, domains []string, sanitizer DomainSanitizer) *DomainReader { 36 | domainReader := &DomainReader{ 37 | source: source, 38 | sourceScanner: bufio.NewScanner(source), 39 | domains: domains, 40 | sanitizer: sanitizer, 41 | } 42 | 43 | domainReader.subdomainReader = procreader.New(domainReader.nextSubdomains) 44 | 45 | return domainReader 46 | } 47 | 48 | // Read creates and returns subdomains in the buffer specified. 49 | func (r *DomainReader) Read(p []byte) (int, error) { 50 | return r.subdomainReader.Read(p) 51 | } 52 | 53 | // nextSubdomain is a callback used to generate the next subdomains. 54 | func (r *DomainReader) nextSubdomains(size int) ([]byte, error) { 55 | if !r.sourceScanner.Scan() { 56 | // Make sure the close the source, discarding the error 57 | // as we want the error from the scanner 58 | r.source.Close() 59 | 60 | // Return the error from the scanner 61 | if err := r.sourceScanner.Err(); err != nil { 62 | return nil, err 63 | } 64 | 65 | // Return EOF 66 | return nil, io.EOF 67 | } 68 | 69 | var output bytes.Buffer 70 | word := r.sourceScanner.Text() 71 | 72 | if len(r.domains) == 0 { 73 | // Single domain was read from reader 74 | domain := word 75 | domain = r.processDomain(domain) 76 | output.WriteString(domain) 77 | } else { 78 | // Generate a subdomain from the word and the list of domains 79 | for _, domain := range r.domains { 80 | if strings.ContainsRune(domain, '*') { 81 | domain = strings.ReplaceAll(domain, "*", word) 82 | } else { 83 | domain = fmt.Sprintf("%s.%s", word, domain) 84 | } 85 | 86 | domain = r.processDomain(domain) 87 | output.WriteString(domain) 88 | } 89 | } 90 | 91 | return output.Bytes(), nil 92 | } 93 | 94 | // processDomain processes the domain data 95 | func (r *DomainReader) processDomain(domain string) string { 96 | // Sanitize the domain 97 | if r.sanitizer != nil { 98 | domain = r.sanitizer(domain) 99 | } 100 | 101 | // Append newline even if we have empty domain for accurate progress bar 102 | domain = domain + "\n" 103 | 104 | return domain 105 | } 106 | -------------------------------------------------------------------------------- /internal/usecase/resolve/domainreader_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "strings" 7 | "testing" 8 | "testing/iotest" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewDomainReader(t *testing.T) { 14 | r := NewDomainReader(io.NopCloser(strings.NewReader("test")), nil, nil) 15 | assert.NotNil(t, r) 16 | } 17 | 18 | func TestDomainReaderRead(t *testing.T) { 19 | tests := []struct { 20 | name string 21 | haveData string 22 | haveDomains []string 23 | haveSanitizer DomainSanitizer 24 | want string 25 | wantErr error 26 | }{ 27 | {name: "domain list", haveData: "example.com\nwww.example.com\nftp.example.com", want: "example.com\nwww.example.com\nftp.example.com\n", wantErr: io.EOF}, 28 | {name: "words", haveData: "www\nftp\nmail", haveDomains: []string{"example.com"}, want: "www.example.com\nftp.example.com\nmail.example.com\n", wantErr: io.EOF}, 29 | {name: "wildcard", haveData: "www\nftp\nmail", haveDomains: []string{"www.*.example.com"}, want: "www.www.example.com\nwww.ftp.example.com\nwww.mail.example.com\n", wantErr: io.EOF}, 30 | {name: "multiple wildcards", haveData: "word", haveDomains: []string{"www.*.*.example.com"}, want: "www.word.word.example.com\n", wantErr: io.EOF}, 31 | {name: "words multiple domains", haveData: "www\nftp\nmail", haveDomains: []string{"example.com", "example.org"}, want: "www.example.com\nwww.example.org\nftp.example.com\nftp.example.org\nmail.example.com\nmail.example.org\n", wantErr: io.EOF}, 32 | {name: "sanitize", haveData: "+", haveDomains: []string{"example.com"}, haveSanitizer: DefaultSanitizer, want: "\n", wantErr: io.EOF}, 33 | } 34 | 35 | for _, test := range tests { 36 | t.Run(test.name, func(t *testing.T) { 37 | r := NewDomainReader(io.NopCloser(strings.NewReader(test.haveData)), test.haveDomains, test.haveSanitizer) 38 | 39 | buf := make([]byte, 1024) 40 | n, err := r.Read(buf) 41 | 42 | assert.ErrorIs(t, err, test.wantErr) 43 | assert.Equal(t, test.want, string(buf[:n])) 44 | }) 45 | } 46 | } 47 | 48 | func TestDomainReaderRead_ScannerError(t *testing.T) { 49 | wantErr := errors.New("error") 50 | 51 | r := NewDomainReader(io.NopCloser(iotest.ErrReader(wantErr)), nil, nil) 52 | buf := make([]byte, 1024) 53 | _, err := r.Read(buf) 54 | 55 | assert.ErrorIs(t, err, wantErr) 56 | } 57 | -------------------------------------------------------------------------------- /internal/usecase/resolve/massresolver.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 7 | "github.com/d3mondev/puredns/v2/pkg/massdns" 8 | "github.com/d3mondev/puredns/v2/pkg/progressbar" 9 | ) 10 | 11 | // DefaultMassResolver implements the MassResolver interface. 12 | type DefaultMassResolver struct { 13 | massdns *massdns.Resolver 14 | } 15 | 16 | // NewDefaultMassResolver creates a new DefaultMassResolver. 17 | func NewDefaultMassResolver(binPath string) *DefaultMassResolver { 18 | return &DefaultMassResolver{ 19 | massdns: massdns.NewResolver(binPath), 20 | } 21 | } 22 | 23 | // Resolve calls massdns to resolve the domains contained in the input file. 24 | func (m *DefaultMassResolver) Resolve(r io.Reader, output string, total int, resolversFilename string, qps int) error { 25 | var template string 26 | 27 | if total == 0 { 28 | template = "Processed: {{ current }} Rate: {{ rate }} Elapsed: {{ time }}" 29 | } else { 30 | template = "[ETA {{ eta }}] {{ bar }} {{ current }}/{{ total }} rate: {{ rate }} qps (time: {{ time }})" 31 | } 32 | 33 | bar := progressbar.New( 34 | m.updateProgressBar, 35 | int64(total), 36 | progressbar.WithTemplate(template), 37 | progressbar.WithWriter(console.Output), 38 | ) 39 | 40 | bar.Start() 41 | err := m.massdns.Resolve(r, output, resolversFilename, qps) 42 | bar.Stop() 43 | 44 | return err 45 | } 46 | 47 | // updateProgressBar is the progress bar update callback. 48 | func (m *DefaultMassResolver) updateProgressBar(bar *progressbar.ProgressBar) { 49 | current := m.massdns.Current() 50 | bar.SetCurrent(int64(current)) 51 | } 52 | -------------------------------------------------------------------------------- /internal/usecase/resolve/massresolver_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMassResolverNew(t *testing.T) { 11 | r := NewDefaultMassResolver("") 12 | assert.NotNil(t, r) 13 | } 14 | 15 | func TestMassResolverResolve_OK(t *testing.T) { 16 | r := NewDefaultMassResolver("") 17 | 18 | err := r.Resolve(strings.NewReader("example.com"), "", 0, "", 10) 19 | assert.EqualError(t, err, "exec: no command", "should not call massdns because of invalid path") 20 | } 21 | 22 | func TestMassResolverResolve_WithTotal(t *testing.T) { 23 | r := NewDefaultMassResolver("") 24 | 25 | err := r.Resolve(strings.NewReader("example.com"), "", 100, "", 10) 26 | assert.EqualError(t, err, "exec: no command") 27 | } 28 | -------------------------------------------------------------------------------- /internal/usecase/resolve/requirementchecker.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 7 | ) 8 | 9 | // Executor is a simple interface to execute shell commands. 10 | type Executor interface { 11 | Shell(name string, arg ...string) error 12 | } 13 | 14 | // DefaultRequirementChecker checks that the required binaries are present. 15 | type DefaultRequirementChecker struct { 16 | executor Executor 17 | } 18 | 19 | // NewDefaultRequirementChecker returns a new checker object used to validate whether the required binaries can be run. 20 | func NewDefaultRequirementChecker(executor Executor) DefaultRequirementChecker { 21 | return DefaultRequirementChecker{executor: executor} 22 | } 23 | 24 | // Check makes sure that massdns can be executed on the system. 25 | // If not, it displays a message to help the user fix the issue. 26 | func (c DefaultRequirementChecker) Check(opt *ctx.ResolveOptions) error { 27 | if err := c.executor.Shell(opt.BinPath, "--help"); err != nil { 28 | fmt.Printf("Unable to execute massdns. Make sure it is present and that the\n") 29 | fmt.Printf("path to the binary is added to the PATH environment variable.\n\n") 30 | 31 | fmt.Printf("Alternatively, specify the path to massdns using --bin\n\n") 32 | 33 | return fmt.Errorf("unable to execute massdns: %w", err) 34 | } 35 | 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /internal/usecase/resolve/requirementchecker_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type stubExecutor struct { 12 | index int 13 | returnValues []error 14 | } 15 | 16 | func (s *stubExecutor) Shell(name string, arg ...string) error { 17 | ret := s.returnValues[s.index] 18 | s.index++ 19 | 20 | return ret 21 | } 22 | 23 | func TestCheck(t *testing.T) { 24 | wantErr := errors.New("error") 25 | 26 | tests := []struct { 27 | name string 28 | haveError []error 29 | wantErr error 30 | }{ 31 | {name: "ok", haveError: []error{nil, nil}, wantErr: nil}, 32 | {name: "massdns error handling", haveError: []error{wantErr, nil}, wantErr: wantErr}, 33 | } 34 | 35 | for _, test := range tests { 36 | t.Run(test.name, func(t *testing.T) { 37 | executor := &stubExecutor{} 38 | executor.returnValues = test.haveError 39 | checker := NewDefaultRequirementChecker(executor) 40 | 41 | err := checker.Check(&ctx.ResolveOptions{}) 42 | 43 | assert.ErrorIs(t, err, test.wantErr) 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/usecase/resolve/resolverloader.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | "strings" 8 | 9 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 10 | ) 11 | 12 | // DefaultResolverLoader loads resolvers from a text file. 13 | type DefaultResolverLoader struct{} 14 | 15 | // NewDefaultResolverFileLoader creates a new ResolverFileLoader instance. 16 | func NewDefaultResolverFileLoader() *DefaultResolverLoader { 17 | return &DefaultResolverLoader{} 18 | } 19 | 20 | // Load parses the specified filename to load resolvers and saves them to the program context. 21 | func (l *DefaultResolverLoader) Load(ctx *ctx.Ctx, filename string) error { 22 | if filename == "" { 23 | return nil 24 | } 25 | 26 | file, err := os.Open(filename) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | resolvers, err := load(file) 32 | 33 | if len(resolvers) > 0 { 34 | ctx.Options.TrustedResolvers = resolvers 35 | } 36 | 37 | return err 38 | } 39 | 40 | func load(r io.Reader) ([]string, error) { 41 | resolvers := []string{} 42 | 43 | scanner := bufio.NewScanner(r) 44 | for scanner.Scan() { 45 | resolver := strings.TrimSpace(scanner.Text()) 46 | if resolver == "" { 47 | continue 48 | } 49 | 50 | resolvers = append(resolvers, resolver) 51 | } 52 | 53 | if err := scanner.Err(); err != nil { 54 | return nil, err 55 | } 56 | 57 | return resolvers, nil 58 | } 59 | -------------------------------------------------------------------------------- /internal/usecase/resolve/resolverloader_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "testing/iotest" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 9 | "github.com/d3mondev/puredns/v2/pkg/filetest" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestResolverLoader(t *testing.T) { 15 | file := filetest.CreateFile(t, "") 16 | _, err := file.WriteString("8.8.8.8\n \n1.1.1.1\n4.4.4.4") 17 | require.Nil(t, err) 18 | 19 | ctx := ctx.NewCtx() 20 | 21 | loader := NewDefaultResolverFileLoader() 22 | err = loader.Load(ctx, file.Name()) 23 | 24 | assert.Nil(t, err) 25 | assert.ElementsMatch(t, ctx.Options.TrustedResolvers, []string{"8.8.8.8", "1.1.1.1", "4.4.4.4"}) 26 | } 27 | 28 | func TestResolverLoaderFileOpenError(t *testing.T) { 29 | ctx := ctx.NewCtx() 30 | loader := NewDefaultResolverFileLoader() 31 | 32 | err := loader.Load(ctx, "thisfiledoesnotexit.txt") 33 | 34 | assert.NotNil(t, err) 35 | } 36 | 37 | func TestResolverScannerError(t *testing.T) { 38 | reader := iotest.ErrReader(errors.New("read error")) 39 | 40 | _, err := load(reader) 41 | 42 | assert.NotNil(t, err) 43 | } 44 | -------------------------------------------------------------------------------- /internal/usecase/resolve/resultsaver.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 5 | "github.com/d3mondev/puredns/v2/pkg/fileoperation" 6 | ) 7 | 8 | // ResultFileSaver is responsible for saving the results of the resolve operation to files. 9 | type ResultFileSaver struct { 10 | fileCopy func(src string, dest string) error 11 | } 12 | 13 | // NewResultFileSaver creates a new ResultSaver object. 14 | func NewResultFileSaver() *ResultFileSaver { 15 | return &ResultFileSaver{ 16 | fileCopy: fileoperation.Copy, 17 | } 18 | } 19 | 20 | // Save saves the results contained in the working files according to the specified options. 21 | func (s *ResultFileSaver) Save(workfiles *Workfiles, opt *ctx.ResolveOptions) error { 22 | if opt.WriteDomainsFile != "" { 23 | if err := s.fileCopy(workfiles.Domains, opt.WriteDomainsFile); err != nil { 24 | return err 25 | } 26 | } 27 | 28 | if opt.WriteMassdnsFile != "" { 29 | if err := s.fileCopy(workfiles.MassdnsPublic, opt.WriteMassdnsFile); err != nil { 30 | return err 31 | } 32 | } 33 | 34 | if opt.WriteWildcardsFile != "" { 35 | if err := s.fileCopy(workfiles.WildcardRoots, opt.WriteWildcardsFile); err != nil { 36 | return err 37 | } 38 | } 39 | 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /internal/usecase/resolve/resultsaver_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSave(t *testing.T) { 12 | domainFile := filetest.CreateFile(t, "domain.com") 13 | massdnsFile := filetest.CreateFile(t, "domain.com. A 127.0.0.1") 14 | wildcardFile := filetest.CreateFile(t, "*.domain.com") 15 | 16 | domainOutputFile := filetest.CreateFile(t, "") 17 | massdnsOutputFile := filetest.CreateFile(t, "") 18 | wildcardOutputFile := filetest.CreateFile(t, "") 19 | 20 | tests := []struct { 21 | name string 22 | haveReadDomainFile string 23 | haveWriteDomainFile string 24 | haveReadMassdnsFile string 25 | haveWriteMassdnsFile string 26 | haveReadWildcardFile string 27 | haveWriteWildcardFile string 28 | wantDomainContent []string 29 | wantMassdnsContent []string 30 | wantWildcardContent []string 31 | wantErr bool 32 | }{ 33 | { 34 | name: "don't save", 35 | wantErr: false, 36 | }, 37 | { 38 | name: "save all", 39 | haveReadDomainFile: domainFile.Name(), 40 | haveWriteDomainFile: domainOutputFile.Name(), 41 | haveReadMassdnsFile: massdnsFile.Name(), 42 | haveWriteMassdnsFile: massdnsOutputFile.Name(), 43 | haveReadWildcardFile: wildcardFile.Name(), 44 | haveWriteWildcardFile: wildcardOutputFile.Name(), 45 | wantDomainContent: []string{"domain.com"}, 46 | wantMassdnsContent: []string{"domain.com. A 127.0.0.1"}, 47 | wantWildcardContent: []string{"*.domain.com"}, 48 | }, 49 | { 50 | name: "domain file error handling", 51 | haveReadDomainFile: "thisfiledoesntexist.txt", 52 | haveWriteDomainFile: domainOutputFile.Name(), 53 | wantErr: true, 54 | }, 55 | { 56 | name: "massdns answers file error handling", 57 | haveReadMassdnsFile: "thisfiledoesntexist.txt", 58 | haveWriteMassdnsFile: massdnsOutputFile.Name(), 59 | wantErr: true, 60 | }, 61 | { 62 | name: "wildcard roots file error handling", 63 | haveReadWildcardFile: "thisfiledoesntexist.txt", 64 | haveWriteWildcardFile: wildcardOutputFile.Name(), 65 | wantErr: true, 66 | }, 67 | } 68 | 69 | for _, test := range tests { 70 | t.Run(test.name, func(t *testing.T) { 71 | filetest.ClearFile(t, domainOutputFile) 72 | filetest.ClearFile(t, massdnsOutputFile) 73 | filetest.ClearFile(t, wildcardOutputFile) 74 | 75 | opt := &ctx.ResolveOptions{} 76 | opt.WriteDomainsFile = test.haveWriteDomainFile 77 | opt.WriteMassdnsFile = test.haveWriteMassdnsFile 78 | opt.WriteWildcardsFile = test.haveWriteWildcardFile 79 | 80 | workfiles := &Workfiles{} 81 | workfiles.Domains = test.haveReadDomainFile 82 | workfiles.MassdnsPublic = test.haveReadMassdnsFile 83 | workfiles.WildcardRoots = test.haveReadWildcardFile 84 | 85 | saver := NewResultFileSaver() 86 | 87 | gotErr := saver.Save(workfiles, opt) 88 | 89 | assert.Equal(t, test.wantErr, gotErr != nil) 90 | assert.ElementsMatch(t, test.wantDomainContent, filetest.ReadFile(t, test.haveWriteDomainFile)) 91 | assert.ElementsMatch(t, test.wantMassdnsContent, filetest.ReadFile(t, test.haveWriteMassdnsFile)) 92 | assert.ElementsMatch(t, test.wantWildcardContent, filetest.ReadFile(t, test.haveWriteWildcardFile)) 93 | }) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /internal/usecase/resolve/sanitizer.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // DefaultSanitizer is the default sanitizer function. It transforms the domain to lower characters, 8 | // and ensures only valid characters are present. Returns an empty string if the domain fails sanitization. 9 | func DefaultSanitizer(domain string) string { 10 | // Set to lowercase 11 | domain = strings.ToLower(domain) 12 | 13 | // Remove *. 14 | domain = strings.TrimPrefix(domain, "*.") 15 | 16 | // Keep only domains containing [a-z0-9.-] 17 | // Faster than using a regular expression 18 | for i := 0; i < len(domain); i++ { 19 | char := domain[i] 20 | 21 | if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || (char == '-') || (char == '_') || (char == '.') { 22 | continue 23 | } 24 | 25 | domain = "" 26 | break 27 | } 28 | 29 | return domain 30 | } 31 | -------------------------------------------------------------------------------- /internal/usecase/resolve/sanitizer_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDefaultSanitizer(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | haveDomain string 13 | wantDomain string 14 | }{ 15 | {name: "valid domain", haveDomain: "example.com", wantDomain: "example.com"}, 16 | {name: "tolower transform", haveDomain: "EXAMPLE.COM", wantDomain: "example.com"}, 17 | {name: "invalid characters", haveDomain: "example+.com", wantDomain: ""}, 18 | {name: "wildcard", haveDomain: "*.example.com", wantDomain: "example.com"}, 19 | } 20 | 21 | for _, test := range tests { 22 | t.Run(test.name, func(t *testing.T) { 23 | got := DefaultSanitizer(test.haveDomain) 24 | assert.Equal(t, test.wantDomain, got) 25 | }) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /internal/usecase/resolve/stubs_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 8 | "github.com/d3mondev/puredns/v2/pkg/filetest" 9 | ) 10 | 11 | type stubs struct { 12 | spyRequirementChecker *spyRequirementChecker 13 | fakeWorkfileCreator *fakeWorkfileCreator 14 | spyResolverLoader *spyResolverLoader 15 | stubDomainSanitizer *stubDomainSanitizer 16 | spyMassResolver *spyMassResolver 17 | stubWildcardFilter *stubWildcardFilter 18 | stubResultSaver *stubResultSaver 19 | } 20 | 21 | func newStubService(t *testing.T) (*Service, stubs) { 22 | stubs := stubs{ 23 | spyRequirementChecker: &spyRequirementChecker{}, 24 | fakeWorkfileCreator: newFakeWorkfileCreator(t), 25 | spyResolverLoader: &spyResolverLoader{}, 26 | stubDomainSanitizer: &stubDomainSanitizer{}, 27 | spyMassResolver: &spyMassResolver{}, 28 | stubWildcardFilter: &stubWildcardFilter{}, 29 | stubResultSaver: &stubResultSaver{}, 30 | } 31 | 32 | service := &Service{ 33 | Context: ctx.NewCtx(), 34 | Options: ctx.DefaultResolveOptions(), 35 | 36 | RequirementChecker: stubs.spyRequirementChecker, 37 | WorkfileCreator: stubs.fakeWorkfileCreator, 38 | ResolverLoader: stubs.spyResolverLoader, 39 | MassResolver: stubs.spyMassResolver, 40 | WildcardFilter: stubs.stubWildcardFilter, 41 | ResultSaver: stubs.stubResultSaver, 42 | } 43 | 44 | service.Options.ResolverFile = filetest.CreateFile(t, "8.8.8.8").Name() 45 | 46 | t.Cleanup(func() { 47 | service.Close(false) 48 | }) 49 | 50 | return service, stubs 51 | } 52 | 53 | type spyRequirementChecker struct { 54 | called int 55 | returns error 56 | } 57 | 58 | func (s *spyRequirementChecker) Check(opt *ctx.ResolveOptions) error { 59 | s.called++ 60 | return s.returns 61 | } 62 | 63 | func newFakeWorkfileCreator(t *testing.T) *fakeWorkfileCreator { 64 | return &fakeWorkfileCreator{t: t} 65 | } 66 | 67 | type fakeWorkfileCreator struct { 68 | t *testing.T 69 | 70 | workfiles *Workfiles 71 | called int 72 | 73 | err error 74 | } 75 | 76 | func (f *fakeWorkfileCreator) Create() (*Workfiles, error) { 77 | f.called++ 78 | 79 | if f.err != nil { 80 | return nil, f.err 81 | } 82 | 83 | realCreator := NewDefaultWorkfileCreator() 84 | 85 | files, err := realCreator.Create() 86 | if err != nil { 87 | f.t.Fatal(err) 88 | } 89 | 90 | f.workfiles = files 91 | 92 | return f.workfiles, nil 93 | } 94 | 95 | type spyResolverLoader struct { 96 | called int 97 | err error 98 | } 99 | 100 | func (s *spyResolverLoader) Load(*ctx.Ctx, string) error { 101 | s.called++ 102 | return s.err 103 | } 104 | 105 | type spyMassResolver struct { 106 | called int 107 | resolvers string 108 | ratelimit int 109 | } 110 | 111 | func (s *spyMassResolver) Resolve(r io.Reader, output string, total int, resolvers string, qps int) error { 112 | s.called++ 113 | s.resolvers = resolvers 114 | s.ratelimit = qps 115 | return nil 116 | } 117 | 118 | func (s *spyMassResolver) Current() int { 119 | return 0 120 | } 121 | 122 | func (s *spyMassResolver) Rate() float64 { 123 | return 0.0 124 | } 125 | 126 | type stubWildcardFilter struct { 127 | called int 128 | err error 129 | domains []string 130 | roots []string 131 | } 132 | 133 | func (s *stubWildcardFilter) Filter(WildcardFilterOptions, int) (found int, roots []string, err error) { 134 | s.called++ 135 | return len(s.domains), s.roots, s.err 136 | } 137 | 138 | type stubDomainSanitizer struct { 139 | called int 140 | returns error 141 | } 142 | 143 | func (s *stubDomainSanitizer) Sanitize(string, string) error { 144 | s.called++ 145 | return s.returns 146 | } 147 | 148 | type stubResultSaver struct { 149 | called int 150 | returns error 151 | } 152 | 153 | func (s *stubResultSaver) Save(*Workfiles, *ctx.ResolveOptions) error { 154 | s.called++ 155 | return s.returns 156 | } 157 | -------------------------------------------------------------------------------- /internal/usecase/resolve/wildcardfilter.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | 8 | "github.com/d3mondev/puredns/v2/internal/pkg/console" 9 | "github.com/d3mondev/puredns/v2/pkg/fileoperation" 10 | "github.com/d3mondev/puredns/v2/pkg/progressbar" 11 | "github.com/d3mondev/puredns/v2/pkg/wildcarder" 12 | ) 13 | 14 | // WildcardFilterOptions defines the options for the Filter function. 15 | type WildcardFilterOptions struct { 16 | // Input files 17 | CacheFilename string 18 | 19 | // Output files 20 | DomainOutputFilename string 21 | RootOutputFilename string 22 | 23 | // Filtering parameters 24 | Resolvers []string 25 | QueriesPerSecond int 26 | ThreadCount int 27 | ResolveTestCount int 28 | BatchSize int 29 | } 30 | 31 | // DefaultWildcardFilter implements the WildcardFilter interface used to filter wildcards. 32 | type DefaultWildcardFilter struct { 33 | wc *wildcarder.Wildcarder 34 | } 35 | 36 | // NewDefaultWildcardFilter returns a new DefaultWildcardFilter object. 37 | func NewDefaultWildcardFilter() *DefaultWildcardFilter { 38 | return &DefaultWildcardFilter{} 39 | } 40 | 41 | // Filter returns the number of domains that are not wildcards along with the wildcard roots found. It uses the massdns cache file 42 | // to prepopulate a cache of DNS responses to optimize the number of DNS queries to perform. It saves the results to the specified filenames. 43 | func (f *DefaultWildcardFilter) Filter(opt WildcardFilterOptions, totalCount int) (found int, roots []string, err error) { 44 | // Create the cache file reader 45 | cacheReader, err := createCacheReader(opt.CacheFilename) 46 | if err != nil { 47 | return 0, nil, err 48 | } 49 | defer cacheReader.Close() 50 | 51 | // Create wildcarder 52 | f.wc = createWildcarder(opt) 53 | 54 | // Create temporary file 55 | tempFile, err := ioutil.TempFile("", "") 56 | if err != nil { 57 | return 0, nil, err 58 | } 59 | defer func() { tempFile.Close(); os.Remove(tempFile.Name()) }() 60 | 61 | // Start progress bar 62 | tmpl := "[ETA {{ eta }}] {{ bar }} {{ current }}/{{ total }} queries: {{ queries }} (time: {{ time }})" 63 | bar := progressbar.New(f.updateProgressBar, int64(totalCount), progressbar.WithTemplate(tmpl), progressbar.WithWriter(console.Output)) 64 | bar.Start() 65 | 66 | // Process entries in batch to prevent the precache from taking too much memory on very large (70M+) 67 | // domain lists. The wildcard cache and DNS cache stay intact and keep growing between batches for now. 68 | rootMap := make(map[string]struct{}) 69 | for { 70 | // Load precache batch 71 | precache, domainFile, count, err := prepareCache(cacheReader, tempFile.Name(), opt.BatchSize) 72 | if err != nil { 73 | return 0, nil, err 74 | } 75 | 76 | // Nothing to process, we're done! 77 | if count == 0 { 78 | break 79 | } 80 | 81 | // Set current precache 82 | f.wc.SetPreCache(precache) 83 | 84 | // Filter wildcards 85 | domains, roots := f.wc.Filter(domainFile) 86 | domainFile.Close() 87 | found += len(domains) 88 | 89 | // Save domains found 90 | if err := fileoperation.AppendLines(domains, opt.DomainOutputFilename); err != nil { 91 | return 0, nil, err 92 | } 93 | 94 | // Keep unique roots in map 95 | for _, root := range roots { 96 | rootMap[root] = struct{}{} 97 | } 98 | } 99 | 100 | // Save roots found 101 | var rootList []string 102 | for root := range rootMap { 103 | rootList = append(rootList, root) 104 | } 105 | 106 | if err := fileoperation.AppendLines(rootList, opt.RootOutputFilename); err != nil { 107 | return 0, nil, err 108 | } 109 | 110 | // Stop progress bar 111 | bar.Stop() 112 | 113 | return found, rootList, nil 114 | } 115 | 116 | // updateProgressBar is function called asynchronously to update the progress bar. 117 | func (f *DefaultWildcardFilter) updateProgressBar(bar *progressbar.ProgressBar) { 118 | current := f.wc.Current() 119 | 120 | bar.SetCurrent(int64(current)) 121 | bar.Set("queries", fmt.Sprintf("%d", f.wc.QueryCount())) 122 | } 123 | 124 | // createCacheReader creates a new cache reader. The reader needs to be closed 125 | // by the caller in order to free the file. 126 | func createCacheReader(filename string) (*CacheReader, error) { 127 | // Open the cache file 128 | cacheFile, err := os.Open(filename) 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | cacheReader := NewCacheReader(cacheFile) 134 | 135 | return cacheReader, nil 136 | } 137 | 138 | // createWildcarder creates a new wildcarder.Wildcarder object from the options specified. 139 | func createWildcarder(opt WildcardFilterOptions) *wildcarder.Wildcarder { 140 | // Convert global QPS to a QPS per resolver 141 | qps := qpsPerResolver(len(opt.Resolvers), opt.QueriesPerSecond) 142 | 143 | // Create a custom resolver 144 | resolver := wildcarder.NewClientDNS(opt.Resolvers, 10, qps, 100) 145 | 146 | // Create the wildcarder with the custom resolver 147 | wc := wildcarder.New(opt.ThreadCount, opt.ResolveTestCount, wildcarder.WithResolver(resolver)) 148 | 149 | return wc 150 | } 151 | 152 | // prepareCache loads a massdns cache from a reader, saves the valid domains it contains to a file, 153 | // and returns a populated wildcarder.DNSCache object along with the domain file created. 154 | // The caller is responsible to close the domain file returned. 155 | func prepareCache(cacheReader *CacheReader, tempFilename string, batchSize int) (*wildcarder.DNSCache, *os.File, int, error) { 156 | // Create the temporary file that will hold domains 157 | domainFile, err := os.Create(tempFilename) 158 | if err != nil { 159 | return nil, nil, 0, err 160 | } 161 | 162 | // Load cache and save found domains to file 163 | precache := wildcarder.NewDNSCache() 164 | totalCount, err := cacheReader.Read(domainFile, precache, batchSize) 165 | 166 | // Make sure domain data is written to disk and seek to the beginning of the file 167 | if err := domainFile.Sync(); err != nil { 168 | return nil, nil, 0, err 169 | } 170 | 171 | if _, err := domainFile.Seek(0, 0); err != nil { 172 | return nil, nil, 0, err 173 | } 174 | 175 | return precache, domainFile, totalCount, err 176 | } 177 | 178 | // qpsPerResolver transforms a global number of queries per second into a number of queries per second per resolver. 179 | func qpsPerResolver(resolverCount, globalQPS int) int { 180 | if resolverCount == 0 { 181 | return 0 182 | } 183 | 184 | qps := globalQPS / resolverCount 185 | 186 | // Set a minimum of 1 query per second 187 | if qps == 0 { 188 | qps = 1 189 | } 190 | 191 | return qps 192 | } 193 | -------------------------------------------------------------------------------- /internal/usecase/resolve/wildcardfilter_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/puredns/v2/pkg/filetest" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestWildcardFilter(t *testing.T) { 11 | cacheFile := filetest.CreateFile(t, "") 12 | domainFile := filetest.CreateFile(t, "") 13 | rootFile := filetest.CreateFile(t, "") 14 | 15 | wc := NewDefaultWildcardFilter() 16 | 17 | opt := WildcardFilterOptions{ 18 | CacheFilename: cacheFile.Name(), 19 | DomainOutputFilename: domainFile.Name(), 20 | RootOutputFilename: rootFile.Name(), 21 | Resolvers: []string{}, 22 | QueriesPerSecond: 10, 23 | ThreadCount: 1, 24 | } 25 | _, _, err := wc.Filter(opt, 1) 26 | 27 | assert.Nil(t, err) 28 | } 29 | 30 | func TestWildcardFilter_Files(t *testing.T) { 31 | badFile := filetest.CreateDir(t) 32 | cacheFile := filetest.CreateFile(t, "example.com. A 127.0.0.1") 33 | domainFile := filetest.CreateFile(t, "") 34 | rootFile := filetest.CreateFile(t, "") 35 | 36 | tests := []struct { 37 | name string 38 | haveCacheFile string 39 | haveOutputDomainFile string 40 | haveOutputRootFile string 41 | wantErr bool 42 | }{ 43 | { 44 | name: "valid files", 45 | haveCacheFile: cacheFile.Name(), 46 | haveOutputDomainFile: domainFile.Name(), 47 | haveOutputRootFile: rootFile.Name(), 48 | wantErr: false, 49 | }, 50 | { 51 | name: "cache file error handling", 52 | haveCacheFile: "", 53 | haveOutputDomainFile: domainFile.Name(), 54 | haveOutputRootFile: rootFile.Name(), 55 | wantErr: true, 56 | }, 57 | { 58 | name: "output file error handling", 59 | haveCacheFile: cacheFile.Name(), 60 | haveOutputDomainFile: badFile, 61 | haveOutputRootFile: rootFile.Name(), 62 | wantErr: true, 63 | }, 64 | { 65 | name: "root file error handling", 66 | haveCacheFile: cacheFile.Name(), 67 | haveOutputDomainFile: domainFile.Name(), 68 | haveOutputRootFile: badFile, 69 | wantErr: true, 70 | }, 71 | } 72 | 73 | for _, test := range tests { 74 | t.Run(test.name, func(t *testing.T) { 75 | filter := NewDefaultWildcardFilter() 76 | 77 | opt := WildcardFilterOptions{ 78 | CacheFilename: test.haveCacheFile, 79 | DomainOutputFilename: test.haveOutputDomainFile, 80 | RootOutputFilename: test.haveOutputRootFile, 81 | Resolvers: []string{}, 82 | QueriesPerSecond: 10, 83 | ThreadCount: 1, 84 | } 85 | 86 | _, _, err := filter.Filter(opt, 0) 87 | 88 | assert.Equal(t, test.wantErr, err != nil) 89 | }) 90 | } 91 | } 92 | 93 | func TestQPSPerResolver(t *testing.T) { 94 | tests := []struct { 95 | name string 96 | haveResolverCount int 97 | haveGlobalQPS int 98 | want int 99 | }{ 100 | {name: "no resolvers", haveResolverCount: 0, haveGlobalQPS: 10, want: 0}, 101 | {name: "two resolvers", haveResolverCount: 2, haveGlobalQPS: 10, want: 5}, 102 | {name: "many resolvers", haveResolverCount: 10, haveGlobalQPS: 1, want: 1}, 103 | } 104 | 105 | for _, test := range tests { 106 | t.Run(test.name, func(t *testing.T) { 107 | got := qpsPerResolver(test.haveResolverCount, test.haveGlobalQPS) 108 | 109 | assert.Equal(t, test.want, got) 110 | }) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /internal/usecase/resolve/workfilecreator.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | ) 8 | 9 | // Workfiles are temporary files used during the program execution. 10 | type Workfiles struct { 11 | TempDirectory string 12 | 13 | Domains string 14 | MassdnsPublic string 15 | MassdnsTrusted string 16 | Temporary string 17 | 18 | PublicResolvers string 19 | TrustedResolvers string 20 | 21 | WildcardRoots string 22 | } 23 | 24 | // Close deletes all the temporary files that were created. 25 | func (w *Workfiles) Close() { 26 | if w.TempDirectory != "" { 27 | os.RemoveAll(w.TempDirectory) 28 | } 29 | } 30 | 31 | // DefaultWorkfileCreator is a service that creates a set of workfiles on disk. 32 | type DefaultWorkfileCreator struct { 33 | osMkdirTemp func(dir string, pattern string) (string, error) 34 | osCreate func(name string) (*os.File, error) 35 | } 36 | 37 | // NewDefaultWorkfileCreator creates a new set of temporary files. 38 | // Call Close() to cleanup the files once they are no longer needed. 39 | func NewDefaultWorkfileCreator() *DefaultWorkfileCreator { 40 | return &DefaultWorkfileCreator{ 41 | osMkdirTemp: ioutil.TempDir, 42 | osCreate: os.Create, 43 | } 44 | } 45 | 46 | // Create creates a new set of workfiles. 47 | func (w *DefaultWorkfileCreator) Create() (*Workfiles, error) { 48 | files := &Workfiles{} 49 | 50 | dir, err := w.osMkdirTemp("", "puredns.") 51 | if err != nil { 52 | return nil, fmt.Errorf("unable to create temporary work directory: %w", err) 53 | } 54 | 55 | files.TempDirectory = dir 56 | 57 | if files.Domains, err = w.createFile(files.TempDirectory + "/" + "domains.txt"); err != nil { 58 | return nil, err 59 | } 60 | 61 | if files.MassdnsPublic, err = w.createFile(files.TempDirectory + "/" + "massdns_public.txt"); err != nil { 62 | return nil, err 63 | } 64 | 65 | if files.MassdnsTrusted, err = w.createFile(files.TempDirectory + "/" + "massdns_trusted.txt"); err != nil { 66 | return nil, err 67 | } 68 | 69 | if files.Temporary, err = w.createFile(files.TempDirectory + "/" + "temporary.txt"); err != nil { 70 | return nil, err 71 | } 72 | 73 | if files.PublicResolvers, err = w.createFile(files.TempDirectory + "/" + "resolvers.txt"); err != nil { 74 | return nil, err 75 | } 76 | 77 | if files.TrustedResolvers, err = w.createFile(files.TempDirectory + "/" + "trusted.txt"); err != nil { 78 | return nil, err 79 | } 80 | 81 | if files.WildcardRoots, err = w.createFile(files.TempDirectory + "/" + "wildcards.txt"); err != nil { 82 | return nil, err 83 | } 84 | 85 | return files, nil 86 | } 87 | 88 | func (w *DefaultWorkfileCreator) createFile(filepath string) (string, error) { 89 | file, err := w.osCreate(filepath) 90 | if err != nil { 91 | return "", fmt.Errorf("unable to create temporary file %s: %w", filepath, err) 92 | } 93 | defer file.Close() 94 | 95 | return file.Name(), nil 96 | } 97 | -------------------------------------------------------------------------------- /internal/usecase/resolve/workfilecreator_test.go: -------------------------------------------------------------------------------- 1 | package resolve 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type stubFileCreator struct { 12 | successes int 13 | returnsOnFailure error 14 | } 15 | 16 | func (s *stubFileCreator) Create(filepath string) (*os.File, error) { 17 | s.successes-- 18 | 19 | if s.successes < 0 { 20 | return nil, s.returnsOnFailure 21 | } 22 | 23 | return os.Create(filepath) 24 | } 25 | 26 | type stubDirCreator struct { 27 | returns error 28 | } 29 | 30 | func (s *stubDirCreator) MkdirTemp(dir string, pattern string) (string, error) { 31 | if s.returns != nil { 32 | return "", s.returns 33 | } 34 | 35 | return os.MkdirTemp("", "") 36 | } 37 | 38 | func TestCreate(t *testing.T) { 39 | createError := errors.New("create failed") 40 | mkDirTempError := errors.New("mkdirtemp failed") 41 | 42 | tests := []struct { 43 | name string 44 | haveMkdirTempError error 45 | haveCreateSuccesses int 46 | haveCreateError error 47 | wantErr error 48 | }{ 49 | {name: "success", haveCreateSuccesses: 100}, 50 | {name: "mkdirtemp error handling", haveMkdirTempError: mkDirTempError, wantErr: mkDirTempError}, 51 | {name: "first create error handling", haveCreateSuccesses: 0, haveCreateError: createError, wantErr: createError}, 52 | {name: "second create error handling", haveCreateSuccesses: 1, haveCreateError: createError, wantErr: createError}, 53 | {name: "third create error handling", haveCreateSuccesses: 2, haveCreateError: createError, wantErr: createError}, 54 | {name: "fourth create error handling", haveCreateSuccesses: 3, haveCreateError: createError, wantErr: createError}, 55 | {name: "fifth create error handling", haveCreateSuccesses: 4, haveCreateError: createError, wantErr: createError}, 56 | {name: "sixth create error handling", haveCreateSuccesses: 5, haveCreateError: createError, wantErr: createError}, 57 | } 58 | 59 | for _, test := range tests { 60 | t.Run(test.name, func(t *testing.T) { 61 | createFile := stubFileCreator{successes: test.haveCreateSuccesses, returnsOnFailure: test.haveCreateError} 62 | createDir := stubDirCreator{returns: test.haveMkdirTempError} 63 | 64 | creator := NewDefaultWorkfileCreator() 65 | creator.osCreate = createFile.Create 66 | creator.osMkdirTemp = createDir.MkdirTemp 67 | 68 | gotFiles, gotErr := creator.Create() 69 | 70 | if gotFiles != nil { 71 | defer gotFiles.Close() 72 | } 73 | 74 | assert.ErrorIs(t, gotErr, test.wantErr) 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /internal/usecase/sponsors/service.go: -------------------------------------------------------------------------------- 1 | package sponsors 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "os" 8 | ) 9 | 10 | // Service is a Sponsors service. 11 | type Service struct{} 12 | 13 | // NewService creates a new Service. 14 | func NewService() *Service { 15 | return &Service{} 16 | } 17 | 18 | // Show downloads the sponsors text file and sends it to the console. 19 | func (s Service) Show(fileURL string) error { 20 | resp, err := http.Get(fileURL) 21 | if err != nil { 22 | return err 23 | } 24 | defer resp.Body.Close() 25 | 26 | _, err = io.Copy(os.Stdout, resp.Body) 27 | fmt.Println() 28 | 29 | return err 30 | } 31 | -------------------------------------------------------------------------------- /internal/usecase/sponsors/service_test.go: -------------------------------------------------------------------------------- 1 | package sponsors 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestShow_OK(t *testing.T) { 15 | // Restore stdout on exit 16 | stdOut := os.Stdout 17 | defer func() { os.Stdout = stdOut }() 18 | 19 | // Override stdout 20 | r, w, err := os.Pipe() 21 | require.Nil(t, err) 22 | os.Stdout = w 23 | 24 | // Stub http server 25 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | fmt.Fprintf(w, "test") 27 | })) 28 | defer ts.Close() 29 | 30 | service := NewService() 31 | err = service.Show(ts.URL) 32 | assert.Nil(t, err) 33 | 34 | buf := make([]byte, 1024) 35 | n, err := r.Read(buf) 36 | require.Nil(t, err) 37 | 38 | assert.Equal(t, "test\n", string(buf[:n])) 39 | } 40 | 41 | func TestShow_HTTPError(t *testing.T) { 42 | service := NewService() 43 | err := service.Show("") 44 | 45 | assert.NotNil(t, err) 46 | } 47 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/d3mondev/puredns/v2/internal/app/cmd" 8 | "github.com/d3mondev/puredns/v2/internal/app/ctx" 9 | ) 10 | 11 | var exitHandler func(int) = os.Exit 12 | 13 | func main() { 14 | ctx := ctx.NewCtx() 15 | 16 | if err := cmd.Execute(ctx); err != nil { 17 | fmt.Fprintf(os.Stderr, "puredns error: %s\n", err) 18 | exitHandler(1) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type spyExitHandler struct { 11 | count int 12 | lastCode int 13 | } 14 | 15 | func (s *spyExitHandler) Exit(code int) { 16 | s.count++ 17 | s.lastCode = code 18 | } 19 | 20 | func TestMain(t *testing.T) { 21 | spyExit := spyExitHandler{} 22 | 23 | os.Args = []string{os.Args[0], "--version"} 24 | exitHandler = spyExit.Exit 25 | 26 | main() 27 | 28 | assert.Equal(t, 0, spyExit.count) 29 | assert.Equal(t, 0, spyExit.lastCode) 30 | } 31 | 32 | func TestMainError(t *testing.T) { 33 | spyExit := spyExitHandler{} 34 | 35 | os.Args = []string{os.Args[0], "invalid-command"} 36 | exitHandler = spyExit.Exit 37 | 38 | main() 39 | 40 | assert.Equal(t, 1, spyExit.count) 41 | assert.Equal(t, 1, spyExit.lastCode) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/fileoperation/appendlines.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | // AppendLines appends all lines to a text file, creating a new file if it doesn't exist. 8 | func AppendLines(lines []string, filename string) error { 9 | file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 10 | if err != nil { 11 | return err 12 | } 13 | defer file.Close() 14 | 15 | return writelines(lines, file, 64*1024) 16 | } 17 | -------------------------------------------------------------------------------- /pkg/fileoperation/appendlines_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "os" 5 | "syscall" 6 | "testing" 7 | 8 | "github.com/d3mondev/puredns/v2/pkg/filetest" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestAppendLines(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | haveFileContent string 17 | haveLines []string 18 | wantContent []string 19 | wantErr bool 20 | }{ 21 | {name: "empty file", haveLines: []string{"foo", "bar"}, wantContent: []string{"foo", "bar"}}, 22 | {name: "file with content", haveFileContent: "one\ntwo\n", haveLines: []string{"foo", "bar"}, wantContent: []string{"one", "two", "foo", "bar"}}, 23 | {name: "no lines", haveFileContent: "one\ntwo\n", wantContent: []string{"one", "two"}}, 24 | } 25 | 26 | for _, test := range tests { 27 | t.Run(test.name, func(t *testing.T) { 28 | dir := filetest.CreateDir(t) 29 | filename := dir + "/testfile.txt" 30 | 31 | if test.haveFileContent != "" { 32 | file, err := os.Create(filename) 33 | require.Nil(t, err) 34 | _, err = file.WriteString(test.haveFileContent) 35 | require.Nil(t, err) 36 | require.Nil(t, file.Sync()) 37 | require.Nil(t, file.Close()) 38 | } 39 | 40 | gotErr := AppendLines(test.haveLines, filename) 41 | gotContent := filetest.ReadFile(t, filename) 42 | 43 | assert.Equal(t, test.wantErr, gotErr != nil) 44 | assert.Equal(t, test.wantContent, gotContent) 45 | }) 46 | } 47 | } 48 | 49 | func TestAppendLines_OpenError(t *testing.T) { 50 | dir := filetest.CreateDir(t) 51 | err := AppendLines([]string{}, dir) 52 | assert.ErrorIs(t, err, syscall.Errno(21)) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/fileoperation/appendword.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | ) 9 | 10 | // AppendWord appends a separator and a word to each lines present in a text file. 11 | func AppendWord(src string, dest string, sep string, word string) error { 12 | srcFile, err := os.Open(src) 13 | if err != nil { 14 | return err 15 | } 16 | defer srcFile.Close() 17 | 18 | destFile, err := os.Create(dest) 19 | if err != nil { 20 | return err 21 | } 22 | defer destFile.Close() 23 | 24 | return appendWordIO(srcFile, destFile, sep, word) 25 | } 26 | 27 | func appendWordIO(r io.Reader, w io.Writer, sep string, word string) error { 28 | scanner := bufio.NewScanner(r) 29 | writer := bufio.NewWriter(w) 30 | 31 | for scanner.Scan() { 32 | line := scanner.Text() + sep + word 33 | fmt.Fprintln(writer, line) 34 | } 35 | 36 | if err := scanner.Err(); err != nil { 37 | return err 38 | } 39 | 40 | return writer.Flush() 41 | } 42 | -------------------------------------------------------------------------------- /pkg/fileoperation/appendword_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAppendWord(t *testing.T) { 12 | srcFile := filetest.CreateFile(t, "A\nB\nC\n") 13 | destFile := filetest.CreateFile(t, "") 14 | 15 | tests := []struct { 16 | name string 17 | haveSource string 18 | haveDest string 19 | wantErr bool 20 | }{ 21 | {name: "ok", haveSource: srcFile.Name(), haveDest: destFile.Name(), wantErr: false}, 22 | {name: "source error handling", haveSource: "thisfiledoesnotexist.txt", haveDest: destFile.Name(), wantErr: true}, 23 | {name: "dest error handling", haveSource: srcFile.Name(), haveDest: "", wantErr: true}, 24 | } 25 | 26 | for _, test := range tests { 27 | t.Run(test.name, func(t *testing.T) { 28 | gotErr := AppendWord(test.haveSource, test.haveDest, ":", "word") 29 | 30 | assert.Equal(t, test.wantErr, gotErr != nil) 31 | }) 32 | } 33 | } 34 | 35 | func TestAppendWordIO(t *testing.T) { 36 | tests := []struct { 37 | name string 38 | haveReader *filetest.StubReader 39 | haveWriter *filetest.StubWriter 40 | haveSep string 41 | haveWord string 42 | wantBuffer []byte 43 | wantErr bool 44 | }{ 45 | {name: "ok", haveReader: filetest.NewStubReader([]byte("A\nB\nC\n"), nil), haveWriter: filetest.NewStubWriter(nil), haveSep: ":", haveWord: "word", wantBuffer: []byte("A:word\nB:word\nC:word\n"), wantErr: false}, 46 | {name: "reader error handling", haveReader: filetest.NewStubReader(nil, errors.New("read error")), haveWriter: filetest.NewStubWriter(nil), haveSep: ":", haveWord: "word", wantErr: true}, 47 | {name: "writer error handling", haveReader: filetest.NewStubReader([]byte("A\nB\nC\n"), nil), haveWriter: filetest.NewStubWriter(errors.New("write error")), haveSep: ":", haveWord: "word", wantErr: true}, 48 | } 49 | 50 | for _, test := range tests { 51 | t.Run(test.name, func(t *testing.T) { 52 | gotErr := appendWordIO(test.haveReader, test.haveWriter, test.haveSep, test.haveWord) 53 | 54 | assert.Equal(t, test.wantErr, gotErr != nil) 55 | 56 | if gotErr == nil { 57 | assert.ElementsMatch(t, test.wantBuffer, test.haveWriter.Buffer) 58 | } 59 | }) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /pkg/fileoperation/cat.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | ) 9 | 10 | // Cat reads files sequentially and sends the output to the specified writer. 11 | func Cat(filenames []string, w io.Writer) error { 12 | readers := []io.Reader{} 13 | 14 | for _, name := range filenames { 15 | file, err := os.Open(name) 16 | if err != nil { 17 | return err 18 | } 19 | defer file.Close() 20 | 21 | readers = append(readers, file) 22 | } 23 | 24 | return CatIO(readers, w) 25 | } 26 | 27 | // CatIO reads sequentially from readers and sends the output to the specified writer. 28 | func CatIO(readers []io.Reader, w io.Writer) error { 29 | for _, r := range readers { 30 | scanner := bufio.NewScanner(r) 31 | for scanner.Scan() { 32 | line := scanner.Text() 33 | if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { 34 | return err 35 | } 36 | } 37 | 38 | if err := scanner.Err(); err != nil { 39 | return err 40 | } 41 | } 42 | 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /pkg/fileoperation/cat_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "strings" 7 | "testing" 8 | "testing/iotest" 9 | 10 | "github.com/d3mondev/puredns/v2/pkg/filetest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestCat(t *testing.T) { 15 | testFileA := filetest.CreateFile(t, "contentA") 16 | testFileB := filetest.CreateFile(t, "contentB") 17 | 18 | tests := []struct { 19 | name string 20 | haveFilenames []string 21 | wantBuffer string 22 | wantErr bool 23 | }{ 24 | {name: "cat single file", haveFilenames: []string{testFileA.Name()}, wantBuffer: "contentA\n", wantErr: false}, 25 | {name: "cat two files", haveFilenames: []string{testFileA.Name(), testFileB.Name()}, wantBuffer: "contentA\ncontentB\n", wantErr: false}, 26 | {name: "file error handling", haveFilenames: []string{"thisfiledoesnotexist.txt"}, wantErr: true}, 27 | } 28 | 29 | for _, test := range tests { 30 | t.Run(test.name, func(t *testing.T) { 31 | stubWriter := filetest.NewStubWriter(nil) 32 | 33 | err := Cat(test.haveFilenames, stubWriter) 34 | 35 | assert.Equal(t, test.wantErr, err != nil) 36 | assert.Equal(t, test.wantBuffer, string(stubWriter.Buffer)) 37 | }) 38 | } 39 | } 40 | 41 | func TestCatIO(t *testing.T) { 42 | tests := []struct { 43 | name string 44 | haveReader io.Reader 45 | haveWriter *filetest.StubWriter 46 | wantBuffer []byte 47 | wantErr bool 48 | }{ 49 | {name: "ok", haveReader: strings.NewReader("test\nfile\n"), haveWriter: filetest.NewStubWriter(nil), wantBuffer: []byte("test\nfile\n"), wantErr: false}, 50 | {name: "read error handling", haveReader: iotest.ErrReader(errors.New("reader error")), haveWriter: filetest.NewStubWriter(nil), wantBuffer: []byte{}, wantErr: true}, 51 | {name: "write error handling", haveReader: strings.NewReader("test\nfile\n"), haveWriter: filetest.NewStubWriter(errors.New("write error")), wantBuffer: []byte{}, wantErr: true}, 52 | } 53 | 54 | for _, test := range tests { 55 | t.Run(test.name, func(t *testing.T) { 56 | gotErr := CatIO([]io.Reader{test.haveReader}, test.haveWriter) 57 | 58 | assert.ElementsMatch(t, test.wantBuffer, test.haveWriter.Buffer) 59 | assert.Equal(t, test.wantErr, gotErr != nil) 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pkg/fileoperation/copy.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | ) 8 | 9 | // Copy copies a file from the source filename to the destination filename. 10 | func Copy(src string, dest string) error { 11 | source, err := os.Open(src) 12 | if err != nil { 13 | return err 14 | } 15 | defer source.Close() 16 | 17 | destination, err := os.Create(dest) 18 | if err != nil { 19 | return err 20 | } 21 | defer destination.Close() 22 | 23 | return copyByBuffer(source, destination, 64*1024) 24 | } 25 | 26 | func copyByBuffer(r io.Reader, w io.Writer, bufferSize int) error { 27 | buf := make([]byte, bufferSize) 28 | writer := bufio.NewWriterSize(w, bufferSize) 29 | 30 | for { 31 | n, err := r.Read(buf) 32 | 33 | if err != nil && err != io.EOF { 34 | return err 35 | } 36 | 37 | if n == 0 { 38 | break 39 | } 40 | 41 | if _, err := writer.Write(buf[:n]); err != nil { 42 | return err 43 | } 44 | } 45 | 46 | return writer.Flush() 47 | } 48 | -------------------------------------------------------------------------------- /pkg/fileoperation/copy_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestCopy(t *testing.T) { 12 | srcFile := filetest.CreateFile(t, "file content") 13 | destFile := filetest.CreateFile(t, "") 14 | 15 | tests := []struct { 16 | name string 17 | haveSrc string 18 | haveDest string 19 | wantErr bool 20 | }{ 21 | {name: "ok", haveSrc: srcFile.Name(), haveDest: destFile.Name(), wantErr: false}, 22 | {name: "source file error handling", haveSrc: "thisfiledoesnotexist.txt", haveDest: destFile.Name(), wantErr: true}, 23 | {name: "dest file error handling", haveSrc: srcFile.Name(), haveDest: "", wantErr: true}, 24 | } 25 | 26 | for _, test := range tests { 27 | t.Run(test.name, func(t *testing.T) { 28 | gotErr := Copy(test.haveSrc, test.haveDest) 29 | 30 | assert.Equal(t, test.wantErr, gotErr != nil) 31 | }) 32 | } 33 | } 34 | 35 | func TestCopyByBuffer(t *testing.T) { 36 | data := "this is test data" 37 | 38 | tests := []struct { 39 | name string 40 | haveReader *filetest.StubReader 41 | haveWriter *filetest.StubWriter 42 | haveBuffer int 43 | wantErr bool 44 | }{ 45 | {name: "ok", haveReader: filetest.NewStubReader([]byte(data), nil), haveWriter: filetest.NewStubWriter(nil), haveBuffer: 32768, wantErr: false}, 46 | {name: "small buffer", haveReader: filetest.NewStubReader([]byte(data), nil), haveWriter: filetest.NewStubWriter(nil), haveBuffer: 1, wantErr: false}, 47 | {name: "read error handling", haveReader: filetest.NewStubReader(nil, errors.New("read error")), haveWriter: filetest.NewStubWriter(nil), haveBuffer: 32768, wantErr: true}, 48 | {name: "write error handling", haveReader: filetest.NewStubReader([]byte(data), nil), haveWriter: filetest.NewStubWriter(errors.New("write error")), haveBuffer: 1, wantErr: true}, 49 | } 50 | 51 | for _, test := range tests { 52 | t.Run(test.name, func(t *testing.T) { 53 | gotErr := copyByBuffer(test.haveReader, test.haveWriter, test.haveBuffer) 54 | 55 | assert.Equal(t, test.wantErr, gotErr != nil) 56 | 57 | if gotErr == nil { 58 | assert.ElementsMatch(t, test.haveReader.Buffer, test.haveWriter.Buffer) 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pkg/fileoperation/countlines.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "os" 7 | ) 8 | 9 | // CountLines counts the number of lines in a file. 10 | func CountLines(filename string) (int, error) { 11 | file, err := os.Open(filename) 12 | if err != nil { 13 | return 0, err 14 | } 15 | defer file.Close() 16 | 17 | return count(file) 18 | } 19 | 20 | func count(reader io.Reader) (int, error) { 21 | buf := make([]byte, 64*1024) 22 | count := 0 23 | sep := []byte{'\n'} 24 | 25 | for { 26 | c, err := reader.Read(buf) 27 | count += bytes.Count(buf[:c], sep) 28 | 29 | if err == io.EOF { 30 | return count, nil 31 | } 32 | 33 | if err != nil { 34 | return 0, err 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /pkg/fileoperation/countlines_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "os" 7 | "testing" 8 | 9 | "github.com/d3mondev/puredns/v2/pkg/filetest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestCountLines(t *testing.T) { 14 | file, err := os.CreateTemp("", "") 15 | if err != nil { 16 | t.Fatal() 17 | } 18 | defer func() { file.Close(); os.Remove(file.Name()) }() 19 | 20 | _, err = file.WriteString("line1\nline2\nline3\n") 21 | if err != nil { 22 | t.Fatal() 23 | } 24 | 25 | tests := []struct { 26 | name string 27 | haveFilename string 28 | wantCount int 29 | wantErr bool 30 | }{ 31 | {name: "existing file", haveFilename: file.Name(), wantCount: 3}, 32 | {name: "file error handling", haveFilename: "thisfiledoesnotexist.txt", wantErr: true}, 33 | } 34 | 35 | for _, test := range tests { 36 | t.Run(test.name, func(t *testing.T) { 37 | gotCount, gotErr := CountLines(test.haveFilename) 38 | 39 | assert.Equal(t, test.wantCount, gotCount) 40 | assert.Equal(t, test.wantErr, gotErr != nil) 41 | }) 42 | } 43 | } 44 | 45 | func TestCount(t *testing.T) { 46 | lines := "line1\nline2\n" 47 | readerError := errors.New("reader error") 48 | 49 | tests := []struct { 50 | name string 51 | haveReader io.ReadCloser 52 | wantCount int 53 | wantErr error 54 | }{ 55 | {name: "success", haveReader: filetest.NewStubReader([]byte(lines), nil), wantCount: 2, wantErr: nil}, 56 | {name: "empty reader", haveReader: filetest.NewStubReader(nil, nil), wantCount: 0, wantErr: nil}, 57 | {name: "reader error handling", haveReader: filetest.NewStubReader(nil, readerError), wantCount: 0, wantErr: readerError}, 58 | } 59 | 60 | for _, test := range tests { 61 | t.Run(test.name, func(t *testing.T) { 62 | gotCount, gotErr := count(test.haveReader) 63 | 64 | assert.ErrorIs(t, gotErr, test.wantErr) 65 | assert.Equal(t, test.wantCount, gotCount) 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /pkg/fileoperation/doc.go: -------------------------------------------------------------------------------- 1 | // Package fileoperation provides a set of functions commonly used on files. 2 | package fileoperation 3 | -------------------------------------------------------------------------------- /pkg/fileoperation/fileexists.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | // FileExists returns true if a file exists on disk. 8 | func FileExists(path string) bool { 9 | _, err := os.Stat(path) 10 | if os.IsNotExist(err) { 11 | return false 12 | } 13 | return err == nil 14 | } 15 | -------------------------------------------------------------------------------- /pkg/fileoperation/fileexists_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestFileExists(t *testing.T) { 11 | file, err := os.CreateTemp("", "") 12 | if err != nil { 13 | t.Fatal() 14 | } 15 | defer func() { file.Close(); os.Remove(file.Name()) }() 16 | 17 | tests := []struct { 18 | name string 19 | haveFilename string 20 | want bool 21 | }{ 22 | {name: "existing file", haveFilename: file.Name(), want: true}, 23 | {name: "non-existing file", haveFilename: "thisfiledoesnotexist.txt", want: false}, 24 | } 25 | 26 | for _, test := range tests { 27 | t.Run(test.name, func(t *testing.T) { 28 | got := FileExists(test.haveFilename) 29 | 30 | assert.Equal(t, test.want, got) 31 | }) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /pkg/fileoperation/readlines.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | ) 8 | 9 | // ReadLines read lines from a text file and returns them as a slice. 10 | func ReadLines(filename string) ([]string, error) { 11 | file, err := os.Open(filename) 12 | if err != nil { 13 | return nil, err 14 | } 15 | defer file.Close() 16 | 17 | return readlines(file) 18 | } 19 | 20 | func readlines(r io.Reader) ([]string, error) { 21 | var lines []string 22 | scanner := bufio.NewScanner(r) 23 | 24 | for scanner.Scan() { 25 | lines = append(lines, scanner.Text()) 26 | } 27 | 28 | return lines, scanner.Err() 29 | } 30 | -------------------------------------------------------------------------------- /pkg/fileoperation/readlines_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "io/fs" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestReadLines(t *testing.T) { 12 | file := filetest.CreateFile(t, "foo\nbar") 13 | 14 | lines, err := ReadLines(file.Name()) 15 | assert.Nil(t, err) 16 | assert.Equal(t, []string{"foo", "bar"}, lines) 17 | } 18 | 19 | func TestReadLines_FileNotFound(t *testing.T) { 20 | _, err := ReadLines("thisfiledoesnotexist.txt") 21 | assert.ErrorIs(t, err, fs.ErrNotExist) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/fileoperation/writelines.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | ) 8 | 9 | // WriteLines writes all lines to a text file, truncating the file if it already exists. 10 | func WriteLines(lines []string, filename string) error { 11 | file, err := os.Create(filename) 12 | if err != nil { 13 | return err 14 | } 15 | defer file.Close() 16 | 17 | return writelines(lines, file, 64*1024) 18 | } 19 | 20 | func writelines(lines []string, w io.Writer, bufferSize int) error { 21 | writer := bufio.NewWriterSize(w, bufferSize) 22 | 23 | for _, line := range lines { 24 | if _, err := writer.Write([]byte(line + "\n")); err != nil { 25 | return err 26 | } 27 | } 28 | 29 | return writer.Flush() 30 | } 31 | -------------------------------------------------------------------------------- /pkg/fileoperation/writelines_test.go: -------------------------------------------------------------------------------- 1 | package fileoperation 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestWriteLines(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | haveLines []string 15 | haveWriter *filetest.StubWriter 16 | haveBuffer int 17 | wantBuffer []byte 18 | wantErr bool 19 | }{ 20 | {name: "single line", haveLines: []string{"foo"}, haveWriter: filetest.NewStubWriter(nil), haveBuffer: 1024, wantBuffer: []byte("foo\n"), wantErr: false}, 21 | {name: "multiple lines", haveLines: []string{"foo", "bar"}, haveWriter: filetest.NewStubWriter(nil), haveBuffer: 1024, wantBuffer: []byte("foo\nbar\n"), wantErr: false}, 22 | {name: "no lines", haveLines: []string{}, haveWriter: filetest.NewStubWriter(nil), haveBuffer: 1024, wantBuffer: nil, wantErr: false}, 23 | {name: "write error handling", haveLines: []string{"foo"}, haveWriter: filetest.NewStubWriter(errors.New("write error")), haveBuffer: 1, wantBuffer: nil, wantErr: true}, 24 | } 25 | 26 | for _, test := range tests { 27 | t.Run(test.name, func(t *testing.T) { 28 | gotErr := writelines(test.haveLines, test.haveWriter, test.haveBuffer) 29 | 30 | assert.Equal(t, test.wantErr, gotErr != nil) 31 | assert.Equal(t, test.wantBuffer, test.haveWriter.Buffer) 32 | }) 33 | } 34 | } 35 | 36 | func TestWriteLinesFileError(t *testing.T) { 37 | file := filetest.CreateFile(t, "") 38 | 39 | tests := []struct { 40 | name string 41 | haveFilename string 42 | haveLines []string 43 | wantErr bool 44 | }{ 45 | {name: "valid output file", haveFilename: file.Name(), wantErr: false}, 46 | {name: "file error handling", haveFilename: "", wantErr: true}, 47 | } 48 | 49 | for _, test := range tests { 50 | t.Run(test.name, func(t *testing.T) { 51 | gotErr := WriteLines([]string{"foo", "bar"}, test.haveFilename) 52 | 53 | assert.Equal(t, test.wantErr, gotErr != nil) 54 | }) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /pkg/filetest/doc.go: -------------------------------------------------------------------------------- 1 | // Package filetest provides utility functions that can be used during unit testing. 2 | // Any resource created by the package during a test is cleaned up automatically when the test is done. 3 | package filetest 4 | -------------------------------------------------------------------------------- /pkg/filetest/file.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import ( 4 | "bufio" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | ) 9 | 10 | // CreateFile creates a temporary file with the content specified used during the test. 11 | // The file is closed and deleted after the test is done running. 12 | func CreateFile(t *testing.T, content string) *os.File { 13 | file, err := ioutil.TempFile("", "") 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | 18 | if _, err := file.WriteString(content); err != nil { 19 | t.Fatal(err) 20 | } 21 | 22 | if _, err := file.Seek(0, 0); err != nil { 23 | t.Fatal(err) 24 | } 25 | 26 | t.Cleanup(func() { 27 | file.Close() 28 | os.Remove(file.Name()) 29 | }) 30 | 31 | return file 32 | } 33 | 34 | // CreateDir creates a temporary directory. 35 | // The directory is deleted after the test is done running. 36 | func CreateDir(t *testing.T) string { 37 | dir, err := ioutil.TempDir("", "") 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | t.Cleanup(func() { 43 | os.Remove(dir) 44 | }) 45 | 46 | return dir 47 | } 48 | 49 | // ReadFile reads a text file and returns each line in a slice. 50 | // If the file name is empty, returns an empty slice. 51 | func ReadFile(t *testing.T, name string) []string { 52 | lines := []string{} 53 | 54 | if name == "" { 55 | return lines 56 | } 57 | 58 | file, err := os.Open(name) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | defer file.Close() 63 | 64 | scanner := bufio.NewScanner(file) 65 | for scanner.Scan() { 66 | lines = append(lines, scanner.Text()) 67 | } 68 | 69 | if err := scanner.Err(); err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | return lines 74 | } 75 | 76 | // ClearFile truncates the content of a file. 77 | func ClearFile(t *testing.T, file *os.File) { 78 | if err := file.Truncate(0); err != nil { 79 | t.Fatal(err) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /pkg/filetest/file_test.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCreateFile_Empty(t *testing.T) { 10 | file := CreateFile(t, "") 11 | lines := ReadFile(t, file.Name()) 12 | 13 | assert.NotNil(t, file) 14 | assert.Equal(t, []string{}, lines) 15 | } 16 | 17 | func TestCreateFile_Content(t *testing.T) { 18 | file := CreateFile(t, "foo\nbar") 19 | lines := ReadFile(t, file.Name()) 20 | 21 | assert.NotNil(t, file) 22 | assert.Equal(t, []string{"foo", "bar"}, lines) 23 | } 24 | 25 | func TestCreateDir(t *testing.T) { 26 | dir := CreateDir(t) 27 | assert.NotEqual(t, "", dir) 28 | } 29 | 30 | func TestReadFile_OK(t *testing.T) { 31 | file := CreateFile(t, "line1\nline2\nline3") 32 | lines := ReadFile(t, file.Name()) 33 | assert.Equal(t, []string{"line1", "line2", "line3"}, lines) 34 | } 35 | 36 | func TestReadFile_Empty(t *testing.T) { 37 | lines := ReadFile(t, "") 38 | assert.Equal(t, []string{}, lines) 39 | } 40 | 41 | func TestClearFile(t *testing.T) { 42 | file := CreateFile(t, "foo\nbar") 43 | 44 | ClearFile(t, file) 45 | 46 | lines := ReadFile(t, file.Name()) 47 | assert.Equal(t, []string{}, lines) 48 | } 49 | -------------------------------------------------------------------------------- /pkg/filetest/stdin.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | // OverrideStdin overrides os.Stdin with the specified file and restores it after the test. 9 | func OverrideStdin(t *testing.T, f *os.File) { 10 | old := os.Stdin 11 | os.Stdin = f 12 | 13 | t.Cleanup(func() { 14 | os.Stdin = old 15 | }) 16 | } 17 | -------------------------------------------------------------------------------- /pkg/filetest/stubreader.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import "io" 4 | 5 | // StubReader implements the io.ReadCloser interface. 6 | // It is used during tests to return the content of its internal buffer, 7 | // or to return an error as needed. 8 | type StubReader struct { 9 | Buffer []byte 10 | Index int 11 | Err error 12 | } 13 | 14 | // NewStubReader returns a new StubReader object that reads from the buffer specified. 15 | // If err is not nil, it will return an error on the first read. If the end of the buffer is 16 | // reached, it will return an io.EOF error. 17 | func NewStubReader(buffer []byte, err error) *StubReader { 18 | return &StubReader{ 19 | Buffer: buffer, 20 | Err: err, 21 | } 22 | } 23 | 24 | // Read implements the io.Reader interface. 25 | func (r *StubReader) Read(p []byte) (int, error) { 26 | if r.Err != nil { 27 | return 0, r.Err 28 | } 29 | 30 | count := copy(p, r.Buffer[r.Index:]) 31 | r.Index += count 32 | 33 | if r.Err == nil && r.Index >= len(r.Buffer) { 34 | r.Err = io.EOF 35 | } 36 | 37 | return count, r.Err 38 | } 39 | 40 | // Close implements the io.Closer interface. 41 | func (r *StubReader) Close() error { 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /pkg/filetest/stubreader_test.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStubReaderRead(t *testing.T) { 12 | t.Run("read full buffer", func(t *testing.T) { 13 | buffer := []byte("foo") 14 | r := NewStubReader(buffer, nil) 15 | 16 | output := make([]byte, len(buffer)) 17 | count, err := r.Read(output) 18 | 19 | assert.Equal(t, err, io.EOF) 20 | assert.Equal(t, len(output), count) 21 | assert.Equal(t, buffer, output) 22 | }) 23 | 24 | t.Run("read part of buffer", func(t *testing.T) { 25 | buffer := []byte("foo") 26 | r := NewStubReader(buffer, nil) 27 | 28 | output := make([]byte, 1) 29 | count, err := r.Read(output) 30 | 31 | assert.Nil(t, err) 32 | assert.Equal(t, len(output), count) 33 | assert.Equal(t, []byte("f"), output) 34 | }) 35 | 36 | t.Run("generate read error", func(t *testing.T) { 37 | buffer := []byte("foo") 38 | readErr := errors.New("read error") 39 | r := NewStubReader(buffer, readErr) 40 | 41 | output := make([]byte, 1) 42 | count, err := r.Read(output) 43 | 44 | assert.Equal(t, readErr, err) 45 | assert.Equal(t, 0, count) 46 | assert.Equal(t, []byte{0x0}, output) 47 | }) 48 | } 49 | 50 | func TestStubReaderClose(t *testing.T) { 51 | r := NewStubReader(nil, nil) 52 | err := r.Close() 53 | assert.Nil(t, err) 54 | } 55 | -------------------------------------------------------------------------------- /pkg/filetest/stubwriter.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import "io" 4 | 5 | // StubWriter implements the io.Writer interface. 6 | // It is used during tests to examine the content of the data written, 7 | // and to return fake errors as needed. 8 | type StubWriter struct { 9 | Buffer []byte 10 | Count int 11 | Err error 12 | } 13 | 14 | var _ io.WriteCloser = (*StubWriter)(nil) 15 | 16 | // NewStubWriter returns a new StubWriter. 17 | func NewStubWriter(err error) *StubWriter { 18 | return &StubWriter{ 19 | Err: err, 20 | } 21 | } 22 | 23 | // Write implements the io.Writer interface. 24 | func (w *StubWriter) Write(p []byte) (n int, err error) { 25 | if w.Err != nil { 26 | return 0, w.Err 27 | } 28 | 29 | w.Buffer = append(w.Buffer, p...) 30 | w.Count += len(p) 31 | 32 | return len(p), nil 33 | } 34 | 35 | // Close implements the io.Closer interface. 36 | func (w *StubWriter) Close() error { 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /pkg/filetest/stubwriter_test.go: -------------------------------------------------------------------------------- 1 | package filetest 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestStubWriter(t *testing.T) { 11 | t.Run("internal buffer correctly updated", func(t *testing.T) { 12 | w := NewStubWriter(nil) 13 | 14 | buf := []byte("test") 15 | n, err := w.Write(buf) 16 | 17 | assert.Nil(t, err) 18 | assert.Equal(t, len(buf), n) 19 | assert.Equal(t, buf, w.Buffer) 20 | }) 21 | 22 | t.Run("generate write error", func(t *testing.T) { 23 | wantErr := errors.New("error") 24 | w := NewStubWriter(wantErr) 25 | 26 | buf := []byte("test") 27 | n, err := w.Write(buf) 28 | 29 | assert.Equal(t, wantErr, err) 30 | assert.Equal(t, 0, n) 31 | assert.Equal(t, []byte(nil), w.Buffer) 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/massdns/callback.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | type state int 10 | 11 | const ( 12 | stateNewAnswerSection state = iota 13 | stateSaveAnswer 14 | stateSkip 15 | ) 16 | 17 | // DefaultWriteCallback is a callback that can save massdns results to files. 18 | // It can save the valid domains found and the massdns results that gave valid domains. 19 | type DefaultWriteCallback struct { 20 | massdnsFile *os.File 21 | domainFile *os.File 22 | 23 | curState state 24 | curDomain string 25 | domainSaved bool 26 | 27 | found int 28 | } 29 | 30 | var _ Callback = (*DefaultWriteCallback)(nil) 31 | 32 | // NewDefaultWriteCallback creates a new DefaultWriteCallback. 33 | // The file names can be empty to disable saving to a file. 34 | func NewDefaultWriteCallback(massdnsFilename string, domainFilename string) (*DefaultWriteCallback, error) { 35 | cb := &DefaultWriteCallback{} 36 | 37 | // Create a writer that writes massdns answers 38 | if massdnsFilename != "" { 39 | file, err := os.Create(massdnsFilename) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | cb.massdnsFile = file 45 | } 46 | 47 | // Create a writer that writes valid domains found 48 | if domainFilename != "" { 49 | file, err := os.Create(domainFilename) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | cb.domainFile = file 55 | } 56 | 57 | return cb, nil 58 | } 59 | 60 | // Callback reads a line from the massdns stdout handler, parses the output and 61 | // saves the relevant data. 62 | func (c *DefaultWriteCallback) Callback(line string) error { 63 | // Don't parse JSON if we're not saving anything 64 | if c.domainFile == nil && c.massdnsFile == nil { 65 | return nil 66 | } 67 | 68 | // If we receive an empty line, it's the start of a new answer 69 | if line == "" { 70 | c.curState = stateNewAnswerSection 71 | return nil 72 | } 73 | 74 | switch c.curState { 75 | // We're at the beginning of a new answer section, look for the domain name 76 | case stateNewAnswerSection: 77 | parts := strings.Split(line, " ") 78 | if len(parts) != 3 { 79 | c.curState = stateSkip 80 | return nil 81 | } 82 | 83 | domain := strings.TrimSuffix(parts[0], ".") 84 | if domain == "" { 85 | c.curState = stateSkip 86 | return nil 87 | } 88 | 89 | c.curDomain = domain 90 | c.curState = stateSaveAnswer 91 | c.domainSaved = false 92 | fallthrough 93 | 94 | // Save the answer record found 95 | case stateSaveAnswer: 96 | parts := strings.Split(line, " ") 97 | if len(parts) != 3 { 98 | c.curState = stateSkip 99 | return nil 100 | } 101 | 102 | domain := c.curDomain 103 | rrType := parts[1] 104 | answer := strings.TrimSuffix(parts[2], ".") 105 | 106 | // Only look for A, AAAA, and CNAME records 107 | if rrType != "A" && rrType != "AAAA" && rrType != "CNAME" { 108 | return nil 109 | } 110 | 111 | // If we haven't saved the domain yet, save it 112 | if !c.domainSaved { 113 | c.saveDomain(c.curDomain) 114 | c.domainSaved = true 115 | c.found++ 116 | } 117 | 118 | // Valid record found, save it 119 | return c.saveLine(fmt.Sprintf("%s %s %s", domain, rrType, answer)) 120 | 121 | // Answer was invalid, skip until we receive a new answer section 122 | case stateSkip: 123 | return nil 124 | } 125 | 126 | return nil 127 | } 128 | 129 | // saveLine saves a line to the massdns file. 130 | func (c *DefaultWriteCallback) saveLine(line string) error { 131 | if c.massdnsFile != nil { 132 | _, err := c.massdnsFile.WriteString(line + "\n") 133 | return err 134 | } 135 | 136 | return nil 137 | } 138 | 139 | // saveDomain saves a domain to the domain file. 140 | func (c *DefaultWriteCallback) saveDomain(domain string) error { 141 | if c.domainFile != nil { 142 | _, err := c.domainFile.WriteString(domain + "\n") 143 | return err 144 | } 145 | 146 | return nil 147 | } 148 | 149 | // Close closes the writers. 150 | func (c *DefaultWriteCallback) Close() { 151 | if c.massdnsFile != nil { 152 | c.massdnsFile.Sync() 153 | c.massdnsFile.Close() 154 | } 155 | 156 | if c.domainFile != nil { 157 | c.domainFile.Sync() 158 | c.domainFile.Close() 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /pkg/massdns/callback_test.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "syscall" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/filetest" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewDefaultWriteCallback(t *testing.T) { 13 | domainFile := filetest.CreateFile(t, "") 14 | massdnsFile := filetest.CreateFile(t, "") 15 | badFile := filetest.CreateDir(t) 16 | 17 | tests := []struct { 18 | name string 19 | haveMassdnsFile string 20 | haveDomainFile string 21 | want error 22 | }{ 23 | {name: "ok", haveMassdnsFile: massdnsFile.Name(), haveDomainFile: domainFile.Name()}, 24 | {name: "massdns file error handling", haveMassdnsFile: badFile, want: syscall.Errno(21)}, 25 | {name: "domain file error handling", haveDomainFile: badFile, want: syscall.Errno(21)}, 26 | } 27 | 28 | for _, test := range tests { 29 | t.Run(test.name, func(t *testing.T) { 30 | _, err := NewDefaultWriteCallback(test.haveMassdnsFile, test.haveDomainFile) 31 | assert.ErrorIs(t, err, test.want) 32 | }) 33 | } 34 | } 35 | 36 | func TestDefaultWriteCallback(t *testing.T) { 37 | massdnsFile := filetest.CreateFile(t, "") 38 | domainFile := filetest.CreateFile(t, "") 39 | 40 | tests := []struct { 41 | name string 42 | haveLines []string 43 | wantMassdns []string 44 | wantDomain []string 45 | wantErr error 46 | }{ 47 | { 48 | name: "single record", 49 | haveLines: []string{ 50 | "example.com. A 127.0.0.1", 51 | }, 52 | wantMassdns: []string{ 53 | "example.com A 127.0.0.1", 54 | }, 55 | wantDomain: []string{ 56 | "example.com", 57 | }, 58 | }, 59 | { 60 | name: "multiple record", 61 | haveLines: []string{ 62 | "www.example.com. CNAME example.com.", 63 | "example.com. A 127.0.0.1", 64 | "example.com. AAAA ::1", 65 | }, 66 | wantMassdns: []string{ 67 | "www.example.com CNAME example.com", 68 | "www.example.com A 127.0.0.1", 69 | "www.example.com AAAA ::1", 70 | }, 71 | wantDomain: []string{ 72 | "www.example.com", 73 | }, 74 | }, 75 | { 76 | name: "invalid record type", 77 | haveLines: []string{ 78 | "example.com. NS ns.example.com.", 79 | }, 80 | wantMassdns: []string{}, 81 | wantDomain: []string{}, 82 | }, 83 | { 84 | name: "save domain after valid record is found", 85 | haveLines: []string{ 86 | "example.com. NS ns.example.com.", 87 | "example.com. AAAA ::1", 88 | }, 89 | wantMassdns: []string{ 90 | "example.com AAAA ::1", 91 | }, 92 | wantDomain: []string{ 93 | "example.com", 94 | }, 95 | }, 96 | { 97 | name: "multiple answer sections", 98 | haveLines: []string{ 99 | "", 100 | "example.com. A 127.0.0.1", 101 | "", 102 | "", 103 | "www.test.com. CNAME test.com.", 104 | "test.com. A 127.0.0.1", 105 | "test.com. AAAA ::1", 106 | "", 107 | }, 108 | wantMassdns: []string{ 109 | "example.com A 127.0.0.1", 110 | "www.test.com CNAME test.com", 111 | "www.test.com A 127.0.0.1", 112 | "www.test.com AAAA ::1", 113 | }, 114 | wantDomain: []string{ 115 | "example.com", 116 | "www.test.com", 117 | }, 118 | }, 119 | { 120 | name: "skip answer section containing bad data", 121 | haveLines: []string{ 122 | "garbage", 123 | "example.com. A 127.0.0.1", 124 | }, 125 | wantMassdns: []string{}, 126 | wantDomain: []string{}, 127 | }, 128 | { 129 | name: "empty domain", 130 | haveLines: []string{ 131 | ". A 127.0.0.1", 132 | }, 133 | wantMassdns: []string{}, 134 | wantDomain: []string{}, 135 | }, 136 | } 137 | 138 | for _, test := range tests { 139 | t.Run(test.name, func(t *testing.T) { 140 | cb, err := NewDefaultWriteCallback(massdnsFile.Name(), domainFile.Name()) 141 | require.Nil(t, err) 142 | 143 | for _, line := range test.haveLines { 144 | err := cb.Callback(line) 145 | assert.ErrorIs(t, err, test.wantErr) 146 | 147 | if err != nil { 148 | break 149 | } 150 | } 151 | 152 | gotMassdns := filetest.ReadFile(t, massdnsFile.Name()) 153 | gotDomain := filetest.ReadFile(t, domainFile.Name()) 154 | 155 | assert.Equal(t, test.wantMassdns, gotMassdns) 156 | assert.Equal(t, test.wantDomain, gotDomain) 157 | }) 158 | } 159 | } 160 | 161 | func TestDefaultWriteCallback_NoWriter(t *testing.T) { 162 | cb, err := NewDefaultWriteCallback("", "") 163 | require.Nil(t, err) 164 | 165 | gotErr := cb.Callback("") 166 | assert.Nil(t, gotErr) 167 | } 168 | 169 | func TestDefaultWriteCallbackClose(t *testing.T) { 170 | massdnsFile := filetest.CreateFile(t, "") 171 | domainFile := filetest.CreateFile(t, "") 172 | 173 | cb, err := NewDefaultWriteCallback(massdnsFile.Name(), domainFile.Name()) 174 | assert.Nil(t, err) 175 | 176 | cb.Close() 177 | 178 | assert.Equal(t, uintptr(0xffffffffffffffff), cb.massdnsFile.Fd()) 179 | assert.Equal(t, uintptr(0xffffffffffffffff), cb.domainFile.Fd()) 180 | } 181 | -------------------------------------------------------------------------------- /pkg/massdns/doc.go: -------------------------------------------------------------------------------- 1 | // Package massdns provides a Resolver object used to invoke the massdns binary file. 2 | // 3 | // The package contains a LineReader struct that implements the io.Reader interface. It is used 4 | // to read strings line by line from an io.Reader while throttling the results according to a 5 | // rate-limit specified. The LineReader is passed to the stdin of massdns, allowing it to 6 | // approximately respect the number of DNS queries per second wanted. 7 | package massdns 8 | -------------------------------------------------------------------------------- /pkg/massdns/linereader.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "io" 7 | "math" 8 | "sync/atomic" 9 | "time" 10 | ) 11 | 12 | // ErrNotStarted is an error happening when the LineReader hasn't been started. 13 | var ErrNotStarted = errors.New("not started") 14 | 15 | // LineReader is a line reader that limits the number of line per second read. 16 | type LineReader struct { 17 | now func() time.Time 18 | since func(time.Time) time.Duration 19 | 20 | reader io.Reader 21 | readerBuffer *bufio.Reader 22 | 23 | rate float64 24 | startTime time.Time 25 | lineCount int32 26 | } 27 | 28 | var _ io.Reader = (*LineReader)(nil) 29 | 30 | // NewLineReader creates a new RateLimitLineReader. 31 | func NewLineReader(r io.Reader, rate int) *LineReader { 32 | readerBuffer := bufio.NewReader(r) 33 | 34 | return &LineReader{ 35 | now: time.Now, 36 | since: time.Since, 37 | 38 | reader: r, 39 | readerBuffer: readerBuffer, 40 | rate: float64(rate), 41 | } 42 | } 43 | 44 | // Read reads from the reader, counting the number of lines read and applying rate limiting. 45 | func (r *LineReader) Read(p []byte) (n int, err error) { 46 | const nl = byte('\n') 47 | var lines int 48 | 49 | canSend := r.canSend() 50 | 51 | for n = 0; n < len(p) && canSend > 0; n++ { 52 | var b byte 53 | if b, err = r.readerBuffer.ReadByte(); err != nil { 54 | break 55 | } 56 | 57 | if b == nl { 58 | lines++ 59 | canSend-- 60 | } 61 | 62 | p[n] = b 63 | } 64 | 65 | if r.rate > 0 { 66 | time.Sleep(100 * time.Millisecond) 67 | } 68 | 69 | atomic.AddInt32(&r.lineCount, int32(lines)) 70 | 71 | return n, err 72 | } 73 | 74 | // Count returns the number of lines read. 75 | func (r *LineReader) Count() int { 76 | return int(atomic.LoadInt32(&r.lineCount)) 77 | } 78 | 79 | // canSend calculates the number of lines that can be sent while respecting the rate limit. 80 | func (r *LineReader) canSend() int { 81 | var canSend int 82 | if r.rate == 0 { 83 | canSend = math.MaxInt32 84 | } else { 85 | if r.startTime.IsZero() { 86 | r.startTime = r.now() 87 | } 88 | 89 | delta := r.since(r.startTime) 90 | canSend = int(r.rate*(delta.Seconds()+1)) - int(r.lineCount) 91 | } 92 | 93 | return canSend 94 | } 95 | -------------------------------------------------------------------------------- /pkg/massdns/linereader_test.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type stubClock struct { 13 | now time.Time 14 | } 15 | 16 | func newStubClock() *stubClock { 17 | return &stubClock{ 18 | now: time.Now(), 19 | } 20 | } 21 | 22 | func (c *stubClock) advance(d time.Duration) { 23 | c.now = c.now.Add(d) 24 | } 25 | 26 | func (c *stubClock) Now() time.Time { 27 | return c.now 28 | } 29 | 30 | func (c *stubClock) Since(t time.Time) time.Duration { 31 | return c.now.Sub(t) 32 | } 33 | 34 | func newWithClock(r io.Reader, rate int) (*LineReader, *stubClock) { 35 | clock := newStubClock() 36 | reader := NewLineReader(r, rate) 37 | reader.now = clock.Now 38 | reader.since = clock.Since 39 | 40 | return reader, clock 41 | } 42 | 43 | func TestLineReaderNew(t *testing.T) { 44 | r := NewLineReader(strings.NewReader("test"), 1) 45 | assert.NotNil(t, r) 46 | } 47 | 48 | func TestLineReaderRead_Unlimited(t *testing.T) { 49 | testString := "line1\nline2\nline3\n" 50 | r, _ := newWithClock(strings.NewReader(testString), 0) 51 | 52 | buf := make([]byte, 1) 53 | var got string 54 | var gotErr error 55 | 56 | for { 57 | var n int 58 | if n, gotErr = r.Read(buf); gotErr != nil { 59 | break 60 | } 61 | 62 | got = got + string(buf[:n]) 63 | } 64 | 65 | assert.ErrorIs(t, gotErr, io.EOF) 66 | assert.Equal(t, testString, got) 67 | assert.Equal(t, 3, r.Count()) 68 | } 69 | 70 | func TestLineReaderRead_Limited(t *testing.T) { 71 | r, clock := newWithClock(strings.NewReader("line1\nline2\n"), 1) 72 | 73 | buf := make([]byte, 4096) 74 | var got string 75 | 76 | // First read 77 | n, gotErr := r.Read(buf) 78 | got = got + string(buf[:n]) 79 | assert.Nil(t, gotErr) 80 | assert.Equal(t, "line1\n", got, "should return 1 line") 81 | 82 | n, gotErr = r.Read(buf) 83 | assert.Nil(t, gotErr) 84 | assert.Equal(t, 0, n, "should not return bytes before advancing clock") 85 | 86 | // Second read 87 | clock.advance(time.Second) 88 | 89 | n, gotErr = r.Read(buf) 90 | got = got + string(buf[:n]) 91 | assert.Nil(t, gotErr) 92 | assert.Equal(t, "line1\nline2\n", got) 93 | 94 | // EOF 95 | clock.advance(time.Second) 96 | 97 | n, gotErr = r.Read(buf) 98 | assert.ErrorIs(t, gotErr, io.EOF) 99 | assert.Equal(t, 0, n) 100 | } 101 | -------------------------------------------------------------------------------- /pkg/massdns/resolver.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "io" 5 | "os" 6 | ) 7 | 8 | // Runner is an interface that runs the commands required to execute the massdns binary. 9 | type Runner interface { 10 | Run(reader io.Reader, output string, resolversFile string, qps int) error 11 | } 12 | 13 | // Resolver uses massdns to resolve a batch of domain names. 14 | type Resolver struct { 15 | osOpen func(file string) (*os.File, error) 16 | osCreate func(file string) (*os.File, error) 17 | 18 | runner Runner 19 | 20 | reader *LineReader 21 | } 22 | 23 | // NewResolver creates a new Resolver. 24 | func NewResolver(binPath string) *Resolver { 25 | return &Resolver{ 26 | runner: newDefaultRunner(binPath), 27 | 28 | osOpen: os.Open, 29 | osCreate: os.Create, 30 | } 31 | } 32 | 33 | // Resolve reads domain names from the reader and saves the answers to a file. 34 | // It uses the resolvers and queries per second limit specified. 35 | func (r *Resolver) Resolve(reader io.Reader, output string, resolversFile string, qps int) error { 36 | r.reader = NewLineReader(reader, qps) 37 | 38 | if err := r.runner.Run(r.reader, output, resolversFile, qps); err != nil { 39 | return err 40 | } 41 | 42 | return nil 43 | } 44 | 45 | // Current returns the index of the last domain processed. 46 | func (r *Resolver) Current() int { 47 | if r.reader == nil { 48 | return 0 49 | } 50 | 51 | return r.reader.Count() 52 | } 53 | -------------------------------------------------------------------------------- /pkg/massdns/resolver_test.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type stubRunner struct { 13 | returns error 14 | } 15 | 16 | func (r stubRunner) Run(lineReader io.Reader, output string, resolvers string, qps int) error { 17 | buf := make([]byte, 1024) 18 | 19 | for { 20 | _, err := lineReader.Read(buf) 21 | if err != nil { 22 | break 23 | } 24 | } 25 | 26 | return r.returns 27 | } 28 | 29 | func TestResolve(t *testing.T) { 30 | tests := []struct { 31 | name string 32 | haveRunnerError error 33 | wantErr bool 34 | }{ 35 | {name: "success"}, 36 | {name: "runner error handling", haveRunnerError: errors.New("runner error"), wantErr: true}, 37 | } 38 | 39 | for _, test := range tests { 40 | t.Run(test.name, func(t *testing.T) { 41 | resolver := NewResolver("massdns") 42 | resolver.runner = stubRunner{returns: test.haveRunnerError} 43 | 44 | gotErr := resolver.Resolve(strings.NewReader(""), "", "", 10) 45 | 46 | assert.Equal(t, test.wantErr, gotErr != nil, gotErr) 47 | }) 48 | } 49 | } 50 | 51 | func TestCurrent(t *testing.T) { 52 | resolver := NewResolver("massdns") 53 | resolver.runner = stubRunner{} 54 | 55 | gotCurrent := resolver.Current() 56 | assert.Equal(t, 0, gotCurrent) 57 | 58 | resolver.Resolve(strings.NewReader("example.com\n"), "", "", 0) 59 | gotCurrent = resolver.Current() 60 | assert.Equal(t, 1, gotCurrent) 61 | } 62 | -------------------------------------------------------------------------------- /pkg/massdns/runner.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "io" 5 | "os/exec" 6 | "strconv" 7 | ) 8 | 9 | type defaultRunner struct { 10 | binPath string 11 | execCommand func(name string, arg ...string) *exec.Cmd 12 | } 13 | 14 | func newDefaultRunner(binPath string) *defaultRunner { 15 | return &defaultRunner{ 16 | binPath: binPath, 17 | execCommand: exec.Command, 18 | } 19 | } 20 | 21 | // Run executes massdns on the specified domains files and saves the results to the output file. 22 | func (runner *defaultRunner) Run(r io.Reader, output string, resolvers string, qps int) error { 23 | // Create massdns program arguments 24 | massdnsArgs := runner.createMassdnsArgs(output, resolvers, qps) 25 | 26 | // Create a new exec.Cmd and set Stdin and Stdout to our custom handlers to avoid file operations 27 | massdns := runner.execCommand(runner.binPath, massdnsArgs...) 28 | massdns.Stdin = r 29 | 30 | // Run massdns and block until it's done 31 | if err := massdns.Run(); err != nil { 32 | return err 33 | } 34 | 35 | return nil 36 | } 37 | 38 | // createMassdnsArgs creates the command line arguments for massdns. 39 | func (runner *defaultRunner) createMassdnsArgs(output string, resolvers string, qps int) []string { 40 | // Default command line 41 | args := []string{"-q", "-r", resolvers, "-o", "Snl", "-t", "A", "--root", "--retry", "REFUSED", "--retry", "SERVFAIL", "-w", output} 42 | 43 | // Set the massdns hashmap size manually to prevent it from accumulating DNS query on start 44 | if qps > 0 { 45 | args = append(args, "-s", strconv.Itoa(qps)) 46 | } 47 | 48 | return args 49 | } 50 | -------------------------------------------------------------------------------- /pkg/massdns/runner_test.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var stubMassdnsExitCode int 15 | 16 | func stubExecCommand(command string, args ...string) *exec.Cmd { 17 | cs := []string{"-test.run=TestHelperProcess", "--", command} 18 | cs = append(cs, args...) 19 | 20 | cmd := exec.Command(os.Args[0], cs...) 21 | cmd.Env = []string{ 22 | "GO_WANT_HELPER_PROCESS=1", 23 | fmt.Sprintf("MASSDNS_EXIT_CODE=%d", stubMassdnsExitCode), 24 | } 25 | 26 | return cmd 27 | } 28 | 29 | func TestHelperProcess(t *testing.T) { 30 | if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { 31 | return 32 | } 33 | 34 | if os.Args[3] == "massdns" { 35 | fmt.Fprintf(os.Stderr, "massdns: %v\n", os.Args) 36 | exitCode, _ := strconv.Atoi(os.Getenv("MASSDNS_EXIT_CODE")) 37 | os.Exit(exitCode) 38 | } 39 | 40 | fmt.Fprintf(os.Stderr, "%v\n", os.Args) 41 | } 42 | 43 | func TestDefaultRunnerRun(t *testing.T) { 44 | tests := []struct { 45 | name string 46 | haveMassdnsExitCode int 47 | haveCommandAutoStart bool 48 | wantErr bool 49 | }{ 50 | {name: "success"}, 51 | {name: "massdns error handling", haveMassdnsExitCode: 1, wantErr: true}, 52 | } 53 | 54 | for _, test := range tests { 55 | t.Run(test.name, func(t *testing.T) { 56 | stubMassdnsExitCode = test.haveMassdnsExitCode 57 | 58 | runner := newDefaultRunner("massdns") 59 | runner.execCommand = stubExecCommand 60 | 61 | gotErr := runner.Run(strings.NewReader(""), "", "", 10) 62 | 63 | assert.Equal(t, test.wantErr, gotErr != nil, gotErr) 64 | }) 65 | } 66 | } 67 | 68 | func TestCreateMassdnsArgs_DefaultRateLimit(t *testing.T) { 69 | runner := defaultRunner{} 70 | gotArgs := runner.createMassdnsArgs("output.txt", "resolvers.txt", 0) 71 | assert.ElementsMatch(t, []string{"-q", "-r", "resolvers.txt", "-o", "Snl", "-t", "A", "--root", "--retry", "REFUSED", "--retry", "SERVFAIL", "-w", "output.txt"}, gotArgs) 72 | } 73 | 74 | func TestCreateMassdnsArgs_CustomRateLimit(t *testing.T) { 75 | runner := defaultRunner{} 76 | gotArgs := runner.createMassdnsArgs("output.txt", "resolvers.txt", 100) 77 | assert.ElementsMatch(t, []string{"-q", "-r", "resolvers.txt", "-o", "Snl", "-t", "A", "--root", "--retry", "REFUSED", "--retry", "SERVFAIL", "-w", "output.txt", "-s", "100"}, gotArgs) 78 | } 79 | -------------------------------------------------------------------------------- /pkg/massdns/stdouthandler.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | ) 7 | 8 | // StdoutHandler read complete lines from the massdns output and sends them one by 9 | // one to the callback function. 10 | type StdoutHandler struct { 11 | callback Callback 12 | remainder string 13 | } 14 | 15 | var _ io.Writer = (*StdoutHandler)(nil) 16 | 17 | // Callback is a callback function that receives lines from the massdns output. 18 | type Callback interface { 19 | Callback(line string) error 20 | Close() 21 | } 22 | 23 | // NewStdoutHandler returns a new OutputHandler that can be used to receive massdns' stdout. 24 | func NewStdoutHandler(callback Callback) *StdoutHandler { 25 | return &StdoutHandler{ 26 | callback: callback, 27 | } 28 | } 29 | 30 | // Write detects strings terminated by a \n character in the specified buffer and 31 | // sends them to the callback function. 32 | func (w *StdoutHandler) Write(p []byte) (n int, err error) { 33 | var builder strings.Builder 34 | builder.WriteString(w.remainder) 35 | 36 | for n = 0; n < len(p); n++ { 37 | // If we reach the end of a line, send the line to the callback function 38 | // even if the line is empty 39 | if p[n] == byte('\n') { 40 | line := builder.String() 41 | builder.Reset() 42 | 43 | if err := w.callback.Callback(line); err != nil { 44 | return n + 1, err 45 | } 46 | 47 | continue 48 | } 49 | 50 | // Build the string from the current data 51 | builder.WriteByte(p[n]) 52 | } 53 | 54 | // Keep the remainder of the line for the next call to Write 55 | w.remainder = builder.String() 56 | 57 | return n, nil 58 | } 59 | 60 | // Close closes the callback interface. 61 | func (w *StdoutHandler) Close() { 62 | w.callback.Close() 63 | } 64 | -------------------------------------------------------------------------------- /pkg/massdns/stdouthandler_test.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type stubCallback struct { 11 | data []string 12 | returns error 13 | closed bool 14 | } 15 | 16 | func (c *stubCallback) Callback(line string) error { 17 | if c.returns != nil { 18 | return c.returns 19 | } 20 | 21 | c.data = append(c.data, line) 22 | 23 | return nil 24 | } 25 | 26 | func (c *stubCallback) Close() { 27 | c.closed = true 28 | } 29 | 30 | func TestOutputHandlerNew(t *testing.T) { 31 | var cb stubCallback 32 | handler := NewStdoutHandler(&cb) 33 | assert.NotNil(t, handler) 34 | } 35 | 36 | func TestOutputHandlerWrite(t *testing.T) { 37 | callbackError := errors.New("error") 38 | 39 | tests := []struct { 40 | name string 41 | haveBuffers [][]byte 42 | haveError error 43 | want []string 44 | wantErr error 45 | }{ 46 | { 47 | name: "empty write", 48 | }, 49 | { 50 | name: "no newline", 51 | haveBuffers: [][]byte{ 52 | []byte("line"), 53 | }, 54 | }, 55 | { 56 | name: "with newline", 57 | haveBuffers: [][]byte{ 58 | []byte("line\n"), 59 | }, 60 | want: []string{"line"}, 61 | }, 62 | { 63 | name: "multiple lines", 64 | haveBuffers: [][]byte{ 65 | []byte("line1\nline2\nline3\n"), 66 | }, 67 | want: []string{"line1", "line2", "line3"}, 68 | }, 69 | { 70 | name: "partial line", 71 | haveBuffers: [][]byte{ 72 | []byte("line1\nli"), 73 | []byte("ne2\nline3\n"), 74 | }, 75 | want: []string{"line1", "line2", "line3"}, 76 | }, 77 | { 78 | name: "callback error", 79 | haveBuffers: [][]byte{ 80 | []byte("line\n"), 81 | }, 82 | haveError: callbackError, 83 | wantErr: callbackError, 84 | }, 85 | } 86 | 87 | for _, test := range tests { 88 | t.Run(test.name, func(t *testing.T) { 89 | cb := &stubCallback{returns: test.haveError} 90 | handler := NewStdoutHandler(cb) 91 | 92 | for _, buf := range test.haveBuffers { 93 | n, err := handler.Write(buf) 94 | assert.Equal(t, len(buf), n) 95 | assert.ErrorIs(t, err, test.wantErr) 96 | } 97 | 98 | assert.Equal(t, test.want, cb.data) 99 | }) 100 | } 101 | } 102 | 103 | func TestOutputHandlerClose(t *testing.T) { 104 | var cb stubCallback 105 | handler := NewStdoutHandler(&cb) 106 | handler.Close() 107 | assert.True(t, cb.closed) 108 | } 109 | -------------------------------------------------------------------------------- /pkg/massdns/type.go: -------------------------------------------------------------------------------- 1 | package massdns 2 | 3 | // JSONRecord contains a record from a massdns output file. 4 | type JSONRecord struct { 5 | TTL int `json:"ttl"` 6 | Type string `json:"type"` 7 | Class string `json:"class"` 8 | Name string `json:"name"` 9 | Data string `json:"data"` 10 | } 11 | 12 | // JSONResponseData contains the response data from a massdns output file. 13 | type JSONResponseData struct { 14 | Answers []JSONRecord `json:"answers"` 15 | Authorities []JSONRecord `json:"authorities"` 16 | Additionals []JSONRecord `json:"additionals"` 17 | } 18 | 19 | // JSONResponse contains the response from a massdns output file. 20 | type JSONResponse struct { 21 | Name string `json:"name"` 22 | Type string `json:"type"` 23 | Class string `json:"class"` 24 | Status string `json:"status"` 25 | Data JSONResponseData `json:"data"` 26 | Resolver string `json:"resolver"` 27 | } 28 | -------------------------------------------------------------------------------- /pkg/procreader/doc.go: -------------------------------------------------------------------------------- 1 | // Package procreader provides a ProcReader object that implements the io.Reader interface and generates 2 | // its data from a user-specified callback function. 3 | // 4 | // Use the New function to create a new ProcReader and pass it a callback function that is invoked to 5 | // create new data when the ProcReader's buffers are empty. 6 | package procreader 7 | -------------------------------------------------------------------------------- /pkg/procreader/procreader.go: -------------------------------------------------------------------------------- 1 | package procreader 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // ProcReader is a procedural reader that generates its data from a callback function. 8 | type ProcReader struct { 9 | callback Callback 10 | remainder []byte 11 | err error 12 | } 13 | 14 | var _ io.Reader = (*ProcReader)(nil) 15 | 16 | // Callback is a callback function that generates data in the form of a slice of bytes. 17 | // Size is a hint as to how much data is requested. If the callback returns more data, the 18 | // excess will be buffered by the reader for subsequent Read calls. If no data is left, 19 | // the Callback function must returns an io.EOF error. 20 | type Callback func(size int) ([]byte, error) 21 | 22 | // New creates a new ProcReader. 23 | func New(callback Callback) *ProcReader { 24 | return &ProcReader{ 25 | callback: callback, 26 | } 27 | } 28 | 29 | // Read requests data from the callback until either the buffer is full, or an error like EOF occurs. 30 | func (r *ProcReader) Read(p []byte) (int, error) { 31 | // Buffer cannot be nil, otherwise an error like EOF will never be returned 32 | if p == nil { 33 | return 0, io.ErrShortBuffer 34 | } 35 | 36 | var written int 37 | total := len(p) 38 | 39 | for { 40 | var data []byte 41 | 42 | // Get the data to write to the buffer 43 | if r.remainder != nil { 44 | data = r.remainder 45 | r.remainder = nil 46 | } else { 47 | data, r.err = r.callback(total - written) 48 | } 49 | 50 | // Write to buffer 51 | n := copy(p[written:], data) 52 | written += n 53 | 54 | // Could not write entire data, save remainder and exit 55 | if n < len(data) { 56 | r.remainder = data[n:] 57 | return written, nil 58 | } 59 | 60 | // Could save entire data, but buffer is full 61 | if written == len(p) { 62 | return written, r.err 63 | } 64 | 65 | // Error or EOF while creating data, return 66 | if r.err != nil { 67 | return written, r.err 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /pkg/procreader/procreader_test.go: -------------------------------------------------------------------------------- 1 | package procreader 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNew(t *testing.T) { 11 | r := New(func(int) ([]byte, error) { return []byte{}, nil }) 12 | assert.NotNil(t, r) 13 | } 14 | 15 | func TestRead_Buffer(t *testing.T) { 16 | var cb = func(int) ([]byte, error) { 17 | return []byte("this is a test"), io.EOF 18 | } 19 | 20 | tests := []struct { 21 | name string 22 | haveBuffer []byte 23 | wantRead []byte 24 | wantErr error 25 | }{ 26 | {name: "no buffer", haveBuffer: nil, wantRead: nil, wantErr: io.ErrShortBuffer}, 27 | {name: "small buffer", haveBuffer: make([]byte, 1), wantRead: []byte("this is a test"), wantErr: io.EOF}, 28 | {name: "big buffer", haveBuffer: make([]byte, 256), wantRead: []byte("this is a test"), wantErr: io.EOF}, 29 | } 30 | 31 | for _, test := range tests { 32 | t.Run(test.name, func(t *testing.T) { 33 | r := New(cb) 34 | 35 | var err error 36 | var readBuf []byte 37 | 38 | for err == nil { 39 | var n int 40 | n, err = r.Read(test.haveBuffer) 41 | readBuf = append(readBuf, test.haveBuffer[:n]...) 42 | } 43 | 44 | assert.ErrorIs(t, err, test.wantErr) 45 | assert.Equal(t, test.wantRead, readBuf) 46 | }) 47 | } 48 | } 49 | 50 | func TestRead_MultipleCallbacks(t *testing.T) { 51 | data := [][]byte{ 52 | []byte("first callback"), 53 | []byte("second callback"), 54 | } 55 | 56 | var cb = func(int) ([]byte, error) { 57 | val := data[0] 58 | data = data[1:] 59 | 60 | var err error 61 | if len(data) == 0 { 62 | err = io.EOF 63 | } 64 | 65 | return val, err 66 | } 67 | 68 | r := New(cb) 69 | 70 | var err error 71 | var readBuf []byte 72 | 73 | buffer := make([]byte, 1) 74 | for err == nil { 75 | var n int 76 | n, err = r.Read(buffer) 77 | readBuf = append(readBuf, buffer[:n]...) 78 | } 79 | 80 | assert.ErrorIs(t, err, io.EOF) 81 | assert.Equal(t, "first callbacksecond callback", string(readBuf)) 82 | } 83 | -------------------------------------------------------------------------------- /pkg/progressbar/doc.go: -------------------------------------------------------------------------------- 1 | // Package progressbar implements a basic asynchronous and thread-safe progress bar that can perform polling updates. 2 | // 3 | // Create a new progress bar by calling the New function and pass it a callback function that will be called every time the 4 | // progress bar updates. This callback can update the progress bar's current count and other variables as needed. 5 | // 6 | // The progress bar can be customized with different options at creation. 7 | package progressbar 8 | -------------------------------------------------------------------------------- /pkg/progressbar/movingrate.go: -------------------------------------------------------------------------------- 1 | package progressbar 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | var ( 10 | // ErrNotStarted is an error happening when the MovingRate hasn't been started. 11 | ErrNotStarted = errors.New("rate is not started") 12 | 13 | // ErrAlreadyStarted is an error happening when the MovingRate has already been started. 14 | ErrAlreadyStarted = errors.New("rate is already started") 15 | 16 | // ErrStopped is an error happening when the MovingRate has been stopped. 17 | ErrStopped = errors.New("rate has been stopped") 18 | 19 | // ErrAlreadyStopped is an error happening when the MovingRate has already been stopped. 20 | ErrAlreadyStopped = errors.New("rate is already stopped") 21 | ) 22 | 23 | // MovingRate calculates the rate of elements sampled using a moving average. 24 | type MovingRate struct { 25 | now func() time.Time 26 | since func(time.Time) time.Duration 27 | 28 | mu sync.Mutex 29 | 30 | movingAvgSamplingRate time.Duration 31 | movingAvgMaxSamples int 32 | movingAvgSamples []float64 33 | 34 | accumCounter float64 35 | accumCounterTime time.Time 36 | 37 | startTime time.Time 38 | stopTime time.Time 39 | totalCounter float64 40 | } 41 | 42 | // NewMovingRate creates a new MovingRate object with the specified sampling rate and number of samples 43 | // to consider in the moving average. 44 | func NewMovingRate(samplingRate time.Duration, samples int) *MovingRate { 45 | return &MovingRate{ 46 | now: time.Now, 47 | since: time.Since, 48 | 49 | movingAvgSamplingRate: samplingRate, 50 | movingAvgMaxSamples: samples, 51 | } 52 | } 53 | 54 | // Start starts the MovingRate object. 55 | func (r *MovingRate) Start() error { 56 | r.mu.Lock() 57 | defer r.mu.Unlock() 58 | 59 | if !r.startTime.IsZero() { 60 | return ErrAlreadyStarted 61 | } 62 | 63 | r.startTime = r.now() 64 | 65 | return nil 66 | } 67 | 68 | // Stop stops gathering 69 | func (r *MovingRate) Stop() error { 70 | r.mu.Lock() 71 | defer r.mu.Unlock() 72 | 73 | if r.startTime.IsZero() { 74 | return ErrNotStarted 75 | } 76 | 77 | if !r.stopTime.IsZero() { 78 | return ErrAlreadyStopped 79 | } 80 | 81 | r.stopTime = r.now() 82 | 83 | return nil 84 | } 85 | 86 | // Sample records new data in the moving average. If there is not enough time elapsed between 87 | // the previous call of Sample, the data is accumulated into a buffer until a proper rate can 88 | // be calculated. 89 | func (r *MovingRate) Sample(count float64) error { 90 | r.mu.Lock() 91 | defer r.mu.Unlock() 92 | 93 | // Return an error if the sampler hasn't been started 94 | if r.startTime.IsZero() { 95 | return ErrNotStarted 96 | } 97 | 98 | // Return an error if the sampler has been stopped 99 | if !r.stopTime.IsZero() { 100 | return ErrStopped 101 | } 102 | 103 | // Set initial value 104 | if r.accumCounterTime.IsZero() { 105 | r.totalCounter += count 106 | r.movingAvgSamples = append(r.movingAvgSamples, count) 107 | r.accumCounterTime = r.now() 108 | return nil 109 | } 110 | 111 | // Accumulate values 112 | r.accumCounter += count 113 | r.totalCounter += count 114 | 115 | // Don't update the rates if we're below our sampling rate 116 | delta := r.since(r.accumCounterTime).Seconds() 117 | if delta < r.movingAvgSamplingRate.Seconds() { 118 | return nil 119 | } 120 | 121 | // Calculate the current rate and add it to the moving average 122 | curRate := r.accumCounter / delta 123 | r.movingAvgSamples = append(r.movingAvgSamples, curRate) 124 | r.accumCounter = 0 125 | 126 | // Trim moving average values if we have too many samples 127 | if len(r.movingAvgSamples) > r.movingAvgMaxSamples { 128 | r.movingAvgSamples = r.movingAvgSamples[1:] 129 | } 130 | 131 | r.accumCounterTime = r.now() 132 | 133 | return nil 134 | } 135 | 136 | // Current returns the current rate based on the moving average. 137 | // If the MovingRate object has been stopped, return the global rate instead. 138 | func (r *MovingRate) Current() (float64, error) { 139 | r.mu.Lock() 140 | defer r.mu.Unlock() 141 | 142 | // If the counter hasn't been started, return an error 143 | if r.startTime.IsZero() { 144 | return 0, ErrNotStarted 145 | } 146 | 147 | // If the counter has been stopped, calculate the global rate 148 | if !r.stopTime.IsZero() { 149 | delta := r.stopTime.Sub(r.startTime).Seconds() 150 | return r.totalCounter / delta, nil 151 | } 152 | 153 | // If we don't have data yet, calculate the global rate 154 | // using the current time 155 | if len(r.movingAvgSamples) == 0 { 156 | delta := r.since(r.startTime).Seconds() 157 | return r.totalCounter / delta, nil 158 | } 159 | 160 | // Calculate the moving average 161 | var total float64 162 | for _, rate := range r.movingAvgSamples { 163 | total += rate 164 | } 165 | total = total / float64(len(r.movingAvgSamples)) 166 | 167 | return total, nil 168 | } 169 | -------------------------------------------------------------------------------- /pkg/progressbar/movingrate_test.go: -------------------------------------------------------------------------------- 1 | package progressbar 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | type stubClock struct { 12 | now time.Time 13 | } 14 | 15 | func newStubClock() *stubClock { 16 | return &stubClock{ 17 | now: time.Now(), 18 | } 19 | } 20 | 21 | func (c *stubClock) advance(d time.Duration) { 22 | c.now = c.now.Add(d) 23 | } 24 | 25 | func (c *stubClock) Now() time.Time { 26 | return c.now 27 | } 28 | 29 | func (c *stubClock) Since(t time.Time) time.Duration { 30 | return c.now.Sub(t) 31 | } 32 | 33 | func newWithClock(d time.Duration, samples int) (*MovingRate, *stubClock) { 34 | clock := newStubClock() 35 | rate := NewMovingRate(d, samples) 36 | rate.now = clock.Now 37 | rate.since = clock.Since 38 | 39 | return rate, clock 40 | } 41 | 42 | func TestNewMovingRate(t *testing.T) { 43 | rate := NewMovingRate(time.Second, 10) 44 | assert.NotNil(t, rate) 45 | } 46 | 47 | func TestMovingRateStart(t *testing.T) { 48 | rate := NewMovingRate(time.Second, 10) 49 | 50 | err := rate.Start() 51 | assert.Nil(t, err) 52 | 53 | err = rate.Start() 54 | assert.ErrorIs(t, err, ErrAlreadyStarted) 55 | } 56 | 57 | func TestMovingRateStop(t *testing.T) { 58 | rate := NewMovingRate(time.Second, 10) 59 | 60 | err := rate.Stop() 61 | assert.ErrorIs(t, err, ErrNotStarted) 62 | 63 | require.Nil(t, rate.Start()) 64 | err = rate.Stop() 65 | assert.Nil(t, err) 66 | 67 | err = rate.Stop() 68 | assert.ErrorIs(t, err, ErrAlreadyStopped) 69 | } 70 | 71 | func TestMovingRateSample_NotStarted(t *testing.T) { 72 | rate := NewMovingRate(time.Second, 10) 73 | got := rate.Sample(1) 74 | assert.ErrorIs(t, got, ErrNotStarted) 75 | } 76 | 77 | func TestMovingRateSample_Stopped(t *testing.T) { 78 | rate := NewMovingRate(time.Second, 10) 79 | require.Nil(t, rate.Start()) 80 | require.Nil(t, rate.Stop()) 81 | 82 | got := rate.Sample(1) 83 | 84 | assert.ErrorIs(t, got, ErrStopped) 85 | } 86 | 87 | func TestMovingRateSample(t *testing.T) { 88 | rate, clock := newWithClock(time.Second, 2) 89 | require.Nil(t, rate.Start()) 90 | 91 | rate.Sample(10) 92 | got, _ := rate.Current() 93 | assert.Equal(t, 10.0, got, "first sample taken") 94 | 95 | clock.advance(time.Second) 96 | rate.Sample(20) 97 | got, _ = rate.Current() 98 | assert.Equal(t, 15.0, got, "second sample taken") 99 | 100 | clock.advance(time.Second) 101 | rate.Sample(2) 102 | got, _ = rate.Current() 103 | assert.Equal(t, 11.0, got, "third sample taken, first discarded") 104 | 105 | clock.advance(500 * time.Millisecond) 106 | rate.Sample(4) 107 | got, _ = rate.Current() 108 | assert.Equal(t, 11.0, got, "fourth sample sent to accumulator since it's smaller than our interval") 109 | 110 | clock.advance(500 * time.Millisecond) 111 | rate.Sample(4) 112 | got, _ = rate.Current() 113 | assert.Equal(t, 5.0, got, "fourth + fifth sample taken together") 114 | } 115 | 116 | func TestMovingRateCurrent_NotStarted(t *testing.T) { 117 | rate := NewMovingRate(time.Second, 10) 118 | 119 | _, gotErr := rate.Current() 120 | 121 | assert.ErrorIs(t, gotErr, ErrNotStarted) 122 | } 123 | 124 | func TestMovingRateCurrent_Initial(t *testing.T) { 125 | rate := NewMovingRate(time.Second, 10) 126 | 127 | rate.Start() 128 | 129 | got, gotErr := rate.Current() 130 | assert.Nil(t, gotErr) 131 | assert.Equal(t, 0.0, got) 132 | 133 | rate.Sample(1) 134 | 135 | got, gotErr = rate.Current() 136 | 137 | assert.Nil(t, gotErr) 138 | assert.Equal(t, 1.0, got) 139 | } 140 | 141 | func TestMovingRateCurrent_GlobalRate(t *testing.T) { 142 | rate, clock := newWithClock(time.Second, 10) 143 | 144 | rate.Start() 145 | rate.Sample(1) 146 | clock.advance(time.Second) 147 | rate.Sample(1) 148 | clock.advance(time.Second) 149 | rate.Stop() 150 | 151 | gotCurrent, gotErr := rate.Current() 152 | 153 | assert.Nil(t, gotErr) 154 | assert.Equal(t, 1.0, gotCurrent) 155 | } 156 | 157 | func TestMovingRateCurrent_WhileGathering(t *testing.T) { 158 | rate, clock := newWithClock(time.Second, 10) 159 | 160 | rate.Start() 161 | rate.Sample(1) 162 | clock.advance(time.Second) 163 | rate.Sample(1) 164 | clock.advance(time.Second) 165 | 166 | gotCurrent, gotErr := rate.Current() 167 | 168 | assert.Nil(t, gotErr) 169 | assert.Equal(t, 1.0, gotCurrent) 170 | } 171 | 172 | func TestMovingRateCurrent_Decay(t *testing.T) { 173 | rate, clock := newWithClock(time.Second, 1) 174 | 175 | rate.Start() 176 | rate.Sample(10) 177 | clock.advance(time.Second) 178 | clock.advance(time.Second) 179 | clock.advance(time.Second) 180 | clock.advance(time.Second) 181 | clock.advance(time.Second) 182 | rate.Stop() 183 | 184 | got, _ := rate.Current() 185 | 186 | assert.Equal(t, 2.0, got) 187 | } 188 | -------------------------------------------------------------------------------- /pkg/progressbar/options.go: -------------------------------------------------------------------------------- 1 | package progressbar 2 | 3 | import ( 4 | "io" 5 | "time" 6 | ) 7 | 8 | // Option configures a progress bar. 9 | type Option interface { 10 | apply(*options) 11 | } 12 | 13 | type options struct { 14 | updateInterval time.Duration 15 | template string 16 | writer io.Writer 17 | style Style 18 | } 19 | 20 | // WithTemplate provides a template string to the progress bar. 21 | func WithTemplate(t string) Option { 22 | return templateOption(t) 23 | } 24 | 25 | type templateOption string 26 | 27 | func (t templateOption) apply(opts *options) { 28 | opts.template = string(t) 29 | } 30 | 31 | // WithWriter provides a custom writer to the progress bar. 32 | func WithWriter(w io.Writer) Option { 33 | return writerOption{w: w} 34 | } 35 | 36 | type writerOption struct { 37 | w io.Writer 38 | } 39 | 40 | func (w writerOption) apply(opts *options) { 41 | opts.writer = w.w 42 | } 43 | 44 | // WithInterval provides an update interval for the progress bar. 45 | func WithInterval(d time.Duration) Option { 46 | return intervalOption(d) 47 | } 48 | 49 | type intervalOption time.Duration 50 | 51 | func (i intervalOption) apply(opts *options) { 52 | opts.updateInterval = time.Duration(i) 53 | } 54 | 55 | // WithStyle provides a custom styling for the progress bar. 56 | func WithStyle(style Style) Option { 57 | return styleOption(style) 58 | } 59 | 60 | type styleOption Style 61 | 62 | func (s styleOption) apply(opts *options) { 63 | opts.style = Style(s) 64 | } 65 | -------------------------------------------------------------------------------- /pkg/progressbar/progressbar_test.go: -------------------------------------------------------------------------------- 1 | package progressbar 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func updateFn(bar *ProgressBar) {} 13 | 14 | func TestNew_Default(t *testing.T) { 15 | pb := New(updateFn, 100) 16 | assert.NotNil(t, pb) 17 | } 18 | 19 | func TestNew_Options(t *testing.T) { 20 | pb := New( 21 | updateFn, 22 | 100, 23 | WithTemplate(""), 24 | WithWriter(io.Discard), 25 | WithInterval(100*time.Millisecond), 26 | WithStyle(DefaultStyle()), 27 | ) 28 | assert.NotNil(t, pb) 29 | } 30 | 31 | func TestStart_OK(t *testing.T) { 32 | pb := New(updateFn, 100, WithWriter(io.Discard)) 33 | pb.Start() 34 | } 35 | 36 | func TestStop(t *testing.T) { 37 | pb := New(updateFn, 100, WithWriter(io.Discard)) 38 | pb.Start() 39 | pb.Stop() 40 | } 41 | 42 | func TestGetSet(t *testing.T) { 43 | pb := New(updateFn, 100) 44 | 45 | got := pb.Get("key") 46 | assert.Nil(t, got, "should not exist") 47 | 48 | pb.Set("key", "value") 49 | got = pb.Get("key") 50 | assert.Equal(t, "value", got, "should exist") 51 | } 52 | 53 | func TestIncrement(t *testing.T) { 54 | pb := New(updateFn, 100) 55 | 56 | got := pb.Current() 57 | assert.EqualValues(t, 0, got) 58 | 59 | pb.Increment(1) 60 | 61 | got = pb.Current() 62 | assert.EqualValues(t, 1, got) 63 | } 64 | 65 | func TestSetCurrent(t *testing.T) { 66 | pb := New(updateFn, 100) 67 | 68 | pb.SetCurrent(10) 69 | 70 | got := pb.Current() 71 | assert.EqualValues(t, 10, got) 72 | } 73 | 74 | func TestSetCurrent_Lower(t *testing.T) { 75 | pb := New(updateFn, 100) 76 | 77 | pb.SetCurrent(10) 78 | pb.SetCurrent(5) 79 | 80 | got := pb.Current() 81 | assert.EqualValues(t, 10, got) 82 | } 83 | 84 | func TestTotal(t *testing.T) { 85 | pb := New(updateFn, 100) 86 | got := pb.Total() 87 | assert.EqualValues(t, 100, got) 88 | } 89 | 90 | func TestRate_Initial(t *testing.T) { 91 | pb := New(updateFn, 100, WithWriter(io.Discard)) 92 | pb.Start() 93 | 94 | got := pb.Rate() 95 | assert.Equal(t, 0.0, got) 96 | } 97 | 98 | func TestRate_Increment(t *testing.T) { 99 | pb := New(updateFn, 100, WithWriter(io.Discard)) 100 | 101 | pb.Start() 102 | pb.Increment(1) 103 | 104 | got := pb.Rate() 105 | assert.Equal(t, 1.0, got) 106 | } 107 | 108 | func TestRate_SetCurrent(t *testing.T) { 109 | pb := New(updateFn, 100, WithWriter(io.Discard)) 110 | 111 | pb.Start() 112 | pb.SetCurrent(1) 113 | 114 | got := pb.Rate() 115 | assert.Equal(t, 1.0, got) 116 | } 117 | 118 | func TestETA_OK(t *testing.T) { 119 | pb := New(updateFn, 7777, WithWriter(io.Discard)) 120 | pb.Start() 121 | pb.SetCurrent(1) 122 | 123 | require.Equal(t, 1.0, pb.Rate(), "rate should be 1/sec") 124 | 125 | gotH, gotM, gotS := pb.ETA() 126 | 127 | assert.Equal(t, 2, gotH) 128 | assert.Equal(t, 9, gotM) 129 | assert.Equal(t, 36, gotS) 130 | } 131 | 132 | func TestETA_NoTotal(t *testing.T) { 133 | pb := New(updateFn, 0, WithWriter(io.Discard)) 134 | pb.Start() 135 | pb.SetCurrent(1) 136 | 137 | gotH, gotM, gotS := pb.ETA() 138 | 139 | assert.Equal(t, 0, gotH) 140 | assert.Equal(t, 0, gotM) 141 | assert.Equal(t, 0, gotS) 142 | } 143 | 144 | func TestETA_NoRate(t *testing.T) { 145 | pb := New(updateFn, 100, WithWriter(io.Discard)) 146 | pb.Start() 147 | pb.SetCurrent(0) 148 | 149 | require.Equal(t, 0.0, pb.Rate(), "rate should be 0/sec") 150 | 151 | gotH, gotM, gotS := pb.ETA() 152 | 153 | assert.Equal(t, 99, gotH) 154 | assert.Equal(t, 59, gotM) 155 | assert.Equal(t, 59, gotS) 156 | } 157 | 158 | func TestETA_Done(t *testing.T) { 159 | pb := New(updateFn, 100, WithWriter(io.Discard)) 160 | pb.Start() 161 | pb.SetCurrent(100) 162 | 163 | gotH, gotM, gotS := pb.ETA() 164 | 165 | assert.Equal(t, 0, gotH) 166 | assert.Equal(t, 0, gotM) 167 | assert.Equal(t, 0, gotS) 168 | } 169 | 170 | func TestTime_Initial(t *testing.T) { 171 | pb := New(updateFn, 100) 172 | 173 | gotH, gotM, gotS := pb.Time() 174 | 175 | assert.Equal(t, 0, gotH) 176 | assert.Equal(t, 0, gotM) 177 | assert.Equal(t, 0, gotS) 178 | } 179 | 180 | func TestTime_Running(t *testing.T) { 181 | pb := New(updateFn, 100, WithWriter(io.Discard)) 182 | pb.Start() 183 | pb.startTime = time.Now().Add(-30 * time.Second) 184 | 185 | gotH, gotM, gotS := pb.Time() 186 | 187 | assert.Equal(t, 0, gotH) 188 | assert.Equal(t, 0, gotM) 189 | assert.Greater(t, gotS, 0) 190 | } 191 | 192 | func TestTime_Finished(t *testing.T) { 193 | pb := New(updateFn, 100, WithWriter(io.Discard)) 194 | pb.Start() 195 | pb.startTime = time.Now().Add(-30 * time.Second) 196 | pb.Stop() 197 | 198 | gotH, gotM, gotS := pb.Time() 199 | 200 | assert.Equal(t, 0, gotH) 201 | assert.Equal(t, 0, gotM) 202 | assert.Greater(t, gotS, 0) 203 | } 204 | -------------------------------------------------------------------------------- /pkg/progressbar/style.go: -------------------------------------------------------------------------------- 1 | package progressbar 2 | 3 | // Color is a terminal escape string that defines a color. 4 | type Color string 5 | 6 | const ( 7 | // ColorBlack is the black foreground color. 8 | ColorBlack Color = "\033[0;30m" 9 | 10 | // ColorGray is the gray foreground color. 11 | ColorGray Color = "\033[1;30m" 12 | 13 | // ColorRed is the red foreground color. 14 | ColorRed Color = "\033[0;31m" 15 | 16 | // ColorBrightRed is the bright red foreground color. 17 | ColorBrightRed Color = "\033[1;31m" 18 | 19 | // ColorGreen is the green foreground color. 20 | ColorGreen Color = "\033[0;32m" 21 | 22 | // ColorBrightGreen is the bright green foreground color. 23 | ColorBrightGreen Color = "\033[1;32m" 24 | 25 | // ColorYellow is the yellow foreground color. 26 | ColorYellow Color = "\033[0;33m" 27 | 28 | // ColorBrightYellow is the yellow foreground color. 29 | ColorBrightYellow Color = "\033[1;33m" 30 | 31 | // ColorBlue is the blue foreground color. 32 | ColorBlue Color = "\033[0;34m" 33 | 34 | // ColorBrightBlue is the bright blue foreground color. 35 | ColorBrightBlue Color = "\033[1;34m" 36 | 37 | // ColorMagenta is the magenta foreground color. 38 | ColorMagenta Color = "\033[0;35m" 39 | 40 | // ColorBrightMagenta is the bright magenta foreground color. 41 | ColorBrightMagenta Color = "\033[1;35m" 42 | 43 | // ColorCyan is the cyan foreground color. 44 | ColorCyan Color = "\033[0;36m" 45 | 46 | // ColorBrightCyan is the bright cyan foreground color. 47 | ColorBrightCyan Color = "\033[1;36m" 48 | 49 | // ColorWhite is the white foreground color. 50 | ColorWhite Color = "\033[0;37m" 51 | 52 | // ColorBrightWhite is the bright white foreground color. 53 | ColorBrightWhite Color = "\033[1;37m" 54 | 55 | // ColorReset is the code to reset all attributes. 56 | ColorReset Color = "\033[0m" 57 | ) 58 | 59 | // Style defines a progress bar style. 60 | type Style struct { 61 | BarPrefix rune 62 | BarSuffix rune 63 | BarFull rune 64 | BarEmpty rune 65 | 66 | BarPrefixColor Color 67 | BarSuffixColor Color 68 | BarFullColor Color 69 | BarEmptyColor Color 70 | } 71 | 72 | // DefaultStyle create a Style object with the default styling. 73 | func DefaultStyle() Style { 74 | return Style{ 75 | BarPrefix: '|', 76 | BarSuffix: '|', 77 | BarEmpty: '░', 78 | BarFull: '█', 79 | 80 | BarPrefixColor: ColorWhite, 81 | BarSuffixColor: ColorWhite, 82 | BarEmptyColor: ColorGray, 83 | BarFullColor: ColorWhite, 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /pkg/shellexecutor/doc.go: -------------------------------------------------------------------------------- 1 | // Package shellexecutor offers a basic ShellExecutor object that can be used to execute command-line applications. 2 | package shellexecutor 3 | -------------------------------------------------------------------------------- /pkg/shellexecutor/shellexecutor.go: -------------------------------------------------------------------------------- 1 | package shellexecutor 2 | 3 | import "os/exec" 4 | 5 | // ShellExecutor is a shell executor object. 6 | type ShellExecutor struct { 7 | execCommand func(name string, arg ...string) *exec.Cmd 8 | } 9 | 10 | // NewShellExecutor returns a new ShellExecutor object. 11 | func NewShellExecutor() *ShellExecutor { 12 | return &ShellExecutor{ 13 | execCommand: exec.Command, 14 | } 15 | } 16 | 17 | // Shell executes a program with the specified arguments. 18 | // The execution is silent, and an error is returned if the execution ends with an error code. 19 | func (e ShellExecutor) Shell(name string, arg ...string) error { 20 | cmd := e.execCommand(name, arg...) 21 | 22 | return cmd.Run() 23 | } 24 | -------------------------------------------------------------------------------- /pkg/shellexecutor/shellexecutor_test.go: -------------------------------------------------------------------------------- 1 | package shellexecutor 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "strconv" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | var stubExitCode int 14 | 15 | func stubExecCommand(command string, args ...string) *exec.Cmd { 16 | cs := []string{"-test.run=TestHelperProcess", "--", command} 17 | cs = append(cs, args...) 18 | 19 | cmd := exec.Command(os.Args[0], cs...) 20 | cmd.Env = []string{ 21 | "GO_WANT_HELPER_PROCESS=1", 22 | fmt.Sprintf("EXIT_CODE=%d", stubExitCode), 23 | } 24 | 25 | return cmd 26 | } 27 | 28 | func TestHelperProcess(t *testing.T) { 29 | if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { 30 | return 31 | } 32 | 33 | exitCode, err := strconv.Atoi(os.Getenv("EXIT_CODE")) 34 | 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | os.Exit(exitCode) 40 | } 41 | 42 | func TestShell(t *testing.T) { 43 | tests := []struct { 44 | name string 45 | haveExitCode int 46 | wantErr bool 47 | }{ 48 | {name: "success", haveExitCode: 0, wantErr: false}, 49 | {name: "exit code 1", haveExitCode: 1, wantErr: true}, 50 | } 51 | 52 | for _, test := range tests { 53 | t.Run(test.name, func(t *testing.T) { 54 | stubExitCode = test.haveExitCode 55 | 56 | executor := NewShellExecutor() 57 | executor.execCommand = stubExecCommand 58 | 59 | gotErr := executor.Shell("dummy") 60 | 61 | assert.Equal(t, test.wantErr, gotErr != nil) 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /pkg/threadpool/doc.go: -------------------------------------------------------------------------------- 1 | // Package threadpool offers a pool of workers implemented using goroutines. 2 | // 3 | // Create a new thread pool by calling NewThreadPool and specify the number of workers wanted and a work queue size. 4 | // If the work queue is full when a new task is pushed, the thread pool will block until another task finishes. 5 | // 6 | // Tasks can be any objects that implement the Runnable interface. 7 | package threadpool 8 | -------------------------------------------------------------------------------- /pkg/threadpool/threadpool.go: -------------------------------------------------------------------------------- 1 | package threadpool 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | // ThreadPool is a thread pool object used to execute tasks in parallel. 11 | type ThreadPool struct { 12 | taskCounter int64 13 | taskDoneCounter int64 14 | 15 | taskChan chan Runnable 16 | doneChan chan bool 17 | 18 | wg sync.WaitGroup 19 | ctx context.Context 20 | cancel context.CancelFunc 21 | } 22 | 23 | // Runnable defines an interface with a Run function to be executed by a thread in a thread pool. 24 | type Runnable interface { 25 | Run() 26 | } 27 | 28 | // NewThreadPool creates a new ThreadPool object and starts the worker threads. 29 | func NewThreadPool(threads int, queueSize int) *ThreadPool { 30 | p := &ThreadPool{} 31 | p.taskChan = make(chan Runnable, queueSize) 32 | p.doneChan = make(chan bool, queueSize) 33 | 34 | p.ctx, p.cancel = context.WithCancel(context.Background()) 35 | 36 | p.createPool(threads) 37 | p.createSentinel() 38 | 39 | return p 40 | } 41 | 42 | // Execute adds a task to the task queue to be picked by a worker thread. It will block if the queue is full. 43 | func (p *ThreadPool) Execute(task Runnable) { 44 | atomic.AddInt64(&p.taskCounter, 1) 45 | p.taskChan <- task 46 | } 47 | 48 | // Done returns true if there are no tasks in flight. 49 | func (p *ThreadPool) Done() bool { 50 | current := atomic.LoadInt64(&p.taskCounter) 51 | done := atomic.LoadInt64(&p.taskDoneCounter) 52 | return current == done 53 | } 54 | 55 | // Wait waits for all the tasks in flight to be processed. 56 | func (p *ThreadPool) Wait() { 57 | for !p.Done() { 58 | time.Sleep(1 * time.Millisecond) 59 | } 60 | } 61 | 62 | // Close closes the threadpool and frees up the threads. 63 | func (p *ThreadPool) Close() { 64 | p.Wait() 65 | 66 | p.cancel() 67 | p.wg.Wait() 68 | 69 | close(p.taskChan) 70 | close(p.doneChan) 71 | } 72 | 73 | // CurrentCount returns the current number of tasks processed. 74 | func (p *ThreadPool) CurrentCount() int { 75 | done := atomic.LoadInt64(&p.taskDoneCounter) 76 | return int(done) 77 | } 78 | 79 | func (p *ThreadPool) createPool(threads int) { 80 | for i := 0; i < threads; i++ { 81 | worker := newWorker(p.ctx, &p.wg, p.taskChan, p.doneChan) 82 | worker.start() 83 | } 84 | } 85 | 86 | func (p *ThreadPool) createSentinel() { 87 | go func(ctx context.Context, counter *int64) { 88 | for { 89 | select { 90 | case <-ctx.Done(): 91 | return 92 | case <-p.doneChan: 93 | atomic.AddInt64(counter, 1) 94 | } 95 | } 96 | }(p.ctx, &p.taskDoneCounter) 97 | } 98 | -------------------------------------------------------------------------------- /pkg/threadpool/threadpool_test.go: -------------------------------------------------------------------------------- 1 | package threadpool_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/d3mondev/puredns/v2/pkg/threadpool" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type charCountTask struct { 12 | value string 13 | count int 14 | 15 | wantCount int 16 | } 17 | 18 | func (t *charCountTask) Run() { 19 | t.count = len(t.value) 20 | } 21 | 22 | func TestThreadPool(t *testing.T) { 23 | emptyList := []charCountTask{} 24 | 25 | singleList := []charCountTask{ 26 | {value: "test", wantCount: 4}, 27 | } 28 | 29 | smallList := []charCountTask{ 30 | {value: "hello", wantCount: 5}, 31 | {value: "world", wantCount: 5}, 32 | {value: "foo", wantCount: 3}, 33 | {value: "bar", wantCount: 3}, 34 | {value: "test", wantCount: 4}, 35 | } 36 | 37 | var bigList []charCountTask = make([]charCountTask, 1000) 38 | for i := range bigList { 39 | bigList[i].value = strings.Repeat("a", i+1) 40 | bigList[i].wantCount = i + 1 41 | } 42 | 43 | tests := []struct { 44 | name string 45 | haveThreadCount int 46 | haveQueueSize int 47 | haveTasks []charCountTask 48 | wantQueries int 49 | }{ 50 | {name: "single worker", haveThreadCount: 1, haveQueueSize: 10, haveTasks: smallList}, 51 | {name: "multiple workers", haveThreadCount: 3, haveQueueSize: 10, haveTasks: smallList}, 52 | {name: "single queue", haveThreadCount: 3, haveQueueSize: 1, haveTasks: smallList}, 53 | {name: "big list", haveThreadCount: 5, haveQueueSize: 1000, haveTasks: bigList}, 54 | {name: "no tasks", haveThreadCount: 5, haveQueueSize: 10, haveTasks: emptyList}, 55 | {name: "single task", haveThreadCount: 5, haveQueueSize: 10, haveTasks: singleList}, 56 | } 57 | 58 | for _, test := range tests { 59 | t.Run(test.name, func(t *testing.T) { 60 | pool := threadpool.NewThreadPool(test.haveThreadCount, test.haveQueueSize) 61 | defer pool.Close() 62 | 63 | for i := range test.haveTasks { 64 | pool.Execute(&test.haveTasks[i]) 65 | } 66 | 67 | pool.Wait() 68 | 69 | for _, task := range test.haveTasks { 70 | assert.Equal(t, task.wantCount, task.count) 71 | } 72 | 73 | gotTotal := pool.CurrentCount() 74 | 75 | assert.Equal(t, len(test.haveTasks), gotTotal) 76 | }) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /pkg/threadpool/worker.go: -------------------------------------------------------------------------------- 1 | package threadpool 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type worker struct { 8 | taskChan chan Runnable 9 | doneChan chan bool 10 | ctx context.Context 11 | wg signaler 12 | } 13 | 14 | type signaler interface { 15 | Add(delta int) 16 | Done() 17 | Wait() 18 | } 19 | 20 | func newWorker(ctx context.Context, wg signaler, taskChan chan Runnable, doneChan chan bool) *worker { 21 | wg.Add(1) 22 | 23 | return &worker{ 24 | ctx: ctx, 25 | wg: wg, 26 | taskChan: taskChan, 27 | doneChan: doneChan, 28 | } 29 | } 30 | 31 | func (w *worker) start() { 32 | go func(ctx context.Context, wg signaler, taskChan chan Runnable, doneChan chan bool) { 33 | for { 34 | select { 35 | case <-ctx.Done(): 36 | wg.Done() 37 | return 38 | 39 | case task := <-taskChan: 40 | task.Run() 41 | doneChan <- true 42 | } 43 | } 44 | }(w.ctx, w.wg, w.taskChan, w.doneChan) 45 | } 46 | -------------------------------------------------------------------------------- /pkg/wildcarder/answercache.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type answerCache struct { 8 | cache map[AnswerHash][]string 9 | mu sync.Mutex 10 | } 11 | 12 | func newAnswerCache() *answerCache { 13 | cache := answerCache{} 14 | cache.cache = make(map[AnswerHash][]string) 15 | 16 | return &cache 17 | } 18 | 19 | // add adds DNS answers to the list and associate a root domain to them. 20 | func (c *answerCache) add(root string, answers []DNSAnswer) { 21 | c.mu.Lock() 22 | answerHashes := []AnswerHash{} 23 | for _, answer := range answers { 24 | answerHashes = append(answerHashes, HashAnswer(answer)) 25 | } 26 | c.mu.Unlock() 27 | 28 | c.addHash(root, answerHashes) 29 | } 30 | 31 | // addHash adds DNS answer hashes to the list and associate a root domain to them. 32 | func (c *answerCache) addHash(root string, answers []AnswerHash) { 33 | c.mu.Lock() 34 | defer c.mu.Unlock() 35 | 36 | for _, answer := range answers { 37 | if _, ok := c.cache[answer]; !ok { 38 | c.cache[answer] = []string{} 39 | } 40 | 41 | found := false 42 | for _, r := range c.cache[answer] { 43 | if root == r { 44 | found = true 45 | break 46 | } 47 | } 48 | 49 | if !found { 50 | c.cache[answer] = append(c.cache[answer], root) 51 | } 52 | } 53 | } 54 | 55 | // find returns the root domains associated with DNS answers. 56 | func (c *answerCache) find(answers []DNSAnswer) []string { 57 | c.mu.Lock() 58 | 59 | answerHashes := []AnswerHash{} 60 | for _, answer := range answers { 61 | answerHashes = append(answerHashes, HashAnswer(answer)) 62 | } 63 | 64 | c.mu.Unlock() 65 | 66 | return c.findHash(answerHashes) 67 | } 68 | 69 | // findHash returns the root domains associated with DNS answer hashes. 70 | func (c *answerCache) findHash(answers []AnswerHash) []string { 71 | c.mu.Lock() 72 | defer c.mu.Unlock() 73 | 74 | for _, answer := range answers { 75 | if roots, ok := c.cache[answer]; ok { 76 | return roots 77 | } 78 | } 79 | 80 | return []string{} 81 | } 82 | 83 | // count returns the number of answers in the cache. 84 | func (c *answerCache) count() int { 85 | var count int 86 | 87 | for _, answers := range c.cache { 88 | count += len(answers) 89 | } 90 | 91 | return count 92 | } 93 | -------------------------------------------------------------------------------- /pkg/wildcarder/answercache_test.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/resolvermt" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type cacheData struct { 11 | root string 12 | answers []DNSAnswer 13 | } 14 | 15 | func TestAnswerCache(t *testing.T) { 16 | singleAnswer := []DNSAnswer{ 17 | { 18 | Type: resolvermt.TypeA, 19 | Answer: "127.0.0.1", 20 | }, 21 | } 22 | 23 | singleCNAME := []DNSAnswer{ 24 | { 25 | Type: resolvermt.TypeCNAME, 26 | Answer: "127.0.0.1", 27 | }, 28 | } 29 | 30 | multipleAnswers := []DNSAnswer{ 31 | { 32 | Type: resolvermt.TypeA, 33 | Answer: "127.0.0.1", 34 | }, 35 | { 36 | Type: resolvermt.TypeAAAA, 37 | Answer: "::1", 38 | }, 39 | { 40 | Type: resolvermt.TypeCNAME, 41 | Answer: "cname", 42 | }, 43 | } 44 | 45 | tests := []struct { 46 | name string 47 | haveCacheData []cacheData 48 | haveSearch []DNSAnswer 49 | want []string 50 | }{ 51 | {name: "empty cache", haveCacheData: nil, haveSearch: singleAnswer, want: []string{}}, 52 | {name: "empty search", haveCacheData: []cacheData{{root: "root", answers: singleAnswer}}, haveSearch: nil, want: []string{}}, 53 | {name: "find single record", haveCacheData: []cacheData{{root: "root", answers: multipleAnswers}}, haveSearch: singleAnswer, want: []string{"root"}}, 54 | {name: "same root", haveCacheData: []cacheData{{root: "root", answers: singleAnswer}, {root: "root", answers: singleCNAME}}, haveSearch: singleAnswer, want: []string{"root"}}, 55 | {name: "duplicate answer", haveCacheData: []cacheData{{root: "root", answers: singleAnswer}, {root: "root", answers: singleAnswer}}, haveSearch: singleAnswer, want: []string{"root"}}, 56 | {name: "multiple roots", 57 | haveCacheData: []cacheData{ 58 | { 59 | root: "root A", 60 | answers: singleAnswer, 61 | }, 62 | { 63 | root: "root B", 64 | answers: singleAnswer, 65 | }, 66 | }, 67 | haveSearch: multipleAnswers, 68 | want: []string{"root A", "root B"}, 69 | }, 70 | {name: "different types", haveCacheData: []cacheData{{root: "root", answers: multipleAnswers}}, haveSearch: singleCNAME, want: []string{}}, 71 | } 72 | 73 | for _, test := range tests { 74 | t.Run(test.name, func(t *testing.T) { 75 | cache := newAnswerCache() 76 | 77 | for _, data := range test.haveCacheData { 78 | cache.add(data.root, data.answers) 79 | } 80 | 81 | got := cache.find(test.haveSearch) 82 | assert.ElementsMatch(t, test.want, got) 83 | }) 84 | } 85 | } 86 | 87 | func TestAnswerCacheCount(t *testing.T) { 88 | cache := newAnswerCache() 89 | assert.Equal(t, 0, cache.count(), "empty cache count is 0") 90 | 91 | cache.add("root", []DNSAnswer{}) 92 | assert.Equal(t, 0, cache.count(), "empty record cache count is 0") 93 | 94 | cache.add("root", []DNSAnswer{{Type: resolvermt.TypeA, Answer: "127.0.0.1"}}) 95 | assert.Equal(t, 1, cache.count(), "add single answer cache count is 1") 96 | 97 | cache.add("root", []DNSAnswer{{Type: resolvermt.TypeA, Answer: "192.168.0.1"}}) 98 | assert.Equal(t, 2, cache.count(), "add new answer cache count is 2") 99 | } 100 | -------------------------------------------------------------------------------- /pkg/wildcarder/clientdns.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import "github.com/d3mondev/resolvermt" 4 | 5 | // DNSAnswer represents a DNS answer without the question. 6 | type DNSAnswer struct { 7 | Type resolvermt.RRtype 8 | Answer string 9 | } 10 | 11 | // ClientDNS is a DNS client that implements the Resolver interface. 12 | type ClientDNS struct { 13 | client resolver 14 | } 15 | 16 | type resolver interface { 17 | Resolve(domains []string, rrtype resolvermt.RRtype) []resolvermt.Record 18 | QueryCount() int 19 | } 20 | 21 | // NewClientDNS creates a new ResolverDNS object to use with a Wildcarder object. 22 | func NewClientDNS(resolvers []string, retryCount int, qps int, concurrency int) *ClientDNS { 23 | return &ClientDNS{ 24 | client: resolvermt.New(resolvers, retryCount, qps, concurrency), 25 | } 26 | } 27 | 28 | // Resolve resolves A records from a list of domain names and returns the answers. 29 | func (r *ClientDNS) Resolve(domains []string) []DNSAnswer { 30 | records := r.client.Resolve(domains, resolvermt.TypeA) 31 | 32 | // Removed AAAA records as those are not being handled by massdns right now 33 | // records = append(records, r.client.Resolve(domains, resolvermt.TypeAAAA)...) 34 | 35 | answers := []DNSAnswer{} 36 | for _, record := range records { 37 | answers = append(answers, DNSAnswer{Type: record.Type, Answer: record.Answer}) 38 | } 39 | 40 | return answers 41 | } 42 | 43 | // QueryCount returns the number of DNS queries really performed. 44 | func (r *ClientDNS) QueryCount() int { 45 | return r.client.QueryCount() 46 | } 47 | -------------------------------------------------------------------------------- /pkg/wildcarder/clientdns_test.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/resolvermt" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type stubClient struct { 11 | empty bool 12 | queries int 13 | } 14 | 15 | func (r *stubClient) Resolve(domains []string, rrtype resolvermt.RRtype) []resolvermt.Record { 16 | if r.empty { 17 | return []resolvermt.Record{} 18 | } 19 | 20 | records := []resolvermt.Record{} 21 | for _, domain := range domains { 22 | var answer string 23 | 24 | if rrtype == resolvermt.TypeA { 25 | answer = "127.0.0.1" 26 | } 27 | 28 | records = append(records, resolvermt.Record{Question: domain, Type: rrtype, Answer: answer}) 29 | r.queries++ 30 | } 31 | 32 | return records 33 | } 34 | 35 | func (r *stubClient) QueryCount() int { 36 | return r.queries 37 | } 38 | 39 | func TestResolverDNS(t *testing.T) { 40 | tests := []struct { 41 | name string 42 | haveRecords bool 43 | want []DNSAnswer 44 | }{ 45 | {name: "empty answer", haveRecords: false, want: []DNSAnswer{}}, 46 | {name: "non-empty answer", haveRecords: true, want: []DNSAnswer{{Type: resolvermt.TypeA, Answer: "127.0.0.1"}}}, 47 | } 48 | 49 | for _, test := range tests { 50 | client := &stubClient{empty: !test.haveRecords} 51 | resolver := ClientDNS{} 52 | resolver.client = client 53 | 54 | got := resolver.Resolve([]string{"test"}) 55 | 56 | assert.ElementsMatch(t, test.want, got) 57 | } 58 | } 59 | 60 | func TestResolverDNSQueryCount(t *testing.T) { 61 | client := &stubClient{empty: false} 62 | 63 | resolver := ClientDNS{} 64 | resolver.client = client 65 | 66 | got := resolver.QueryCount() 67 | assert.Equal(t, 0, got, "initial query count should be 0") 68 | 69 | resolver.Resolve([]string{"test A", "test B"}) 70 | got = resolver.QueryCount() 71 | assert.Equal(t, 2, got, "query count should increment by 1 for each query") 72 | } 73 | -------------------------------------------------------------------------------- /pkg/wildcarder/dnscache.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // DNSCache represents a cache of DNS queries and answers. 8 | type DNSCache struct { 9 | mu sync.Mutex 10 | cache map[QuestionHash][]AnswerHash 11 | } 12 | 13 | // NewDNSCache creates an empty cache. 14 | func NewDNSCache() *DNSCache { 15 | cache := DNSCache{} 16 | cache.cache = make(map[QuestionHash][]AnswerHash) 17 | return &cache 18 | } 19 | 20 | // Add adds an answer to the DNS cache. The answer will be appended to the list of 21 | // existing answers for a question if they already exist. 22 | func (c *DNSCache) Add(question string, answers []DNSAnswer) { 23 | c.mu.Lock() 24 | defer c.mu.Unlock() 25 | 26 | questionHash := HashQuestion(question) 27 | 28 | if _, ok := c.cache[questionHash]; !ok { 29 | c.cache[questionHash] = []AnswerHash{} 30 | } 31 | 32 | for _, answer := range answers { 33 | answerHash := HashAnswer(answer) 34 | 35 | found := false 36 | for _, answer := range c.cache[questionHash] { 37 | if answer == answerHash { 38 | found = true 39 | break 40 | } 41 | } 42 | 43 | if !found { 44 | c.cache[questionHash] = append(c.cache[questionHash], answerHash) 45 | } 46 | } 47 | } 48 | 49 | // Find returns the answers for a given DNS query from the cache. 50 | // The list of answers returned can be empty if the question is in the cache but 51 | // no results were found, or nil if the question is not in the cache. 52 | func (c *DNSCache) Find(question string) []AnswerHash { 53 | c.mu.Lock() 54 | defer c.mu.Unlock() 55 | 56 | questionHash := HashQuestion(question) 57 | 58 | if questionMap, ok := c.cache[questionHash]; ok { 59 | answers := []AnswerHash{} 60 | answers = append(answers, questionMap...) 61 | 62 | return answers 63 | } 64 | 65 | return nil 66 | } 67 | -------------------------------------------------------------------------------- /pkg/wildcarder/dnscache_test.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/resolvermt" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDNSCacheAdd(t *testing.T) { 11 | answerA := []DNSAnswer{ 12 | {Type: resolvermt.TypeA, Answer: "127.0.0.1"}, 13 | } 14 | 15 | answerB := []DNSAnswer{ 16 | {Type: resolvermt.TypeA, Answer: "127.0.0.1"}, 17 | {Type: resolvermt.TypeAAAA, Answer: "::1"}, 18 | {Type: resolvermt.TypeCNAME, Answer: "test"}, 19 | {Type: resolvermt.TypeA, Answer: "127.0.0.1"}, 20 | } 21 | 22 | wantA := []AnswerHash{ 23 | HashAnswer(answerA[0]), 24 | } 25 | 26 | cache := NewDNSCache() 27 | 28 | cache.Add("question", answerA) 29 | got := cache.Find("question") 30 | assert.ElementsMatch(t, wantA, got, "element added to internal cache") 31 | 32 | wantB := []AnswerHash{ 33 | HashAnswer(answerA[0]), 34 | HashAnswer(answerB[1]), 35 | HashAnswer(answerB[2]), 36 | } 37 | 38 | cache.Add("question", answerB) 39 | got = cache.Find("question") 40 | assert.ElementsMatch(t, wantB, got, "no element duplicated") 41 | } 42 | 43 | func TestDNSCacheAddDifferentQuestion(t *testing.T) { 44 | answerA := []DNSAnswer{{Type: resolvermt.TypeA, Answer: "127.0.0.1"}} 45 | answerB := []DNSAnswer{{Type: resolvermt.TypeAAAA, Answer: "::1"}} 46 | 47 | wantA := []AnswerHash{HashAnswer(answerA[0])} 48 | wantB := []AnswerHash{HashAnswer(answerB[0])} 49 | 50 | cache := NewDNSCache() 51 | cache.Add("question 1", answerA) 52 | cache.Add("question 2", answerB) 53 | 54 | got1 := cache.Find("question 1") 55 | got2 := cache.Find("question 2") 56 | 57 | assert.ElementsMatch(t, wantA, got1) 58 | assert.ElementsMatch(t, wantB, got2) 59 | } 60 | 61 | func TestDNSCacheFind(t *testing.T) { 62 | answers := []DNSAnswer{ 63 | {Type: resolvermt.TypeA, Answer: "127.0.0.1"}, 64 | {Type: resolvermt.TypeAAAA, Answer: "::1"}, 65 | {Type: resolvermt.TypeCNAME, Answer: "test"}, 66 | } 67 | 68 | hashes := []AnswerHash{ 69 | HashAnswer(answers[0]), 70 | HashAnswer(answers[1]), 71 | HashAnswer(answers[2]), 72 | } 73 | 74 | tests := []struct { 75 | name string 76 | haveAnswers []DNSAnswer 77 | haveQuestion string 78 | want []AnswerHash 79 | }{ 80 | {name: "existing question", haveQuestion: "question", haveAnswers: answers, want: hashes}, 81 | {name: "existing question without answers", haveQuestion: "question", haveAnswers: []DNSAnswer{}, want: []AnswerHash{}}, 82 | {name: "inexistent question", haveQuestion: "invalid", haveAnswers: answers, want: nil}, 83 | } 84 | 85 | for _, test := range tests { 86 | t.Run(test.name, func(t *testing.T) { 87 | cache := NewDNSCache() 88 | cache.Add("question", test.haveAnswers) 89 | 90 | got := cache.Find(test.haveQuestion) 91 | assert.ElementsMatch(t, test.want, got) 92 | 93 | if got == nil || test.want == nil { 94 | assert.Equal(t, test.want, got) 95 | } 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /pkg/wildcarder/doc.go: -------------------------------------------------------------------------------- 1 | // Package wildcarder provides a subdomain wildcard detection algorithm. It operates on the DNS records of the target domains. 2 | // It can use the result file produced by "massdns -o Snl" to optimize the process and lower the number of DNS queries made. 3 | package wildcarder 4 | -------------------------------------------------------------------------------- /pkg/wildcarder/gather.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | func gatherRoots(cache *answerCache) []string { 4 | rootMap := make(map[string]struct{}) 5 | 6 | for _, roots := range cache.cache { 7 | for _, root := range roots { 8 | rootMap[root] = struct{}{} 9 | } 10 | } 11 | 12 | found := []string{} 13 | for root := range rootMap { 14 | found = append(found, root) 15 | } 16 | 17 | return found 18 | } 19 | -------------------------------------------------------------------------------- /pkg/wildcarder/gather_test.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/d3mondev/resolvermt" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGatherRoots(t *testing.T) { 11 | emptyCache := newAnswerCache() 12 | 13 | rootCache := newAnswerCache() 14 | rootCache.add("root A", []DNSAnswer{{Type: resolvermt.TypeA}}) 15 | rootCache.add("root B", []DNSAnswer{{Type: resolvermt.TypeA}}) 16 | rootCache.add("root C", []DNSAnswer{{Type: resolvermt.TypeA}}) 17 | 18 | tests := []struct { 19 | name string 20 | haveCache *answerCache 21 | want []string 22 | }{ 23 | {name: "empty cache", haveCache: emptyCache, want: []string{}}, 24 | {name: "root detected", haveCache: rootCache, want: []string{"root A", "root B", "root C"}}, 25 | } 26 | 27 | for _, test := range tests { 28 | t.Run(test.name, func(t *testing.T) { 29 | got := gatherRoots(test.haveCache) 30 | 31 | assert.ElementsMatch(t, test.want, got) 32 | }) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /pkg/wildcarder/hashing_hash.go: -------------------------------------------------------------------------------- 1 | //go:build !no_hashing 2 | // +build !no_hashing 3 | 4 | package wildcarder 5 | 6 | import ( 7 | "hash/maphash" 8 | 9 | "github.com/d3mondev/resolvermt" 10 | ) 11 | 12 | var hashSeed maphash.Seed = maphash.MakeSeed() 13 | 14 | // QuestionHash is the type of the question stored in the cache. 15 | type QuestionHash uint64 16 | 17 | // AnswerHash is the type of an answer stored in the cache. 18 | type AnswerHash struct { 19 | Type resolvermt.RRtype 20 | Hash uint64 21 | } 22 | 23 | // HashQuestion hashes a question and returns a QuestionHash. 24 | func HashQuestion(question string) QuestionHash { 25 | var hasher maphash.Hash 26 | hasher.SetSeed(hashSeed) 27 | hasher.WriteString(question) 28 | 29 | return QuestionHash(hasher.Sum64()) 30 | } 31 | 32 | // HashAnswer hashes a DNSAnswer and returns a AnswerHash. 33 | func HashAnswer(answer DNSAnswer) AnswerHash { 34 | var hasher maphash.Hash 35 | hasher.SetSeed(hashSeed) 36 | hasher.WriteString(answer.Answer) 37 | 38 | return AnswerHash{ 39 | Type: answer.Type, 40 | Hash: hasher.Sum64(), 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /pkg/wildcarder/hashing_string.go: -------------------------------------------------------------------------------- 1 | //go:build no_hashing 2 | // +build no_hashing 3 | 4 | package wildcarder 5 | 6 | // QuestionHash is the type of the question stored in the cache. 7 | type QuestionHash string 8 | 9 | // AnswerHash is the type of an answer stored in the cache. 10 | type AnswerHash DNSAnswer 11 | 12 | // HashQuestion hashes a question and returns a QuestionHash. 13 | func HashQuestion(question string) QuestionHash { 14 | return QuestionHash(question) 15 | } 16 | 17 | // HashAnswer hashes a DNSAnswer and returns a AnswerHash. 18 | func HashAnswer(answer DNSAnswer) AnswerHash { 19 | return AnswerHash(answer) 20 | } 21 | -------------------------------------------------------------------------------- /pkg/wildcarder/randomsub.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | const randomSubdomainLength = 16 9 | 10 | func newRandomSubdomains(count int) []string { 11 | const letters = "abcdefghijklmnopqrstuvwxyz1234567890" 12 | 13 | rng := rand.New(rand.NewSource(time.Now().UnixNano())) 14 | 15 | var subs []string 16 | 17 | for i := 0; i < count; i++ { 18 | b := make([]byte, randomSubdomainLength) 19 | 20 | for i := range b { 21 | b[i] = letters[rng.Intn(len(letters))] 22 | } 23 | 24 | subs = append(subs, string(b)) 25 | } 26 | 27 | return subs 28 | } 29 | -------------------------------------------------------------------------------- /pkg/wildcarder/wildcarder.go: -------------------------------------------------------------------------------- 1 | package wildcarder 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "strings" 7 | "sync" 8 | 9 | "github.com/d3mondev/puredns/v2/pkg/threadpool" 10 | ) 11 | 12 | var defaultResolvers []string = []string{ 13 | "8.8.8.8", 14 | "8.8.4.4", 15 | } 16 | 17 | // Wildcarder filters out wildcard subdomains from a list. 18 | type Wildcarder struct { 19 | resolver Resolver 20 | threadCount int 21 | 22 | answerCache *answerCache 23 | preCache *DNSCache 24 | dnsCache *DNSCache 25 | 26 | tpool *threadpool.ThreadPool 27 | tpoolMutex sync.Mutex 28 | total int 29 | 30 | randomSubdomains []string 31 | } 32 | 33 | // Resolver resolves domain names A and AAAA records and returns the DNS answers found. 34 | type Resolver interface { 35 | Resolve(domains []string) []DNSAnswer 36 | QueryCount() int 37 | } 38 | 39 | type result struct { 40 | mu sync.Mutex 41 | domains []string 42 | } 43 | 44 | // New returns a Wildcarder object used to filter out wildcards. 45 | func New(threadCount int, testCount int, options ...Option) *Wildcarder { 46 | config := buildConfig(options) 47 | 48 | resolver := config.resolver 49 | if resolver == nil { 50 | resolver = NewClientDNS(defaultResolvers, 3, 100, 10) 51 | } 52 | 53 | precache := config.precache 54 | if precache == nil { 55 | precache = NewDNSCache() 56 | } 57 | 58 | wc := &Wildcarder{ 59 | threadCount: threadCount, 60 | resolver: resolver, 61 | 62 | answerCache: newAnswerCache(), 63 | preCache: precache, 64 | dnsCache: NewDNSCache(), 65 | 66 | randomSubdomains: newRandomSubdomains(testCount), 67 | } 68 | 69 | return wc 70 | } 71 | 72 | // Filter reads subdomains from a reader and returns a list of domains that are not wildcards, 73 | // along with the wildcard subdomain roots found. 74 | func (wc *Wildcarder) Filter(r io.Reader) (domains, roots []string) { 75 | // Mutex used because a progress bar could be trying to access wc.tpool through wc.Current(), 76 | // creating a benign race condition that can make tests fail 77 | wc.tpoolMutex.Lock() 78 | if wc.tpool != nil { 79 | panic("concurrent executions of Filter on the same Wildcarder object is not supported") 80 | } 81 | wc.tpool = threadpool.NewThreadPool(wc.threadCount, 1000) 82 | wc.tpoolMutex.Unlock() 83 | 84 | results := &result{} 85 | 86 | scanner := bufio.NewScanner(r) 87 | for scanner.Scan() { 88 | domain := strings.TrimSpace(scanner.Text()) 89 | if domain == "" { 90 | continue 91 | } 92 | 93 | // If the domain is too long to test for wildcards, it won't appear in the results 94 | if len(domain)+randomSubdomainLength > 253 { 95 | continue 96 | } 97 | 98 | ctx := detectionTaskContext{ 99 | results: results, 100 | 101 | resolver: wc.resolver, 102 | wildcardCache: wc.answerCache, 103 | preCache: wc.preCache, 104 | dnsCache: wc.dnsCache, 105 | 106 | randomSubs: wc.randomSubdomains, 107 | queryCount: len(wc.randomSubdomains), 108 | } 109 | 110 | task := newDetectionTask(ctx, domain) 111 | wc.tpool.Execute(task) 112 | } 113 | 114 | wc.tpool.Wait() 115 | 116 | wc.tpoolMutex.Lock() 117 | wc.total += wc.tpool.CurrentCount() 118 | wc.tpool.Close() 119 | wc.tpool = nil 120 | wc.tpoolMutex.Unlock() 121 | 122 | domains = results.domains 123 | roots = gatherRoots(wc.answerCache) 124 | 125 | return domains, roots 126 | } 127 | 128 | // QueryCount returns the total number of DNS queries made so far to detect wildcards. 129 | func (wc *Wildcarder) QueryCount() int { 130 | return wc.resolver.QueryCount() 131 | } 132 | 133 | // Current returns the current number of domains that have been processed. 134 | func (wc *Wildcarder) Current() int { 135 | wc.tpoolMutex.Lock() 136 | defer wc.tpoolMutex.Unlock() 137 | 138 | if wc.tpool == nil { 139 | return wc.total 140 | } 141 | 142 | return wc.total + wc.tpool.CurrentCount() 143 | } 144 | 145 | // SetPreCache sets the precache after the Wildcarder object has been created. 146 | func (wc *Wildcarder) SetPreCache(precache *DNSCache) { 147 | wc.preCache = precache 148 | } 149 | 150 | // Option configures a wildcarder. 151 | type Option interface { 152 | apply(c *config) 153 | } 154 | 155 | // WithPreCache returns an option that provides a pre-populated DNS cache used to 156 | // optimize the number of DNS queries made during the wildcard detection phase. 157 | // This DNS cache is not trusted, and the results will be validated as needed using trusted resolvers. 158 | func WithPreCache(cache *DNSCache) Option { 159 | return precacheOption{precache: cache} 160 | } 161 | 162 | type precacheOption struct { 163 | precache *DNSCache 164 | } 165 | 166 | func (o precacheOption) apply(c *config) { 167 | c.precache = o.precache 168 | } 169 | 170 | // WithResolver returns an option that provides a custom resolver to use while performing wildcard detection. 171 | func WithResolver(resolver Resolver) Option { 172 | return resolverOption{resolver: resolver} 173 | } 174 | 175 | type resolverOption struct { 176 | resolver Resolver 177 | } 178 | 179 | func (o resolverOption) apply(c *config) { 180 | c.resolver = o.resolver 181 | } 182 | 183 | type config struct { 184 | precache *DNSCache 185 | resolver Resolver 186 | } 187 | 188 | func buildConfig(options []Option) config { 189 | config := config{} 190 | 191 | for _, opt := range options { 192 | opt.apply(&config) 193 | } 194 | 195 | return config 196 | } 197 | --------------------------------------------------------------------------------